From 816bf3a2374e413b85026fa4405fdf898b22df2f Mon Sep 17 00:00:00 2001 From: rem Date: Tue, 25 Jun 2024 00:35:54 +0200 Subject: [PATCH] add PCA for dimensions reduction on clustering --- src/back/clustering_csv.py | 16 ++++++++++++---- src/pages/prediction.py | 10 +++++++--- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/back/clustering_csv.py b/src/back/clustering_csv.py index 4b3b6fb..b6a95de 100644 --- a/src/back/clustering_csv.py +++ b/src/back/clustering_csv.py @@ -3,6 +3,7 @@ import matplotlib.pyplot as plt from sklearn.cluster import KMeans, DBSCAN from sklearn.datasets import make_blobs, make_moons from mpl_toolkits.mplot3d import Axes3D +from sklearn.decomposition import PCA def visualize_clusters_2d(X, labels, centers=None, title="Clusters"): plt.figure(figsize=(10, 7)) @@ -56,8 +57,11 @@ def calculate_cluster_statistics_dbscan(X, labels): }) return stats -def launch_cluster_knn(df, array_columns, n=3): +def launch_cluster_knn(df, array_columns, n=3, dimensions=2): X = df[array_columns].values + if len(array_columns) > 3: + pca = PCA(dimensions) + X = pca.fit_transform(df) kmeans = KMeans(n_clusters=n, random_state=42) labels_kmeans = kmeans.fit_predict(X) @@ -66,19 +70,23 @@ def launch_cluster_knn(df, array_columns, n=3): # print(f"Cluster {stat['cluster']}: {stat['num_points']} points, Center: {stat['center']}") stats_kmeans = calculate_cluster_statistics_kmeans(X, labels_kmeans, centers_kmeans) - if len(array_columns) == 3: + if dimensions == 3: return visualize_clusters_3d(X, labels_kmeans, centers_kmeans, title="K-Means Clustering 3D") else: return visualize_clusters_2d(X, labels_kmeans, centers_kmeans, title="K-Means Clustering") -def launch_cluster_dbscan(df, array_columns): +def launch_cluster_dbscan(df, array_columns, dimensions=2): X = df[array_columns].values + if len(array_columns) > 3: + pca = PCA(dimensions) + X = pca.fit_transform(df) + dbscan = DBSCAN(eps=0.2, min_samples=5) labels_dbscan = dbscan.fit_predict(X) stats_dbscan = calculate_cluster_statistics_dbscan(X, labels_dbscan) # for stat in stats_dbscan: # print(f"Cluster {stat['cluster']}: {stat['num_points']} points, Density: {stat['density']}") - if len(array_columns) == 3: + if dimensions == 3: return visualize_clusters_3d(X, labels_dbscan, title="DBSCAN Clustering 3D") else: return visualize_clusters_2d(X, labels_dbscan, title="DBSCAN Clustering") diff --git a/src/pages/prediction.py b/src/pages/prediction.py index e2a95ab..d2b8805 100644 --- a/src/pages/prediction.py +++ b/src/pages/prediction.py @@ -30,18 +30,22 @@ if 'df' in st.session_state: with tab1: st.header("Clustering") selected_columns = handle_column_multiselect(df, "clustering") - + + if len(selected_columns) >= 3: + dimensions = st.radio("Reduce to dimensions X with PCA:",[2,3],index=0) + else: + dimensions = 2 tab_names = ["K-means", "DBSCAN"] tab11, tab12 = st.tabs(tab_names) with tab11: if st.button(f"Start {tab_names[0]}"): - st.pyplot(cc.launch_cluster_knn(df, selected_columns)) + st.pyplot(cc.launch_cluster_knn(df, selected_columns, dimensions=dimensions)) with tab12: if st.button(f"Start {tab_names[1]}"): - st.pyplot(cc.launch_cluster_dbscan(df, selected_columns)) + st.pyplot(cc.launch_cluster_dbscan(df, selected_columns, dimensions)) with tab2: st.header("Predictions")