You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
3nar/plotComplexite3.py

61 lines
1.4 KiB

import os
import sys
parent_dir_name = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(parent_dir_name + "/3nar/code")
from nnnar import *
import matplotlib.pyplot as plt
import numpy as np
from time import time
maxVal = 1_000
nbTest = 1000
nbPts = 10
nbModel = 10
nbMaxSubDiv = 100
# Création des données d'entrainement
train = []
test = []
for i in range(maxVal):
x = np.random.rand()
y = np.random.rand()
train.append([x,y])
for i in range(nbTest):
x = np.random.rand()
test.append(x)
train = np.array(train)
test = np.array(test)
def testModel(model, train ,test):
model.reset()
t = time()
# Entrainement du modèle
for i in range(len(train)):
model.addPoint(np.array([train[i,0]]), np.array([train[i,1]]))
# Test du modèles
for i in range(len(test)):
model.getValueOfPoint(np.array([test[i]]), 5)[0]
return time() - t
for i in range(1,nbModel):
nbSub = round(i*nbMaxSubDiv/nbModel)
print(nbSub)
model = Nnnar(1,0,1,nbSub)
res = []
idxs = []
for i in range(1,nbPts):
idx = round((i*nbTest)/nbPts)
idxs.append(idx)
res.append(testModel(model, train, test[:idx]))
plt.plot(res,label='NNNAR('+str(nbSub)+')')
plt.xlabel('Number of points infered')
plt.ylabel('Time (s)')
plt.xticks(range(len(idxs)), idxs)
plt.legend()
plt.show()