You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
61 lines
1.4 KiB
61 lines
1.4 KiB
import os
|
|
import sys
|
|
parent_dir_name = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(parent_dir_name + "/3nar/code")
|
|
|
|
from nnnar import *
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from time import time
|
|
|
|
|
|
|
|
maxVal = 1_000
|
|
nbTest = 1000
|
|
nbPts = 10
|
|
nbModel = 10
|
|
nbMaxSubDiv = 100
|
|
|
|
# Création des données d'entrainement
|
|
train = []
|
|
test = []
|
|
for i in range(maxVal):
|
|
x = np.random.rand()
|
|
y = np.random.rand()
|
|
train.append([x,y])
|
|
for i in range(nbTest):
|
|
x = np.random.rand()
|
|
test.append(x)
|
|
train = np.array(train)
|
|
test = np.array(test)
|
|
|
|
def testModel(model, train ,test):
|
|
model.reset()
|
|
t = time()
|
|
# Entrainement du modèle
|
|
for i in range(len(train)):
|
|
model.addPoint(np.array([train[i,0]]), np.array([train[i,1]]))
|
|
# Test du modèles
|
|
for i in range(len(test)):
|
|
model.getValueOfPoint(np.array([test[i]]), 5)[0]
|
|
return time() - t
|
|
|
|
|
|
for i in range(1,nbModel):
|
|
nbSub = round(i*nbMaxSubDiv/nbModel)
|
|
print(nbSub)
|
|
model = Nnnar(1,0,1,nbSub)
|
|
res = []
|
|
idxs = []
|
|
for i in range(1,nbPts):
|
|
idx = round((i*nbTest)/nbPts)
|
|
idxs.append(idx)
|
|
res.append(testModel(model, train, test[:idx]))
|
|
plt.plot(res,label='NNNAR('+str(nbSub)+')')
|
|
|
|
plt.xlabel('Number of points infered')
|
|
plt.ylabel('Time (s)')
|
|
plt.xticks(range(len(idxs)), idxs)
|
|
plt.legend()
|
|
plt.show() |