diff --git a/dataloader.py b/dataloader.py new file mode 100644 index 0000000..c118a0f --- /dev/null +++ b/dataloader.py @@ -0,0 +1,61 @@ +import torch +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +from PIL import Image +import os +import random +from pathlib import Path + + +# Define the transformations +transform = transforms.Compose([ + transforms.RandomRotation(degrees=10), + transforms.Resize((512, 512)), + transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), + transforms.ToTensor(), +]) + +class CustomDataset(Dataset): + def __init__(self, root_dir, transform=None): + self.root_dir = root_dir + self.transform = transform + self.image_folders = [folder for folder in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, folder))] + + def __len__(self): + return len(self.image_folders) + + def __getitem__(self, idx): + folder_name = self.image_folders[idx] + folder_path = os.path.join(self.root_dir, folder_name) + + # # Get the list of image filenames in the folder + image_filenames = os.listdir(folder_path) + + # Pick two random assets from the folder + source_image_name, target_image_name = random.sample(image_filenames, 2) + + source_age = int(Path(source_image_name).stem) / 100 + target_age = int(Path(target_image_name).stem) / 100 + + # Randomly select two assets from the folder + source_image_path = os.path.join(folder_path, source_image_name) + target_image_path = os.path.join(folder_path, target_image_name) + + source_image = Image.open(source_image_path).convert('RGB') + target_image = Image.open(target_image_path).convert('RGB') + + # Apply the same random crop and augmentations to both assets + if self.transform: + seed = torch.randint(0, 2 ** 32 - 1, (1,)).item() + torch.manual_seed(seed) + source_image = self.transform(source_image) + torch.manual_seed(seed) + target_image = self.transform(target_image) + + source_age_channel = torch.full_like(source_image[:1, :, :], source_age) + target_age_channel = torch.full_like(source_image[:1, :, :], target_age) + + # Concatenate the age channels with the source_image + source_image = torch.cat([source_image, source_age_channel, target_age_channel], dim=0) + + return source_image, target_image diff --git a/losses.py b/losses.py new file mode 100644 index 0000000..dabe226 --- /dev/null +++ b/losses.py @@ -0,0 +1,70 @@ +import torch +import torch.nn as nn +import lpips + +class GeneratorLoss(nn.Module): + def __init__(self, discriminator_model, l1_weight=1.0, perceptual_weight=1.0, adversarial_weight=0.05, + device="cpu"): + super(GeneratorLoss, self).__init__() + self.discriminator_model = discriminator_model + self.l1_weight = l1_weight + self.perceptual_weight = perceptual_weight + self.adversarial_weight = adversarial_weight + self.criterion_l1 = nn.L1Loss() + self.criterion_adversarial = nn.BCEWithLogitsLoss() + self.criterion_perceptual = lpips.LPIPS(net='vgg').to(device) + + def forward(self, output, target, source): + # L1 loss + + l1_loss = self.criterion_l1(output, target) + + # Perceptual loss + perceptual_loss = torch.mean(self.criterion_perceptual(output, target)) + + # Adversarial loss + fake_input = torch.cat([output, source[:, 4:5, :, :]], dim=1) + fake_prediction = self.discriminator_model(fake_input) + + adversarial_loss = self.criterion_adversarial(fake_prediction, torch.ones_like(fake_prediction)) + + # Combine losses + generator_loss = self.l1_weight * l1_loss + self.perceptual_weight * perceptual_loss + \ + self.adversarial_weight * adversarial_loss + + return generator_loss, l1_loss, perceptual_loss, adversarial_loss + +class DiscriminatorLoss(nn.Module): + def __init__(self, discriminator_model, fake_weight=1.0, real_weight=2.0, mock_weight=.5): + super(DiscriminatorLoss, self).__init__() + self.discriminator_model = discriminator_model + self.criterion_adversarial = nn.BCEWithLogitsLoss() + self.fake_weight = fake_weight + self.real_weight = real_weight + self.mock_weight = mock_weight + + def forward(self, output, target, source): + # Adversarial loss + fake_input = torch.cat([output, source[:, 4:5, :, :]], dim=1) # prediction img with target age + real_input = torch.cat([target, source[:, 4:5, :, :]], dim=1) # target img with target age + + mock_input1 = torch.cat([source[:, :3, :, :], source[:, 4:5, :, :]], dim=1) # source img with target age + mock_input2 = torch.cat([target, source[:, 3:4, :, :]], dim=1) # target img with source age + mock_input3 = torch.cat([output, source[:, 3:4, :, :]], dim=1) # prediction img with source age + mock_input4 = torch.cat([target, source[:, 3:4, :, :]], dim=1) # target img with target age + + fake_pred, real_pred = self.discriminator_model(fake_input), self.discriminator_model(real_input) + mock_pred1, mock_pred2, mock_pred3, mock_pred4 = (self.discriminator_model(mock_input1), + self.discriminator_model(mock_input2), + self.discriminator_model(mock_input3), + self.discriminator_model(mock_input4)) + + discriminator_loss = (self.fake_weight * self.criterion_adversarial(fake_pred, torch.zeros_like(fake_pred)) + + self.real_weight * self.criterion_adversarial(real_pred, torch.ones_like(real_pred)) + + self.mock_weight * self.criterion_adversarial(mock_pred1, torch.zeros_like(mock_pred1)) + + self.mock_weight * self.criterion_adversarial(mock_pred2, torch.zeros_like(mock_pred2)) + + self.mock_weight * self.criterion_adversarial(mock_pred3, torch.zeros_like(mock_pred3)) + + self.mock_weight * self.criterion_adversarial(mock_pred4, torch.zeros_like(mock_pred4)) + ) + + return discriminator_loss diff --git a/models.py b/models.py index 3806319..aa287d9 100644 --- a/models.py +++ b/models.py @@ -74,3 +74,26 @@ class UNet(nn.Module): x = self.up4(x, x0) x = self.final_conv(x) return x + + +class PatchGANDiscriminator(nn.Module): + def __init__(self, input_channels=3): + super(PatchGANDiscriminator, self).__init__() + self.model = nn.Sequential( + nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + + nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(128), + nn.LeakyReLU(0.2, inplace=True), + + nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(256), + nn.LeakyReLU(0.2, inplace=True), + + nn.Conv2d(256, 1, kernel_size=4, stride=1, padding=1) + # Output layer with 1 channel for binary classification + ) + + def forward(self, x): + return self.model(x) diff --git a/train.py b/train.py new file mode 100644 index 0000000..5887643 --- /dev/null +++ b/train.py @@ -0,0 +1,163 @@ +import torch +import torch.optim as optim +from torch.utils.data import DataLoader, random_split +import argparse + +import sys +sys.path.append(".") + +from models import UNet, PatchGANDiscriminator +from losses import GeneratorLoss, DiscriminatorLoss +from dataloader import CustomDataset, transform + + +def train_model(root_dir, start_epoch, num_epochs, load_model_g, load_model_d, num_workers, + val_freq, batch_size, accum_iter, lr, lr_d): + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + print(f"device: {device}") + if torch.cuda.device_count() > 0: + print(f"{torch.cuda.device_count()} GPU(s)") + if torch.cuda.device_count() > 1: + print("multi-GPU training is currently not supported.") + + # Create instances of the dataset and split into scripts and validation sets + dataset = CustomDataset(root_dir, transform=transform) + + # Assuming you want to use 80% of the data for scripts and 20% for validation + train_size = int(0.8 * len(dataset)) + val_size = len(dataset) - train_size + train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) + + print(f"Training on {len(train_dataset)} samples, validating on {len(val_dataset)} samples") + # Create data loaders for scripts and validation + train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) + val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) + + # Create instances of the U-Net, discriminator, and loss models + print("Creating models") + unet_model = UNet() + discriminator_model = PatchGANDiscriminator(input_channels=4) + + if load_model_g: + unet_model.load_state_dict(torch.load(load_model_g, map_location=device)) + print(f'loaded {load_model_g} for unet_model') + if load_model_d: + discriminator_model.load_state_dict(torch.load(load_model_d, map_location=device)) + print(f'loaded {load_model_d} for discriminator_model') + + print("Adjusting models to device") + unet_model = unet_model.to(device) + discriminator_model = discriminator_model.to(device) + + # Create loss + print("Creating loss functions") + generator_loss_func = GeneratorLoss(discriminator_model, l1_weight=1.0, perceptual_weight=1.0, + adversarial_weight=0.05, device=device) + discriminator_loss_func = DiscriminatorLoss(discriminator_model) + + # Create instances of the Adam optimizer + print("Creating optimizers") + optimizer_g = optim.Adam(unet_model.parameters(), lr=lr) + optimizer_d = optim.Adam(discriminator_model.parameters(), lr=lr_d) + + # Training and validation loop + best_val_loss = float('inf') + + print(f"Training for {num_epochs} epochs") + for epoch in range(start_epoch, num_epochs): + # Training + unet_model.train() + discriminator_model.train() + batch_idx = 0 + for batch in train_dataloader: + print(f"Training Epoch [{epoch + 1}/{num_epochs}], Batch [{batch_idx + 1}/{len(train_dataloader)}]") + batch_idx += 1 + source_images, target_images = batch + + source_images = source_images.to(device) + target_images = target_images.to(device) + + # Forward pass + output_images = unet_model(source_images) + output_images += source_images[:, :3, :, :] + + # Discriminator pass + discriminator_loss = discriminator_loss_func(output_images.detach(), target_images, source_images) + discriminator_loss.backward() + + if (batch_idx % accum_iter == 0) or (batch_idx == len(train_dataloader)): + optimizer_d.step() + optimizer_d.zero_grad() + + # Generator pass + # Calculate the loss + generator_loss, l1_loss, per_loss, adv_loss = generator_loss_func(output_images, target_images, + source_images) + generator_loss, l1_loss, per_loss, adv_loss = [i / accum_iter for i in + [generator_loss, l1_loss, per_loss, adv_loss]] + generator_loss.backward() + + if (batch_idx % accum_iter == 0) or (batch_idx == len(train_dataloader)): + optimizer_g.step() + optimizer_g.zero_grad() + + + torch.save(unet_model.state_dict(), 'recent_unet_model.pth') + torch.save(discriminator_model.state_dict(), 'recent_discriminator_model.pth') + + # Validation + if epoch % val_freq == 0: + unet_model.eval() + total_val_loss = 0.0 + with torch.no_grad(): + for val_batch in val_dataloader: + val_source_images, val_target_images = val_batch + val_source_images = val_source_images.to(device) + val_target_images = val_target_images.to(device) + + # Forward pass + val_output_images = unet_model(val_source_images) + + # Calculate the loss + generator_loss, _, _, _ = generator_loss_func(val_output_images, val_target_images, + val_source_images) + total_val_loss += generator_loss.item() + + average_val_loss = total_val_loss / len(val_dataloader) + + # Print validation information + print(f'Validation Epoch [{epoch + 1}/{num_epochs}], Average Loss: {average_val_loss}') + + # Save the model with the best validation loss + if average_val_loss < best_val_loss: + best_val_loss = average_val_loss + torch.save(unet_model.state_dict(), 'best_unet_model.pth') + torch.save(discriminator_model.state_dict(), 'best_discriminator_model.pth') + + + +if __name__ == "__main__": + # Define command-line arguments + parser = argparse.ArgumentParser(description="Training Script") + parser.add_argument("--root_dir", type=str, default='image', + help="Path to the training data. Note the format: To use the dataloader, the directory should be filled with folders containing image files of various ages, where the file name is {age}.jpg") + parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch, if scripts is resumed") + parser.add_argument("--num_epochs", type=int, default=50, help="End epoch") + parser.add_argument("--load_model_g", type=str, default='recent_unet_model.pth', + help="Path to pretrained generator model. Leave blank to train from scratch") + parser.add_argument("--load_model_d", type=str, default='recent_discriminator_model.pth', + help="Path to pretrained discriminator model. Leave blank to train from scratch") + parser.add_argument("--num_workers", type=int, default=10, help="Number of workers") + parser.add_argument("--batch_size", type=int, default=3, help="Batch size") + parser.add_argument("--accum_iter", type=int, default=2, help="Number of batches after which weights are updated") + parser.add_argument("--val_freq", type=int, default=1, help="Validation frequency (epochs)") + parser.add_argument("--lr", type=float, default=0.00001, help="Learning rate for generator") + parser.add_argument("--lr_d", type=float, default=0.00001, help="Learning rate for discriminator") + + # Parse command-line arguments + args = parser.parse_args() + + # Call the scripts function with parsed arguments + train_model(args.root_dir, args.start_epoch, args.num_epochs, args.load_model_g, args.load_model_d, + args.num_workers, args.val_freq, args.batch_size, args.accum_iter, args.lr, args.lr_d)