File size: 2,887 Bytes
78360e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import torch
import torch.nn as nn
import numpy as np
from torch.optim import AdamW
import torch.optim as optim
import itertools
from torch.nn.parallel import DistributedDataParallel as DDP
from .IFNet_HDv3 import *
import torch.nn.functional as F
# from ..model.loss import *
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Model:
def __init__(self, local_rank=-1):
self.flownet = IFNet()
# self.device()
# self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-4)
# self.epe = EPE()
# self.vgg = VGGPerceptualLoss().to(device)
# self.sobel = SOBEL()
if local_rank != -1:
self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank)
def train(self):
self.flownet.train()
def eval(self):
self.flownet.eval()
def to(self, device):
self.flownet.to(device)
def load_model(self, path, rank=0, device = "cuda"):
self.device = device
def convert(param):
if rank == -1:
return {
k.replace("module.", ""): v
for k, v in param.items()
if "module." in k
}
else:
return param
self.flownet.load_state_dict(convert(torch.load(path, map_location=device)))
def save_model(self, path, rank=0):
if rank == 0:
torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path))
def inference(self, img0, img1, scale=1.0):
imgs = torch.cat((img0, img1), 1)
scale_list = [4/scale, 2/scale, 1/scale]
flow, mask, merged = self.flownet(imgs, scale_list)
return merged[2]
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
for param_group in self.optimG.param_groups:
param_group['lr'] = learning_rate
img0 = imgs[:, :3]
img1 = imgs[:, 3:]
if training:
self.train()
else:
self.eval()
scale = [4, 2, 1]
flow, mask, merged = self.flownet(torch.cat((imgs, gt), 1), scale=scale, training=training)
loss_l1 = (merged[2] - gt).abs().mean()
loss_smooth = self.sobel(flow[2], flow[2]*0).mean()
# loss_vgg = self.vgg(merged[2], gt)
if training:
self.optimG.zero_grad()
loss_G = loss_cons + loss_smooth * 0.1
loss_G.backward()
self.optimG.step()
else:
flow_teacher = flow[2]
return merged[2], {
'mask': mask,
'flow': flow[2][:, :2],
'loss_l1': loss_l1,
'loss_cons': loss_cons,
'loss_smooth': loss_smooth,
}
|