Source code for simplegan.losses.minmax_loss

import tensorflow as tf


__all__ = ["gan_discriminator_loss", "gan_generator_loss"]

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)


[docs]def gan_discriminator_loss(real_output, fake_output): r""" Args: real_output (tensor): A tensor representing the real logits of discriminator fake_output (tensor): A tensor representing the fake logits of discriminator Return: a tensor representing the sum of real and fake loss """ real_loss = cross_entropy(tf.ones_like(real_output), real_output) fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output) return real_loss + fake_loss
[docs]def gan_generator_loss(fake_output): r""" Args: fake_output (tensor): A tensor representing the fake logits of discriminator Return: a tensor representing the generator loss """ return cross_entropy(tf.ones_like(fake_output), fake_output)