Merge files using strategies
continuous-integration/drone/push Build is passing Details

pull/15/head
Clément FRÉVILLE 10 months ago
parent 9da6e2d594
commit 01ef19a2f8

@ -1,25 +1,40 @@
from sklearn.cluster import DBSCAN, KMeans from sklearn.cluster import DBSCAN, KMeans
import numpy as np import numpy as np
from dataclasses import dataclass
from abc import ABC, abstractmethod
from typing import Any, Optional
class DBSCAN_cluster(): @dataclass
def __init__(self, eps, min_samples,data): class ClusterResult:
labels: np.array
centers: Optional[np.array]
statistics: list[dict[str, Any]]
class Cluster(ABC):
@abstractmethod
def run(self, data: np.array) -> ClusterResult:
pass
class DBSCANCluster(Cluster):
def __init__(self, eps: float = 0.5, min_samples: int = 5):
self.eps = eps self.eps = eps
self.min_samples = min_samples self.min_samples = min_samples
self.data = data
self.labels = np.array([])
def run(self): #@typing.override
def run(self, data: np.array) -> ClusterResult:
dbscan = DBSCAN(eps=self.eps, min_samples=self.min_samples) dbscan = DBSCAN(eps=self.eps, min_samples=self.min_samples)
self.labels = dbscan.fit_predict(self.data) labels = dbscan.fit_predict(data)
return self.labels return ClusterResult(labels, None, self.get_statistics(data, labels))
def get_stats(self): def get_statistics(self, data: np.array, labels: np.array) -> list[dict[str, Any]]:
unique_labels = np.unique(self.labels) unique_labels = np.unique(labels)
stats = [] stats = []
for label in unique_labels: for label in unique_labels:
if label == -1: if label == -1:
continue continue
cluster_points = self.data[self.labels == label] cluster_points = data[labels == label]
num_points = len(cluster_points) num_points = len(cluster_points)
density = num_points / (np.max(cluster_points, axis=0) - np.min(cluster_points, axis=0)).prod() density = num_points / (np.max(cluster_points, axis=0) - np.min(cluster_points, axis=0)).prod()
stats.append({ stats.append({
@ -27,37 +42,42 @@ class DBSCAN_cluster():
"num_points": num_points, "num_points": num_points,
"density": density "density": density
}) })
return stats return stats
class KMeans_cluster(): def __str__(self) -> str:
def __init__(self, n_clusters, n_init, max_iter, data): return "DBScan"
class KMeansCluster(Cluster):
def __init__(self, n_clusters: int = 8, n_init: int = 1, max_iter: int = 300):
self.n_clusters = n_clusters self.n_clusters = n_clusters
self.n_init = n_init self.n_init = n_init
self.max_iter = max_iter self.max_iter = max_iter
self.data = data
self.labels = np.array([])
self.centers = []
def run(self): #@typing.override
def run(self, data: np.array) -> ClusterResult:
kmeans = KMeans(n_clusters=self.n_clusters, init="random", n_init=self.n_init, max_iter=self.max_iter, random_state=111) kmeans = KMeans(n_clusters=self.n_clusters, init="random", n_init=self.n_init, max_iter=self.max_iter, random_state=111)
self.labels = kmeans.fit_predict(self.data) labels = kmeans.fit_predict(data)
self.centers = kmeans.cluster_centers_ centers = kmeans.cluster_centers_
return self.labels return ClusterResult(labels, centers, self.get_statistics(data, labels, centers))
def get_statistics(self, data: np.array, labels: np.array, centers: np.array) -> list[dict[str, Any]]:
def get_stats(self): unique_labels = np.unique(labels)
unique_labels = np.unique(self.labels)
stats = [] stats = []
for label in unique_labels: for label in unique_labels:
cluster_points = self.data[self.labels == label] cluster_points = data[labels == label]
num_points = len(cluster_points) num_points = len(cluster_points)
center = self.centers[label] center = centers[label]
stats.append({ stats.append({
'cluster': label, "cluster": label,
'num_points': num_points, "num_points": num_points,
'center': center "center": center,
}) })
return stats return stats
def __str__(self) -> str:
return "KMeans"
CLUSTERING_STRATEGIES = [DBSCANCluster(), KMeansCluster()]

@ -0,0 +1,48 @@
import streamlit as st
import matplotlib.pyplot as plt
from clusters import DBSCANCluster, KMeansCluster, CLUSTERING_STRATEGIES
st.header("Clustering")
if "data" in st.session_state:
data = st.session_state.data
general_row = st.columns([1, 1])
clustering = general_row[0].selectbox("Clustering method", CLUSTERING_STRATEGIES)
data_name = general_row[1].multiselect("Data Name",data.select_dtypes(include="number").columns, max_selections=3)
with st.form("cluster_form"):
if isinstance(clustering, KMeansCluster):
row1 = st.columns([1, 1, 1])
clustering.n_clusters = row1[0].number_input("Number of clusters", min_value=1, max_value=data.shape[0], value=clustering.n_clusters)
clustering.n_init = row1[1].number_input("n_init", min_value=1, value=clustering.n_init)
clustering.max_iter = row1[2].number_input("max_iter", min_value=1, value=clustering.max_iter)
elif isinstance(clustering, DBSCANCluster):
clustering.eps = st.slider("eps", min_value=0.0001, max_value=1.0, step=0.1, value=clustering.eps)
clustering.min_samples = st.number_input("min_samples", min_value=1, value=clustering.min_samples)
st.form_submit_button("Launch")
if len(data_name) >= 2 and len(data_name) <=3:
x = data[data_name].to_numpy()
result = clustering.run(x)
st.table(result.statistics)
fig = plt.figure()
if len(data_name) == 2:
ax = fig.add_subplot(projection='rectilinear')
plt.scatter(x[:, 0], x[:, 1], c=result.labels, s=50, cmap="viridis")
if result.centers is not None:
plt.scatter(result.centers[:, 0], result.centers[:, 1], c="black", s=200, marker="X")
else:
ax = fig.add_subplot(projection='3d')
ax.scatter(x[:, 0], x[:, 1],x[:, 2], c=result.labels, s=50, cmap="viridis")
if result.centers is not None:
ax.scatter(result.centers[:, 0], result.centers[:, 1], result.centers[:, 2], c="black", s=200, marker="X")
st.pyplot(fig)
else:
st.error("file not loaded")

@ -1,32 +0,0 @@
import streamlit as st
import matplotlib.pyplot as plt
from clusters import DBSCAN_cluster
st.header("Clustering: dbscan")
if "data" in st.session_state:
data = st.session_state.data
with st.form("my_form"):
data_name = st.multiselect("Data Name", data.select_dtypes(include="number").columns, max_selections=3)
eps = st.slider("eps", min_value=0.0, max_value=1.0, value=0.5, step=0.01)
min_samples = st.number_input("min_samples", step=1, min_value=1, value=5)
st.form_submit_button("launch")
if len(data_name) >= 2 and len(data_name) <=3:
x = data[data_name].to_numpy()
dbscan = DBSCAN_cluster(eps,min_samples,x)
y_dbscan = dbscan.run()
st.table(dbscan.get_stats())
fig = plt.figure()
if len(data_name) == 2:
ax = fig.add_subplot(projection='rectilinear')
plt.scatter(x[:, 0], x[:, 1], c=y_dbscan, s=50, cmap="viridis")
else:
ax = fig.add_subplot(projection='3d')
ax.scatter(x[:, 0], x[:, 1],x[:, 2], c=y_dbscan, s=50, cmap="viridis")
st.pyplot(fig)
else:
st.error("file not loaded")

@ -1,44 +0,0 @@
import streamlit as st
import matplotlib.pyplot as plt
from clusters import KMeans_cluster
st.header("Clustering: kmeans")
if "data" in st.session_state:
data = st.session_state.data
with st.form("my_form"):
row1 = st.columns([1,1,1])
n_clusters = row1[0].selectbox("Number of clusters", range(1,data.shape[0]))
data_name = row1[1].multiselect("Data Name",data.select_dtypes(include="number").columns, max_selections=3)
n_init = row1[2].number_input("n_init",step=1,min_value=1)
row2 = st.columns([1,1])
max_iter = row1[0].number_input("max_iter",step=1,min_value=1)
st.form_submit_button("launch")
if len(data_name) >= 2 and len(data_name) <=3:
x = data[data_name].to_numpy()
kmeans = KMeans_cluster(n_clusters, n_init, max_iter, x)
y_kmeans = kmeans.run()
st.table(kmeans.get_stats())
centers = kmeans.centers
fig = plt.figure()
if len(data_name) == 2:
ax = fig.add_subplot(projection='rectilinear')
plt.scatter(x[:, 0], x[:, 1], c=y_kmeans, s=50, cmap="viridis")
plt.scatter(centers[:, 0], centers[:, 1], c="black", s=200, marker="X")
else:
ax = fig.add_subplot(projection='3d')
ax.scatter(x[:, 0], x[:, 1],x[:, 2], c=y_kmeans, s=50, cmap="viridis")
ax.scatter(centers[:, 0], centers[:, 1], centers[:, 2], c="black", s=200, marker="X")
st.pyplot(fig)
else:
st.error("file not loaded")
Loading…
Cancel
Save