Spaces:
mashroo
/
Running on Zero

YoussefAnso commited on
Commit
7391a72
·
1 Parent(s): d493b2e

Update device handling in TwoStagePipeline to default to 'cuda' and adjust model loading to use 'cpu' for compatibility. Ensure consistent device usage across stages and improve code clarity.

Browse files
Files changed (1) hide show
  1. pipelines.py +4 -5
pipelines.py CHANGED
@@ -16,7 +16,7 @@ class TwoStagePipeline(object):
16
  stage2_model_config,
17
  stage1_sampler_config,
18
  stage2_sampler_config,
19
- device=None,
20
  dtype=torch.float16,
21
  resize_rate=1,
22
  ) -> None:
@@ -25,15 +25,14 @@ class TwoStagePipeline(object):
25
  - the first stage was condition on single pixel image, gererate multi-view pixel image, based on the v2pp config
26
  - the second stage was condition on multiview pixel image generated by the first stage, generate the final image, based on the stage2-test config
27
  """
28
- device = torch.device("cuda")
29
  self.resize_rate = resize_rate
30
 
31
  self.stage1_model = instantiate_from_config(OmegaConf.load(stage1_model_config.config).model)
32
- self.stage1_model.load_state_dict(torch.load(stage1_model_config.resume, map_location=device), strict=False)
33
  self.stage1_model = self.stage1_model.to(device).to(dtype)
34
 
35
  self.stage2_model = instantiate_from_config(OmegaConf.load(stage2_model_config.config).model)
36
- sd = torch.load(stage2_model_config.resume, map_location=device)
37
  self.stage2_model.load_state_dict(sd, strict=False)
38
  self.stage2_model = self.stage2_model.to(device).to(dtype)
39
 
@@ -168,4 +167,4 @@ if __name__ == "__main__":
168
  np_imgs = np.concatenate(stage1_images, 1)
169
  np_xyzs = np.concatenate(stage2_images, 1)
170
  Image.fromarray(np_imgs).save("pixel_images.png")
171
- Image.fromarray(np_xyzs).save("xyz_images.png")
 
16
  stage2_model_config,
17
  stage1_sampler_config,
18
  stage2_sampler_config,
19
+ device="cuda",
20
  dtype=torch.float16,
21
  resize_rate=1,
22
  ) -> None:
 
25
  - the first stage was condition on single pixel image, gererate multi-view pixel image, based on the v2pp config
26
  - the second stage was condition on multiview pixel image generated by the first stage, generate the final image, based on the stage2-test config
27
  """
 
28
  self.resize_rate = resize_rate
29
 
30
  self.stage1_model = instantiate_from_config(OmegaConf.load(stage1_model_config.config).model)
31
+ self.stage1_model.load_state_dict(torch.load(stage1_model_config.resume, map_location="cpu"), strict=False)
32
  self.stage1_model = self.stage1_model.to(device).to(dtype)
33
 
34
  self.stage2_model = instantiate_from_config(OmegaConf.load(stage2_model_config.config).model)
35
+ sd = torch.load(stage2_model_config.resume, map_location="cpu")
36
  self.stage2_model.load_state_dict(sd, strict=False)
37
  self.stage2_model = self.stage2_model.to(device).to(dtype)
38
 
 
167
  np_imgs = np.concatenate(stage1_images, 1)
168
  np_xyzs = np.concatenate(stage2_images, 1)
169
  Image.fromarray(np_imgs).save("pixel_images.png")
170
+ Image.fromarray(np_xyzs).save("xyz_images.png")