You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
50 lines
1.8 KiB
50 lines
1.8 KiB
import os
|
|
import numpy as np
|
|
from tensorflow.keras.models import Sequential
|
|
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
|
|
from tensorflow.keras.losses import MeanSquaredError
|
|
from tensorflow.keras.preprocessing.image import img_to_array, load_img
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
def load_utkface_dataset(dataset_path):
|
|
images = []
|
|
ages = []
|
|
for filename in os.listdir(dataset_path):
|
|
if filename.endswith(".jpg"):
|
|
age = int(filename.split("_")[0])
|
|
img_path = os.path.join(dataset_path, filename)
|
|
img = load_img(img_path, target_size=(128, 128))
|
|
img_array = img_to_array(img) / 255.0
|
|
images.append(img_array)
|
|
ages.append(age)
|
|
return np.array(images), np.array(ages)
|
|
|
|
def build_model():
|
|
model = Sequential([
|
|
Conv2D(32, (3, 3), activation='relu', input_shape=(128, 128, 3)),
|
|
MaxPooling2D((2, 2)),
|
|
Conv2D(64, (3, 3), activation='relu'),
|
|
MaxPooling2D((2, 2)),
|
|
Conv2D(128, (3, 3), activation='relu'),
|
|
MaxPooling2D((2, 2)),
|
|
Flatten(),
|
|
Dense(128, activation='relu'),
|
|
Dense(1)
|
|
])
|
|
model.compile(optimizer='adam', loss=MeanSquaredError(), metrics=['mae'])
|
|
return model
|
|
|
|
def train_model(dataset_path, model_path):
|
|
images, ages = load_utkface_dataset(dataset_path)
|
|
X_train, X_test, y_train, y_test = train_test_split(images, ages, test_size=0.2, random_state=42)
|
|
|
|
model = build_model()
|
|
model.fit(X_train, y_train, epochs=10, validation_data=(X_test, y_test))
|
|
|
|
model.save(model_path)
|
|
print(f"Model saved to {model_path}")
|
|
|
|
if __name__ == "__main__":
|
|
dataset_path = "./UTKFace/part1/part1"
|
|
model_path = "face_aging_model.h5"
|
|
train_model(dataset_path, model_path) |