from diffusers import ( StableDiffusionXLPipeline, StableDiffusionXLAdapterPipeline, AutoencoderKL, UniPCMultistepScheduler, T2IAdapter, ) import torch, os from PIL import Image from io import BytesIO import models from database import SessionLocal from text_processor import ( get_resolved_sentences, detect_and_translate_to_english, get_script_captions, ) from s3 import upload_image_to_s3 from diffusers.utils import load_image import random from controlnet_aux import OpenposeDetector import numpy as np import gc # Global device configuration dtype = torch.float16 # Initialize global generator generator = torch.Generator() # Initialize the models globally to ensure they're only loaded once print("Loading VAE...") vae = AutoencoderKL.from_pretrained( "madebyollin/sdxl-vae-fp16-fix", torch_dtype=dtype, use_safetensors=True ).to("cuda") print("Loading base pipeline...") pipe = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", vae=vae, torch_dtype=dtype, variant="fp16", use_safetensors=True, ).to("cuda") pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) pipe.load_lora_weights("safetensors/Storyboard_sketch.safetensors", adapter_name="sketch") pipe.load_lora_weights("safetensors/anglesv2.safetensors", adapter_name="angles") pipe.set_adapters(["sketch", "angles"], adapter_weights=[0.5, 0.5]) pipe.enable_xformers_memory_efficient_attention() print("Loading OpenPose detector...") openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") print("Loading T2I adapter...") adapter = T2IAdapter.from_pretrained( "TencentARC/t2i-adapter-openpose-sdxl-1.0", torch_dtype=dtype ).to("cuda") print("Loading adapter pipeline...") posepipe = StableDiffusionXLAdapterPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", adapter=adapter, vae=vae, torch_dtype=dtype, variant="fp16", use_safetensors=True, ).to("cuda") posepipe.scheduler = UniPCMultistepScheduler.from_config(posepipe.scheduler.config) posepipe.load_lora_weights( "safetensors/Storyboard_sketch.safetensors", adapter_name="sketch" ) posepipe.load_lora_weights("safetensors/anglesv2.safetensors", adapter_name="angles") posepipe.set_adapters(["sketch", "angles"], adapter_weights=[0.5, 0.5]) posepipe.enable_xformers_memory_efficient_attention() print("All models loaded successfully") def clear_cuda_cache(): """Clear CUDA cache to free up memory""" if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() def get_dimensions(resolution: str) -> tuple[int, int]: resolution_map = { "16:9": (1024, 576), "1:1": (1024, 1024), "9:16": (576, 1024), } return resolution_map.get(resolution, (1024, 1024)) def generate_batch_images( story: str, storyboard_id: int, resolution: str = "1:1", isStory: bool = True ): # Clear cache before batch generation clear_cuda_cache() db = SessionLocal() try: if isStory: prompts = get_resolved_sentences(story) elif not isStory: prompts = get_script_captions(story) width, height = get_dimensions(resolution) for num, prompt in enumerate(prompts): # Generate a random seed for each image in the batch seed = random.randint(0, 2**32 - 1) generator.manual_seed(seed) print(f"Generating image {num+1} with seed {seed}") result = pipe( prompt=f"Storyboard sketch of {prompt}, black and white, cinematic, high quality", negative_prompt="ugly, deformed, disfigured, poor details, bad anatomy, abstract, bad physics", guidance_scale=8.5, height=height, width=width, num_inference_steps=30, generator=generator, ) image = result.images[0] buf = BytesIO() image.save(buf, format="JPEG") buf.seek(0) s3_url = upload_image_to_s3( buf.read(), f"image_{num + 1}.jpg", folder=f"storyboards/{storyboard_id}", ) db_image = models.Image( storyboard_id=storyboard_id, image_path=s3_url, caption=prompt, ) db.add(db_image) db.commit() db.refresh(db_image) print(f"Image {num+1} generated successfully") # Clear cache after each image clear_cuda_cache() except Exception as e: print(f"Error during image generation: {e}") import traceback traceback.print_exc() db.rollback() finally: db.close() def generate_single_image( image_id: int, caption: str, seed: int = None, resolution: str = "1:1", isOpenPose: bool = False, pose_img: Image.Image = None, ): # Clear cache before single image generation clear_cuda_cache() db = SessionLocal() try: # Get existing image record db_image = db.query(models.Image).filter(models.Image.id == image_id).first() processed_caption = detect_and_translate_to_english(caption) width, height = get_dimensions(resolution) # Use provided seed or generate a random one current_seed = seed if seed is not None else random.randint(0, 2**32 - 1) generator.manual_seed(current_seed) print(f"Generating single image with seed {current_seed}") if not db_image: raise ValueError(f"Image with id {image_id} not found.") if isOpenPose: print("Using OpenPose pipeline") image = openpose(pose_img, detect_resolution=512, image_resolution=1024) image = np.array(image)[:, :, ::-1] image = Image.fromarray(np.uint8(image)) result = posepipe( prompt=f"Storyboard sketch of {processed_caption}, black and white, cinematic, high quality", negative_prompt="ugly, deformed, disfigured, poor details, bad anatomy, abstract, bad physics", image=image, adapter_conditioning_scale=1, guidance_scale=8.5, num_inference_steps=30, generator=generator, ) else: print("Using standard pipeline") result = pipe( prompt=f"Storyboard sketch of {processed_caption}, black and white, cinematic, high quality", negative_prompt="ugly, deformed, disfigured, poor details, bad anatomy, abstract, bad physics", guidance_scale=8.5, num_inference_steps=30, width=width, height=height, generator=generator, ) # Save and upload image = result.images[0] buf = BytesIO() image.save(buf, format="JPEG") buf.seek(0) s3_url = upload_image_to_s3( buf.read(), f"image_{image_id}.jpg", folder=f"storyboards/{db_image.storyboard_id}", ) # Update image record db_image.image_path = s3_url db_image.caption = caption db_image.seed = current_seed db.commit() db.refresh(db_image) print(f"Single image generated successfully") # Clear cache after generation clear_cuda_cache() return db_image except Exception as e: print(f"Error during image regeneration: {e}") import traceback traceback.print_exc() db.rollback() return None finally: db.close()