Commit
·
7391a72
1
Parent(s):
d493b2e
Update device handling in TwoStagePipeline to default to 'cuda' and adjust model loading to use 'cpu' for compatibility. Ensure consistent device usage across stages and improve code clarity.
Browse files- pipelines.py +4 -5
pipelines.py
CHANGED
@@ -16,7 +16,7 @@ class TwoStagePipeline(object):
|
|
16 |
stage2_model_config,
|
17 |
stage1_sampler_config,
|
18 |
stage2_sampler_config,
|
19 |
-
device=
|
20 |
dtype=torch.float16,
|
21 |
resize_rate=1,
|
22 |
) -> None:
|
@@ -25,15 +25,14 @@ class TwoStagePipeline(object):
|
|
25 |
- the first stage was condition on single pixel image, gererate multi-view pixel image, based on the v2pp config
|
26 |
- the second stage was condition on multiview pixel image generated by the first stage, generate the final image, based on the stage2-test config
|
27 |
"""
|
28 |
-
device = torch.device("cuda")
|
29 |
self.resize_rate = resize_rate
|
30 |
|
31 |
self.stage1_model = instantiate_from_config(OmegaConf.load(stage1_model_config.config).model)
|
32 |
-
self.stage1_model.load_state_dict(torch.load(stage1_model_config.resume, map_location=
|
33 |
self.stage1_model = self.stage1_model.to(device).to(dtype)
|
34 |
|
35 |
self.stage2_model = instantiate_from_config(OmegaConf.load(stage2_model_config.config).model)
|
36 |
-
sd = torch.load(stage2_model_config.resume, map_location=
|
37 |
self.stage2_model.load_state_dict(sd, strict=False)
|
38 |
self.stage2_model = self.stage2_model.to(device).to(dtype)
|
39 |
|
@@ -168,4 +167,4 @@ if __name__ == "__main__":
|
|
168 |
np_imgs = np.concatenate(stage1_images, 1)
|
169 |
np_xyzs = np.concatenate(stage2_images, 1)
|
170 |
Image.fromarray(np_imgs).save("pixel_images.png")
|
171 |
-
Image.fromarray(np_xyzs).save("xyz_images.png")
|
|
|
16 |
stage2_model_config,
|
17 |
stage1_sampler_config,
|
18 |
stage2_sampler_config,
|
19 |
+
device="cuda",
|
20 |
dtype=torch.float16,
|
21 |
resize_rate=1,
|
22 |
) -> None:
|
|
|
25 |
- the first stage was condition on single pixel image, gererate multi-view pixel image, based on the v2pp config
|
26 |
- the second stage was condition on multiview pixel image generated by the first stage, generate the final image, based on the stage2-test config
|
27 |
"""
|
|
|
28 |
self.resize_rate = resize_rate
|
29 |
|
30 |
self.stage1_model = instantiate_from_config(OmegaConf.load(stage1_model_config.config).model)
|
31 |
+
self.stage1_model.load_state_dict(torch.load(stage1_model_config.resume, map_location="cpu"), strict=False)
|
32 |
self.stage1_model = self.stage1_model.to(device).to(dtype)
|
33 |
|
34 |
self.stage2_model = instantiate_from_config(OmegaConf.load(stage2_model_config.config).model)
|
35 |
+
sd = torch.load(stage2_model_config.resume, map_location="cpu")
|
36 |
self.stage2_model.load_state_dict(sd, strict=False)
|
37 |
self.stage2_model = self.stage2_model.to(device).to(dtype)
|
38 |
|
|
|
167 |
np_imgs = np.concatenate(stage1_images, 1)
|
168 |
np_xyzs = np.concatenate(stage2_images, 1)
|
169 |
Image.fromarray(np_imgs).save("pixel_images.png")
|
170 |
+
Image.fromarray(np_xyzs).save("xyz_images.png")
|