shivansarora commited on
Commit
b33a755
·
verified ·
1 Parent(s): f5d3ac2

Update controllable_blender/generation_methods.py

Browse files
controllable_blender/generation_methods.py CHANGED
@@ -61,7 +61,7 @@ class RerankedTopKSampling(TreeSearch):
61
 
62
  if all_penalties.dim() == 1:
63
  all_penalties = all_penalties.unsqueeze(0)
64
- all_penalties = all_penalties.expand(self.beam_size, -1)
65
 
66
  penalties = torch.gather(all_penalties, -1, indices)
67
  penalised_probs = torch.mul(probs, penalties)
 
61
 
62
  if all_penalties.dim() == 1:
63
  all_penalties = all_penalties.unsqueeze(0)
64
+ all_penalties = all_penalties.repeat(self.beam_size, -1)
65
 
66
  penalties = torch.gather(all_penalties, -1, indices)
67
  penalised_probs = torch.mul(probs, penalties)