Wan2GP / rife /inference.py
zxymimi23451's picture
Upload 258 files
78360e7 verified
import os
import torch
from torch.nn import functional as F
# from .model.pytorch_msssim import ssim_matlab
from .ssim import ssim_matlab
from .RIFE_HDv3 import Model
def get_frame(frames, frame_no):
if frame_no >= frames.shape[1]:
return None
frame = (frames[:, frame_no] + 1) /2
frame = frame.clip(0., 1.)
return frame
def add_frame(frames, frame, h, w):
frame = (frame * 2) - 1
frame = frame.clip(-1., 1.)
frame = frame.squeeze(0)
frame = frame[:, :h, :w]
frame = frame.unsqueeze(1)
frames.append(frame.cpu())
def process_frames(model, device, frames, exp):
pos = 0
output_frames = []
lastframe = get_frame(frames, 0)
_, h, w = lastframe.shape
scale = 1
fp16 = False
def make_inference(I0, I1, n):
middle = model.inference(I0, I1, scale)
if n == 1:
return [middle]
first_half = make_inference(I0, middle, n=n//2)
second_half = make_inference(middle, I1, n=n//2)
if n%2:
return [*first_half, middle, *second_half]
else:
return [*first_half, *second_half]
tmp = max(32, int(32 / scale))
ph = ((h - 1) // tmp + 1) * tmp
pw = ((w - 1) // tmp + 1) * tmp
padding = (0, pw - w, 0, ph - h)
def pad_image(img):
if(fp16):
return F.pad(img, padding).half()
else:
return F.pad(img, padding)
I1 = lastframe.to(device, non_blocking=True).unsqueeze(0)
I1 = pad_image(I1)
temp = None # save lastframe when processing static frame
while True:
if temp is not None:
frame = temp
temp = None
else:
pos += 1
frame = get_frame(frames, pos)
if frame is None:
break
I0 = I1
I1 = frame.to(device, non_blocking=True).unsqueeze(0)
I1 = pad_image(I1)
I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False)
I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
break_flag = False
if ssim > 0.996 or pos > 100:
pos += 1
frame = get_frame(frames, pos)
if frame is None:
break_flag = True
frame = lastframe
else:
temp = frame
I1 = frame.to(device, non_blocking=True).unsqueeze(0)
I1 = pad_image(I1)
I1 = model.inference(I0, I1, scale)
I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
frame = I1[0][:, :h, :w]
if ssim < 0.2:
output = []
for _ in range((2 ** exp) - 1):
output.append(I0)
else:
output = make_inference(I0, I1, 2**exp-1) if exp else []
add_frame(output_frames, lastframe, h, w)
for mid in output:
add_frame(output_frames, mid, h, w)
lastframe = frame
if break_flag:
break
add_frame(output_frames, lastframe, h, w)
return torch.cat( output_frames, dim=1)
def temporal_interpolation(model_path, frames, exp, device ="cuda"):
model = Model()
model.load_model(model_path, -1, device=device)
model.eval()
model.to(device=device)
with torch.no_grad():
output = process_frames(model, device, frames.float(), exp)
return output