tensor fix
Browse files
controllable_blender/generation_methods.py
CHANGED
@@ -57,7 +57,12 @@ class RerankedTopKSampling(TreeSearch):
|
|
57 |
values, indices = logprobs.topk(self.k, dim=-1)
|
58 |
probs = torch.softmax(values, dim=-1)
|
59 |
|
60 |
-
all_penalties = self.reranker.token_penalties.
|
|
|
|
|
|
|
|
|
|
|
61 |
penalties = torch.gather(all_penalties, -1, indices)
|
62 |
penalised_probs = torch.mul(probs, penalties)
|
63 |
|
|
|
57 |
values, indices = logprobs.topk(self.k, dim=-1)
|
58 |
probs = torch.softmax(values, dim=-1)
|
59 |
|
60 |
+
all_penalties = self.reranker.token_penalties.to(probs.device)
|
61 |
+
|
62 |
+
if all_penalties.dim() == 1:
|
63 |
+
all_penalties = all_penalties.unsqueeze(0)
|
64 |
+
all_penalties = all_penalties.expand(self.beam_size, -1)
|
65 |
+
|
66 |
penalties = torch.gather(all_penalties, -1, indices)
|
67 |
penalised_probs = torch.mul(probs, penalties)
|
68 |
|