import matplotlib.pyplot as plt from sklearn.cluster import KMeans def perform_kmeans_clustering(data, data_name, n_clusters, n_init, max_iter): x = data[data_name].to_numpy() kmeans = KMeans(n_clusters=n_clusters, init="random", n_init=n_init, max_iter=max_iter, random_state=111) y_kmeans = kmeans.fit_predict(x) fig = plt.figure() if len(data_name) == 2: ax = fig.add_subplot(projection='rectilinear') plt.scatter(x[:, 0], x[:, 1], c=y_kmeans, s=50, cmap="viridis") centers = kmeans.cluster_centers_ plt.scatter(centers[:, 0], centers[:, 1], c="black", s=200, marker="X") else: ax = fig.add_subplot(projection='3d') ax.scatter(x[:, 0], x[:, 1], x[:, 2], c=y_kmeans, s=50, cmap="viridis") centers = kmeans.cluster_centers_ ax.scatter(centers[:, 0], centers[:, 1], centers[:, 2], c="black", s=200, marker="X") return fig