You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
3nar/demo3.py

95 lines
2.4 KiB

import os
import sys
parent_dir_name = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(parent_dir_name + "/3nar/code")
from nnnar import *
from knn import *
import random
import matplotlib.pyplot as plt
import numpy as np
from time import time
#########################
# #
# function #
# #
#########################
def replaceRandomPixel(image,nb):
for i in range(nb):
y = random.randint(0, height - 1)
x = random.randint(0, width - 1)
image[y,x] = [255,255,255]
return image
#########################
# #
# programme #
# #
#########################
image = plt.imread('./data/img.jpg')
image = np.array(image)
imageOg = image.copy()
height, width, _ = image.shape
print("image lue")
# remplacer l'image par une image avec des pixels corrompus
image = replaceRandomPixel(image, 1000)
differences = np.zeros((height, width))
for y in range(1, height-1):
for x in range(1, width-1):
pixel = image[y,x].astype(int)
neighbors = []
n = [[y-1,x], [y+1,x], [y,x-1], [y,x+1]]
for ny, nx in n:
if ny >= 0 and ny < height and nx >= 0 and nx < width:
neighbors.append(image[ny,nx])
avg_diff = np.mean([np.abs(pixel - n).sum() for n in neighbors])
differences[y,x] = avg_diff
print("image corompue")
# Trouver les pixels corrompus algorithmiquement
corrupted_pixels = np.where(differences > 200)
yCor = corrupted_pixels[0]
xCor = corrupted_pixels[1]
badPix = list(zip(yCor, xCor))
goodPix = [(y, x) for y in range(height) for x in range(width) if (y, x) not in badPix]
print("pixel corompu trouvé")
# Fournir les données d'entrainement au model
t = time()
if (sys.argv[1] == "knn"):
model = Knn()
else:
model = Nnnar(2,0,260,100)
for y, x in goodPix:
model.addPoint(np.array([x,y]), np.array(image[y,x]))
# Remplacer les pixels corrompus par des pixels prédits
image2 = image.copy()
for y, x in badPix:
coord = np.array([x,y])
image2[y,x] = model.getValueOfPoint(coord, 8)
print("temps d'infération des",str(len(badPix)),"pixels:",str(round(time()-t,3)),"s")
plt.figure(figsize=(13, 6))
plt.subplot(131)
plt.imshow(imageOg)
plt.axis('off')
plt.title('orininal')
plt.subplot(132)
plt.imshow(image)
plt.axis('off')
plt.title('corompu')
plt.subplot(133)
plt.imshow(image2)
plt.axis('off')
plt.title('corrigé')
plt.show()