From 5f960df83842e866833d863b63d08c0a0348acfd Mon Sep 17 00:00:00 2001 From: clfreville2 Date: Fri, 7 Jun 2024 11:58:52 +0200 Subject: [PATCH] Support Pandas linear regression --- frontend/mvstrategy.py | 13 +++++++++++-- frontend/pages/normalization.py | 2 +- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/frontend/mvstrategy.py b/frontend/mvstrategy.py index 81db2f8..fb7cc6c 100644 --- a/frontend/mvstrategy.py +++ b/frontend/mvstrategy.py @@ -14,11 +14,11 @@ class MVStrategy(ABC): return df @staticmethod - def list_available(series: Series) -> list['MVStrategy']: + def list_available(df: DataFrame, series: Series) -> list['MVStrategy']: """Get all the strategies that can be used.""" choices = [DropStrategy(), ModeStrategy()] if is_numeric_dtype(series): - choices.extend((MeanStrategy(), MedianStrategy())) + choices.extend((MeanStrategy(), MedianStrategy(), LinearRegressionStrategy())) return choices @@ -68,3 +68,12 @@ class ModeStrategy(PositionStrategy): def __str__(self) -> str: return "Use mode" + + +class LinearRegressionStrategy(MVStrategy): + def apply(self, df: DataFrame, label: str, series: Series) -> DataFrame: + series.interpolate(inplace=True) + return df + + def __str__(self) -> str: + return "Use linear regression" diff --git a/frontend/pages/normalization.py b/frontend/pages/normalization.py index 4f20c7a..7dd5b84 100644 --- a/frontend/pages/normalization.py +++ b/frontend/pages/normalization.py @@ -7,7 +7,7 @@ if "data" in st.session_state: for column, series in data.items(): missing_count = series.isna().sum() - choices = MVStrategy.list_available(series) + choices = MVStrategy.list_available(data, series) option = st.selectbox( f"Missing values of {column} ({missing_count})", choices,