You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

137 lines
5.6 KiB

import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, BatchNormalization, Activation, ZeroPadding2D
from tensorflow.keras.layers import LeakyReLU, UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
import numpy as np
import os
from tensorflow.keras.preprocessing.image import img_to_array, load_img
import cv2 # Importer la bibliothèque OpenCV
# Limitez la croissance de la mémoire GPU
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
except RuntimeError as e:
print(e)
# Charger les données UTKFace
def load_utkface_dataset(dataset_path):
images = []
ages = []
for file_name in os.listdir(dataset_path):
if file_name.endswith(".jpg"):
age = int(file_name.split("_")[0]) # Extraire l'âge du nom du fichier
img_path = os.path.join(dataset_path, file_name)
img = load_img(img_path, target_size=(128, 128))
img_array = img_to_array(img) / 255.0
images.append(img_array)
ages.append(age)
return np.array(images), np.array(ages)
# Définir le générateur
def build_generator():
model = Sequential()
model.add(Dense(128 * 8 * 8, activation="relu", input_dim=100))
model.add(Reshape((8, 8, 128)))
model.add(UpSampling2D())
model.add(Conv2D(128, kernel_size=3, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(Activation("relu"))
model.add(UpSampling2D())
model.add(Conv2D(64, kernel_size=3, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(Activation("relu"))
model.add(UpSampling2D())
model.add(Conv2D(32, kernel_size=3, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(Activation("relu"))
model.add(UpSampling2D())
model.add(Conv2D(16, kernel_size=3, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(Activation("relu"))
model.add(Conv2D(3, kernel_size=3, padding="same"))
model.add(Activation("tanh"))
return model
# Définir le discriminateur
def build_discriminator():
model = Sequential()
model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=(128, 128, 3), padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
model.add(ZeroPadding2D(padding=((0,1),(0,1))))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(1, activation='sigmoid'))
return model
# Compiler le modèle GAN
def build_gan(generator, discriminator):
discriminator.trainable = False
z = Input(shape=(100,))
img = generator(z)
valid = discriminator(img)
combined = Model(z, valid)
combined.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
return combined
# Entraîner le modèle GAN
def train_gan(generator, discriminator, combined, epochs, batch_size, save_interval, dataset_path):
X_train, _ = load_utkface_dataset(dataset_path)
X_train = (X_train - 0.5) * 2 # Normaliser les images entre -1 et 1
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
# Compiler le discriminateur
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])
for epoch in range(epochs):
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
noise = np.random.normal(0, 1, (batch_size, 100))
gen_imgs = generator.predict(noise)
d_loss_real = discriminator.train_on_batch(imgs, valid)
d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
g_loss = combined.train_on_batch(noise, valid)
print(f"{epoch} [D loss: {d_loss[0]} | D accuracy: {100*d_loss[1]}] [G loss: {g_loss}]")
if epoch % save_interval == 0:
save_images(generator, epoch)
save_model(generator, epoch)
# Sauvegarder les images générées
def save_images(generator, epoch, output_dir="images"):
noise = np.random.normal(0, 1, (25, 100))
gen_imgs = generator.predict(noise)
gen_imgs = 0.5 * gen_imgs + 0.5 # Rescale images 0 - 1
os.makedirs(output_dir, exist_ok=True)
for i in range(25):
img = gen_imgs[i]
cv2.imwrite(f"{output_dir}/img_{epoch}_{i}.png", cv2.cvtColor(img * 255, cv2.COLOR_RGB2BGR))
# Sauvegarder le modèle générateur
def save_model(generator, epoch, output_dir="models"):
os.makedirs(output_dir, exist_ok=True)
generator.save(f"{output_dir}/generator_epoch_{epoch}.h5")
print(f"Modèle générateur sauvegardé à l'époque {epoch}")
if __name__ == "__main__":
dataset_path = "./UTKFace/part1/part1" # Modifier selon l'emplacement du dataset
generator = build_generator()
discriminator = build_discriminator()
combined = build_gan(generator, discriminator)
train_gan(generator, discriminator, combined, epochs=10000, batch_size=32, save_interval=200, dataset_path=dataset_path)