import os import cv2 import glob import numpy as np import torch from torch.nn import functional as F CONFIG = { 'folder': '/path/to/image/folder', 'model_dir': 'train_log_wms', 'exp': 1, # insert 2^exp-1 frames between each pair 'out_dir': 'wms_interp_frames', 'out_video': None, # optional mp4 path to write a video 'fps': 30, # video fps if out_video is set } try: from model.RIFE_HDv2 import Model except Exception: try: from train_log.RIFE_HDv3 import Model except Exception: try: from model.RIFE_HD import Model except Exception: from model.RIFE import Model IMG_EXTS = ('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff', '.webp') def _natural_key(s: str): import re return [int(t) if t.isdigit() else t.lower() for t in re.split(r'(\d+)', s)] def list_images_sorted(folder): files = [os.path.join(folder, f) for f in os.listdir(folder) if f.lower().endswith(IMG_EXTS)] files.sort(key=_natural_key) return files def pad32(t): n, c, h, w = t.shape ph = ((h - 1) // 32 + 1) * 32 pw = ((w - 1) // 32 + 1) * 32 pad = (0, pw - w, 0, ph - h) return F.pad(t, pad), pad, (h, w) def main(): cfg = CONFIG device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Model() model.load_model(cfg['model_dir'], -1) model.eval(); model.device() imgs = list_images_sorted(cfg['folder']) if len(imgs) < 2: raise RuntimeError('Need at least two images in the folder.') os.makedirs(cfg['out_dir'], exist_ok=True) writer = None last_frame_bgr = None for idx in range(len(imgs)-1): p0, p1 = imgs[idx], imgs[idx+1] I0 = cv2.imread(p0, cv2.IMREAD_UNCHANGED) I1 = cv2.imread(p1, cv2.IMREAD_UNCHANGED) if I0.ndim == 2: I0 = cv2.cvtColor(I0, cv2.COLOR_GRAY2RGB) if I1.ndim == 2: I1 = cv2.cvtColor(I1, cv2.COLOR_GRAY2RGB) H, W = I0.shape[:2] T0 = torch.from_numpy(I0[:,:,::-1].transpose(2,0,1)).float().unsqueeze(0)/255.0 T1 = torch.from_numpy(I1[:,:,::-1].transpose(2,0,1)).float().unsqueeze(0)/255.0 T0, pad, (h,w) = pad32(T0); T1, _, _ = pad32(T1) def recurse(a, b, n): mid = model.inference(a, b) if n == 1: return [mid] left = recurse(a, mid, n//2) right = recurse(mid, b, n//2) return left + right if n % 2 == 0 else left + [mid] + right mids = recurse(T0.to(device), T1.to(device), max(1, 2**cfg['exp'] - 1)) if cfg['exp']>0 else [] # write frames if last_frame_bgr is None: cv2.imwrite(os.path.join(cfg['out_dir'], f'{idx:07d}.png'), I0) last_frame_bgr = I0 for k, m in enumerate(mids): m = (m[0].detach().cpu()*255).byte().numpy().transpose(1,2,0) m = m[:h, :w, ::-1] cv2.imwrite(os.path.join(cfg['out_dir'], f'{idx:07d}_{k+1}.png'), m) last_frame_bgr = m # finally write I1 at the end of the interval if idx == len(imgs)-2: cv2.imwrite(os.path.join(cfg['out_dir'], f'{idx+1:07d}.png'), I1) if cfg['out_video']: # assemble PNGs into a video frames = [f for f in os.listdir(cfg['out_dir']) if f.lower().endswith('.png')] frames.sort() if not frames: raise RuntimeError('No frames written; cannot make video.') sample = cv2.imread(os.path.join(cfg['out_dir'], frames[0])) h,w = sample.shape[:2] fourcc = cv2.VideoWriter_fourcc(*'mp4v') vw = cv2.VideoWriter(cfg['out_video'], fourcc, cfg['fps'], (w,h)) for f in frames: img = cv2.imread(os.path.join(cfg['out_dir'], f)) vw.write(img) vw.release() print(f'Wrote video: {cfg['out_video']}') print(f'Done. Frames in: {cfg['out_dir']}') if __name__ == '__main__': main()