diff --git a/frontend/clusters.py b/frontend/clusters.py new file mode 100644 index 0000000..ac2af4c --- /dev/null +++ b/frontend/clusters.py @@ -0,0 +1,63 @@ +from sklearn.cluster import DBSCAN, KMeans +import numpy as np + +class DBSCAN_cluster(): + def __init__(self, eps, min_samples,data): + self.eps = eps + self.min_samples = min_samples + self.data = data + self.labels = np.array([]) + + def run(self): + dbscan = DBSCAN(eps=self.eps, min_samples=self.min_samples) + self.labels = dbscan.fit_predict(self.data) + return self.labels + + def get_stats(self): + unique_labels = np.unique(self.labels) + stats = [] + for label in unique_labels: + if label == -1: + continue + cluster_points = self.data[self.labels == label] + num_points = len(cluster_points) + density = num_points / (np.max(cluster_points, axis=0) - np.min(cluster_points, axis=0)).prod() + stats.append({ + "cluster": label, + "num_points": num_points, + "density": density + }) + + return stats + + +class KMeans_cluster(): + def __init__(self, n_clusters, n_init, max_iter, data): + self.n_clusters = n_clusters + self.n_init = n_init + self.max_iter = max_iter + self.data = data + self.labels = np.array([]) + self.centers = [] + + def run(self): + kmeans = KMeans(n_clusters=self.n_clusters, init="random", n_init=self.n_init, max_iter=self.max_iter, random_state=111) + self.labels = kmeans.fit_predict(self.data) + self.centers = kmeans.cluster_centers_ + return self.labels + + + def get_stats(self): + unique_labels = np.unique(self.labels) + stats = [] + + for label in unique_labels: + cluster_points = self.data[self.labels == label] + num_points = len(cluster_points) + center = self.centers[label] + stats.append({ + 'cluster': label, + 'num_points': num_points, + 'center': center + }) + return stats diff --git a/frontend/pages/clustering_dbscan.py b/frontend/pages/clustering_dbscan.py index d06b10a..7ca16f6 100644 --- a/frontend/pages/clustering_dbscan.py +++ b/frontend/pages/clustering_dbscan.py @@ -1,10 +1,9 @@ import streamlit as st import matplotlib.pyplot as plt -from sklearn.cluster import DBSCAN +from clusters import DBSCAN_cluster st.header("Clustering: dbscan") - if "data" in st.session_state: data = st.session_state.data @@ -17,8 +16,9 @@ if "data" in st.session_state: if len(data_name) >= 2 and len(data_name) <=3: x = data[data_name].to_numpy() - dbscan = DBSCAN(eps=eps, min_samples=min_samples) - y_dbscan = dbscan.fit_predict(x) + dbscan = DBSCAN_cluster(eps,min_samples,x) + y_dbscan = dbscan.run() + st.table(dbscan.get_stats()) fig = plt.figure() if len(data_name) == 2: @@ -28,8 +28,5 @@ if "data" in st.session_state: ax = fig.add_subplot(projection='3d') ax.scatter(x[:, 0], x[:, 1],x[:, 2], c=y_dbscan, s=50, cmap="viridis") st.pyplot(fig) - - - else: st.error("file not loaded") \ No newline at end of file diff --git a/frontend/pages/clustering_kmeans.py b/frontend/pages/clustering_kmeans.py index c61bf40..63c7d55 100644 --- a/frontend/pages/clustering_kmeans.py +++ b/frontend/pages/clustering_kmeans.py @@ -1,10 +1,9 @@ import streamlit as st -from sklearn.cluster import KMeans import matplotlib.pyplot as plt +from clusters import KMeans_cluster st.header("Clustering: kmeans") - if "data" in st.session_state: data = st.session_state.data @@ -23,21 +22,22 @@ if "data" in st.session_state: if len(data_name) >= 2 and len(data_name) <=3: x = data[data_name].to_numpy() - kmeans = KMeans(n_clusters=n_clusters, init="random", n_init=n_init, max_iter=max_iter, random_state=111) - y_kmeans = kmeans.fit_predict(x) + kmeans = KMeans_cluster(n_clusters, n_init, max_iter, x) + y_kmeans = kmeans.run() + + st.table(kmeans.get_stats()) + centers = kmeans.centers fig = plt.figure() if len(data_name) == 2: ax = fig.add_subplot(projection='rectilinear') plt.scatter(x[:, 0], x[:, 1], c=y_kmeans, s=50, cmap="viridis") - centers = kmeans.cluster_centers_ plt.scatter(centers[:, 0], centers[:, 1], c="black", s=200, marker="X") else: ax = fig.add_subplot(projection='3d') ax.scatter(x[:, 0], x[:, 1],x[:, 2], c=y_kmeans, s=50, cmap="viridis") - centers = kmeans.cluster_centers_ - ax.scatter(centers[:, 0], centers[:, 1],centers[:, 2], c="black", s=200, marker="X") + ax.scatter(centers[:, 0], centers[:, 1], centers[:, 2], c="black", s=200, marker="X") st.pyplot(fig) else: