prediction/clustering page with user columns choice
continuous-integration/drone/push Build is passing Details

pull/22/head
remrem 10 months ago
parent f2fa040de0
commit a6fb8d2b35

@ -6,10 +6,13 @@ sys.path.append('./back/')
import clustering_csv as cc import clustering_csv as cc
import prediction as p import prediction as p
def display_prediction_results(df, targetCol): def handle_column_multiselect(df, method_name):
df_cols.remove(col) selected_columns = st.multiselect(f"Select the columns you want for {method_name}:", df.columns.tolist(), placeholder="Select dataset columns")
original_col = df[col] return selected_columns
predicted_col = p.getColumnsForPredictionAndPredict(df, df_cols, "Route Type", "Linear Regression")
def display_prediction_results(df, targetCol, sourceColumns, method):
original_col = df[targetCol]
predicted_col = p.getColumnsForPredictionAndPredict(df, sourceColumns, targetCol, method)
new_df = pd.DataFrame() new_df = pd.DataFrame()
new_df['Original'] = original_col new_df['Original'] = original_col
@ -19,22 +22,50 @@ def display_prediction_results(df, targetCol):
if 'df' in st.session_state: if 'df' in st.session_state:
df = st.session_state.df df = st.session_state.df
df_cols = df.columns.tolist()
st.write("# 🔮 Prediction") st.write("# 🔮 Prediction")
if st.button("K-means"): tab1, tab2 = st.tabs(["Clustering", "Predictions"])
st.pyplot(cc.launch_cluster_knn(df, ["Route Type", "Traffic Control"]))
with tab1:
st.header("Clustering")
selected_columns = handle_column_multiselect(df, "clustering")
tab_names = ["K-means", "DBSCAN"]
tab11, tab12 = st.tabs(tab_names)
with tab11:
if st.button(f"Start {tab_names[0]}"):
st.pyplot(cc.launch_cluster_knn(df, selected_columns))
with tab12:
if st.button(f"Start {tab_names[1]}"):
st.pyplot(cc.launch_cluster_dbscan(df, selected_columns))
with tab2:
st.header("Predictions")
target_column = st.selectbox(
"Target column:",
df.columns.tolist(),
index=None,
placeholder="Select target column"
)
if st.button("DBSCAN"): if target_column != None:
st.pyplot(cc.launch_cluster_dbscan(df, ["Route Type", "Traffic Control"])) selected_columns_p = handle_column_multiselect(df, "predictions")
tab_names = ["Linear Regression", "Random Forest"]
tab21, tab22 = st.tabs(tab_names)
if st.button("Linear Regression"): with tab21:
col = "Route Type" if st.button(f"Start {tab_names[0]}"):
display_prediction_results(df, col) st.write(target_column)
st.write(selected_columns_p)
display_prediction_results(df, target_column, selected_columns_p, tab_names[0])
if st.button("Random Forest"): with tab22:
col = "Route Type" if st.button(f"Start {tab_names[1]}"):
display_prediction_results(df, col) display_prediction_results(df, target_column, selected_columns_p, tab_names[1])
else: else:
st.write("Please clean your dataset.") st.write("Please clean your dataset.")

Loading…
Cancel
Save