Wan2GP / wan /utils /utils.py
zxymimi23451's picture
Upload 258 files
78360e7 verified
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import binascii
import os
import os.path as osp
import torchvision.transforms.functional as TF
import torch.nn.functional as F
import imageio
import torch
import decord
import torchvision
from PIL import Image
import numpy as np
from rembg import remove, new_session
import random
__all__ = ['cache_video', 'cache_image', 'str2bool']
from PIL import Image
def seed_everything(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
if torch.backends.mps.is_available():
torch.mps.manual_seed(seed)
def resample(video_fps, video_frames_count, max_target_frames_count, target_fps, start_target_frame ):
import math
if video_fps < target_fps :
video_fps = target_fps
video_frame_duration = 1 /video_fps
target_frame_duration = 1 / target_fps
target_time = start_target_frame * target_frame_duration
frame_no = math.ceil(target_time / video_frame_duration)
cur_time = frame_no * video_frame_duration
frame_ids =[]
while True:
if max_target_frames_count != 0 and len(frame_ids) >= max_target_frames_count :
break
diff = round( (target_time -cur_time) / video_frame_duration , 5)
add_frames_count = math.ceil( diff)
frame_no += add_frames_count
if frame_no >= video_frames_count:
break
frame_ids.append(frame_no)
cur_time += add_frames_count * video_frame_duration
target_time += target_frame_duration
frame_ids = frame_ids[:max_target_frames_count]
return frame_ids
def get_video_frame(file_name, frame_no):
decord.bridge.set_bridge('torch')
reader = decord.VideoReader(file_name)
frame = reader.get_batch([frame_no]).squeeze(0)
img = Image.fromarray(frame.numpy().astype(np.uint8))
return img
def resize_lanczos(img, h, w):
img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8))
img = img.resize((w,h), resample=Image.Resampling.LANCZOS)
return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
def remove_background(img, session=None):
if session ==None:
session = new_session()
img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8))
img = remove(img, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
def convert_tensor_to_image(t, frame_no = -1):
t = t[:, frame_no] if frame_no >= 0 else t
return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy())
def save_image(tensor_image, name, frame_no = -1):
convert_tensor_to_image(tensor_image, frame_no).save(name)
def get_outpainting_full_area_dimensions(frame_height,frame_width, outpainting_dims):
outpainting_top, outpainting_bottom, outpainting_left, outpainting_right= outpainting_dims
frame_height = int(frame_height * (100 + outpainting_top + outpainting_bottom) / 100)
frame_width = int(frame_width * (100 + outpainting_left + outpainting_right) / 100)
return frame_height, frame_width
def get_outpainting_frame_location(final_height, final_width, outpainting_dims, block_size = 8):
outpainting_top, outpainting_bottom, outpainting_left, outpainting_right= outpainting_dims
raw_height = int(final_height / ((100 + outpainting_top + outpainting_bottom) / 100))
height = int(raw_height / block_size) * block_size
extra_height = raw_height - height
raw_width = int(final_width / ((100 + outpainting_left + outpainting_right) / 100))
width = int(raw_width / block_size) * block_size
extra_width = raw_width - width
margin_top = int(outpainting_top/(100 + outpainting_top + outpainting_bottom) * final_height)
if extra_height != 0 and (outpainting_top + outpainting_bottom) != 0:
margin_top += int(outpainting_top / (outpainting_top + outpainting_bottom) * extra_height)
if (margin_top + height) > final_height or outpainting_bottom == 0: margin_top = final_height - height
margin_left = int(outpainting_left/(100 + outpainting_left + outpainting_right) * final_width)
if extra_width != 0 and (outpainting_left + outpainting_right) != 0:
margin_left += int(outpainting_left / (outpainting_left + outpainting_right) * extra_height)
if (margin_left + width) > final_width or outpainting_right == 0: margin_left = final_width - width
return height, width, margin_top, margin_left
def calculate_new_dimensions(canvas_height, canvas_width, height, width, fit_into_canvas, block_size = 16):
if fit_into_canvas == None:
return height, width
if fit_into_canvas:
scale1 = min(canvas_height / height, canvas_width / width)
scale2 = min(canvas_width / height, canvas_height / width)
scale = max(scale1, scale2)
else:
scale = (canvas_height * canvas_width / (height * width))**(1/2)
new_height = round( height * scale / block_size) * block_size
new_width = round( width * scale / block_size) * block_size
return new_height, new_width
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, fit_into_canvas = False ):
if rm_background > 0:
session = new_session()
output_list =[]
for i, img in enumerate(img_list):
width, height = img.size
if fit_into_canvas:
white_canvas = np.ones((budget_height, budget_width, 3), dtype=np.uint8) * 255
scale = min(budget_height / height, budget_width / width)
new_height = int(height * scale)
new_width = int(width * scale)
resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
top = (budget_height - new_height) // 2
left = (budget_width - new_width) // 2
white_canvas[top:top + new_height, left:left + new_width] = np.array(resized_image)
resized_image = Image.fromarray(white_canvas)
else:
scale = (budget_height * budget_width / (height * width))**(1/2)
new_height = int( round(height * scale / 16) * 16)
new_width = int( round(width * scale / 16) * 16)
resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
if rm_background == 1 or rm_background == 2 and i > 0 :
# resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200,
return output_list
def rand_name(length=8, suffix=''):
name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
if suffix:
if not suffix.startswith('.'):
suffix = '.' + suffix
name += suffix
return name
def cache_video(tensor,
save_file=None,
fps=30,
suffix='.mp4',
nrow=8,
normalize=True,
value_range=(-1, 1),
retry=5):
# cache file
cache_file = osp.join('/tmp', rand_name(
suffix=suffix)) if save_file is None else save_file
# save to cache
error = None
for _ in range(retry):
try:
# preprocess
tensor = tensor.clamp(min(value_range), max(value_range))
tensor = torch.stack([
torchvision.utils.make_grid(
u, nrow=nrow, normalize=normalize, value_range=value_range)
for u in tensor.unbind(2)
],
dim=1).permute(1, 2, 3, 0)
tensor = (tensor * 255).type(torch.uint8).cpu()
# write video
writer = imageio.get_writer(
cache_file, fps=fps, codec='libx264', quality=8)
for frame in tensor.numpy():
writer.append_data(frame)
writer.close()
return cache_file
except Exception as e:
error = e
continue
else:
print(f'cache_video failed, error: {error}', flush=True)
return None
def cache_image(tensor,
save_file,
nrow=8,
normalize=True,
value_range=(-1, 1),
retry=5):
# cache file
suffix = osp.splitext(save_file)[1]
if suffix.lower() not in [
'.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp'
]:
suffix = '.png'
# save to cache
error = None
for _ in range(retry):
try:
tensor = tensor.clamp(min(value_range), max(value_range))
torchvision.utils.save_image(
tensor,
save_file,
nrow=nrow,
normalize=normalize,
value_range=value_range)
return save_file
except Exception as e:
error = e
continue
def str2bool(v):
"""
Convert a string to a boolean.
Supported true values: 'yes', 'true', 't', 'y', '1'
Supported false values: 'no', 'false', 'f', 'n', '0'
Args:
v (str): String to convert.
Returns:
bool: Converted boolean value.
Raises:
argparse.ArgumentTypeError: If the value cannot be converted to boolean.
"""
if isinstance(v, bool):
return v
v_lower = v.lower()
if v_lower in ('yes', 'true', 't', 'y', '1'):
return True
elif v_lower in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected (True/False)')
import sys, time
# Global variables to track download progress
_start_time = None
_last_time = None
_last_downloaded = 0
_speed_history = []
_update_interval = 0.5 # Update speed every 0.5 seconds
def progress_hook(block_num, block_size, total_size, filename=None):
"""
Simple progress bar hook for urlretrieve
Args:
block_num: Number of blocks downloaded so far
block_size: Size of each block in bytes
total_size: Total size of the file in bytes
filename: Name of the file being downloaded (optional)
"""
global _start_time, _last_time, _last_downloaded, _speed_history, _update_interval
current_time = time.time()
downloaded = block_num * block_size
# Initialize timing on first call
if _start_time is None or block_num == 0:
_start_time = current_time
_last_time = current_time
_last_downloaded = 0
_speed_history = []
# Calculate download speed only at specified intervals
speed = 0
if current_time - _last_time >= _update_interval:
if _last_time > 0:
current_speed = (downloaded - _last_downloaded) / (current_time - _last_time)
_speed_history.append(current_speed)
# Keep only last 5 speed measurements for smoothing
if len(_speed_history) > 5:
_speed_history.pop(0)
# Average the recent speeds for smoother display
speed = sum(_speed_history) / len(_speed_history)
_last_time = current_time
_last_downloaded = downloaded
elif _speed_history:
# Use the last calculated average speed
speed = sum(_speed_history) / len(_speed_history)
# Format file sizes and speed
def format_bytes(bytes_val):
for unit in ['B', 'KB', 'MB', 'GB']:
if bytes_val < 1024:
return f"{bytes_val:.1f}{unit}"
bytes_val /= 1024
return f"{bytes_val:.1f}TB"
file_display = filename if filename else "Unknown file"
if total_size <= 0:
# If total size is unknown, show downloaded bytes
speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else ""
line = f"\r{file_display}: {format_bytes(downloaded)}{speed_str}"
# Clear any trailing characters by padding with spaces
sys.stdout.write(line.ljust(80))
sys.stdout.flush()
return
downloaded = block_num * block_size
percent = min(100, (downloaded / total_size) * 100)
# Create progress bar (40 characters wide to leave room for other info)
bar_length = 40
filled = int(bar_length * percent / 100)
bar = '█' * filled + '░' * (bar_length - filled)
# Format file sizes and speed
def format_bytes(bytes_val):
for unit in ['B', 'KB', 'MB', 'GB']:
if bytes_val < 1024:
return f"{bytes_val:.1f}{unit}"
bytes_val /= 1024
return f"{bytes_val:.1f}TB"
speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else ""
# Display progress with filename first
line = f"\r{file_display}: [{bar}] {percent:.1f}% ({format_bytes(downloaded)}/{format_bytes(total_size)}){speed_str}"
# Clear any trailing characters by padding with spaces
sys.stdout.write(line.ljust(100))
sys.stdout.flush()
# Print newline when complete
if percent >= 100:
print()
# Wrapper function to include filename in progress hook
def create_progress_hook(filename):
"""Creates a progress hook with the filename included"""
global _start_time, _last_time, _last_downloaded, _speed_history
# Reset timing variables for new download
_start_time = None
_last_time = None
_last_downloaded = 0
_speed_history = []
def hook(block_num, block_size, total_size):
return progress_hook(block_num, block_size, total_size, filename)
return hook