rewrite + conflict

master
remrem 1 year ago
commit 3882a5f0b3

@ -1,4 +1,5 @@
#!/usr/bin/python3 #!/usr/bin/python3
import os
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -12,7 +13,20 @@ from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
# from sklearn.externals.joblib import parallel_backend
# main
def main():
# User input
opt = prompt_display()
model = model_switch(opt)
# Get interesting data
df = read_dataset("data.csv")
x, y = get_xy_from_dataframe(df)
# Train model
training(model, x, y)
# Open dataset with panda # Open dataset with panda
def read_dataset(filename): def read_dataset(filename):
@ -20,54 +34,101 @@ def read_dataset(filename):
return df return df
# Drop useless columns and return x and y # Drop useless columns and return x and y
def get_xy_from_dataset(filename): def get_xy_from_dataframe(df):
df = read_dataset(filename)
x = df.drop(['obj_ID','field_ID','run_ID','rerun_ID','cam_col','plate','MJD','fiber_ID','class'],axis=1) x = df.drop(['obj_ID','field_ID','run_ID','rerun_ID','cam_col','plate','MJD','fiber_ID','class'],axis=1)
y = df['class'].values y = df['class'].values
return x, y return x, y
x, y = get_xy_from_dataset("data.csv") # Ask for model choice
def prompt_display():
print("""Choose a model:
x.hist()
#plt.show()
print("""Choose a model:
(1) - KNN (1) - KNN
(2) - Tree (2) - Tree
(3) - RandomForestClassifier (3) - RandomForestClassifier
(4) - SGD (4) - SGD
(5) - Linear SVC""") (5) - Linear SVC""")
res = int(input()) return int(input())
if (res == 1): def model_switch(choice):
model = KNeighborsClassifier() if (choice == 1):
elif (res == 2): model = KNeighborsClassifier()
model = DecisionTreeClassifier(random_state=0, max_depth=20) elif (choice == 2):
elif (res == 3): model = DecisionTreeClassifier(random_state=0, max_depth=20)
model = RandomForestClassifier(n_estimators=100 ,criterion='entropy') elif (choice == 3):
elif (res == 4): model = RandomForestClassifier(n_estimators=100 ,criterion='entropy')
model = SGDClassifier(max_iter=1000, tol=0.01) elif (choice == 4):
elif (res == 5): model = SGDClassifier(max_iter=1000, tol=0.01)
model = svm.SVC(kernel='linear', C = 1.0) elif (choice == 5):
else: model = svm.SVC(kernel='linear', C = 1.0)
raise Exception('RENTRE LE BON NOMBRE GROS CON') else:
raise Exception('RENTRE LE BON NOMBRE GROS CON')
return model
Xtrain, Xtest, ytrain, ytest = train_test_split(x, y,test_size=0.25, random_state=0) def plot_columns_hist(columns):
x.hist()
Xtrain = Xtrain.values plt.show()
Xtest = Xtest.values
def printPredictedValues(ypredit,ytest):
if len(Xtrain.shape) < 2: for i in range(0,len(ypredit)):
Xtrain = Xtrain.reshape(-1, 1) print("✅ Prédit/Réel: ",ypredit[i],ytest[i]) if ypredit[i]==ytest[i] else print("🔴 Prédit/Réel: ",ypredit[i], ytest[i])
if len(Xtest.shape) < 2: def printStatValues(ypredit,ytest):
Xtest = Xtest.reshape(-1, 1) galaxyStats = 0
model.fit(Xtrain,ytrain) starStats = 0
QSOStats = 0
ypredit = model.predict(Xtest) N = len(ypredit)
# print(ypredit) NF = 0
print(accuracy_score(ytest, ypredit)) for i in range(0,N):
if ypredit[i] != ytest[i]:
NF +=1
if ypredit[i] == "GALAXY":
galaxyStats+=1
elif ypredit[i] == "QSO":
QSOStats+=1
elif ypredit[i]=="STAR":
starStats+=1
print("Répartition des prédiction fausses : ")
print("Galaxy : ",(galaxyStats*100/NF),"%","Star :",(starStats*100/NF),"%","QSO : ",(QSOStats*100/NF),"%")
# Train model
def training(model, x, y):
Xtrain, Xtest, ytrain, ytest = train_test_split(x, y,test_size=0.25, random_state=0)
Xtrain = Xtrain.values
Xtest = Xtest.values
if len(Xtrain.shape) < 2:
Xtrain = Xtrain.reshape(-1, 1)
if len(Xtest.shape) < 2:
Xtest = Xtest.reshape(-1, 1)
# if isinstance(model, svm.LinearSVC):
# with parallel_backend('threading', n_jobs=-1):
# model.fit(X_train, y_train)
#else:
model.fit(Xtrain,ytrain)
ypredit = model.predict(Xtest)
os.system("clear")
res = int(input())
while(res != 0):
print(" Rentre un chiffre:\n\n1 - Stats %\n2 - Stats raw\n3 - accuracy_score")
print("0 - QUIT")
res = int(input())
if(res == 1):
os.system("clear")
printStatValues(ypredit,ytest)
elif(res == 2):
os.system("clear")
printPredictedValues(ypredit,ytest)
elif res == 3:
os.system("clear")
print(accuracy_score(ytest, ypredit))
elif res == 0:
break
else:
raise Exception('Mauvaise saisie')
main()

Loading…
Cancel
Save