# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "datasets",
#     "hf_transfer",
#     "huggingface-hub[hf_xet]",
#     "polars",
#     "stamina",
#     "transformers",
#     "vllm",
#     "tqdm",
#     "setuptools",
#     "flashinfer-python",
# ]
#
# ///
import argparse
import logging
import os
import sys
from typing import Optional

# Set environment variables to speed up model loading
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

import polars as pl
from datasets import Dataset, load_dataset
from huggingface_hub import login, dataset_info, snapshot_download
from tqdm.auto import tqdm
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
import vllm

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)

# log cuda version
import torch

if torch.cuda.is_available():
    logger.info(f"CUDA version: {torch.version.cuda}")
    logger.info(f"CUDA device: {torch.cuda.get_device_name(0)}")
    logger.info(f"CUDA capability: {torch.cuda.get_device_capability(0)}")
else:
    logger.info("CUDA is not available")

# log torch version
logger.info(f"PyTorch version: {torch.__version__}")
logger.info(f"vLLM version: {vllm.__version__}")


def format_prompt(content: str, card_type: str, tokenizer) -> str:
    """Format content as a prompt for the model."""
    if card_type == "model":
        messages = [{"role": "user", "content": f"<MODEL_CARD>{content[:4000]}"}]
    else:
        messages = [{"role": "user", "content": f"<DATASET_CARD>{content[:4000]}"}]

    return tokenizer.apply_chat_template(
        messages, add_generation_prompt=True, tokenize=False
    )


def load_and_filter_data(
    dataset_id: str, card_type: str, min_likes: int = 1, min_downloads: int = 1
) -> pl.DataFrame:
    """Load and filter dataset/model data."""
    logger.info(f"Loading data from {dataset_id}")
    ds = load_dataset(dataset_id, split="train")
    df = ds.to_polars().lazy()

    # Extract content after YAML frontmatter
    df = df.with_columns(
        [
            pl.col("card")
            .str.replace_all(r"^---\n[\s\S]*?\n---\n", "", literal=False)
            .str.strip_chars()
            .alias("post_yaml_content")
        ]
    )

    # Apply filters
    df = df.filter(pl.col("post_yaml_content").str.len_bytes() > 200)
    df = df.filter(pl.col("post_yaml_content").str.len_bytes() < 120_000)

    if card_type == "model":
        df = df.filter(pl.col("likes") >= min_likes)
        df = df.filter(pl.col("downloads") >= min_downloads)

    df_filtered = df.collect()
    logger.info(f"Filtered dataset has {len(df_filtered)} items")
    return df_filtered


def generate_summaries(
    model_id: str,
    input_dataset_id: str,
    output_dataset_id: str,
    card_type: str = "dataset",
    max_tokens: int = 120,
    temperature: float = 0.6,
    batch_size: int = 1000,
    min_likes: int = 1,
    min_downloads: int = 1,
    hf_token: Optional[str] = None,
):
    """Main function to generate summaries."""

    # Login if token provided
    HF_TOKEN = hf_token or os.environ.get("HF_TOKEN")
    if HF_TOKEN:
        login(token=HF_TOKEN)

    # Load and filter data
    df_filtered = load_and_filter_data(
        input_dataset_id, card_type, min_likes, min_downloads
    )

    # Download model to local directory first
    logger.info(f"Downloading model {model_id} to local directory...")
    local_model_path = snapshot_download(repo_id=model_id, resume_download=True)
    logger.info(f"Model downloaded to: {local_model_path}")

    # Initialize model and tokenizer from local path
    logger.info(f"Initializing vLLM model from local path: {local_model_path}")
    llm = LLM(model=local_model_path, enable_chunked_prefill=True)
    tokenizer = AutoTokenizer.from_pretrained(local_model_path)
    sampling_params = SamplingParams(
        temperature=temperature,
        max_tokens=max_tokens,
    )

    # Prepare prompts
    logger.info("Preparing prompts")
    post_yaml_contents = df_filtered["post_yaml_content"].to_list()
    prompts = [
        format_prompt(content, card_type, tokenizer)
        for content in tqdm(post_yaml_contents, desc="Formatting prompts")
    ]

    # Generate summaries in batches
    logger.info(f"Generating summaries for {len(prompts)} items")
    all_outputs = []

    # for i in tqdm(range(0, len(prompts), batch_size), desc="Generating summaries"):
    #     batch_prompts = prompts[i : i + batch_size]
    #     outputs = llm.generate(batch_prompts, sampling_params)
    #     all_outputs.extend(outputs)
    # try directly doing whole dataset
    all_outputs = llm.generate(
        prompts,
        sampling_params,
    )
    logger.info(f"Generated {len(all_outputs)} summaries")
    if len(all_outputs) != len(prompts):
        logger.warning(
            f"Generated {len(all_outputs)} summaries, but expected {len(prompts)}. Some prompts may have failed."
        )

    # Extract clean results
    clean_results = [output.outputs[0].text.strip() for output in all_outputs]

    # Create dataset and add summaries
    ds = Dataset.from_polars(df_filtered)
    ds = ds.add_column("summary", clean_results)

    # Push to hub
    logger.info(f"Pushing dataset to hub: {output_dataset_id}")
    ds.push_to_hub(output_dataset_id, token=HF_TOKEN)
    logger.info("Dataset successfully pushed to hub")


def main():
    parser = argparse.ArgumentParser(
        description="Generate summaries for Hugging Face datasets or models using vLLM"
    )
    parser.add_argument(
        "model_id",
        help="Model ID for summary generation (e.g., davanstrien/SmolLM2-135M-tldr-sft-2025-03-12_19-02)",
    )
    parser.add_argument(
        "input_dataset_id",
        help="Input dataset ID (e.g., librarian-bots/dataset_cards_with_metadata)",
    )
    parser.add_argument(
        "output_dataset_id", help="Output dataset ID where results will be saved"
    )
    parser.add_argument(
        "--card-type",
        choices=["dataset", "model"],
        default="dataset",
        help="Type of cards to process (default: dataset)",
    )
    parser.add_argument(
        "--max-tokens",
        type=int,
        default=120,
        help="Maximum tokens for summary generation (default: 120)",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.6,
        help="Temperature for generation (default: 0.6)",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=1000,
        help="Batch size for processing (default: 1000)",
    )
    parser.add_argument(
        "--min-likes",
        type=int,
        default=1,
        help="Minimum likes filter for models (default: 1)",
    )
    parser.add_argument(
        "--min-downloads",
        type=int,
        default=1,
        help="Minimum downloads filter for models (default: 1)",
    )
    parser.add_argument(
        "--hf-token", help="Hugging Face token (uses HF_TOKEN env var if not provided)"
    )

    args = parser.parse_args()

    generate_summaries(
        model_id=args.model_id,
        input_dataset_id=args.input_dataset_id,
        output_dataset_id=args.output_dataset_id,
        card_type=args.card_type,
        max_tokens=args.max_tokens,
        temperature=args.temperature,
        batch_size=args.batch_size,
        min_likes=args.min_likes,
        min_downloads=args.min_downloads,
        hf_token=args.hf_token,
    )


if __name__ == "__main__":
    if len(sys.argv) == 1:
        # Show example hfjobs command when run without arguments
        print("Example hfjobs command:")
        print(
            "hfjobs run --flavor l4x1 --secret HF_TOKEN=hf_*** ghcr.io/astral-sh/uv:debian /bin/bash -c '"
        )
        print("apt-get update && apt-get install -y python3-dev gcc && \\")
        print("export HOME=/tmp && \\")
        print("export USER=dummy && \\")
        print("export TORCHINDUCTOR_CACHE_DIR=/tmp/torch-inductor && \\")
        print("export UV_TORCH_BACKEND=auto && \\")
        print("uv run generate_summaries_uv.py \\")
        print("  davanstrien/Smol-Hub-tldr \\")
        print("  librarian-bots/dataset_cards_with_metadata \\")
        print("  your-username/datasets_with_summaries \\")
        print("  --card-type dataset \\")
        print("  --batch-size 2000")
        print("' --project summary-generation --name dataset-summaries")
        print()
        print("For models:")
        print("uv run generate_summaries_uv.py \\")
        print("  davanstrien/SmolLM2-135M-tldr-sft-2025-03-12_19-02 \\")
        print("  librarian-bots/model_cards_with_metadata \\")
        print("  your-username/models_with_summaries \\")
        print("  --card-type model \\")
        print("  --min-likes 5 \\")
        print("  --min-downloads 1000")
    else:
        main()