xinjie.wang commited on
Commit
2e0bac6
·
1 Parent(s): e990bc4
Files changed (1) hide show
  1. embodied_gen/data/utils.py +1 -1
embodied_gen/data/utils.py CHANGED
@@ -139,7 +139,7 @@ class DiffrastRender(object):
139
  vertices: torch.Tensor,
140
  matrix: torch.Tensor,
141
  ) -> torch.Tensor:
142
- verts_ones = torch.ones((len(vertices), 1)).to(vertices)
143
  verts_homo = torch.cat([vertices, verts_ones], dim=-1)
144
  trans_vertices = torch.matmul(verts_homo, matrix.permute(0, 2, 1))
145
 
 
139
  vertices: torch.Tensor,
140
  matrix: torch.Tensor,
141
  ) -> torch.Tensor:
142
+ verts_ones = torch.ones((len(vertices), 1), device=vertices.device, dtype=vertices.dtype)
143
  verts_homo = torch.cat([vertices, verts_ones], dim=-1)
144
  trans_vertices = torch.matmul(verts_homo, matrix.permute(0, 2, 1))
145