# =============================
# 📄 codet5_summarizer.py (Updated)
# =============================
import torch
import re
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
import os 
MODEL_OPTIONS = {
    "CodeT5 Base (multi-sum)": "Salesforce/codet5-base-multi-sum",
    "CodeT5 Base": "Salesforce/codet5-base",
    "CodeT5 Small (Python-specific)": "stmnk/codet5-small-code-summarization-python",
    "Gemini (describeai)": "describeai/gemini",
    "Mistral 7B Instruct (v0.2)": "mistralai/Mistral-7B-Instruct-v0.2",
}

class CodeT5Summarizer:
    def __init__(self, model_name=None):
        model_name = model_name or MODEL_OPTIONS["CodeT5 Base (multi-sum)"]
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        hf_token = os.getenv('HF_TOKEN')
        if hf_token is None:
            raise ValueError("Hugging Face token must be set in the environment variable 'HF_TOKEN'.")

        self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)

        # Use causal model for decoder-only (e.g., Mistral), otherwise Seq2Seq
        try:
            self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, token=hf_token).to(self.device)
        except:
            self.model = AutoModelForCausalLM.from_pretrained(model_name, token=hf_token).to(self.device)

        self.is_encoder_decoder = self.model.config.is_encoder_decoder if hasattr(self.model.config, "is_encoder_decoder") else False

    def preprocess_code(self, code):
        code = re.sub(r'\n\s*\n', '\n', code)
        lines = code.split('\n')
        clean = []
        docstring = False
        for line in lines:
            if '"""' in line or "'''" in line:
                docstring = not docstring
            if docstring or not line.strip().startswith('#'):
                clean.append(line)
        return re.sub(r' +', ' ', '\n'.join(clean))

    def extract_functions(self, code):
        function_pattern = r'def\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(.*?\).*?:'
        function_matches = re.finditer(function_pattern, code, re.DOTALL)
        functions = []
        for match in function_matches:
            start_pos = match.start()
            function_name = match.group(1)
            lines = code[start_pos:].split('\n')
            body_start = 1
            while body_start < len(lines) and not lines[body_start].strip():
                body_start += 1
            if body_start < len(lines):
                body_indent = len(lines[body_start]) - len(lines[body_start].lstrip())
                function_body = [lines[0]]
                i = 1
                while i < len(lines):
                    line = lines[i]
                    if line.strip() and (len(line) - len(line.lstrip())) < body_indent and not line.strip().startswith('#'):
                        break
                    function_body.append(line)
                    i += 1
                function_code = '\n'.join(function_body)
                functions.append((function_name, function_code))

        # Class method detection
        class_pattern = r'class\s+([a-zA-Z_][a-zA-Z0-9_]*)'
        class_matches = re.finditer(class_pattern, code, re.DOTALL)
        for match in class_matches:
            class_name = match.group(1)
            start_pos = match.start()
            class_code = code[start_pos:]
            method_matches = re.finditer(function_pattern, class_code, re.DOTALL)
            for method_match in method_matches:
                if method_match.start() > 200:  # Only near the top of the class
                    break
                method_name = method_match.group(1)
                method_start = method_match.start()
                method_lines = class_code[method_start:].split('\n')
                body_start = 1
                while body_start < len(method_lines) and not method_lines[body_start].strip():
                    body_start += 1
                if body_start < len(method_lines):
                    body_indent = len(method_lines[body_start]) - len(method_lines[body_start].lstrip())
                    method_body = [method_lines[0]]
                    i = 1
                    while i < len(method_lines):
                        line = method_lines[i]
                        if line.strip() and (len(line) - len(line.lstrip())) < body_indent and not line.strip().startswith('#'):
                            break
                        method_body.append(line)
                        i += 1
                    method_code = '\n'.join(method_body)
                    functions.append((f"{class_name}.{method_name}", method_code))
        return functions

    def extract_classes(self, code):
        class_pattern = r'class\s+([a-zA-Z_][a-zA-Z0-9_]*)'
        class_matches = re.finditer(class_pattern, code, re.DOTALL)
        classes = []
        for match in class_matches:
            class_name = match.group(1)
            start_pos = match.start()
            class_lines = code[start_pos:].split('\n')
            body_start = 1
            while body_start < len(class_lines) and not class_lines[body_start].strip():
                body_start += 1
            if body_start < len(class_lines):
                body_indent = len(class_lines[body_start]) - len(class_lines[body_start].lstrip())
                class_body = [class_lines[0]]
                i = 1
                while i < len(class_lines):
                    line = class_lines[i]
                    if line.strip() and (len(line) - len(line.lstrip())) < body_indent:
                        break
                    class_body.append(line)
                    i += 1
                class_code = '\n'.join(class_body)
                classes.append((class_name, class_code))
        return classes

    def summarize(self, code, max_length=512):
        inputs = self.tokenizer(code, return_tensors="pt", truncation=True, max_length=512).to(self.device)
        with torch.no_grad():
            if self.is_encoder_decoder:
                output = self.model.generate(
                    inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],  # Optional but good to include

                    max_new_tokens=max_length,
                    num_beams=4,
                    early_stopping=True
                )
                return self.tokenizer.decode(output[0], skip_special_tokens=True)
            else:
                input_ids = inputs["input_ids"]
                attention_mask = inputs["attention_mask"]
    
                output = self.model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,  # ✅ Add this line
        
                    max_new_tokens=max_length,
                    do_sample=False,
                    num_beams=4,
                    early_stopping=True,
                    pad_token_id=self.tokenizer.eos_token_id
                )
                return self.tokenizer.decode(output[0], skip_special_tokens=True)

    def summarize_code(self, code, summarize_functions=True, summarize_classes=True):
        preprocessed_code = self.preprocess_code(code)
        results = {
            "file_summary": None,
            "function_summaries": {},
            "class_summaries": {}
        }
        try:
            results["file_summary"] = self.summarize(preprocessed_code)
        except Exception as e:
            results["file_summary"] = f"Error generating file summary: {str(e)}"

        if summarize_functions:
            for function_name, function_code in self.extract_functions(preprocessed_code):
                try:
                    summary = self.summarize(function_code)
                    results["function_summaries"][function_name] = summary
                except Exception as e:
                    results["function_summaries"][function_name] = f"Error: {str(e)}"

        if summarize_classes:
            for class_name, class_code in self.extract_classes(preprocessed_code):
                try:
                    summary = self.summarize(class_code)
                    results["class_summaries"][class_name] = summary
                except Exception as e:
                    results["class_summaries"][class_name] = f"Error: {str(e)}"

        return results