import gradio as gr | |
import torch | |
from llama_cpp import Llama # New import for GGUF models | |
# Define the model name for the GGUF model. | |
# IMPORTANT: This assumes you will upload the 'Magistral-Small-2506_gguf' file | |
# directly into the root directory of your Hugging Face Space. | |
# You will need to download this file from its Hugging Face repository (e.g., from the "Files and versions" tab) | |
# and upload it to your Space, naming it exactly as it appears here. | |
GGUF_MODEL_FILE = "Magistral-Small-2506_gguf" # Adjust if your uploaded file name is different | |
# Global variable for the Llama model instance | |
llm = None | |
# Function to load the Llama GGUF model | |
def load_model(): | |
global llm | |
if llm is None: | |
print(f"Loading GGUF model: {GGUF_MODEL_FILE}...") | |
try: | |
# Initialize the Llama model. | |
# `model_path` must point to the local file path of your GGUF model. | |
# `n_gpu_layers` can be set to a positive integer to offload layers to GPU if available. | |
# Set to 0 for CPU-only inference (recommended for simplicity on free Spaces tiers). | |
# `n_ctx` defines the context window size. Adjust as needed for your use case. | |
llm = Llama(model_path=GGUF_MODEL_FILE, n_gpu_layers=0, n_ctx=2048) | |
print("GGUF Model loaded successfully.") | |
except Exception as e: | |
print(f"Error loading GGUF model: {e}") | |
raise RuntimeError(f"Failed to load GGUF model: {e}. Please ensure '{GGUF_MODEL_FILE}' is correctly uploaded and accessible.") | |
# Call this function once at the start of the script to load the model. | |
load_model() | |
# This is the core function that will be exposed as an API endpoint. | |
# It takes a prompt and generation parameters, and returns generated text. | |
def generate_text(prompt: str, max_new_tokens: int = 100, temperature: float = 0.7, top_k: int = 50) -> str: | |
# Basic input validation for the prompt. | |
if not prompt: | |
return "Please enter a prompt to generate text!" | |
if llm is None: | |
return "Model not loaded. Please check Space logs for errors." | |
try: | |
# Generate text using the Llama model's create_completion method. | |
# `prompt` is the input text. | |
# `max_tokens` controls the length of the generated output. | |
# `temperature` controls randomness (higher = more creative). | |
# `top_k` filters the sampling pool. | |
# `stop` can be used to define tokens where the generation should stop (e.g., ["\nUser:"]). | |
# `echo=False` ensures the prompt is not repeated in the output. | |
output = llm.create_completion( | |
prompt=prompt, | |
max_tokens=max_new_tokens, | |
temperature=temperature, | |
top_k=top_k, | |
stop=["\nUser:", "##"], # Example stop sequences | |
echo=False | |
) | |
# The generated text is typically found in the 'choices' list of the output dictionary. | |
generated_text = output['choices'][0]['text'] | |
return generated_text | |
except Exception as e: | |
# Log any errors that occur during text generation for debugging. | |
print(f"Error during text generation: {e}") | |
# Return an informative error message to the user/caller. | |
return f"An error occurred: {e}. Please try again with a different prompt or check the Space logs." | |
# Create the Gradio interface. | |
# This interface will automatically generate a web UI and an API endpoint. | |
demo = gr.Interface( | |
fn=generate_text, # The Python function to expose. | |
inputs=[ | |
# Input component for the prompt. | |
gr.Textbox(label="Enter your prompt here", lines=3), | |
# Slider for maximum number of new tokens to generate. | |
gr.Slider(minimum=10, maximum=500, value=100, label="Max New Tokens"), | |
# Slider for generation temperature (randomness). | |
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.05, label="Temperature"), | |
# Slider for Top-K sampling (diversity). | |
gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top K") | |
], | |
outputs=gr.Textbox(label="Generated Text", lines=5), # Output component for generated text. | |
title="Magistral-Small-2506_gguf Text Generation API on Hugging Face Space", | |
description="Enter a prompt and Magistral-Small-2506_gguf will generate a response. Adjust parameters for different results. This function is also exposed as an API endpoint.", | |
allow_flagging="never" # Disables Gradio's data flagging feature. | |
) | |
# Launch the Gradio application. | |
# `server_name="0.0.0.0"` is essential for Hugging Face Spaces to expose the app publicly. | |
# `server_port=7860` is the default port used by Hugging Face Spaces. | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |