shivansarora commited on
Commit
b12cada
·
verified ·
1 Parent(s): b33a755

Update controllable_blender/generation_methods.py

Browse files
controllable_blender/generation_methods.py CHANGED
@@ -61,7 +61,7 @@ class RerankedTopKSampling(TreeSearch):
61
 
62
  if all_penalties.dim() == 1:
63
  all_penalties = all_penalties.unsqueeze(0)
64
- all_penalties = all_penalties.repeat(self.beam_size, -1)
65
 
66
  penalties = torch.gather(all_penalties, -1, indices)
67
  penalised_probs = torch.mul(probs, penalties)
 
61
 
62
  if all_penalties.dim() == 1:
63
  all_penalties = all_penalties.unsqueeze(0)
64
+ all_penalties = all_penalties.repeat(self.beam_size, 1)
65
 
66
  penalties = torch.gather(all_penalties, -1, indices)
67
  penalised_probs = torch.mul(probs, penalties)