You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
miner/frontend/pages/clustering.py

36 lines
1.1 KiB

import streamlit as st
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
st.header("Clustering")
if "data" in st.session_state:
data = st.session_state.data
with st.form("my_form"):
row1 = st.columns([1,1,1])
n_clusters = row1[0].selectbox("Number of clusters", range(1, 10))
data_name = row1[1].multiselect("Data Name",data.select_dtypes(include="number").columns, max_selections=2)
n_init = row1[2].number_input("n_init",step=1,min_value=1)
row2 = st.columns([1,1])
max_iter = row1[0].number_input("max_iter",step=1,min_value=1)
st.form_submit_button('launch')
if len(data_name) == 2:
x = data[data_name].to_numpy()
kmeans = KMeans(n_clusters=n_clusters, init='random', n_init=n_init, max_iter=max_iter, random_state=111)
y_kmeans = kmeans.fit_predict(x)
fig, ax = plt.subplots(figsize=(12,8))
plt.scatter(x[:, 0], x[:, 1], s=100, c=kmeans.labels_, cmap='Set1')
plt.scatter(kmeans.cluster_centers_[:, 0], kmeans.cluster_centers_[:, 1], s=400, marker='*', color='k')
st.pyplot(fig)
else:
st.error("file not loaded")