diff --git a/frontend/pages/prediction_classification.py b/frontend/pages/prediction_classification.py index 20ae5e1..c11d7ee 100644 --- a/frontend/pages/prediction_classification.py +++ b/frontend/pages/prediction_classification.py @@ -1,11 +1,11 @@ import streamlit as st from sklearn.linear_model import LogisticRegression from sklearn.model_selection import train_test_split -from sklearn.metrics import accuracy_score +from sklearn.metrics import accuracy_score,confusion_matrix from sklearn.preprocessing import LabelEncoder import pandas as pd import matplotlib.pyplot as plt - +import seaborn as sns st.header("Prediction: Classification") @@ -63,24 +63,20 @@ if "data" in st.session_state: st.write("Prediction:", prediction[0]) + if len(data_name) == 1: + fig = plt.figure() + y_pred = [model.predict(pd.DataFrame([pred_value[0]], columns=data_name)) for pred_value in X.values.tolist()] + print([x[0] for x in X.values.tolist()]) + cm = confusion_matrix(y, y_pred) + sns.heatmap(cm, annot=True, fmt="d") - fig = plt.figure() - dataframe_sorted = pd.concat([X, y], axis=1).sort_values(by=data_name) - - X = dataframe_sorted[data_name[0]] - y = dataframe_sorted[target_name] - - prediction_array_y = [ - model.predict(pd.DataFrame([[dataframe_sorted[data_name[0]].iloc[i]]], columns=data_name))[0] - for i in range(dataframe_sorted.shape[0]) - ] + plt.xlabel('Predicted') + plt.ylabel('True') - plt.scatter(dataframe_sorted[data_name[0]], dataframe_sorted[target_name], color='b') - plt.scatter(dataframe_sorted[data_name[0]], prediction_array_y, color='r') + st.pyplot(fig) - st.pyplot(fig) else: diff --git a/frontend/pages/prediction_regression.py b/frontend/pages/prediction_regression.py index 6d125e0..e06fa12 100644 --- a/frontend/pages/prediction_regression.py +++ b/frontend/pages/prediction_regression.py @@ -2,7 +2,6 @@ import streamlit as st from sklearn.linear_model import LinearRegression import pandas as pd import matplotlib.pyplot as plt -import numpy as np st.header("Prediction: Regression") @@ -31,6 +30,7 @@ if "data" in st.session_state: fig = plt.figure() dataframe_sorted = pd.concat([X, y], axis=1).sort_values(by=data_name) + if len(data_name) == 1: X = dataframe_sorted[data_name[0]] y = dataframe_sorted[target_name]