shivansarora commited on
Commit
f5d3ac2
·
verified ·
1 Parent(s): 504bc77

tensor fix

Browse files
controllable_blender/generation_methods.py CHANGED
@@ -57,7 +57,12 @@ class RerankedTopKSampling(TreeSearch):
57
  values, indices = logprobs.topk(self.k, dim=-1)
58
  probs = torch.softmax(values, dim=-1)
59
 
60
- all_penalties = self.reranker.token_penalties.repeat(self.beam_size, 1).to(probs.device)
 
 
 
 
 
61
  penalties = torch.gather(all_penalties, -1, indices)
62
  penalised_probs = torch.mul(probs, penalties)
63
 
 
57
  values, indices = logprobs.topk(self.k, dim=-1)
58
  probs = torch.softmax(values, dim=-1)
59
 
60
+ all_penalties = self.reranker.token_penalties.to(probs.device)
61
+
62
+ if all_penalties.dim() == 1:
63
+ all_penalties = all_penalties.unsqueeze(0)
64
+ all_penalties = all_penalties.expand(self.beam_size, -1)
65
+
66
  penalties = torch.gather(all_penalties, -1, indices)
67
  penalised_probs = torch.mul(probs, penalties)
68