diff --git a/src/pages/evaluate.py b/src/pages/evaluate.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/pages/prediction.py b/src/pages/prediction.py index 6a87c01..e2a95ab 100644 --- a/src/pages/prediction.py +++ b/src/pages/prediction.py @@ -6,10 +6,13 @@ sys.path.append('./back/') import clustering_csv as cc import prediction as p -def display_prediction_results(df, targetCol): - df_cols.remove(col) - original_col = df[col] - predicted_col = p.getColumnsForPredictionAndPredict(df, df_cols, "Route Type", "Linear Regression") +def handle_column_multiselect(df, method_name): + selected_columns = st.multiselect(f"Select the columns you want for {method_name}:", df.columns.tolist(), placeholder="Select dataset columns") + return selected_columns + +def display_prediction_results(df, targetCol, sourceColumns, method): + original_col = df[targetCol] + predicted_col = p.getColumnsForPredictionAndPredict(df, sourceColumns, targetCol, method) new_df = pd.DataFrame() new_df['Original'] = original_col @@ -19,22 +22,50 @@ def display_prediction_results(df, targetCol): if 'df' in st.session_state: df = st.session_state.df - df_cols = df.columns.tolist() st.write("# 🔮 Prediction") - if st.button("K-means"): - st.pyplot(cc.launch_cluster_knn(df, ["Route Type", "Traffic Control"])) + tab1, tab2 = st.tabs(["Clustering", "Predictions"]) + + with tab1: + st.header("Clustering") + selected_columns = handle_column_multiselect(df, "clustering") + + + tab_names = ["K-means", "DBSCAN"] + tab11, tab12 = st.tabs(tab_names) + + with tab11: + if st.button(f"Start {tab_names[0]}"): + st.pyplot(cc.launch_cluster_knn(df, selected_columns)) + + with tab12: + if st.button(f"Start {tab_names[1]}"): + st.pyplot(cc.launch_cluster_dbscan(df, selected_columns)) + + with tab2: + st.header("Predictions") + target_column = st.selectbox( + "Target column:", + df.columns.tolist(), + index=None, + placeholder="Select target column" + ) - if st.button("DBSCAN"): - st.pyplot(cc.launch_cluster_dbscan(df, ["Route Type", "Traffic Control"])) + if target_column != None: + selected_columns_p = handle_column_multiselect(df, "predictions") + + tab_names = ["Linear Regression", "Random Forest"] + tab21, tab22 = st.tabs(tab_names) - if st.button("Linear Regression"): - col = "Route Type" - display_prediction_results(df, col) + with tab21: + if st.button(f"Start {tab_names[0]}"): + st.write(target_column) + st.write(selected_columns_p) + display_prediction_results(df, target_column, selected_columns_p, tab_names[0]) - if st.button("Random Forest"): - col = "Route Type" - display_prediction_results(df, col) + with tab22: + if st.button(f"Start {tab_names[1]}"): + display_prediction_results(df, target_column, selected_columns_p, tab_names[1]) else: st.write("Please clean your dataset.")