![]() |
VOOZH | about |
Wasserstein Generative Adversarial Network (WGANs) is a variation of Deep Learning GAN with little modification in the algorithm. Generative Adversarial Network (GAN) is a method for constructing an efficient generative model. Martin Arjovsky, Soumith Chintala, and LΓ©on Bottou developed this network in 2017. This is used widely to produce real images.
WGAN's architecture uses deep neural networks for both generator and discriminator. The key difference between GANs and WGANs is the loss function and the gradient penalty. WGANs were introduced as the solution to mode collapse issues. The network uses the Wasserstein distance, which provides a meaningful and smoother measure of distance between distributions.
WGANs use the Wasserstein distance, which provides a more meaningful and smoother measure of distance between distributions.
The benefit of having Wasserstein Distance instead of Jensen-Shannon (JS) or Kullback-Leibler divergence is as follows:
The algorithm is stated as follows:
π Screenshot-from-2023-12-14-15-19-01
The steps to generate images using WGANS are discussed below:
For the implementation, required python libraries are: numpy, keras, matplotlib.
To define the wasserstein loss function, we use the following method. Our goal is to minimize the Wasserstein distance between distribution of generated samples and distribution of real samples. The following is an efficient implementation of wasserstein loss function where the score is maximum. We take the average distance, so we use backend.mean()
First is we need to generate the images from the dataset as follows: We will be using the class of digit 5, we can use any value.
Randomly we need to generate real samples from the dataset above we chosen as X.
It is the time to define the critic or discriminator model. We need to update the discriminator model more than generator since it needs to be more accurate otherwise the generator will easily make it fool. Before that, we need the clip constraint to be applied on our weights since we discussed we need the gradient descent and hence we make it cubic clip.
And then we define the critic
In the generator model, we simply take a 28x28 image and downscale it to 7x7 for better performance and model it accurately.
The following method is used to update the generator in GAN. We use the Root Mean Square as our optimizer for the generator since the Adam optimizer generates problem for the model.
Now to generate fake samples, we need latent space, so we put take the latent space and the number of samples and then ask the generator to predict the samples.
It is the time to train the model. Remember we update the critic/discrimnator more than the generator to make it flawless. You can check the generated image in the directory.
We use the following plot functions. You can check the history plot in your directory.
Now to test it run it as follows:
Output:
11490434/11490434 [==============================] - 0s 0us/step
(5421, 28, 28, 1)
1/1 [==============================] - 1s 882ms/step
1/1 [==============================] - 0s 106ms/step
1/1 [==============================] - 0s 50ms/step
1/1 [==============================] - 0s 25ms/step
1/1 [==============================] - 0s 36ms/step
>1, c1=-13.690, c2=-4.848 g=18.497
1/1 [==============================] - 0s 24ms/step
1/1 [==============================] - 0s 23ms/step
1/1 [==============================] - 0s 44ms/step
1/1 [==============================] - 0s 24ms/step
1/1 [==============================] - 0s 33ms/step
>2, c1=-28.276, c2=0.991 g=16.891
1/1 [==============================] - 0s 57ms/step
1/1 [==============================] - 0s 33ms/step
1/1 [==============================] - 0s 70ms/step
1/1 [==============================] - 0s 113ms/step
1/1 [==============================] - 0s 49ms/step
>3, c1=-39.209, c2=-34.840 g=22.131
The samples generated by our GAN model. We can merge the plots as follows:
Output:
As we see, before the epoch 300, we have very unclear generation, and it doesn't correlates to digit 5. But after that, we see some good generation of fake digits which appears real. Hence, we see clearer images as we progress. At the starting stage, the generator gets adjusted to compete with discriminator and provides initialized data modified slightly. After running several epochs, generator gets adjusted and produces good results.
And the loss graph is as follows:
π 53fd73d6-d793-49c3-bf87-f061ba1baf2f
Related Article: