import os import math import time import numpy as np from tqdm import tqdm import torch from torch.utils.data import DataLoader from finetune_wms_dataset import WMSDataset torch.backends.cudnn.benchmark = True from train_log.RIFE_HDv3 import Model CONFIG = { 'data_root': r'D:\main-projects\WMS_Cleaned', 'save_dir': 'train_log_wms', 'resume': 'train_log', 'epochs': 5, 'batch_size': 16, 'num_workers': 8, 'crop': 512, 'stride': 1, 'lr': 2e-4, 'device': 'cuda', 'use_motion_weight': True, 'motion_lambda': 0.5, } def main(): cfg = CONFIG device = torch.device(cfg['device'] if torch.cuda.is_available() else 'cpu') # Datasets & loaders train_set = WMSDataset(cfg['data_root'], split='train', crop_size=cfg['crop'], stride=cfg['stride'], color_mode='rgb') val_set = WMSDataset(cfg['data_root'], split='val', crop_size=cfg['crop'], stride=cfg['stride'], color_mode='rgb', augment=False) train_loader = DataLoader(train_set, batch_size=cfg['batch_size'], shuffle=True, num_workers=cfg['num_workers'], pin_memory=True, drop_last=True, prefetch_factor=2, persistent_workers=True) val_loader = DataLoader(val_set, batch_size=cfg['batch_size']*2, shuffle=False, num_workers=cfg['num_workers'], pin_memory=True, drop_last=False, prefetch_factor=2, persistent_workers=True) # Model model = Model() # Load pretrained baseline weights model.load_model(cfg['resume'], -1) model.device() print("Model class:", type(model).__module__) try: print("Flow net module:", type(model.flownet).__module__) except Exception: pass os.makedirs(cfg['save_dir'], exist_ok=True) # Training loop (using model.update like your original train.py) global_step = 0 total_steps = cfg['epochs'] * max(1, len(train_loader)) # Evaluation function def get_cosine_lr(step, total_steps, base_lr=2e-4, min_lr=1e-6, warmup=1000): if step < warmup: return base_lr * (step / max(1, warmup)) prog = (step - warmup) / max(1, total_steps - warmup) return min_lr + 0.5 * (base_lr - min_lr) * (1 + math.cos(math.pi * prog)) def eval_once(): model.eval() psnrs = [] with torch.no_grad(): for batch in val_loader: data_gpu, timestep = batch data_gpu = data_gpu.to(device, non_blocking=True) imgs = data_gpu[:, :6] # I0, I1 gt = data_gpu[:, 6:9] # It # forward without update pred, info = model.update(imgs, gt, training=False) # PSNR mse = torch.mean((gt - pred) ** 2, dim=[1,2,3]) psnr = -10.0 * torch.log10(mse + 1e-8) psnrs.append(psnr.cpu()) if psnrs: psnr_mean = torch.cat(psnrs).mean().item() else: psnr_mean = float('nan') model.train() return psnr_mean best_psnr = -1e9 model.train() for epoch in range(cfg['epochs']): pbar = tqdm(train_loader, ncols=100, desc=f"Epoch {epoch+1}/{cfg['epochs']}") for batch in pbar: data_gpu, timestep = batch data_gpu = data_gpu.to(device, non_blocking=True) imgs = data_gpu[:, :6] # I0, I1 gt = data_gpu[:, 6:9] # It lr = get_cosine_lr(global_step, total_steps, base_lr=cfg['lr'], min_lr=1e-6, warmup=1000) # Primary RIFE update (internally handles optimizer) pred, info = model.update(imgs, gt, learning_rate=lr, training=True) # Optional extra motion-aware L1 (best-effort) if cfg['use_motion_weight'] and hasattr(model, 'optimizer') and model.optimizer is not None: # compute weight from |I1 - I0| magnitude I0 = imgs[:, :3] I1 = imgs[:, 3:6] w = torch.mean(torch.abs(I1 - I0), dim=1, keepdim=True) # Bx1xHxW w = w / (w.mean(dim=[1,2,3], keepdim=True) + 1e-6) extra = (w * torch.abs(pred - gt)).mean() model.optimizer.zero_grad(set_to_none=True) extra.backward() model.optimizer.step() pbar.set_postfix({'l1': float(info.get('loss_l1', 0.0)), 'lr': lr}) global_step += 1 # Eval + save val_psnr = eval_once() ckpt_dir = cfg['save_dir'] model.save_model(ckpt_dir, 0) if val_psnr > best_psnr: best_psnr = val_psnr print(f"Training complete. Best PSNR: {best_psnr:.3f} dB. Models saved to {cfg['save_dir']}") if __name__ == '__main__': main()