Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch # Optional, but good practice if using a PyTorch model | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# --- 1. Load a simple, small pre-trained LLM and its tokenizer --- | |
# We'll use DistilGPT2 for speed and small size. | |
# You can replace this with another small model if you prefer. | |
model_name = "distilgpt2" | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
# If you have a GPU, uncomment the next line | |
# model.to("cuda" if torch.cuda.is_available() else "cpu") | |
model_loaded = True | |
print(f"Successfully loaded model and tokenizer for: {model_name}") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
model_loaded = False | |
# Define dummy functions if model fails to load, so Gradio interface still launches | |
def generate_text_from_llm(prompt_text): | |
return "Error: Model could not be loaded. Please check server logs." | |
tokenizer = None # To avoid errors later if tokenizer specific functions are called | |
if model_loaded and tokenizer: | |
# Ensure pad_token is set if it's not already (important for generate) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
model.config.pad_token_id = model.config.eos_token_id | |
# --- 2. Define the LLM inference function --- | |
def generate_text_from_llm(prompt_text): | |
""" | |
Generates a short text continuation using the loaded LLM. | |
""" | |
if not prompt_text: | |
return "Please enter a starting prompt!" | |
try: | |
# Encode the input prompt | |
inputs = tokenizer.encode(prompt_text, return_tensors="pt", truncation=True, max_length=512) | |
# If you have a GPU, uncomment the next line | |
# inputs = inputs.to("cuda" if torch.cuda.is_available() else "cpu") | |
# Generate text | |
# max_length is the total length of prompt + generated text | |
# num_return_sequences=1 means we want one completion | |
# no_repeat_ngram_size helps avoid repetitive text | |
outputs = model.generate( | |
inputs, | |
max_length=len(inputs[0]) + 50, # Generate up to 50 new tokens | |
num_return_sequences=1, | |
pad_token_id=tokenizer.eos_token_id, # Use EOS token for padding during generation | |
no_repeat_ngram_size=2, # Avoid repeating 2-grams | |
early_stopping=True | |
) | |
# Decode the generated text | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Return only the newly generated part (optional, can be tricky) | |
# For simplicity, we'll return the whole thing for now. | |
# To return only new text: return generated_text[len(prompt_text):].strip() | |
return generated_text | |
except Exception as e: | |
print(f"Error during generation: {e}") | |
return f"Error during text generation: {e}" | |
# --- 3. Create the Gradio Interface --- | |
demo = gr.Interface( | |
fn=generate_text_from_llm, | |
inputs=[ | |
gr.Textbox( | |
label="Enter your prompt", | |
placeholder="Start typing here...", | |
lines=5 | |
) | |
], | |
outputs=[ | |
gr.Textbox(label="LLM Generated Text", lines=10) | |
], | |
title="π Simple LLM Text Generator", | |
description="Enter a prompt and a small LLM (DistilGPT2) will try to continue it. This is a basic demo for learning purposes.", | |
examples=[ | |
["Once upon a time, in a land far away,"], | |
["The best way to learn programming is"], | |
["Artificial intelligence is rapidly changing the world by"] | |
], | |
theme=gr.themes.Soft() # You can try other themes like gr.themes.Default() | |
) | |
# --- 4. Launch the app --- | |
# When deploying to Hugging Face Spaces, they will run this launch() command. | |
# For local testing with a shareable link, use share=True. | |
if __name__ == "__main__": | |
if model_loaded: | |
demo.launch(debug=True, share=True) # share=True creates a temporary public link | |
else: | |
print("Model failed to load. Gradio app will run with an error message function.") | |
# Launch with the dummy function so the UI still appears | |
demo_error = gr.Interface(fn=lambda x: "Error: Model could not be loaded.", inputs="textbox", outputs="textbox", title="LLM Demo - MODEL LOAD ERROR") | |
demo_error.launch(debug=True, share=True) |