Lovitra commited on
Commit
69a8ec7
·
verified ·
1 Parent(s): edc89b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -151
app.py CHANGED
@@ -1,154 +1,139 @@
1
- import gradio as gr
2
- import numpy as np
3
- import random
4
-
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
8
-
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
-
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
  }
65
- """
66
-
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
-
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
-
82
- result = gr.Image(label="Result", show_label=False)
83
-
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
-
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
- )
99
-
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
-
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
135
-
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
  prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
151
- )
152
-
153
- if __name__ == "__main__":
154
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import libraries
 
 
 
 
 
2
  import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
6
+ from diffusers import StableDiffusionPipeline
7
+ from IPython.display import display
8
+
9
+ ### --- STEP 1: Load TinyLlama for Text Generation --- ###
10
+ model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
13
+
14
+ # Initialize text generation pipeline
15
+ comic_pipeline = pipeline(
16
+ "text-generation",
17
+ model=model,
18
+ tokenizer=tokenizer
19
+ )
20
+
21
+ ### --- STEP 2: Load Stable Diffusion XL for High-Quality Images --- ###
22
+ model_id = "stabilityai/sd-turbo" # Best for artistic comic style
23
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16")
24
+ pipe.to("cuda") # Move to GPU for better performance
25
+
26
+ ### --- STEP 3: User Inputs a Prompt & Number of Panels --- ###
27
+ user_prompt = input("Enter a topic for the comic strip: ") # Example: "Government of India"
28
+
29
+ # Get number of panels from the user
30
+ while True:
31
+ try:
32
+ num_panels = int(input("Enter the number of comic panels (3 to 6): "))
33
+ if 3 <= num_panels <= 6:
34
+ break
35
+ else:
36
+ print("❌ Please enter a number between 3 and 6.")
37
+ except ValueError:
38
+ print("❌ Invalid input! Please enter a number between 3 and 6.")
39
+
40
+ ### --- STEP 4: User Chooses an Art Style --- ###
41
+ art_styles = {
42
+ "1": "Classic Comic",
43
+ "2": "Anime",
44
+ "3": "Cartoon",
45
+ "4": "Noir",
46
+ "5": "Cyberpunk",
47
+ "6": "Watercolor"
 
 
 
 
 
 
 
 
 
 
 
48
  }
49
+
50
+ print("\n🎨 Choose an Art Style for the Comic:")
51
+ for key, style in art_styles.items():
52
+ print(f"{key}. {style}")
53
+
54
+ while True:
55
+ art_choice = input("\nEnter the number for your preferred art style: ")
56
+ if art_choice in art_styles:
57
+ chosen_style = art_styles[art_choice]
58
+ print(f"✅ You selected: {chosen_style}")
59
+ break
60
+ else:
61
+ print("❌ Invalid choice! Please enter a valid number.")
62
+
63
+ ### --- STEP 5: Generate Comic-Style Breakdown Using TinyLlama --- ###
64
+ instruction = (
65
+ f"Generate a structured {num_panels}-panel comic strip description for the topic. "
66
+ "Each panel should have a simple but clear scene description. "
67
+ "Keep it short and focus on visuals for easy image generation.\n\n"
68
+ "Topic: " + user_prompt + "\n\n"
69
+ "Comic Strip Panels:\n"
70
+ )
71
+
72
+ response = comic_pipeline(
73
+ instruction,
74
+ max_new_tokens=400, # Ensure full response
75
+ temperature=0.7,
76
+ repetition_penalty=1.1,
77
+ do_sample=True
78
+ )[0]['generated_text']
79
+
80
+ # Extract only the structured comic description
81
+ comic_breakdown = response.replace(instruction, "").strip()
82
+ comic_panels = [line.strip() for line in comic_breakdown.split("\n") if line.strip()][:num_panels]
83
+
84
+ print("\n🔹 Comic Strip Breakdown:\n", "\n".join(comic_panels)) # Show generated panels
85
+
86
+ ### --- STEP 6: Generate High-Quality Comic-Style Images --- ###
87
+ def generate_comic_image(description, style):
88
+ """
89
+ Generates a comic panel image using Stable Diffusion Turbo.
90
+ """
91
+ # Validate style input (fallback to "Comic" if invalid)
92
+ valid_styles = ["Comic", "Anime", "Cyberpunk", "Watercolor", "Pixel Art"]
93
+ chosen_style = style if style in valid_styles else "Comic"
94
+
95
+ # Refined prompt (shorter, SD-Turbo-friendly)
96
+ prompt = f"{description}, {chosen_style} style, bold outlines, vibrant colors, dynamic action."
97
+
98
+ # Negative prompt (avoiding unwanted elements)
99
+ negative_prompt = "blurry, distorted, text, watermark, low quality, extra limbs, messy background"
100
+
101
+ try:
102
+ # Generate image with optimized parameters
103
+ image = pipe(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  prompt,
105
+ negative_prompt=negative_prompt,
106
+ num_inference_steps=30, # Faster processing for SD-Turbo
107
+ guidance_scale=7
108
+ ).images[0]
109
+ return image
110
+ except Exception as e:
111
+ print(f"❌ Error generating image: {e}")
112
+ return None # Return None if generation fails
113
+
114
+ # Generate images for each panel
115
+ comic_images = [generate_comic_image(panel, chosen_style) for panel in comic_panels]
116
+
117
+ # Remove None values if any images failed to generate
118
+ comic_images = [img for img in comic_images if img is not None]
119
+
120
+ if comic_images:
121
+ ### --- STEP 7: Arrange Images in a Grid Based on Panel Count --- ###
122
+ grid_map = {3: (1, 3), 4: (2, 2), 5: (2, 3), 6: (2, 3)}
123
+ rows, cols = grid_map.get(len(comic_images), (1, len(comic_images)))
124
+
125
+ panel_width, panel_height = comic_images[0].size
126
+ comic_strip = Image.new("RGB", (panel_width * cols, panel_height * rows))
127
+
128
+ # Paste images in grid format
129
+ for i, img in enumerate(comic_images):
130
+ x_offset = (i % cols) * panel_width
131
+ y_offset = (i // cols) * panel_height
132
+ comic_strip.paste(img, (x_offset, y_offset))
133
+
134
+ # Display and save the comic strip
135
+ display(comic_strip)
136
+ comic_strip.save("comic_strip.png")
137
+ print("\n✅ Comic strip saved as 'comic_strip.png'")
138
+ else:
139
+ print("\n❌ No images were generated.")