You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
156 lines
5.2 KiB
156 lines
5.2 KiB
from matplotlib.widgets import RectangleSelector
|
|
import matplotlib.pyplot as plt
|
|
from random import randint
|
|
import numpy as np
|
|
import cv2
|
|
import time
|
|
from function import *
|
|
|
|
def reScale(img,scale):
|
|
height, width = img.shape[:2]
|
|
new_height = int(height / scale)
|
|
new_width = int(width / scale)
|
|
scaled_img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_AREA)
|
|
return scaled_img, new_height,new_width
|
|
|
|
def reScaleCoord(oWidth,oHeight,nWidth,nHeight,x1,y1,x2,y2):
|
|
x1, x2 = int(x1*nWidth/oWidth),int(x2*nWidth/oWidth)
|
|
y1, y2 = int(y1*nHeight/oHeight),int(y2*nHeight/oHeight)
|
|
return x1,y1,x2,y2
|
|
|
|
def getDist(pValue1, pValue2):
|
|
return np.sum((pValue1 - pValue2) ** 2)
|
|
|
|
def getRandomPatch(img2,pSize,x1,y1,x2,y2):
|
|
height, width = img2.shape[:2]
|
|
x = [randint(0,x1),randint(x2,width-pSize)][randint(0,1)]
|
|
y = [randint(0,y1),randint(y2,height-pSize)][randint(0,1)]
|
|
patch = getZoneFromCoord(x,y,pSize)
|
|
return patch
|
|
|
|
def getValueFromPatch(img,patch,pSize):
|
|
ret = img[patch[0][1]:patch[0][1]+pSize,patch[0][0]:patch[0][0]+pSize]
|
|
ret = ret.transpose(1, 0, 2)
|
|
return ret.reshape(-1, 3)
|
|
|
|
def applyPatch(img,zone,patchValue):
|
|
for i in range(len(zone)) :
|
|
img[zone[i][1],zone[i][0]] = patchValue[i]
|
|
return img
|
|
|
|
def findBestPatchFromNeigbour(zoneValue,oDist,patch,offset,height,width,img,pSize):
|
|
neigbour = [[-1,-1],[-1,0],[0,-1],[-1,1],[1,-1],[0,1],[1,0],[1,1]]
|
|
trouve = False
|
|
rP = patch
|
|
for x,y in neigbour:
|
|
p = patch.copy()
|
|
p[:,0] += x*offset
|
|
p[:,1] += y*offset
|
|
if np.any(p < 0) or np.any(p[:,0] >= width) or np.any(p[:,1] >= height):
|
|
continue
|
|
value = getValueFromPatch(img,p,pSize)
|
|
dist = getDist(zoneValue,value)
|
|
if (dist < oDist):
|
|
trouve = True
|
|
oDist = dist
|
|
rP = p
|
|
return trouve, rP, oDist
|
|
|
|
def findBestPatch(img2,zone,zoneValue,pSize,pixSize,height,width,x1,y1,x2,y2):
|
|
if not (x1<=zone[0][0]<=x2 and y1<=zone[0][1]):
|
|
patch = zone.copy()
|
|
return patch
|
|
|
|
patch = getRandomPatch(img2,int(pSize/pixSize)*2,x1,y1,x2,y2)
|
|
pValue = getValueFromPatch(img2,patch,pSize)
|
|
pdist = getDist(zoneValue,pValue)
|
|
for i in range(500):
|
|
tpatch = getRandomPatch(img2,int(pSize/pixSize)*2,x1,y1,x2,y2)
|
|
tpValue = getValueFromPatch(img2,tpatch,pSize)
|
|
tpdist = getDist(zoneValue,tpValue)
|
|
if tpdist<pdist:
|
|
pdist = tpdist
|
|
patch = tpatch
|
|
offset = 1
|
|
while offset < min(height,width)/3:
|
|
found, nPatch,nDist = findBestPatchFromNeigbour(zoneValue,pdist,patch,int(offset),height,width,img2,pSize)
|
|
if found:
|
|
patch = nPatch
|
|
pdist = nDist
|
|
offset = 1
|
|
else:
|
|
offset*=2
|
|
return patch
|
|
|
|
def getZoneFromCoord(x,y,patchSize):
|
|
zone = np.array([[i, j] for i in range(x, x + patchSize)
|
|
for j in range(y, y + patchSize)])
|
|
return zone
|
|
|
|
def rebuildImg(img1,img2,pixSize,x1,y1,x2,y2):
|
|
height,width = img1.shape[:2]
|
|
pSize = pixSize * 2
|
|
for x in range(int(width/pSize)):
|
|
for y in range(int(height/pSize)):
|
|
zone = getZoneFromCoord(x*pSize,y*pSize,pSize)
|
|
if not (x1<=x*pSize<=x2 and y1<=y*pSize<=y2):
|
|
zoneValue = getValueFromPatch(img2,zone,pSize)
|
|
applyPatch(img1,zone,zoneValue)
|
|
continue
|
|
zoneValue = getValueFromPatch(img1,zone,pSize)
|
|
patch = findBestPatch(img2,zone,zoneValue,pSize,pixSize,height,width,x1,y1,x2,y2)
|
|
patchValue = getValueFromPatch(img2,patch,pSize)
|
|
img1 = applyPatch(img1,zone,patchValue)
|
|
return img1
|
|
|
|
def doPatchMatch(image,x1,y1,x2,y2,scaleFactor=20,patchSize=129):
|
|
oHeight, oWidth = image.shape[:2]
|
|
rImage, nHeight, nWidth = reScale(image,scaleFactor)
|
|
nx1, ny1, nx2, ny2 = reScaleCoord(oWidth,oHeight,nWidth,nHeight,x1,y1,x2,y2)
|
|
rImage[ny1:ny2+1, nx1:nx2+1] = 0
|
|
rImage = initialPatchMatch(rImage,nx1,ny1,nx2,ny2,5)
|
|
|
|
while scaleFactor != 2:
|
|
scaleFactor -= 1
|
|
rImage, nHeight, nWidth = reScale(rImage,scaleFactor/(scaleFactor+1))
|
|
timg, h,w = reScale(image,scaleFactor)
|
|
nx1, ny1, nx2, ny2 = reScaleCoord(oWidth,oHeight,w,h,x1,y1,x2,y2)
|
|
rImage = rebuildImg(rImage,timg,int(h/nHeight),nx1,ny1,nx2,ny2)
|
|
tempRes, _, _= reScale(rImage,1/scaleFactor)
|
|
ax.imshow(tempRes)
|
|
plt.draw()
|
|
plt.pause(0.1)
|
|
nHeight = h
|
|
print(scaleFactor)
|
|
|
|
return tempRes
|
|
|
|
|
|
|
|
|
|
|
|
img = plt.imread('asset/vache.png')
|
|
if img.dtype == np.float32:
|
|
img = (img * 255).astype(np.uint8)
|
|
img = img[:,:,0:3]
|
|
|
|
def onselect(eclick, erelease):
|
|
x1, y1 = eclick.xdata, eclick.ydata
|
|
x2, y2 = erelease.xdata, erelease.ydata
|
|
|
|
print("drawing")
|
|
img_copy = np.copy(img)
|
|
res = doPatchMatch(img_copy,int(x1),int(y1),int(x2),int(y2))
|
|
|
|
ax.imshow(res)
|
|
plt.draw()
|
|
print("drawed")
|
|
|
|
fig, ax = plt.subplots()
|
|
ax.imshow(img)
|
|
toggle_selector = RectangleSelector(ax, onselect, useblit=True,
|
|
button=[1], minspanx=5, minspany=5, spancoords='pixels',
|
|
interactive=True)
|
|
|
|
plt.axis('off')
|
|
plt.show() |