# finetune_custom_loss_train.py import os import math import csv from tqdm import tqdm import torch import torch.nn.functional as F from torch.utils.data import DataLoader from finetune_wms_dataset import WMSDataset # your dataset from train_log.RIFE_HDv3 import Model # your OG model wrapper torch.backends.cudnn.benchmark = True CONFIG = { # data 'data_root': r'D:\main-projects\WMS_Cleaned', 'crop': 512, # 256 is fast & safe; try 384/512 if VRAM allows 'stride': 1, 'augment': True, # training 'epochs': 10, 'batch_size': 12, 'num_workers': 4, # learning rate schedule 'base_lr': 2e-4, 'min_lr': 1e-6, 'warmup': 1000, # steps # checkpoints 'resume': 'train_log', # baseline v3 weights dir 'save_dir': 'train_log_wms_custom_loss', # fine-tuned output dir with custom loss # device 'device': 'cuda', # custom loss knobs 'lambda_mw': 0.5, # motion-weighted Charbonnier 'lambda_ssim': 0.1, # (1 - SSIM) weight; set 0.0 to disable 'mw_tau': 0.05, # motion mask scale 'mw_gamma': 2.0, # motion mask focusing exponent 'smooth_w': 0.1, # flow smoothness weight } # -------- loss helpers -------- def charbonnier(x, eps=1e-3): return torch.sqrt(x * x + eps * eps) def motion_weight_from_imgs(I0, I1, tau=0.05, gamma=2.0): """ Cheap motion proxy from photometric change. I0, I1: [B,3,H,W] in [0,1] -> returns W in [0,1] """ w = (I1 - I0).abs().mean(1, keepdim=True) # [B,1,H,W] w = w / (w.mean(dim=[2,3], keepdim=True) + 1e-6) w = (w / (tau + 1e-6)).clamp(0, 4.0).pow(gamma) w = w / (w.amax(dim=[2,3], keepdim=True) + 1e-6) return w.clamp(0, 1) def motion_weighted_charbonnier(pred, gt, weight, eps=1e-3): diff = (pred - gt) loss_map = charbonnier(diff, eps=eps).mean(1, keepdim=True) # per-pixel scalar return (weight * loss_map).mean() def ssim_fast(pred, gt, C1=0.01**2, C2=0.03**2): """ Simple SSIM (avg-pool kernel ~11x11 via box-filter approximation). Inputs in [0,1], BCHW. Returns global mean. """ k = 11 mu1 = F.avg_pool2d(pred, k, 1, 0) mu2 = F.avg_pool2d(gt, k, 1, 0) mu1_sq, mu2_sq, mu12 = mu1*mu1, mu2*mu2, mu1*mu2 sigma1_sq = F.avg_pool2d(pred*pred, k, 1, 0) - mu1_sq sigma2_sq = F.avg_pool2d(gt*gt, k, 1, 0) - mu2_sq sigma12 = F.avg_pool2d(pred*gt, k, 1, 0) - mu12 ssim_map = ((2*mu12 + C1)*(2*sigma12 + C2)) / ((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2) + 1e-12) return ssim_map.clamp(0,1).mean() def psnr_batch(pred, gt, max_val=1.0): mse = torch.mean((pred - gt) ** 2, dim=[1,2,3]) return 10.0 * torch.log10((max_val ** 2) / (mse + 1e-12)) # -------- cosine lr -------- def get_cosine_lr(step, total_steps, base_lr=2e-4, min_lr=1e-6, warmup=1000): if step < warmup: return base_lr * (step / max(1, warmup)) prog = (step - warmup) / max(1, total_steps - warmup) return min_lr + 0.5 * (base_lr - min_lr) * (1 + math.cos(math.pi * prog)) def ensure_csv(path, header): exists = os.path.isfile(path) f = open(path, 'a', newline='', encoding='utf-8') w = csv.writer(f) if not exists: w.writerow(header) f.flush() return f, w def main(): cfg = CONFIG device = torch.device(cfg['device'] if torch.cuda.is_available() else 'cpu') # data train_set = WMSDataset(cfg['data_root'], 'train', cfg['crop'], cfg['stride'], 'rgb', cfg['augment']) val_set = WMSDataset(cfg['data_root'], 'val', cfg['crop'], cfg['stride'], 'rgb', False) pw = cfg['num_workers'] > 0 train_loader = DataLoader( train_set, batch_size=cfg['batch_size'], shuffle=True, num_workers=cfg['num_workers'], pin_memory=True, drop_last=True, prefetch_factor=(2 if pw else None), persistent_workers=pw ) val_loader = DataLoader( val_set, batch_size=cfg['batch_size']*2, shuffle=False, num_workers=cfg['num_workers'], pin_memory=True, drop_last=False, prefetch_factor=(2 if pw else None), persistent_workers=pw ) # model model = Model() model.load_model(cfg['resume'], -1) # loads baseline v3 checkpoint model.device() model.train() os.makedirs(cfg['save_dir'], exist_ok=True) # CSV logger csv_path = os.path.join(cfg['save_dir'], 'train_steps.csv') csv_file, csv_writer = ensure_csv(csv_path, header=["epoch","global_step","lr","loss_total","loss_l1","loss_mw","loss_ssim","loss_smooth"] ) global_step = 0 total_steps = cfg['epochs'] * max(1, len(train_loader)) best_psnr = -1e9 # persistent metrics line under the main bar status = tqdm(total=0, position=1, bar_format='{desc}', leave=True) status.set_description_str("loading…") for epoch in range(cfg['epochs']): # main progress bar (auto width) pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{cfg['epochs']}") for batch in pbar: data_gpu, _ = batch data_gpu = data_gpu.to(device, non_blocking=True) imgs = data_gpu[:, :6] # I0, I1 gt = data_gpu[:, 6:9] # It I0 = imgs[:, :3] I1 = imgs[:, 3:6] # LR schedule lr = get_cosine_lr( global_step, total_steps, base_lr=cfg['base_lr'], min_lr=cfg['min_lr'], warmup=cfg['warmup'] ) for pg in model.optimG.param_groups: pg['lr'] = lr # ===== TRAIN FORWARD WITH GRAD ===== # Call flownet directly so the graph is built (no no_grad). model.flownet.train() x = torch.cat((imgs, gt), 1) # [B, 9, H, W] scale_list = [4, 2, 1] flow_list, mask, merged = model.flownet(x, scale_list=scale_list, training=True) pred = merged[2].clamp(0, 1) # Losses W = motion_weight_from_imgs(I0, I1, tau=cfg['mw_tau'], gamma=cfg['mw_gamma']) loss_l1 = (pred - gt).abs().mean() loss_mw = motion_weighted_charbonnier(pred, gt, W) loss_ssim = (1.0 - ssim_fast(pred, gt)) if cfg['lambda_ssim'] > 0.0 else pred.new_tensor(0.0) loss_sm = model.sobel(flow_list[2], flow_list[2]*0).mean() # smoothness on finest flow loss_total = loss_l1 + cfg['lambda_mw']*loss_mw + cfg['lambda_ssim']*loss_ssim + cfg['smooth_w']*loss_sm # Step model.optimG.zero_grad(set_to_none=True) loss_total.backward() model.optimG.step() # update the second line only (no postfix on main bar) status.set_description_str( f"L1={loss_l1.item():.4f} MW={loss_mw.item():.4f} " f"SSIM={loss_ssim.item():.4f} SM={loss_sm.item():.4f} LR={lr:.2e}" ) # CSV log csv_writer.writerow([ epoch+1, global_step, f"{lr:.8f}", f"{loss_total.item():.6f}", f"{loss_l1.item():.6f}", f"{loss_mw.item():.6f}", f"{loss_ssim.item():.6f}", f"{loss_sm.item():.6f}", ]) if (global_step % 50) == 0: csv_file.flush() global_step += 1 # ===== VALIDATION (no grad) ===== model.eval() # clear status line for clean print status.set_description_str("") psnrs = [] with torch.no_grad(): for batch in val_loader: data_gpu, _ = batch data_gpu = data_gpu.to(device, non_blocking=True) imgs = data_gpu[:, :6] gt = data_gpu[:, 6:9] # we can reuse the OG update for inference pred, _ = model.update(imgs, gt, training=False) pred = pred.clamp(0, 1) psnrs.append(psnr_batch(pred, gt).cpu()) epoch_psnr = torch.cat(psnrs).mean().item() if psnrs else float('nan') model.train() model.save_model(cfg['save_dir'], 0) if epoch_psnr > best_psnr: best_psnr = epoch_psnr print(f"[Val] PSNR: {epoch_psnr:.3f} dB | best: {best_psnr:.3f} dB") # final flush & close csv_file.flush() csv_file.close() status.close() print(f"Training complete. Best PSNR: {best_psnr:.3f} dB.") print(f"Step logs saved to: {csv_path}") if __name__ == '__main__': main()