Source code for simplegan.losses.wasserstein_loss

import tensorflow as tf

__all__ = ["wgan_discriminator_loss", "wgan_generator_loss"]


[docs]def wgan_discriminator_loss(real_output, fake_output): r""" Args: real_output (tensor): a tensor representing the real logits of the discriminator fake_output (tensor): a tensor representing the fake logits of the discriminator Return: total discriminator loss """ total_loss = tf.math.reduce_mean(real_output) - tf.math.reduce_mean(fake_output) return total_loss
[docs]def wgan_generator_loss(fake_output): r""" Args: fake_output (tensor): a tensor representing the fake logits of the discriminator Return: generator loss """ return -tf.math.reduce_mean(fake_output)