You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
3nar/demo2.py

52 lines
1.1 KiB

import os
import sys
parent_dir_name = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(parent_dir_name + "/3nar/code")
from nnnar import *
from knn import *
import matplotlib.pyplot as plt
from random import *
# initialisation du modèle
if (sys.argv[1] == "knn"):
model = Knn()
else:
model = Nnnar(1, 0, 1000, 10)
def getRandomParam():
return [random()*10-5, random()*10-5,random()*10-5]
def applyFunction(x,param):
return pow(x*param[0],2) + x*param[1] + param[2]
param = getRandomParam()
nbPoints = 100
minx = 0
maxx = 1000
x = []
y = []
for i in [minx + (maxx-minx)*i/nbPoints for i in range(nbPoints)]:
x.append(i)
y.append(applyFunction(i,param))
model.addPoint(np.array([i]),np.array([y[-1]]))
error = 0
nbInfer = 20
for i in range(nbInfer):
xt = random()*maxx
yt = model.getValueOfPoint(np.array([xt]),2)
yr = applyFunction(xt,param)
error += abs(yt[0]-yr)
plt.plot(xt,yt[0],'xr')
print("Error: ",(error/nbInfer)/applyFunction(maxx,param))
plt.plot(x,y)
plt.title("f(x) = "+str(param[0])+"*x^2 + "+str(param[1])+"*x + "+str(param[2]))
plt.xlabel("X")
plt.ylabel("Y")
plt.show()