|
"""
|
|
DataLoader used to train the segmentation network used for the prediction of extremities.
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import time
|
|
from argparse import ArgumentParser
|
|
|
|
import cv2 as cv
|
|
import numpy as np
|
|
from torch.utils.data import Dataset
|
|
from tqdm import tqdm
|
|
|
|
from src.soccerpitch import SoccerPitch
|
|
|
|
|
|
class SoccerNetDataset(Dataset):
|
|
def __init__(self,
|
|
datasetpath,
|
|
split="test",
|
|
width=640,
|
|
height=360,
|
|
mean="../resources/mean.npy",
|
|
std="../resources/std.npy"):
|
|
self.mean = np.load(mean)
|
|
self.std = np.load(std)
|
|
self.width = width
|
|
self.height = height
|
|
|
|
dataset_dir = os.path.join(datasetpath, split)
|
|
if not os.path.exists(dataset_dir):
|
|
print("Invalid dataset path !")
|
|
exit(-1)
|
|
|
|
frames = [f for f in os.listdir(dataset_dir) if ".jpg" in f]
|
|
|
|
self.data = []
|
|
self.n_samples = 0
|
|
for frame in frames:
|
|
|
|
frame_index = frame.split(".")[0]
|
|
annotation_file = os.path.join(dataset_dir, f"{frame_index}.json")
|
|
if not os.path.exists(annotation_file):
|
|
continue
|
|
with open(annotation_file, "r") as f:
|
|
groundtruth_lines = json.load(f)
|
|
img_path = os.path.join(dataset_dir, frame)
|
|
if groundtruth_lines:
|
|
self.data.append({
|
|
"image_path": img_path,
|
|
"annotations": groundtruth_lines,
|
|
})
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, index):
|
|
item = self.data[index]
|
|
|
|
img = cv.imread(item["image_path"])
|
|
img = cv.resize(img, (self.width, self.height), interpolation=cv.INTER_LINEAR)
|
|
|
|
mask = np.zeros(img.shape[:-1], dtype=np.uint8)
|
|
img = np.asarray(img, np.float32) / 255.
|
|
img -= self.mean
|
|
img /= self.std
|
|
img = img.transpose((2, 0, 1))
|
|
for class_number, class_ in enumerate(SoccerPitch.lines_classes):
|
|
if class_ in item["annotations"].keys():
|
|
key = class_
|
|
line = item["annotations"][key]
|
|
prev_point = line[0]
|
|
for i in range(1, len(line)):
|
|
next_point = line[i]
|
|
cv.line(mask,
|
|
(int(prev_point["x"] * mask.shape[1]), int(prev_point["y"] * mask.shape[0])),
|
|
(int(next_point["x"] * mask.shape[1]), int(next_point["y"] * mask.shape[0])),
|
|
class_number + 1,
|
|
2)
|
|
prev_point = next_point
|
|
return img, mask
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
parser = ArgumentParser(description='dataloader')
|
|
|
|
parser.add_argument('--SoccerNet_path', default="./annotations/", type=str,
|
|
help='Path to the SoccerNet-V3 dataset folder')
|
|
parser.add_argument('--tiny', required=False, type=int, default=None, help='Select a subset of x games')
|
|
parser.add_argument('--split', required=False, type=str, default="test", help='Select the split of data')
|
|
parser.add_argument('--num_workers', required=False, type=int, default=4,
|
|
help='number of workers for the dataloader')
|
|
parser.add_argument('--resolution_width', required=False, type=int, default=1920,
|
|
help='width resolution of the images')
|
|
parser.add_argument('--resolution_height', required=False, type=int, default=1080,
|
|
help='height resolution of the images')
|
|
parser.add_argument('--preload_images', action='store_true',
|
|
help="Preload the images when constructing the dataset")
|
|
parser.add_argument('--zipped_images', action='store_true', help="Read images from zipped folder")
|
|
|
|
args = parser.parse_args()
|
|
|
|
start_time = time.time()
|
|
soccernet = SoccerNetDataset(args.SoccerNet_path, split=args.split)
|
|
with tqdm(enumerate(soccernet), total=len(soccernet), ncols=160) as t:
|
|
for i, data in t:
|
|
img = soccernet[i][0].astype(np.uint8).transpose((1, 2, 0))
|
|
print(img.shape)
|
|
print(img.dtype)
|
|
cv.imshow("Normalized image", img)
|
|
cv.waitKey(0)
|
|
cv.destroyAllWindows()
|
|
print(data[1].shape)
|
|
cv.imshow("Mask", soccernet[i][1].astype(np.uint8))
|
|
cv.waitKey(0)
|
|
cv.destroyAllWindows()
|
|
continue
|
|
end_time = time.time()
|
|
print(end_time - start_time)
|
|
|