import re

from langchain_openai import ChatOpenAI

from .agent import BaseAgent

SYSTEM_PROMPT = "You are an expert evaluator. Your task is to assess how well a Web Agent’s generated checklist aligns with the reference checklist for a given user instruction."

USER_PROMPT = """# Task Description
Use the provided task description, evaluation criteria, and both checklists to assign a score from 1 to 5. Justify your rating with a brief explanation that considers both content overlap and logical structure.

## Score Criteria
- 5: Checklist covers all subgoals, is correct and clearly expressed
- 4: Minor omissions or phrasing issues but mostly accurate and complete
- 3: Partially matches, but with noticeable gaps or errors
- 2: Incomplete or includes incorrect steps
- 1: Mostly irrelevant, incorrect, or missing the task goal

## User Instruction:
{intent}

## Reference Checklist:
{gt_checklist}

## Agent’s Generated Checklist:
{generated_checklist}

# Output Format
Your response should be in the following format:
REASON: [Write 2–4 sentences explaining how well the generated checklist matches the reference. Mention specific matches, omissions, errors, or strengths.]
SCORE: [1–5]
"""


class ChecklistEvalAgent(BaseAgent):
    def __init__(self, agent_config: dict):
        super().__init__(agent_config)
        self._setup()
    
    def prepare_message(self, model_input: dict, prompt_type):
        message = [
            {
                "role": "system",
                "content": SYSTEM_PROMPT
            },
            {
                "role": "user",
                "content": USER_PROMPT.format(
                    intent=model_input["intent"],
                    gt_checklist=model_input["gt_checklist"],
                    generated_checklist=model_input["generated_checklist"]
                )
            }
        ]
        return message
    
    def generate_response(self, model_input: dict):
        total_cost = 0
        response_list = []
        # prepare message
        message = self.prepare_message(model_input)

        # n sampling
        for _ in range(self.num_generate):
            response, cost = self.generate_with_retry(message, ["SCORE"])
            response_list.append(response)
            total_cost += cost

        return response_list, total_cost

def parsing_score(response: str):
    score = response.split("SCORE:")[-1].split("\n")[0].strip()
    match = re.search(r'\d+', score)
    
    if match:
        return int(match.group())
    else:
        return None

def average_score(scores: list[int]):
    if len(scores) == 0:
        return 0
    return sum(scores) / len(scores)

def get_score(results: list[dict]):
    score_list = []
    for result in results:
        tmp_scores = [parsing_score(response) for response in result["response"]]
        scores = [score for score in tmp_scores if score is not None]
        result["score_list"] = scores
        final_score = average_score(scores)
        result["score"] = final_score
        score_list.append(result)

    return results, score_list