import os
import logging
import gradio as gr
from typing import Iterator

from dialog import get_dialog_box
from gateway import check_server_health, request_generation

# Setup logging
logging.basicConfig(level=logging.INFO)

# CONSTANTS
# Get max new tokens from environment variable, if it is not set, default to 2048
MAX_NEW_TOKENS: int = os.getenv("MAX_NEW_TOKENS", 2048)

# Validate environment variables
CLOUD_GATEWAY_API = os.getenv("API_ENDPOINT")
if not CLOUD_GATEWAY_API:
    raise EnvironmentError("API_ENDPOINT is not set.")

MODEL_NAME: str = os.getenv("MODEL_NAME")
if not MODEL_NAME:
    raise EnvironmentError("MODEL_NAME is not set.")

# Get API Key
API_KEY = os.getenv("API_KEY")
if not API_KEY:  # simple check to validate API Key
    raise Exception("API Key not valid.")

# Create a header, avoid declaring multiple times
HEADER = {"x-api-key": f"{API_KEY}"}


def toggle_ui():
    """
    Function to toggle the visibility of the UI based on the server health
    Returns:
        hide/show main ui/dialog
    """
    health = check_server_health(cloud_gateway_api=CLOUD_GATEWAY_API, header=HEADER)
    if health:
        return gr.update(visible=True), gr.update(
            visible=False
        )  # Show main UI, hide dialog
    else:
        return gr.update(visible=False), gr.update(
            visible=True
        )  # Hide main UI, show dialog


def generate(
    message: str,
    chat_history: list,
    system_prompt: str,
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    frequency_penalty: float = 0.0,
    presence_penalty: float = 0.0,
) -> Iterator[str]:
    """Send a request to backend, fetch the streaming responses and emit to the UI.

    Args:
        message (str): input message from the user
        chat_history (list[tuple[str, str]]): entire chat history of the session
        system_prompt (str): system prompt
        max_new_tokens (int, optional): maximum number of tokens to generate, ignoring the number of tokens in the
                                        prompt. Defaults to 1024.
        temperature (float, optional): the value used to module the next token probabilities. Defaults to 0.6.
        top_p (float, optional): if set to float<1, only the smallest set of most probable tokens with probabilities
                                    that add up to top_p or higher are kept for generation. Defaults to 0.9.
        top_k (int, optional): the number of highest probability vocabulary tokens to keep for top-k-filtering.
                                Defaults to 50.
        repetition_penalty (float, optional): the parameter for repetition penalty. 1.0 means no penalty.
                                Defaults to 1.2.

    Yields:
        Iterator[str]: Streaming responses to the UI
    """
    # sample method to yield responses from the llm model
    outputs = []
    for text in request_generation(
        header=HEADER,
        message=message,
        system_prompt=system_prompt,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        presence_penalty=presence_penalty,
        frequency_penalty=frequency_penalty,
        cloud_gateway_api=CLOUD_GATEWAY_API,
        model_name=MODEL_NAME,
    ):
        outputs.append(text)
        yield "".join(outputs)


chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Textbox(
            label="System prompt",
            value="You are a highly capable AI assistant. Provide accurate, concise, and fact-based responses that are directly relevant to the user's query. Avoid speculation, ensure logical consistency, and maintain clarity in longer outputs. Keep answers well-structured and under 1200 tokens unless explicitly requested otherwise.",
            lines=3,
        ),
        gr.Slider(
            label="Max New Tokens",
            minimum=1,
            maximum=MAX_NEW_TOKENS,
            step=1,
            value=2048,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.3,
        ),
        gr.Slider(
            label="Frequency penalty",
            minimum=-2.0,
            maximum=2.0,
            step=0.1,
            value=0.0,
        ),
        gr.Slider(
            label="Presence penalty",
            minimum=-2.0,
            maximum=2.0,
            step=0.1,
            value=0.0,
        ),
    ],
    stop_btn=None,
    examples=[
        ["Plan a three-day trip to Washington DC for Cherry Blossom Festival."],
        [
            "Compose a short, joyful musical piece for kids celebrating spring sunshine and blossom."
        ],
        ["Can you explain briefly to me what is the Python programming language?"],
        ["Explain the plot of Cinderella in a sentence."],
        ["How many hours does it take a man to eat a Helicopter?"],
        ["Write a 100-word article on 'Benefits of Open-Source in AI research'."],
    ],
    cache_examples=False,
)

with gr.Blocks(css="style.css", fill_height=True) as demo:
    # Get the server status before displaying UI
    visibility = check_server_health(CLOUD_GATEWAY_API, header=HEADER)

    # Container for the main interface
    with gr.Column(visible=visibility, elem_id="main_ui") as main_ui:
        gr.Markdown(
            f"""
            # Gemma 3 27b Instruct
            <span style="font-size:24px;">⚠️**Maintenance Notice** ⚠️ We are currently undergoing scheduled maintenance. Please check back soon. Thank you for your patience.</span>

            This Space is an Alpha release that demonstrates [Gemma-3-27B-It](https://huggingface.co/google/gemma-3-27b-it) model running on AMD MI300 infrastructure. The space is built with Google Gemma 3 [License](https://ai.google.dev/gemma/terms). Feel free to play with it!
            """
        )
        chat_interface.render()

    # Dialog box using Markdown for the error message
    with gr.Row(visible=(not visibility), elem_id="dialog_box") as dialog_box:
        # Add spinner and message
        get_dialog_box()

    # Timer to check server health every 5 seconds and update UI
    timer = gr.Timer(value=10)
    timer.tick(fn=toggle_ui, outputs=[main_ui, dialog_box])


if __name__ == "__main__":
    demo.queue(
        max_size=int(os.getenv("QUEUE")),
        default_concurrency_limit=int(os.getenv("CONCURRENCY_LIMIT")),
    ).launch()