IceClear commited on
Commit
bfeb62a
·
1 Parent(s): 303cd3c
projects/video_diffusion_sr/infer.py CHANGED
@@ -71,7 +71,7 @@ class VideoDiffusionInfer():
71
 
72
  @log_on_entry
73
  @log_runtime
74
- def configure_dit_model(self, device="cpu", checkpoint=None):
75
  # Load dit checkpoint.
76
  # For fast init & resume,
77
  # when training from scratch, rank0 init DiT on cpu, then sync to other ranks with FSDP.
@@ -83,7 +83,7 @@ class VideoDiffusionInfer():
83
  self.dit.set_gradient_checkpointing(self.config.dit.gradient_checkpoint)
84
 
85
  if checkpoint:
86
- state = torch.load(checkpoint, map_location="cpu", mmap=True)
87
  loading_info = self.dit.load_state_dict(state, strict=True, assign=True)
88
  print(f"Loading pretrained ckpt from {checkpoint}")
89
  print(f"Loading info: {loading_info}")
 
71
 
72
  @log_on_entry
73
  @log_runtime
74
+ def configure_dit_model(self, device="cuda", checkpoint=None):
75
  # Load dit checkpoint.
76
  # For fast init & resume,
77
  # when training from scratch, rank0 init DiT on cpu, then sync to other ranks with FSDP.
 
83
  self.dit.set_gradient_checkpointing(self.config.dit.gradient_checkpoint)
84
 
85
  if checkpoint:
86
+ state = torch.load(checkpoint, map_location=self.device, mmap=True)
87
  loading_info = self.dit.load_state_dict(state, strict=True, assign=True)
88
  print(f"Loading pretrained ckpt from {checkpoint}")
89
  print(f"Loading info: {loading_info}")