|
import logging |
|
from omegaconf import DictConfig |
|
|
|
log = logging.getLogger() |
|
|
|
|
|
def get_dataset_cfg(cfg: DictConfig): |
|
dataset_name = cfg.dataset |
|
data_cfg = cfg.datasets[dataset_name] |
|
|
|
potential_overrides = [ |
|
'image_directory', |
|
'mask_directory', |
|
'json_directory', |
|
'size', |
|
'save_all', |
|
'use_all_masks', |
|
'use_long_term', |
|
'mem_every', |
|
] |
|
|
|
for override in potential_overrides: |
|
if cfg[override] is not None: |
|
log.info(f'Overriding config {override} from {data_cfg[override]} to {cfg[override]}') |
|
data_cfg[override] = cfg[override] |
|
|
|
if override in data_cfg: |
|
cfg[override] = data_cfg[override] |
|
|
|
return data_cfg |
|
|