|
"""
|
|
Demo for Self-Forcing.
|
|
"""
|
|
|
|
import os
|
|
import re
|
|
import random
|
|
import time
|
|
import base64
|
|
import argparse
|
|
import hashlib
|
|
import subprocess
|
|
import urllib.request
|
|
from io import BytesIO
|
|
from PIL import Image
|
|
import numpy as np
|
|
import torch
|
|
from omegaconf import OmegaConf
|
|
from flask import Flask, render_template, jsonify
|
|
from flask_socketio import SocketIO, emit
|
|
import queue
|
|
from threading import Thread, Event
|
|
|
|
from pipeline import CausalInferencePipeline
|
|
from demo_utils.constant import ZERO_VAE_CACHE
|
|
from demo_utils.vae_block3 import VAEDecoderWrapper
|
|
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder
|
|
from demo_utils.utils import generate_timestamp
|
|
from demo_utils.memory import gpu, get_cuda_free_memory_gb, DynamicSwapInstaller, move_model_to_device_with_memory_preservation
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--port', type=int, default=5001)
|
|
parser.add_argument('--host', type=str, default='0.0.0.0')
|
|
parser.add_argument("--checkpoint_path", type=str, default='./checkpoints/self_forcing_dmd.pt')
|
|
parser.add_argument("--config_path", type=str, default='./configs/self_forcing_dmd.yaml')
|
|
parser.add_argument('--trt', action='store_true')
|
|
args = parser.parse_args()
|
|
|
|
print(f'Free VRAM {get_cuda_free_memory_gb(gpu)} GB')
|
|
low_memory = get_cuda_free_memory_gb(gpu) < 40
|
|
|
|
|
|
config = OmegaConf.load(args.config_path)
|
|
default_config = OmegaConf.load("configs/default_config.yaml")
|
|
config = OmegaConf.merge(default_config, config)
|
|
|
|
text_encoder = WanTextEncoder()
|
|
|
|
|
|
current_vae_decoder = None
|
|
current_use_taehv = False
|
|
fp8_applied = False
|
|
torch_compile_applied = False
|
|
global frame_number
|
|
frame_number = 0
|
|
anim_name = ""
|
|
frame_rate = 6
|
|
|
|
def initialize_vae_decoder(use_taehv=False, use_trt=False):
|
|
"""Initialize VAE decoder based on the selected option"""
|
|
global current_vae_decoder, current_use_taehv
|
|
|
|
if use_trt:
|
|
from demo_utils.vae import VAETRTWrapper
|
|
current_vae_decoder = VAETRTWrapper()
|
|
return current_vae_decoder
|
|
|
|
if use_taehv:
|
|
from demo_utils.taehv import TAEHV
|
|
|
|
taehv_checkpoint_path = "checkpoints/taew2_1.pth"
|
|
if not os.path.exists(taehv_checkpoint_path):
|
|
print(f"taew2_1.pth not found in checkpoints folder {taehv_checkpoint_path}. Downloading...")
|
|
os.makedirs("checkpoints", exist_ok=True)
|
|
download_url = "https://github.com/madebyollin/taehv/raw/main/taew2_1.pth"
|
|
try:
|
|
urllib.request.urlretrieve(download_url, taehv_checkpoint_path)
|
|
print(f"Successfully downloaded taew2_1.pth to {taehv_checkpoint_path}")
|
|
except Exception as e:
|
|
print(f"Failed to download taew2_1.pth: {e}")
|
|
raise
|
|
|
|
class DotDict(dict):
|
|
__getattr__ = dict.__getitem__
|
|
__setattr__ = dict.__setitem__
|
|
|
|
class TAEHVDiffusersWrapper(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.dtype = torch.float16
|
|
self.taehv = TAEHV(checkpoint_path=taehv_checkpoint_path).to(self.dtype)
|
|
self.config = DotDict(scaling_factor=1.0)
|
|
|
|
def decode(self, latents, return_dict=None):
|
|
|
|
|
|
return self.taehv.decode_video(latents, parallel=False).mul_(2).sub_(1)
|
|
|
|
current_vae_decoder = TAEHVDiffusersWrapper()
|
|
else:
|
|
current_vae_decoder = VAEDecoderWrapper()
|
|
vae_state_dict = torch.load('wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth', map_location="cpu")
|
|
decoder_state_dict = {}
|
|
for key, value in vae_state_dict.items():
|
|
if 'decoder.' in key or 'conv2' in key:
|
|
decoder_state_dict[key] = value
|
|
current_vae_decoder.load_state_dict(decoder_state_dict)
|
|
|
|
current_vae_decoder.eval()
|
|
current_vae_decoder.to(dtype=torch.float16)
|
|
current_vae_decoder.requires_grad_(False)
|
|
current_vae_decoder.to(gpu)
|
|
current_use_taehv = use_taehv
|
|
|
|
print(f"β
VAE decoder initialized with {'TAEHV' if use_taehv else 'default VAE'}")
|
|
return current_vae_decoder
|
|
|
|
|
|
|
|
vae_decoder = initialize_vae_decoder(use_taehv=False, use_trt=args.trt)
|
|
|
|
transformer = WanDiffusionWrapper(is_causal=True)
|
|
state_dict = torch.load(args.checkpoint_path, map_location="cpu")
|
|
transformer.load_state_dict(state_dict['generator_ema'])
|
|
|
|
text_encoder.eval()
|
|
transformer.eval()
|
|
|
|
transformer.to(dtype=torch.float16)
|
|
text_encoder.to(dtype=torch.bfloat16)
|
|
|
|
text_encoder.requires_grad_(False)
|
|
transformer.requires_grad_(False)
|
|
|
|
pipeline = CausalInferencePipeline(
|
|
config,
|
|
device=gpu,
|
|
generator=transformer,
|
|
text_encoder=text_encoder,
|
|
vae=vae_decoder
|
|
)
|
|
|
|
if low_memory:
|
|
DynamicSwapInstaller.install_model(text_encoder, device=gpu)
|
|
else:
|
|
text_encoder.to(gpu)
|
|
transformer.to(gpu)
|
|
|
|
|
|
app = Flask(__name__)
|
|
app.config['SECRET_KEY'] = 'frontend_buffered_demo'
|
|
socketio = SocketIO(app, cors_allowed_origins="*")
|
|
|
|
generation_active = False
|
|
stop_event = Event()
|
|
frame_send_queue = queue.Queue()
|
|
sender_thread = None
|
|
models_compiled = False
|
|
|
|
|
|
def tensor_to_base64_frame(frame_tensor):
|
|
"""Convert a single frame tensor to base64 image string."""
|
|
global frame_number, anim_name
|
|
|
|
frame = torch.clamp(frame_tensor.float(), -1., 1.) * 127.5 + 127.5
|
|
frame = frame.to(torch.uint8).cpu().numpy()
|
|
|
|
|
|
if len(frame.shape) == 3:
|
|
frame = np.transpose(frame, (1, 2, 0))
|
|
|
|
|
|
if frame.shape[2] == 3:
|
|
image = Image.fromarray(frame, 'RGB')
|
|
else:
|
|
image = Image.fromarray(frame)
|
|
|
|
|
|
buffer = BytesIO()
|
|
image.save(buffer, format='JPEG', quality=100)
|
|
if not os.path.exists("./images/%s" % anim_name):
|
|
os.makedirs("./images/%s" % anim_name)
|
|
frame_number += 1
|
|
image.save("./images/%s/%s_%03d.jpg" % (anim_name, anim_name, frame_number))
|
|
img_str = base64.b64encode(buffer.getvalue()).decode()
|
|
return f"data:image/jpeg;base64,{img_str}"
|
|
|
|
|
|
def frame_sender_worker():
|
|
"""Background thread that processes frame send queue non-blocking."""
|
|
global frame_send_queue, generation_active, stop_event
|
|
|
|
print("π‘ Frame sender thread started")
|
|
|
|
while True:
|
|
frame_data = None
|
|
try:
|
|
|
|
frame_data = frame_send_queue.get(timeout=1.0)
|
|
|
|
if frame_data is None:
|
|
frame_send_queue.task_done()
|
|
break
|
|
|
|
frame_tensor, frame_index, block_index, job_id = frame_data
|
|
|
|
|
|
base64_frame = tensor_to_base64_frame(frame_tensor)
|
|
|
|
|
|
try:
|
|
socketio.emit('frame_ready', {
|
|
'data': base64_frame,
|
|
'frame_index': frame_index,
|
|
'block_index': block_index,
|
|
'job_id': job_id
|
|
})
|
|
except Exception as e:
|
|
print(f"β οΈ Failed to send frame {frame_index}: {e}")
|
|
|
|
frame_send_queue.task_done()
|
|
|
|
except queue.Empty:
|
|
|
|
if not generation_active and frame_send_queue.empty():
|
|
break
|
|
except Exception as e:
|
|
print(f"β Frame sender error: {e}")
|
|
|
|
if frame_data is not None:
|
|
try:
|
|
frame_send_queue.task_done()
|
|
except Exception as e:
|
|
print(f"β Failed to mark frame task as done: {e}")
|
|
break
|
|
|
|
print("π‘ Frame sender thread stopped")
|
|
|
|
|
|
@torch.no_grad()
|
|
def generate_video_stream(prompt, seed, enable_torch_compile=False, enable_fp8=False, use_taehv=False):
|
|
"""Generate video and push frames immediately to frontend."""
|
|
global generation_active, stop_event, frame_send_queue, sender_thread, models_compiled, torch_compile_applied, fp8_applied, current_vae_decoder, current_use_taehv, frame_rate, anim_name
|
|
|
|
try:
|
|
generation_active = True
|
|
stop_event.clear()
|
|
job_id = generate_timestamp()
|
|
|
|
|
|
if sender_thread is None or not sender_thread.is_alive():
|
|
sender_thread = Thread(target=frame_sender_worker, daemon=True)
|
|
sender_thread.start()
|
|
|
|
|
|
def emit_progress(message, progress):
|
|
try:
|
|
socketio.emit('progress', {
|
|
'message': message,
|
|
'progress': progress,
|
|
'job_id': job_id
|
|
})
|
|
except Exception as e:
|
|
print(f"β Failed to emit progress: {e}")
|
|
|
|
emit_progress('Starting generation...', 0)
|
|
|
|
|
|
if use_taehv != current_use_taehv:
|
|
emit_progress('Switching VAE decoder...', 2)
|
|
print(f"π Switching VAE decoder to {'TAEHV' if use_taehv else 'default VAE'}")
|
|
current_vae_decoder = initialize_vae_decoder(use_taehv=use_taehv)
|
|
|
|
pipeline.vae = current_vae_decoder
|
|
|
|
|
|
if enable_fp8 and not fp8_applied:
|
|
emit_progress('Applying FP8 quantization...', 3)
|
|
print("π§ Applying FP8 quantization to transformer")
|
|
from torchao.quantization.quant_api import quantize_, Float8DynamicActivationFloat8WeightConfig, PerTensor
|
|
quantize_(transformer, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()))
|
|
fp8_applied = True
|
|
|
|
|
|
emit_progress('Encoding text prompt...', 8)
|
|
conditional_dict = text_encoder(text_prompts=[prompt])
|
|
for key, value in conditional_dict.items():
|
|
conditional_dict[key] = value.to(dtype=torch.float16)
|
|
if low_memory:
|
|
gpu_memory_preservation = get_cuda_free_memory_gb(gpu) + 5
|
|
move_model_to_device_with_memory_preservation(
|
|
text_encoder,target_device=gpu, preserved_memory_gb=gpu_memory_preservation)
|
|
|
|
|
|
torch_compile_applied = enable_torch_compile
|
|
if enable_torch_compile and not models_compiled:
|
|
|
|
transformer.compile(mode="max-autotune-no-cudagraphs")
|
|
if not current_use_taehv and not low_memory and not args.trt:
|
|
current_vae_decoder.compile(mode="max-autotune-no-cudagraphs")
|
|
|
|
|
|
emit_progress('Initializing generation...', 12)
|
|
|
|
rnd = torch.Generator(gpu).manual_seed(seed)
|
|
|
|
|
|
pipeline._initialize_kv_cache(batch_size=1, dtype=torch.float16, device=gpu)
|
|
pipeline._initialize_crossattn_cache(batch_size=1, dtype=torch.float16, device=gpu)
|
|
|
|
noise = torch.randn([1, 21, 16, 60, 104], device=gpu, dtype=torch.float16, generator=rnd)
|
|
|
|
|
|
num_blocks = 7
|
|
current_start_frame = 0
|
|
num_input_frames = 0
|
|
all_num_frames = [pipeline.num_frame_per_block] * num_blocks
|
|
if current_use_taehv:
|
|
vae_cache = None
|
|
else:
|
|
vae_cache = ZERO_VAE_CACHE
|
|
for i in range(len(vae_cache)):
|
|
vae_cache[i] = vae_cache[i].to(device=gpu, dtype=torch.float16)
|
|
|
|
total_frames_sent = 0
|
|
generation_start_time = time.time()
|
|
|
|
emit_progress('Generating frames... (frontend handles timing)', 15)
|
|
|
|
for idx, current_num_frames in enumerate(all_num_frames):
|
|
if not generation_active or stop_event.is_set():
|
|
break
|
|
|
|
progress = int(((idx + 1) / len(all_num_frames)) * 80) + 15
|
|
|
|
|
|
if idx == 0 and torch_compile_applied and not models_compiled:
|
|
emit_progress(
|
|
f'Processing block 1/{len(all_num_frames)} - Compiling models (may take 5-10 minutes)...', progress)
|
|
print(f"π₯ Processing block {idx+1}/{len(all_num_frames)}")
|
|
models_compiled = True
|
|
else:
|
|
emit_progress(f'Processing block {idx+1}/{len(all_num_frames)}...', progress)
|
|
print(f"π Processing block {idx+1}/{len(all_num_frames)}")
|
|
|
|
block_start_time = time.time()
|
|
|
|
noisy_input = noise[:, current_start_frame -
|
|
num_input_frames:current_start_frame + current_num_frames - num_input_frames]
|
|
|
|
|
|
denoising_start = time.time()
|
|
for index, current_timestep in enumerate(pipeline.denoising_step_list):
|
|
if not generation_active or stop_event.is_set():
|
|
break
|
|
|
|
timestep = torch.ones([1, current_num_frames], device=noise.device,
|
|
dtype=torch.int64) * current_timestep
|
|
|
|
if index < len(pipeline.denoising_step_list) - 1:
|
|
_, denoised_pred = transformer(
|
|
noisy_image_or_video=noisy_input,
|
|
conditional_dict=conditional_dict,
|
|
timestep=timestep,
|
|
kv_cache=pipeline.kv_cache1,
|
|
crossattn_cache=pipeline.crossattn_cache,
|
|
current_start=current_start_frame * pipeline.frame_seq_length
|
|
)
|
|
next_timestep = pipeline.denoising_step_list[index + 1]
|
|
noisy_input = pipeline.scheduler.add_noise(
|
|
denoised_pred.flatten(0, 1),
|
|
torch.randn_like(denoised_pred.flatten(0, 1)),
|
|
next_timestep * torch.ones([1 * current_num_frames], device=noise.device, dtype=torch.long)
|
|
).unflatten(0, denoised_pred.shape[:2])
|
|
else:
|
|
_, denoised_pred = transformer(
|
|
noisy_image_or_video=noisy_input,
|
|
conditional_dict=conditional_dict,
|
|
timestep=timestep,
|
|
kv_cache=pipeline.kv_cache1,
|
|
crossattn_cache=pipeline.crossattn_cache,
|
|
current_start=current_start_frame * pipeline.frame_seq_length
|
|
)
|
|
|
|
if not generation_active or stop_event.is_set():
|
|
break
|
|
|
|
denoising_time = time.time() - denoising_start
|
|
print(f"β‘ Block {idx+1} denoising completed in {denoising_time:.2f}s")
|
|
|
|
|
|
|
|
|
|
|
|
if idx != len(all_num_frames) - 1:
|
|
transformer(
|
|
noisy_image_or_video=denoised_pred,
|
|
conditional_dict=conditional_dict,
|
|
timestep=torch.zeros_like(timestep),
|
|
kv_cache=pipeline.kv_cache1,
|
|
crossattn_cache=pipeline.crossattn_cache,
|
|
current_start=current_start_frame * pipeline.frame_seq_length,
|
|
)
|
|
|
|
|
|
print(f"π¨ Decoding block {idx+1} to pixels...")
|
|
decode_start = time.time()
|
|
if args.trt:
|
|
all_current_pixels = []
|
|
for i in range(denoised_pred.shape[1]):
|
|
is_first_frame = torch.tensor(1.0).cuda().half() if idx == 0 and i == 0 else \
|
|
torch.tensor(0.0).cuda().half()
|
|
outputs = vae_decoder.forward(denoised_pred[:, i:i + 1, :, :, :].half(), is_first_frame, *vae_cache)
|
|
|
|
current_pixels, vae_cache = outputs[0], outputs[1:]
|
|
print(current_pixels.max(), current_pixels.min())
|
|
all_current_pixels.append(current_pixels.clone())
|
|
pixels = torch.cat(all_current_pixels, dim=1)
|
|
if idx == 0:
|
|
pixels = pixels[:, 3:, :, :, :]
|
|
else:
|
|
if current_use_taehv:
|
|
if vae_cache is None:
|
|
vae_cache = denoised_pred
|
|
else:
|
|
denoised_pred = torch.cat([vae_cache, denoised_pred], dim=1)
|
|
vae_cache = denoised_pred[:, -3:, :, :, :]
|
|
pixels = current_vae_decoder.decode(denoised_pred)
|
|
print(f"denoised_pred shape: {denoised_pred.shape}")
|
|
print(f"pixels shape: {pixels.shape}")
|
|
if idx == 0:
|
|
pixels = pixels[:, 3:, :, :, :]
|
|
else:
|
|
pixels = pixels[:, 12:, :, :, :]
|
|
|
|
else:
|
|
pixels, vae_cache = current_vae_decoder(denoised_pred.half(), *vae_cache)
|
|
if idx == 0:
|
|
pixels = pixels[:, 3:, :, :, :]
|
|
|
|
decode_time = time.time() - decode_start
|
|
print(f"π¨ Block {idx+1} VAE decoding completed in {decode_time:.2f}s")
|
|
|
|
|
|
block_frames = pixels.shape[1]
|
|
print(f"π‘ Queueing {block_frames} frames from block {idx+1} for sending...")
|
|
queue_start = time.time()
|
|
|
|
for frame_idx in range(block_frames):
|
|
if not generation_active or stop_event.is_set():
|
|
break
|
|
|
|
frame_tensor = pixels[0, frame_idx].cpu()
|
|
|
|
|
|
frame_send_queue.put((frame_tensor, total_frames_sent, idx, job_id))
|
|
total_frames_sent += 1
|
|
|
|
queue_time = time.time() - queue_start
|
|
block_time = time.time() - block_start_time
|
|
print(f"β
Block {idx+1} completed in {block_time:.2f}s ({block_frames} frames queued in {queue_time:.3f}s)")
|
|
|
|
current_start_frame += current_num_frames
|
|
|
|
generation_time = time.time() - generation_start_time
|
|
print(f"π Generation completed in {generation_time:.2f}s! {total_frames_sent} frames queued for sending")
|
|
|
|
|
|
emit_progress('Waiting for all frames to be sent...', 97)
|
|
print("β³ Waiting for all frames to be sent...")
|
|
frame_send_queue.join()
|
|
print("β
All frames sent successfully!")
|
|
|
|
generate_mp4_from_images("./images","./videos/"+anim_name+".mp4", frame_rate )
|
|
|
|
emit_progress('Generation complete!', 100)
|
|
|
|
try:
|
|
socketio.emit('generation_complete', {
|
|
'message': 'Video generation completed!',
|
|
'total_frames': total_frames_sent,
|
|
'generation_time': f"{generation_time:.2f}s",
|
|
'job_id': job_id
|
|
})
|
|
except Exception as e:
|
|
print(f"β Failed to emit generation complete: {e}")
|
|
|
|
except Exception as e:
|
|
print(f"β Generation failed: {e}")
|
|
try:
|
|
socketio.emit('error', {
|
|
'message': f'Generation failed: {str(e)}',
|
|
'job_id': job_id
|
|
})
|
|
except Exception as e:
|
|
print(f"β Failed to emit error: {e}")
|
|
finally:
|
|
generation_active = False
|
|
stop_event.set()
|
|
|
|
|
|
try:
|
|
frame_send_queue.put(None)
|
|
except Exception as e:
|
|
print(f"β Failed to put None in frame_send_queue: {e}")
|
|
|
|
|
|
def generate_mp4_from_images(image_directory, output_video_path, fps=24):
|
|
"""
|
|
Generate an MP4 video from a directory of images ordered alphabetically.
|
|
|
|
:param image_directory: Path to the directory containing images.
|
|
:param output_video_path: Path where the output MP4 will be saved.
|
|
:param fps: Frames per second for the output video.
|
|
"""
|
|
global anim_name
|
|
|
|
cmd = [
|
|
'ffmpeg',
|
|
'-framerate', str(fps),
|
|
'-i', os.path.join(image_directory, anim_name+'/'+anim_name+'_%03d.jpg'),
|
|
'-c:v', 'libx264',
|
|
'-pix_fmt', 'yuv420p',
|
|
output_video_path
|
|
]
|
|
try:
|
|
subprocess.run(cmd, check=True)
|
|
print(f"Video saved to {output_video_path}")
|
|
except subprocess.CalledProcessError as e:
|
|
print(f"An error occurred: {e}")
|
|
|
|
def calculate_sha256(data):
|
|
|
|
if isinstance(data, str):
|
|
data = data.encode()
|
|
|
|
sha256_hash = hashlib.sha256(data).hexdigest()
|
|
return sha256_hash
|
|
|
|
|
|
@socketio.on('connect')
|
|
def handle_connect():
|
|
print('Client connected')
|
|
emit('status', {'message': 'Connected to frontend-buffered demo server'})
|
|
|
|
|
|
@socketio.on('disconnect')
|
|
def handle_disconnect():
|
|
print('Client disconnected')
|
|
|
|
|
|
@socketio.on('start_generation')
|
|
def handle_start_generation(data):
|
|
global generation_active, frame_number, anim_name, frame_rate
|
|
|
|
frame_number = 0
|
|
if generation_active:
|
|
emit('error', {'message': 'Generation already in progress'})
|
|
return
|
|
|
|
prompt = data.get('prompt', '')
|
|
|
|
seed = data.get('seed', -1)
|
|
if seed==-1:
|
|
seed = random.randint(0, 2**32)
|
|
|
|
|
|
words_up_to_punctuation = re.split(r'[^\w\s]', prompt)[0].strip() if prompt else ''
|
|
if not words_up_to_punctuation:
|
|
words_up_to_punctuation = re.split(r'[\n\r]', prompt)[0].strip()
|
|
|
|
|
|
sha256_hash = calculate_sha256(prompt)
|
|
|
|
|
|
anim_name = f"{words_up_to_punctuation[:20]}_{str(seed)}_{sha256_hash[:10]}"
|
|
|
|
generation_active = True
|
|
generation_start_time = time.time()
|
|
enable_torch_compile = data.get('enable_torch_compile', False)
|
|
enable_fp8 = data.get('enable_fp8', False)
|
|
use_taehv = data.get('use_taehv', False)
|
|
frame_rate = data.get('fps', 6)
|
|
|
|
if not prompt:
|
|
emit('error', {'message': 'Prompt is required'})
|
|
return
|
|
|
|
|
|
socketio.start_background_task(generate_video_stream, prompt, seed,
|
|
enable_torch_compile, enable_fp8, use_taehv)
|
|
emit('status', {'message': 'Generation started - frames will be sent immediately'})
|
|
|
|
|
|
@socketio.on('stop_generation')
|
|
def handle_stop_generation():
|
|
global generation_active, stop_event, frame_send_queue
|
|
generation_active = False
|
|
stop_event.set()
|
|
|
|
|
|
try:
|
|
frame_send_queue.put(None)
|
|
except Exception as e:
|
|
print(f"β Failed to put None in frame_send_queue: {e}")
|
|
|
|
emit('status', {'message': 'Generation stopped'})
|
|
|
|
|
|
|
|
|
|
@app.route('/')
|
|
def index():
|
|
return render_template('demo.html')
|
|
|
|
|
|
@app.route('/api/status')
|
|
def api_status():
|
|
return jsonify({
|
|
'generation_active': generation_active,
|
|
'free_vram_gb': get_cuda_free_memory_gb(gpu),
|
|
'fp8_applied': fp8_applied,
|
|
'torch_compile_applied': torch_compile_applied,
|
|
'current_use_taehv': current_use_taehv
|
|
})
|
|
|
|
|
|
if __name__ == '__main__':
|
|
print(f"π Starting demo on http://{args.host}:{args.port}")
|
|
socketio.run(app, host=args.host, port=args.port, debug=False)
|
|
|