import os
import tempfile

from typing import Dict, List, Optional

from bs4 import BeautifulSoup
import yt_dlp
import pandas as pd
import requests
import torch

from langchain_community.document_loaders import YoutubeLoader
from langchain_community.retrievers import BM25Retriever
from langchain_community.tools import BearlyInterpreterTool
from langchain.docstore.document import Document
from smolagents import (
    DuckDuckGoSearchTool,
    SpeechToTextTool,
    Tool,
    VisitWebpageTool,
    WikipediaSearchTool,
)
from transformers import AutoProcessor, AutoModelForImageTextToText


class RelevantInfoRetrieverTool(Tool):
    name = "relevant_info_retriever"
    description = "Retrieves relevant to the query information."
    inputs = {
        "query": {
            "type": "string",
            "description": "The query for which to retrieve information.",
        },
        "docs": {
            "type": "string",
            "description": "The source documents from which to choose in order to retrieve relevant information",
        },
    }
    output_type = "string"

    def forward(self, query: str, docs: List[Document]):
        self.retriever = BM25Retriever.from_documents(docs)
        results = self.retriever.get_relevant_documents(query)
        if results:
            return "\n\n".join([doc.page_content for doc in results])
        else:
            return "No relevant information found."


class YoutubeTranscriptTool(Tool):
    name = "youtube_transcript"
    description = "Fetches youtube video's transcript."
    inputs = {
        "youtube_url": {
            "type": "string",
            "description": "The youtube video url",
        },
        "source_langs": {
            "type": "array",
            "description": "A list of language codes in a descending priority for the video trascript.",
            "items": {"type": "string"},
            "default": ["en"],
            "required": False,
            "nullable": True,
        },
        "target_lang": {
            "type": "string",
            "description": "The language to which the transcript will be translated.",
            "default": "en",
            "required": False,
            "nullable": True,
        },
    }
    output_type = "string"

    def forward(
        self,
        youtube_url: str,
        source_langs: Optional[List[str]] = ["en"],
        target_lang: Optional[str] = "en",
    ):
        try:
            loader = YoutubeLoader.from_youtube_url(
                youtube_url,
                add_video_info=True,
                language=source_langs,
                translation=target_lang,
                # transcript_format=TranscriptFormat.CHUNKS,
                # chunk_size_seconds=30,
            )
            transcript_docs = loader.load()
            return transcript_docs

        except Exception as e:
            return f"Error fetching video's transcript: {e}"


class ReverseStringTool(Tool):
    name = "reverse_string"
    description = "Reverses the input string."
    inputs = {
        "string": {
            "type": "string",
            "description": "The string that needs to be reversed.",
        }
    }
    output_type = "string"

    def forward(self, string: str):
        try:
            return string[-1::-1]
        except Exception as e:
            return f"Error reversing string: {e}"


class SmolVLM2:
    """The parent class for visual analyzer tools (using SmolVLM2-500M-Video model)"""

    def __init__(self):
        """Initializations for the analyzer tool"""
        model_path = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct"
        device = "cpu"  # "cuda" if torch.cuda.is_available() else "cpu"
        self.processor = AutoProcessor.from_pretrained(model_path)
        self.model = AutoModelForImageTextToText.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16,
            # _attn_implementation="flash_attention_2",
        ).to(device)


class ImagesAnalyzerTool(Tool, SmolVLM2):
    name = "image_analyzer"
    description = "Analyzes each input image according to the query"
    inputs = {
        "query": {
            "type": "string",
            "description": "The query according to which the image will be analyzed.",
        },
        "images_urls": {
            "type": "array",
            "description": "A list of strings containing the images' urls",
            "items": {"type": "string"},
        },
    }
    output_type = "string"

    def __init__(self):
        Tool.__init__(self)
        SmolVLM2.__init__(self)

    def forward(self, query: str, images_urls: List[str]):

        try:

            # Image message entities for the different images' urls
            image_message_ents = [{"type": "image", "url": iu} for iu in images_urls]

            messages = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": query,
                        },
                    ]
                    + image_message_ents,
                },
            ]

            inputs = self.processor.apply_chat_template(
                messages,
                add_generation_prompt=True,
                tokenize=True,
                return_dict=True,
                return_tensors="pt",
            ).to(self.model.device, dtype=torch.bfloat16)

            generated_ids = self.model.generate(
                **inputs, do_sample=False, max_new_tokens=64
            )
            generated_texts = self.processor.batch_decode(
                generated_ids,
                skip_special_tokens=True,
            )
            return generated_texts[0]
        except Exception as e:
            return f"Error analyzing image(s): {e}"


class VideoAnalyzerTool(Tool, SmolVLM2):
    name = "video_analyzer"
    description = "Analyzes video at a specified path according to the query"
    inputs = {
        "query": {
            "type": "string",
            "description": "The query according to which the video will be analyzed.",
        },
        "video_path": {
            "type": "string",
            "description": "A string containing the video path",
        },
    }
    output_type = "string"

    def __init__(self):
        Tool.__init__(self)
        SmolVLM2.__init__(self)

    def forward(self, query: str, video_path: str) -> str:
        try:
            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "video", "path": video_path},
                        {"type": "text", "text": query},
                    ],
                },
            ]

            inputs = self.processor.apply_chat_template(
                messages,
                add_generation_prompt=True,
                tokenize=True,
                return_dict=True,
                return_tensors="pt",
            ).to(self.model.device, dtype=torch.bfloat16)

            generated_ids = self.model.generate(
                **inputs, do_sample=False, max_new_tokens=64
            )
            generated_texts = self.processor.batch_decode(
                generated_ids,
                skip_special_tokens=True,
            )

            return generated_texts[0]
        except Exception as e:
            return f"Error analyzing video: {e}"
        finally:
            # Cleanup if needed
            if video_path and os.path.exists(video_path):
                os.remove(video_path)


class FileDownloaderTool(Tool):
    name = "file_downloader"
    description = "Downloads a file returning the name of the temporarily saved file"
    inputs = {
        "file_url": {
            "type": "string",
            "description": "The url from which the file shall be downloaded.",
        },
    }
    output_type = "string"

    def forward(self, file_url: str) -> str:
        response = requests.get(file_url, stream=True)
        response.raise_for_status()
        original_filename = (
            response.headers.get("content-disposition", "")
            .split("=", -1)[-1]
            .strip('"')
        )

        # Even if original_filename is empty or there is no extension, ext will be ""
        ext = os.path.splitext(original_filename)[-1]

        with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp_file:
            for chunk in response.iter_content(chunk_size=8192):
                tmp_file.write(chunk)
            return tmp_file.name


class YoutubeVideoDownloaderTool(Tool):
    name = "youtube_video_downloader"
    description = "Downloads the video from the specified url and returns the path where the video was saved"
    inputs = {
        "video_url": {
            "type": "string",
            "description": "A string containing the video url",
        },
    }
    output_type = "string"

    def forward(self, video_url: str) -> str:
        try:
            saved_video_path = ""
            temp_dir = tempfile.gettempdir()
            ydl_opts = {
                "outtmpl": f"{temp_dir}/%(title)s.%(ext)s",  # Absolute or relative path
                "quiet": True,
            }

            # Download youtube video as a file in tmp directory
            with yt_dlp.YoutubeDL(ydl_opts) as ydl:
                info = ydl.extract_info(video_url, download=True)
                saved_video_path = ydl.prepare_filename(info)
                return saved_video_path
        except Exception as e:
            return f"Error downloading video: {e}"


class LoadXlsxFileTool(Tool):
    name = "load_xlsx_file"
    description = "This tool loads xlsx file into pandas and returns it"
    inputs = {"file_path": {"type": "string", "description": "File path"}}
    output_type = "object"

    def forward(self, file_path: str) -> object:
        return pd.read_excel(file_path)


class LoadTextFileTool(Tool):
    name = "load_text_file"
    description = "This tool loads any text file"
    inputs = {"file_path": {"type": "string", "description": "File path"}}
    output_type = "string"

    def forward(self, file_path: str) -> str:
        with open(file_path, "r", encoding="utf-8") as file:
            return file.read()


class WebpageTablesContextRetrieverTool(Tool):
    name = "webpage_tables_context_retriever"
    description = """Retrieves structural context for all tables on a webpage.
    Returns table indexes with captions, headers, and surrounding text to help identify relevant tables.
    Use this first to determine which table index to extract."""
    inputs = {
        "url": {"type": "string", "description": "The URL of the webpage to analyze"}
    }
    output_type = "object"

    def forward(self, url: str) -> Dict:
        """Retrieve context information for all tables on the page"""
        try:
            response = requests.get(url, timeout=15)
            response.raise_for_status()
            soup = BeautifulSoup(response.text, "html.parser")

            tables = soup.find_all("table")
            if not tables:
                return {
                    "status": "success",
                    "tables": [],
                    "message": "No tables found on page",
                    "url": url,
                }

            results = []
            for i, table in enumerate(tables):
                context = {
                    "index": i,
                    "id": table.get("id", ""),
                    "class": " ".join(table.get("class", [])),
                    "summary": table.get("summary", ""),
                    "caption": self._get_table_caption(table),
                    "preceding_header": self._get_preceding_header(table),
                    "surrounding_text": self._get_surrounding_text(table),
                }
                results.append(context)

            return {
                "status": "success",
                "tables": results,
                "url": url,
                "message": f"Found {len(results)} tables with context information",
                "suggestion": "Use html_table_extractor with the most relevant index",
            }

        except Exception as e:
            return {
                "status": "error",
                "url": url,
                "message": f"Failed to retrieve table contexts: {str(e)}",
            }

    def _get_table_caption(self, table) -> str:
        """Extract table caption text if available"""
        caption = table.find("caption")
        return caption.get_text(strip=True) if caption else ""

    def _get_preceding_header(self, table) -> str:
        """Find the nearest preceding heading"""
        for tag in table.find_all_previous(["h1", "h2", "h3", "h4", "h5", "h6"]):
            return tag.get_text(strip=True)
        return ""

    def _get_surrounding_text(self, table, chars=150) -> str:
        """Get relevant text around the table"""
        prev_text = " ".join(
            t.strip()
            for t in table.find_all_previous(string=True, limit=3)
            if t.strip()
        )
        next_text = " ".join(
            t.strip() for t in table.find_all_next(string=True, limit=3) if t.strip()
        )
        return f"...{prev_text[-chars:]} [TABLE] {next_text[:chars]}..."


class HtmlTableExtractorTool(Tool):
    name = "html_table_extractor"
    description = """Extracts a specific HTML table as structured data.
    Use after webpage_tables_context_retriever to get the correct table index."""
    inputs = {
        "page_url": {
            "type": "string",
            "description": "The webpage URL containing the table",
        },
        "table_index": {
            "type": "integer",
            "description": "0-based index of the table to extract (from webpage_tables_context_retriever)",
        },
    }
    output_type = "object"

    def forward(self, page_url: str, table_index: int) -> Dict:
        """Extract a specific table by index"""
        try:
            # First verify the URL is accessible
            test_request = requests.head(page_url, timeout=5)
            test_request.raise_for_status()

            # Read all tables
            tables = pd.read_html(page_url)

            if not tables:
                return {
                    "status": "error",
                    "message": "No tables found at URL",
                    "url": page_url,
                }

            # Validate index
            if table_index < 0 or table_index >= len(tables):
                return {
                    "status": "error",
                    "message": f"Invalid table index {table_index}. Page has {len(tables)} tables.",
                    "url": page_url,
                    "available_indexes": list(range(len(tables))),
                }

            # Convert DataFrame to JSON-serializable format
            df = tables[table_index]
            return {
                "status": "success",
                "table_index": table_index,
                "table_data": df,
                "url": page_url,
            }

        except Exception as e:
            return {
                "status": "error",
                "message": f"Table extraction failed: {str(e)}",
                "url": page_url,
                "table_index": table_index,
            }