File size: 3,719 Bytes
5265cb1
 
eb07b95
5265cb1
 
eb07b95
5265cb1
eb07b95
5265cb1
cbd648d
5265cb1
 
eb07b95
 
 
cbd648d
eb07b95
cbd648d
 
 
 
f232606
 
cbd648d
 
eb07b95
cbd648d
 
eb07b95
 
 
5265cb1
eb07b95
5265cb1
cbd648d
5265cb1
 
eb07b95
 
35e1ad5
cbd648d
 
f232606
cbd648d
 
eb07b95
 
 
 
 
5265cb1
 
 
eb07b95
 
 
 
 
 
 
 
 
 
 
2a03292
eb07b95
 
 
f478c80
2a03292
eb07b95
 
 
 
 
 
 
5265cb1
eb07b95
5265cb1
 
eb07b95
 
 
 
 
 
 
5265cb1
eb07b95
 
 
 
 
 
 
2a03292
 
 
 
 
eb07b95
005fdf4
eb07b95
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import streamlit as st
import torch
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from diffusers import StableDiffusionPipeline

# Check for GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
st.write(f"Using device: {device}")  # Debug message

# Load text model (TinyLlama) with optimizations
@st.cache_resource
def load_text_model():
    try:
        st.write("⏳ Loading text model...")
        model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
        
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        # Load model with FP16 or 8-bit quantization
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16 if device == "cuda" else torch.float32,
            low_cpu_mem_usage=True 
        ).to(device)

        st.write("βœ… Text model loaded successfully!")
        return pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if device == "cuda" else -1)
    
    except Exception as e:
        st.error(f"❌ Error loading text model: {e}")
        return None

story_generator = load_text_model()

# Load image model (Stable Diffusion) with optimizations
@st.cache_resource
def load_image_model():
    try:
        st.write("⏳ Loading image model...")
        model_id = "stabilityai/sd-turbo"
        model = StableDiffusionPipeline.from_pretrained(
            model_id,
            torch_dtype=torch.float16 if device == "cuda" else torch.float32
        ).to(device)
        model.enable_attention_slicing()  # Optimize GPU memory
        st.write("βœ… Image model loaded successfully!")
        return model
    except Exception as e:
        st.error(f"❌ Error loading image model: {e}")
        return None

image_generator = load_image_model()

# Function to generate a short story
def generate_story(prompt):
    if not story_generator:
        return "❌ Error: Story model not loaded."

    formatted_prompt = f"Write a short comic-style story about: {prompt}\n\nStory:"
    
    try:
        st.write("⏳ Generating story...")
        story_output = story_generator(
            formatted_prompt,
            max_length=100,  # Reduced for speed
            do_sample=True,
            temperature=0.7,
            top_k=30,
            num_return_sequences=1,
            truncation=True
        )[0]['generated_text']
        st.write("βœ… Story generated successfully!")
        return story_output.replace(formatted_prompt, "").strip()
    except Exception as e:
        st.error(f"❌ Error generating story: {e}")
        return "Error generating story."

# Streamlit UI
st.title("πŸ¦Έβ€β™‚οΈ AI Comic Story Generator")
st.write("Enter a prompt to generate a comic-style story and image!")

# User input
user_prompt = st.text_input("πŸ“ Enter your story prompt:")

if user_prompt:
    st.subheader("πŸ“– AI-Generated Story")
    generated_story = generate_story(user_prompt)
    st.write(generated_story)

    st.subheader("πŸ–ΌοΈ AI-Generated Image")
    
    if not image_generator:
        st.error("❌ Error: Image model not loaded.")
    else:
        with st.spinner("⏳ Generating image..."):
            try:
                image = image_generator(
                    user_prompt, 
                    num_inference_steps=8,  # Reduced for faster generation
                    height=256, width=256   # Smaller size to reduce memory usage
                ).images[0]
                st.write("βœ… Image generated successfully!")
                st.image(image, caption="Generated Comic Image", use_container_width=True)
            except Exception as e:
                st.error(f"❌ Error generating image: {e}")