SeedVR2-3B / projects /inference_seedvr_3b.py
IceClear
upload files
42f2c22
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# //
# // Licensed under the Apache License, Version 2.0 (the "License");
# // you may not use this file except in compliance with the License.
# // You may obtain a copy of the License at
# //
# // http://www.apache.org/licenses/LICENSE-2.0
# //
# // Unless required by applicable law or agreed to in writing, software
# // distributed under the License is distributed on an "AS IS" BASIS,
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# // See the License for the specific language governing permissions and
# // limitations under the License.
import os
import torch
import mediapy
from einops import rearrange
from omegaconf import OmegaConf
print(os.getcwd())
import datetime
from tqdm import tqdm
import gc
from data.image.transforms.divisible_crop import DivisibleCrop
from data.image.transforms.na_resize import NaResize
from data.video.transforms.rearrange import Rearrange
if os.path.exists("./projects/video_diffusion_sr/color_fix.py"):
from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
use_colorfix=True
else:
use_colorfix = False
print('Note!!!!!! Color fix is not avaliable!')
from torchvision.transforms import Compose, Lambda, Normalize
from torchvision.io.video import read_video
import argparse
from common.distributed import (
get_device,
init_torch,
)
from common.distributed.advanced import (
get_data_parallel_rank,
get_data_parallel_world_size,
get_sequence_parallel_rank,
get_sequence_parallel_world_size,
init_sequence_parallel,
)
from projects.video_diffusion_sr.infer import VideoDiffusionInfer
from common.config import load_config
from common.distributed.ops import sync_data
from common.seed import set_seed
from common.partition import partition_by_groups, partition_by_size
def configure_sequence_parallel(sp_size):
if sp_size > 1:
init_sequence_parallel(sp_size)
def configure_runner(sp_size):
config_path = os.path.join('./configs_3b', 'main.yaml')
config = load_config(config_path)
runner = VideoDiffusionInfer(config)
OmegaConf.set_readonly(runner.config, False)
init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600))
configure_sequence_parallel(sp_size)
runner.configure_dit_model(device="cuda", checkpoint='./ckpts/seedvr_ema_3b.pth')
runner.configure_vae_model()
# Set memory limit.
if hasattr(runner.vae, "set_memory_limit"):
runner.vae.set_memory_limit(**runner.config.vae.memory_limit)
return runner
def generation_step(runner, text_embeds_dict, cond_latents):
def _move_to_cuda(x):
return [i.to(get_device()) for i in x]
noises = [torch.randn_like(latent) for latent in cond_latents]
aug_noises = [torch.randn_like(latent) for latent in cond_latents]
print(f"Generating with noise shape: {noises[0].size()}.")
noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0)
noises, aug_noises, cond_latents = list(
map(lambda x: _move_to_cuda(x), (noises, aug_noises, cond_latents))
)
cond_noise_scale = 0.1
def _add_noise(x, aug_noise):
t = (
torch.tensor([1000.0], device=get_device())
* cond_noise_scale
)
shape = torch.tensor(x.shape[1:], device=get_device())[None]
t = runner.timestep_transform(t, shape)
print(
f"Timestep shifting from"
f" {1000.0 * cond_noise_scale} to {t}."
)
x = runner.schedule.forward(x, aug_noise, t)
return x
conditions = [
runner.get_condition(
noise,
task="sr",
latent_blur=_add_noise(latent_blur, aug_noise),
)
for noise, aug_noise, latent_blur in zip(noises, aug_noises, cond_latents)
]
with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
video_tensors = runner.inference(
noises=noises,
conditions=conditions,
dit_offload=True,
**text_embeds_dict,
)
samples = [
(
rearrange(video[:, None], "c t h w -> t c h w")
if video.ndim == 3
else rearrange(video, "c t h w -> t c h w")
)
for video in video_tensors
]
del video_tensors
return samples
def generation_loop(runner, video_path='./test_videos', output_dir='./results', batch_size=1, cfg_scale=6.5, cfg_rescale=0.0, sample_steps=50, seed=666, res_h=1280, res_w=720, sp_size=1):
def _build_pos_and_neg_prompt():
# read positive prompt
positive_text = "Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, \
hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, \
skin pore detailing, hyper sharpness, perfect without deformations."
# read negative prompt
negative_text = "painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, \
CG Style, 3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, \
signature, jpeg artifacts, deformed, lowres, over-smooth"
return positive_text, negative_text
def _build_test_prompts(video_path):
positive_text, negative_text = _build_pos_and_neg_prompt()
original_videos = []
prompts = {}
video_list = os.listdir(video_path)
for f in video_list:
if f.endswith(".mp4"):
original_videos.append(f)
prompts[f] = positive_text
print(f"Total prompts to be generated: {len(original_videos)}")
return original_videos, prompts, negative_text
def _extract_text_embeds():
# Text encoder forward.
positive_prompts_embeds = []
for texts_pos in tqdm(original_videos_local):
text_pos_embeds = torch.load('pos_emb.pt')
text_neg_embeds = torch.load('neg_emb.pt')
positive_prompts_embeds.append(
{"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]}
)
gc.collect()
torch.cuda.empty_cache()
return positive_prompts_embeds
def cut_videos(videos, sp_size):
t = videos.size(1)
if t <= 4 * sp_size:
print(f"Cut input video size: {videos.size()}")
padding = [videos[:, -1].unsqueeze(1)] * (4 * sp_size - t + 1)
padding = torch.cat(padding, dim=1)
videos = torch.cat([videos, padding], dim=1)
return videos
if (t - 1) % (4 * sp_size) == 0:
return videos
else:
padding = [videos[:, -1].unsqueeze(1)] * (
4 * sp_size - ((t - 1) % (4 * sp_size))
)
padding = torch.cat(padding, dim=1)
videos = torch.cat([videos, padding], dim=1)
assert (videos.size(1) - 1) % (4 * sp_size) == 0
return videos
# classifier-free guidance
runner.config.diffusion.cfg.scale = cfg_scale
runner.config.diffusion.cfg.rescale = cfg_rescale
# sampling steps
runner.config.diffusion.timesteps.sampling.steps = sample_steps
runner.configure_diffusion()
# set random seed
set_seed(seed, same_across_ranks=True)
os.makedirs(output_dir, exist_ok=True)
tgt_path = output_dir
# get test prompts
original_videos, _, _ = _build_test_prompts(video_path)
# divide the prompts into different groups
original_videos_group = partition_by_groups(
original_videos,
get_data_parallel_world_size() // get_sequence_parallel_world_size(),
)
# store prompt mapping
original_videos_local = original_videos_group[
get_data_parallel_rank() // get_sequence_parallel_world_size()
]
original_videos_local = partition_by_size(original_videos_local, batch_size)
# pre-extract the text embeddings
positive_prompts_embeds = _extract_text_embeds()
video_transform = Compose(
[
NaResize(
resolution=(
res_h * res_w
)
** 0.5,
mode="area",
# Upsample image, model only trained for high res.
downsample_only=False,
),
Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
DivisibleCrop((16, 16)),
Normalize(0.5, 0.5),
Rearrange("t c h w -> c t h w"),
]
)
# generation loop
for videos, text_embeds in tqdm(zip(original_videos_local, positive_prompts_embeds)):
# read condition latents
cond_latents = []
for video in videos:
video = (
read_video(
os.path.join(video_path, video), output_format="TCHW"
)[0]
/ 255.0
)
print(f"Read video size: {video.size()}")
cond_latents.append(video_transform(video.to(get_device())))
ori_lengths = [video.size(1) for video in cond_latents]
input_videos = cond_latents
cond_latents = [cut_videos(video, sp_size) for video in cond_latents]
runner.dit.to("cpu")
print(f"Encoding videos: {list(map(lambda x: x.size(), cond_latents))}")
runner.vae.to(get_device())
cond_latents = runner.vae_encode(cond_latents)
runner.vae.to("cpu")
runner.dit.to(get_device())
for i, emb in enumerate(text_embeds["texts_pos"]):
text_embeds["texts_pos"][i] = emb.to(get_device())
for i, emb in enumerate(text_embeds["texts_neg"]):
text_embeds["texts_neg"][i] = emb.to(get_device())
samples = generation_step(runner, text_embeds, cond_latents=cond_latents)
runner.dit.to("cpu")
del cond_latents
# dump samples to the output directory
if get_sequence_parallel_rank() == 0:
for path, input, sample, ori_length in zip(
videos, input_videos, samples, ori_lengths
):
if ori_length < sample.shape[0]:
sample = sample[:ori_length]
filename = os.path.join(tgt_path, os.path.basename(path))
# color fix
input = (
rearrange(input[:, None], "c t h w -> t c h w")
if input.ndim == 3
else rearrange(input, "c t h w -> t c h w")
)
if use_colorfix:
sample = wavelet_reconstruction(
sample.to("cpu"), input[: sample.size(0)].to("cpu")
)
else:
sample = sample.to("cpu")
sample = (
rearrange(sample[:, None], "t c h w -> t h w c")
if sample.ndim == 3
else rearrange(sample, "t c h w -> t h w c")
)
sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
sample = sample.to(torch.uint8).numpy()
if sample.shape[0] == 1:
mediapy.write_image(filename, sample.squeeze(0))
else:
mediapy.write_video(
filename, sample, fps=24
)
gc.collect()
torch.cuda.empty_cache()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--video_path", type=str, default="./test_videos")
parser.add_argument("--output_dir", type=str, default="./results")
parser.add_argument("--cfg_scale", type=float, default=6.5)
parser.add_argument("--sample_steps", type=int, default=50)
parser.add_argument("--seed", type=int, default=666)
parser.add_argument("--res_h", type=int, default=720)
parser.add_argument("--res_w", type=int, default=1280)
parser.add_argument("--sp_size", type=int, default=1)
args = parser.parse_args()
runner = configure_runner(args.sp_size)
generation_loop(runner, **vars(args))