|
|
|
@ -18,18 +18,18 @@ if "data" in st.session_state:
|
|
|
|
|
max_iter = row1[0].number_input("max_iter",step=1,min_value=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.form_submit_button('launch')
|
|
|
|
|
st.form_submit_button("launch")
|
|
|
|
|
|
|
|
|
|
if len(data_name) == 2:
|
|
|
|
|
x = data[data_name].to_numpy()
|
|
|
|
|
|
|
|
|
|
kmeans = KMeans(n_clusters=n_clusters, init='random', n_init=n_init, max_iter=max_iter, random_state=111)
|
|
|
|
|
kmeans = KMeans(n_clusters=n_clusters, init="random", n_init=n_init, max_iter=max_iter, random_state=111)
|
|
|
|
|
y_kmeans = kmeans.fit_predict(x)
|
|
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(12,8))
|
|
|
|
|
plt.scatter(x[:, 0], x[:, 1], c=y_kmeans, s=50, cmap='viridis')
|
|
|
|
|
plt.scatter(x[:, 0], x[:, 1], c=y_kmeans, s=50, cmap="viridis")
|
|
|
|
|
centers = kmeans.cluster_centers_
|
|
|
|
|
plt.scatter(centers[:, 0], centers[:, 1], c='black', s=200, marker='X')
|
|
|
|
|
plt.scatter(centers[:, 0], centers[:, 1], c="black", s=200, marker="X")
|
|
|
|
|
st.pyplot(fig)
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|