# Thank you code from https://huggingface.co/spaces/gokaygokay/Gemma-2-llamacpp
import spaces
import os
import json
import subprocess
from llama_cpp import Llama
# from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
# from llama_cpp_agent.providers import LlamaCppPythonProvider
# from llama_cpp_agent.chat_history import BasicChatHistory
# from llama_cpp_agent.chat_history.messages import Roles
import gradio as gr
from huggingface_hub import hf_hub_download


# huggingface_token = os.getenv("HUGGINGFACE_TOKEN")

hf_hub_download(
    repo_id="wannaphong/KhanomTanLLM-1B-Instruct-Q2_K-GGUF",
    filename="khanomtanllm-1b-instruct-q2_k.gguf",
    local_dir="./models"
)

hf_hub_download(
    repo_id="wannaphong/KhanomTanLLM-3B-Instruct-Q2_K-GGUF",
    filename="khanomtanllm-3b-instruct-q2_k.gguf",
    local_dir="./models"
)

# hf_hub_download(
#     repo_id="google/gemma-2-2b-it-GGUF",
#     filename="2b_it_v2.gguf",
#     local_dir="./models",
#     token=huggingface_token
# )



llm = None
llm_model = None
#
@spaces.GPU #duration=120)
def respond(
    message,
    history: list[tuple[str, str]],
    model,
    system_message,
    max_tokens,
    temperature,
    min_p,
    top_p,
    top_k,
    repeat_penalty,
):
    # chat_template = MessagesFormatterType.MISTRAL

    global llm
    global llm_model
    
    if llm is None or llm_model != model:
        llm = Llama(
            model_path=f"models/{model}",
            flash_attn=True,
            #n_gpu_layers=81,
            n_batch=1024,
            n_ctx=2048,
        )
        llm_model = model

    # provider = LlamaCppPythonProvider(llm)

    # agent = LlamaCppAgent(
    #     provider,
    #     system_prompt=f"{system_message}",
    #     predefined_messages_formatter_type=chat_template,
    #     debug_output=True
    # )
    
    # settings = provider.get_provider_default_settings()
    # settings.temperature = temperature
    # settings.top_k = top_k
    # settings.top_p = top_p
    # settings.min_p = min_p
    # settings.max_tokens = max_tokens
    # settings.repeat_penalty = repeat_penalty
    # settings.stream = True

    # messages = BasicChatHistory()
    messages=[{"role":"system","content":system_message}]
    chat=[{"role":"user","content":message}]
    chat_b=[]

    i=1
    if history!=[]:
        for msn in history:
            messages.append({"role":"user","content":msn[0]})
            messages.append({"role":"assistant","content":msn[1]})
    messages+=chat
    print(messages)
    stream = llm.create_chat_completion(messages=messages,temperature = temperature,top_k = top_k,top_p = top_p,min_p = min_p,max_tokens = max_tokens,repeat_penalty = repeat_penalty,stream = True)
    
    outputs = ""
    for chunk in stream:
        delta = chunk['choices'][0]['delta']
        if 'content' in delta:
            tokens = delta['content']#.split()
            for token in tokens:
                outputs+=token
                yield outputs
            
        #yield outputs.replace("<|assistant|>","").replace("<|user|>","")

description = """
"""

demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Dropdown([
                'khanomtanllm-1b-instruct-q2_k.gguf',
                'khanomtanllm-3b-instruct-q2_k.gguf',
            ],
            value="khanomtanllm-1b-instruct-q2_k.gguf",
            label="Model"
        ),
        gr.Textbox(value="You are a helpful assistant.", label="System message", lines=6),
        gr.Slider(minimum=1, maximum=2048, value=2048, step=1, label="Max tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=2.0, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.7,
            step=0.05,
            label="min-p",
        ),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p",
        ),
        gr.Slider(
            minimum=0,
            maximum=100,
            value=40,
            step=1,
            label="Top-k",
        ),
        gr.Slider(
            minimum=0.0,
            maximum=2.0,
            value=1.1,
            step=0.1,
            label="Repetition penalty",
        ),
    ],
    retry_btn="Retry",
    undo_btn="Undo",
    clear_btn="Clear",
    submit_btn="Send",
    title="Chat with KhanomTanLLM using llama.cpp", 
    description=description,
    chatbot=gr.Chatbot(
        scale=1, 
        likeable=False,
        show_copy_button=True
    )
)

if __name__ == "__main__":
    demo.launch()