Spaces:
Runtime error
Runtime error
from transformers import Trainer | |
class UniVATrainer(Trainer): | |
def create_optimizer(self): | |
decay_parameters = self.get_decay_parameter_names(self.model) | |
optimizer_grouped_parameters = [] | |
trainable_params = [] | |
if self.optimizer is None: | |
for n, p in self.model.named_parameters(): | |
if n in decay_parameters and p.requires_grad: | |
optimizer_grouped_parameters += [ | |
{ | |
"params": [p], | |
"weight_decay": self.args.weight_decay, | |
} | |
] | |
trainable_params.append(n) | |
elif n not in decay_parameters and p.requires_grad: | |
optimizer_grouped_parameters += [ | |
{ | |
"params": [p], | |
"weight_decay": 0.0, | |
} | |
] | |
trainable_params.append(n) | |
else: | |
p.requires_grad_(False) | |
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( | |
self.args | |
) | |
self.optimizer = optimizer_cls( | |
optimizer_grouped_parameters, **optimizer_kwargs | |
) | |
return self.optimizer | |