You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
165 lines
6.4 KiB
165 lines
6.4 KiB
from matplotlib.widgets import RectangleSelector
|
|
import matplotlib.pyplot as plt
|
|
from random import randint
|
|
import numpy as np
|
|
|
|
def initialPatchMatch(img,x1,y1,x2,y2,patchSize=129):
|
|
def getDist(pValue1, pValue2):
|
|
return np.sum((pValue1 - pValue2) ** 2)
|
|
|
|
def initializePermimiter(finish=False):
|
|
perimeter = []
|
|
for x in range(x1, x2 + 1):
|
|
perimeter.append((x, y1))
|
|
perimeter.append((x, y2))
|
|
if finish:
|
|
perimeter.append((x,y1-1))
|
|
perimeter.append((x,y2+1))
|
|
|
|
for y in range(y1 + 1, y2):
|
|
perimeter.append((x1, y))
|
|
perimeter.append((x2, y))
|
|
if finish:
|
|
perimeter.append((x1-1,y))
|
|
perimeter.append((x2+1,y))
|
|
return np.array(perimeter)
|
|
|
|
def getRandomPatchFromPerimiter(perimiter):
|
|
x,y = perimiter[np.random.randint(len(perimiter))]
|
|
patch = np.array([[i, j] for i in range(x - semiPatch, x + semiPatch + 1)
|
|
for j in range(y - semiPatch, y + semiPatch + 1)])
|
|
return patch
|
|
|
|
def getZoneMask(zoneValue,outside):
|
|
mask = []
|
|
for value in zoneValue:
|
|
mask.append((value.sum() == 0) ^outside)
|
|
return np.array(mask)
|
|
|
|
def applyMask(patch,mask,oposed=False):
|
|
return patch[mask^oposed]
|
|
|
|
def getValueFromPatch(patch):
|
|
ret = img[patch[0][1]:patch[0][1]+patchSize,patch[0][0]:patch[0][0]+patchSize]
|
|
ret = ret.transpose(1, 0, 2)
|
|
return ret.reshape(-1, 3)
|
|
|
|
def getRandomPatch(patchCoordFound):
|
|
if (len(patchCoordFound) == 0):
|
|
#TODO peut être trouver un patch autour du trou et verrifier que pas dans le trou
|
|
x = randint(semiPatch,width-semiPatch-1)
|
|
y = randint(semiPatch,height-semiPatch-1)
|
|
patch = np.array([[i, j] for i in range(x - semiPatch, x + semiPatch + 1)
|
|
for j in range(y - semiPatch, y + semiPatch + 1)])
|
|
else:
|
|
patch = patchCoordFound[randint(0,len(patchCoordFound)-1)]
|
|
return patch
|
|
|
|
def getBestNeigbourPatch(zoneMask,filteredZoneValue,dist,patch,offset):
|
|
voisin = [[-1,-1],[-1,0],[0,-1],[0,0],[1,-1],[-1,1],[0,1],[1,0],[1,1]]
|
|
found = False
|
|
bPatch = []
|
|
for x,y in voisin:
|
|
nPatch = patch.copy()
|
|
nPatch[:,0] += x*offset
|
|
nPatch[:,1] += y*offset
|
|
if np.any(nPatch < 0) or np.any(nPatch[:,0] >= width) or np.any(nPatch[:,1] >= height):
|
|
#TODO verrifier que le patch est pas dans le troue si non ff
|
|
continue
|
|
nPatchValue = getValueFromPatch(nPatch)
|
|
filteredPatchValue = applyMask(nPatchValue,zoneMask)
|
|
nDist = getDist(filteredZoneValue,filteredPatchValue)
|
|
if (nDist < dist):
|
|
dist = nDist
|
|
bPatch = nPatch
|
|
found = True
|
|
return found,bPatch,dist
|
|
|
|
|
|
def getBestPatchForZone(zoneValue,zoneMask,patchCoordFound):
|
|
filteredZoneValue = applyMask(zoneValue,zoneMask)
|
|
patch = getRandomPatch(patchCoordFound)
|
|
patchValue = getValueFromPatch(patch)
|
|
filteredPatchValue = applyMask(patchValue,zoneMask)
|
|
dist = getDist(filteredZoneValue,filteredPatchValue)
|
|
offset = 1
|
|
while offset < min(width,height):
|
|
found, nPatch,nDist = getBestNeigbourPatch(zoneMask,filteredZoneValue,dist,patch,offset)
|
|
if (found):
|
|
patch = nPatch
|
|
dist = nDist
|
|
offset = 1
|
|
else:
|
|
offset*=2
|
|
patchCoordFound.append(patch)
|
|
return patchValue
|
|
|
|
def applyPatch(filteredZone,zoneMask, patchValue):
|
|
filteredPatchValue = applyMask(patchValue,zoneMask,True)
|
|
for i in range(len(filteredZone)) :
|
|
img[filteredZone[i][1],filteredZone[i][0]] = filteredPatchValue[i]
|
|
|
|
def updatePerimiter(filteredZone,perimiter):
|
|
for x,y in filteredZone:
|
|
if ((x,y) in filteredZone):
|
|
perimiter = np.delete(perimiter, np.where((perimiter == [x, y]).all(axis=1))[0], axis=0)
|
|
voisin = [[-1,-1],[-1,0],[0,-1],[0,0],[1,-1],[-1,1],[0,1],[1,0],[1,1]]
|
|
for x,y in filteredZone:
|
|
for offsetx,offsety in voisin:
|
|
if img[y+offsety,x+offsetx].sum() == 0:
|
|
perimiter = np.vstack((perimiter, [x+offsetx, y+offsety]))
|
|
return perimiter
|
|
|
|
def addEdge(edges,zone):
|
|
# pas des deux coté car zone pas filteredZone pour endroit biscornue
|
|
x,y = zone[0]
|
|
for xx in range(x,x+patchSize):
|
|
if x1<=xx<=x2:
|
|
if y1<=y<=y2:
|
|
edges.append([xx,y])
|
|
if y1<=y+patchSize<=y2:
|
|
edges.append([xx,y+patchSize])
|
|
for yy in range(y,y+patchSize):
|
|
if y1<=yy<=y2:
|
|
if x1<=x<=x2:
|
|
edges.append([x,yy])
|
|
if x1<=x+patchSize<=x2:
|
|
edges.append([x+patchSize,yy])
|
|
return edges
|
|
|
|
def smoothEdges(edges):
|
|
perimiter = initializePermimiter(True)
|
|
edges.extend(perimiter.tolist())
|
|
edges = np.array(edges)
|
|
offsets = np.array([[-1,-1],[-1,0],[-1,1],[0,-1],[0,1],[1,-1],[1,0],[1,1]])
|
|
|
|
for edge in edges:
|
|
neighbors = edge + offsets[:,None]
|
|
neighbors = neighbors.reshape(-1,2)
|
|
valid_neighbors = neighbors[
|
|
(neighbors[:,0] >= 0) & (neighbors[:,0] < width) &
|
|
(neighbors[:,1] >= 0) & (neighbors[:,1] < height)
|
|
]
|
|
if len(valid_neighbors) > 0:
|
|
neighbor_values = img[valid_neighbors[:,1], valid_neighbors[:,0]]
|
|
avg_value = np.mean(neighbor_values, axis=0)
|
|
img[edge[1], edge[0]] = avg_value
|
|
|
|
semiPatch = int(patchSize/2)
|
|
height, width, _ = img.shape
|
|
patchCoordFound = []
|
|
edges = []
|
|
|
|
perimiter = initializePermimiter()
|
|
while len(perimiter)> 0:
|
|
zone = getRandomPatchFromPerimiter(perimiter)
|
|
edges = addEdge(edges,zone)
|
|
zoneValue = getValueFromPatch(zone)
|
|
zoneMask = getZoneMask(zoneValue,True)
|
|
filteredZoneInside = applyMask(zone,zoneMask,True)
|
|
patchValue = getBestPatchForZone(zoneValue,zoneMask,patchCoordFound)
|
|
applyPatch(filteredZoneInside,zoneMask,patchValue)
|
|
perimiter = updatePerimiter(filteredZoneInside,perimiter)
|
|
smoothEdges(edges)
|
|
return img
|