You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
miner/backend/kmeans_strategy.py

22 lines
935 B

import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
def perform_kmeans_clustering(data, data_name, n_clusters, n_init, max_iter):
x = data[data_name].to_numpy()
kmeans = KMeans(n_clusters=n_clusters, init="random", n_init=n_init, max_iter=max_iter, random_state=111)
y_kmeans = kmeans.fit_predict(x)
fig = plt.figure()
if len(data_name) == 2:
ax = fig.add_subplot(projection='rectilinear')
plt.scatter(x[:, 0], x[:, 1], c=y_kmeans, s=50, cmap="viridis")
centers = kmeans.cluster_centers_
plt.scatter(centers[:, 0], centers[:, 1], c="black", s=200, marker="X")
else:
ax = fig.add_subplot(projection='3d')
ax.scatter(x[:, 0], x[:, 1], x[:, 2], c=y_kmeans, s=50, cmap="viridis")
centers = kmeans.cluster_centers_
ax.scatter(centers[:, 0], centers[:, 1], centers[:, 2], c="black", s=200, marker="X")
return fig