Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
from typing import Optional, Union | |
import numpy as np | |
import torch | |
from tqdm import tqdm | |
from dataset import config_utils | |
from dataset.config_utils import BlueprintGenerator, ConfigSanitizer | |
import accelerate | |
from dataset.image_video_dataset import ARCHITECTURE_WAN, ItemInfo, save_text_encoder_output_cache_wan | |
# for t5 config: all Wan2.1 models have the same config for t5 | |
from wan.configs import wan_t2v_14B | |
import cache_text_encoder_outputs | |
import logging | |
from utils.model_utils import str_to_dtype | |
from wan.modules.t5 import T5EncoderModel | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO) | |
def encode_and_save_batch( | |
text_encoder: T5EncoderModel, batch: list[ItemInfo], device: torch.device, accelerator: Optional[accelerate.Accelerator] | |
): | |
prompts = [item.caption for item in batch] | |
# print(prompts) | |
# encode prompt | |
with torch.no_grad(): | |
if accelerator is not None: | |
with accelerator.autocast(): | |
context = text_encoder(prompts, device) | |
else: | |
context = text_encoder(prompts, device) | |
# save prompt cache | |
for item, ctx in zip(batch, context): | |
save_text_encoder_output_cache_wan(item, ctx) | |
def main(args): | |
device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu" | |
device = torch.device(device) | |
# Load dataset config | |
blueprint_generator = BlueprintGenerator(ConfigSanitizer()) | |
logger.info(f"Load dataset config from {args.dataset_config}") | |
user_config = config_utils.load_user_config(args.dataset_config) | |
blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_WAN) | |
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group) | |
datasets = train_dataset_group.datasets | |
# define accelerator for fp8 inference | |
config = wan_t2v_14B.t2v_14B # all Wan2.1 models have the same config for t5 | |
accelerator = None | |
if args.fp8_t5: | |
accelerator = accelerate.Accelerator(mixed_precision="bf16" if config.t5_dtype == torch.bfloat16 else "fp16") | |
# prepare cache files and paths: all_cache_files_for_dataset = exisiting cache files, all_cache_paths_for_dataset = all cache paths in the dataset | |
all_cache_files_for_dataset, all_cache_paths_for_dataset = cache_text_encoder_outputs.prepare_cache_files_and_paths(datasets) | |
# Load T5 | |
logger.info(f"Loading T5: {args.t5}") | |
text_encoder = T5EncoderModel( | |
text_len=config.text_len, dtype=config.t5_dtype, device=device, weight_path=args.t5, fp8=args.fp8_t5 | |
) | |
# Encode with T5 | |
logger.info("Encoding with T5") | |
def encode_for_text_encoder(batch: list[ItemInfo]): | |
encode_and_save_batch(text_encoder, batch, device, accelerator) | |
cache_text_encoder_outputs.process_text_encoder_batches( | |
args.num_workers, | |
args.skip_existing, | |
args.batch_size, | |
datasets, | |
all_cache_files_for_dataset, | |
all_cache_paths_for_dataset, | |
encode_for_text_encoder, | |
) | |
del text_encoder | |
# remove cache files not in dataset | |
cache_text_encoder_outputs.post_process_cache_files(datasets, all_cache_files_for_dataset, all_cache_paths_for_dataset) | |
def wan_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: | |
parser.add_argument("--t5", type=str, default=None, required=True, help="text encoder (T5) checkpoint path") | |
parser.add_argument("--fp8_t5", action="store_true", help="use fp8 for Text Encoder model") | |
return parser | |
if __name__ == "__main__": | |
parser = cache_text_encoder_outputs.setup_parser_common() | |
parser = wan_setup_parser(parser) | |
args = parser.parse_args() | |
main(args) | |