Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
import torch | |
model_name = "ajibawa-2023/Young-Children-Storyteller-Mistral-7B" | |
# Load tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Try 8-bit quantization to reduce memory usage | |
try: | |
from transformers import BitsAndBytesConfig | |
bnb_config = BitsAndBytesConfig(load_in_8bit=True) | |
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", quantization_config=bnb_config) | |
except: | |
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") | |
generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0) | |
def generate_story(prompt): | |
outputs = generator(prompt, max_length=150, do_sample=True, temperature=0.8, top_p=0.9) | |
return outputs[0]['generated_text'] | |
iface = gr.Interface( | |
fn=generate_story, | |
inputs=gr.Textbox(lines=3, placeholder="Enter your story prompt here..."), | |
outputs="text", | |
title="Young Children Storyteller", | |
description="Generate children's stories using Mistral 7B" | |
) | |
if __name__ == "__main__": | |
iface.launch() |