Source code for simplegan.gan.wgan

import os
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Dropout, BatchNormalization
from tensorflow.keras.layers import LeakyReLU, Conv2DTranspose, Dense, Reshape, Flatten
from tensorflow.keras import Model
from ..datasets.load_cifar10 import load_cifar10
from ..datasets.load_mnist import load_mnist
from ..datasets.load_custom_data import load_custom_data
from ..datasets.load_cifar100 import load_cifar100
from .dcgan import DCGAN
from ..losses.wasserstein_loss import wgan_discriminator_loss, wgan_generator_loss
import cv2
import numpy as np
import datetime
from ..datasets.load_lsun import load_lsun
import imageio
from tqdm.auto import tqdm

### Silence Imageio warnings
def silence_imageio_warning(*args, **kwargs):
    pass


imageio.core.util._precision_warn = silence_imageio_warning

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

__all__ = ["WGAN"]

"""
References:
-> https://arxiv.org/abs/1701.07875
"""


[docs]class WGAN(DCGAN): r"""`WGAN <https://arxiv.org/abs/1701.07875>`_ model Args: noise_dim (int, optional): represents the dimension of the prior to sample values. Defaults to ``100`` dropout_rate (float, optional): represents the amount of dropout regularization to be applied. Defaults to ``0.4`` gen_channels (int, list, optional): represents the number of filters in the generator network. Defaults to ``[64, 32, 16]`` disc_channels (int, list, optional): represents the number of filters in the discriminator network. Defaults to ``[16, 32, 64]``` kernel_size (int, tuple, optional): repersents the size of the kernel to perform the convolution. Defaults to ``(5, 5)`` activation (str, optional): type of non-linearity to be applied. Defaults to ``relu`` kernel_initializer (str, optional): initialization of kernel weights. Defaults to ``glorot_uniform`` kernel_regularizer (str, optional): type of regularization to be applied to the weights. Defaults to ``None`` gen_path (str, optional): path to generator checkpoint to load model weights. Defaults to ``None`` disc_path (str, optional): path to discriminator checkpoint to load model weights. Defaults to ``None`` """ def __init__( self, noise_dim=100, dropout_rate=0.4, gen_channels=[64, 32, 16], disc_channels=[16, 32, 64], kernel_size=(5, 5), activation="relu", kernel_initializer="glorot_uniform", kernel_regularizer=None, gen_path=None, disc_path=None, ): DCGAN.__init__( self, noise_dim, dropout_rate, gen_channels, disc_channels, kernel_size, activation, kernel_initializer, kernel_regularizer, gen_path, disc_path, ) def __load_model(self): self.gen_model, self.disc_model = self.generator(), self.discriminator() if self.config["gen_path"] is not None: self.gen_model.load_weights(self.config["gen_path"]) print("Generator checkpoint restored") if self.config["disc_path"] is not None: self.disc_model.load_weights(self.config["disc_path"]) print("Discriminator checkpoint restored")
[docs] def fit( self, train_ds=None, epochs=100, gen_optimizer="RMSprop", disc_optimizer="RMSprop", verbose=1, gen_learning_rate=5e-5, disc_learning_rate=5e-5, beta_1=0.5, tensorboard=False, save_model=None, ): r"""Function to train the model Args: train_ds (tf.data object): training data epochs (int, optional): number of epochs to train the model. Defaults to ``100`` gen_optimizer (str, optional): optimizer used to train generator. Defaults to ``RMSprop`` disc_optimizer (str, optional): optimizer used to train discriminator. Defaults to ``RMSprop`` verbose (int, optional): 1 - prints training outputs, 0 - no outputs. Defaults to ``1`` gen_learning_rate (float, optional): learning rate of the generator optimizer. Defaults to ``5e-5`` disc_learning_rate (float, optional): learning rate of the discriminator optimizer. Defaults to ``5e-5`` beta_1 (float, optional): decay rate of the first momement. set if ``Adam`` optimizer is used. Defaults to ``0.5`` tensorboard (bool, optional): if true, writes loss values to ``logs/gradient_tape`` directory which aids visualization. Defaults to ``False`` save_model (str, optional): Directory to save the trained model. Defaults to ``None`` """ assert ( train_ds is not None ), "Initialize training data through train_ds parameter" self.__load_model() kwargs = {} kwargs["learning_rate"] = gen_learning_rate if gen_optimizer == "Adam": kwargs["beta_1"] = beta_1 gen_optimizer = getattr(tf.keras.optimizers, gen_optimizer)(**kwargs) kwargs = {} kwargs["learning_rate"] = disc_learning_rate if disc_optimizer == "Adam": kwargs["beta_1"] = beta_1 disc_optimizer = getattr(tf.keras.optimizers, disc_optimizer)(**kwargs) if tensorboard: current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") train_log_dir = "logs/gradient_tape/" + current_time + "/train" train_summary_writer = tf.summary.create_file_writer(train_log_dir) steps = 0 generator_loss = tf.keras.metrics.Mean() discriminator_loss = tf.keras.metrics.Mean() try: total = tf.data.experimental.cardinality(train_ds).numpy() except: total = 0 for epoch in range(epochs): generator_loss.reset_states() discriminator_loss.reset_states() pbar = tqdm(total=total, desc="Epoch - " + str(epoch + 1)) for data in train_ds: for _ in range(5): with tf.GradientTape() as tape: Z = tf.random.normal([data.shape[0], self.noise_dim]) fake = self.gen_model(Z) fake_logits = self.disc_model(fake) real_logits = self.disc_model(data) D_loss = wgan_discriminator_loss(real_logits, fake_logits) gradients = tape.gradient( D_loss, self.disc_model.trainable_variables ) clipped_gradients = [ (tf.clip_by_value(grad, -0.01, 0.01)) for grad in gradients ] disc_optimizer.apply_gradients( zip(clipped_gradients, self.disc_model.trainable_variables) ) discriminator_loss(D_loss) with tf.GradientTape() as tape: Z = tf.random.normal([data.shape[0], self.noise_dim]) fake = self.gen_model(Z) fake_logits = self.disc_model(fake) G_loss = wgan_generator_loss(fake_logits) gradients = tape.gradient(G_loss, self.gen_model.trainable_variables) gen_optimizer.apply_gradients( zip(gradients, self.gen_model.trainable_variables) ) generator_loss(G_loss) steps += 1 pbar.update(1) pbar.set_postfix( disc_loss=discriminator_loss.result().numpy(), gen_loss=generator_loss.result().numpy(), ) if tensorboard: with train_summary_writer.as_default(): tf.summary.scalar("discr_loss", D_loss.numpy(), step=steps) tf.summary.scalar("genr_loss", G_loss.numpy(), step=steps) pbar.close() del pbar if verbose == 1: print( "Epoch:", epoch + 1, "D_loss:", generator_loss.result().numpy(), "G_loss", discriminator_loss.result().numpy(), ) if save_model is not None: assert isinstance(save_model, str), "Not a valid directory" if save_model[-1] != "/": self.gen_model.save_weights(save_model + "/generator_checkpoint") self.disc_model.save_weights(save_model + "/discriminator_checkpoint") else: self.gen_model.save_weights(save_model + "generator_checkpoint") self.disc_model.save_weights(save_model + "discriminator_checkpoint")