Source code for simplegan.losses.infogan_loss

import tensorflow as tf


__all__ = ["auxillary_loss"]


[docs]def auxillary_loss(disc_target, disc_preds, cont_trg, cont_pred): r""" Args: disc_target (tensor): a tensor representing the discriminator target ouput disc_preds (tensor): a tensor representing the prediction of discriminator cont_trg (tensor): a tensor representing the content of the target cont_pred (tensor): a tensor representing the predicted content Return: a tensor representing the auxillary loss """ disc_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=disc_target, logits=disc_preds) ) cont_loss = tf.reduce_mean(tf.reduce_sum(tf.square(cont_trg - cont_pred), axis=1)) aux_loss = disc_loss + cont_loss return aux_loss