You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
45 lines
1.1 KiB
45 lines
1.1 KiB
from nnnar import *
|
|
from knn import *
|
|
import matplotlib.pyplot as plt
|
|
from random import *
|
|
|
|
knn = Knn()
|
|
nnnar = Nnnar(1, 0, 1000, 10)
|
|
|
|
def getRandomParam():
|
|
return [random()*10-5, random()*10-5,random()*10-5]
|
|
|
|
def applyFunction(x,param):
|
|
return pow(x*param[0],2) + x*param[1] + param[2]
|
|
|
|
param = getRandomParam()
|
|
nbPoints = 100
|
|
minx = 0
|
|
maxx = 1000
|
|
x = []
|
|
y = []
|
|
|
|
for i in [minx + (maxx-minx)*i/nbPoints for i in range(nbPoints)]:
|
|
x.append(i)
|
|
y.append(applyFunction(i,param))
|
|
knn.addPoint(np.array([i]),np.array([y[-1]]))
|
|
nnnar.addPoint(np.array([i]),np.array([y[-1]]))
|
|
|
|
error = 0
|
|
nbInfer = 20
|
|
for i in range(nbInfer):
|
|
xt = random()*maxx
|
|
yt = nnnar.getValueOfPoint(np.array([xt]),2)
|
|
yr = applyFunction(xt,param)
|
|
error += abs(yt[0]-yr)
|
|
ytk = knn.getValueOfPoint(np.array([xt]),2)
|
|
plt.plot(xt,yt[0],'xr')
|
|
plt.plot(xt,ytk[0],'xg')
|
|
|
|
print("Error: ",(error/nbInfer)/applyFunction(maxx,param))
|
|
|
|
plt.plot(x,y)
|
|
plt.title("f(x) = "+str(param[0])+"*x^2 + "+str(param[1])+"*x + "+str(param[2]))
|
|
plt.xlabel("X")
|
|
plt.ylabel("Y")
|
|
plt.show() |