You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

230 lines
7.9 KiB

import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from glob import glob
import random
import cv2
# Define constants
IMAGE_SIZE = 200 # UTKFace images are commonly resized to 200x200
BATCH_SIZE = 10
EPOCHS = 10
LATENT_DIM = 100
BASE_DIR = "UTKFace/part3/part3" # Update this to your UTKFace dataset path
# Function to load and preprocess UTKFace dataset
def load_utkface_data(base_dir, target_size=(IMAGE_SIZE, IMAGE_SIZE)):
# UTKFace filename format: [age]_[gender]_[race]_[date&time].jpg
images = []
ages = []
image_paths = glob(os.path.join(base_dir, "*.jpg"))
for img_path in image_paths:
try:
# Extract age from filename
filename = os.path.basename(img_path)
age = int(filename.split("_")[0])
# Load and preprocess image
img = load_img(img_path, target_size=target_size)
img_array = img_to_array(img)
img_array = (img_array - 127.5) / 127.5 # Normalize to [-1, 1]
images.append(img_array)
ages.append(age)
except Exception as e:
print(f"Error processing {img_path}: {e}")
continue
return np.array(images), np.array(ages)
# Load the dataset
print("Loading UTKFace dataset...")
images, ages = load_utkface_data(BASE_DIR)
print(f"Loaded {len(images)} images with age information")
# Create age-paired dataset for training
def create_age_pairs(images, ages, min_age_gap=10, max_age_gap=40, batch_size=10000):
young_images = []
old_images = []
# Group images by age
age_to_images = {}
for i, age in enumerate(ages):
if age not in age_to_images:
age_to_images[age] = []
age_to_images[age].append(i)
# Create pairs with specified age gap
for young_age in sorted(age_to_images.keys()):
for old_age in sorted(age_to_images.keys()):
age_gap = old_age - young_age
if min_age_gap <= age_gap <= max_age_gap:
for young_idx in age_to_images[young_age]:
young_images.append(images[young_idx])
# Randomly select an older face
old_idx = random.choice(age_to_images[old_age])
old_images.append(images[old_idx])
# Process in batches to avoid memory issues
if len(young_images) >= batch_size:
yield np.array(young_images), np.array(old_images)
young_images, old_images = [], []
if young_images and old_images:
yield np.array(young_images), np.array(old_images)
# Usage
print("Creating age-paired training data...")
young_faces, old_faces = next(create_age_pairs(images, ages))
print(f"Created {len(young_faces)} young-old face pairs for training")
# Split into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(
young_faces, old_faces, test_size=0.2, random_state=42)
# Build the age progression model (using a modified U-Net architecture)
def build_age_progression_model():
# Encoder
inputs = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))
# Encoder path
e1 = layers.Conv2D(64, (4, 4), strides=(2, 2), padding='same')(inputs)
e1 = layers.LeakyReLU(alpha=0.2)(e1)
e2 = layers.Conv2D(128, (4, 4), strides=(2, 2), padding='same')(e1)
e2 = layers.BatchNormalization()(e2)
e2 = layers.LeakyReLU(alpha=0.2)(e2)
e3 = layers.Conv2D(256, (4, 4), strides=(2, 2), padding='same')(e2)
e3 = layers.BatchNormalization()(e3)
e3 = layers.LeakyReLU(alpha=0.2)(e3)
e4 = layers.Conv2D(512, (4, 4), strides=(2, 2), padding='same')(e3)
e4 = layers.BatchNormalization()(e4)
e4 = layers.LeakyReLU(alpha=0.2)(e4)
e5 = layers.Conv2D(512, (4, 4), strides=(2, 2), padding='same')(e4)
e5 = layers.BatchNormalization()(e5)
e5 = layers.LeakyReLU(alpha=0.2)(e5)
# Decoder path with skip connections
d1 = layers.UpSampling2D(size=(2, 2))(e5)
d1 = layers.Conv2D(512, (3, 3), padding='same', activation='relu')(d1)
d1 = layers.BatchNormalization()(d1)
d1 = layers.Concatenate()([d1, e4])
d2 = layers.UpSampling2D(size=(2, 2))(d1)
d2 = layers.Conv2D(256, (3, 3), padding='same', activation='relu')(d2)
d2 = layers.BatchNormalization()(d2)
d2 = layers.Concatenate()([d2, e3])
d3 = layers.UpSampling2D(size=(2, 2))(d2)
d3 = layers.Conv2D(128, (3, 3), padding='same', activation='relu')(d3)
d3 = layers.BatchNormalization()(d3)
d3 = layers.Concatenate()([d3, e2])
d4 = layers.UpSampling2D(size=(2, 2))(d3)
d4 = layers.Conv2D(64, (3, 3), padding='same', activation='relu')(d4)
d4 = layers.BatchNormalization()(d4)
d4 = layers.Concatenate()([d4, e1])
d5 = layers.UpSampling2D(size=(2, 2))(d4)
outputs = layers.Conv2D(3, (3, 3), padding='same', activation='tanh')(d5)
model = models.Model(inputs=inputs, outputs=outputs)
return model
# Build and compile the model
print("Building age progression model...")
model = build_age_progression_model()
model.compile(
optimizer=optimizers.Adam(learning_rate=0.0002, beta_1=0.5),
loss='mae' # Mean Absolute Error for image generation
)
model.summary()
# Create a callback for saving the model
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath='age_progression_model_best.h5',
save_best_only=True,
monitor='val_loss',
mode='min'
)
# Train the model
print("Training the age progression model...")
history = model.fit(
X_train, y_train,
validation_data=(X_val, y_val),
epochs=EPOCHS,
batch_size=BATCH_SIZE,
callbacks=[checkpoint_callback]
)
# Plot training history
plt.figure(figsize=(12, 4))
plt.subplot(1, 1, 1)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')
plt.savefig('training_history.png')
plt.close()
# Function to use the model for inference
def age_progress_face(model, face_image_path, output_path=None):
# Load and preprocess the input image
img = load_img(face_image_path, target_size=(IMAGE_SIZE, IMAGE_SIZE))
img_array = img_to_array(img)
img_array = (img_array - 127.5) / 127.5 # Normalize to [-1, 1]
img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
# Generate aged face
aged_face = model.predict(img_array)
# Convert back to uint8 format
aged_face = ((aged_face[0] * 127.5) + 127.5).astype(np.uint8)
# Save the result if output path is provided
if output_path:
cv2.imwrite(output_path, cv2.cvtColor(aged_face, cv2.COLOR_RGB2BGR))
return aged_face
# Example usage after training
print("Testing the model with a sample image...")
# Load the best model
best_model = models.load_model('age_progression_model_best.h5')
# Test with a sample image (you'll need to update this path)
sample_image_path = "sample_young_face.jpg" # Update with your test image path
output_path = "aged_face_result.jpg"
try:
aged_face = age_progress_face(best_model, sample_image_path, output_path)
print(f"Aged face saved to {output_path}")
# Display original and aged faces
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(load_img(sample_image_path, target_size=(IMAGE_SIZE, IMAGE_SIZE)))
axes[0].set_title("Original Face")
axes[0].axis("off")
axes[1].imshow(aged_face)
axes[1].set_title("Aged Face")
axes[1].axis("off")
plt.tight_layout()
plt.savefig("comparison.png")
plt.show()
except Exception as e:
print(f"Error testing the model: {e}")
print("Training and testing complete!")