PnLCalib / model /dataloader_l.py
2nzi's picture
Upload 63 files
3d1f2c9
import os
import sys
import glob
import json
import torch
import numpy as np
from torchvision.transforms import v2
from torch.utils.data import Dataset
from PIL import Image
from utils.utils_lines import LineKeypointsDB
from utils.utils_linesWC import LineKeypointsWCDB
class SoccerNetCalibrationDataset(Dataset):
def __init__(self, root_dir, split, transform, main_cam_only=True):
self.root_dir = root_dir
self.split = split
self.transform = transform
#self.match_info = json.load(open(root_dir + split + '/match_info.json'))
self.files = glob.glob(os.path.join(self.root_dir + self.split, "*.jpg"))
if main_cam_only:
self.get_main_camera()
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = self.files[idx]
image = Image.open(img_name)
data = json.load(open(img_name.split('.')[0] + ".json"))
sample = self.transform({'image': image, 'data': data})
img_db = LineKeypointsDB(sample['data'], sample['image'])
target = img_db.get_tensor()
return sample['image'], torch.from_numpy(target).float()
def get_main_camera(self):
self.files = [file for file in self.files if int(self.match_info[file.split('/')[-1]]['ms_time']) == \
int(self.match_info[file.split('/')[-1]]['replay_time'])]
class WorldCup2014Dataset(Dataset):
def __init__(self, root_dir, split, transform):
self.root_dir = root_dir
self.split = split
self.transform = transform
assert self.split in ['train_val', 'test'], f'unknown dataset type {self.split}'
self.files = glob.glob(os.path.join(self.root_dir + self.split, "*.jpg"))
self.homographies = glob.glob(os.path.join(self.root_dir + self.split, "*.homographyMatrix"))
self.num_samples = len(self.files)
self.files.sort()
self.homographies.sort()
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
image = self.get_image_by_index(idx)
homography = self.get_homography_by_index(idx)
img_db = LineKeypointsWCDB(image, homography, (960,540))
target, mask = img_db.get_tensor_w_mask()
sample = self.transform({'image': image, 'target': target, 'mask': mask})
return sample['image'], sample['target'], sample['mask']
def convert_homography_WC14GT_to_SN(self, H):
T = np.eye(3)
#T[0, -1] = -115 / 2
#T[1, -1] = -74 / 2
yard2meter = 0.9144
S = np.eye(3)
S[0, 0] = yard2meter
S[1, 1] = yard2meter
H_SN = S @ (T @ H)
return H_SN
def get_image_by_index(self, index):
img_file = self.files[index]
image = Image.open(img_file)
return image
def get_homography_by_index(self, index):
homography_file = self.homographies[index]
with open(homography_file, 'r') as file:
lines = file.readlines()
matrix_elements = []
for line in lines:
matrix_elements.extend([float(element) for element in line.split()])
homography = np.array(matrix_elements).reshape((3, 3))
homography = self.convert_homography_WC14GT_to_SN(homography)
homography = torch.from_numpy(homography)
homography = homography / homography[2:3, 2:3]
return homography
class TSWorldCupDataset(Dataset):
def __init__(self, root_dir, split, transform):
self.root_dir = root_dir
self.split = split
self.transform = transform
assert self.split in ['train', 'test'], f'unknown dataset type {self.split}'
self.files_txt = self.get_txt()
self.files = self.get_jpg_files()
self.homographies = self.get_homographies()
self.num_samples = len(self.files)
self.files.sort()
self.homographies.sort()
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
image = self.get_image_by_index(idx)
homography = self.get_homography_by_index(idx)
img_db = LineKeypointsWCDB(image, homography, (960,540))
target, mask = img_db.get_tensor_w_mask()
sample = self.transform({'image': image, 'target': target, 'mask': mask})
return sample['image'], sample['target'], sample['mask']
def get_txt(self):
with open(self.root_dir + self.split + '.txt', 'r') as file:
lines = file.readlines()
lines = [line.strip() for line in lines]
return lines
def get_jpg_files(self):
all_jpg_files = []
for dir in self.files_txt:
full_dir = self.root_dir + "Dataset/80_95/" + dir
jpg_files = []
for file in os.listdir(full_dir):
if file.lower().endswith('.jpg') or file.lower().endswith('.jpeg'):
jpg_files.append(os.path.join(full_dir, file))
all_jpg_files.extend(jpg_files)
return all_jpg_files
def get_homographies(self):
all_homographies = []
for dir in self.files_txt:
full_dir = self.root_dir + "Annotations/80_95/" + dir
homographies = []
for file in os.listdir(full_dir):
if file.lower().endswith('.npy'):
homographies.append(os.path.join(full_dir, file))
all_homographies.extend(homographies)
return all_homographies
def convert_homography_WC14GT_to_SN(self, H):
T = np.eye(3)
#T[0, -1] = -115 / 2
#T[1, -1] = -74 / 2
yard2meter = 0.9144
S = np.eye(3)
S[0, 0] = yard2meter
S[1, 1] = yard2meter
H_SN = S @ (T @ H)
return H_SN
def get_image_by_index(self, index):
img_file = self.files[index]
image = Image.open(img_file)
return image
def get_homography_by_index(self, index):
homography_file = self.homographies[index]
homography = np.load(homography_file)
homography = self.convert_homography_WC14GT_to_SN(homography)
homography = torch.from_numpy(homography)
homography = homography / homography[2:3, 2:3]
return homography