import os
import numpy as np
import cv2
import tensorflow as tf
from tensorflow.keras import layers
from ..datasets.load_mnist import load_mnist
from ..datasets.load_cifar10 import load_cifar10
from ..datasets.load_custom_data import load_custom_data_with_labels
from ..losses.minmax_loss import gan_discriminator_loss, gan_generator_loss
from ..losses.infogan_loss import auxillary_loss
import datetime
from tqdm import tqdm
import logging
import imageio
logging.getLogger("tensorflow").setLevel(logging.ERROR)
### Silence Imageio warnings
def silence_imageio_warning(*args, **kwargs):
pass
imageio.core.util._precision_warn = silence_imageio_warning
__all__ = ["InfoGAN"]
"""
References:
-> https://arxiv.org/abs/1606.03657
"""
[docs]class InfoGAN:
r"""`InfoGAN <https://arxiv.org/abs/1606.03657>`_ model
Args:
noise_dim (int, optional): represents the dimension of the prior to sample values. Defaults to ``100``
code_dim (int, list, optional): dimension of the interpretable representation. Defaults to ``2``
dropout_rate (float, optional): represents the amount of dropout regularization to be applied. Defaults to ``0.4``
gen_channels (int, list, optional): represents the number of filters in the generator network. Defaults to ``[128, 64]``
disc_channels (int, list, optional): represents the number of filters in the discriminator network. Defaults to ``[64, 128]```
kernel_size (int, tuple, optional): repersents the size of the kernel to perform the convolution. Defaults to ``(5, 5)``
activation (str, optional): type of non-linearity to be applied. Defaults to ``leaky_relu``
kernel_initializer (str, optional): initialization of kernel weights. Defaults to ``glorot_uniform``
kernel_regularizer (str, optional): type of regularization to be applied to the weights. Defaults to ``None``
gen_path (str, optional): path to generator checkpoint to load model weights. Defaults to ``None``
disc_path (str, optional): path to discriminator checkpoint to load model weights. Defaults to ``None``
"""
def __init__(
self,
noise_dim=100,
code_dim=2,
dropout_rate=0.4,
gen_channels=[128, 64],
disc_channels=[64, 128],
kernel_size=(5, 5),
activation="leaky_relu",
kernel_initializer="glorot_uniform",
kernel_regularizer=None,
gen_path=None,
disc_path=None,
):
self.image_size = None
self.config = locals()
self.n_classes = None
self.noise_dim = noise_dim
self.code_dim = code_dim
[docs] def load_data(
self,
data_dir=None,
use_mnist=False,
use_cifar10=False,
batch_size=32,
img_shape=(64, 64),
):
r"""Load data to train the model
Args:
data_dir (str, optional): string representing the directory to load data from. Defaults to ``None``
use_mnist (bool, optional): use the MNIST dataset to train the model. Defaults to ``False``
use_cifar10 (bool, optional): use the CIFAR10 dataset to train the model. Defaults to ``False``
batch_size (int, optional): mini batch size for training the model. Defaults to ``32``
img_shape (int, tuple, optional): shape of the image when loading data from custom directory. Defaults to ``(64, 64)``
Return:
a tensorflow dataset objects representing the training datset
"""
if use_mnist:
train_data = load_mnist()
self.n_classes = 10
elif use_cifar10:
train_data = load_cifar10()
self.n_classes = 10
else:
train_data, train_labels = load_custom_data_with_labels(data_dir, img_shape)
self.n_classes = np.unique(train_labels).shape[0]
self.image_size = train_data.shape[1:]
train_data = (train_data - 127.5) / 127.5
train_ds = (
tf.data.Dataset.from_tensor_slices(train_data)
.shuffle(10000)
.batch(batch_size)
)
return train_ds
[docs] def get_sample(self, data=None, n_samples=1, save_dir=None):
r"""View sample of the data
Args:
data (tf.data object): dataset to load samples from
n_samples (int, optional): number of samples to load. Defaults to ``1``
save_dir (str, optional): directory to save the sample images. Defaults to ``None``
Return:
``None`` if save_dir is ``not None``, otherwise returns numpy array of samples with shape (n_samples, img_shape)
"""
assert data is not None, "Data not provided"
sample_images = []
data = data.unbatch()
for img in data.take(n_samples):
img = img.numpy()
sample_images.append(img)
sample_images = np.array(sample_images)
if save_dir is None:
return sample_images
assert os.path.exists(save_dir), "Directory does not exist"
for i, sample in enumerate(sample_images):
imageio.imwrite(os.path.join(save_dir, "sample_" + str(i) + ".jpg"), sample)
def conv_block(
self,
inputs,
filters,
kernel_size,
strides=(2, 2),
kernel_initializer="glorot_uniform",
kernel_regularizer=None,
padding="same",
activation="leaky_relu",
use_batch_norm=True,
conv_type="normal",
):
if conv_type == "transpose":
x = layers.Conv2DTranspose(
filters,
kernel_size,
strides=strides,
padding=padding,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
)(inputs)
else:
x = layers.Conv2D(
filters,
kernel_size,
strides=strides,
padding=padding,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
)(inputs)
if use_batch_norm:
x = layers.BatchNormalization()(x)
if activation == "leaky_relu":
x = layers.LeakyReLU()(x)
elif activation == "tanh":
x = tf.keras.activations.tanh(x)
return x
[docs] def discriminator(self):
r"""Discriminator module for InfoGAN. Use it as a regular TensorFlow 2.0 Keras Model.
Return:
A tf.keras model
"""
disc_channels = self.config["disc_channels"]
activation = self.config["activation"]
kernel_initializer = self.config["kernel_initializer"]
kernel_regularizer = self.config["kernel_regularizer"]
kernel_size = self.config["kernel_size"]
image_input = layers.Input(self.image_size)
img = self.conv_block(
image_input,
filters=disc_channels[0],
kernel_size=kernel_size,
strides=(2, 2),
activation=activation,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
)
for i in range(1, len(disc_channels)):
img = self.conv_block(
img,
filters=disc_channels[i],
kernel_size=kernel_size,
strides=(2, 2),
activation=activation,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
)
flatten = layers.Flatten()(img)
valid = layers.Dense(1)(flatten)
conv_out = self.conv_block(
img,
filters=disc_channels[-1],
kernel_size=kernel_size,
strides=(1, 1),
activation=activation,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
)
conv_out = layers.Flatten()(conv_out)
discrete_out = layers.Dense(self.n_classes, activation="softmax")(conv_out)
cont_out = layers.Dense(self.code_dim)(conv_out)
disc_model = tf.keras.Model(
inputs=image_input, outputs=[valid, discrete_out, cont_out]
)
return disc_model
[docs] def generator(self):
r"""Generator module for InfoGAN. Use it as a regular TensorFlow 2.0 Keras Model.
Return:
A tf.keras model
"""
gen_channels = self.config["gen_channels"]
activation = self.config["activation"]
kernel_initializer = self.config["kernel_initializer"]
kernel_regularizer = self.config["kernel_regularizer"]
kernel_size = self.config["kernel_size"]
input_shape = self.noise_dim + self.n_classes + self.code_dim
input_noise = layers.Input(shape=input_shape)
_input = layers.Dense(
(self.image_size[0] // 4)
* (self.image_size[1])
// 4
* (gen_channels[0] * 2),
use_bias=False,
)(input_noise)
_input = layers.BatchNormalization()(_input)
_input = layers.LeakyReLU()(_input)
img = layers.Reshape(
(
(self.image_size[0] // 4),
(self.image_size[1] // 4),
(gen_channels[0] * 2),
)
)(_input)
for i in range(len(gen_channels)):
img = self.conv_block(
img,
filters=gen_channels[i],
kernel_size=kernel_size,
strides=(2, 2),
conv_type="transpose",
activation=activation,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
)
img = self.conv_block(
img,
filters=self.image_size[-1],
kernel_size=kernel_size,
strides=(1, 1),
use_batch_norm=False,
activation="tanh",
conv_type="transpose",
)
gen_model = tf.keras.Model(input_noise, img)
return gen_model
def __load_model(self):
self.gen_model, self.disc_model = self.generator(), self.discriminator()
if self.config["gen_path"] is not None:
self.gen_model.load_weights(self.config["gen_path"])
print("Generator checkpoint restored")
if self.config["disc_path"] is not None:
self.disc_model.load_weights(self.config["disc_path"])
print("Discriminator checkpoint restored")
[docs] def fit(
self,
train_ds=None,
epochs=100,
gen_optimizer="Adam",
disc_optimizer="Adam",
verbose=1,
gen_learning_rate=0.0001,
disc_learning_rate=0.0002,
beta_1=0.5,
tensorboard=False,
save_model=None,
):
r"""Function to train the model
Args:
train_ds (tf.data object): training data
epochs (int, optional): number of epochs to train the model. Defaults to ``100``
gen_optimizer (str, optional): optimizer used to train generator. Defaults to ``Adam``
disc_optimizer (str, optional): optimizer used to train discriminator. Defaults to ``Adam``
verbose (int, optional): 1 - prints training outputs, 0 - no outputs. Defaults to ``1``
gen_learning_rate (float, optional): learning rate of the generator optimizer. Defaults to ``0.0001``
disc_learning_rate (float, optional): learning rate of the discriminator optimizer. Defaults to ``0.0002``
beta_1 (float, optional): decay rate of the first momement. set if ``Adam`` optimizer is used. Defaults to ``0.5``
tensorboard (bool, optional): if true, writes loss values to ``logs/gradient_tape`` directory
which aids visualization. Defaults to ``False``
save_model (str, optional): Directory to save the trained model. Defaults to ``None``
"""
assert train_ds is not None, "No Input data found"
self.__load_model()
kwargs = {}
kwargs["learning_rate"] = gen_learning_rate
if gen_optimizer == "Adam":
kwargs["beta_1"] = beta_1
gen_optimizer = getattr(tf.keras.optimizers, gen_optimizer)(**kwargs)
kwargs = {}
kwargs["learning_rate"] = disc_learning_rate
if disc_optimizer == "Adam":
kwargs["beta_1"] = beta_1
disc_optimizer = getattr(tf.keras.optimizers, disc_optimizer)(**kwargs)
if tensorboard:
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_log_dir = "logs/gradient_tape/" + current_time + "/train"
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
steps = 0
generator_loss = tf.keras.metrics.Mean()
discriminator_loss = tf.keras.metrics.Mean()
total_batches = tf.data.experimental.cardinality(train_ds).numpy()
for epoch in range(epochs):
generator_loss.reset_states()
discriminator_loss.reset_states()
pbar = tqdm(total=total_batches, desc="Epoch - " + str(epoch + 1))
for data in train_ds:
batch_size = data.shape[0]
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
Z = np.random.randn(batch_size, self.noise_dim)
label_input = tf.keras.utils.to_categorical(
(np.random.randint(0, self.n_classes, batch_size)),
self.n_classes,
)
code_input = np.random.randn(batch_size, self.code_dim)
c = np.concatenate((Z, label_input, code_input), axis=1)
gen_imgs = self.gen_model(c, training=True)
real_output, _, _ = self.disc_model(data, training=True)
fake_output, discrete, cont_out = self.disc_model(
gen_imgs, training=True
)
info_loss = auxillary_loss(
discrete, label_input, code_input, cont_out
)
gen_loss = gan_generator_loss(fake_output) + info_loss
disc_loss = (
gan_discriminator_loss(real_output, fake_output) + info_loss
)
generator_grads = gen_tape.gradient(
gen_loss, self.gen_model.trainable_variables
)
discriminator_grads = disc_tape.gradient(
disc_loss, self.disc_model.trainable_variables
)
gen_optimizer.apply_gradients(
zip(generator_grads, self.gen_model.trainable_variables)
)
disc_optimizer.apply_gradients(
zip(discriminator_grads, self.disc_model.trainable_variables)
)
generator_loss.update_state(gen_loss)
discriminator_loss.update_state(disc_loss)
pbar.update(1)
pbar.set_postfix(
disc_loss=discriminator_loss.result().numpy(),
gen_loss=generator_loss.result().numpy(),
)
steps += 1
if tensorboard:
with train_summary_writer.as_default():
tf.summary.scalar(
"discr_loss", disc_loss.numpy(), step=steps
)
tf.summary.scalar("genr_loss", gen_loss.numpy(), step=steps)
pbar.close()
del pbar
if verbose:
print(
"Epoch:",
epoch + 1,
"D_loss:",
generator_loss.result().numpy(),
"G_loss",
discriminator_loss.result().numpy(),
)
if save_model is not None:
assert isinstance(save_model, str), "Not a valid directory"
if save_model[-1] != "/":
self.gen_model.save_weights(save_model + "/generator_checkpoint")
self.disc_model.save_weights(save_model + "/discriminator_checkpoint")
else:
self.gen_model.save_weights(save_model + "generator_checkpoint")
self.disc_model.save_weights(save_model + "discriminator_checkpoint")
[docs] def generate_samples(self, n_samples=1, save_dir=None):
r"""Generate samples using the trained model
Args:
n_samples (int, optional): number of samples to generate. Defaults to ``1``
save_dir (str, optional): directory to save the generated images. Defaults to ``None``
Return:
returns ``None`` if save_dir is ``not None``, otherwise returns a numpy array with generated samples
"""
if self.gen_model is None:
self.__load_model()
Z = np.random.randn(n_samples, self.noise_dim)
label_input = tf.keras.utils.to_categorical(
(np.random.randint(0, self.n_classes, n_samples)), self.n_classes
)
code_input = np.random.randn(n_samples, self.code_dim)
seed = np.concatenate((Z, label_input, code_input), axis=1)
generated_samples = self.gen_model(seed).numpy()
if save_dir is None:
return generated_samples
assert os.path.exists(save_dir), "Directory does not exist"
for i, sample in enumerate(generated_samples):
imageio.imwrite(os.path.join(save_dir, "sample_" + str(i) + ".jpg"), sample)