|
import os
|
|
import cv2
|
|
import math
|
|
import json
|
|
import torch
|
|
import random
|
|
import librosa
|
|
import traceback
|
|
import torchvision
|
|
import numpy as np
|
|
import pandas as pd
|
|
from PIL import Image
|
|
from einops import rearrange
|
|
from torch.utils.data import Dataset
|
|
from decord import VideoReader, cpu
|
|
from transformers import CLIPImageProcessor
|
|
import torchvision.transforms as transforms
|
|
from torchvision.transforms import ToPILImage
|
|
|
|
|
|
|
|
def get_audio_feature(feature_extractor, audio_path):
|
|
audio_input, sampling_rate = librosa.load(audio_path, sr=16000)
|
|
assert sampling_rate == 16000
|
|
|
|
audio_features = []
|
|
window = 750*640
|
|
for i in range(0, len(audio_input), window):
|
|
audio_feature = feature_extractor(audio_input[i:i+window],
|
|
sampling_rate=sampling_rate,
|
|
return_tensors="pt",
|
|
).input_features
|
|
audio_features.append(audio_feature)
|
|
|
|
audio_features = torch.cat(audio_features, dim=-1)
|
|
return audio_features, len(audio_input) // 640
|
|
|
|
|
|
class VideoAudioTextLoaderVal(Dataset):
|
|
def __init__(
|
|
self,
|
|
image_size: int,
|
|
meta_file: str,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.meta_file = meta_file
|
|
self.image_size = image_size
|
|
self.text_encoder = kwargs.get("text_encoder", None)
|
|
self.text_encoder_2 = kwargs.get("text_encoder_2", None)
|
|
self.feature_extractor = kwargs.get("feature_extractor", None)
|
|
self.meta_files = []
|
|
|
|
csv_data = pd.read_csv(meta_file)
|
|
for idx in range(len(csv_data)):
|
|
self.meta_files.append(
|
|
{
|
|
"videoid": str(csv_data["videoid"][idx]),
|
|
"image_path": str(csv_data["image"][idx]),
|
|
"audio_path": str(csv_data["audio"][idx]),
|
|
"prompt": str(csv_data["prompt"][idx]),
|
|
"fps": float(csv_data["fps"][idx])
|
|
}
|
|
)
|
|
|
|
self.llava_transform = transforms.Compose(
|
|
[
|
|
transforms.Resize((336, 336), interpolation=transforms.InterpolationMode.BILINEAR),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.48145466, 0.4578275, 0.4082107), (0.26862954, 0.26130258, 0.27577711)),
|
|
]
|
|
)
|
|
self.clip_image_processor = CLIPImageProcessor()
|
|
|
|
self.device = torch.device("cuda")
|
|
self.weight_dtype = torch.float16
|
|
|
|
|
|
def __len__(self):
|
|
return len(self.meta_files)
|
|
|
|
@staticmethod
|
|
def get_text_tokens(text_encoder, description, dtype_encode="video"):
|
|
text_inputs = text_encoder.text2tokens(description, data_type=dtype_encode)
|
|
text_ids = text_inputs["input_ids"].squeeze(0)
|
|
text_mask = text_inputs["attention_mask"].squeeze(0)
|
|
return text_ids, text_mask
|
|
|
|
def get_batch_data(self, idx):
|
|
meta_file = self.meta_files[idx]
|
|
videoid = meta_file["videoid"]
|
|
image_path = meta_file["image_path"]
|
|
audio_path = meta_file["audio_path"]
|
|
prompt = "Authentic, Realistic, Natural, High-quality, Lens-Fixed, " + meta_file["prompt"]
|
|
fps = meta_file["fps"]
|
|
|
|
img_size = self.image_size
|
|
ref_image = Image.open(image_path).convert('RGB')
|
|
|
|
|
|
w, h = ref_image.size
|
|
scale = img_size / min(w, h)
|
|
new_w = round(w * scale / 64) * 64
|
|
new_h = round(h * scale / 64) * 64
|
|
|
|
if img_size == 704:
|
|
img_size_long = 1216
|
|
if new_w * new_h > img_size * img_size_long:
|
|
import math
|
|
scale = math.sqrt(img_size * img_size_long / w / h)
|
|
new_w = round(w * scale / 64) * 64
|
|
new_h = round(h * scale / 64) * 64
|
|
|
|
ref_image = ref_image.resize((new_w, new_h), Image.LANCZOS)
|
|
|
|
ref_image = np.array(ref_image)
|
|
ref_image = torch.from_numpy(ref_image)
|
|
|
|
audio_input, audio_len = get_audio_feature(self.feature_extractor, audio_path)
|
|
audio_prompts = audio_input[0]
|
|
|
|
motion_bucket_id_heads = np.array([25] * 4)
|
|
motion_bucket_id_exps = np.array([30] * 4)
|
|
motion_bucket_id_heads = torch.from_numpy(motion_bucket_id_heads)
|
|
motion_bucket_id_exps = torch.from_numpy(motion_bucket_id_exps)
|
|
fps = torch.from_numpy(np.array(fps))
|
|
|
|
to_pil = ToPILImage()
|
|
pixel_value_ref = rearrange(ref_image.clone().unsqueeze(0), "b h w c -> b c h w")
|
|
|
|
pixel_value_ref_llava = [self.llava_transform(to_pil(image)) for image in pixel_value_ref]
|
|
pixel_value_ref_llava = torch.stack(pixel_value_ref_llava, dim=0)
|
|
pixel_value_ref_clip = self.clip_image_processor(
|
|
images=Image.fromarray((pixel_value_ref[0].permute(1,2,0)).data.cpu().numpy().astype(np.uint8)),
|
|
return_tensors="pt"
|
|
).pixel_values[0]
|
|
pixel_value_ref_clip = pixel_value_ref_clip.unsqueeze(0)
|
|
|
|
|
|
|
|
text_ids, text_mask = self.get_text_tokens(self.text_encoder, prompt)
|
|
text_ids_2, text_mask_2 = self.get_text_tokens(self.text_encoder_2, prompt)
|
|
|
|
|
|
batch = {
|
|
"text_prompt": prompt,
|
|
"videoid": videoid,
|
|
"pixel_value_ref": pixel_value_ref.to(dtype=torch.float16),
|
|
"pixel_value_ref_llava": pixel_value_ref_llava.to(dtype=torch.float16),
|
|
"pixel_value_ref_clip": pixel_value_ref_clip.to(dtype=torch.float16),
|
|
"audio_prompts": audio_prompts.to(dtype=torch.float16),
|
|
"motion_bucket_id_heads": motion_bucket_id_heads.to(dtype=text_ids.dtype),
|
|
"motion_bucket_id_exps": motion_bucket_id_exps.to(dtype=text_ids.dtype),
|
|
"fps": fps.to(dtype=torch.float16),
|
|
"text_ids": text_ids.clone(),
|
|
"text_mask": text_mask.clone(),
|
|
"text_ids_2": text_ids_2.clone(),
|
|
"text_mask_2": text_mask_2.clone(),
|
|
"audio_len": audio_len,
|
|
"image_path": image_path,
|
|
"audio_path": audio_path,
|
|
}
|
|
return batch
|
|
|
|
def __getitem__(self, idx):
|
|
return self.get_batch_data(idx)
|
|
|
|
|
|
|
|
|