|
import os
|
|
import torch
|
|
from torch.nn import functional as F
|
|
|
|
from .ssim import ssim_matlab
|
|
|
|
from .RIFE_HDv3 import Model
|
|
|
|
def get_frame(frames, frame_no):
|
|
if frame_no >= frames.shape[1]:
|
|
return None
|
|
frame = (frames[:, frame_no] + 1) /2
|
|
frame = frame.clip(0., 1.)
|
|
return frame
|
|
|
|
def add_frame(frames, frame, h, w):
|
|
frame = (frame * 2) - 1
|
|
frame = frame.clip(-1., 1.)
|
|
frame = frame.squeeze(0)
|
|
frame = frame[:, :h, :w]
|
|
frame = frame.unsqueeze(1)
|
|
frames.append(frame.cpu())
|
|
|
|
def process_frames(model, device, frames, exp):
|
|
pos = 0
|
|
output_frames = []
|
|
|
|
lastframe = get_frame(frames, 0)
|
|
_, h, w = lastframe.shape
|
|
scale = 1
|
|
fp16 = False
|
|
|
|
def make_inference(I0, I1, n):
|
|
middle = model.inference(I0, I1, scale)
|
|
if n == 1:
|
|
return [middle]
|
|
first_half = make_inference(I0, middle, n=n//2)
|
|
second_half = make_inference(middle, I1, n=n//2)
|
|
if n%2:
|
|
return [*first_half, middle, *second_half]
|
|
else:
|
|
return [*first_half, *second_half]
|
|
|
|
tmp = max(32, int(32 / scale))
|
|
ph = ((h - 1) // tmp + 1) * tmp
|
|
pw = ((w - 1) // tmp + 1) * tmp
|
|
padding = (0, pw - w, 0, ph - h)
|
|
|
|
def pad_image(img):
|
|
if(fp16):
|
|
return F.pad(img, padding).half()
|
|
else:
|
|
return F.pad(img, padding)
|
|
|
|
I1 = lastframe.to(device, non_blocking=True).unsqueeze(0)
|
|
I1 = pad_image(I1)
|
|
temp = None
|
|
|
|
while True:
|
|
if temp is not None:
|
|
frame = temp
|
|
temp = None
|
|
else:
|
|
pos += 1
|
|
frame = get_frame(frames, pos)
|
|
if frame is None:
|
|
break
|
|
I0 = I1
|
|
I1 = frame.to(device, non_blocking=True).unsqueeze(0)
|
|
I1 = pad_image(I1)
|
|
I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False)
|
|
I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
|
|
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
|
|
|
|
break_flag = False
|
|
if ssim > 0.996 or pos > 100:
|
|
pos += 1
|
|
frame = get_frame(frames, pos)
|
|
if frame is None:
|
|
break_flag = True
|
|
frame = lastframe
|
|
else:
|
|
temp = frame
|
|
I1 = frame.to(device, non_blocking=True).unsqueeze(0)
|
|
I1 = pad_image(I1)
|
|
I1 = model.inference(I0, I1, scale)
|
|
I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
|
|
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
|
|
frame = I1[0][:, :h, :w]
|
|
|
|
if ssim < 0.2:
|
|
output = []
|
|
for _ in range((2 ** exp) - 1):
|
|
output.append(I0)
|
|
else:
|
|
output = make_inference(I0, I1, 2**exp-1) if exp else []
|
|
|
|
add_frame(output_frames, lastframe, h, w)
|
|
for mid in output:
|
|
add_frame(output_frames, mid, h, w)
|
|
lastframe = frame
|
|
if break_flag:
|
|
break
|
|
|
|
add_frame(output_frames, lastframe, h, w)
|
|
return torch.cat( output_frames, dim=1)
|
|
|
|
def temporal_interpolation(model_path, frames, exp, device ="cuda"):
|
|
|
|
model = Model()
|
|
model.load_model(model_path, -1, device=device)
|
|
|
|
model.eval()
|
|
model.to(device=device)
|
|
|
|
with torch.no_grad():
|
|
output = process_frames(model, device, frames.float(), exp)
|
|
|
|
return output
|
|
|