You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
230 lines
7.9 KiB
230 lines
7.9 KiB
import os
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
from tensorflow.keras import layers, models, optimizers
|
|
from tensorflow.keras.preprocessing.image import load_img, img_to_array
|
|
from sklearn.model_selection import train_test_split
|
|
import matplotlib.pyplot as plt
|
|
from glob import glob
|
|
import random
|
|
import cv2
|
|
|
|
# Define constants
|
|
IMAGE_SIZE = 200 # UTKFace images are commonly resized to 200x200
|
|
BATCH_SIZE = 10
|
|
EPOCHS = 10
|
|
LATENT_DIM = 100
|
|
BASE_DIR = "UTKFace/part3/part3" # Update this to your UTKFace dataset path
|
|
|
|
# Function to load and preprocess UTKFace dataset
|
|
def load_utkface_data(base_dir, target_size=(IMAGE_SIZE, IMAGE_SIZE)):
|
|
# UTKFace filename format: [age]_[gender]_[race]_[date&time].jpg
|
|
images = []
|
|
ages = []
|
|
|
|
image_paths = glob(os.path.join(base_dir, "*.jpg"))
|
|
for img_path in image_paths:
|
|
try:
|
|
# Extract age from filename
|
|
filename = os.path.basename(img_path)
|
|
age = int(filename.split("_")[0])
|
|
|
|
# Load and preprocess image
|
|
img = load_img(img_path, target_size=target_size)
|
|
img_array = img_to_array(img)
|
|
img_array = (img_array - 127.5) / 127.5 # Normalize to [-1, 1]
|
|
|
|
images.append(img_array)
|
|
ages.append(age)
|
|
except Exception as e:
|
|
print(f"Error processing {img_path}: {e}")
|
|
continue
|
|
|
|
return np.array(images), np.array(ages)
|
|
|
|
# Load the dataset
|
|
print("Loading UTKFace dataset...")
|
|
images, ages = load_utkface_data(BASE_DIR)
|
|
print(f"Loaded {len(images)} images with age information")
|
|
|
|
# Create age-paired dataset for training
|
|
def create_age_pairs(images, ages, min_age_gap=10, max_age_gap=40, batch_size=10000):
|
|
young_images = []
|
|
old_images = []
|
|
|
|
# Group images by age
|
|
age_to_images = {}
|
|
for i, age in enumerate(ages):
|
|
if age not in age_to_images:
|
|
age_to_images[age] = []
|
|
age_to_images[age].append(i)
|
|
|
|
# Create pairs with specified age gap
|
|
for young_age in sorted(age_to_images.keys()):
|
|
for old_age in sorted(age_to_images.keys()):
|
|
age_gap = old_age - young_age
|
|
if min_age_gap <= age_gap <= max_age_gap:
|
|
for young_idx in age_to_images[young_age]:
|
|
young_images.append(images[young_idx])
|
|
# Randomly select an older face
|
|
old_idx = random.choice(age_to_images[old_age])
|
|
old_images.append(images[old_idx])
|
|
|
|
# Process in batches to avoid memory issues
|
|
if len(young_images) >= batch_size:
|
|
yield np.array(young_images), np.array(old_images)
|
|
young_images, old_images = [], []
|
|
|
|
if young_images and old_images:
|
|
yield np.array(young_images), np.array(old_images)
|
|
|
|
# Usage
|
|
print("Creating age-paired training data...")
|
|
young_faces, old_faces = next(create_age_pairs(images, ages))
|
|
print(f"Created {len(young_faces)} young-old face pairs for training")
|
|
|
|
# Split into training and validation sets
|
|
X_train, X_val, y_train, y_val = train_test_split(
|
|
young_faces, old_faces, test_size=0.2, random_state=42)
|
|
|
|
# Build the age progression model (using a modified U-Net architecture)
|
|
def build_age_progression_model():
|
|
# Encoder
|
|
inputs = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))
|
|
|
|
# Encoder path
|
|
e1 = layers.Conv2D(64, (4, 4), strides=(2, 2), padding='same')(inputs)
|
|
e1 = layers.LeakyReLU(alpha=0.2)(e1)
|
|
|
|
e2 = layers.Conv2D(128, (4, 4), strides=(2, 2), padding='same')(e1)
|
|
e2 = layers.BatchNormalization()(e2)
|
|
e2 = layers.LeakyReLU(alpha=0.2)(e2)
|
|
|
|
e3 = layers.Conv2D(256, (4, 4), strides=(2, 2), padding='same')(e2)
|
|
e3 = layers.BatchNormalization()(e3)
|
|
e3 = layers.LeakyReLU(alpha=0.2)(e3)
|
|
|
|
e4 = layers.Conv2D(512, (4, 4), strides=(2, 2), padding='same')(e3)
|
|
e4 = layers.BatchNormalization()(e4)
|
|
e4 = layers.LeakyReLU(alpha=0.2)(e4)
|
|
|
|
e5 = layers.Conv2D(512, (4, 4), strides=(2, 2), padding='same')(e4)
|
|
e5 = layers.BatchNormalization()(e5)
|
|
e5 = layers.LeakyReLU(alpha=0.2)(e5)
|
|
|
|
# Decoder path with skip connections
|
|
d1 = layers.UpSampling2D(size=(2, 2))(e5)
|
|
d1 = layers.Conv2D(512, (3, 3), padding='same', activation='relu')(d1)
|
|
d1 = layers.BatchNormalization()(d1)
|
|
d1 = layers.Concatenate()([d1, e4])
|
|
|
|
d2 = layers.UpSampling2D(size=(2, 2))(d1)
|
|
d2 = layers.Conv2D(256, (3, 3), padding='same', activation='relu')(d2)
|
|
d2 = layers.BatchNormalization()(d2)
|
|
d2 = layers.Concatenate()([d2, e3])
|
|
|
|
d3 = layers.UpSampling2D(size=(2, 2))(d2)
|
|
d3 = layers.Conv2D(128, (3, 3), padding='same', activation='relu')(d3)
|
|
d3 = layers.BatchNormalization()(d3)
|
|
d3 = layers.Concatenate()([d3, e2])
|
|
|
|
d4 = layers.UpSampling2D(size=(2, 2))(d3)
|
|
d4 = layers.Conv2D(64, (3, 3), padding='same', activation='relu')(d4)
|
|
d4 = layers.BatchNormalization()(d4)
|
|
d4 = layers.Concatenate()([d4, e1])
|
|
|
|
d5 = layers.UpSampling2D(size=(2, 2))(d4)
|
|
outputs = layers.Conv2D(3, (3, 3), padding='same', activation='tanh')(d5)
|
|
|
|
model = models.Model(inputs=inputs, outputs=outputs)
|
|
return model
|
|
|
|
# Build and compile the model
|
|
print("Building age progression model...")
|
|
model = build_age_progression_model()
|
|
model.compile(
|
|
optimizer=optimizers.Adam(learning_rate=0.0002, beta_1=0.5),
|
|
loss='mae' # Mean Absolute Error for image generation
|
|
)
|
|
model.summary()
|
|
|
|
# Create a callback for saving the model
|
|
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
|
filepath='age_progression_model_best.h5',
|
|
save_best_only=True,
|
|
monitor='val_loss',
|
|
mode='min'
|
|
)
|
|
|
|
# Train the model
|
|
print("Training the age progression model...")
|
|
history = model.fit(
|
|
X_train, y_train,
|
|
validation_data=(X_val, y_val),
|
|
epochs=EPOCHS,
|
|
batch_size=BATCH_SIZE,
|
|
callbacks=[checkpoint_callback]
|
|
)
|
|
|
|
# Plot training history
|
|
plt.figure(figsize=(12, 4))
|
|
plt.subplot(1, 1, 1)
|
|
plt.plot(history.history['loss'], label='Training Loss')
|
|
plt.plot(history.history['val_loss'], label='Validation Loss')
|
|
plt.xlabel('Epoch')
|
|
plt.ylabel('Loss')
|
|
plt.legend()
|
|
plt.title('Training and Validation Loss')
|
|
plt.savefig('training_history.png')
|
|
plt.close()
|
|
|
|
# Function to use the model for inference
|
|
def age_progress_face(model, face_image_path, output_path=None):
|
|
# Load and preprocess the input image
|
|
img = load_img(face_image_path, target_size=(IMAGE_SIZE, IMAGE_SIZE))
|
|
img_array = img_to_array(img)
|
|
img_array = (img_array - 127.5) / 127.5 # Normalize to [-1, 1]
|
|
img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
|
|
|
|
# Generate aged face
|
|
aged_face = model.predict(img_array)
|
|
|
|
# Convert back to uint8 format
|
|
aged_face = ((aged_face[0] * 127.5) + 127.5).astype(np.uint8)
|
|
|
|
# Save the result if output path is provided
|
|
if output_path:
|
|
cv2.imwrite(output_path, cv2.cvtColor(aged_face, cv2.COLOR_RGB2BGR))
|
|
|
|
return aged_face
|
|
|
|
# Example usage after training
|
|
print("Testing the model with a sample image...")
|
|
# Load the best model
|
|
best_model = models.load_model('age_progression_model_best.h5')
|
|
|
|
# Test with a sample image (you'll need to update this path)
|
|
sample_image_path = "sample_young_face.jpg" # Update with your test image path
|
|
output_path = "aged_face_result.jpg"
|
|
|
|
try:
|
|
aged_face = age_progress_face(best_model, sample_image_path, output_path)
|
|
print(f"Aged face saved to {output_path}")
|
|
|
|
# Display original and aged faces
|
|
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
|
|
axes[0].imshow(load_img(sample_image_path, target_size=(IMAGE_SIZE, IMAGE_SIZE)))
|
|
axes[0].set_title("Original Face")
|
|
axes[0].axis("off")
|
|
|
|
axes[1].imshow(aged_face)
|
|
axes[1].set_title("Aged Face")
|
|
axes[1].axis("off")
|
|
|
|
plt.tight_layout()
|
|
plt.savefig("comparison.png")
|
|
plt.show()
|
|
except Exception as e:
|
|
print(f"Error testing the model: {e}")
|
|
|
|
print("Training and testing complete!") |