You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
135 lines
4.9 KiB
135 lines
4.9 KiB
from maskedImage import MaskedImage
|
|
import matplotlib.pyplot as plt
|
|
from nnf import Nnf
|
|
import numpy as np
|
|
import threading
|
|
import cv2
|
|
|
|
import concurrent.futures
|
|
|
|
def read(file):
|
|
img = plt.imread(file)
|
|
if img.dtype == np.float32:
|
|
img = (img * 255).astype(np.uint8)
|
|
img = img[:,:,0:3]
|
|
return img
|
|
|
|
def doTheInpainting(img,mask,radius):
|
|
def maximizeForTheScale(scale):
|
|
iterEM = 1+2*scale
|
|
iterNnf = min(7,1+scale)
|
|
source = sourceToTarget.input
|
|
target = targetToSource.output
|
|
newTarget = None
|
|
for emloop in range(1,iterEM+1):
|
|
if (newTarget != None):
|
|
targetToSource.input = newTarget
|
|
target = newTarget
|
|
newTarget = None
|
|
for y in range(source.height):
|
|
for x in range(source.width):
|
|
if not source.containsMask(x,y,radius):
|
|
targetToSource.field[y,x] = (x,y,0)
|
|
targetToSource.minimize(iterNnf)
|
|
|
|
upscaled = False
|
|
if scale>=1 and emloop == iterEM:
|
|
newSource = pyramid[scale-1]
|
|
newTarget = target.upscale(newSource.height,newSource.width)
|
|
upscaled = True
|
|
else:
|
|
newSource = pyramid[scale]
|
|
newTarget = target.copy()
|
|
upscaled = False
|
|
|
|
vote = np.zeros((newTarget.width, newTarget.height, 4))
|
|
ExpectationStep(targetToSource,vote,newSource,upscaled)
|
|
MaximizationStep(newTarget, vote)
|
|
result = cv2.resize(newTarget.image, (initial.width, initial.height), interpolation=cv2.INTER_AREA)
|
|
plt.imshow(result)
|
|
plt.pause(0.01)
|
|
return newTarget, sourceToTarget, targetToSource
|
|
|
|
initial = MaskedImage(img,mask)
|
|
pyramid = [initial]
|
|
source = initial
|
|
while source.width>radius and source.height>radius:
|
|
source = source.downsample()
|
|
pyramid.append(source)
|
|
maxLevel = len(pyramid)
|
|
|
|
for level in range(maxLevel-1,0,-1):
|
|
source = pyramid[level]
|
|
if (level == maxLevel-1):
|
|
target = source.copy()
|
|
target.mask[0:target.height,0:target.width] = False
|
|
sourceToTarget = Nnf(source,target,radius)
|
|
sourceToTarget.randomize()
|
|
|
|
targetToSource = Nnf(target,source,radius)
|
|
targetToSource.randomize()
|
|
else:
|
|
newNnf = Nnf(source,target,radius)
|
|
newNnf.initializeFromNnf(sourceToTarget)
|
|
sourceToTarget = newNnf
|
|
|
|
newNnfRev = Nnf(target,source,radius)
|
|
newNnfRev.initializeFromNnf(targetToSource)
|
|
targetToSource = newNnfRev
|
|
target, sourceToTarget, targetToSource = maximizeForTheScale(level)
|
|
plt.imshow(target.image)
|
|
plt.pause(0.01)
|
|
return target.image
|
|
|
|
def ExpectationStep(nnf, vote, source, upscale):
|
|
def ExpectationStepForNb(nb):
|
|
hei = nnf.input.height//7
|
|
for y in range(nb*hei,(nb+1)*hei if nb != 7 else nnf.input.height):
|
|
for x in range(nnf.input.width):
|
|
xp, yp, dp = nnf.field[y,x]
|
|
w = MaskedImage.similarity[dp]
|
|
for dy in range(-nnf.patchSize,nnf.patchSize):
|
|
for dx in range(-nnf.patchSize,nnf.patchSize):
|
|
xs = xp+dx
|
|
ys = yp+dy
|
|
xt = x+dx
|
|
yt = y+dy
|
|
if not 0<=xs<nnf.input.width:
|
|
continue
|
|
if not 0<=ys<nnf.input.height:
|
|
continue
|
|
if not 0<=xt<nnf.input.width:
|
|
continue
|
|
if not 0<=yt<nnf.input.height:
|
|
continue
|
|
if upscale:
|
|
weightedCopy(source,2*xs,2*ys,vote,2*xt,2*yt,w)
|
|
weightedCopy(source,2*xs+1,2*ys,vote,2*xt+1,2*yt,w)
|
|
weightedCopy(source,2*xs,2*ys+1,vote,2*xt,2*yt+1,w)
|
|
weightedCopy(source,2*xs+1,2*ys+1,vote,2*xt+1,2*yt+1,w)
|
|
else:
|
|
weightedCopy(source,xs,ys,vote,xt,yt,w)
|
|
|
|
pool = concurrent.futures.ThreadPoolExecutor(max_workers=8)
|
|
for i in range(8):
|
|
pool.submit(ExpectationStepForNb,i)
|
|
pool.shutdown(wait=True)
|
|
|
|
def weightedCopy(src,xs,ys,vote,xd,yd,w):
|
|
if src.mask[ys,xs]:
|
|
return
|
|
vote[xd,yd,0] += w*src.image[ys,xs,0]
|
|
vote[xd,yd,1] += w*src.image[ys,xs,1]
|
|
vote[xd,yd,2] += w*src.image[ys,xs,2]
|
|
vote[xd,yd,3] += w
|
|
|
|
def MaximizationStep(target,vote):
|
|
for y in range(target.height):
|
|
for x in range(target.width):
|
|
if vote[x,y,3]>0:
|
|
r = int(vote[x,y,0]/vote[x,y,3])
|
|
g = int(vote[x,y,1]/vote[x,y,3])
|
|
b = int(vote[x,y,2]/vote[x,y,3])
|
|
target.image[y,x] = (r,g,b)
|
|
target.mask[y,x] = False
|