You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

167 lines
5.8 KiB

from matplotlib.widgets import RectangleSelector
import matplotlib.pyplot as plt
from random import randint
import numpy as np
def doPatchMatch(img,x1,y1,x2,y2,patchSize=65):
def getDist(pValue1, pValue2):
return np.sum((pValue1 - pValue2) ** 2)
def initializePermimiter():
perimeter = []
for x in range(x1, x2 + 1):
perimeter.append((x, y1))
perimeter.append((x, y2))
for y in range(y1 + 1, y2):
perimeter.append((x1, y))
perimeter.append((x2, y))
img[y1:y2+1, x1:x2+1] = 0
return np.array(perimeter)
def getRandomPatchFromPerimiter(perimiter):
x,y = perimiter[np.random.randint(len(perimiter))]
patch = np.array([[i, j] for i in range(x - semiPatch, x + semiPatch + 1)
for j in range(y - semiPatch, y + semiPatch + 1)])
return patch
def getZoneMask(zoneValue,outside):
mask = []
for value in zoneValue:
mask.append((value.sum() == 0) ^outside)
return np.array(mask)
def applyMask(patch,mask,oposed=False):
patchf = []
for i in range(len(mask)):
if(mask[i]^oposed):
patchf.append(patch[i])
return np.array(patchf)
def getValueFromPatch(patch):
value = []
for x,y in patch:
value.append(img[y,x])
return np.array(value)
def getRandomPatch(patchCoordFound):
if (len(patchCoordFound) ==0):
#TODO peut être trouver un patch autour du trou et verrifier que pas dans le trou
x = randint(semiPatch,width-semiPatch-1)
y = randint(semiPatch,height-semiPatch-1)
patch = np.array([[i, j] for i in range(x - semiPatch, x + semiPatch + 1)
for j in range(y - semiPatch, y + semiPatch + 1)])
else:
patch = patchCoordFound[randint(0,len(patchCoordFound)-1)]
return patch
def getBestNeigbourPatch(zoneMask,filteredZoneValue,dist,patch,offset):
voisin = [[-1,-1],[-1,0],[0,-1],[0,0],[1,-1],[-1,1],[0,1],[1,0],[1,1]]
found = False
bPatch = []
for x,y in voisin:
nPatch = patch.copy()
nPatch[:,0] += x*offset
nPatch[:,1] += y*offset
if np.any(nPatch < 0) or np.any(nPatch[:,0] >= width) or np.any(nPatch[:,1] >= height):
#TODO verrifier que le patch est pas dans le troue si non ff
continue
nPatchValue = getValueFromPatch(nPatch)
filteredPatchValue = applyMask(nPatchValue,zoneMask)
nDist = getDist(filteredZoneValue,filteredPatchValue)
if (nDist < dist):
dist = nDist
bPatch = nPatch
found = True
return found,bPatch,dist
def getBestPatchForZone(zoneValue,zoneMask,patchCoordFound):
filteredZoneValue = applyMask(zoneValue,zoneMask)
patch = getRandomPatch(patchCoordFound)
patchValue = getValueFromPatch(patch)
filteredPatchValue = applyMask(patchValue,zoneMask)
dist = getDist(filteredZoneValue,filteredPatchValue)
offset = 1
while offset < min(width,height)/3:
found, nPatch,nDist = getBestNeigbourPatch(zoneMask,filteredZoneValue,dist,patch,offset)
if (found):
patch = nPatch
dist = nDist
offset = 1
else:
offset*=2
patchCoordFound.append(patch)
return patchValue
def applyPatch(zoneCoord,zoneMask, patchValue):
filteredPatchValue = applyMask(patchValue,zoneMask,True)
filteredZone = applyMask(zoneCoord,zoneMask,True)
for i in range(len(filteredZone)) :
img[filteredZone[i][1],filteredZone[i][0]] = filteredPatchValue[i]
def updatePerimiter(zone,zoneMask,perimiter):
filteredZone = applyMask(zone,zoneMask,True)
for x,y in filteredZone:
if ((x,y) in filteredZone):
perimiter = np.delete(perimiter, np.where((perimiter == [x, y]).all(axis=1))[0], axis=0)
voisin = [[-1,-1],[-1,0],[0,-1],[0,0],[1,-1],[-1,1],[0,1],[1,0],[1,1]]
for x,y in filteredZone:
for offsetx,offsety in voisin:
if img[y+offsety,x+offsetx].sum() == 0:
perimiter = np.vstack((perimiter, [x+offsetx, y+offsety]))
return perimiter
semiPatch = int(patchSize/2)
height, width, _ = img.shape
patchCoordFound = []
eadges = []
perimiter = initializePermimiter()
it = 0
while len(perimiter)> 0:
zone = getRandomPatchFromPerimiter(perimiter)
zoneValue = getValueFromPatch(zone)
zoneMask = getZoneMask(zoneValue,True)
patchValue = getBestPatchForZone(zoneValue,zoneMask,patchCoordFound)
applyPatch(zone,zoneMask,patchValue)
perimiter = updatePerimiter(zone,zoneMask,perimiter)
it +=1
print(it)
return img
# for x, y in zone:
# if 0 <= x < width and 0 <= y < height:
# img[y, x] = [255, 255, 255]
# return img
img = plt.imread('asset/mur.jpg')
if img.dtype == np.float32:
img = (img * 255).astype(np.uint8)
img = img[:,:,0:3]
def onselect(eclick, erelease):
x1, y1 = eclick.xdata, eclick.ydata
x2, y2 = erelease.xdata, erelease.ydata
print("drawing")
img_copy = np.copy(img)
res = doPatchMatch(img_copy,int(x1),int(y1),int(x2),int(y2))
ax.imshow(res)
plt.draw()
print("drawed")
fig, ax = plt.subplots()
ax.imshow(img)
toggle_selector = RectangleSelector(ax, onselect, useblit=True,
button=[1], minspanx=5, minspany=5, spancoords='pixels',
interactive=True)
plt.axis('off')
plt.show()