rewrite + conflict

master
remrem 1 year ago
commit 3882a5f0b3

@ -1,5 +1,6 @@
#!/usr/bin/python3 #!/usr/bin/python3
import numpy as np import os
import numpy as np
import pandas as pd import pandas as pd
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import sklearn as sk import sklearn as sk
@ -12,7 +13,20 @@ from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
# from sklearn.externals.joblib import parallel_backend
# main
def main():
# User input
opt = prompt_display()
model = model_switch(opt)
# Get interesting data
df = read_dataset("data.csv")
x, y = get_xy_from_dataframe(df)
# Train model
training(model, x, y)
# Open dataset with panda # Open dataset with panda
def read_dataset(filename): def read_dataset(filename):
@ -20,54 +34,101 @@ def read_dataset(filename):
return df return df
# Drop useless columns and return x and y # Drop useless columns and return x and y
def get_xy_from_dataset(filename): def get_xy_from_dataframe(df):
df = read_dataset(filename)
x = df.drop(['obj_ID','field_ID','run_ID','rerun_ID','cam_col','plate','MJD','fiber_ID','class'],axis=1) x = df.drop(['obj_ID','field_ID','run_ID','rerun_ID','cam_col','plate','MJD','fiber_ID','class'],axis=1)
y = df['class'].values y = df['class'].values
return x, y return x, y
x, y = get_xy_from_dataset("data.csv")
x.hist() # Ask for model choice
#plt.show() def prompt_display():
print("""Choose a model:
print("""Choose a model:
(1) - KNN (1) - KNN
(2) - Tree (2) - Tree
(3) - RandomForestClassifier (3) - RandomForestClassifier
(4) - SGD (4) - SGD
(5) - Linear SVC""") (5) - Linear SVC""")
res = int(input()) return int(input())
if (res == 1): def model_switch(choice):
model = KNeighborsClassifier() if (choice == 1):
elif (res == 2): model = KNeighborsClassifier()
model = DecisionTreeClassifier(random_state=0, max_depth=20) elif (choice == 2):
elif (res == 3): model = DecisionTreeClassifier(random_state=0, max_depth=20)
model = RandomForestClassifier(n_estimators=100 ,criterion='entropy') elif (choice == 3):
elif (res == 4): model = RandomForestClassifier(n_estimators=100 ,criterion='entropy')
model = SGDClassifier(max_iter=1000, tol=0.01) elif (choice == 4):
elif (res == 5): model = SGDClassifier(max_iter=1000, tol=0.01)
model = svm.SVC(kernel='linear', C = 1.0) elif (choice == 5):
else: model = svm.SVC(kernel='linear', C = 1.0)
raise Exception('RENTRE LE BON NOMBRE GROS CON') else:
raise Exception('RENTRE LE BON NOMBRE GROS CON')
Xtrain, Xtest, ytrain, ytest = train_test_split(x, y,test_size=0.25, random_state=0)
Xtrain = Xtrain.values
Xtest = Xtest.values
if len(Xtrain.shape) < 2:
Xtrain = Xtrain.reshape(-1, 1)
if len(Xtest.shape) < 2: return model
Xtest = Xtest.reshape(-1, 1)
model.fit(Xtrain,ytrain)
ypredit = model.predict(Xtest) def plot_columns_hist(columns):
# print(ypredit) x.hist()
print(accuracy_score(ytest, ypredit)) plt.show()
def printPredictedValues(ypredit,ytest):
for i in range(0,len(ypredit)):
print("✅ Prédit/Réel: ",ypredit[i],ytest[i]) if ypredit[i]==ytest[i] else print("🔴 Prédit/Réel: ",ypredit[i], ytest[i])
def printStatValues(ypredit,ytest):
galaxyStats = 0
starStats = 0
QSOStats = 0
N = len(ypredit)
NF = 0
for i in range(0,N):
if ypredit[i] != ytest[i]:
NF +=1
if ypredit[i] == "GALAXY":
galaxyStats+=1
elif ypredit[i] == "QSO":
QSOStats+=1
elif ypredit[i]=="STAR":
starStats+=1
print("Répartition des prédiction fausses : ")
print("Galaxy : ",(galaxyStats*100/NF),"%","Star :",(starStats*100/NF),"%","QSO : ",(QSOStats*100/NF),"%")
# Train model
def training(model, x, y):
Xtrain, Xtest, ytrain, ytest = train_test_split(x, y,test_size=0.25, random_state=0)
Xtrain = Xtrain.values
Xtest = Xtest.values
if len(Xtrain.shape) < 2:
Xtrain = Xtrain.reshape(-1, 1)
if len(Xtest.shape) < 2:
Xtest = Xtest.reshape(-1, 1)
# if isinstance(model, svm.LinearSVC):
# with parallel_backend('threading', n_jobs=-1):
# model.fit(X_train, y_train)
#else:
model.fit(Xtrain,ytrain)
ypredit = model.predict(Xtest)
os.system("clear")
res = int(input())
while(res != 0):
print(" Rentre un chiffre:\n\n1 - Stats %\n2 - Stats raw\n3 - accuracy_score")
print("0 - QUIT")
res = int(input())
if(res == 1):
os.system("clear")
printStatValues(ypredit,ytest)
elif(res == 2):
os.system("clear")
printPredictedValues(ypredit,ytest)
elif res == 3:
os.system("clear")
print(accuracy_score(ytest, ypredit))
elif res == 0:
break
else:
raise Exception('Mauvaise saisie')
main()

Loading…
Cancel
Save