Spaces:
Running
on
Zero
Running
on
Zero
IceClear
commited on
Commit
·
bfeb62a
1
Parent(s):
303cd3c
update
Browse files
projects/video_diffusion_sr/infer.py
CHANGED
@@ -71,7 +71,7 @@ class VideoDiffusionInfer():
|
|
71 |
|
72 |
@log_on_entry
|
73 |
@log_runtime
|
74 |
-
def configure_dit_model(self, device="
|
75 |
# Load dit checkpoint.
|
76 |
# For fast init & resume,
|
77 |
# when training from scratch, rank0 init DiT on cpu, then sync to other ranks with FSDP.
|
@@ -83,7 +83,7 @@ class VideoDiffusionInfer():
|
|
83 |
self.dit.set_gradient_checkpointing(self.config.dit.gradient_checkpoint)
|
84 |
|
85 |
if checkpoint:
|
86 |
-
state = torch.load(checkpoint, map_location=
|
87 |
loading_info = self.dit.load_state_dict(state, strict=True, assign=True)
|
88 |
print(f"Loading pretrained ckpt from {checkpoint}")
|
89 |
print(f"Loading info: {loading_info}")
|
|
|
71 |
|
72 |
@log_on_entry
|
73 |
@log_runtime
|
74 |
+
def configure_dit_model(self, device="cuda", checkpoint=None):
|
75 |
# Load dit checkpoint.
|
76 |
# For fast init & resume,
|
77 |
# when training from scratch, rank0 init DiT on cpu, then sync to other ranks with FSDP.
|
|
|
83 |
self.dit.set_gradient_checkpointing(self.config.dit.gradient_checkpoint)
|
84 |
|
85 |
if checkpoint:
|
86 |
+
state = torch.load(checkpoint, map_location=self.device, mmap=True)
|
87 |
loading_info = self.dit.load_state_dict(state, strict=True, assign=True)
|
88 |
print(f"Loading pretrained ckpt from {checkpoint}")
|
89 |
print(f"Loading info: {loading_info}")
|