shivansarora commited on
Commit
195bef2
·
verified ·
1 Parent(s): b12cada

test tensor adjustment

Browse files
controllable_blender/generation_methods.py CHANGED
@@ -57,11 +57,12 @@ class RerankedTopKSampling(TreeSearch):
57
  values, indices = logprobs.topk(self.k, dim=-1)
58
  probs = torch.softmax(values, dim=-1)
59
 
 
60
  all_penalties = self.reranker.token_penalties.to(probs.device)
61
 
62
  if all_penalties.dim() == 1:
63
  all_penalties = all_penalties.unsqueeze(0)
64
- all_penalties = all_penalties.repeat(self.beam_size, 1)
65
 
66
  penalties = torch.gather(all_penalties, -1, indices)
67
  penalised_probs = torch.mul(probs, penalties)
 
57
  values, indices = logprobs.topk(self.k, dim=-1)
58
  probs = torch.softmax(values, dim=-1)
59
 
60
+ batch_size = logprobs.size(0)
61
  all_penalties = self.reranker.token_penalties.to(probs.device)
62
 
63
  if all_penalties.dim() == 1:
64
  all_penalties = all_penalties.unsqueeze(0)
65
+ all_penalties = all_penalties.expand(batch_size, 1)
66
 
67
  penalties = torch.gather(all_penalties, -1, indices)
68
  penalised_probs = torch.mul(probs, penalties)