Upgrade face agging

master
p09ba 1 month ago
parent f65cbf5b5a
commit 079544aa8c

@ -1,12 +1,21 @@
import tkinter as tk import tkinter as tk
from tkinter import filedialog, messagebox from tkinter import filedialog, messagebox
from tkinter import filedialog, messagebox
from PIL import Image, ImageTk from PIL import Image, ImageTk
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
import cv2 import cv2
from tensorflow.keras.models import load_model from tensorflow.keras.models import load_model
import numpy as np
import cv2
from tensorflow.keras.models import load_model
from tensorflow.keras.losses import MeanSquaredError from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.utils import get_custom_objects from tensorflow.keras.utils import get_custom_objects
from torch.autograd import Variable
import torch
import torchvision.transforms as transforms
from models import UNet
def mse(y_true, y_pred): def mse(y_true, y_pred):
@ -16,7 +25,11 @@ get_custom_objects().update({'mse': mse})
# Chargement des modèles # Chargement des modèles
face_aging_model = load_model("face_aging_model.h5", custom_objects={"mse": mse}) face_aging_model = load_model("face_aging_model.h5", custom_objects={"mse": mse})
face_aging_autoencoder = load_model("face_aging_autoencoder.h5", custom_objects={"mse": mse})
device = "cuda" if torch.cuda.is_available() else "cpu"
face_aging_autoencoder = UNet().to(device)
face_aging_autoencoder.load_state_dict(torch.load("best_unet_model_dl.pth", map_location=device))
def load_image(): def load_image():
file_path = filedialog.askopenfilename(title="Sélectionner une image", filetypes=[("Image Files", "*.jpg;*.png;*.jpeg")]) file_path = filedialog.askopenfilename(title="Sélectionner une image", filetypes=[("Image Files", "*.jpg;*.png;*.jpeg")])
@ -45,17 +58,62 @@ def predict_age_from_model(model, image_path):
prediction = model.predict(img_array) prediction = model.predict(img_array)
return prediction[0][0] return prediction[0][0]
def apply_aging_effect(model, image_path): def sliding_window_tensor(input_tensor, your_model):
img = Image.open(image_path).resize((128, 128)) window_size,stride=512,256
img_array = np.array(img) / 255.0 input_tensor = input_tensor.to(next(your_model.parameters()).device)
predicted_img = model.predict(np.expand_dims(img_array, axis=0))
predicted_img = np.clip(predicted_img[0], 0, 1) n, c, h, w = input_tensor.size()
predicted_img = (predicted_img * 255).astype(np.uint8) output_tensor = torch.zeros((n, 3, h, w), dtype=input_tensor.dtype, device=input_tensor.device)
return Image.fromarray(predicted_img) count_tensor = torch.zeros((n, 3, h, w), dtype=torch.float32, device=input_tensor.device)
for y in range(0, h - window_size + 1, stride):
for x in range(0, w - window_size + 1, stride):
window = input_tensor[:, :, y:y + window_size, x:x + window_size]
input_variable = Variable(window, requires_grad=False)
with torch.no_grad():
output = your_model(input_variable)
output_tensor[:, :, y:y + window_size, x:x + window_size] += output
count_tensor[:, :, y:y + window_size, x:x + window_size] += 1
count_tensor = torch.clamp(count_tensor, min=1.0)
output_tensor /= count_tensor
return output_tensor.cpu()
def process_image(your_model, image, source_age, target_age=80, ):
image = np.array(image)
image_original = image.copy()
image = transforms.ToTensor()(image)
source_age_channel = torch.full_like(image[:1, :, :], source_age / 100)
target_age_channel = torch.full_like(image[:1, :, :], target_age / 100)
input_tensor = torch.cat([image, source_age_channel, target_age_channel], dim=0).unsqueeze(0)
image_original = transforms.ToTensor()(image_original)
# performing actions on image
aged_image = sliding_window_tensor(input_tensor, your_model)
image_original += aged_image.squeeze(0)
image_original = torch.clamp(image_original, 0, 1)
return transforms.functional.to_pil_image(image_original)
def apply_aging_effect(model, image_path, age_source, age_cible):
image = Image.open(image_path)
return process_image(model, image, source_age=age_source, target_age=age_cible)
def show_aged_image(): def show_aged_image():
if 'image_path' in globals(): if 'image_path' in globals():
aged_image = apply_aging_effect(face_aging_autoencoder, image_path) predict_age = predict_age_from_model(face_aging_model, image_path)
target_age = target_age_var.get() # Récupérer l'âge cible depuis le slider
aged_image = apply_aging_effect(face_aging_autoencoder, image_path, predict_age, target_age)
aged_image_tk = ImageTk.PhotoImage(aged_image.resize((256, 256))) aged_image_tk = ImageTk.PhotoImage(aged_image.resize((256, 256)))
aged_panel.configure(image=aged_image_tk) aged_panel.configure(image=aged_image_tk)
aged_panel.image = aged_image_tk aged_panel.image = aged_image_tk
@ -77,7 +135,14 @@ load_button.pack(pady=5)
predict_button = tk.Button(root, text="Prédire l'âge", command=predict_age) predict_button = tk.Button(root, text="Prédire l'âge", command=predict_age)
predict_button.pack(pady=5) predict_button.pack(pady=5)
age_button = tk.Button(root, text="Afficher l'image vieillie", command=show_aged_image) # Ajout d'une variable pour stocker l'âge cible
target_age_var = tk.IntVar(value=80)
# Ajout d'une barre de défilement pour choisir l'âge cible
age_slider = tk.Scale(root, from_=5, to=100, orient="horizontal", label="Âge cible", variable=target_age_var, resolution=5)
age_slider.pack(pady=5)
age_button = tk.Button(root, text="Afficher l'image transformée", command=show_aged_image)
age_button.pack(pady=5) age_button.pack(pady=5)
# Affichage de l'image vieillie # Affichage de l'image vieillie

Binary file not shown.

@ -0,0 +1,76 @@
import torch
import torch.nn as nn
import antialiased_cnns
class DownLayer(nn.Module):
def __init__(self, in_channels, out_channels):
super(DownLayer, self).__init__()
self.layer = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=1),
antialiased_cnns.BlurPool(in_channels, stride=2),
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.LeakyReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.LeakyReLU(inplace=True)
)
def forward(self, x):
return self.layer(x)
class UpLayer(nn.Module):
def __init__(self, in_channels, out_channels):
super(UpLayer, self).__init__()
# Conv transpose upsampling
self.blur_upsample = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0),
antialiased_cnns.BlurPool(out_channels, stride=1)
)
self.layer = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.LeakyReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.LeakyReLU(inplace=True)
)
def forward(self, x, skip):
x = self.blur_upsample(x)
x = torch.cat([x, skip], dim=1) # Concatenate with skip connection
return self.layer(x)
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.init_conv = nn.Sequential(
nn.Conv2d(5, 64, kernel_size=3, padding=1), # output: 512 x 512 x 64
nn.LeakyReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1), # output: 512 x 512 x 64
nn.LeakyReLU(inplace=True)
)
self.down1 = DownLayer(64, 128) # output: 256 x 256 x 128
self.down2 = DownLayer(128, 256) # output: 128 x 128 x 256
self.down3 = DownLayer(256, 512) # output: 64 x 64 x 512
self.down4 = DownLayer(512, 1024) # output: 32 x 32 x 1024
self.up1 = UpLayer(1024, 512) # output: 64 x 64 x 512
self.up2 = UpLayer(512, 256) # output: 128 x 128 x 256
self.up3 = UpLayer(256, 128) # output: 256 x 256 x 128
self.up4 = UpLayer(128, 64) # output: 512 x 512 x 64
self.final_conv = nn.Conv2d(64, 3, kernel_size=1) # output: 512 x 512 x 3
def forward(self, x):
x0 = self.init_conv(x)
x1 = self.down1(x0)
x2 = self.down2(x1)
x3 = self.down3(x2)
x4 = self.down4(x3)
x = self.up1(x4, x3)
x = self.up2(x, x2)
x = self.up3(x, x1)
x = self.up4(x, x0)
x = self.final_conv(x)
return x
Loading…
Cancel
Save