import argparse
import torch
import os
import json
from tqdm import tqdm

import project_subpath
from backend.dataloader import create_dataloader_frames_only
from backend.inference import setup_model, do_detection


def main(args, verbose=False):
    """
    Construct and save raw detections from yolov5 based on a frame directory
    Args:
        frames (str): path to image directory
        output (str): where detections will be stored
        weights (str): path to model weights
    """
    

    print("In task...")
    print("Cuda available in task?", torch.cuda.is_available())

    model, device = setup_model(args.weights)

    in_loc_dir = os.path.join(args.frames, args.location)
    out_loc_dir = os.path.join(args.output, args.location)
    print(in_loc_dir)
    print(out_loc_dir)

    detect_location(in_loc_dir, out_loc_dir, model, device, verbose)


                
def detect_location(in_loc_dir, out_loc_dir, model, device, verbose):

    seq_list = os.listdir(in_loc_dir)

    with tqdm(total=len(seq_list), desc="...", ncols=0) as pbar:
        for seq in seq_list:

            pbar.update(1)
            if (seq.startswith(".")): continue
            pbar.set_description("Processing " + seq)

            in_seq_dir = os.path.join(in_loc_dir, seq)
            out_seq_dir = os.path.join(out_loc_dir, seq)
            os.makedirs(out_seq_dir, exist_ok=True)

            detect(in_seq_dir, out_seq_dir, model, device, verbose)

def detect(in_seq_dir, out_seq_dir, model, device, verbose):

    # create dataloader
    dataloader = create_dataloader_frames_only(in_seq_dir)

    inference, image_shapes, width, height = do_detection(dataloader, model, device, verbose=verbose)

    json_obj = {
        'image_shapes': image_shapes,
        'width': width,
        'height': height
    }

    with open(os.path.join(out_seq_dir, 'pred.json'), 'w') as f:
        json.dump(json_obj, f)

    torch.save(inference, os.path.join(out_seq_dir, 'inference.pt'))

def argument_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--frames", default="../frames/images", help="Path to frame directory. Required.")
    parser.add_argument("--location", default="kenai-val", help="Name of location dir. Required.")
    parser.add_argument("--output", default="../frames/detections/detection_storage/", help="Path to output directory. Required.")
    parser.add_argument("--weights", default='models/v5m_896_300best.pt', help="Path to saved YOLOv5 weights. Default: ../models/v5m_896_300best.pt")
    return parser

if __name__ == "__main__":
    args = argument_parser().parse_args()
    main(args)