Spaces:
Paused
Paused
# Copyright 2020-2025 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import json | |
import logging | |
import os | |
from copy import deepcopy | |
from typing import Optional | |
import torch | |
import torch.nn as nn | |
from accelerate import PartialState | |
from huggingface_hub import hf_hub_download | |
from huggingface_hub.utils import ( | |
EntryNotFoundError, | |
HFValidationError, | |
LocalEntryNotFoundError, | |
RepositoryNotFoundError, | |
) | |
from safetensors.torch import load_file as safe_load_file | |
from transformers import GenerationMixin, PreTrainedModel, is_torch_npu_available, is_torch_xpu_available | |
from transformers.utils import is_peft_available | |
if is_peft_available(): | |
from peft import ( | |
PeftConfig, | |
PeftModel, | |
PeftModelForCausalLM, | |
PeftModelForSeq2SeqLM, | |
PromptLearningConfig, | |
get_peft_model, | |
prepare_model_for_kbit_training, | |
) | |
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled | |
LAYER_PATTERNS = [ | |
"transformer.h.{layer}", | |
"model.decoder.layers.{layer}", | |
"gpt_neox.layers.{layer}", | |
"model.layers.{layer}", | |
] | |
class PreTrainedModelWrapper(nn.Module): | |
r""" | |
A wrapper class around a (`transformers.PreTrainedModel`) to be compatible with the | |
(`~transformers.PreTrained`) class in order to keep some attributes and methods of the | |
(`~transformers.PreTrainedModel`) class. | |
Attributes: | |
pretrained_model (`transformers.PreTrainedModel`): | |
The model to be wrapped. | |
parent_class (`transformers.PreTrainedModel`): | |
The parent class of the model to be wrapped. | |
supported_args (`list`): | |
The list of arguments that are supported by the wrapper class. | |
""" | |
transformers_parent_class = None | |
supported_args = None | |
supported_modules = ("v_head",) | |
supported_rm_modules = ("score",) | |
supported_pretrained_model_architectures = ( | |
(PreTrainedModel) | |
if not is_peft_available() | |
else (PreTrainedModel, PeftModelForCausalLM, PeftModelForSeq2SeqLM) | |
) | |
def __init__( | |
self, pretrained_model=None, score_module=None, supports_rm_adapter=False, rm_adapter_name=None, **kwargs | |
): | |
super().__init__() | |
self.pretrained_model = pretrained_model | |
self.config = pretrained_model.config | |
self.prepare_inputs_for_generation = pretrained_model.prepare_inputs_for_generation | |
self.is_loaded_in_8bit = getattr(pretrained_model, "is_loaded_in_8bit", False) | |
self.is_loaded_in_4bit = getattr(pretrained_model, "is_loaded_in_4bit", False) | |
self.is_sequential_parallel = False | |
if hasattr(pretrained_model, "gradient_checkpointing_disable"): | |
self.gradient_checkpointing_disable = pretrained_model.gradient_checkpointing_disable | |
if hasattr(pretrained_model, "gradient_checkpointing_enable"): | |
self.gradient_checkpointing_enable = pretrained_model.gradient_checkpointing_enable | |
if hasattr(pretrained_model, "enable_input_require_grads"): | |
self.enable_input_require_grads = pretrained_model.enable_input_require_grads | |
self.supports_rm_adapter = supports_rm_adapter | |
self.rm_adapter_name = rm_adapter_name | |
self.policy_adapter_name = "default" | |
if score_module is not None: | |
self.score = score_module | |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |
r""" | |
Instantiates a new model from a pretrained model from `transformers`. The | |
pretrained model is loaded using the `from_pretrained` method of the | |
`transformers.PreTrainedModel` class. The arguments that are specific to the | |
`transformers.PreTrainedModel` class are passed along this method and filtered | |
out from the `kwargs` argument. | |
Args: | |
pretrained_model_name_or_path (`str` or `transformers.PreTrainedModel`): | |
The path to the pretrained model or its name. | |
*model_args (`list`, *optional*)): | |
Additional positional arguments passed along to the underlying model's | |
`from_pretrained` method. | |
**kwargs (`dict`, *optional*): | |
Additional keyword arguments passed along to the underlying model's | |
`from_pretrained` method. We also pre-process the kwargs to extract | |
the arguments that are specific to the `transformers.PreTrainedModel` | |
class and the arguments that are specific to trl models. The kwargs | |
also support `prepare_model_for_kbit_training` arguments from | |
`peft` library. | |
""" | |
if kwargs is not None: | |
peft_config = kwargs.pop("peft_config", None) | |
reward_adapter = kwargs.pop("reward_adapter", None) | |
reward_adapter_name = kwargs.pop("reward_adapter_name", "reward_adapter") | |
is_trainable = kwargs.pop("is_trainable", False) | |
trl_model_args, pretrained_kwargs, peft_quantization_kwargs = cls._split_kwargs(kwargs) | |
token = pretrained_kwargs.get("token", None) | |
else: | |
peft_config = None | |
is_trainable = False | |
trl_model_args = {} | |
pretrained_kwargs = {} | |
peft_quantization_kwargs = {} | |
token = None | |
if reward_adapter is not None and not isinstance(reward_adapter, str): | |
raise ValueError( | |
"The `reward_adapter` argument should be a string representing the name of local path or the Hub id to the Reward Modeling adapter." | |
) | |
is_peft_model = False | |
current_device = cls._get_current_device() | |
if isinstance(pretrained_model_name_or_path, str): | |
is_loaded_in_8bit = pretrained_kwargs["load_in_8bit"] if "load_in_8bit" in pretrained_kwargs else False | |
is_loaded_in_4bit = pretrained_kwargs["load_in_4bit"] if "load_in_4bit" in pretrained_kwargs else False | |
else: | |
is_loaded_in_8bit = getattr(pretrained_model_name_or_path, "is_loaded_in_8bit", False) | |
is_loaded_in_4bit = getattr(pretrained_model_name_or_path, "is_loaded_in_4bit", False) | |
if (is_loaded_in_8bit or is_loaded_in_4bit) and "device_map" not in pretrained_kwargs: | |
# warn users | |
logging.warning( | |
"The `device_map` argument is not provided. We will override the device_map argument." | |
" to set the entire" | |
" model on the current device. If you want to set the model on multiple devices, please provide" | |
" a custom `device_map` argument." | |
) | |
pretrained_kwargs["device_map"] = {"": current_device} | |
if is_peft_available() and peft_config is not None and not isinstance(peft_config, PeftConfig): | |
raise ValueError("The `peft_config` argument should be an instance of `peft.PeftConfig` class.") | |
# First, load the pre-trained model using the parent-class | |
# either `AutoModelForCausalLM` or `AutoModelForSeq2SeqLM` | |
if isinstance(pretrained_model_name_or_path, str): | |
if is_peft_available(): | |
try: | |
# If there is a trained peft adapter in the hub, load its config. | |
remote_adapter_config = hf_hub_download( | |
pretrained_model_name_or_path, | |
"adapter_config.json", | |
token=token, | |
) | |
except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError, RepositoryNotFoundError): | |
remote_adapter_config = None | |
else: | |
remote_adapter_config = None | |
local_adapter_present = os.path.exists(os.path.join(pretrained_model_name_or_path, "adapter_config.json")) | |
if (local_adapter_present or remote_adapter_config is not None) and is_peft_available(): | |
if peft_config is not None: | |
logging.warning( | |
"`peft_config` argument ignored since a peft config file was found in " | |
f"{pretrained_model_name_or_path}" | |
) | |
# Load the trained peft adapter config | |
if local_adapter_present: | |
trained_adapter_config = PeftConfig.from_pretrained(pretrained_model_name_or_path) | |
else: | |
remote_adapter_dir = os.path.dirname(remote_adapter_config) | |
trained_adapter_config = PeftConfig.from_pretrained(remote_adapter_dir) | |
# Load the pretrained base model | |
pretrained_model = cls.transformers_parent_class.from_pretrained( | |
trained_adapter_config.base_model_name_or_path, *model_args, **pretrained_kwargs | |
) | |
# Wrap the pretrained model with the trained peft adapter | |
pretrained_model = PeftModel.from_pretrained( | |
pretrained_model, pretrained_model_name_or_path, is_trainable=is_trainable, token=token | |
) | |
logging.info("Trained peft adapter loaded") | |
else: | |
pretrained_model = cls.transformers_parent_class.from_pretrained( | |
pretrained_model_name_or_path, *model_args, **pretrained_kwargs | |
) | |
if peft_config is not None: | |
# Initialize a new peft adapter with the given config | |
if is_loaded_in_8bit or is_loaded_in_4bit: | |
pretrained_model = prepare_model_for_kbit_training( | |
pretrained_model, | |
**peft_quantization_kwargs, | |
) | |
pretrained_model = get_peft_model(pretrained_model, peft_config) | |
logging.info("peft adapter initialised") | |
elif isinstance(pretrained_model_name_or_path, cls.supported_pretrained_model_architectures): | |
pretrained_model = pretrained_model_name_or_path | |
if peft_config is not None and isinstance(pretrained_model, PreTrainedModel): | |
# Initialize a new peft adapter with the given config | |
if is_loaded_in_8bit or is_loaded_in_4bit: | |
pretrained_model = prepare_model_for_kbit_training( | |
pretrained_model, | |
**peft_quantization_kwargs, | |
) | |
pretrained_model = get_peft_model(pretrained_model, peft_config) | |
logging.info("peft adapter initialised") | |
else: | |
raise ValueError( | |
"pretrained_model_name_or_path should be a string or a PreTrainedModel, " | |
f"but is {type(pretrained_model_name_or_path)}" | |
) | |
if is_peft_available(): | |
if isinstance(pretrained_model, PeftModel): | |
is_peft_model = True | |
# for backward compatibility | |
if hasattr(pretrained_model, "active_peft_config") and isinstance( | |
pretrained_model.active_peft_config, PromptLearningConfig | |
): | |
raise ValueError("PromptLearningConfig is not supported for PPO training.") | |
# Add reward modeling adapter if specified | |
if not is_peft_model and reward_adapter is not None: | |
raise ValueError("reward_adapter can only be used with a PeftModel. ") | |
elif is_peft_model and reward_adapter is not None: | |
score_module = cls.add_and_load_reward_modeling_adapter( | |
pretrained_model, reward_adapter, reward_adapter_name, token=token | |
) | |
multi_adapter_args = { | |
"score_module": score_module, | |
"supports_rm_adapter": True, | |
"rm_adapter_name": reward_adapter_name, | |
} | |
else: | |
multi_adapter_args = {"supports_rm_adapter": False} | |
# Then, create the full model by instantiating the wrapper class | |
model = cls(pretrained_model, **multi_adapter_args, **trl_model_args) | |
# if resume_training, load the state_dict again - this is ok since the | |
# state_dict is removed from the model after loading it. | |
is_resuming_training = True | |
if isinstance(pretrained_model_name_or_path, str): | |
safe_filename = os.path.join(pretrained_model_name_or_path, "model.safetensors") | |
filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") | |
sharded_index_filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin.index.json") | |
safe_sharded_index_filename = os.path.join(pretrained_model_name_or_path, "model.safetensors.index.json") | |
is_sharded = False | |
use_safe = os.path.exists(safe_filename) | |
if not (os.path.exists(filename) or os.path.exists(safe_filename)): | |
# Try with `pytorch_model.bin` | |
filename, files_to_download, is_sharded, is_resuming_training = cls._get_checkpoint_from_hub( | |
pretrained_model, | |
pretrained_model_name_or_path, | |
sharded_index_filename, | |
token=token, | |
) | |
# Try with safetensors | |
if filename is None and files_to_download is None: | |
safe_filename, files_to_download, is_sharded, is_resuming_training = cls._get_checkpoint_from_hub( | |
pretrained_model, | |
pretrained_model_name_or_path, | |
safe_sharded_index_filename, | |
token=token, | |
model_name="model.safetensors", | |
model_index_name="model.safetensors.index.json", | |
) | |
use_safe = True | |
else: | |
use_safe = False | |
loading_func = safe_load_file if use_safe else torch.load | |
load_kwargs = {} if use_safe else {"map_location": "cpu", "weights_only": True} | |
if is_resuming_training: | |
if is_sharded: | |
# download each file and add it to the state_dict | |
state_dict = {} | |
for shard_file in files_to_download: | |
filename = hf_hub_download( | |
pretrained_model_name_or_path, | |
shard_file, | |
token=token, | |
) | |
state_dict.update(loading_func(filename, **load_kwargs)) | |
else: | |
state_dict = loading_func(filename if not use_safe else safe_filename, **load_kwargs) | |
else: | |
state_dict = pretrained_model_name_or_path.state_dict() | |
model.is_peft_model = is_peft_model | |
model.current_device = current_device | |
if is_resuming_training: | |
model.post_init(state_dict=state_dict) | |
return model | |
def _get_checkpoint_from_hub( | |
cls, | |
pretrained_model, | |
pretrained_model_name_or_path, | |
index_filename, | |
token=None, | |
model_name="pytorch_model.bin", | |
model_index_name="pytorch_model.bin.index.json", | |
): | |
files_to_download = None | |
filename = None | |
is_resuming_training = True | |
is_sharded = False | |
try: | |
filename = hf_hub_download( | |
pretrained_model_name_or_path, | |
model_name, | |
token=token, | |
) | |
# sharded | |
except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError, RepositoryNotFoundError): | |
if os.path.exists(index_filename): | |
index_file_name = index_filename | |
else: | |
try: | |
index_file_name = hf_hub_download( | |
pretrained_model_name_or_path, | |
model_index_name, | |
token=token, | |
) | |
except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError, RepositoryNotFoundError): | |
# not continue training, do not have v_head weight | |
is_resuming_training = False | |
logging.warning( | |
f"A {type(pretrained_model)} model is loaded from '{pretrained_model_name_or_path}', " | |
f"and no v_head weight is found. This IS expected if you are not resuming PPO training." | |
) | |
# load json | |
if is_resuming_training: | |
with open(index_file_name) as f: | |
index = json.load(f) | |
# check filename with `v_head` or any known extra module: | |
files_to_download = set() | |
for k, v in index["weight_map"].items(): | |
if any(module in k for module in cls.supported_modules): | |
files_to_download.add(v) | |
is_sharded = True | |
return filename, files_to_download, is_sharded, is_resuming_training | |
def _get_current_device(cls): | |
r""" | |
Get the current device. For GPU & XPU, we return the local process index using the `accelerate.PartialState` | |
object to handle corner cases when running scripts in distributed environments. | |
Returns: | |
current_device (`Union[int, str]`): | |
The current device. | |
""" | |
state = PartialState() | |
if torch.cuda.is_available() or is_torch_xpu_available(): | |
return state.local_process_index | |
elif is_torch_npu_available(): | |
return f"npu:{state.local_process_index}" | |
else: | |
return "cpu" | |
def _split_kwargs(cls, kwargs): | |
""" | |
Separate the kwargs from the arguments that we support inside | |
`supported_args` and the ones that we don't. | |
""" | |
check_peft_kwargs = False | |
if is_peft_available(): | |
from peft import prepare_model_for_kbit_training | |
check_peft_kwargs = True | |
supported_kwargs = {} | |
unsupported_kwargs = {} | |
peft_kwargs = {} | |
for key, value in kwargs.items(): | |
if key in cls.supported_args: | |
supported_kwargs[key] = value | |
else: | |
unsupported_kwargs[key] = value | |
if check_peft_kwargs: | |
if key in prepare_model_for_kbit_training.__code__.co_varnames: | |
peft_kwargs[key] = value | |
if key in unsupported_kwargs: | |
unsupported_kwargs.pop(key) | |
return supported_kwargs, unsupported_kwargs, peft_kwargs | |
def add_and_load_reward_modeling_adapter( | |
cls, pretrained_model, adapter_model_id, adapter_name="reward_model_adapter", token=None | |
): | |
r""" | |
Add and load a reward modeling adapter. This method can only be used if the | |
model is a `PeftModel` and if you have initialized the model with the `reward_modeling_adapter_id` | |
argument, pointing to the id of the reward modeling adapter. The latest needs also to contain the | |
score head in order to produce the reward. | |
""" | |
pretrained_model.load_adapter(adapter_model_id, adapter_name, is_trainable=False) | |
pretrained_model.train() | |
filename = os.path.join(adapter_model_id, "adapter_model.bin") | |
safe_loading = False | |
if not os.path.exists(filename): | |
try: | |
local_filename = hf_hub_download( | |
adapter_model_id, | |
"adapter_model.bin", | |
token=token, | |
) | |
except Exception: | |
filename = os.path.join(adapter_model_id, "adapter_model.safetensors") | |
safe_loading = True | |
if not os.path.exists(filename): | |
try: | |
local_filename = hf_hub_download( | |
adapter_model_id, | |
"adapter_model.safetensors", | |
token=token, | |
) | |
except Exception as exc: | |
raise ValueError( | |
"Could not find adapter model in the Hub, make sure you have the correct adapter model id." | |
) from exc | |
else: | |
local_filename = filename | |
else: | |
local_filename = filename | |
loading_func = safe_load_file if safe_loading else torch.load | |
load_kwargs = {} if safe_loading else {"map_location": "cpu", "weights_only": True} | |
adapter_state_dict = loading_func(local_filename, **load_kwargs) | |
for score_name_candidate in cls.supported_rm_modules: | |
if any(score_name_candidate in name for name in adapter_state_dict.keys()): | |
score_name = score_name_candidate | |
# we have found the correct head name and can break | |
break | |
score_dict = {} | |
for name, param in adapter_state_dict.items(): | |
if score_name in name: | |
key_name = ".".join(name.split(".")[-1:]) | |
score_dict[key_name] = param.to(cls._get_current_device()) | |
num_labels, hidden_dim = score_dict["weight"].shape | |
has_bias = any("bias" in name for name in adapter_state_dict.keys()) | |
score = nn.Linear(hidden_dim, num_labels, bias=has_bias).to( | |
device=cls._get_current_device(), | |
dtype=pretrained_model.dtype, | |
) | |
score.load_state_dict(score_dict) | |
for param in score.parameters(): | |
param.requires_grad = False | |
return score | |
def push_to_hub(self, *args, **kwargs): | |
r""" | |
Push the pretrained model to the hub. This method is a wrapper around | |
`transformers.PreTrainedModel.push_to_hub`. Please refer to the documentation | |
of `transformers.PreTrainedModel.push_to_hub` for more information. | |
Args: | |
*args (`list`, *optional*): | |
Positional arguments passed along to the underlying model's | |
`push_to_hub` method. | |
**kwargs (`dict`, *optional*): | |
Keyword arguments passed along to the underlying model's | |
`push_to_hub` method. | |
""" | |
raise NotImplementedError | |
def save_pretrained(self, *args, **kwargs): | |
r""" | |
Save the pretrained model to a directory. This method is a wrapper around | |
`transformers.PreTrainedModel.save_pretrained`. Please refer to the documentation | |
of `transformers.PreTrainedModel.save_pretrained` for more information. | |
Args: | |
*args (`list`, *optional*): | |
Positional arguments passed along to the underlying model's | |
`save_pretrained` method. | |
**kwargs (`dict`, *optional*): | |
Keyword arguments passed along to the underlying model's | |
`save_pretrained` method. | |
""" | |
state_dict = kwargs.get("state_dict") | |
if state_dict is None: | |
state_dict = self.state_dict() | |
kwargs["state_dict"] = state_dict | |
# if it is a peft model only save the `v_head` state_dict and | |
# pop the `state_dict` from the kwargs to avoid slient bugs with `peft` | |
if self.is_peft_model: | |
save_path = args[0] | |
save_path = os.path.join(save_path, "pytorch_model.bin") | |
torch.save(state_dict, save_path) | |
_ = kwargs.pop("state_dict", None) | |
return self.pretrained_model.save_pretrained(*args, **kwargs) | |
def state_dict(self, *args, **kwargs): | |
r""" | |
Return the state_dict of the pretrained model. | |
""" | |
raise NotImplementedError | |
def post_init(self, *args, **kwargs): | |
r""" | |
Post initialization method. This method is called after the model is | |
instantiated and loaded from a checkpoint. It can be used to perform | |
additional operations such as loading the state_dict. | |
""" | |
raise NotImplementedError | |
def compute_reward_score(self, input_ids, attention_mask=None, **kwargs): | |
r""" | |
Computes the reward score for a given input. The method has first to enable the adapter | |
and then compute the reward score. After that the model disables the reward modeling | |
adapter and enables the default ppo adapter again. | |
""" | |
if not self.supports_rm_adapter: | |
raise ValueError("This model does not support reward modeling adapter.") | |
# enable rm adapter | |
self.pretrained_model.set_adapter(self.rm_adapter_name) | |
self.pretrained_model.eval() | |
with torch.no_grad(): | |
base_model_output = self.pretrained_model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
output_hidden_states=True, | |
return_dict=True, | |
**kwargs, | |
) | |
last_hidden_states = base_model_output.hidden_states[-1] | |
scores = self.score(last_hidden_states) | |
self.pretrained_model.set_adapter(self.policy_adapter_name) | |
self.pretrained_model.eval() | |
return scores | |
def create_reference_model( | |
model: PreTrainedModelWrapper, num_shared_layers: Optional[int] = None, pattern: Optional[str] = None | |
) -> PreTrainedModelWrapper: | |
""" | |
Creates a static reference copy of a model. Note that model will be in `.eval()` mode. | |
Args: | |
model (`PreTrainedModelWrapper`): The model to be copied. | |
num_shared_layers (`int`, *optional*): The number of initial layers that are shared between both models and kept frozen. | |
pattern (`str`, *optional*): The shared layers are selected with a string pattern | |
(e.g. "transformer.h.{layer}" for GPT2) and if a custom pattern is necessary it can be passed here. | |
Returns: | |
`PreTrainedModelWrapper` | |
""" | |
if is_deepspeed_zero3_enabled(): | |
raise ValueError( | |
"DeepSpeed ZeRO-3 is enabled and is not compatible with `create_reference_model()`. Please instantiate your reference model directly with `AutoModelForCausalLM.from_pretrained()`." | |
) | |
parameter_names = [n for n, _ in model.named_parameters()] | |
ref_model = deepcopy(model) | |
# if no layers are shared, return copy of model | |
if num_shared_layers is None: | |
for param_name in parameter_names: | |
param = ref_model.get_parameter(param_name) | |
param.requires_grad = False | |
return ref_model.eval() | |
# identify layer name pattern | |
if pattern is not None: | |
pattern = pattern.format(layer=num_shared_layers) | |
else: | |
for pattern_candidate in LAYER_PATTERNS: | |
pattern_candidate = pattern_candidate.format(layer=num_shared_layers) | |
if any(pattern_candidate in name for name in parameter_names): | |
pattern = pattern_candidate | |
break | |
if pattern is None: | |
raise ValueError("Layer pattern could not be matched.") | |
# divide parameters in shared and unshared parameter lists | |
shared_param_list = [] | |
unshared_param_list = [] | |
shared_parameter = True | |
for name, _param in model.named_parameters(): | |
if pattern in name: | |
shared_parameter = False | |
if shared_parameter: | |
shared_param_list.append(name) | |
else: | |
unshared_param_list.append(name) | |
# create reference of the original parameter if they are shared | |
for param_name in shared_param_list: | |
param = model.get_parameter(param_name) | |
param.requires_grad = False | |
_ref_param = ref_model.get_parameter(param_name) | |
# for all other parameters just make sure they don't use gradients | |
for param_name in unshared_param_list: | |
param = ref_model.get_parameter(param_name) | |
param.requires_grad = False | |
if pattern is not None and len(unshared_param_list) == 0: | |
logging.warning("Pattern passed or found, but no layers matched in the model. Check for a typo.") | |
return ref_model.eval() | |
class GeometricMixtureWrapper(GenerationMixin): | |
r""" | |
Geometric Mixture generation wrapper that samples from the logits of two model's geometric mixture. | |
Args: | |
model (`PreTrainedModel`): The model to be wrapped. | |
ref_model (`PreTrainedModel`): The reference model. | |
generation_config (`GenerationConfig`): The generation config. | |
mixture_coef (`float`, *optional* - default: 0.5): The mixture coefficient. | |
""" | |
main_input_name = "input_ids" | |
_supports_cache_class = False | |
_supports_static_cache = False | |
def __init__(self, model, ref_model, generation_config, mixture_coef=0.5, device=None): | |
super().__init__() | |
self.model = model | |
self.config = model.config | |
self.ref_model = ref_model | |
self.generation_config = generation_config | |
self.mixture_coef = mixture_coef | |
self.device = device | |
def __call__(self, *args, **kwargs): | |
return self.forward(*args, **kwargs) | |
def forward(self, *args, **kwargs): | |
model_outputs = self.model(*args, **kwargs) | |
model_logits = model_outputs.logits | |
ref_model_logits = self.ref_model(*args, **kwargs).logits | |
model_outputs.logits = torch.nn.functional.log_softmax( | |
self.mixture_coef * ref_model_logits + (1 - self.mixture_coef) * model_logits, dim=-1 | |
) | |
return model_outputs | |
def prepare_inputs_for_generation(self, *args, **kwargs): | |
# turn off cache in the generation config | |
kwargs["use_cache"] = False | |
model_inputs = self.model.prepare_inputs_for_generation(*args, **kwargs) | |
_ = self.ref_model.prepare_inputs_for_generation(*args, **kwargs) | |
return model_inputs | |
def _validate_model_class(self): | |
self.model._validate_model_class() | |
def _validate_model_kwargs(self, model_kwargs): | |
return self.model._validate_model_kwargs(model_kwargs) | |