import os
import json
import shutil

from optimum.exporters.onnx import main_export
import onnx
from onnxconverter_common import float16
import onnxruntime as rt
from onnxruntime.tools.onnx_model_utils import *
from onnxruntime.quantization import quantize_dynamic, QuantType
from huggingface_hub import hf_hub_download


with open('conversion_config.json') as json_file:
    conversion_config = json.load(json_file)


    model_id = conversion_config["model_id"]
    number_of_generated_embeddings = conversion_config["number_of_generated_embeddings"]
    precision_to_filename_map = conversion_config["precision_to_filename_map"]
    opset = conversion_config["opset"]
    IR = conversion_config["IR"]

    
    op = onnx.OperatorSetIdProto()
    op.version = opset
    
    
    if not os.path.exists("onnx"):
        os.makedirs("onnx")
    

    if "int8" in precision_to_filename_map:
        print("Exporting the int8 onnx file...")


        filename = precision_to_filename_map['int8']
 
        hf_hub_download(repo_id=model_id, filename=filename, local_dir = "./")
        model = onnx.load(filename)
        model_fixed = onnx.helper.make_model(model.graph, ir_version = IR, opset_imports = [op]) #to be sure that we have compatible opset and IR version
        onnx.save(model_fixed, filename)

        
        print("Done\n\n")
        
    if "uint8" in precision_to_filename_map:
        print("Exporting the uint8 onnx file...")

        filename = precision_to_filename_map['uint8']
 
        hf_hub_download(repo_id=model_id, filename=filename, local_dir = "./")
        model = onnx.load(filename)
        model_fixed = onnx.helper.make_model(model.graph, ir_version = IR, opset_imports = [op]) #to be sure that we have compatible opset and IR version
        onnx.save(model_fixed, filename)

        
        print("Done\n\n")