You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
67 lines
1.8 KiB
67 lines
1.8 KiB
import os
|
|
import sys
|
|
parent_dir_name = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(parent_dir_name + "/3nar/code")
|
|
|
|
from nnnar import *
|
|
from knn import *
|
|
import numpy as np
|
|
import pandas as pd
|
|
from time import time
|
|
|
|
def runOneTest(trainTestRatio,model):
|
|
df = pd.read_csv('./data/Iris.csv')
|
|
df = df.iloc[:, 1:]
|
|
# Normalisation des données
|
|
|
|
df.iloc[:, 0] = df.iloc[:, 0] - df.iloc[:, 0].min()
|
|
df.iloc[:, 1] = df.iloc[:, 1] - df.iloc[:, 1].min()
|
|
df.iloc[:, 2] = df.iloc[:, 2] - df.iloc[:, 2].min()
|
|
df.iloc[:, 3] = df.iloc[:, 3] - df.iloc[:, 3].min()
|
|
|
|
df.iloc[:, 0] = df.iloc[:, 0] / df.iloc[:, 0].max()
|
|
df.iloc[:, 1] = df.iloc[:, 1] / df.iloc[:, 1].max()
|
|
df.iloc[:, 2] = df.iloc[:, 2] / df.iloc[:, 2].max()
|
|
df.iloc[:, 3] = df.iloc[:, 3] / df.iloc[:, 3].max()
|
|
|
|
df.iloc[:, 0:4] = df.iloc[:, 0:4] * 100
|
|
|
|
# Création des données d'entrainement et de test
|
|
train = df.sample(frac=trainTestRatio)
|
|
test = df.drop(train.index)
|
|
|
|
# Entrainement du modèle
|
|
coord = train.iloc[:, :-1].values
|
|
value = train.iloc[:, -1].values
|
|
|
|
for i in range(len(coord)):
|
|
model.addPoint(np.array(coord[i]), np.array([value[i]]))
|
|
|
|
# Test du modèle
|
|
coord = test.iloc[:, :-1].values
|
|
value = test.iloc[:, -1].values
|
|
|
|
nbError = 0
|
|
for i in range(len(coord)):
|
|
if model.getLabelOfPoint(np.array(coord[i]), 5) != value[i]:
|
|
nbError += 1
|
|
|
|
return 100 - nbError / len(coord) * 100
|
|
|
|
|
|
if (sys.argv[1] == "knn"):
|
|
model = Knn()
|
|
else:
|
|
model = Nnnar(4, 0, 100.001, 10)
|
|
|
|
t1 = time()
|
|
|
|
nbRepetition = 100
|
|
accuracy = 0
|
|
for i in range(nbRepetition):
|
|
model.reset()
|
|
accuracy += runOneTest(0.8,model)
|
|
|
|
t2 = time()
|
|
print("accuracy moyenne:",str(accuracy/nbRepetition))
|
|
print("delta temps:",str(t2-t1)) |