from matplotlib.widgets import RectangleSelector import matplotlib.pyplot as plt from random import randint import numpy as np def doPatchMatch(img,x1,y1,x2,y2,patchSize=129): def getDist(pValue1, pValue2): return np.sum((pValue1 - pValue2) ** 2) def initializePermimiter(finish=False): perimeter = [] for x in range(x1, x2 + 1): perimeter.append((x, y1)) perimeter.append((x, y2)) if finish: perimeter.append((x,y1-1)) perimeter.append((x,y2+1)) for y in range(y1 + 1, y2): perimeter.append((x1, y)) perimeter.append((x2, y)) if finish: perimeter.append((x1-1,y)) perimeter.append((x2+1,y)) return np.array(perimeter) def getRandomPatchFromPerimiter(perimiter): x,y = perimiter[np.random.randint(len(perimiter))] patch = np.array([[i, j] for i in range(x - semiPatch, x + semiPatch + 1) for j in range(y - semiPatch, y + semiPatch + 1)]) return patch def getZoneMask(zoneValue,outside): mask = [] for value in zoneValue: mask.append((value.sum() == 0) ^outside) return np.array(mask) def applyMask(patch,mask,oposed=False): return patch[mask^oposed] def getValueFromPatch(patch): ret = img[patch[0][1]:patch[0][1]+patchSize,patch[0][0]:patch[0][0]+patchSize] ret = ret.transpose(1, 0, 2) return ret.reshape(-1, 3) def getRandomPatch(patchCoordFound): if (len(patchCoordFound) == 0): #TODO peut ĂȘtre trouver un patch autour du trou et verrifier que pas dans le trou x = randint(semiPatch,width-semiPatch-1) y = randint(semiPatch,height-semiPatch-1) patch = np.array([[i, j] for i in range(x - semiPatch, x + semiPatch + 1) for j in range(y - semiPatch, y + semiPatch + 1)]) else: patch = patchCoordFound[randint(0,len(patchCoordFound)-1)] return patch def getBestNeigbourPatch(zoneMask,filteredZoneValue,dist,patch,offset): voisin = [[-1,-1],[-1,0],[0,-1],[0,0],[1,-1],[-1,1],[0,1],[1,0],[1,1]] found = False bPatch = [] for x,y in voisin: nPatch = patch.copy() nPatch[:,0] += x*offset nPatch[:,1] += y*offset if np.any(nPatch < 0) or np.any(nPatch[:,0] >= width) or np.any(nPatch[:,1] >= height): #TODO verrifier que le patch est pas dans le troue si non ff continue nPatchValue = getValueFromPatch(nPatch) filteredPatchValue = applyMask(nPatchValue,zoneMask) nDist = getDist(filteredZoneValue,filteredPatchValue) if (nDist < dist): dist = nDist bPatch = nPatch found = True return found,bPatch,dist def getBestPatchForZone(zoneValue,zoneMask,patchCoordFound): filteredZoneValue = applyMask(zoneValue,zoneMask) patch = getRandomPatch(patchCoordFound) patchValue = getValueFromPatch(patch) filteredPatchValue = applyMask(patchValue,zoneMask) dist = getDist(filteredZoneValue,filteredPatchValue) offset = 1 while offset < min(width,height)/2: found, nPatch,nDist = getBestNeigbourPatch(zoneMask,filteredZoneValue,dist,patch,offset) if (found): patch = nPatch dist = nDist offset = 1 else: offset*=2 patchCoordFound.append(patch) return patchValue def applyPatch(filteredZone,zoneMask, patchValue): filteredPatchValue = applyMask(patchValue,zoneMask,True) for i in range(len(filteredZone)) : img[filteredZone[i][1],filteredZone[i][0]] = filteredPatchValue[i] def updatePerimiter(filteredZone,perimiter): for x,y in filteredZone: if ((x,y) in filteredZone): perimiter = np.delete(perimiter, np.where((perimiter == [x, y]).all(axis=1))[0], axis=0) voisin = [[-1,-1],[-1,0],[0,-1],[0,0],[1,-1],[-1,1],[0,1],[1,0],[1,1]] for x,y in filteredZone: for offsetx,offsety in voisin: if img[y+offsety,x+offsetx].sum() == 0: perimiter = np.vstack((perimiter, [x+offsetx, y+offsety])) return perimiter def addEdge(edges,zone): # pas des deux cotĂ© car zone pas filteredZone pour endroit biscornue x,y = zone[0] for xx in range(x,x+patchSize): if x1<=xx<=x2: if y1<=y<=y2: edges.append([xx,y]) if y1<=y+patchSize<=y2: edges.append([xx,y+patchSize]) for yy in range(y,y+patchSize): if y1<=yy<=y2: if x1<=x<=x2: edges.append([x,yy]) if x1<=x+patchSize<=x2: edges.append([x+patchSize,yy]) return edges def smoothEdges(edges): perimiter = initializePermimiter(True) edges.extend(perimiter.tolist()) edges = np.array(edges) offsets = np.array([[-1,-1],[-1,0],[-1,1],[0,-1],[0,1],[1,-1],[1,0],[1,1]]) for edge in edges: neighbors = edge + offsets[:,None] neighbors = neighbors.reshape(-1,2) valid_neighbors = neighbors[ (neighbors[:,0] >= 0) & (neighbors[:,0] < width) & (neighbors[:,1] >= 0) & (neighbors[:,1] < height) ] if len(valid_neighbors) > 0: neighbor_values = img[valid_neighbors[:,1], valid_neighbors[:,0]] avg_value = np.mean(neighbor_values, axis=0) img[edge[1], edge[0]] = avg_value # for x,y in edges: # img[y,x] = [255,0,0] semiPatch = int(patchSize/2) height, width, _ = img.shape patchCoordFound = [] edges = [] perimiter = initializePermimiter() img[y1:y2+1, x1:x2+1] = 0 it = 0 while len(perimiter)> 0: zone = getRandomPatchFromPerimiter(perimiter) edges = addEdge(edges,zone) zoneValue = getValueFromPatch(zone) zoneMask = getZoneMask(zoneValue,True) filteredZoneInside = applyMask(zone,zoneMask,True) patchValue = getBestPatchForZone(zoneValue,zoneMask,patchCoordFound) applyPatch(filteredZoneInside,zoneMask,patchValue) perimiter = updatePerimiter(filteredZoneInside,perimiter) it +=1 print(it) print("smoothing edges") smoothEdges(edges) return img img = plt.imread('asset/vache.png') if img.dtype == np.float32: img = (img * 255).astype(np.uint8) img = img[:,:,0:3] def onselect(eclick, erelease): x1, y1 = eclick.xdata, eclick.ydata x2, y2 = erelease.xdata, erelease.ydata print("drawing") img_copy = np.copy(img) res = doPatchMatch(img_copy,int(x1),int(y1),int(x2),int(y2)) ax.imshow(res) plt.draw() print("drawed") fig, ax = plt.subplots() ax.imshow(img) toggle_selector = RectangleSelector(ax, onselect, useblit=True, button=[1], minspanx=5, minspany=5, spancoords='pixels', interactive=True) plt.axis('off') plt.show()