import torch import torch.nn as nn import lpips class GeneratorLoss(nn.Module): def __init__(self, discriminator_model, l1_weight=1.0, perceptual_weight=1.0, adversarial_weight=0.05, device="cpu"): super(GeneratorLoss, self).__init__() self.discriminator_model = discriminator_model self.l1_weight = l1_weight self.perceptual_weight = perceptual_weight self.adversarial_weight = adversarial_weight self.criterion_l1 = nn.L1Loss() self.criterion_adversarial = nn.BCEWithLogitsLoss() self.criterion_perceptual = lpips.LPIPS(net='vgg').to(device) def forward(self, output, target, source): # L1 loss l1_loss = self.criterion_l1(output, target) # Perceptual loss perceptual_loss = torch.mean(self.criterion_perceptual(output, target)) # Adversarial loss fake_input = torch.cat([output, source[:, 4:5, :, :]], dim=1) fake_prediction = self.discriminator_model(fake_input) adversarial_loss = self.criterion_adversarial(fake_prediction, torch.ones_like(fake_prediction)) # Combine losses generator_loss = self.l1_weight * l1_loss + self.perceptual_weight * perceptual_loss + \ self.adversarial_weight * adversarial_loss return generator_loss, l1_loss, perceptual_loss, adversarial_loss class DiscriminatorLoss(nn.Module): def __init__(self, discriminator_model, fake_weight=1.0, real_weight=2.0, mock_weight=.5): super(DiscriminatorLoss, self).__init__() self.discriminator_model = discriminator_model self.criterion_adversarial = nn.BCEWithLogitsLoss() self.fake_weight = fake_weight self.real_weight = real_weight self.mock_weight = mock_weight def forward(self, output, target, source): # Adversarial loss fake_input = torch.cat([output, source[:, 4:5, :, :]], dim=1) # prediction img with target age real_input = torch.cat([target, source[:, 4:5, :, :]], dim=1) # target img with target age mock_input1 = torch.cat([source[:, :3, :, :], source[:, 4:5, :, :]], dim=1) # source img with target age mock_input2 = torch.cat([target, source[:, 3:4, :, :]], dim=1) # target img with source age mock_input3 = torch.cat([output, source[:, 3:4, :, :]], dim=1) # prediction img with source age mock_input4 = torch.cat([target, source[:, 3:4, :, :]], dim=1) # target img with target age fake_pred, real_pred = self.discriminator_model(fake_input), self.discriminator_model(real_input) mock_pred1, mock_pred2, mock_pred3, mock_pred4 = (self.discriminator_model(mock_input1), self.discriminator_model(mock_input2), self.discriminator_model(mock_input3), self.discriminator_model(mock_input4)) discriminator_loss = (self.fake_weight * self.criterion_adversarial(fake_pred, torch.zeros_like(fake_pred)) + self.real_weight * self.criterion_adversarial(real_pred, torch.ones_like(real_pred)) + self.mock_weight * self.criterion_adversarial(mock_pred1, torch.zeros_like(mock_pred1)) + self.mock_weight * self.criterion_adversarial(mock_pred2, torch.zeros_like(mock_pred2)) + self.mock_weight * self.criterion_adversarial(mock_pred3, torch.zeros_like(mock_pred3)) + self.mock_weight * self.criterion_adversarial(mock_pred4, torch.zeros_like(mock_pred4)) ) return discriminator_loss