#
# SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org>
# SPDX-License-Identifier: Apache-2.0
#

import asyncio
import codecs
import docx
import gradio as gr
import httpx
import json
import os
import pandas as pd
import pdfplumber
import pytesseract
import random
import requests
import threading
import uuid
import zipfile
import io

from PIL import Image
from pathlib import Path
from pptx import Presentation
from openpyxl import load_workbook

os.system("apt-get update -q -y && apt-get install -q -y tesseract-ocr tesseract-ocr-eng tesseract-ocr-ind libleptonica-dev libtesseract-dev")

JARVIS_INIT = json.loads(os.getenv("HELLO", "[]"))

INTERNAL_AI_GET_SERVER = os.getenv("INTERNAL_AI_GET_SERVER")
INTERNAL_TRAINING_DATA = os.getenv("INTERNAL_TRAINING_DATA")

SYSTEM_PROMPT_MAPPING = json.loads(os.getenv("SYSTEM_PROMPT_MAPPING", "{}"))
SYSTEM_PROMPT_DEFAULT = os.getenv("DEFAULT_SYSTEM")

LINUX_SERVER_HOSTS = [h for h in json.loads(os.getenv("LINUX_SERVER_HOST", "[]")) if h]

LINUX_SERVER_PROVIDER_KEYS = [k for k in json.loads(os.getenv("LINUX_SERVER_PROVIDER_KEY", "[]")) if k]
LINUX_SERVER_PROVIDER_KEYS_MARKED = set()
LINUX_SERVER_PROVIDER_KEYS_ATTEMPTS = {}

LINUX_SERVER_ERRORS = set(map(int, os.getenv("LINUX_SERVER_ERROR", "").split(",")))

AI_TYPES = {f"AI_TYPE_{i}": os.getenv(f"AI_TYPE_{i}") for i in range(1, 8)}

RESPONSES = {f"RESPONSE_{i}": os.getenv(f"RESPONSE_{i}") for i in range(1, 11)}

MODEL_MAPPING = json.loads(os.getenv("MODEL_MAPPING", "{}"))
MODEL_CONFIG = json.loads(os.getenv("MODEL_CONFIG", "{}"))
MODEL_CHOICES = list(MODEL_MAPPING.values())

DEFAULT_CONFIG = json.loads(os.getenv("DEFAULT_CONFIG", "{}"))
DEFAULT_MODEL_KEY = list(MODEL_MAPPING.keys())[0] if MODEL_MAPPING else None

META_TAGS = os.getenv("META_TAGS")

ALLOWED_EXTENSIONS = json.loads(os.getenv("ALLOWED_EXTENSIONS", "[]"))

class SessionWithID(requests.Session):
    def __init__(sess):
        super().__init__()
        sess.session_id = str(uuid.uuid4())
        sess.stop_event = asyncio.Event()
        sess.cancel_token = {"cancelled": False}

def create_session():
    return SessionWithID()

def ensure_stop_event(sess):
    if not hasattr(sess, "stop_event"):
        sess.stop_event = asyncio.Event()
    if not hasattr(sess, "cancel_token"):
        sess.cancel_token = {"cancelled": False}

def marked_item(item, marked, attempts):
    marked.add(item)
    attempts[item] = attempts.get(item, 0) + 1
    if attempts[item] >= 3:
        def remove():
            marked.discard(item)
            attempts.pop(item, None)
        threading.Timer(300, remove).start()

def get_model_key(display):
    return next((k for k, v in MODEL_MAPPING.items() if v == display), DEFAULT_MODEL_KEY)

def extract_pdf_content(fp):
    content = ""
    try:
        with pdfplumber.open(fp) as pdf:
            for page in pdf.pages:
                text = page.extract_text() or ""
                content += text + "\n"
                if page.images:
                    img_obj = page.to_image(resolution=300)
                    for img in page.images:
                        bbox = (img["x0"], img["top"], img["x1"], img["bottom"])
                        cropped = img_obj.original.crop(bbox)
                        ocr_text = pytesseract.image_to_string(cropped)
                        if ocr_text.strip():
                            content += ocr_text + "\n"
                tables = page.extract_tables()
                for table in tables:
                    for row in table:
                        cells = [str(cell) for cell in row if cell is not None]
                        if cells:
                            content += "\t".join(cells) + "\n"
    except Exception as e:
        content += f"{fp}: {e}"
    return content.strip()

def extract_docx_content(fp):
    content = ""
    try:
        doc = docx.Document(fp)
        for para in doc.paragraphs:
            content += para.text + "\n"
        for table in doc.tables:
            for row in table.rows:
                cells = [cell.text for cell in row.cells]
                content += "\t".join(cells) + "\n"
        with zipfile.ZipFile(fp) as z:
            for file in z.namelist():
                if file.startswith("word/media/"):
                    data = z.read(file)
                    try:
                        img = Image.open(io.BytesIO(data))
                        ocr_text = pytesseract.image_to_string(img)
                        if ocr_text.strip():
                            content += ocr_text + "\n"
                    except:
                        pass
    except Exception as e:
        content += f"{fp}: {e}"
    return content.strip()

def extract_excel_content(fp):
    content = ""
    try:
        sheets = pd.read_excel(fp, sheet_name=None)
        for name, df in sheets.items():
            content += f"Sheet: {name}\n"
            content += df.to_csv(index=False) + "\n"
        wb = load_workbook(fp, data_only=True)
        if wb._images:
            for image in wb._images:
                img = image.ref
                if isinstance(img, bytes):
                    try:
                        pil_img = Image.open(io.BytesIO(img))
                        ocr_text = pytesseract.image_to_string(pil_img)
                        if ocr_text.strip():
                            content += ocr_text + "\n"
                    except:
                        pass
    except Exception as e:
        content += f"{fp}: {e}"
    return content.strip()

def extract_pptx_content(fp):
    content = ""
    try:
        prs = Presentation(fp)
        for slide in prs.slides:
            for shape in slide.shapes:
                if hasattr(shape, "text") and shape.text:
                    content += shape.text + "\n"
                if shape.shape_type == 13 and hasattr(shape, "image") and shape.image:
                    try:
                        img = Image.open(io.BytesIO(shape.image.blob))
                        ocr_text = pytesseract.image_to_string(img)
                        if ocr_text.strip():
                            content += ocr_text + "\n"
                    except:
                        pass
            for shape in slide.shapes:
                if shape.has_table:
                    table = shape.table
                    for row in table.rows:
                        cells = [cell.text for cell in row.cells]
                        content += "\t".join(cells) + "\n"
    except Exception as e:
        content += f"{fp}: {e}"
    return content.strip()

def extract_file_content(fp):
    ext = Path(fp).suffix.lower()
    if ext == ".pdf":
        return extract_pdf_content(fp)
    elif ext in [".doc", ".docx"]:
        return extract_docx_content(fp)
    elif ext in [".xlsx", ".xls"]:
        return extract_excel_content(fp)
    elif ext in [".ppt", ".pptx"]:
        return extract_pptx_content(fp)
    else:
        try:
            return Path(fp).read_text(encoding="utf-8").strip()
        except Exception as e:
            return f"{fp}: {e}"

async def fetch_response_stream_async(host, key, model, msgs, cfg, sid, stop_event, cancel_token):
    for t in [5, 10]:
        try:
            async with httpx.AsyncClient(timeout=t) as client:
                async with client.stream("POST", host, json={**{"model": model, "messages": msgs, "session_id": sid, "stream": True}, **cfg}, headers={"Authorization": f"Bearer {key}"}) as response:
                    if response.status_code in LINUX_SERVER_ERRORS:
                        marked_item(key, LINUX_SERVER_PROVIDER_KEYS_MARKED, LINUX_SERVER_PROVIDER_KEYS_ATTEMPTS)
                        return
                    async for line in response.aiter_lines():
                        if stop_event.is_set() or cancel_token["cancelled"]:
                            return
                        if not line:
                            continue
                        if line.startswith("data: "):
                            data = line[6:]
                            if data.strip() == RESPONSES["RESPONSE_10"]:
                                return
                            try:
                                j = json.loads(data)
                                if isinstance(j, dict) and j.get("choices"):
                                    for ch in j["choices"]:
                                        delta = ch.get("delta", {})
                                        if "reasoning" in delta and delta["reasoning"]:
                                            decoded = delta["reasoning"].encode('utf-8').decode('unicode_escape')
                                            yield ("reasoning", decoded)
                                        if "content" in delta and delta["content"]:
                                            yield ("content", delta["content"])
                            except:
                                continue
        except:
            continue
        marked_item(key, LINUX_SERVER_PROVIDER_KEYS_MARKED, LINUX_SERVER_PROVIDER_KEYS_ATTEMPTS)
    return

async def chat_with_model_async(history, user_input, model_display, sess, custom_prompt):
    ensure_stop_event(sess)
    sess.stop_event.clear()
    sess.cancel_token["cancelled"] = False
    if not LINUX_SERVER_PROVIDER_KEYS or not LINUX_SERVER_HOSTS:
        yield ("content", RESPONSES["RESPONSE_3"])
        return
    if not hasattr(sess, "session_id") or not sess.session_id:
        sess.session_id = str(uuid.uuid4())
    model_key = get_model_key(model_display)
    cfg = MODEL_CONFIG.get(model_key, DEFAULT_CONFIG)
    msgs = [{"role": "user", "content": u} for u, _ in history] + [{"role": "assistant", "content": a} for _, a in history if a]
    prompt = INTERNAL_TRAINING_DATA if model_key == DEFAULT_MODEL_KEY and INTERNAL_TRAINING_DATA else (custom_prompt or SYSTEM_PROMPT_MAPPING.get(model_key, SYSTEM_PROMPT_DEFAULT))
    msgs.insert(0, {"role": "system", "content": prompt})
    msgs.append({"role": "user", "content": user_input})
    candidates = [(h, k) for h in LINUX_SERVER_HOSTS for k in LINUX_SERVER_PROVIDER_KEYS]
    random.shuffle(candidates)
    for h, k in candidates:
        stream_gen = fetch_response_stream_async(h, k, model_key, msgs, cfg, sess.session_id, sess.stop_event, sess.cancel_token)
        got_responses = False
        async for chunk in stream_gen:
            if sess.stop_event.is_set() or sess.cancel_token["cancelled"]:
                return
            got_responses = True
            yield chunk
        if got_responses:
            return
    yield ("content", RESPONSES["RESPONSE_2"])

async def respond_async(multi, history, model_display, sess, custom_prompt):
    ensure_stop_event(sess)
    sess.stop_event.clear()
    sess.cancel_token["cancelled"] = False
    msg_input = {"text": multi.get("text", "").strip(), "files": multi.get("files", [])}
    if not msg_input["text"] and not msg_input["files"]:
        yield history, gr.update(value="", interactive=True, submit_btn=True, stop_btn=False), sess
        return
    inp = ""
    for f in msg_input["files"]:
        fp = f.get("data", f.get("name", "")) if isinstance(f, dict) else f
        inp += f"{Path(fp).name}\n\n{extract_file_content(fp)}\n\n"
    if msg_input["text"]:
        inp += msg_input["text"]
    history.append([inp, RESPONSES["RESPONSE_8"]])
    yield history, gr.update(interactive=False, submit_btn=False, stop_btn=True), sess
    queue = asyncio.Queue()
    async def background():
        reasoning = ""
        responses = ""
        content_started = False
        ignore_reasoning = False
        async for typ, chunk in chat_with_model_async(history, inp, model_display, sess, custom_prompt):
            if sess.stop_event.is_set() or sess.cancel_token["cancelled"]:
                break
            if typ == "reasoning":
                if ignore_reasoning:
                    continue
                reasoning += chunk
                await queue.put(("reasoning", reasoning))
            elif typ == "content":
                if not content_started:
                    content_started = True
                    ignore_reasoning = True
                    responses = chunk
                    await queue.put(("reasoning", ""))
                    await queue.put(("replace", responses))
                else:
                    responses += chunk
                    await queue.put(("append", responses))
        await queue.put(None)
        return responses
    bg_task = asyncio.create_task(background())
    stop_task = asyncio.create_task(sess.stop_event.wait())
    try:
        while True:
            done, _ = await asyncio.wait({stop_task, asyncio.create_task(queue.get())}, return_when=asyncio.FIRST_COMPLETED)
            if stop_task in done:
                sess.cancel_token["cancelled"] = True
                bg_task.cancel()
                history[-1][1] = RESPONSES["RESPONSE_1"]
                yield history, gr.update(value="", interactive=True, submit_btn=True, stop_btn=False), sess
                return
            for d in done:
                result = d.result()
                if result is None:
                    raise StopAsyncIteration
                action, text = result
                history[-1][1] = text
                yield history, gr.update(interactive=False, submit_btn=False, stop_btn=True), sess
    except StopAsyncIteration:
        pass
    finally:
        stop_task.cancel()
    full_response = await bg_task
    yield history, gr.update(value="", interactive=True, submit_btn=True, stop_btn=False), sess

def change_model(new):
    visible = new != MODEL_CHOICES[0]
    default = SYSTEM_PROMPT_MAPPING.get(get_model_key(new), SYSTEM_PROMPT_DEFAULT)
    return [], create_session(), new, default

def stop_response(history, sess):
    ensure_stop_event(sess)
    sess.stop_event.set()
    sess.cancel_token["cancelled"] = True
    if history:
        history[-1][1] = RESPONSES["RESPONSE_1"]
    return history, None, create_session()

with gr.Blocks(fill_height=True, fill_width=True, title=AI_TYPES["AI_TYPE_4"], head=META_TAGS) as jarvis:
    user_history = gr.State([])
    user_session = gr.State(create_session())
    selected_model = gr.State(MODEL_CHOICES[0] if MODEL_CHOICES else "")
    J_A_R_V_I_S = gr.State("")
    chatbot = gr.Chatbot(label=AI_TYPES["AI_TYPE_1"], show_copy_button=True, scale=1, elem_id=AI_TYPES["AI_TYPE_2"], examples=JARVIS_INIT)
    msg = gr.MultimodalTextbox(show_label=False, placeholder=RESPONSES["RESPONSE_5"], interactive=True, file_count="single", file_types=ALLOWED_EXTENSIONS)
    with gr.Sidebar(open=False):
        model_radio = gr.Radio(show_label=False, choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
    model_radio.change(fn=change_model, inputs=[model_radio], outputs=[user_history, user_session, selected_model, J_A_R_V_I_S])
    def on_example_select(evt: gr.SelectData):
        return evt.value
    chatbot.example_select(fn=on_example_select, inputs=[], outputs=[msg]).then(fn=respond_async, inputs=[msg, user_history, selected_model, user_session, J_A_R_V_I_S], outputs=[chatbot, msg, user_session])
    def clear_chat(history, sess, prompt, model):
        return [], create_session(), prompt, model, []
    chatbot.clear(fn=clear_chat, inputs=[user_history, user_session, J_A_R_V_I_S, selected_model], outputs=[chatbot, user_session, J_A_R_V_I_S, selected_model, user_history])
    msg.submit(fn=respond_async, inputs=[msg, user_history, selected_model, user_session, J_A_R_V_I_S], outputs=[chatbot, msg, user_session], api_name=INTERNAL_AI_GET_SERVER, show_progress="full", show_progress_on=[chatbot, msg])
    msg.stop(fn=stop_response, inputs=[user_history, user_session], outputs=[chatbot, msg, user_session])
jarvis.queue(default_concurrency_limit=2).launch(max_file_size="1mb")