Spaces:
Paused
Paused
# Copyright 2020-2025 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
import torch.nn as nn | |
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, is_torch_npu_available, is_torch_xpu_available | |
from .modeling_base import PreTrainedModelWrapper | |
class ValueHead(nn.Module): | |
r""" | |
The ValueHead class implements a head for GPT2 that returns a scalar for each output token. | |
""" | |
def __init__(self, config, **kwargs): | |
super().__init__() | |
if not hasattr(config, "summary_dropout_prob"): | |
summary_dropout_prob = kwargs.pop("summary_dropout_prob", 0.1) | |
else: | |
summary_dropout_prob = config.summary_dropout_prob | |
self.dropout = nn.Dropout(summary_dropout_prob) if summary_dropout_prob else nn.Identity() | |
# some models such as OPT have a projection layer before the word embeddings - e.g. OPT-350m | |
if hasattr(config, "hidden_size"): | |
hidden_size = config.hidden_size | |
if hasattr(config, "word_embed_proj_dim"): | |
hidden_size = config.word_embed_proj_dim | |
elif hasattr(config, "is_encoder_decoder"): | |
if config.is_encoder_decoder and hasattr(config, "decoder"): | |
if hasattr(config.decoder, "hidden_size"): | |
hidden_size = config.decoder.hidden_size | |
self.summary = nn.Linear(hidden_size, 1) | |
self.flatten = nn.Flatten() | |
def forward(self, hidden_states): | |
output = self.dropout(hidden_states) | |
# For now force upcast in fp32 if needed. Let's keep the | |
# output in fp32 for numerical stability. | |
if output.dtype != self.summary.weight.dtype: | |
output = output.to(self.summary.weight.dtype) | |
output = self.summary(output) | |
return output | |
class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper): | |
r""" | |
An autoregressive model with a value head in addition to the language model head. | |
This class inherits from `~trl.PreTrainedModelWrapper` and wraps a | |
`transformers.PreTrainedModel` class. The wrapper class supports classic functions | |
such as `from_pretrained`, `push_to_hub` and `generate`. To call a method of the wrapped | |
model, simply manipulate the `pretrained_model` attribute of this class. | |
Class attributes: | |
- **transformers_parent_class** (`transformers.PreTrainedModel`) -- The parent class of the wrapped model. This | |
should be set to `transformers.AutoModelForCausalLM` for this class. | |
- **supported_args** (`tuple`) -- A tuple of strings that are used to identify the arguments that are supported | |
by the `ValueHead` class. Currently, the supported args are: | |
- **summary_dropout_prob** (`float`, `optional`, defaults to `None`) -- The dropout probability for the | |
`ValueHead` class. | |
- **v_head_initializer_range** (`float`, `optional`, defaults to `0.2`) -- The initializer range for the | |
`ValueHead` if a specific initialization strategy is selected. | |
- **v_head_init_strategy** (`str`, `optional`, defaults to `None`) -- The initialization strategy for the | |
`ValueHead`. Currently, the supported strategies are: | |
- **`None`** -- Initializes the weights of the `ValueHead` with a random distribution. This is the default | |
strategy. | |
- **"normal"** -- Initializes the weights of the `ValueHead` with a normal distribution. | |
""" | |
transformers_parent_class = AutoModelForCausalLM | |
supported_args = ( | |
"summary_dropout_prob", | |
"v_head_initializer_range", | |
"v_head_init_strategy", | |
) | |
def __init__(self, pretrained_model, **kwargs): | |
r""" | |
Initializes the model. | |
Args: | |
pretrained_model (`transformers.PreTrainedModel`): | |
The model to wrap. It should be a causal language model such as GPT2. | |
or any model mapped inside the `AutoModelForCausalLM` class. | |
kwargs (`dict`, `optional`): | |
Additional keyword arguments, that are passed to the `ValueHead` class. | |
""" | |
super().__init__(pretrained_model, **kwargs) | |
v_head_kwargs, _, _ = self._split_kwargs(kwargs) | |
self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs) | |
self._init_weights(**v_head_kwargs) | |
def _init_weights(self, **kwargs): | |
r""" | |
Initializes the weights of the value head. The default initialization strategy is random. | |
Users can pass a different initialization strategy by passing the `v_head_init_strategy` argument | |
when calling `.from_pretrained`. Supported strategies are: | |
- `normal`: initializes the weights with a normal distribution. | |
Args: | |
**kwargs (`dict`, `optional`): | |
Additional keyword arguments, that are passed to the `ValueHead` class. These arguments | |
can contain the `v_head_init_strategy` argument as well as the `v_head_initializer_range` | |
argument. | |
""" | |
initializer_range = kwargs.pop("v_head_initializer_range", 0.2) | |
# random init by default | |
init_strategy = kwargs.pop("v_head_init_strategy", None) | |
if init_strategy is None: | |
# do nothing | |
pass | |
elif init_strategy == "normal": | |
self.v_head.summary.weight.data.normal_(mean=0.0, std=initializer_range) | |
self.v_head.summary.bias.data.zero_() | |
def forward( | |
self, | |
input_ids=None, | |
past_key_values=None, | |
attention_mask=None, | |
return_past_key_values=False, | |
**kwargs, | |
): | |
r""" | |
Applies a forward pass to the wrapped model and returns the logits of the value head. | |
Args: | |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
Indices of input sequence tokens in the vocabulary. | |
past_key_values (`tuple(tuple(torch.FloatTensor))`, `optional`): | |
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model | |
(see `past_key_values` input) to speed up sequential decoding. | |
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`): | |
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: | |
- 1 for tokens that are **not masked**, | |
- 0 for tokens that are **masked**. | |
return_past_key_values (bool): A flag indicating if the computed hidden-states should be returned. | |
kwargs (`dict`, `optional`): | |
Additional keyword arguments, that are passed to the wrapped model. | |
""" | |
kwargs["output_hidden_states"] = True # this had already been set in the LORA / PEFT examples | |
kwargs["past_key_values"] = past_key_values | |
if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING": | |
kwargs.pop("past_key_values") | |
base_model_output = self.pretrained_model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
**kwargs, | |
) | |
last_hidden_state = base_model_output.hidden_states[-1] | |
lm_logits = base_model_output.logits | |
loss = base_model_output.loss | |
if last_hidden_state.device != self.v_head.summary.weight.device: | |
last_hidden_state = last_hidden_state.to(self.v_head.summary.weight.device) | |
value = self.v_head(last_hidden_state).squeeze(-1) | |
# force upcast in fp32 if logits are in half-precision | |
if lm_logits.dtype != torch.float32: | |
lm_logits = lm_logits.float() | |
if return_past_key_values: | |
return (lm_logits, loss, value, base_model_output.past_key_values) | |
else: | |
return (lm_logits, loss, value) | |
def generate(self, *args, **kwargs): | |
r""" | |
A simple wrapper around the `generate` method of the wrapped model. | |
Please refer to the [`generate`](https://huggingface.co/docs/transformers/internal/generation_utils) | |
method of the wrapped model for more information about the supported arguments. | |
Args: | |
*args (`list`, *optional*): | |
Positional arguments passed to the `generate` method of the wrapped model. | |
**kwargs (`dict`, *optional*): | |
Keyword arguments passed to the `generate` method of the wrapped model. | |
""" | |
return self.pretrained_model.generate(*args, **kwargs) | |
def state_dict(self, *args, **kwargs): | |
r""" | |
Returns the state dictionary of the model. We add the state dictionary of the value head | |
to the state dictionary of the wrapped model by prepending the key with `v_head.`. | |
""" | |
if not self.is_peft_model: | |
pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs) | |
else: | |
# if it is a peft model, only save the v_head | |
pretrained_model_state_dict = {} | |
v_head_state_dict = self.v_head.state_dict(*args, **kwargs) | |
for k, v in v_head_state_dict.items(): | |
pretrained_model_state_dict[f"v_head.{k}"] = v | |
return pretrained_model_state_dict | |
def push_to_hub(self, *args, **kwargs): | |
self.pretrained_model.v_head = self.v_head | |
return self.pretrained_model.push_to_hub(*args, **kwargs) | |
def post_init(self, state_dict): | |
r""" | |
We add the state dictionary of the value head to the state dictionary of the wrapped model | |
by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the | |
keys of the value head state dictionary. | |
""" | |
for k in list(state_dict.keys()): | |
if "v_head." in k: | |
state_dict[k.replace("v_head.", "")] = state_dict.pop(k) | |
self.v_head.load_state_dict(state_dict, strict=False) | |
del state_dict | |
if hasattr(self.pretrained_model, "hf_device_map"): | |
if ( | |
"cpu" in self.pretrained_model.hf_device_map.values() | |
or "disk" in self.pretrained_model.hf_device_map.values() | |
): | |
raise ValueError( | |
"The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead models." | |
) | |
first_device = list(set(self.pretrained_model.hf_device_map.values()))[0] | |
if isinstance(first_device, int): | |
if is_torch_npu_available(): | |
first_device = f"npu:{first_device}" | |
elif is_torch_xpu_available(): | |
first_device = f"xpu:{first_device}" | |
else: | |
first_device = f"cuda:{first_device}" | |
self.v_head = self.v_head.to(first_device) | |
def set_device_hook(module, input, outputs): | |
new_output = () | |
for output in outputs: | |
if isinstance(output, torch.Tensor): | |
new_output += (output.to(first_device),) | |
else: | |
new_output += (output,) | |
return new_output | |
self.register_forward_hook(set_device_hook) | |
self.is_sequential_parallel = True | |
class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper): | |
r""" | |
A seq2seq model with a value head in addition to the language model head. | |
This class inherits from `~trl.PreTrainedModelWrapper` and wraps a | |
`transformers.PreTrainedModel` class. The wrapper class supports classic functions | |
such as `from_pretrained` and `push_to_hub` and also provides some additional | |
functionalities such as `generate`. | |
Args: | |
pretrained_model (`transformers.PreTrainedModel`): | |
The model to wrap. It should be a causal language model such as GPT2. | |
or any model mapped inside the `AutoModelForSeq2SeqLM` class. | |
kwargs: | |
Additional keyword arguments passed along to the `ValueHead` class. | |
""" | |
transformers_parent_class = AutoModelForSeq2SeqLM | |
lm_head_namings = ["lm_head", "embed_out", "output_projection"] | |
supported_args = ( | |
"summary_dropout_prob", | |
"v_head_initializer_range", | |
"v_head_init_strategy", | |
) | |
def __init__(self, pretrained_model, **kwargs): | |
super().__init__(pretrained_model, **kwargs) | |
v_head_kwargs, _, _ = self._split_kwargs(kwargs) | |
self.is_encoder_decoder = True | |
if not self._has_lm_head(): | |
raise ValueError("The model does not have a language model head, please use a model that has one.") | |
self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs) | |
self._init_weights(**v_head_kwargs) | |
def _has_lm_head(self): | |
# check module names of all modules inside `pretrained_model` to find the language model head | |
for name, _module in self.pretrained_model.named_modules(): | |
if any(attribute in name for attribute in self.lm_head_namings): | |
return True | |
return False | |
def post_init(self, state_dict): | |
r""" | |
We add the state dictionary of the value head to the state dictionary of the wrapped model | |
by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the | |
keys of the value head state dictionary. | |
""" | |
for k in list(state_dict.keys()): | |
if "v_head." in k: | |
state_dict[k.replace("v_head.", "")] = state_dict.pop(k) | |
self.v_head.load_state_dict(state_dict, strict=False) | |
del state_dict | |
if hasattr(self.pretrained_model, "hf_device_map"): | |
if ( | |
"cpu" in self.pretrained_model.hf_device_map.values() | |
or "disk" in self.pretrained_model.hf_device_map.values() | |
): | |
raise ValueError( | |
"The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead models." | |
) | |
# get the lm_head device | |
for name, module in self.pretrained_model.named_modules(): | |
if any(attribute in name for attribute in self.lm_head_namings): | |
lm_head_device = module.weight.device | |
break | |
# put v_head on the same device as the lm_head to avoid issues | |
self.v_head = self.v_head.to(lm_head_device) | |
def set_device_hook(module, input, outputs): | |
r""" | |
A hook that sets the device of the output of the model to the device of the first | |
parameter of the model. | |
Args: | |
module (`nn.Module`): | |
The module to which the hook is attached. | |
input (`tuple`): | |
The input to the module. | |
outputs (`tuple`): | |
The output of the module. | |
""" | |
new_output = () | |
for output in outputs: | |
if isinstance(output, torch.Tensor): | |
new_output += (output.to(lm_head_device),) | |
else: | |
new_output += (output,) | |
return new_output | |
self.register_forward_hook(set_device_hook) | |
self.is_sequential_parallel = True | |
def state_dict(self, *args, **kwargs): | |
r""" | |
Returns the state dictionary of the model. We add the state dictionary of the value head | |
to the state dictionary of the wrapped model by prepending the key with `v_head.`. | |
""" | |
if not self.is_peft_model: | |
pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs) | |
else: | |
# if it is a peft model, only save the v_head | |
pretrained_model_state_dict = {} | |
v_head_state_dict = self.v_head.state_dict(*args, **kwargs) | |
for k, v in v_head_state_dict.items(): | |
pretrained_model_state_dict[f"v_head.{k}"] = v | |
return pretrained_model_state_dict | |
def push_to_hub(self, *args, **kwargs): | |
self.pretrained_model.v_head = self.v_head | |
return self.pretrained_model.push_to_hub(*args, **kwargs) | |
def _init_weights(self, **kwargs): | |
r""" | |
We initialize the weights of the value head. | |
""" | |
initializer_range = kwargs.pop("v_head_initializer_range", 0.2) | |
# random init by default | |
init_strategy = kwargs.pop("v_head_init_strategy", None) | |
if init_strategy is None: | |
# do nothing | |
pass | |
elif init_strategy == "normal": | |
self.v_head.summary.weight.data.normal_(mean=0.0, std=initializer_range) | |
self.v_head.summary.bias.data.zero_() | |
def forward( | |
self, | |
input_ids=None, | |
past_key_values=None, | |
attention_mask=None, | |
return_past_key_values=False, | |
**kwargs, | |
): | |
kwargs["past_key_values"] = past_key_values | |
if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING": | |
kwargs.pop("past_key_values") | |
base_model_output = self.pretrained_model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
output_hidden_states=True, # We force the model to output hidden states | |
**kwargs, | |
) | |
last_hidden_state = base_model_output.decoder_hidden_states[-1] | |
lm_logits = base_model_output.logits | |
loss = base_model_output.loss | |
value = self.v_head(last_hidden_state).squeeze(-1) | |
# force upcast in fp32 if logits are in half-precision | |
if lm_logits.dtype != torch.float32: | |
lm_logits = lm_logits.float() | |
if return_past_key_values: | |
return (lm_logits, loss, value, base_model_output.past_key_values) | |
else: | |
return (lm_logits, loss, value) | |
def generate(self, *args, **kwargs): | |
r""" | |
We call `generate` on the wrapped model. | |
""" | |
return self.pretrained_model.generate(*args, **kwargs) | |