NitinBot001 commited on
Commit
7ee2d3d
·
verified ·
1 Parent(s): 9fdff0d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +352 -0
app.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers import LTXVideoTransformer3DModel, LTXVideoPipeline
4
+ from transformers import T5EncoderModel, T5Tokenizer
5
+ import spaces
6
+ import numpy as np
7
+ import tempfile
8
+ import os
9
+ import time
10
+ import logging
11
+ from PIL import Image
12
+ import cv2
13
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
14
+ from fastapi.responses import FileResponse
15
+ import uvicorn
16
+ import threading
17
+ import json
18
+
19
+ # Configure logging
20
+ logging.basicConfig(level=logging.INFO)
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Global variables for model
24
+ pipe = None
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
+
27
+ def load_model():
28
+ """Load the LTX-Video model with optimizations"""
29
+ global pipe
30
+ try:
31
+ logger.info("Loading LTX-Video model...")
32
+
33
+ # Load the pipeline
34
+ pipe = LTXVideoPipeline.from_pretrained(
35
+ "Lightricks/LTX-Video-0.9.7-dev",
36
+ torch_dtype=torch.bfloat16,
37
+ use_safetensors=True
38
+ )
39
+
40
+ # Move to device
41
+ pipe = pipe.to(device)
42
+
43
+ # Enable optimizations
44
+ pipe.vae.enable_tiling()
45
+ pipe.vae.enable_slicing()
46
+
47
+ # Enable memory efficient attention if available
48
+ if hasattr(pipe.unet, 'enable_xformers_memory_efficient_attention'):
49
+ pipe.unet.enable_xformers_memory_efficient_attention()
50
+
51
+ logger.info("Model loaded successfully!")
52
+ return True
53
+ except Exception as e:
54
+ logger.error(f"Error loading model: {e}")
55
+ return False
56
+
57
+ def validate_inputs(prompt, duration, image=None):
58
+ """Validate input parameters"""
59
+ errors = []
60
+
61
+ if not prompt or len(prompt.strip()) == 0:
62
+ errors.append("Prompt is required")
63
+
64
+ if len(prompt) > 500:
65
+ errors.append("Prompt must be less than 500 characters")
66
+
67
+ if duration < 3 or duration > 5:
68
+ errors.append("Duration must be between 3 and 5 seconds")
69
+
70
+ if image is not None:
71
+ try:
72
+ if isinstance(image, str):
73
+ img = Image.open(image)
74
+ else:
75
+ img = image
76
+
77
+ # Check image dimensions
78
+ width, height = img.size
79
+ if width > 1024 or height > 1024:
80
+ errors.append("Image dimensions must be less than 1024x1024")
81
+
82
+ except Exception as e:
83
+ errors.append(f"Invalid image: {str(e)}")
84
+
85
+ return errors
86
+
87
+ def frames_to_video(frames, output_path, fps=24):
88
+ """Convert frames to video using OpenCV"""
89
+ try:
90
+ height, width = frames[0].shape[:2]
91
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
92
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
93
+
94
+ for frame in frames:
95
+ # Convert RGB to BGR for OpenCV
96
+ frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
97
+ out.write(frame_bgr)
98
+
99
+ out.release()
100
+ return True
101
+ except Exception as e:
102
+ logger.error(f"Error creating video: {e}")
103
+ return False
104
+
105
+ @spaces.GPU(duration=60)
106
+ def generate_video_core(prompt, negative_prompt="", duration=4, image=None):
107
+ """Core video generation function with ZeroGPU decorator"""
108
+ global pipe
109
+
110
+ start_time = time.time()
111
+
112
+ try:
113
+ # Calculate number of frames (24 FPS)
114
+ num_frames = int(duration * 24)
115
+
116
+ # Prepare generation parameters
117
+ generation_kwargs = {
118
+ "prompt": prompt,
119
+ "negative_prompt": negative_prompt,
120
+ "num_frames": num_frames,
121
+ "height": 512,
122
+ "width": 768,
123
+ "num_inference_steps": 30,
124
+ "guidance_scale": 7.5,
125
+ "generator": torch.Generator(device=device).manual_seed(42)
126
+ }
127
+
128
+ # Add image if provided
129
+ if image is not None:
130
+ if isinstance(image, str):
131
+ image = Image.open(image)
132
+ # Resize image to match output dimensions
133
+ image = image.resize((768, 512), Image.Resampling.LANCZOS)
134
+ generation_kwargs["image"] = image
135
+
136
+ logger.info(f"Starting generation with {num_frames} frames...")
137
+
138
+ # Generate video
139
+ with torch.inference_mode():
140
+ result = pipe(**generation_kwargs)
141
+
142
+ # Get the generated frames
143
+ frames = result.frames[0] # First (and only) video in batch
144
+
145
+ # Convert to numpy arrays if needed
146
+ if torch.is_tensor(frames):
147
+ frames = frames.cpu().numpy()
148
+
149
+ # Ensure frames are in the right format (0-255 uint8)
150
+ if frames.dtype != np.uint8:
151
+ frames = (frames * 255).astype(np.uint8)
152
+
153
+ # Create temporary video file
154
+ temp_dir = tempfile.mkdtemp()
155
+ video_path = os.path.join(temp_dir, "generated_video.mp4")
156
+
157
+ # Convert frames to video
158
+ success = frames_to_video(frames, video_path, fps=24)
159
+
160
+ if not success:
161
+ raise Exception("Failed to create video file")
162
+
163
+ generation_time = time.time() - start_time
164
+ logger.info(f"Video generated successfully in {generation_time:.2f} seconds")
165
+
166
+ return video_path, f"Generated in {generation_time:.2f}s"
167
+
168
+ except Exception as e:
169
+ logger.error(f"Error generating video: {e}")
170
+ raise Exception(f"Generation failed: {str(e)}")
171
+
172
+ def generate_video_gradio(prompt, negative_prompt, duration, image):
173
+ """Gradio interface wrapper"""
174
+ try:
175
+ # Validate inputs
176
+ errors = validate_inputs(prompt, duration, image)
177
+ if errors:
178
+ return None, f"Validation errors: {'; '.join(errors)}"
179
+
180
+ # Check if model is loaded
181
+ if pipe is None:
182
+ return None, "Model not loaded. Please wait for initialization."
183
+
184
+ # Generate video
185
+ video_path, status = generate_video_core(prompt, negative_prompt, duration, image)
186
+ return video_path, status
187
+
188
+ except Exception as e:
189
+ logger.error(f"Gradio generation error: {e}")
190
+ return None, f"Error: {str(e)}"
191
+
192
+ # Create Gradio interface
193
+ def create_gradio_interface():
194
+ with gr.Blocks(title="LTX-Video Generator", theme=gr.themes.Soft()) as demo:
195
+ gr.Markdown("# 🎬 LTX-Video Generator")
196
+ gr.Markdown("Generate 3-5 second videos using the LTX-Video model from Lightricks")
197
+
198
+ with gr.Row():
199
+ with gr.Column(scale=1):
200
+ # Input controls
201
+ image_input = gr.File(
202
+ label="Input Image (Optional)",
203
+ file_types=[".png", ".jpg", ".jpeg"],
204
+ type="filepath"
205
+ )
206
+
207
+ prompt_input = gr.Textbox(
208
+ label="Prompt",
209
+ placeholder="Describe the video you want to generate...",
210
+ lines=3,
211
+ max_lines=5
212
+ )
213
+
214
+ negative_prompt_input = gr.Textbox(
215
+ label="Negative Prompt (Optional)",
216
+ placeholder="What you don't want in the video...",
217
+ lines=2,
218
+ max_lines=3
219
+ )
220
+
221
+ duration_slider = gr.Slider(
222
+ minimum=3,
223
+ maximum=5,
224
+ value=4,
225
+ step=0.5,
226
+ label="Duration (seconds)"
227
+ )
228
+
229
+ generate_btn = gr.Button("🎬 Generate Video", variant="primary")
230
+
231
+ gr.Markdown("**Estimated time:** 4-6 seconds")
232
+
233
+ with gr.Column(scale=1):
234
+ # Output controls
235
+ video_output = gr.Video(label="Generated Video")
236
+ status_output = gr.Textbox(label="Status", interactive=False)
237
+
238
+ # Event handlers
239
+ generate_btn.click(
240
+ fn=generate_video_gradio,
241
+ inputs=[prompt_input, negative_prompt_input, duration_slider, image_input],
242
+ outputs=[video_output, status_output]
243
+ )
244
+
245
+ # Examples
246
+ gr.Examples(
247
+ examples=[
248
+ ["A cat playing with a ball of yarn", "", 4, None],
249
+ ["Ocean waves crashing on a beach at sunset", "", 3, None],
250
+ ["A person walking through a forest", "blurry, low quality", 5, None],
251
+ ],
252
+ inputs=[prompt_input, negative_prompt_input, duration_slider, image_input]
253
+ )
254
+
255
+ return demo
256
+
257
+ # FastAPI setup
258
+ app = FastAPI(title="LTX-Video API", description="Generate videos using LTX-Video model")
259
+
260
+ @app.post("/generate_video")
261
+ async def api_generate_video(
262
+ prompt: str = Form(..., description="Text prompt for video generation"),
263
+ negative_prompt: str = Form("", description="Negative prompt (optional)"),
264
+ duration: float = Form(4.0, description="Duration in seconds (3-5)"),
265
+ image: UploadFile = File(None, description="Input image (optional)")
266
+ ):
267
+ """Generate video via API"""
268
+ try:
269
+ # Validate inputs
270
+ image_path = None
271
+ if image:
272
+ # Save uploaded image temporarily
273
+ temp_dir = tempfile.mkdtemp()
274
+ image_path = os.path.join(temp_dir, image.filename)
275
+ with open(image_path, "wb") as f:
276
+ content = await image.read()
277
+ f.write(content)
278
+
279
+ errors = validate_inputs(prompt, duration, image_path)
280
+ if errors:
281
+ raise HTTPException(status_code=400, detail={"errors": errors})
282
+
283
+ if pipe is None:
284
+ raise HTTPException(status_code=503, detail="Model not loaded")
285
+
286
+ # Generate video
287
+ video_path, status = generate_video_core(prompt, negative_prompt, duration, image_path)
288
+
289
+ # Return video file
290
+ return FileResponse(
291
+ video_path,
292
+ media_type="video/mp4",
293
+ filename=f"generated_video_{int(time.time())}.mp4"
294
+ )
295
+
296
+ except HTTPException:
297
+ raise
298
+ except Exception as e:
299
+ logger.error(f"API generation error: {e}")
300
+ raise HTTPException(status_code=500, detail=str(e))
301
+
302
+ @app.get("/")
303
+ async def root():
304
+ """API documentation"""
305
+ return {
306
+ "message": "LTX-Video API",
307
+ "endpoints": {
308
+ "/generate_video": "POST - Generate video",
309
+ "/docs": "GET - API documentation"
310
+ },
311
+ "curl_example": """
312
+ curl -X POST "http://localhost:7860/generate_video" \\
313
+ -F "prompt=A cat playing with a ball" \\
314
+ -F "duration=4" \\
315
+ -F "negative_prompt=blurry" \\
316
+ -F "image=@your_image.jpg" \\
317
+ --output generated_video.mp4
318
+ """
319
+ }
320
+
321
+ def run_api():
322
+ """Run FastAPI server"""
323
+ uvicorn.run(app, host="0.0.0.0", port=7861, log_level="info")
324
+
325
+ def main():
326
+ """Main function"""
327
+ # Load model
328
+ logger.info("Initializing LTX-Video Generator...")
329
+ model_loaded = load_model()
330
+
331
+ if not model_loaded:
332
+ logger.error("Failed to load model. Exiting.")
333
+ return
334
+
335
+ # Create Gradio interface
336
+ demo = create_gradio_interface()
337
+
338
+ # Start API server in a separate thread
339
+ api_thread = threading.Thread(target=run_api, daemon=True)
340
+ api_thread.start()
341
+ logger.info("API server started on http://localhost:7861")
342
+
343
+ # Launch Gradio interface
344
+ demo.launch(
345
+ server_name="0.0.0.0",
346
+ server_port=7860,
347
+ share=False,
348
+ show_api=False
349
+ )
350
+
351
+ if __name__ == "__main__":
352
+ main()