parent
079544aa8c
commit
b78546745c
@ -0,0 +1,61 @@
|
|||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from torchvision import transforms
|
||||||
|
from PIL import Image
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
# Define the transformations
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.RandomRotation(degrees=10),
|
||||||
|
transforms.Resize((512, 512)),
|
||||||
|
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
])
|
||||||
|
|
||||||
|
class CustomDataset(Dataset):
|
||||||
|
def __init__(self, root_dir, transform=None):
|
||||||
|
self.root_dir = root_dir
|
||||||
|
self.transform = transform
|
||||||
|
self.image_folders = [folder for folder in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, folder))]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.image_folders)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
folder_name = self.image_folders[idx]
|
||||||
|
folder_path = os.path.join(self.root_dir, folder_name)
|
||||||
|
|
||||||
|
# # Get the list of image filenames in the folder
|
||||||
|
image_filenames = os.listdir(folder_path)
|
||||||
|
|
||||||
|
# Pick two random assets from the folder
|
||||||
|
source_image_name, target_image_name = random.sample(image_filenames, 2)
|
||||||
|
|
||||||
|
source_age = int(Path(source_image_name).stem) / 100
|
||||||
|
target_age = int(Path(target_image_name).stem) / 100
|
||||||
|
|
||||||
|
# Randomly select two assets from the folder
|
||||||
|
source_image_path = os.path.join(folder_path, source_image_name)
|
||||||
|
target_image_path = os.path.join(folder_path, target_image_name)
|
||||||
|
|
||||||
|
source_image = Image.open(source_image_path).convert('RGB')
|
||||||
|
target_image = Image.open(target_image_path).convert('RGB')
|
||||||
|
|
||||||
|
# Apply the same random crop and augmentations to both assets
|
||||||
|
if self.transform:
|
||||||
|
seed = torch.randint(0, 2 ** 32 - 1, (1,)).item()
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
source_image = self.transform(source_image)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
target_image = self.transform(target_image)
|
||||||
|
|
||||||
|
source_age_channel = torch.full_like(source_image[:1, :, :], source_age)
|
||||||
|
target_age_channel = torch.full_like(source_image[:1, :, :], target_age)
|
||||||
|
|
||||||
|
# Concatenate the age channels with the source_image
|
||||||
|
source_image = torch.cat([source_image, source_age_channel, target_age_channel], dim=0)
|
||||||
|
|
||||||
|
return source_image, target_image
|
@ -0,0 +1,70 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import lpips
|
||||||
|
|
||||||
|
class GeneratorLoss(nn.Module):
|
||||||
|
def __init__(self, discriminator_model, l1_weight=1.0, perceptual_weight=1.0, adversarial_weight=0.05,
|
||||||
|
device="cpu"):
|
||||||
|
super(GeneratorLoss, self).__init__()
|
||||||
|
self.discriminator_model = discriminator_model
|
||||||
|
self.l1_weight = l1_weight
|
||||||
|
self.perceptual_weight = perceptual_weight
|
||||||
|
self.adversarial_weight = adversarial_weight
|
||||||
|
self.criterion_l1 = nn.L1Loss()
|
||||||
|
self.criterion_adversarial = nn.BCEWithLogitsLoss()
|
||||||
|
self.criterion_perceptual = lpips.LPIPS(net='vgg').to(device)
|
||||||
|
|
||||||
|
def forward(self, output, target, source):
|
||||||
|
# L1 loss
|
||||||
|
|
||||||
|
l1_loss = self.criterion_l1(output, target)
|
||||||
|
|
||||||
|
# Perceptual loss
|
||||||
|
perceptual_loss = torch.mean(self.criterion_perceptual(output, target))
|
||||||
|
|
||||||
|
# Adversarial loss
|
||||||
|
fake_input = torch.cat([output, source[:, 4:5, :, :]], dim=1)
|
||||||
|
fake_prediction = self.discriminator_model(fake_input)
|
||||||
|
|
||||||
|
adversarial_loss = self.criterion_adversarial(fake_prediction, torch.ones_like(fake_prediction))
|
||||||
|
|
||||||
|
# Combine losses
|
||||||
|
generator_loss = self.l1_weight * l1_loss + self.perceptual_weight * perceptual_loss + \
|
||||||
|
self.adversarial_weight * adversarial_loss
|
||||||
|
|
||||||
|
return generator_loss, l1_loss, perceptual_loss, adversarial_loss
|
||||||
|
|
||||||
|
class DiscriminatorLoss(nn.Module):
|
||||||
|
def __init__(self, discriminator_model, fake_weight=1.0, real_weight=2.0, mock_weight=.5):
|
||||||
|
super(DiscriminatorLoss, self).__init__()
|
||||||
|
self.discriminator_model = discriminator_model
|
||||||
|
self.criterion_adversarial = nn.BCEWithLogitsLoss()
|
||||||
|
self.fake_weight = fake_weight
|
||||||
|
self.real_weight = real_weight
|
||||||
|
self.mock_weight = mock_weight
|
||||||
|
|
||||||
|
def forward(self, output, target, source):
|
||||||
|
# Adversarial loss
|
||||||
|
fake_input = torch.cat([output, source[:, 4:5, :, :]], dim=1) # prediction img with target age
|
||||||
|
real_input = torch.cat([target, source[:, 4:5, :, :]], dim=1) # target img with target age
|
||||||
|
|
||||||
|
mock_input1 = torch.cat([source[:, :3, :, :], source[:, 4:5, :, :]], dim=1) # source img with target age
|
||||||
|
mock_input2 = torch.cat([target, source[:, 3:4, :, :]], dim=1) # target img with source age
|
||||||
|
mock_input3 = torch.cat([output, source[:, 3:4, :, :]], dim=1) # prediction img with source age
|
||||||
|
mock_input4 = torch.cat([target, source[:, 3:4, :, :]], dim=1) # target img with target age
|
||||||
|
|
||||||
|
fake_pred, real_pred = self.discriminator_model(fake_input), self.discriminator_model(real_input)
|
||||||
|
mock_pred1, mock_pred2, mock_pred3, mock_pred4 = (self.discriminator_model(mock_input1),
|
||||||
|
self.discriminator_model(mock_input2),
|
||||||
|
self.discriminator_model(mock_input3),
|
||||||
|
self.discriminator_model(mock_input4))
|
||||||
|
|
||||||
|
discriminator_loss = (self.fake_weight * self.criterion_adversarial(fake_pred, torch.zeros_like(fake_pred)) +
|
||||||
|
self.real_weight * self.criterion_adversarial(real_pred, torch.ones_like(real_pred)) +
|
||||||
|
self.mock_weight * self.criterion_adversarial(mock_pred1, torch.zeros_like(mock_pred1)) +
|
||||||
|
self.mock_weight * self.criterion_adversarial(mock_pred2, torch.zeros_like(mock_pred2)) +
|
||||||
|
self.mock_weight * self.criterion_adversarial(mock_pred3, torch.zeros_like(mock_pred3)) +
|
||||||
|
self.mock_weight * self.criterion_adversarial(mock_pred4, torch.zeros_like(mock_pred4))
|
||||||
|
)
|
||||||
|
|
||||||
|
return discriminator_loss
|
@ -0,0 +1,163 @@
|
|||||||
|
import torch
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.utils.data import DataLoader, random_split
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import sys
|
||||||
|
sys.path.append(".")
|
||||||
|
|
||||||
|
from models import UNet, PatchGANDiscriminator
|
||||||
|
from losses import GeneratorLoss, DiscriminatorLoss
|
||||||
|
from dataloader import CustomDataset, transform
|
||||||
|
|
||||||
|
|
||||||
|
def train_model(root_dir, start_epoch, num_epochs, load_model_g, load_model_d, num_workers,
|
||||||
|
val_freq, batch_size, accum_iter, lr, lr_d):
|
||||||
|
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"device: {device}")
|
||||||
|
if torch.cuda.device_count() > 0:
|
||||||
|
print(f"{torch.cuda.device_count()} GPU(s)")
|
||||||
|
if torch.cuda.device_count() > 1:
|
||||||
|
print("multi-GPU training is currently not supported.")
|
||||||
|
|
||||||
|
# Create instances of the dataset and split into scripts and validation sets
|
||||||
|
dataset = CustomDataset(root_dir, transform=transform)
|
||||||
|
|
||||||
|
# Assuming you want to use 80% of the data for scripts and 20% for validation
|
||||||
|
train_size = int(0.8 * len(dataset))
|
||||||
|
val_size = len(dataset) - train_size
|
||||||
|
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
||||||
|
|
||||||
|
print(f"Training on {len(train_dataset)} samples, validating on {len(val_dataset)} samples")
|
||||||
|
# Create data loaders for scripts and validation
|
||||||
|
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
|
||||||
|
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
||||||
|
|
||||||
|
# Create instances of the U-Net, discriminator, and loss models
|
||||||
|
print("Creating models")
|
||||||
|
unet_model = UNet()
|
||||||
|
discriminator_model = PatchGANDiscriminator(input_channels=4)
|
||||||
|
|
||||||
|
if load_model_g:
|
||||||
|
unet_model.load_state_dict(torch.load(load_model_g, map_location=device))
|
||||||
|
print(f'loaded {load_model_g} for unet_model')
|
||||||
|
if load_model_d:
|
||||||
|
discriminator_model.load_state_dict(torch.load(load_model_d, map_location=device))
|
||||||
|
print(f'loaded {load_model_d} for discriminator_model')
|
||||||
|
|
||||||
|
print("Adjusting models to device")
|
||||||
|
unet_model = unet_model.to(device)
|
||||||
|
discriminator_model = discriminator_model.to(device)
|
||||||
|
|
||||||
|
# Create loss
|
||||||
|
print("Creating loss functions")
|
||||||
|
generator_loss_func = GeneratorLoss(discriminator_model, l1_weight=1.0, perceptual_weight=1.0,
|
||||||
|
adversarial_weight=0.05, device=device)
|
||||||
|
discriminator_loss_func = DiscriminatorLoss(discriminator_model)
|
||||||
|
|
||||||
|
# Create instances of the Adam optimizer
|
||||||
|
print("Creating optimizers")
|
||||||
|
optimizer_g = optim.Adam(unet_model.parameters(), lr=lr)
|
||||||
|
optimizer_d = optim.Adam(discriminator_model.parameters(), lr=lr_d)
|
||||||
|
|
||||||
|
# Training and validation loop
|
||||||
|
best_val_loss = float('inf')
|
||||||
|
|
||||||
|
print(f"Training for {num_epochs} epochs")
|
||||||
|
for epoch in range(start_epoch, num_epochs):
|
||||||
|
# Training
|
||||||
|
unet_model.train()
|
||||||
|
discriminator_model.train()
|
||||||
|
batch_idx = 0
|
||||||
|
for batch in train_dataloader:
|
||||||
|
print(f"Training Epoch [{epoch + 1}/{num_epochs}], Batch [{batch_idx + 1}/{len(train_dataloader)}]")
|
||||||
|
batch_idx += 1
|
||||||
|
source_images, target_images = batch
|
||||||
|
|
||||||
|
source_images = source_images.to(device)
|
||||||
|
target_images = target_images.to(device)
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
output_images = unet_model(source_images)
|
||||||
|
output_images += source_images[:, :3, :, :]
|
||||||
|
|
||||||
|
# Discriminator pass
|
||||||
|
discriminator_loss = discriminator_loss_func(output_images.detach(), target_images, source_images)
|
||||||
|
discriminator_loss.backward()
|
||||||
|
|
||||||
|
if (batch_idx % accum_iter == 0) or (batch_idx == len(train_dataloader)):
|
||||||
|
optimizer_d.step()
|
||||||
|
optimizer_d.zero_grad()
|
||||||
|
|
||||||
|
# Generator pass
|
||||||
|
# Calculate the loss
|
||||||
|
generator_loss, l1_loss, per_loss, adv_loss = generator_loss_func(output_images, target_images,
|
||||||
|
source_images)
|
||||||
|
generator_loss, l1_loss, per_loss, adv_loss = [i / accum_iter for i in
|
||||||
|
[generator_loss, l1_loss, per_loss, adv_loss]]
|
||||||
|
generator_loss.backward()
|
||||||
|
|
||||||
|
if (batch_idx % accum_iter == 0) or (batch_idx == len(train_dataloader)):
|
||||||
|
optimizer_g.step()
|
||||||
|
optimizer_g.zero_grad()
|
||||||
|
|
||||||
|
|
||||||
|
torch.save(unet_model.state_dict(), 'recent_unet_model.pth')
|
||||||
|
torch.save(discriminator_model.state_dict(), 'recent_discriminator_model.pth')
|
||||||
|
|
||||||
|
# Validation
|
||||||
|
if epoch % val_freq == 0:
|
||||||
|
unet_model.eval()
|
||||||
|
total_val_loss = 0.0
|
||||||
|
with torch.no_grad():
|
||||||
|
for val_batch in val_dataloader:
|
||||||
|
val_source_images, val_target_images = val_batch
|
||||||
|
val_source_images = val_source_images.to(device)
|
||||||
|
val_target_images = val_target_images.to(device)
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
val_output_images = unet_model(val_source_images)
|
||||||
|
|
||||||
|
# Calculate the loss
|
||||||
|
generator_loss, _, _, _ = generator_loss_func(val_output_images, val_target_images,
|
||||||
|
val_source_images)
|
||||||
|
total_val_loss += generator_loss.item()
|
||||||
|
|
||||||
|
average_val_loss = total_val_loss / len(val_dataloader)
|
||||||
|
|
||||||
|
# Print validation information
|
||||||
|
print(f'Validation Epoch [{epoch + 1}/{num_epochs}], Average Loss: {average_val_loss}')
|
||||||
|
|
||||||
|
# Save the model with the best validation loss
|
||||||
|
if average_val_loss < best_val_loss:
|
||||||
|
best_val_loss = average_val_loss
|
||||||
|
torch.save(unet_model.state_dict(), 'best_unet_model.pth')
|
||||||
|
torch.save(discriminator_model.state_dict(), 'best_discriminator_model.pth')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Define command-line arguments
|
||||||
|
parser = argparse.ArgumentParser(description="Training Script")
|
||||||
|
parser.add_argument("--root_dir", type=str, default='image',
|
||||||
|
help="Path to the training data. Note the format: To use the dataloader, the directory should be filled with folders containing image files of various ages, where the file name is {age}.jpg")
|
||||||
|
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch, if scripts is resumed")
|
||||||
|
parser.add_argument("--num_epochs", type=int, default=50, help="End epoch")
|
||||||
|
parser.add_argument("--load_model_g", type=str, default='recent_unet_model.pth',
|
||||||
|
help="Path to pretrained generator model. Leave blank to train from scratch")
|
||||||
|
parser.add_argument("--load_model_d", type=str, default='recent_discriminator_model.pth',
|
||||||
|
help="Path to pretrained discriminator model. Leave blank to train from scratch")
|
||||||
|
parser.add_argument("--num_workers", type=int, default=10, help="Number of workers")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=3, help="Batch size")
|
||||||
|
parser.add_argument("--accum_iter", type=int, default=2, help="Number of batches after which weights are updated")
|
||||||
|
parser.add_argument("--val_freq", type=int, default=1, help="Validation frequency (epochs)")
|
||||||
|
parser.add_argument("--lr", type=float, default=0.00001, help="Learning rate for generator")
|
||||||
|
parser.add_argument("--lr_d", type=float, default=0.00001, help="Learning rate for discriminator")
|
||||||
|
|
||||||
|
# Parse command-line arguments
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Call the scripts function with parsed arguments
|
||||||
|
train_model(args.root_dir, args.start_epoch, args.num_epochs, args.load_model_g, args.load_model_d,
|
||||||
|
args.num_workers, args.val_freq, args.batch_size, args.accum_iter, args.lr, args.lr_d)
|
Loading…
Reference in new issue