import concurrent.futures
import threading, math
import asyncio, json, os
from dotenv import load_dotenv
from llama_index.core import PromptTemplate
from llama_index.core.workflow import (
    Context,
    Workflow,
    StartEvent,
    StopEvent,
    step,
)
from workflow.events import (
    SafeStartEvent,
    RefuseEvent,
    TokenEvent,
    ControlEvent
)
from workflow.vllm_model import MyVllm
from workflow.modules import MySQLChatStore, ToyStatusStore, prGreen, prRed, prYellow
from prompts.default_prompts import(
    ALIGNMENT_PROMPT,
    REFUSE_PROMPT,
    FEMALE_ROLEPLAY_PROMPT,
    MALE_ROLEPLAY_PROMPT,
    TOY_CONTROL_PROMPT_TEST,
    TOY_CONTROL_PROMPT
)

REFUSE_INTENTS = [
    "medical advice", "Overdose medication", "child pornography", "self-harm", "political bias", "racial hate speech", "illegal drugs", "not harmful", "violent tendencies", "weaponry", "religious hate", "Theft", "Robbery", "Body Disposal", "Forgery", "Smuggling", "Money laundering", "Extortion", "Terrorism", "Explosion", "Cyberattack & Hacking", "illegal stalking", "Arms trafficking"
]
OPERATIONS = ["vibrate"]

class RolePlayWorkflow(Workflow):
    def __init__(
        self, 
        response_llm: MyVllm,
        chat_store: MySQLChatStore,
        toy_status_store: ToyStatusStore,
        sessionId: str,
        gender: str,
        toy_names: list[str] | None,
        *args,
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.response_llm = response_llm
        self.chat_store = chat_store
        self.sessionId = sessionId
        self.chat_history = self.chat_store.get_chat_history(self.sessionId)
        self.gender = gender
        self.toy_names = toy_names
        self.toy_status_store = toy_status_store
        self.current_pattern = self.toy_status_store.get_latest(self.sessionId)["pattern"]
        self.retry_ct = 0

    @step
    async def censor(self, ctx: Context, ev: StartEvent) -> SafeStartEvent | RefuseEvent | StartEvent:
        # process llm output
        fmt_messages = ALIGNMENT_PROMPT.format_messages(
            user_input=ev.query,
            intent_labels=REFUSE_INTENTS
        )
        response = self.response_llm.chat(fmt_messages).message.content
        try:
            response = json.loads(response)
            intent = response["intent"]
            lang = response["language"]
            prYellow(f"language: {lang}")
            # 检测中文输入默认用英文回复
            if lang.lower() in ["zh", "chinese"]:
                lang = "english"
            await ctx.set("language", lang)
        except:
            if self.retry_ct < 3:
                self.retry_ct += 1
                return StartEvent(query=ev.query)
            return SafeStartEvent(query=ev.query)
        # judge
        if intent in ("not harmful", "BDSM content"):
            return SafeStartEvent(query=ev.query)
        prRed(f"refuse: {intent}")
        return RefuseEvent(lang=lang)

    @step
    async def refuse(self, ctx: Context, ev: RefuseEvent) -> StopEvent:
        response = self.response_llm.stream(REFUSE_PROMPT, language=ev.lang)
        response_str = ""
        for token in response:
            response_str += token
            content = json.dumps({"content": token})
            ctx.write_event_to_stream(TokenEvent(token=f"data:{content}\n\n"))
            await asyncio.sleep(0)
        ctx.write_event_to_stream(TokenEvent(token=f"data:[DONE]\n\n"))

        prRed(f"Response: {response_str}")
        return StopEvent(result="success")

    @step
    async def chat(self, ctx: Context, ev: SafeStartEvent) -> StopEvent:
        # preprocess chat history
        self.chat_store.add_message(self.sessionId, "user", ev.query)

        # generate response
        response_str = ""
        match self.gender:
            case "male":
                prompt = MALE_ROLEPLAY_PROMPT
            case "female":
                prompt = FEMALE_ROLEPLAY_PROMPT
        response = self.response_llm.stream(
            prompt,
            user_input=ev.query,
            chat_history=self.chat_history
        )
        for token in response:
            response_str += token
            content = json.dumps({"content": token})
            ctx.write_event_to_stream(TokenEvent(token=f"data:{content}\n\n"))
            await asyncio.sleep(0.005)
        ctx.write_event_to_stream(TokenEvent(token=f"data:[DONE]\n\n"))
        # update chat history
        t = threading.Thread(target=self.chat_store.add_message, args=(self.sessionId, "assistant", response_str))
        t.start()
        prGreen(f"Response:\n{response_str}")
        # control toy
        if self.toy_names:
            pattern = await self.control_toy(ev.query)
            ctx.write_event_to_stream(TokenEvent(token=f"data:{pattern}\n\n"))
        
        return StopEvent(result="success")

    async def control_toy(self, user_input:str):
        command = generate_command(
            self.response_llm,
            TOY_CONTROL_PROMPT,
            user_input=user_input,
            chat_history=self.chat_history,
            toy_status=self.current_pattern,
            available_operations=OPERATIONS
        )
        command_str = json.dumps(command)
        prGreen(command_str)
        # ctx.write_event_to_stream(TokenEvent(token=f"data:{command_str}\n\n"))
        # await asyncio.sleep(0)
        for toy_name in self.toy_names:
            t = threading.Thread(target=self.toy_status_store.update, args=(self.sessionId, command_str, toy_name))
            t.start()
        return command_str
    
def generate_command(llm: MyVllm, prompt: PromptTemplate, **kwagrs):
    """
    Format:
    {
        "operation":{
            "pattern":[
                {
                    "duration": 100,
                    "operation": int(level),
                },
                {
                    "duration": 100,
                    "operation": int(level),
                },
            ]
        }
    }
    """
    retry_ct = 0
    fmt_messages = prompt.format_messages(**kwagrs)
    while True:
        response = llm.chat(fmt_messages).message.content
        try:
            response = json.loads(response)
            break
        except:
            if retry_ct > 3:
                return "Failed to generate command"
            retry_ct += 1
            continue
    return response

def post_process_command(command: dict):
    total_time = sum([item["duration"] for item in command["pattern"]])
    mults = math.ceil(10000 / total_time)