import gradio as gr import torch from diffusers import LTXVideoTransformer3DModel, LTXVideoPipeline from transformers import T5EncoderModel, T5Tokenizer import spaces import numpy as np import tempfile import os import time import logging from PIL import Image import cv2 from fastapi import FastAPI, File, UploadFile, Form, HTTPException from fastapi.responses import FileResponse import uvicorn import threading import json # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global variables for model pipe = None device = "cuda" if torch.cuda.is_available() else "cpu" def load_model(): """Load the LTX-Video model with optimizations""" global pipe try: logger.info("Loading LTX-Video model...") # Load the pipeline pipe = LTXVideoPipeline.from_pretrained( "Lightricks/LTX-Video-0.9.7-dev", torch_dtype=torch.bfloat16, use_safetensors=True ) # Move to device pipe = pipe.to(device) # Enable optimizations pipe.vae.enable_tiling() pipe.vae.enable_slicing() # Enable memory efficient attention if available if hasattr(pipe.unet, 'enable_xformers_memory_efficient_attention'): pipe.unet.enable_xformers_memory_efficient_attention() logger.info("Model loaded successfully!") return True except Exception as e: logger.error(f"Error loading model: {e}") return False def validate_inputs(prompt, duration, image=None): """Validate input parameters""" errors = [] if not prompt or len(prompt.strip()) == 0: errors.append("Prompt is required") if len(prompt) > 500: errors.append("Prompt must be less than 500 characters") if duration < 3 or duration > 5: errors.append("Duration must be between 3 and 5 seconds") if image is not None: try: if isinstance(image, str): img = Image.open(image) else: img = image # Check image dimensions width, height = img.size if width > 1024 or height > 1024: errors.append("Image dimensions must be less than 1024x1024") except Exception as e: errors.append(f"Invalid image: {str(e)}") return errors def frames_to_video(frames, output_path, fps=24): """Convert frames to video using OpenCV""" try: height, width = frames[0].shape[:2] fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) for frame in frames: # Convert RGB to BGR for OpenCV frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) out.write(frame_bgr) out.release() return True except Exception as e: logger.error(f"Error creating video: {e}") return False @spaces.GPU(duration=60) def generate_video_core(prompt, negative_prompt="", duration=4, image=None): """Core video generation function with ZeroGPU decorator""" global pipe start_time = time.time() try: # Calculate number of frames (24 FPS) num_frames = int(duration * 24) # Prepare generation parameters generation_kwargs = { "prompt": prompt, "negative_prompt": negative_prompt, "num_frames": num_frames, "height": 512, "width": 768, "num_inference_steps": 30, "guidance_scale": 7.5, "generator": torch.Generator(device=device).manual_seed(42) } # Add image if provided if image is not None: if isinstance(image, str): image = Image.open(image) # Resize image to match output dimensions image = image.resize((768, 512), Image.Resampling.LANCZOS) generation_kwargs["image"] = image logger.info(f"Starting generation with {num_frames} frames...") # Generate video with torch.inference_mode(): result = pipe(**generation_kwargs) # Get the generated frames frames = result.frames[0] # First (and only) video in batch # Convert to numpy arrays if needed if torch.is_tensor(frames): frames = frames.cpu().numpy() # Ensure frames are in the right format (0-255 uint8) if frames.dtype != np.uint8: frames = (frames * 255).astype(np.uint8) # Create temporary video file temp_dir = tempfile.mkdtemp() video_path = os.path.join(temp_dir, "generated_video.mp4") # Convert frames to video success = frames_to_video(frames, video_path, fps=24) if not success: raise Exception("Failed to create video file") generation_time = time.time() - start_time logger.info(f"Video generated successfully in {generation_time:.2f} seconds") return video_path, f"Generated in {generation_time:.2f}s" except Exception as e: logger.error(f"Error generating video: {e}") raise Exception(f"Generation failed: {str(e)}") def generate_video_gradio(prompt, negative_prompt, duration, image): """Gradio interface wrapper""" try: # Validate inputs errors = validate_inputs(prompt, duration, image) if errors: return None, f"Validation errors: {'; '.join(errors)}" # Check if model is loaded if pipe is None: return None, "Model not loaded. Please wait for initialization." # Generate video video_path, status = generate_video_core(prompt, negative_prompt, duration, image) return video_path, status except Exception as e: logger.error(f"Gradio generation error: {e}") return None, f"Error: {str(e)}" # Create Gradio interface def create_gradio_interface(): with gr.Blocks(title="LTX-Video Generator", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🎬 LTX-Video Generator") gr.Markdown("Generate 3-5 second videos using the LTX-Video model from Lightricks") with gr.Row(): with gr.Column(scale=1): # Input controls image_input = gr.File( label="Input Image (Optional)", file_types=[".png", ".jpg", ".jpeg"], type="filepath" ) prompt_input = gr.Textbox( label="Prompt", placeholder="Describe the video you want to generate...", lines=3, max_lines=5 ) negative_prompt_input = gr.Textbox( label="Negative Prompt (Optional)", placeholder="What you don't want in the video...", lines=2, max_lines=3 ) duration_slider = gr.Slider( minimum=3, maximum=5, value=4, step=0.5, label="Duration (seconds)" ) generate_btn = gr.Button("🎬 Generate Video", variant="primary") gr.Markdown("**Estimated time:** 4-6 seconds") with gr.Column(scale=1): # Output controls video_output = gr.Video(label="Generated Video") status_output = gr.Textbox(label="Status", interactive=False) # Event handlers generate_btn.click( fn=generate_video_gradio, inputs=[prompt_input, negative_prompt_input, duration_slider, image_input], outputs=[video_output, status_output] ) # Examples gr.Examples( examples=[ ["A cat playing with a ball of yarn", "", 4, None], ["Ocean waves crashing on a beach at sunset", "", 3, None], ["A person walking through a forest", "blurry, low quality", 5, None], ], inputs=[prompt_input, negative_prompt_input, duration_slider, image_input] ) return demo # FastAPI setup app = FastAPI(title="LTX-Video API", description="Generate videos using LTX-Video model") @app.post("/generate_video") async def api_generate_video( prompt: str = Form(..., description="Text prompt for video generation"), negative_prompt: str = Form("", description="Negative prompt (optional)"), duration: float = Form(4.0, description="Duration in seconds (3-5)"), image: UploadFile = File(None, description="Input image (optional)") ): """Generate video via API""" try: # Validate inputs image_path = None if image: # Save uploaded image temporarily temp_dir = tempfile.mkdtemp() image_path = os.path.join(temp_dir, image.filename) with open(image_path, "wb") as f: content = await image.read() f.write(content) errors = validate_inputs(prompt, duration, image_path) if errors: raise HTTPException(status_code=400, detail={"errors": errors}) if pipe is None: raise HTTPException(status_code=503, detail="Model not loaded") # Generate video video_path, status = generate_video_core(prompt, negative_prompt, duration, image_path) # Return video file return FileResponse( video_path, media_type="video/mp4", filename=f"generated_video_{int(time.time())}.mp4" ) except HTTPException: raise except Exception as e: logger.error(f"API generation error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/") async def root(): """API documentation""" return { "message": "LTX-Video API", "endpoints": { "/generate_video": "POST - Generate video", "/docs": "GET - API documentation" }, "curl_example": """ curl -X POST "http://localhost:7860/generate_video" \\ -F "prompt=A cat playing with a ball" \\ -F "duration=4" \\ -F "negative_prompt=blurry" \\ -F "image=@your_image.jpg" \\ --output generated_video.mp4 """ } def run_api(): """Run FastAPI server""" uvicorn.run(app, host="0.0.0.0", port=7861, log_level="info") def main(): """Main function""" # Load model logger.info("Initializing LTX-Video Generator...") model_loaded = load_model() if not model_loaded: logger.error("Failed to load model. Exiting.") return # Create Gradio interface demo = create_gradio_interface() # Start API server in a separate thread api_thread = threading.Thread(target=run_api, daemon=True) api_thread.start() logger.info("API server started on http://localhost:7861") # Launch Gradio interface demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_api=False ) if __name__ == "__main__": main()