|
import logging
|
|
from typing import Union, List, Optional
|
|
|
|
import torch
|
|
from PIL import Image
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
T2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes.
|
|
Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph.
|
|
Start directly with the action, and keep descriptions literal and precise.
|
|
Think like a cinematographer describing a shot list.
|
|
Do not change the user input intent, just enhance it.
|
|
Keep within 150 words.
|
|
For best results, build your prompts using this structure:
|
|
Start with main action in a single sentence
|
|
Add specific details about movements and gestures
|
|
Describe character/object appearances precisely
|
|
Include background and environment details
|
|
Specify camera angles and movements
|
|
Describe lighting and colors
|
|
Note any changes or sudden events
|
|
Do not exceed the 150 word limit!
|
|
Output the enhanced prompt only.
|
|
"""
|
|
|
|
I2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes.
|
|
Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph.
|
|
Start directly with the action, and keep descriptions literal and precise.
|
|
Think like a cinematographer describing a shot list.
|
|
Keep within 150 words.
|
|
For best results, build your prompts using this structure:
|
|
Describe the image first and then add the user input. Image description should be in first priority! Align to the image caption if it contradicts the user text input.
|
|
Start with main action in a single sentence
|
|
Add specific details about movements and gestures
|
|
Describe character/object appearances precisely
|
|
Include background and environment details
|
|
Specify camera angles and movements
|
|
Describe lighting and colors
|
|
Note any changes or sudden events
|
|
Align to the image caption if it contradicts the user text input.
|
|
Do not exceed the 150 word limit!
|
|
Output the enhanced prompt only.
|
|
"""
|
|
|
|
|
|
def tensor_to_pil(tensor):
|
|
|
|
assert tensor.min() >= -1 and tensor.max() <= 1
|
|
|
|
|
|
tensor = (tensor + 1) / 2
|
|
|
|
|
|
tensor = tensor.permute(1, 2, 0)
|
|
|
|
|
|
numpy_image = (tensor.cpu().numpy() * 255).astype("uint8")
|
|
|
|
|
|
return Image.fromarray(numpy_image)
|
|
|
|
|
|
def generate_cinematic_prompt(
|
|
image_caption_model,
|
|
image_caption_processor,
|
|
prompt_enhancer_model,
|
|
prompt_enhancer_tokenizer,
|
|
prompt: Union[str, List[str]],
|
|
images: Optional[List] = None,
|
|
max_new_tokens: int = 256,
|
|
) -> List[str]:
|
|
prompts = [prompt] if isinstance(prompt, str) else prompt
|
|
|
|
if images is None:
|
|
prompts = _generate_t2v_prompt(
|
|
prompt_enhancer_model,
|
|
prompt_enhancer_tokenizer,
|
|
prompts,
|
|
max_new_tokens,
|
|
T2V_CINEMATIC_PROMPT,
|
|
)
|
|
else:
|
|
|
|
prompts = _generate_i2v_prompt(
|
|
image_caption_model,
|
|
image_caption_processor,
|
|
prompt_enhancer_model,
|
|
prompt_enhancer_tokenizer,
|
|
prompts,
|
|
images,
|
|
max_new_tokens,
|
|
I2V_CINEMATIC_PROMPT,
|
|
)
|
|
|
|
return prompts
|
|
|
|
|
|
def _get_first_frames_from_conditioning_item(conditioning_item) -> List[Image.Image]:
|
|
frames_tensor = conditioning_item.media_item
|
|
return [
|
|
tensor_to_pil(frames_tensor[i, :, 0, :, :])
|
|
for i in range(frames_tensor.shape[0])
|
|
]
|
|
|
|
|
|
def _generate_t2v_prompt(
|
|
prompt_enhancer_model,
|
|
prompt_enhancer_tokenizer,
|
|
prompts: List[str],
|
|
max_new_tokens: int,
|
|
system_prompt: str,
|
|
) -> List[str]:
|
|
messages = [
|
|
[
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": f"user_prompt: {p}"},
|
|
]
|
|
for p in prompts
|
|
]
|
|
|
|
texts = [
|
|
prompt_enhancer_tokenizer.apply_chat_template(
|
|
m, tokenize=False, add_generation_prompt=True
|
|
)
|
|
for m in messages
|
|
]
|
|
|
|
out_prompts = []
|
|
for text in texts:
|
|
model_inputs = prompt_enhancer_tokenizer(text, return_tensors="pt").to(
|
|
prompt_enhancer_model.device
|
|
)
|
|
out_prompts.append(_generate_and_decode_prompts(prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens)[0])
|
|
|
|
return out_prompts
|
|
|
|
def _generate_i2v_prompt(
|
|
image_caption_model,
|
|
image_caption_processor,
|
|
prompt_enhancer_model,
|
|
prompt_enhancer_tokenizer,
|
|
prompts: List[str],
|
|
first_frames: List[Image.Image],
|
|
max_new_tokens: int,
|
|
system_prompt: str,
|
|
) -> List[str]:
|
|
image_captions = _generate_image_captions(
|
|
image_caption_model, image_caption_processor, first_frames
|
|
)
|
|
if len(image_captions) == 1 and len(image_captions) < len(prompts):
|
|
image_captions *= len(prompts)
|
|
messages = [
|
|
[
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": f"user_prompt: {p}\nimage_caption: {c}"},
|
|
]
|
|
for p, c in zip(prompts, image_captions)
|
|
]
|
|
|
|
texts = [
|
|
prompt_enhancer_tokenizer.apply_chat_template(
|
|
m, tokenize=False, add_generation_prompt=True
|
|
)
|
|
for m in messages
|
|
]
|
|
out_prompts = []
|
|
for text in texts:
|
|
model_inputs = prompt_enhancer_tokenizer(text, return_tensors="pt").to(
|
|
prompt_enhancer_model.device
|
|
)
|
|
out_prompts.append(_generate_and_decode_prompts(prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens)[0])
|
|
|
|
return out_prompts
|
|
|
|
|
|
def _generate_image_captions(
|
|
image_caption_model,
|
|
image_caption_processor,
|
|
images: List[Image.Image],
|
|
system_prompt: str = "<DETAILED_CAPTION>",
|
|
) -> List[str]:
|
|
image_caption_prompts = [system_prompt] * len(images)
|
|
inputs = image_caption_processor(
|
|
image_caption_prompts, images, return_tensors="pt"
|
|
).to("cuda")
|
|
|
|
with torch.inference_mode():
|
|
generated_ids = image_caption_model.generate(
|
|
input_ids=inputs["input_ids"],
|
|
pixel_values=inputs["pixel_values"],
|
|
max_new_tokens=1024,
|
|
do_sample=False,
|
|
num_beams=3,
|
|
)
|
|
|
|
return image_caption_processor.batch_decode(generated_ids, skip_special_tokens=True)
|
|
|
|
|
|
def _generate_and_decode_prompts(
|
|
prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens: int
|
|
) -> List[str]:
|
|
with torch.inference_mode():
|
|
outputs = prompt_enhancer_model.generate(
|
|
**model_inputs, max_new_tokens=max_new_tokens
|
|
)
|
|
generated_ids = [
|
|
output_ids[len(input_ids) :]
|
|
for input_ids, output_ids in zip(model_inputs.input_ids, outputs)
|
|
]
|
|
decoded_prompts = prompt_enhancer_tokenizer.batch_decode(
|
|
generated_ids, skip_special_tokens=True
|
|
)
|
|
|
|
return decoded_prompts
|
|
|