You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
43 lines
2.0 KiB
43 lines
2.0 KiB
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, ConfusionMatrixDisplay, roc_curve, auc, RocCurveDisplay
|
|
import matplotlib.pyplot as plt
|
|
from sklearn.model_selection import learning_curve
|
|
import numpy as np
|
|
from sklearn import metrics
|
|
|
|
# Fonction pour calculer la matrice de confusion, le rapport de classification et l'exactitude
|
|
def calculateMatrix(y_true, y_pred):
|
|
accuracy = accuracy_score(y_true, y_pred)
|
|
confusion_matrix_result = confusion_matrix(y_true, y_pred)
|
|
classification_report_result = classification_report(y_true, y_pred, zero_division=1)
|
|
return accuracy, confusion_matrix_result, classification_report_result
|
|
|
|
# Fonction pour afficher une matrice de confusion
|
|
def seeMatrix(matrix, classes):
|
|
cmap = plt.cm.Blues
|
|
disp = ConfusionMatrixDisplay(confusion_matrix=matrix, display_labels=classes)
|
|
disp.plot(cmap=cmap)
|
|
plt.show()
|
|
|
|
# Fonction pour tracer la courbe ROC
|
|
def rocCurve(y_test, y_pred):
|
|
fpr, tpr, thresholds = metrics.roc_curve(y_test, y_pred)
|
|
roc_auc = metrics.auc(fpr, tpr)
|
|
display = metrics.RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc,
|
|
estimator_name='example estimator')
|
|
display.plot()
|
|
|
|
# Fonction pour visualiser la courbe d'apprentissage
|
|
def seeRocCurve(model, X_train, y_train, learning_reps):
|
|
# Calcul des scores d'apprentissage et de test en fonction de la taille de l'ensemble d'apprentissage
|
|
train_sizes, train_scores, test_scores = learning_curve(model, X_train, y_train, cv=learning_reps, scoring='accuracy', n_jobs=-1)
|
|
# Calcul des moyennes des scores d'apprentissage et de test
|
|
train_scores_mean = np.mean(train_scores, axis=1)
|
|
test_scores_mean = np.mean(test_scores, axis=1)
|
|
|
|
plt.plot(train_sizes, train_scores_mean, label='Entrainement')
|
|
plt.plot(train_sizes, test_scores_mean, label='Test')
|
|
plt.xlabel("lol")
|
|
plt.ylabel('Score')
|
|
plt.title('Courbe d apprentissage')
|
|
plt.legend()
|
|
plt.show() |