diff --git a/patchMatch.py b/patchMatch.py new file mode 100644 index 0000000..dc90ae1 --- /dev/null +++ b/patchMatch.py @@ -0,0 +1,210 @@ +from matplotlib.widgets import RectangleSelector +import matplotlib.pyplot as plt +import numpy as np + + + +def doPatchMatch(img,x1,y1,x2,y2,patchSize=17,nbRadomPatch=10): + + def getPatchFromCoord(x,y): + patch = np.array([[i, j] for i in range(patchSize) for j in range(patchSize)]) + patch[:,0] = patch[:,0] + x + patch[:,1] = patch[:,1] + y + return patch + + def distance(patchValue1,patchValue2): + mask = np.all(patchValue1 == [-1, -1, -1, -1], axis=-1) + return np.sum((patchValue1[~mask] - patchValue2[~mask]) ** 2) + + def getBestNeigbourPatch(xy,ogValue,ogDist,step): + x, y = xy + + dist = -1 + + xt, yt = x+step, y + if (0 <= xt <= width - patchSize and 0 <= yt <= height - patchSize): + patch = getPatchFromCoord(xt,yt) + patchValue = patchToValue(patch) + dist = distance(ogValue,patchValue) + + xt, yt = x-step, y + if (0 <= xt <= width - patchSize and 0 <= yt <= height - patchSize): + tpatch = getPatchFromCoord(xt,yt) + tpatchValue = patchToValue(tpatch) + tdist = distance(ogValue,tpatchValue) + if tdist < dist or dist == -1: + dist = tdist + patch = tpatch + patchValue = tpatchValue + + xt, yt = x, y+step + if (0 <= xt <= width - patchSize and 0 <= yt <= height - patchSize): + tpatch = getPatchFromCoord(xt,yt) + tpatchValue = patchToValue(tpatch) + tdist = distance(ogValue,tpatchValue) + if tdist < dist or dist == -1: + dist = tdist + patch = tpatch + patchValue = tpatchValue + + xt, yt = x, y-step + if (0 <= xt <= width - patchSize and 0 <= yt <= height - patchSize): + tpatch = getPatchFromCoord(xt,yt) + tpatchValue = patchToValue(tpatch) + tdist = distance(ogValue,tpatchValue) + if tdist < dist or dist == -1: + dist = tdist + patch = tpatch + patchValue = tpatchValue + if dist == -1: + return False, None, None, None + return dist < ogDist, patch, patchValue, dist + + def getTheBestPatch(addr,ogValue): + patchs = [] + patchsValue = [] + dists = [] + for i in range(nbRadomPatch): + x,y = getRandomPatch() + patch = getPatchFromCoord(x,y) + patchValue = patchToValue(patch) + dist = distance(ogValue,patchValue) + patchs.append(patch) + patchsValue.append(patchValue) + dists.append(dist) + + minIdx = np.argmin(np.array(dist)) + patch = patchs[minIdx] + patchValue = patchsValue[minIdx] + + ogDist = dists[minIdx] + foundNew = True + step = 5 + addr = addr[0] + while foundNew: + foundNew, tpatch, tpatchValue, tdist = getBestNeigbourPatch(addr,ogValue,ogDist,step) + if (foundNew): + addr = tpatch[patchSize//2] + patch = tpatch + patchValue = tpatchValue + ogDist = tdist + step = 5 + else: + step = step*1.25 + foundNew = step < min(width, height)/2 + return patch, patchValue + + + def patchToValue(patch): + return img[patch[0][1]:patch[len(patch)-1][1], patch[0][0]:patch[len(patch)-1][0]] + + def getRandomPatch(): + rx = np.random.randint(0, width - patchSize) + ry = np.random.randint(0, height - patchSize) + return rx, ry + + def initializePermimiter(): + perimeter = [] + for x in range(x1, x2 + 1): + perimeter.append((x, y1)) + perimeter.append((x, y2)) + + for y in range(y1 + 1, y2): + perimeter.append((x1, y)) + perimeter.append((x2, y)) + img[y1:y2+1, x1:x2+1] = -1 + return np.array(perimeter) + + def removeAndAddFromPerimiter(perimiter, addr): + p= [] + npAddr = np.array(addr) + for coord in perimiter: + if not np.any(np.all(npAddr == np.array(coord), axis=1)): + p.append(coord) + + perimiter = p + p1 = patchSize+2 + for dx in range(-1, p1): + for dy in range(-1, p1): + if (dx!=-1 and dx!=p1-1 and dy != -1 and dy != p1-1): + continue + nx, ny = addr[0,0] + dx, addr[0,1] + dy + if 0 <= nx < width and 0 <= ny < height and img[ny, nx][0] == -1: + if len(perimiter) == 0: + perimiter.append([nx, ny]) + continue + if not np.any(np.all(perimiter == np.array([int(nx), int(ny)]), axis=1)): + perimiter.append([nx, ny]) + return perimiter + + def applyPatch(patch,addr): + for i in range(len(addr)): + if img[addr[i, 1], addr[i, 0]][0] == -1: + img[addr[i, 1], addr[i, 0]] = img[patch[i, 1],patch[i, 0]] + + def getRandomFromPerimiter(perimiter): + return perimiter[np.random.randint(len(perimiter))] + + def loop(perimiter): + x,y = getRandomFromPerimiter(perimiter) + addr = getPatchFromCoord(x,y) + ogValue = patchToValue(addr) + patch,patchValue = getTheBestPatch(addr,ogValue) + applyPatch(patch,addr) + perimiter = removeAndAddFromPerimiter(perimiter,addr) + return perimiter + + + + + + semiPatch = int(patchSize/2) + height, width, _ = img.shape + + + + perimiter = initializePermimiter() + it = 0 + + # perimiter = loop(perimiter) + # perimiter = loop(perimiter) + # perimiter = loop(perimiter) + # for coord in perimiter: + # img[coord[1], coord[0]] = [1,1,1,1] + # img[img == -1] = 0 + + while len(perimiter)> 0: + it += 1 + perimiter = loop(perimiter) + if (it == 1000): + it = 0 + print(len(perimiter)) + + return img + + + + +img = plt.imread('asset/boat.png') + +if len(img.shape) == 2: + img = np.stack((img,)*3, axis=-1) + +def onselect(eclick, erelease): + x1, y1 = eclick.xdata, eclick.ydata + x2, y2 = erelease.xdata, erelease.ydata + + print("drawing") + img_copy = np.copy(img) + res = doPatchMatch(img_copy,int(x1),int(y1),int(x2),int(y2),33) + ax.imshow(res) + plt.draw() + print("drawed") + +fig, ax = plt.subplots() +ax.imshow(img) +toggle_selector = RectangleSelector(ax, onselect, useblit=True, + button=[1], minspanx=5, minspany=5, spancoords='pixels', + interactive=True) +plt.axis('off') +plt.show()