shivansarora commited on
Commit
e53591d
·
verified ·
1 Parent(s): 195bef2

Update controllable_blender/generation_methods.py

Browse files
controllable_blender/generation_methods.py CHANGED
@@ -62,7 +62,7 @@ class RerankedTopKSampling(TreeSearch):
62
 
63
  if all_penalties.dim() == 1:
64
  all_penalties = all_penalties.unsqueeze(0)
65
- all_penalties = all_penalties.expand(batch_size, 1)
66
 
67
  penalties = torch.gather(all_penalties, -1, indices)
68
  penalised_probs = torch.mul(probs, penalties)
 
62
 
63
  if all_penalties.dim() == 1:
64
  all_penalties = all_penalties.unsqueeze(0)
65
+ all_penalties = all_penalties.expand(batch_size, 8008)
66
 
67
  penalties = torch.gather(all_penalties, -1, indices)
68
  penalised_probs = torch.mul(probs, penalties)