import os

# 1) Dynamo 완전 비활성화
os.environ["TORCH_DYNAMO_DISABLE"] = "1"

# 2) Triton의 cudagraphs 최적화 비활성화
os.environ["TRITON_DISABLE_CUDAGRAPHS"] = "1"

# (옵션) 경고 무시 설정
import warnings
warnings.filterwarnings("ignore", message="skipping cudagraphs due to mutated inputs")
warnings.filterwarnings("ignore", message="Not enough SMs to use max_autotune_gemm mode")

import torch
# TensorFloat32 연산 활성화 (성능 최적화)
torch.set_float32_matmul_precision('high')

import torch._inductor
torch._inductor.config.triton.cudagraphs = False

import torch._dynamo
# suppress_errors (오류 시 eager로 fallback)
torch._dynamo.config.suppress_errors = True

import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

from threading import Thread
from datasets import load_dataset
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
import pandas as pd
import json
from datetime import datetime
import pyarrow.parquet as pq
import pypdf
import io
import platform
import subprocess
import pytesseract
from pdf2image import convert_from_path
import queue
import time

# -------------------- PDF to Markdown 변환 관련 import --------------------
try:
    import re
    import requests
    from bs4 import BeautifulSoup
    import urllib.request
    import ocrmypdf
    import pytz
    import urllib.parse
    from pypdf import PdfReader
except ModuleNotFoundError as e:
    raise ModuleNotFoundError(
        "필수 모듈이 누락되었습니다. 'beautifulsoup4' 패키지를 설치해주세요.\n"
        "예: pip install beautifulsoup4"
    )
# ---------------------------------------------------------------------------

# 전역 변수
current_file_context = None

# 환경 변수 설정
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
MODEL_NAME = MODEL_ID.split("/")[-1]

model = None  # 전역에서 관리
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# (1) 위키피디아 데이터셋 로드
wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna")
print("Wikipedia dataset loaded:", wiki_dataset)

# (2) TF-IDF 벡터라이저 초기화 및 학습 (일부만 사용)
print("TF-IDF 벡터화 시작...")
questions = wiki_dataset['train']['question'][:10000]
vectorizer = TfidfVectorizer(max_features=1000)
question_vectors = vectorizer.fit_transform(questions)
print("TF-IDF 벡터화 완료")

# ------------------------- ChatHistory 클래스 -------------------------
class ChatHistory:
    def __init__(self):
        self.history = []
        self.history_file = "/tmp/chat_history.json"
        self.load_history()

    def add_conversation(self, user_msg: str, assistant_msg: str):
        conversation = {
            "timestamp": datetime.now().isoformat(),
            "messages": [
                {"role": "user", "content": user_msg},
                {"role": "assistant", "content": assistant_msg}
            ]
        }
        self.history.append(conversation)
        self.save_history()

    def format_for_display(self):
        formatted = []
        for conv in self.history:
            formatted.append([
                conv["messages"][0]["content"],
                conv["messages"][1]["content"]
            ])
        return formatted

    def get_messages_for_api(self):
        messages = []
        for conv in self.history:
            messages.extend([
                {"role": "user", "content": conv["messages"][0]["content"]},
                {"role": "assistant", "content": conv["messages"][1]["content"]}
            ])
        return messages

    def clear_history(self):
        self.history = []
        self.save_history()

    def save_history(self):
        try:
            with open(self.history_file, 'w', encoding='utf-8') as f:
                json.dump(self.history, f, ensure_ascii=False, indent=2)
        except Exception as e:
            print(f"히스토리 저장 실패: {e}")

    def load_history(self):
        try:
            if os.path.exists(self.history_file):
                with open(self.history_file, 'r', encoding='utf-8') as f:
                    self.history = json.load(f)
        except Exception as e:
            print(f"히스토리 로드 실패: {e}")
            self.history = []

chat_history = ChatHistory()

# ------------------------- 위키 문서 검색 (TF-IDF) -------------------------
def find_relevant_context(query, top_k=3):
    query_vector = vectorizer.transform([query])
    similarities = (query_vector * question_vectors.T).toarray()[0]
    top_indices = np.argsort(similarities)[-top_k:][::-1]

    relevant_contexts = []
    for idx in top_indices:
        if similarities[idx] > 0:
            relevant_contexts.append({
                'question': questions[idx],
                'answer': wiki_dataset['train']['answer'][idx],
                'similarity': similarities[idx]
            })
    return relevant_contexts

def init_msg():
    return "파일을 분석하고 있습니다..."

# -------------------- PDF 파일을 Markdown으로 변환하는 유틸 함수들 --------------------
def extract_text_from_pdf(reader: PdfReader) -> str:
    full_text = ""
    for idx, page in enumerate(reader.pages):
        text = page.extract_text() or ""
        if len(text) > 0:
            full_text += f"---- Page {idx+1} ----\n" + text + "\n\n"
    return full_text.strip()

def convert_pdf_to_markdown(pdf_file: str):
    try:
        reader = PdfReader(pdf_file)
    except Exception as e:
        return f"PDF 파일을 읽는 중 오류 발생: {e}", None, None

    raw_meta = reader.metadata
    metadata = {
        "author": raw_meta.author if raw_meta else None,
        "creator": raw_meta.creator if raw_meta else None,
        "producer": raw_meta.producer if raw_meta else None,
        "subject": raw_meta.subject if raw_meta else None,
        "title": raw_meta.title if raw_meta else None,
    }

    full_text = extract_text_from_pdf(reader)

    image_count = sum(len(page.images) for page in reader.pages)
    if image_count > 0 and len(full_text) < 1000:
        try:
            out_pdf_file = pdf_file.replace(".pdf", "_ocr.pdf")
            ocrmypdf.ocr(pdf_file, out_pdf_file, force_ocr=True)
            reader_ocr = PdfReader(out_pdf_file)
            full_text = extract_text_from_pdf(reader_ocr)
        except Exception as e:
            full_text = f"OCR 처리 중 오류 발생: {e}\n\n원본 PDF 텍스트:\n\n" + full_text

    return full_text, metadata, pdf_file

# ------------------------- 파일 분석 함수 -------------------------
def analyze_file_content(content, file_type):
    if file_type in ['parquet', 'csv']:
        try:
            lines = content.split('\n')
            header = lines[0]
            columns = header.count('|') - 1
            rows = len(lines) - 3
            return f"📊 Dataset Structure: {columns} columns, {rows} rows"
        except:
            return "❌ Failed to analyze dataset structure"

    lines = content.split('\n')
    total_lines = len(lines)
    non_empty_lines = len([line for line in lines if line.strip()])

    if any(keyword in content.lower() for keyword in ['def ', 'class ', 'import ', 'function']):
        functions = len([line for line in lines if 'def ' in line])
        classes = len([line for line in lines if 'class ' in line])
        imports = len([line for line in lines if 'import ' in line or 'from ' in line])
        return f"💻 Code Structure: {total_lines} lines (Functions: {functions}, Classes: {classes}, Imports: {imports})"

    paragraphs = content.count('\n\n') + 1
    words = len(content.split())
    return f"📝 Document Structure: {total_lines} lines, {paragraphs} paragraphs, approximately {words} words"

def read_uploaded_file(file):
    if file is None:
        return "", ""

    import pyarrow.parquet as pq
    import pandas as pd
    from tabulate import tabulate

    try:
        file_ext = os.path.splitext(file.name)[1].lower()

        if file_ext == '.parquet':
            try:
                table = pq.read_table(file.name)
                df = table.to_pandas()

                content = f"📊 Parquet File Analysis:\n\n"
                content += f"1. Basic Information:\n"
                content += f"- Total Rows: {len(df):,}\n"
                content += f"- Total Columns: {len(df.columns)}\n"
                mem_usage = df.memory_usage(deep=True).sum() / 1024 / 1024
                content += f"- Memory Usage: {mem_usage:.2f} MB\n\n"

                content += f"2. Column Information:\n"
                for col in df.columns:
                    content += f"- {col} ({df[col].dtype})\n"

                content += f"\n3. Data Preview:\n"
                content += tabulate(df.head(5), headers='keys', tablefmt='pipe', showindex=False)

                content += f"\n\n4. Missing Values:\n"
                null_counts = df.isnull().sum()
                for col, count in null_counts[null_counts > 0].items():
                    rate = count / len(df) * 100
                    content += f"- {col}: {count:,} ({rate:.1f}%)\n"

                numeric_cols = df.select_dtypes(include=['int64', 'float64']).columns
                if len(numeric_cols) > 0:
                    content += f"\n5. Numeric Column Statistics:\n"
                    stats_df = df[numeric_cols].describe()
                    content += tabulate(stats_df, headers='keys', tablefmt='pipe')

                return content, "parquet"
            except Exception as e:
                return f"Error reading Parquet file: {str(e)}", "error"

        elif file_ext == '.pdf':
            try:
                markdown_text, metadata, processed_pdf_path = convert_pdf_to_markdown(file.name)
                if metadata is None:
                    return f"PDF 파일 변환 오류 또는 읽기 실패.\n\n원본 메시지:\n{markdown_text}", "error"

                content = "# PDF to Markdown Conversion\n\n"
                content += "## Metadata\n"
                for k, v in metadata.items():
                    content += f"**{k.capitalize()}**: {v}\n\n"
                content += "## Extracted Text\n\n"
                content += markdown_text

                return content, "pdf"
            except Exception as e:
                return f"Error reading PDF file: {str(e)}", "error"

        elif file_ext == '.csv':
            encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
            for encoding in encodings:
                try:
                    df = pd.read_csv(file.name, encoding=encoding)
                    content = f"📊 CSV File Analysis:\n\n"
                    content += f"1. Basic Information:\n"
                    content += f"- Total Rows: {len(df):,}\n"
                    content += f"- Total Columns: {len(df.columns)}\n"
                    mem_usage = df.memory_usage(deep=True).sum() / 1024 / 1024
                    content += f"- Memory Usage: {mem_usage:.2f} MB\n\n"

                    content += f"2. Column Information:\n"
                    for col in df.columns:
                        content += f"- {col} ({df[col].dtype})\n"

                    content += f"\n3. Data Preview:\n"
                    content += df.head(5).to_markdown(index=False)

                    content += f"\n\n4. Missing Values:\n"
                    null_counts = df.isnull().sum()
                    for col, count in null_counts[null_counts > 0].items():
                        rate = count / len(df) * 100
                        content += f"- {col}: {count:,} ({rate:.1f}%)\n"

                    return content, "csv"
                except UnicodeDecodeError:
                    continue
            raise UnicodeDecodeError(
                f"Unable to read file with supported encodings ({', '.join(encodings)})"
            )

        else:
            encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
            for encoding in encodings:
                try:
                    with open(file.name, 'r', encoding=encoding) as f:
                        content = f.read()

                    lines = content.split('\n')
                    total_lines = len(lines)
                    non_empty_lines = len([line for line in lines if line.strip()])
                    is_code = any(
                        keyword in content.lower()
                        for keyword in ['def ', 'class ', 'import ', 'function']
                    )

                    analysis = "\n📝 File Analysis:\n"
                    if is_code:
                        functions = sum('def ' in line for line in lines)
                        classes = sum('class ' in line for line in lines)
                        imports = sum(
                            ('import ' in line) or ('from ' in line)
                            for line in lines
                        )
                        analysis += f"- File Type: Code\n"
                        analysis += f"- Total Lines: {total_lines:,}\n"
                        analysis += f"- Functions: {functions}\n"
                        analysis += f"- Classes: {classes}\n"
                        analysis += f"- Import Statements: {imports}\n"
                    else:
                        words = len(content.split())
                        chars = len(content)
                        analysis += f"- File Type: Text\n"
                        analysis += f"- Total Lines: {total_lines:,}\n"
                        analysis += f"- Non-empty Lines: {non_empty_lines:,}\n"
                        analysis += f"- Word Count: {words:,}\n"
                        analysis += f"- Character Count: {chars:,}\n"

                    return content + analysis, "text"

                except UnicodeDecodeError:
                    continue

            raise UnicodeDecodeError(
                f"Unable to read file with supported encodings ({', '.join(encodings)})"
            )

    except Exception as e:
        return f"Error reading file: {str(e)}", "error"

# ------------------------- CSS -------------------------
CSS = """
/* (생략: 동일) */
"""

def clear_cuda_memory():
    if hasattr(torch.cuda, 'empty_cache'):
        with torch.cuda.device('cuda'):
            torch.cuda.empty_cache()

# ------------------------- 모델 로딩 함수 -------------------------
@spaces.GPU
def load_model():
    try:
        clear_cuda_memory()
        loaded_model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            low_cpu_mem_usage=True,
        )
        # (중요) 모델 기본 config에서도 캐시 사용 꺼둘 수 있음
        loaded_model.config.use_cache = False
        return loaded_model
    except Exception as e:
        print(f"모델 로드 오류: {str(e)}")
        raise

def build_prompt(conversation: list) -> str:
    prompt = ""
    for msg in conversation:
        if msg["role"] == "user":
            prompt += "User: " + msg["content"] + "\n"
        elif msg["role"] == "assistant":
            prompt += "Assistant: " + msg["content"] + "\n"
    prompt += "Assistant: "
    return prompt

# ------------------------- 메시지 스트리밍 함수 -------------------------
@spaces.GPU
def stream_chat(
    message: str,
    history: list,
    uploaded_file,
    temperature: float,
    max_new_tokens: int,
    top_p: float,
    top_k: int,
    penalty: float
):
    global model, current_file_context

    try:
        if model is None:
            model = load_model()

        print(f'[User input] message: {message}')
        print(f'[History] {history}')

        # 1) 파일 업로드 처리
        file_context = ""
        if uploaded_file and message == "파일을 분석하고 있습니다...":
            current_file_context = None
            try:
                content, file_type = read_uploaded_file(uploaded_file)
                if content:
                    file_analysis = analyze_file_content(content, file_type)
                    file_context = (
                        f"\n\n📄 파일 분석 결과:\n{file_analysis}"
                        f"\n\n파일 내용:\n```\n{content}\n```"
                    )
                    current_file_context = file_context
                    message = "업로드된 파일을 분석해주세요."
            except Exception as e:
                print(f"[파일 분석 오류] {str(e)}")
                file_context = f"\n\n❌ 파일 분석 중 오류가 발생했습니다: {str(e)}"
        elif current_file_context:
            file_context = current_file_context

        # 2) 위키 컨텍스트
        wiki_context = ""
        try:
            relevant_contexts = find_relevant_context(message)
            if relevant_contexts:
                wiki_context = "\n\n관련 위키피디아 정보:\n"
                for ctx in relevant_contexts:
                    wiki_context += (
                        f"Q: {ctx['question']}\n"
                        f"A: {ctx['answer']}\n"
                        f"유사도: {ctx['similarity']:.3f}\n\n"
                    )
        except Exception as e:
            print(f"[컨텍스트 검색 오류] {str(e)}")

        # 3) 대화 이력 축소
        max_history_length = 10
        if len(history) > max_history_length:
            history = history[-max_history_length:]

        conversation = []
        for prompt, answer in history:
            conversation.extend([
                {"role": "user", "content": prompt},
                {"role": "assistant", "content": answer}
            ])

        # 4) 최종 메시지
        final_message = message
        if file_context:
            final_message = file_context + "\n현재 질문: " + message
        if wiki_context:
            final_message = wiki_context + "\n현재 질문: " + message
        if file_context and wiki_context:
            final_message = file_context + wiki_context + "\n현재 질문: " + message

        conversation.append({"role": "user", "content": final_message})

        # 5) 토큰화
        input_ids_str = build_prompt(conversation)
        max_context = 8192
        tokenized_input = tokenizer(input_ids_str, return_tensors="pt")
        input_length = tokenized_input["input_ids"].shape[1]

        # 6) 컨텍스트 초과 시 자르기
        if input_length > max_context - max_new_tokens:
            print(f"[경고] 입력이 너무 깁니다: {input_length} 토큰 -> 잘라냄.")
            min_generation = min(256, max_new_tokens)
            new_desired_input_length = max_context - min_generation
            tokens = tokenizer.encode(input_ids_str)
            if len(tokens) > new_desired_input_length:
                tokens = tokens[-new_desired_input_length:]
                input_ids_str = tokenizer.decode(tokens)
            tokenized_input = tokenizer(input_ids_str, return_tensors="pt")
            input_length = tokenized_input["input_ids"].shape[1]

        print(f"[토큰 길이] {input_length}")
        inputs = tokenized_input.to("cuda")

        # 7) 남은 토큰 수로 max_new_tokens 보정
        remaining = max_context - input_length
        if remaining < max_new_tokens:
            print(f"[max_new_tokens 조정] {max_new_tokens} -> {remaining}")
            max_new_tokens = remaining

        # 8) TextIteratorStreamer 설정
        streamer = TextIteratorStreamer(
            tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True
        )

        # ★ use_cache=False 설정 (중요) ★
        generate_kwargs = dict(
            **inputs,
            streamer=streamer,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=penalty,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            use_cache=False,   # ← 여기가 핵심!
        )

        clear_cuda_memory()

        # 9) 별도 스레드로 모델 호출
        thread = Thread(target=model.generate, kwargs=generate_kwargs)
        thread.start()

        # 10) 스트리밍
        buffer = ""
        partial_message = ""
        last_yield_time = time.time()

        try:
            for new_text in streamer:
                buffer += new_text
                partial_message += new_text

                # 타이밍 or 일정 길이마다 UI 업데이트
                current_time = time.time()
                if (current_time - last_yield_time > 0.1) or (len(partial_message) > 20):
                    yield "", history + [[message, buffer]]
                    partial_message = ""
                    last_yield_time = current_time

            # 마지막 출력
            if buffer:
                yield "", history + [[message, buffer]]

            # 대화 히스토리 저장
            chat_history.add_conversation(message, buffer)

        except Exception as e:
            print(f"[스트리밍 중 오류] {str(e)}")
            if not buffer:
                buffer = f"응답 생성 중 오류 발생: {str(e)}"
            yield "", history + [[message, buffer]]

        if thread.is_alive():
            thread.join(timeout=5.0)

        clear_cuda_memory()

    except Exception as e:
        import traceback
        error_details = traceback.format_exc()
        error_message = f"오류가 발생했습니다: {str(e)}\n{error_details}"
        print(f"[Stream chat 오류] {error_message}")
        clear_cuda_memory()
        yield "", history + [[message, error_message]]

# ------------------------- Gradio UI 구성 -------------------------
def create_demo():
    with gr.Blocks(css=CSS) as demo:
        with gr.Column(elem_classes="markdown-style"):
            gr.Markdown("""
                # 🤖 RAGOndevice
                #### 📊 RAG: Upload and Analyze Files (TXT, CSV, PDF, Parquet files)
                Upload your files for data analysis and learning
            """)

        chatbot = gr.Chatbot(
            value=[],
            height=600,
            label="GiniGEN AI Assistant",
            elem_classes="chat-container"
        )

        with gr.Row(elem_classes="input-container"):
            with gr.Column(scale=1, min_width=70):
                file_upload = gr.File(
                    type="filepath",
                    elem_classes="file-upload-icon",
                    scale=1,
                    container=True,
                    interactive=True,
                    show_label=False
                )

            with gr.Column(scale=3):
                msg = gr.Textbox(
                    show_label=False,
                    placeholder="Type your message here... 💭",
                    container=False,
                    elem_classes="input-textbox",
                    scale=1
                )

            with gr.Column(scale=1, min_width=70):
                send = gr.Button(
                    "Send",
                    elem_classes="send-button custom-button",
                    scale=1
                )

            with gr.Column(scale=1, min_width=70):
                clear = gr.Button(
                    "Clear",
                    elem_classes="clear-button custom-button",
                    scale=1
                )

        # 고급 설정
        with gr.Accordion("🎮 Advanced Settings", open=False):
            with gr.Row():
                with gr.Column(scale=1):
                    temperature = gr.Slider(
                        minimum=0, maximum=1, step=0.1, value=0.8,
                        label="Creativity Level 🎨"
                    )
                    max_new_tokens = gr.Slider(
                        minimum=128, maximum=8000, step=1, value=4000,
                        label="Maximum Token Count 📝"
                    )
                with gr.Column(scale=1):
                    top_p = gr.Slider(
                        minimum=0.0, maximum=1.0, step=0.1, value=0.8,
                        label="Diversity Control 🎯"
                    )
                    top_k = gr.Slider(
                        minimum=1, maximum=20, step=1, value=20,
                        label="Selection Range 📊"
                    )
                    penalty = gr.Slider(
                        minimum=0.0, maximum=2.0, step=0.1, value=1.0,
                        label="Repetition Penalty 🔄"
                    )

        # 예시 입력
        gr.Examples(
            examples=[
                ["Please analyze this code and suggest improvements:\ndef fibonacci(n):\n    if n <= 1: return n\n    return fibonacci(n-1) + fibonacci(n-2)"],
                ["Please analyze this data and provide insights:\nAnnual Revenue (Million)\n2019: 1200\n2020: 980\n2021: 1450\n2022: 2100\n2023: 1890"],
                ["Please solve this math problem step by step: 'When a circle's area is twice that of its inscribed square, find the relationship between the circle's radius and the square's side length.'"],
                ["Please analyze this marketing campaign's ROI and suggest improvements:\nTotal Cost: $50,000\nReach: 1M users\nClick Rate: 2.3%\nConversion Rate: 0.8%\nAverage Purchase: $35"],
            ],
            inputs=msg
        )

        # 대화 내용 초기화
        def clear_conversation():
            global current_file_context
            current_file_context = None
            return [], None, "Start a new conversation..."

        # 메시지 전송(Submit)
        msg.submit(
            stream_chat,
            inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
            outputs=[msg, chatbot]
        )
        send.click(
            stream_chat,
            inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
            outputs=[msg, chatbot]
        )

        # 파일 업로드 이벤트
        file_upload.change(
            fn=lambda: ("처리 중...", [["시스템", "파일을 분석 중입니다. 잠시만 기다려주세요..."]]),
            outputs=[msg, chatbot],
            queue=False
        ).then(
            fn=init_msg,
            outputs=msg,
            queue=False
        ).then(
            fn=stream_chat,
            inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
            outputs=[msg, chatbot],
            queue=True
        )

        # Clear 버튼
        clear.click(
            fn=clear_conversation,
            outputs=[chatbot, file_upload, msg],
            queue=False
        )

        return demo

if __name__ == "__main__":
    demo = create_demo()
    demo.launch()