File size: 2,442 Bytes
89dc196
8dbcb99
69a8ec7
 
 
 
89dc196
69a8ec7
 
 
89dc196
69a8ec7
89dc196
 
24c7d1a
d359ae9
89dc196
 
 
 
 
 
 
 
 
 
 
 
24c7d1a
89dc196
 
 
69a8ec7
89dc196
69a8ec7
89dc196
69a8ec7
 
 
 
89dc196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69a8ec7
89dc196
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import gradio as gr
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from diffusers import StableDiffusionPipeline

# Load models
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
comic_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)

# Stable Diffusion Model
model_id = "stabilityai/sd-turbo"
pipe = StableDiffusionPipeline.from_pretrained(model_id)
pipe.to("cpu")

# Function to generate comic strip
def generate_comic(user_prompt, num_panels, art_choice):
    # Step 1: Generate Comic Panel Descriptions
    instruction = f"Generate a {num_panels}-panel comic strip description for the topic: {user_prompt}"
    response = comic_pipeline(instruction, max_new_tokens=400, temperature=0.7)[0]['generated_text']
    comic_panels = [line.strip() for line in response.split("\n") if line.strip()][:num_panels]
    
    # Step 2: Generate Comic Images
    comic_images = []
    for panel in comic_panels:
        prompt = f"{panel}, {art_choice} style, bold outlines, vibrant colors"
        image = pipe(prompt, num_inference_steps=30, do_sample=True, temperature=0.7).images[0]
        comic_images.append(image)
    
    # Step 3: Create a Grid Layout for Comic Strip
    panel_width, panel_height = comic_images[0].size
    rows, cols = (1, len(comic_images)) if len(comic_images) <= 3 else (2, 3)
    comic_strip = Image.new("RGB", (panel_width * cols, panel_height * rows))
    
    for i, img in enumerate(comic_images):
        x_offset = (i % cols) * panel_width
        y_offset = (i // cols) * panel_height
        comic_strip.paste(img, (x_offset, y_offset))
    
    return comic_strip

# Gradio Interface
art_styles = ["Classic Comic", "Anime", "Cartoon", "Noir", "Cyberpunk", "Watercolor"]
interface = gr.Interface(
    fn=generate_comic,
    inputs=[
        gr.Textbox(label="Enter Comic Topic", placeholder="e.g., Iron Man vs Hulk"),
        gr.Slider(minimum=3, maximum=6, step=1, label="Number of Panels"),
        gr.Dropdown(choices=art_styles, label="Choose Art Style")
    ],
    outputs="image",
    title="Comic Strip Generator",
    description="Generate your own comic strip by entering a topic, choosing the number of panels, and selecting an art style."
)

interface.launch()