You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

164 lines
7.7 KiB

import torch
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import argparse
import sys
sys.path.append(".")
from models import UNet, PatchGANDiscriminator
from losses import GeneratorLoss, DiscriminatorLoss
from dataloader import CustomDataset, transform
def train_model(root_dir, start_epoch, num_epochs, load_model_g, load_model_d, num_workers,
val_freq, batch_size, accum_iter, lr, lr_d):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")
if torch.cuda.device_count() > 0:
print(f"{torch.cuda.device_count()} GPU(s)")
if torch.cuda.device_count() > 1:
print("multi-GPU training is currently not supported.")
# Create instances of the dataset and split into scripts and validation sets
dataset = CustomDataset(root_dir, transform=transform)
# Assuming you want to use 80% of the data for scripts and 20% for validation
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
print(f"Training on {len(train_dataset)} samples, validating on {len(val_dataset)} samples")
# Create data loaders for scripts and validation
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
# Create instances of the U-Net, discriminator, and loss models
print("Creating models")
unet_model = UNet()
discriminator_model = PatchGANDiscriminator(input_channels=4)
if load_model_g:
unet_model.load_state_dict(torch.load(load_model_g, map_location=device))
print(f'loaded {load_model_g} for unet_model')
if load_model_d:
discriminator_model.load_state_dict(torch.load(load_model_d, map_location=device))
print(f'loaded {load_model_d} for discriminator_model')
print("Adjusting models to device")
unet_model = unet_model.to(device)
discriminator_model = discriminator_model.to(device)
# Create loss
print("Creating loss functions")
generator_loss_func = GeneratorLoss(discriminator_model, l1_weight=1.0, perceptual_weight=1.0,
adversarial_weight=0.05, device=device)
discriminator_loss_func = DiscriminatorLoss(discriminator_model)
# Create instances of the Adam optimizer
print("Creating optimizers")
optimizer_g = optim.Adam(unet_model.parameters(), lr=lr)
optimizer_d = optim.Adam(discriminator_model.parameters(), lr=lr_d)
# Training and validation loop
best_val_loss = float('inf')
print(f"Training for {num_epochs} epochs")
for epoch in range(start_epoch, num_epochs):
# Training
unet_model.train()
discriminator_model.train()
batch_idx = 0
for batch in train_dataloader:
print(f"Training Epoch [{epoch + 1}/{num_epochs}], Batch [{batch_idx + 1}/{len(train_dataloader)}]")
batch_idx += 1
source_images, target_images = batch
source_images = source_images.to(device)
target_images = target_images.to(device)
# Forward pass
output_images = unet_model(source_images)
output_images += source_images[:, :3, :, :]
# Discriminator pass
discriminator_loss = discriminator_loss_func(output_images.detach(), target_images, source_images)
discriminator_loss.backward()
if (batch_idx % accum_iter == 0) or (batch_idx == len(train_dataloader)):
optimizer_d.step()
optimizer_d.zero_grad()
# Generator pass
# Calculate the loss
generator_loss, l1_loss, per_loss, adv_loss = generator_loss_func(output_images, target_images,
source_images)
generator_loss, l1_loss, per_loss, adv_loss = [i / accum_iter for i in
[generator_loss, l1_loss, per_loss, adv_loss]]
generator_loss.backward()
if (batch_idx % accum_iter == 0) or (batch_idx == len(train_dataloader)):
optimizer_g.step()
optimizer_g.zero_grad()
torch.save(unet_model.state_dict(), 'recent_unet_model.pth')
torch.save(discriminator_model.state_dict(), 'recent_discriminator_model.pth')
# Validation
if epoch % val_freq == 0:
unet_model.eval()
total_val_loss = 0.0
with torch.no_grad():
for val_batch in val_dataloader:
val_source_images, val_target_images = val_batch
val_source_images = val_source_images.to(device)
val_target_images = val_target_images.to(device)
# Forward pass
val_output_images = unet_model(val_source_images)
# Calculate the loss
generator_loss, _, _, _ = generator_loss_func(val_output_images, val_target_images,
val_source_images)
total_val_loss += generator_loss.item()
average_val_loss = total_val_loss / len(val_dataloader)
# Print validation information
print(f'Validation Epoch [{epoch + 1}/{num_epochs}], Average Loss: {average_val_loss}')
# Save the model with the best validation loss
if average_val_loss < best_val_loss:
best_val_loss = average_val_loss
torch.save(unet_model.state_dict(), 'best_unet_model.pth')
torch.save(discriminator_model.state_dict(), 'best_discriminator_model.pth')
if __name__ == "__main__":
# Define command-line arguments
parser = argparse.ArgumentParser(description="Training Script")
parser.add_argument("--root_dir", type=str, default='image',
help="Path to the training data. Note the format: To use the dataloader, the directory should be filled with folders containing image files of various ages, where the file name is {age}.jpg")
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch, if scripts is resumed")
parser.add_argument("--num_epochs", type=int, default=50, help="End epoch")
parser.add_argument("--load_model_g", type=str, default='recent_unet_model.pth',
help="Path to pretrained generator model. Leave blank to train from scratch")
parser.add_argument("--load_model_d", type=str, default='recent_discriminator_model.pth',
help="Path to pretrained discriminator model. Leave blank to train from scratch")
parser.add_argument("--num_workers", type=int, default=10, help="Number of workers")
parser.add_argument("--batch_size", type=int, default=3, help="Batch size")
parser.add_argument("--accum_iter", type=int, default=2, help="Number of batches after which weights are updated")
parser.add_argument("--val_freq", type=int, default=1, help="Validation frequency (epochs)")
parser.add_argument("--lr", type=float, default=0.00001, help="Learning rate for generator")
parser.add_argument("--lr_d", type=float, default=0.00001, help="Learning rate for discriminator")
# Parse command-line arguments
args = parser.parse_args()
# Call the scripts function with parsed arguments
train_model(args.root_dir, args.start_epoch, args.num_epochs, args.load_model_g, args.load_model_d,
args.num_workers, args.val_freq, args.batch_size, args.accum_iter, args.lr, args.lr_d)