import gradio as gr
import base64
import requests
import json
import re
import os
import uuid
from datetime import datetime

# --- Configuration ---
# IMPORTANT: Set your OPENROUTER_API_KEY as a Hugging Face Space Secret
OPENROUTER_API_KEY = "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"
IMAGE_MODEL = "opengvlab/internvl3-14b:free" # Using the free tier model as specified
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/chat/completions"

# --- Global State (managed within Gradio's session if possible, or module-level for simplicity here) ---
# This will be reset each time the processing function is called.
processed_files_data = [] # Stores dicts for each file's details and status
person_profiles = {}      # Stores dicts for each identified person and their documents

# --- Helper Functions ---

def extract_json_from_text(text):
    if not text:
        return {"error": "Empty text provided for JSON extraction."}
    match_block = re.search(r"```json\s*(\{.*?\})\s*```", text, re.DOTALL | re.IGNORECASE)
    if match_block:
        json_str = match_block.group(1)
    else:
        text_stripped = text.strip()
        if text_stripped.startswith("`") and text_stripped.endswith("`"):
            json_str = text_stripped[1:-1]
        else:
            json_str = text_stripped
    try:
        return json.loads(json_str)
    except json.JSONDecodeError as e:
        try:
            first_brace = json_str.find('{')
            last_brace = json_str.rfind('}')
            if first_brace != -1 and last_brace != -1 and last_brace > first_brace:
                potential_json_str = json_str[first_brace : last_brace+1]
                return json.loads(potential_json_str)
            else:
                return {"error": f"Invalid JSON structure: {str(e)}", "original_text": text}
        except json.JSONDecodeError as e2:
             return {"error": f"Invalid JSON structure after attempting substring: {str(e2)}", "original_text": text}

def get_ocr_prompt():
    return f"""You are an advanced OCR and information extraction AI.
Your task is to meticulously analyze this image and extract all relevant information.

Output Format Instructions:
Provide your response as a SINGLE, VALID JSON OBJECT. Do not include any explanatory text before or after the JSON.
The JSON object should have the following top-level keys:
- "document_type_detected": (string) Your best guess of the specific document type (e.g., "Passport", "National ID Card", "Driver's License", "Visa Sticker", "Hotel Confirmation Voucher", "Bank Statement", "Photo of a person").
- "extracted_fields": (object) A key-value map of all extracted information. Be comprehensive. Examples:
    - For passports/IDs: "Surname", "Given Names", "Full Name", "Document Number", "Nationality", "Date of Birth", "Sex", "Place of Birth", "Date of Issue", "Date of Expiry", "Issuing Authority", "Country Code".
    - For hotel reservations: "Guest Name", "Hotel Name", "Booking Reference", "Check-in Date", "Check-out Date".
    - For bank statements: "Account Holder Name", "Account Number", "Bank Name", "Statement Period", "Ending Balance".
    - For photos: "Description" (e.g., "Portrait of a person", "Group photo at a location"), "People Present" (array of strings if multiple).
- "mrz_data": (object or null) If a Machine Readable Zone (MRZ) is present:
    - "raw_mrz_lines": (array of strings) Each line of the MRZ.
    - "parsed_mrz": (object) Key-value pairs of parsed MRZ fields.
    If no MRZ, this field should be null.
- "full_text_ocr": (string) Concatenation of all text found on the document.

Extraction Guidelines:
1.  Prioritize accuracy.
2.  Extract all visible text. Include "Full Name" by combining given and surnames if possible.
3.  For dates, try to use ISO 8601 format (YYYY-MM-DD) if possible, but retain original format if conversion is ambiguous.

Ensure the entire output strictly adheres to the JSON format.
"""

def call_openrouter_ocr(image_filepath):
    if not OPENROUTER_API_KEY:
        return {"error": "OpenRouter API Key not configured."}
    try:
        with open(image_filepath, "rb") as f:
            encoded_image = base64.b64encode(f.read()).decode("utf-8")
        mime_type = "image/jpeg"
        if image_filepath.lower().endswith(".png"):
            mime_type = "image/png"
        elif image_filepath.lower().endswith(".webp"):
            mime_type = "image/webp"
        data_url = f"data:{mime_type};base64,{encoded_image}"
        prompt_text = get_ocr_prompt()
        payload = {
            "model": IMAGE_MODEL,
            "messages": [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": prompt_text},
                        {"type": "image_url", "image_url": {"url": data_url}}
                    ]
                }
            ],
            "max_tokens": 3500,
            "temperature": 0.1,
        }
        headers = {
            "Authorization": f"Bearer {OPENROUTER_API_KEY}",
            "Content-Type": "application/json",
            "HTTP-Referer": "https://huggingface.co/spaces/YOUR_SPACE", 
            "X-Title": "Gradio Document Processor"
        }
        response = requests.post(OPENROUTER_API_URL, headers=headers, json=payload, timeout=180)
        response.raise_for_status()
        result = response.json()
        if "choices" in result and result["choices"]:
            raw_content = result["choices"][0]["message"]["content"]
            return extract_json_from_text(raw_content)
        else:
            return {"error": "No 'choices' in API response from OpenRouter.", "details": result}
    except requests.exceptions.Timeout:
        return {"error": "API request timed out."}
    except requests.exceptions.RequestException as e:
        error_message = f"API Request Error: {str(e)}"
        if hasattr(e, 'response') and e.response is not None:
            error_message += f" Status: {e.response.status_code}, Response: {e.response.text}"
        return {"error": error_message}
    except Exception as e:
        return {"error": f"An unexpected error occurred during OCR: {str(e)}"}

def extract_entities_from_ocr(ocr_json):
    if not ocr_json or "extracted_fields" not in ocr_json or not isinstance(ocr_json.get("extracted_fields"), dict):
        doc_type_from_ocr = "Unknown"
        if isinstance(ocr_json, dict): # ocr_json itself might be an error dict
            doc_type_from_ocr = ocr_json.get("document_type_detected", "Unknown (error in OCR)")
        return {"name": None, "dob": None, "passport_no": None, "doc_type": doc_type_from_ocr}

    fields = ocr_json["extracted_fields"]
    doc_type = ocr_json.get("document_type_detected", "Unknown")
    name_keys = ["full name", "name", "account holder name", "guest name"]
    dob_keys = ["date of birth", "dob"]
    passport_keys = ["document number", "passport number"]
    extracted_name = None
    for key in name_keys:
        for field_key, value in fields.items():
            if key == field_key.lower():
                extracted_name = str(value) if value else None
                break
        if extracted_name: break
    extracted_dob = None
    for key in dob_keys:
        for field_key, value in fields.items():
            if key == field_key.lower():
                extracted_dob = str(value) if value else None
                break
        if extracted_dob: break
    extracted_passport_no = None
    for key in passport_keys:
        for field_key, value in fields.items():
            if key == field_key.lower():
                extracted_passport_no = str(value).replace(" ", "").upper() if value else None
                break
        if extracted_passport_no: break
    return {
        "name": extracted_name,
        "dob": extracted_dob,
        "passport_no": extracted_passport_no,
        "doc_type": doc_type
    }

def normalize_name(name):
    if not name: return ""
    return "".join(filter(str.isalnum, name)).lower()

def get_person_id_and_update_profiles(doc_id, entities, current_persons_data):
    passport_no = entities.get("passport_no")
    name = entities.get("name")
    dob = entities.get("dob")
    if passport_no:
        for p_key, p_data in current_persons_data.items():
            if passport_no in p_data.get("passport_numbers", set()):
                p_data["doc_ids"].add(doc_id)
                if name and not p_data.get("canonical_name"): p_data["canonical_name"] = name
                if dob and not p_data.get("canonical_dob"): p_data["canonical_dob"] = dob
                return p_key
        new_person_key = f"person_{passport_no}"
        current_persons_data[new_person_key] = {
            "canonical_name": name, "canonical_dob": dob,
            "names": {normalize_name(name)} if name else set(),
            "dobs": {dob} if dob else set(),
            "passport_numbers": {passport_no}, "doc_ids": {doc_id},
            "display_name": name or f"Person (ID: {passport_no})"
        }
        return new_person_key
    if name and dob:
        norm_name = normalize_name(name)
        composite_key_nd = f"{norm_name}_{dob}"
        for p_key, p_data in current_persons_data.items():
            if norm_name in p_data.get("names", set()) and dob in p_data.get("dobs", set()):
                p_data["doc_ids"].add(doc_id)
                return p_key
        new_person_key = f"person_{composite_key_nd}_{str(uuid.uuid4())[:4]}"
        current_persons_data[new_person_key] = {
            "canonical_name": name, "canonical_dob": dob,
            "names": {norm_name}, "dobs": {dob},
            "passport_numbers": set(), "doc_ids": {doc_id},
            "display_name": name
        }
        return new_person_key
    if name:
        norm_name = normalize_name(name)
        new_person_key = f"person_{norm_name}_{str(uuid.uuid4())[:4]}"
        current_persons_data[new_person_key] = {
            "canonical_name": name, "canonical_dob": None,
            "names": {norm_name}, "dobs": set(), "passport_numbers": set(),
            "doc_ids": {doc_id}, "display_name": name
        }
        return new_person_key
    generic_person_key = f"unidentified_person_{str(uuid.uuid4())[:6]}"
    current_persons_data[generic_person_key] = {
        "canonical_name": "Unknown", "canonical_dob": None,
        "names": set(), "dobs": set(), "passport_numbers": set(),
        "doc_ids": {doc_id}, "display_name": f"Unknown Person ({doc_id[:6]})"
    }
    return generic_person_key

def format_dataframe_data(current_files_data):
    df_rows = []
    for f_data in current_files_data:
        entities = f_data.get("entities") or {} # CORRECTED LINE HERE
        df_rows.append([
            f_data.get("doc_id", "N/A")[:8],
            f_data.get("filename", "N/A"),
            f_data.get("status", "N/A"),
            entities.get("doc_type", "N/A"),
            entities.get("name", "N/A"),
            entities.get("dob", "N/A"),
            entities.get("passport_no", "N/A"),
            f_data.get("assigned_person_key", "N/A")
        ])
    return df_rows

def format_persons_markdown(current_persons_data, current_files_data):
    if not current_persons_data:
        return "No persons identified yet."
    md_parts = ["## Classified Persons & Documents\n"]
    for p_key, p_data in current_persons_data.items():
        display_name = p_data.get('display_name', p_key)
        md_parts.append(f"### Person: {display_name} (Profile Key: {p_key})")
        if p_data.get("canonical_dob"): md_parts.append(f"* DOB: {p_data['canonical_dob']}")
        if p_data.get("passport_numbers"): md_parts.append(f"* Passport(s): {', '.join(p_data['passport_numbers'])}")
        md_parts.append("* Documents:")
        doc_ids_for_person = p_data.get("doc_ids", set())
        if doc_ids_for_person:
            for doc_id in doc_ids_for_person:
                doc_detail = next((f for f in current_files_data if f["doc_id"] == doc_id), None)
                if doc_detail:
                    filename = doc_detail.get("filename", "Unknown File")
                    doc_entities = doc_detail.get("entities") or {}
                    doc_type = doc_entities.get("doc_type", "Unknown Type")
                    md_parts.append(f"  - {filename} (`{doc_type}`)")
                else:
                    md_parts.append(f"  - Document ID: {doc_id[:8]} (details error)")
        else:
            md_parts.append("  - No documents currently assigned.")
        md_parts.append("\n---\n")
    return "\n".join(md_parts)

def process_uploaded_files(files_list, progress=gr.Progress(track_tqdm=True)):
    global processed_files_data, person_profiles
    processed_files_data = []
    person_profiles = {}
    if not OPENROUTER_API_KEY:
        yield (
            [["N/A", "ERROR", "OpenRouter API Key not configured.", "N/A", "N/A", "N/A", "N/A", "N/A"]],
            "Error: OpenRouter API Key not configured. Please set it in Space Secrets.",
            "{}", "API Key Missing. Processing halted."
        )
        return
    if not files_list:
        yield ([], "No files uploaded.", "{}", "Upload files to begin.")
        return
    for i, file_obj in enumerate(files_list):
        doc_uid = str(uuid.uuid4())
        processed_files_data.append({
            "doc_id": doc_uid,
            "filename": os.path.basename(file_obj.name if hasattr(file_obj, 'name') else f"file_{i+1}.unknown"),
            "filepath": file_obj.name if hasattr(file_obj, 'name') else None, # file_obj itself is filepath if from gr.Files type="filepath"
            "status": "Queued",
            "ocr_json": None,
            "entities": None,
            "assigned_person_key": None
        })
    initial_df_data = format_dataframe_data(processed_files_data)
    initial_persons_md = format_persons_markdown(person_profiles, processed_files_data)
    yield (initial_df_data, initial_persons_md, "{}", f"Initialized. Found {len(files_list)} files.")
    for i, file_data_item in enumerate(progress.tqdm(processed_files_data, desc="Processing Documents")):
        current_doc_id = file_data_item["doc_id"]
        current_filename = file_data_item["filename"]
        if not file_data_item["filepath"]: # Check if filepath is valid
            file_data_item["status"] = "Error: Invalid file path"
            df_data = format_dataframe_data(processed_files_data)
            persons_md = format_persons_markdown(person_profiles, processed_files_data)
            yield(df_data, persons_md, "{}", f"({i+1}/{len(processed_files_data)}) Error with file {current_filename}")
            continue

        file_data_item["status"] = "OCR in Progress..."
        df_data = format_dataframe_data(processed_files_data)
        persons_md = format_persons_markdown(person_profiles, processed_files_data)
        yield (df_data, persons_md, "{}", f"({i+1}/{len(processed_files_data)}) OCR for: {current_filename}")
        ocr_result = call_openrouter_ocr(file_data_item["filepath"])
        file_data_item["ocr_json"] = ocr_result
        if "error" in ocr_result:
            file_data_item["status"] = f"OCR Error: {str(ocr_result['error'])[:50]}..."
            df_data = format_dataframe_data(processed_files_data)
            yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) OCR Error on {current_filename}")
            continue
        file_data_item["status"] = "OCR Done. Extracting Entities..."
        df_data = format_dataframe_data(processed_files_data)
        yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) OCR Done for {current_filename}")
        entities = extract_entities_from_ocr(ocr_result)
        file_data_item["entities"] = entities
        file_data_item["status"] = "Entities Extracted. Classifying..."
        df_data = format_dataframe_data(processed_files_data)
        yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) Entities for {current_filename}")
        person_key = get_person_id_and_update_profiles(current_doc_id, entities, person_profiles)
        file_data_item["assigned_person_key"] = person_key
        file_data_item["status"] = "Classified"
        df_data = format_dataframe_data(processed_files_data)
        persons_md = format_persons_markdown(person_profiles, processed_files_data)
        yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) Classified {current_filename} -> {person_key}")
    final_df_data = format_dataframe_data(processed_files_data)
    final_persons_md = format_persons_markdown(person_profiles, processed_files_data)
    yield (final_df_data, final_persons_md, "{}", f"All {len(processed_files_data)} documents processed.")

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 📄 Intelligent Document Processor & Classifier")
    gr.Markdown(
        "**Upload multiple documents (images of passports, bank statements, hotel reservations, photos, etc.). "
        "The system will perform OCR, attempt to extract key entities, and classify documents by the person they belong to.**\n"
        "Ensure `OPENROUTER_API_KEY` is set as a Secret in your Hugging Face Space."
    )
    if not OPENROUTER_API_KEY:
        gr.Markdown("<h3 style='color:red;'>⚠️ ERROR: `OPENROUTER_API_KEY` is not set in Space Secrets! OCR will fail.</h3>")
    with gr.Row():
        with gr.Column(scale=1):
            files_input = gr.Files(label="Upload Document Images (Bulk)", file_count="multiple", type="filepath") # Using filepath
            process_button = gr.Button("🚀 Process Uploaded Documents", variant="primary")
            overall_status_textbox = gr.Textbox(label="Overall Progress", interactive=False, lines=1)
    gr.Markdown("---")
    gr.Markdown("## Document Processing Details")
    dataframe_headers = ["Doc ID (short)", "Filename", "Status", "Detected Type", "Name", "DOB", "Passport No.", "Assigned Person Key"]
    document_status_df = gr.Dataframe(
        headers=dataframe_headers,
        datatype=["str"] * len(dataframe_headers),
        label="Individual Document Status & Extracted Entities",
        row_count=(1, "dynamic"), # Start with 1 row, dynamically grows
        col_count=(len(dataframe_headers), "fixed"),
        wrap=True
    )
    ocr_json_output = gr.Code(label="Selected Document OCR JSON", language="json", interactive=False)
    gr.Markdown("---")
    person_classification_output_md = gr.Markdown("## Classified Persons & Documents\nNo persons identified yet.")
    process_button.click(
        fn=process_uploaded_files,
        inputs=[files_input],
        outputs=[
            document_status_df,
            person_classification_output_md,
            ocr_json_output,
            overall_status_textbox
        ]
    )
    @document_status_df.select(inputs=None, outputs=ocr_json_output, show_progress="hidden")
    def display_selected_ocr(evt: gr.SelectData):
        if evt.index is None or evt.index[0] is None:
            return "{}"
        selected_row_index = evt.index[0]
        # Ensure processed_files_data is accessible here. If it's truly global, it should be.
        # For safety, one might pass it or make it part of a class if this were more complex.
        if 0 <= selected_row_index < len(processed_files_data):
            selected_doc_data = processed_files_data[selected_row_index]
            if selected_doc_data and selected_doc_data.get("ocr_json"):
                # Check if ocr_json is already a dict, if not, try to parse (though it should be)
                ocr_data_to_display = selected_doc_data["ocr_json"]
                if isinstance(ocr_data_to_display, str): # Should not happen if stored correctly
                    try:
                        ocr_data_to_display = json.loads(ocr_data_to_display)
                    except json.JSONDecodeError:
                        return json.dumps({"error": "Stored OCR data is not valid JSON string."}, indent=2)
                return json.dumps(ocr_data_to_display, indent=2, ensure_ascii=False)
        return json.dumps({ "message": "No OCR data found for selected row or selection out of bounds (check if processing is complete). Current rows: " + str(len(processed_files_data))}, indent=2)

if __name__ == "__main__":
    demo.queue().launch(debug=True, share=os.environ.get("GRADIO_SHARE", "true").lower() == "true")