You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
61 lines
1.9 KiB
61 lines
1.9 KiB
import numpy as np
|
|
|
|
from point import *
|
|
|
|
|
|
class Knn:
|
|
def __init__(self):
|
|
self.space = np.array([])
|
|
|
|
def reset(self):
|
|
self.space = np.array([])
|
|
|
|
def addPoint(self, coord, value):
|
|
self.space = np.append(self.space, Point(coord, value))
|
|
|
|
def getValueOfPoint(self, coord,nbNearest):
|
|
points, dists = self.getNNearest(coord, nbNearest)
|
|
ttPart = 1/(dists[0]+1)
|
|
value = []
|
|
for i in points[0].value:
|
|
value.append(i * ttPart)
|
|
for idx in range(1,len(points)):
|
|
ttPart += 1/(dists[idx]+1)
|
|
tvalue = points[idx].value
|
|
for i in range(len(tvalue)):
|
|
value[i] += tvalue[i] * (1/(dists[idx]+1))
|
|
for i in range(len(value)):
|
|
value[i] = value[i] / ttPart
|
|
return value
|
|
|
|
def getLabelOfPoint(self, coord,nbNearest):
|
|
points, dists = self.getNNearest(coord, nbNearest)
|
|
label = {}
|
|
for idx in range(len(points)):
|
|
tvalue = points[idx].value[0]
|
|
if tvalue in label:
|
|
label[tvalue] += 1/(dists[idx]+1)
|
|
else:
|
|
label[tvalue] = 1/(dists[idx]+1)
|
|
value = max(label, key=label.get)
|
|
return value
|
|
|
|
def getNNearest(self, coord, nbNearest):
|
|
dist = np.copy(self.space)
|
|
dist = np.frompyfunc(lambda x: x.getDistFromCoord(coord), 1, 1)(dist)
|
|
if (nbNearest > len(dist)):
|
|
print("Error: not enough points")
|
|
return None, None
|
|
found = []
|
|
distance = []
|
|
for i in range(nbNearest):
|
|
found.append(self.space[i])
|
|
distance.append(dist[i])
|
|
maxDistIdx = np.argmax(distance)
|
|
for i in range(nbNearest, len(dist)):
|
|
if (dist[i] < distance[maxDistIdx]):
|
|
found[maxDistIdx] = self.space[i]
|
|
distance[maxDistIdx] = dist[i]
|
|
maxDistIdx = np.argmax(distance)
|
|
return found, distance
|
|
|