import torch import torch.optim as optim from torch.utils.data import DataLoader, random_split import argparse import sys sys.path.append(".") from models import UNet, PatchGANDiscriminator from losses import GeneratorLoss, DiscriminatorLoss from dataloader import CustomDataset, transform def train_model(root_dir, start_epoch, num_epochs, load_model_g, load_model_d, num_workers, val_freq, batch_size, accum_iter, lr, lr_d): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(f"device: {device}") if torch.cuda.device_count() > 0: print(f"{torch.cuda.device_count()} GPU(s)") if torch.cuda.device_count() > 1: print("multi-GPU training is currently not supported.") # Create instances of the dataset and split into scripts and validation sets dataset = CustomDataset(root_dir, transform=transform) # Assuming you want to use 80% of the data for scripts and 20% for validation train_size = int(0.8 * len(dataset)) val_size = len(dataset) - train_size train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) print(f"Training on {len(train_dataset)} samples, validating on {len(val_dataset)} samples") # Create data loaders for scripts and validation train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) # Create instances of the U-Net, discriminator, and loss models print("Creating models") unet_model = UNet() discriminator_model = PatchGANDiscriminator(input_channels=4) if load_model_g: unet_model.load_state_dict(torch.load(load_model_g, map_location=device)) print(f'loaded {load_model_g} for unet_model') if load_model_d: discriminator_model.load_state_dict(torch.load(load_model_d, map_location=device)) print(f'loaded {load_model_d} for discriminator_model') print("Adjusting models to device") unet_model = unet_model.to(device) discriminator_model = discriminator_model.to(device) # Create loss print("Creating loss functions") generator_loss_func = GeneratorLoss(discriminator_model, l1_weight=1.0, perceptual_weight=1.0, adversarial_weight=0.05, device=device) discriminator_loss_func = DiscriminatorLoss(discriminator_model) # Create instances of the Adam optimizer print("Creating optimizers") optimizer_g = optim.Adam(unet_model.parameters(), lr=lr) optimizer_d = optim.Adam(discriminator_model.parameters(), lr=lr_d) # Training and validation loop best_val_loss = float('inf') print(f"Training for {num_epochs} epochs") for epoch in range(start_epoch, num_epochs): # Training unet_model.train() discriminator_model.train() batch_idx = 0 for batch in train_dataloader: print(f"Training Epoch [{epoch + 1}/{num_epochs}], Batch [{batch_idx + 1}/{len(train_dataloader)}]") batch_idx += 1 source_images, target_images = batch source_images = source_images.to(device) target_images = target_images.to(device) # Forward pass output_images = unet_model(source_images) output_images += source_images[:, :3, :, :] # Discriminator pass discriminator_loss = discriminator_loss_func(output_images.detach(), target_images, source_images) discriminator_loss.backward() if (batch_idx % accum_iter == 0) or (batch_idx == len(train_dataloader)): optimizer_d.step() optimizer_d.zero_grad() # Generator pass # Calculate the loss generator_loss, l1_loss, per_loss, adv_loss = generator_loss_func(output_images, target_images, source_images) generator_loss, l1_loss, per_loss, adv_loss = [i / accum_iter for i in [generator_loss, l1_loss, per_loss, adv_loss]] generator_loss.backward() if (batch_idx % accum_iter == 0) or (batch_idx == len(train_dataloader)): optimizer_g.step() optimizer_g.zero_grad() torch.save(unet_model.state_dict(), 'recent_unet_model.pth') torch.save(discriminator_model.state_dict(), 'recent_discriminator_model.pth') # Validation if epoch % val_freq == 0: unet_model.eval() total_val_loss = 0.0 with torch.no_grad(): for val_batch in val_dataloader: val_source_images, val_target_images = val_batch val_source_images = val_source_images.to(device) val_target_images = val_target_images.to(device) # Forward pass val_output_images = unet_model(val_source_images) # Calculate the loss generator_loss, _, _, _ = generator_loss_func(val_output_images, val_target_images, val_source_images) total_val_loss += generator_loss.item() average_val_loss = total_val_loss / len(val_dataloader) # Print validation information print(f'Validation Epoch [{epoch + 1}/{num_epochs}], Average Loss: {average_val_loss}') # Save the model with the best validation loss if average_val_loss < best_val_loss: best_val_loss = average_val_loss torch.save(unet_model.state_dict(), 'best_unet_model.pth') torch.save(discriminator_model.state_dict(), 'best_discriminator_model.pth') if __name__ == "__main__": # Define command-line arguments parser = argparse.ArgumentParser(description="Training Script") parser.add_argument("--root_dir", type=str, default='image', help="Path to the training data. Note the format: To use the dataloader, the directory should be filled with folders containing image files of various ages, where the file name is {age}.jpg") parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch, if scripts is resumed") parser.add_argument("--num_epochs", type=int, default=50, help="End epoch") parser.add_argument("--load_model_g", type=str, default='recent_unet_model.pth', help="Path to pretrained generator model. Leave blank to train from scratch") parser.add_argument("--load_model_d", type=str, default='recent_discriminator_model.pth', help="Path to pretrained discriminator model. Leave blank to train from scratch") parser.add_argument("--num_workers", type=int, default=10, help="Number of workers") parser.add_argument("--batch_size", type=int, default=3, help="Batch size") parser.add_argument("--accum_iter", type=int, default=2, help="Number of batches after which weights are updated") parser.add_argument("--val_freq", type=int, default=1, help="Validation frequency (epochs)") parser.add_argument("--lr", type=float, default=0.00001, help="Learning rate for generator") parser.add_argument("--lr_d", type=float, default=0.00001, help="Learning rate for discriminator") # Parse command-line arguments args = parser.parse_args() # Call the scripts function with parsed arguments train_model(args.root_dir, args.start_epoch, args.num_epochs, args.load_model_g, args.load_model_d, args.num_workers, args.val_freq, args.batch_size, args.accum_iter, args.lr, args.lr_d)