You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

62 lines
2.5 KiB

import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Flatten, Dense, Reshape, Conv2DTranspose
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.callbacks import Callback
import matplotlib.pyplot as plt
# Activer l'exécution immédiate
tf.config.run_functions_eagerly(True)
data_dir = "./UTKFace/part1/part1"
image_size = (128, 128)
def load_utkface_images(data_dir, image_size=(128, 128)):
images, ages = [], []
for file in os.listdir(data_dir):
if file.endswith(".jpg"):
age = int(file.split("_")[0]) # Age is the first part of filename
img = load_img(os.path.join(data_dir, file), target_size=image_size)
img = img_to_array(img) / 255.0 # Normalize images
images.append(img)
ages.append(age)
return np.array(images), np.array(ages)
X, y = load_utkface_images(data_dir)
class ImageLogger(Callback):
def __init__(self, autoencoder_model, sample_image, output_dir="progress_images"):
self.autoencoder_model = autoencoder_model
self.sample_image = sample_image
self.output_dir = output_dir
os.makedirs(output_dir, exist_ok=True)
def on_epoch_end(self, epoch, logs=None):
reconstructed_image = self.autoencoder_model.predict(np.expand_dims(self.sample_image, axis=0))[0]
reconstructed_image = (reconstructed_image * 255).astype(np.uint8)
output_path = os.path.join(self.output_dir, f"epoch_{epoch + 1}.png")
plt.imsave(output_path, reconstructed_image)
print(f"Image sauvegardée à l'époque {epoch + 1}")
# Charger le modèle sauvegardé
model_path = "face_aging_autoencoder.h5"
autoencoder = load_model(model_path, custom_objects={'mse': tf.keras.losses.MeanSquaredError()})
# Recompiler le modèle après le chargement
autoencoder.compile(optimizer=Adam(learning_rate=0.001), loss='mse')
# Sélectionner une image d'exemple pour visualiser la progression
sample_image = X[0]
# Définir le nombre d'époques supplémentaires pour continuer l'entraînement
additional_epochs = 20
# Continuer l'entraînement avec le callback ImageLogger
autoencoder.fit(X, X, epochs=additional_epochs, batch_size=32, validation_split=0.2, callbacks=[ImageLogger(autoencoder, sample_image)])
# Sauvegarder le modèle mis à jour
autoencoder.save("face_aging_autoencoder_updated.h5")
print("Modèle mis à jour et sauvegardé avec succès !")