HackNight / app.py
Lovitra's picture
Update app.py
69a8ec7 verified
raw
history blame
5.08 kB
# Import libraries
import torch
import numpy as np
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from diffusers import StableDiffusionPipeline
from IPython.display import display
### --- STEP 1: Load TinyLlama for Text Generation --- ###
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
# Initialize text generation pipeline
comic_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer
)
### --- STEP 2: Load Stable Diffusion XL for High-Quality Images --- ###
model_id = "stabilityai/sd-turbo" # Best for artistic comic style
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16")
pipe.to("cuda") # Move to GPU for better performance
### --- STEP 3: User Inputs a Prompt & Number of Panels --- ###
user_prompt = input("Enter a topic for the comic strip: ") # Example: "Government of India"
# Get number of panels from the user
while True:
try:
num_panels = int(input("Enter the number of comic panels (3 to 6): "))
if 3 <= num_panels <= 6:
break
else:
print("❌ Please enter a number between 3 and 6.")
except ValueError:
print("❌ Invalid input! Please enter a number between 3 and 6.")
### --- STEP 4: User Chooses an Art Style --- ###
art_styles = {
"1": "Classic Comic",
"2": "Anime",
"3": "Cartoon",
"4": "Noir",
"5": "Cyberpunk",
"6": "Watercolor"
}
print("\n🎨 Choose an Art Style for the Comic:")
for key, style in art_styles.items():
print(f"{key}. {style}")
while True:
art_choice = input("\nEnter the number for your preferred art style: ")
if art_choice in art_styles:
chosen_style = art_styles[art_choice]
print(f"βœ… You selected: {chosen_style}")
break
else:
print("❌ Invalid choice! Please enter a valid number.")
### --- STEP 5: Generate Comic-Style Breakdown Using TinyLlama --- ###
instruction = (
f"Generate a structured {num_panels}-panel comic strip description for the topic. "
"Each panel should have a simple but clear scene description. "
"Keep it short and focus on visuals for easy image generation.\n\n"
"Topic: " + user_prompt + "\n\n"
"Comic Strip Panels:\n"
)
response = comic_pipeline(
instruction,
max_new_tokens=400, # Ensure full response
temperature=0.7,
repetition_penalty=1.1,
do_sample=True
)[0]['generated_text']
# Extract only the structured comic description
comic_breakdown = response.replace(instruction, "").strip()
comic_panels = [line.strip() for line in comic_breakdown.split("\n") if line.strip()][:num_panels]
print("\nπŸ”Ή Comic Strip Breakdown:\n", "\n".join(comic_panels)) # Show generated panels
### --- STEP 6: Generate High-Quality Comic-Style Images --- ###
def generate_comic_image(description, style):
"""
Generates a comic panel image using Stable Diffusion Turbo.
"""
# Validate style input (fallback to "Comic" if invalid)
valid_styles = ["Comic", "Anime", "Cyberpunk", "Watercolor", "Pixel Art"]
chosen_style = style if style in valid_styles else "Comic"
# Refined prompt (shorter, SD-Turbo-friendly)
prompt = f"{description}, {chosen_style} style, bold outlines, vibrant colors, dynamic action."
# Negative prompt (avoiding unwanted elements)
negative_prompt = "blurry, distorted, text, watermark, low quality, extra limbs, messy background"
try:
# Generate image with optimized parameters
image = pipe(
prompt,
negative_prompt=negative_prompt,
num_inference_steps=30, # Faster processing for SD-Turbo
guidance_scale=7
).images[0]
return image
except Exception as e:
print(f"❌ Error generating image: {e}")
return None # Return None if generation fails
# Generate images for each panel
comic_images = [generate_comic_image(panel, chosen_style) for panel in comic_panels]
# Remove None values if any images failed to generate
comic_images = [img for img in comic_images if img is not None]
if comic_images:
### --- STEP 7: Arrange Images in a Grid Based on Panel Count --- ###
grid_map = {3: (1, 3), 4: (2, 2), 5: (2, 3), 6: (2, 3)}
rows, cols = grid_map.get(len(comic_images), (1, len(comic_images)))
panel_width, panel_height = comic_images[0].size
comic_strip = Image.new("RGB", (panel_width * cols, panel_height * rows))
# Paste images in grid format
for i, img in enumerate(comic_images):
x_offset = (i % cols) * panel_width
y_offset = (i // cols) * panel_height
comic_strip.paste(img, (x_offset, y_offset))
# Display and save the comic strip
display(comic_strip)
comic_strip.save("comic_strip.png")
print("\nβœ… Comic strip saved as 'comic_strip.png'")
else:
print("\n❌ No images were generated.")