From 405439564147fc6430b9a933e156fe1e480372da Mon Sep 17 00:00:00 2001 From: bastien Date: Tue, 25 Jun 2024 19:54:35 +0200 Subject: [PATCH] update --- frontend/pages/prediction_classification.py | 23 +++++++++++++++++++++ frontend/pages/prediction_regression.py | 7 +++---- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/frontend/pages/prediction_classification.py b/frontend/pages/prediction_classification.py index 5aaf52f..20ae5e1 100644 --- a/frontend/pages/prediction_classification.py +++ b/frontend/pages/prediction_classification.py @@ -4,6 +4,8 @@ from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score from sklearn.preprocessing import LabelEncoder import pandas as pd +import matplotlib.pyplot as plt + st.header("Prediction: Classification") @@ -60,5 +62,26 @@ if "data" in st.session_state: prediction = label_encoders[target_name].inverse_transform(prediction) st.write("Prediction:", prediction[0]) + + + + + fig = plt.figure() + dataframe_sorted = pd.concat([X, y], axis=1).sort_values(by=data_name) + + X = dataframe_sorted[data_name[0]] + y = dataframe_sorted[target_name] + + prediction_array_y = [ + model.predict(pd.DataFrame([[dataframe_sorted[data_name[0]].iloc[i]]], columns=data_name))[0] + for i in range(dataframe_sorted.shape[0]) + ] + + plt.scatter(dataframe_sorted[data_name[0]], dataframe_sorted[target_name], color='b') + plt.scatter(dataframe_sorted[data_name[0]], prediction_array_y, color='r') + + st.pyplot(fig) + + else: st.error("File not loaded") diff --git a/frontend/pages/prediction_regression.py b/frontend/pages/prediction_regression.py index 42acf34..6d125e0 100644 --- a/frontend/pages/prediction_regression.py +++ b/frontend/pages/prediction_regression.py @@ -41,8 +41,8 @@ if "data" in st.session_state: ] plt.scatter(dataframe_sorted[data_name[0]], dataframe_sorted[target_name], color='b') - plt.scatter(dataframe_sorted[data_name[0]], prediction_array_y, color='r') - else: + plt.plot(dataframe_sorted[data_name[0]], prediction_array_y, color='r') + elif len(data_name) == 2: ax = fig.add_subplot(111, projection='3d') prediction_array_y = [ @@ -51,10 +51,9 @@ if "data" in st.session_state: ] ax.scatter(dataframe_sorted[data_name[0]], dataframe_sorted[data_name[1]], dataframe_sorted[target_name], color='b') - ax.scatter(dataframe_sorted[data_name[0]], dataframe_sorted[data_name[1]], prediction_array_y, color='r') + ax.plot(dataframe_sorted[data_name[0]], dataframe_sorted[data_name[1]], prediction_array_y, color='r') st.pyplot(fig) - else: st.error("File not loaded")