Spaces:
Running
Running
import enum | |
from functools import reduce | |
from typing import Dict, List, Tuple | |
import numpy as np | |
import copy | |
from utils.common.log import logger | |
from ..datasets.ab_dataset import ABDataset | |
from ..dataloader import FastDataLoader, InfiniteDataLoader, build_dataloader | |
from data import get_dataset, MergedDataset, Scenario as DAScenario | |
class _ABDatasetMetaInfo: | |
def __init__(self, name, classes, task_type, object_type, class_aliases, shift_type, ignore_classes, idx_map): | |
self.name = name | |
self.classes = classes | |
self.class_aliases = class_aliases | |
self.shift_type = shift_type | |
self.task_type = task_type | |
self.object_type = object_type | |
self.ignore_classes = ignore_classes | |
self.idx_map = idx_map | |
def __repr__(self) -> str: | |
return f'({self.name}, {self.classes}, {self.idx_map})' | |
class Scenario: | |
def __init__(self, config, target_datasets_info: List[_ABDatasetMetaInfo], num_classes: int, num_source_classes: int, data_dirs): | |
self.config = config | |
self.target_datasets_info = target_datasets_info | |
self.num_classes = num_classes | |
self.cur_task_index = 0 | |
self.num_source_classes = num_source_classes | |
self.cur_class_offset = num_source_classes | |
self.data_dirs = data_dirs | |
self.target_tasks_order = [i.name for i in self.target_datasets_info] | |
self.num_tasks_to_be_learn = sum([len(i.classes) for i in target_datasets_info]) | |
logger.info(f'[scenario build] # classes: {num_classes}, # tasks to be learnt: {len(target_datasets_info)}, ' | |
f'# classes per task: {config["num_classes_per_task"]}') | |
def to_json(self): | |
config = copy.deepcopy(self.config) | |
config['da_scenario'] = config['da_scenario'].to_json() | |
target_datasets_info = [str(i) for i in self.target_datasets_info] | |
return dict( | |
config=config, target_datasets_info=target_datasets_info, | |
num_classes=self.num_classes | |
) | |
def __str__(self): | |
return f'Scenario({self.to_json()})' | |
def get_cur_class_offset(self): | |
return self.cur_class_offset | |
def get_cur_num_class(self): | |
return len(self.target_datasets_info[self.cur_task_index].classes) | |
def get_nc_per_task(self): | |
return len(self.target_datasets_info[0].classes) | |
def next_task(self): | |
self.cur_class_offset += len(self.target_datasets_info[self.cur_task_index].classes) | |
self.cur_task_index += 1 | |
print(f'now, cur task: {self.cur_task_index}, cur_class_offset: {self.cur_class_offset}') | |
def get_cur_task_datasets(self): | |
dataset_info = self.target_datasets_info[self.cur_task_index] | |
dataset_name = dataset_info.name.split('|')[0] | |
# print() | |
# source_datasets_info = [] | |
res ={ **{split: get_dataset(dataset_name=dataset_name, | |
root_dir=self.data_dirs[dataset_name], | |
split=split, | |
transform=None, | |
ignore_classes=dataset_info.ignore_classes, | |
idx_map=dataset_info.idx_map) for split in ['train']}, | |
**{split: MergedDataset([get_dataset(dataset_name=dataset_name, | |
root_dir=self.data_dirs[dataset_name], | |
split=split, | |
transform=None, | |
ignore_classes=di.ignore_classes, | |
idx_map=di.idx_map) for di in self.target_datasets_info[0: self.cur_task_index + 1]]) | |
for split in ['val', 'test']} | |
} | |
# if len(res['train']) < 200 or len(res['val']) < 200 or len(res['test']) < 200: | |
# return None | |
if len(res['train']) < 1000: | |
res['train'] = MergedDataset([res['train']] * 5) | |
logger.info('aug train dataset') | |
if len(res['val']) < 1000: | |
res['val'] = MergedDataset(res['val'].datasets * 5) | |
logger.info('aug val dataset') | |
if len(res['test']) < 1000: | |
res['test'] = MergedDataset(res['test'].datasets * 5) | |
logger.info('aug test dataset') | |
# da_scenario: DAScenario = self.config['da_scenario'] | |
# offline_datasets = da_scenario.get_offline_datasets() | |
for k, v in res.items(): | |
logger.info(f'{k} dataset: {len(v)}') | |
# new_val_datasets = [ | |
# *[d['val'] for d in offline_datasets.values()], | |
# res['val'] | |
# ] | |
# res['val'] = MergedDataset(new_val_datasets) | |
# new_test_datasets = [ | |
# *[d['test'] for d in offline_datasets.values()], | |
# res['test'] | |
# ] | |
# res['test'] = MergedDataset(new_test_datasets) | |
return res | |
def get_cur_task_train_datasets(self): | |
dataset_info = self.target_datasets_info[self.cur_task_index] | |
dataset_name = dataset_info.name.split('|')[0] | |
# print() | |
# source_datasets_info = [] | |
res = get_dataset(dataset_name=dataset_name, | |
root_dir=self.data_dirs[dataset_name], | |
split='train', | |
transform=None, | |
ignore_classes=dataset_info.ignore_classes, | |
idx_map=dataset_info.idx_map) | |
return res | |
def get_online_cur_task_samples_for_training(self, num_samples): | |
dataset = self.get_cur_task_datasets() | |
dataset = dataset['train'] | |
return next(iter(build_dataloader(dataset, num_samples, 0, True, None)))[0] |