|
from typing import Optional, Union |
|
|
|
import torch |
|
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
from transformers.models.llama.modeling_llama import LlamaModel |
|
|
|
from ...cache_utils import Cache |
|
|
|
|
|
|
|
class SuperModel(LlamaModel): |
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
) -> Union[tuple, CausalLMOutputWithPast]: |
|
out = super().forward( |
|
input_ids, |
|
attention_mask, |
|
position_ids, |
|
past_key_values, |
|
inputs_embeds, |
|
use_cache, |
|
output_attentions, |
|
output_hidden_states, |
|
return_dict, |
|
cache_position, |
|
) |
|
out.logits *= 2**4 |
|
return out |
|
|