Source code for simplegan.losses.pix2pix_loss

import tensorflow as tf

"""
loss functions used in pix2pix model
"""

__all__ = ["pix2pix_generator_loss", "pix2pix_discriminator_loss"]

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)


[docs]def pix2pix_generator_loss(disc_fake_output, fake, real, l): r""" Args: disc_fake_output (tensor): A tensor representing the fake logits of discriminator fake (tensor): A tensor representing the values from the generator real (tensor): A tensor representing the real values l (int): An integer to scale the l1 loss Return: total loss of generator, total loss of GAN and L1 loss """ gan_loss = cross_entropy(tf.ones_like(disc_fake_output), disc_fake_output) l1_loss = tf.math.reduce_mean(tf.math.abs(real - fake)) total_gen_loss = gan_loss + l * l1_loss return total_gen_loss, gan_loss, l1_loss
[docs]def pix2pix_discriminator_loss(disc_real_output, disc_fake_output): r""" Args: disc_real_output (tensor): A tensor representing the real logits of the discriminator disc_fake_output (tensor): A tensor representing the fake logits of the discriminator Return: total loss of discriminator """ real_loss = cross_entropy(tf.ones_like(disc_real_output), disc_real_output) generated_loss = cross_entropy(tf.ones_like(disc_fake_output), disc_fake_output) total_loss = real_loss + generated_loss return total_loss