You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

115 lines
6.1 KiB

from maskedImage import MaskedImage
import numpy as np
import random
class ImageRelation:
def __init__(self,input,output,patchSize):
self.input = input
self.output = output
self.patchSize = patchSize
def randomize(self):
# on crée un assignation zone patch de manière random
self.field = np.zeros((self.input.height, self.input.width, 3), dtype=int)
self.field[:,:,2] = MaskedImage.DSCALE # tant que la réel distance est pas calculé, elle est concidéré comme maximal
self.field[:,:,0] = np.random.randint(0,self.output.width,(self.input.height,self.input.width))
self.field[:,:,1] = np.random.randint(0,self.output.height,(self.input.height,self.input.width))
self.initialize() # on calcule la vrai distance et on change les patch si ils ne conviennent pas
def initializeFromImageRelation(self,imRel):
# on crée l'assignation zone patch à partir des assignations précédentes
self.field = np.zeros((self.input.height, self.input.width, 3), dtype=int)
fx = int(self.input.width/imRel.input.width)
fy = int(self.input.height/imRel.input.height)
for y in range(self.input.height):
for x in range(self.input.width):
xl = min(int(x/fx),imRel.input.width-1)
yl = min(int(y/fy),imRel.input.height-1)
self.field[y,x] = (imRel.field[yl,xl,0]*fx, imRel.field[yl,xl,1]*fy, MaskedImage.DSCALE)
self.initialize() # on calcule la vrai distance et on change si ils ne conviennent pas
def initialize(self):
for y in range(self.input.height):
for x in range(self.input.width):
self.field[y,x,2] = self.distance(x,y,self.field[y,x,0],self.field[y,x,1]) # calcule la vrai distance des patchs
iter= 0
maxIter = 10 # au cas ou pour ne pas rester bloqué par manque de chance
while (self.field[y,x,2] == MaskedImage.DSCALE and iter<maxIter): # tant qu'on a pas trouvé un patch qui n'est pas dans le trou
self.field[y,x] = (random.randint(0,self.output.width),random.randint(0,self.output.height),self.distance(x,y,self.field[y,x,0],self.field[y,x,1]))
iter += 1
def findBestPatch(self,nbPass):
# recherche les meilleurs patch pour toutes les zones
for i in range(nbPass):
for y in range(self.input.height-1): # on cherche le meilleur patch à droit et en bas
for x in range(self.input.width-1):
if (self.field[y,x,2]>0):
self.findBestPatchFroOne(x,y,i)
for y in range(self.input.height-1,0,-1): # on cherche le meilleur patche en haut et à gauche
for x in range(self.input.width-1,0,-1):
if (self.field[y,x,2]>0):
self.findBestPatchFroOne(x,y,-i)
def findBestPatchFroOne(self,x,y,direction):
# recherche le meilleur patch pour une zone en particulier dans un sens (en haut et à gauche | à droite et en bas) + cherche random
# horizontale
if (0<x-direction<self.input.width):
xp = self.field[y,x-direction,0] +direction
yp = self.field[y,x-direction,1]
dp = self.distance(x,y,xp,yp)
if (dp<self.field[y,x,2]):
self.field[y,x] = (xp,yp,dp)
# verticale
if (0<y-direction<self.input.height):
xp = self.field[y-direction,x,0]
yp = self.field[y-direction,x,1] + direction
dp = self.distance(x,y,xp,yp)
if (dp<self.field[y,x,2]):
self.field[y,x] = (xp,yp,dp)
# recherche random
# (on fait la cherche random ici et pas dans la méthode parente pour pouvoir faire deux recherche random au lieux de 1
# ce qui ne côute pas grand chose et augmente la qualité de l'inpainting)
zoneRecherche = max(self.output.height,self.output.width)
while zoneRecherche>0:
xp = self.field[y,x,0] + random.randint(0,2*zoneRecherche)-zoneRecherche
yp = self.field[y,x,1] + random.randint(0,2*zoneRecherche)-zoneRecherche
xp = max(0,min(self.output.width-1,xp))
yp = max(0,min(self.output.height-1,yp))
dp = self.distance(x,y,xp,yp)
if (dp<self.field[y,x,2]):
self.field[y,x] = (xp,yp,dp)
zoneRecherche = int(zoneRecherche/2)
def distance(self,x,y,xp,yp):
# simple interface avec la function distance
return distance(self.input,x,y,self.output,xp,yp,self.patchSize)
def distance(source, xs, ys, target, xt, yt, patchSize): # cette function est ici et pas dans function pour éviter les import circulaire et car ce fichier est le seul qui l'utilise
# calcule la distance entre deux patch (pas la distance physique mais la distance des valeurs de leurs pixels)
# en utilisant la somme des erreurs au carré
# la valeurs maximal de cette fonction est MaskedImage.DSCALE et n'est retourné en pratique presque que quand
# le patch à un problème (qu'il est en partie dans un trou ou qu'il sort du cadre de l'image)
dMax = 255*255*3
distance, wsum = 0, dMax * (patchSize*2)**2
for dy in range(-patchSize, patchSize):
yks, ykt = ys + dy, yt + dy
if not (1 <= yks < source.height - 1 and 1 <= ykt < target.height - 1):
distance += dMax * patchSize * 2
continue
for dx in range(-patchSize, patchSize):
xks, xkt = xs + dx, xt + dx
if not (1 <= xks < source.width - 1 and 1 <= xkt < target.width - 1):
distance += dMax
continue
if source.containsMask(xks, yks,patchSize):
distance += dMax
continue
if target.containsMask(xkt, ykt,patchSize):
distance += dMax
continue
d = np.sum((source.image[yks, xks] - target.image[ykt, xkt]) ** 2)
distance += d
return int(MaskedImage.DSCALE * distance / wsum)