|
|
import torch |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import cv2 |
|
|
import os |
|
|
|
|
|
def get_clothes_mask(old_label) : |
|
|
clothes = torch.FloatTensor((old_label.cpu().numpy() == 3).astype(np.int)) |
|
|
return clothes |
|
|
|
|
|
def changearm(old_label): |
|
|
label=old_label |
|
|
arm1=torch.FloatTensor((old_label.cpu().numpy()==5).astype(np.int)) |
|
|
arm2=torch.FloatTensor((old_label.cpu().numpy()==6).astype(np.int)) |
|
|
label=label*(1-arm1)+arm1*3 |
|
|
label=label*(1-arm2)+arm2*3 |
|
|
return label |
|
|
|
|
|
def gen_noise(shape): |
|
|
noise = np.zeros(shape, dtype=np.uint8) |
|
|
|
|
|
noise = cv2.randn(noise, 0, 255) |
|
|
noise = np.asarray(noise / 255, dtype=np.uint8) |
|
|
noise = torch.tensor(noise, dtype=torch.float32) |
|
|
return noise |
|
|
|
|
|
def cross_entropy2d(input, target, weight=None, size_average=True): |
|
|
n, c, h, w = input.size() |
|
|
nt, ht, wt = target.size() |
|
|
|
|
|
|
|
|
if h != ht or w != wt: |
|
|
input = F.interpolate(input, size=(ht, wt), mode="bilinear", align_corners=True) |
|
|
|
|
|
input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) |
|
|
target = target.view(-1) |
|
|
loss = F.cross_entropy( |
|
|
input, target, weight=weight, size_average=size_average, ignore_index=250 |
|
|
) |
|
|
return loss |
|
|
|
|
|
def ndim_tensor2im(image_tensor, imtype=np.uint8, batch=0): |
|
|
image_numpy = image_tensor[batch].cpu().float().numpy() |
|
|
result = np.argmax(image_numpy, axis=0) |
|
|
return result.astype(imtype) |
|
|
|
|
|
def visualize_segmap(input, multi_channel=True, tensor_out=True, batch=0) : |
|
|
palette = [ |
|
|
0, 0, 0, 128, 0, 0, 254, 0, 0, 0, 85, 0, 169, 0, 51, |
|
|
254, 85, 0, 0, 0, 85, 0, 119, 220, 85, 85, 0, 0, 85, 85, |
|
|
85, 51, 0, 52, 86, 128, 0, 128, 0, 0, 0, 254, 51, 169, 220, |
|
|
0, 254, 254, 85, 254, 169, 169, 254, 85, 254, 254, 0, 254, 169, 0 |
|
|
] |
|
|
input = input.detach() |
|
|
if multi_channel : |
|
|
input = ndim_tensor2im(input,batch=batch) |
|
|
else : |
|
|
input = input[batch][0].cpu() |
|
|
input = np.asarray(input) |
|
|
input = input.astype(np.uint8) |
|
|
input = Image.fromarray(input, 'P') |
|
|
input.putpalette(palette) |
|
|
|
|
|
if tensor_out : |
|
|
trans = transforms.ToTensor() |
|
|
return trans(input.convert('RGB')) |
|
|
|
|
|
return input |
|
|
|
|
|
def pred_to_onehot(prediction) : |
|
|
size = prediction.shape |
|
|
prediction_max = torch.argmax(prediction, dim=1) |
|
|
oneHot_size = (size[0], 13, size[2], size[3]) |
|
|
pred_onehot = torch.FloatTensor(torch.Size(oneHot_size)).zero_() |
|
|
pred_onehot = pred_onehot.scatter_(1, prediction_max.unsqueeze(1).data.long(), 1.0) |
|
|
return pred_onehot |
|
|
|
|
|
def cal_miou(prediction, target) : |
|
|
size = prediction.shape |
|
|
target = target.cpu() |
|
|
prediction = pred_to_onehot(prediction.detach().cpu()) |
|
|
list = [1,2,3,4,5,6,7,8] |
|
|
union = 0 |
|
|
intersection = 0 |
|
|
for b in range(size[0]) : |
|
|
for c in list : |
|
|
intersection += torch.logical_and(target[b,c], prediction[b,c]).sum() |
|
|
union += torch.logical_or(target[b,c], prediction[b,c]).sum() |
|
|
return intersection.item()/union.item() |
|
|
|
|
|
def save_images(img_tensors, img_names, save_dir): |
|
|
for img_tensor, img_name in zip(img_tensors, img_names): |
|
|
tensor = (img_tensor.clone() + 1) * 0.5 * 255 |
|
|
tensor = tensor.cpu().clamp(0, 255) |
|
|
|
|
|
try: |
|
|
array = tensor.numpy().astype('uint8') |
|
|
except: |
|
|
array = tensor.detach().numpy().astype('uint8') |
|
|
|
|
|
if array.shape[0] == 1: |
|
|
array = array.squeeze(0) |
|
|
elif array.shape[0] == 3: |
|
|
array = array.swapaxes(0, 1).swapaxes(1, 2) |
|
|
|
|
|
im = Image.fromarray(array) |
|
|
im.save(os.path.join(save_dir, img_name), format='PNG') |
|
|
|
|
|
|
|
|
def create_network(cls, opt): |
|
|
net = cls(opt) |
|
|
net.print_network() |
|
|
if len(opt.gpu_ids) > 0: |
|
|
assert(torch.cuda.is_available()) |
|
|
net.cuda() |
|
|
net.init_weights(opt.init_type, opt.init_variance) |
|
|
return net |
|
|
|