File size: 3,719 Bytes
5265cb1 eb07b95 5265cb1 eb07b95 5265cb1 eb07b95 5265cb1 cbd648d 5265cb1 eb07b95 cbd648d eb07b95 cbd648d f232606 cbd648d eb07b95 cbd648d eb07b95 5265cb1 eb07b95 5265cb1 cbd648d 5265cb1 eb07b95 35e1ad5 cbd648d f232606 cbd648d eb07b95 5265cb1 eb07b95 2a03292 eb07b95 f478c80 2a03292 eb07b95 5265cb1 eb07b95 5265cb1 eb07b95 5265cb1 eb07b95 2a03292 eb07b95 005fdf4 eb07b95 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import streamlit as st
import torch
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from diffusers import StableDiffusionPipeline
# Check for GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
st.write(f"Using device: {device}") # Debug message
# Load text model (TinyLlama) with optimizations
@st.cache_resource
def load_text_model():
try:
st.write("β³ Loading text model...")
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load model with FP16 or 8-bit quantization
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
low_cpu_mem_usage=True
).to(device)
st.write("β
Text model loaded successfully!")
return pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if device == "cuda" else -1)
except Exception as e:
st.error(f"β Error loading text model: {e}")
return None
story_generator = load_text_model()
# Load image model (Stable Diffusion) with optimizations
@st.cache_resource
def load_image_model():
try:
st.write("β³ Loading image model...")
model_id = "stabilityai/sd-turbo"
model = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
model.enable_attention_slicing() # Optimize GPU memory
st.write("β
Image model loaded successfully!")
return model
except Exception as e:
st.error(f"β Error loading image model: {e}")
return None
image_generator = load_image_model()
# Function to generate a short story
def generate_story(prompt):
if not story_generator:
return "β Error: Story model not loaded."
formatted_prompt = f"Write a short comic-style story about: {prompt}\n\nStory:"
try:
st.write("β³ Generating story...")
story_output = story_generator(
formatted_prompt,
max_length=100, # Reduced for speed
do_sample=True,
temperature=0.7,
top_k=30,
num_return_sequences=1,
truncation=True
)[0]['generated_text']
st.write("β
Story generated successfully!")
return story_output.replace(formatted_prompt, "").strip()
except Exception as e:
st.error(f"β Error generating story: {e}")
return "Error generating story."
# Streamlit UI
st.title("π¦ΈββοΈ AI Comic Story Generator")
st.write("Enter a prompt to generate a comic-style story and image!")
# User input
user_prompt = st.text_input("π Enter your story prompt:")
if user_prompt:
st.subheader("π AI-Generated Story")
generated_story = generate_story(user_prompt)
st.write(generated_story)
st.subheader("πΌοΈ AI-Generated Image")
if not image_generator:
st.error("β Error: Image model not loaded.")
else:
with st.spinner("β³ Generating image..."):
try:
image = image_generator(
user_prompt,
num_inference_steps=8, # Reduced for faster generation
height=256, width=256 # Smaller size to reduce memory usage
).images[0]
st.write("β
Image generated successfully!")
st.image(image, caption="Generated Comic Image", use_container_width=True)
except Exception as e:
st.error(f"β Error generating image: {e}")
|