| import torch | |
| import torch.nn as nn | |
| from transformers import DebertaV2Model | |
| class MultiTaskBiasModel(nn.Module): | |
| def __init__(self, model_name="microsoft/deberta-v3-base"): | |
| super().__init__() | |
| self.bert = DebertaV2Model.from_pretrained(model_name) | |
| hidden = self.bert.config.hidden_size | |
| self.heads = nn.ModuleDict({ | |
| task: nn.Sequential( | |
| nn.Linear(hidden, hidden), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(hidden, 3) | |
| ) | |
| for task in ["political", "gender", "immigration"] | |
| }) | |
| def forward(self, input_ids, attention_mask, tasks): | |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0] | |
| logits = [] | |
| for i in range(len(tasks)): | |
| logits.append(self.heads[tasks[i]](outputs[i].unsqueeze(0))) | |
| return torch.cat(logits, dim=0) | |