You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
104 lines
4.0 KiB
104 lines
4.0 KiB
import os
|
|
import pandas as pd
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
from tensorflow.keras.layers import Input, Conv2D, Flatten, Dense, Reshape, Conv2DTranspose
|
|
from tensorflow.keras.models import Model
|
|
from tensorflow.keras.optimizers import Adam
|
|
from tensorflow.keras.preprocessing.image import load_img, img_to_array, ImageDataGenerator
|
|
from tensorflow.keras.callbacks import Callback
|
|
import matplotlib.pyplot as plt
|
|
import cv2
|
|
from tensorflow.keras.models import load_model
|
|
from tensorflow.keras.losses import MeanSquaredError
|
|
|
|
data_dir = "./UTKFace/part1/part1"
|
|
image_size = (512, 512) # Augmenter la résolution des images
|
|
batch_size = 32
|
|
|
|
def create_dataframe(data_dir):
|
|
filepaths = []
|
|
for file in os.listdir(data_dir):
|
|
if file.endswith(".jpg"):
|
|
filepaths.append(os.path.join(data_dir, file))
|
|
df = pd.DataFrame({'filename': filepaths})
|
|
return df
|
|
|
|
def data_generator(df, image_size=(512, 512), batch_size=32):
|
|
datagen = ImageDataGenerator(rescale=1./255)
|
|
generator = datagen.flow_from_dataframe(
|
|
df,
|
|
x_col='filename',
|
|
class_mode=None,
|
|
target_size=image_size,
|
|
batch_size=batch_size,
|
|
shuffle=True
|
|
)
|
|
return generator
|
|
|
|
class ImageLogger(Callback):
|
|
def __init__(self, autoencoder_model, sample_image, output_dir="progress_images"):
|
|
self.autoencoder_model = autoencoder_model
|
|
self.sample_image = sample_image
|
|
self.output_dir = output_dir
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
reconstructed_image = self.autoencoder_model.predict(np.expand_dims(self.sample_image, axis=0))[0]
|
|
reconstructed_image = (reconstructed_image * 255).astype(np.uint8)
|
|
output_path = os.path.join(self.output_dir, f"epoch_{epoch + 1}.png")
|
|
plt.imsave(output_path, reconstructed_image)
|
|
print(f"Image sauvegardée à l'époque {epoch + 1}")
|
|
|
|
def build_autoencoder():
|
|
input_img = Input(shape=(512, 512, 3)) # Augmenter la résolution des images
|
|
|
|
# Encoder
|
|
x = Conv2D(32, (3, 3), activation='relu', padding='same')(input_img)
|
|
x = Conv2D(64, (3, 3), activation='relu', padding='same', strides=(2, 2))(x)
|
|
x = Conv2D(128, (3, 3), activation='relu', padding='same', strides=(2, 2))(x)
|
|
x = Conv2D(256, (3, 3), activation='relu', padding='same', strides=(2, 2))(x)
|
|
x = Flatten()(x)
|
|
latent = Dense(512, activation='relu')(x)
|
|
|
|
# Decoder
|
|
x = Dense(64 * 64 * 256, activation='relu')(latent) # Ajuster la taille
|
|
x = Reshape((64, 64, 256))(x)
|
|
x = Conv2DTranspose(256, (3, 3), activation='relu', padding='same', strides=(2, 2))(x)
|
|
x = Conv2DTranspose(128, (3, 3), activation='relu', padding='same', strides=(2, 2))(x)
|
|
x = Conv2DTranspose(64, (3, 3), activation='relu', padding='same', strides=(2, 2))(x)
|
|
output_img = Conv2DTranspose(3, (3, 3), activation='sigmoid', padding='same')(x)
|
|
|
|
autoencoder = Model(input_img, output_img)
|
|
autoencoder.compile(optimizer=Adam(learning_rate=0.001), loss='mse')
|
|
|
|
return autoencoder
|
|
|
|
autoencoder = build_autoencoder()
|
|
|
|
# Créer un DataFrame avec les chemins des images
|
|
df = create_dataframe(data_dir)
|
|
|
|
# Vérifier que le DataFrame ne contient pas de valeurs None
|
|
df = df.dropna()
|
|
|
|
# Vérifier que les fichiers existent et sont accessibles
|
|
df = df[df['filename'].apply(lambda x: os.path.exists(x))]
|
|
|
|
# Sélectionner une image d'exemple pour visualiser la progression
|
|
sample_image = load_img(df['filename'].iloc[0], target_size=image_size)
|
|
sample_image = img_to_array(sample_image) / 255.0
|
|
|
|
# Réduire le nombre d'époques pour gagner du temps
|
|
epochs = 20
|
|
|
|
# Utiliser le générateur de données pour l'entraînement
|
|
train_generator = data_generator(df, image_size=image_size, batch_size=batch_size)
|
|
|
|
# Entraîner le modèle avec le callback ImageLogger
|
|
autoencoder.fit(train_generator, epochs=epochs, callbacks=[ImageLogger(autoencoder, sample_image)])
|
|
|
|
# Sauvegarder le modèle
|
|
autoencoder.save("face_aging_autoencoder_2.h5")
|
|
print("Modèle entraîné et sauvegardé avec succès !")
|