You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

156 lines
5.2 KiB

from matplotlib.widgets import RectangleSelector
import matplotlib.pyplot as plt
from random import randint
import numpy as np
import cv2
import time
from function import *
def reScale(img,scale):
height, width = img.shape[:2]
new_height = int(height / scale)
new_width = int(width / scale)
scaled_img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_AREA)
return scaled_img, new_height,new_width
def reScaleCoord(oWidth,oHeight,nWidth,nHeight,x1,y1,x2,y2):
x1, x2 = int(x1*nWidth/oWidth),int(x2*nWidth/oWidth)
y1, y2 = int(y1*nHeight/oHeight),int(y2*nHeight/oHeight)
return x1,y1,x2,y2
def getDist(pValue1, pValue2):
return np.sum((pValue1 - pValue2) ** 2)
def getRandomPatch(img2,pSize,x1,y1,x2,y2):
height, width = img2.shape[:2]
x = [randint(0,x1),randint(x2,width-pSize)][randint(0,1)]
y = [randint(0,y1),randint(y2,height-pSize)][randint(0,1)]
patch = getZoneFromCoord(x,y,pSize)
return patch
def getValueFromPatch(img,patch,pSize):
ret = img[patch[0][1]:patch[0][1]+pSize,patch[0][0]:patch[0][0]+pSize]
ret = ret.transpose(1, 0, 2)
return ret.reshape(-1, 3)
def applyPatch(img,zone,patchValue):
for i in range(len(zone)) :
img[zone[i][1],zone[i][0]] = patchValue[i]
return img
def findBestPatchFromNeigbour(zoneValue,oDist,patch,offset,height,width,img,pSize):
neigbour = [[-1,-1],[-1,0],[0,-1],[-1,1],[1,-1],[0,1],[1,0],[1,1]]
trouve = False
rP = patch
for x,y in neigbour:
p = patch.copy()
p[:,0] += x*offset
p[:,1] += y*offset
if np.any(p < 0) or np.any(p[:,0] >= width) or np.any(p[:,1] >= height):
continue
value = getValueFromPatch(img,p,pSize)
dist = getDist(zoneValue,value)
if (dist < oDist):
trouve = True
oDist = dist
rP = p
return trouve, rP, oDist
def findBestPatch(img2,zone,zoneValue,pSize,pixSize,height,width,x1,y1,x2,y2):
if not (x1<=zone[0][0]<=x2 and y1<=zone[0][1]):
patch = zone.copy()
return patch
patch = getRandomPatch(img2,int(pSize/pixSize)*2,x1,y1,x2,y2)
pValue = getValueFromPatch(img2,patch,pSize)
pdist = getDist(zoneValue,pValue)
for i in range(500):
tpatch = getRandomPatch(img2,int(pSize/pixSize)*2,x1,y1,x2,y2)
tpValue = getValueFromPatch(img2,tpatch,pSize)
tpdist = getDist(zoneValue,tpValue)
if tpdist<pdist:
pdist = tpdist
patch = tpatch
offset = 1
while offset < min(height,width)/3:
found, nPatch,nDist = findBestPatchFromNeigbour(zoneValue,pdist,patch,int(offset),height,width,img2,pSize)
if found:
patch = nPatch
pdist = nDist
offset = 1
else:
offset*=2
return patch
def getZoneFromCoord(x,y,patchSize):
zone = np.array([[i, j] for i in range(x, x + patchSize)
for j in range(y, y + patchSize)])
return zone
def rebuildImg(img1,img2,pixSize,x1,y1,x2,y2):
height,width = img1.shape[:2]
pSize = pixSize * 2
for x in range(int(width/pSize)):
for y in range(int(height/pSize)):
zone = getZoneFromCoord(x*pSize,y*pSize,pSize)
if not (x1<=x*pSize<=x2 and y1<=y*pSize<=y2):
zoneValue = getValueFromPatch(img2,zone,pSize)
applyPatch(img1,zone,zoneValue)
continue
zoneValue = getValueFromPatch(img1,zone,pSize)
patch = findBestPatch(img2,zone,zoneValue,pSize,pixSize,height,width,x1,y1,x2,y2)
patchValue = getValueFromPatch(img2,patch,pSize)
img1 = applyPatch(img1,zone,patchValue)
return img1
def doPatchMatch(image,x1,y1,x2,y2,scaleFactor=20,patchSize=129):
oHeight, oWidth = image.shape[:2]
rImage, nHeight, nWidth = reScale(image,scaleFactor)
nx1, ny1, nx2, ny2 = reScaleCoord(oWidth,oHeight,nWidth,nHeight,x1,y1,x2,y2)
rImage[ny1:ny2+1, nx1:nx2+1] = 0
rImage = initialPatchMatch(rImage,nx1,ny1,nx2,ny2,5)
while scaleFactor != 2:
scaleFactor -= 1
rImage, nHeight, nWidth = reScale(rImage,scaleFactor/(scaleFactor+1))
timg, h,w = reScale(image,scaleFactor)
nx1, ny1, nx2, ny2 = reScaleCoord(oWidth,oHeight,w,h,x1,y1,x2,y2)
rImage = rebuildImg(rImage,timg,int(h/nHeight),nx1,ny1,nx2,ny2)
tempRes, _, _= reScale(rImage,1/scaleFactor)
ax.imshow(tempRes)
plt.draw()
plt.pause(0.1)
nHeight = h
print(scaleFactor)
return tempRes
img = plt.imread('asset/vache.png')
if img.dtype == np.float32:
img = (img * 255).astype(np.uint8)
img = img[:,:,0:3]
def onselect(eclick, erelease):
x1, y1 = eclick.xdata, eclick.ydata
x2, y2 = erelease.xdata, erelease.ydata
print("drawing")
img_copy = np.copy(img)
res = doPatchMatch(img_copy,int(x1),int(y1),int(x2),int(y2))
ax.imshow(res)
plt.draw()
print("drawed")
fig, ax = plt.subplots()
ax.imshow(img)
toggle_selector = RectangleSelector(ax, onselect, useblit=True,
button=[1], minspanx=5, minspany=5, spancoords='pixels',
interactive=True)
plt.axis('off')
plt.show()