You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
85 lines
1.9 KiB
85 lines
1.9 KiB
import os
|
|
import sys
|
|
parent_dir_name = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(parent_dir_name + "/3nar/code")
|
|
|
|
from nnnar import *
|
|
from knn import *
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from time import time
|
|
|
|
|
|
# initialisation du modèle
|
|
comp = False
|
|
if (sys.argv[1] == "comp"):
|
|
model1 = Knn()
|
|
model2 = Nnnar(1, 0, 1, 50)
|
|
comp = True
|
|
if (sys.argv[1] == "knn"):
|
|
model = Knn()
|
|
else:
|
|
model = Nnnar(1, 0, 1, 50)
|
|
|
|
maxVal = 20_000
|
|
nbTest = 100
|
|
nbPts = 10
|
|
|
|
# Création des données d'entrainement
|
|
train = []
|
|
test = []
|
|
for i in range(maxVal):
|
|
x = np.random.rand()
|
|
y = np.random.rand()
|
|
train.append([x,y])
|
|
for i in range(nbTest):
|
|
x = np.random.rand()
|
|
test.append(x)
|
|
train = np.array(train)
|
|
test = np.array(test)
|
|
|
|
def testModel(model, train ,test):
|
|
model.reset()
|
|
t = time()
|
|
# Entrainement du modèle
|
|
for i in range(len(train)):
|
|
model.addPoint(np.array([train[i,0]]), np.array([train[i,1]]))
|
|
# Test du modèles
|
|
for i in range(len(test)):
|
|
model.getValueOfPoint(np.array([test[i]]), 5)[0]
|
|
return time() - t
|
|
|
|
if comp:
|
|
res1 = []
|
|
res2 = []
|
|
idxs = []
|
|
for i in range(1,nbPts):
|
|
idx = round((i*maxVal)/nbPts)
|
|
print(idx)
|
|
idxs.append(idx)
|
|
res1.append(testModel(model1, train[:idx], test))
|
|
res2.append(testModel(model2, train[:idx], test))
|
|
|
|
plt.xlabel('Number of training points')
|
|
plt.ylabel('Time (s)')
|
|
plt.xticks(range(len(idxs)), idxs)
|
|
|
|
plt.plot(res1,label='KNN')
|
|
plt.plot(res2,label='3NAR')
|
|
plt.legend()
|
|
plt.show()
|
|
else:
|
|
res = []
|
|
idxs = []
|
|
for i in range(1,nbPts):
|
|
idx = round((i*maxVal)/nbPts)
|
|
print(idx)
|
|
idxs.append(idx)
|
|
res.append(testModel(model, train[:round((i*maxVal)/nbPts)], test))
|
|
|
|
plt.xlabel('Number of training points')
|
|
plt.ylabel('Time (s)')
|
|
plt.xticks(range(len(idxs)), idxs)
|
|
plt.plot(res)
|
|
plt.show() |