from matplotlib.widgets import RectangleSelector import matplotlib.pyplot as plt from random import randint import numpy as np def doPatchMatch(img,x1,y1,x2,y2,patchSize=65): def getDist(pValue1, pValue2): return np.sum((pValue1 - pValue2) ** 2) def initializePermimiter(): perimeter = [] for x in range(x1, x2 + 1): perimeter.append((x, y1)) perimeter.append((x, y2)) for y in range(y1 + 1, y2): perimeter.append((x1, y)) perimeter.append((x2, y)) img[y1:y2+1, x1:x2+1] = 0 return np.array(perimeter) def getRandomPatchFromPerimiter(perimiter): x,y = perimiter[np.random.randint(len(perimiter))] patch = np.array([[i, j] for i in range(x - semiPatch, x + semiPatch + 1) for j in range(y - semiPatch, y + semiPatch + 1)]) return patch def getZoneMask(zoneValue,outside): mask = [] for value in zoneValue: mask.append((value.sum() == 0) ^outside) return np.array(mask) def applyMask(patch,mask,oposed=False): patchf = [] for i in range(len(mask)): if(mask[i]^oposed): patchf.append(patch[i]) return np.array(patchf) def getValueFromPatch(patch): value = [] for x,y in patch: value.append(img[y,x]) return np.array(value) def getRandomPatch(patchCoordFound): if (len(patchCoordFound) ==0): #TODO peut ĂȘtre trouver un patch autour du trou et verrifier que pas dans le trou x = randint(semiPatch,width-semiPatch-1) y = randint(semiPatch,height-semiPatch-1) patch = np.array([[i, j] for i in range(x - semiPatch, x + semiPatch + 1) for j in range(y - semiPatch, y + semiPatch + 1)]) else: patch = patchCoordFound[randint(0,len(patchCoordFound)-1)] return patch def getBestNeigbourPatch(zoneMask,filteredZoneValue,dist,patch,offset): voisin = [[-1,-1],[-1,0],[0,-1],[0,0],[1,-1],[-1,1],[0,1],[1,0],[1,1]] found = False bPatch = [] for x,y in voisin: nPatch = patch.copy() nPatch[:,0] += x*offset nPatch[:,1] += y*offset if np.any(nPatch < 0) or np.any(nPatch[:,0] >= width) or np.any(nPatch[:,1] >= height): #TODO verrifier que le patch est pas dans le troue si non ff continue nPatchValue = getValueFromPatch(nPatch) filteredPatchValue = applyMask(nPatchValue,zoneMask) nDist = getDist(filteredZoneValue,filteredPatchValue) if (nDist < dist): dist = nDist bPatch = nPatch found = True return found,bPatch,dist def getBestPatchForZone(zoneValue,zoneMask,patchCoordFound): filteredZoneValue = applyMask(zoneValue,zoneMask) patch = getRandomPatch(patchCoordFound) patchValue = getValueFromPatch(patch) filteredPatchValue = applyMask(patchValue,zoneMask) dist = getDist(filteredZoneValue,filteredPatchValue) offset = 1 while offset < min(width,height)/3: found, nPatch,nDist = getBestNeigbourPatch(zoneMask,filteredZoneValue,dist,patch,offset) if (found): patch = nPatch dist = nDist offset = 1 else: offset*=2 patchCoordFound.append(patch) return patchValue def applyPatch(zoneCoord,zoneMask, patchValue): filteredPatchValue = applyMask(patchValue,zoneMask,True) filteredZone = applyMask(zoneCoord,zoneMask,True) for i in range(len(filteredZone)) : img[filteredZone[i][1],filteredZone[i][0]] = filteredPatchValue[i] def updatePerimiter(zone,zoneMask,perimiter): filteredZone = applyMask(zone,zoneMask,True) for x,y in filteredZone: if ((x,y) in filteredZone): perimiter = np.delete(perimiter, np.where((perimiter == [x, y]).all(axis=1))[0], axis=0) voisin = [[-1,-1],[-1,0],[0,-1],[0,0],[1,-1],[-1,1],[0,1],[1,0],[1,1]] for x,y in filteredZone: for offsetx,offsety in voisin: if img[y+offsety,x+offsetx].sum() == 0: perimiter = np.vstack((perimiter, [x+offsetx, y+offsety])) return perimiter semiPatch = int(patchSize/2) height, width, _ = img.shape patchCoordFound = [] eadges = [] perimiter = initializePermimiter() it = 0 while len(perimiter)> 0: zone = getRandomPatchFromPerimiter(perimiter) zoneValue = getValueFromPatch(zone) zoneMask = getZoneMask(zoneValue,True) patchValue = getBestPatchForZone(zoneValue,zoneMask,patchCoordFound) applyPatch(zone,zoneMask,patchValue) perimiter = updatePerimiter(zone,zoneMask,perimiter) it +=1 print(it) return img # for x, y in zone: # if 0 <= x < width and 0 <= y < height: # img[y, x] = [255, 255, 255] # return img img = plt.imread('asset/mur.jpg') if img.dtype == np.float32: img = (img * 255).astype(np.uint8) img = img[:,:,0:3] def onselect(eclick, erelease): x1, y1 = eclick.xdata, eclick.ydata x2, y2 = erelease.xdata, erelease.ydata print("drawing") img_copy = np.copy(img) res = doPatchMatch(img_copy,int(x1),int(y1),int(x2),int(y2)) ax.imshow(res) plt.draw() print("drawed") fig, ax = plt.subplots() ax.imshow(img) toggle_selector = RectangleSelector(ax, onselect, useblit=True, button=[1], minspanx=5, minspany=5, spancoords='pixels', interactive=True) plt.axis('off') plt.show()