|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import logging |
|
import os |
|
from copy import deepcopy |
|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
from accelerate import PartialState |
|
from huggingface_hub import hf_hub_download |
|
from huggingface_hub.utils import ( |
|
EntryNotFoundError, |
|
HFValidationError, |
|
LocalEntryNotFoundError, |
|
RepositoryNotFoundError, |
|
) |
|
from safetensors.torch import load_file as safe_load_file |
|
from transformers import GenerationMixin, PreTrainedModel, is_torch_npu_available, is_torch_xpu_available |
|
from transformers.utils import is_peft_available |
|
|
|
|
|
if is_peft_available(): |
|
from peft import ( |
|
PeftConfig, |
|
PeftModel, |
|
PeftModelForCausalLM, |
|
PeftModelForSeq2SeqLM, |
|
PromptLearningConfig, |
|
get_peft_model, |
|
prepare_model_for_kbit_training, |
|
) |
|
|
|
|
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled |
|
|
|
|
|
LAYER_PATTERNS = [ |
|
"transformer.h.{layer}", |
|
"model.decoder.layers.{layer}", |
|
"gpt_neox.layers.{layer}", |
|
"model.layers.{layer}", |
|
] |
|
|
|
|
|
class PreTrainedModelWrapper(nn.Module): |
|
r""" |
|
A wrapper class around a (`transformers.PreTrainedModel`) to be compatible with the |
|
(`~transformers.PreTrained`) class in order to keep some attributes and methods of the |
|
(`~transformers.PreTrainedModel`) class. |
|
|
|
Attributes: |
|
pretrained_model (`transformers.PreTrainedModel`): |
|
The model to be wrapped. |
|
parent_class (`transformers.PreTrainedModel`): |
|
The parent class of the model to be wrapped. |
|
supported_args (`list`): |
|
The list of arguments that are supported by the wrapper class. |
|
""" |
|
|
|
transformers_parent_class = None |
|
supported_args = None |
|
supported_modules = ("v_head",) |
|
supported_rm_modules = ("score",) |
|
supported_pretrained_model_architectures = ( |
|
(PreTrainedModel) |
|
if not is_peft_available() |
|
else (PreTrainedModel, PeftModelForCausalLM, PeftModelForSeq2SeqLM) |
|
) |
|
|
|
def __init__( |
|
self, pretrained_model=None, score_module=None, supports_rm_adapter=False, rm_adapter_name=None, **kwargs |
|
): |
|
super().__init__() |
|
self.pretrained_model = pretrained_model |
|
|
|
self.config = pretrained_model.config |
|
self.prepare_inputs_for_generation = pretrained_model.prepare_inputs_for_generation |
|
self.is_loaded_in_8bit = getattr(pretrained_model, "is_loaded_in_8bit", False) |
|
self.is_loaded_in_4bit = getattr(pretrained_model, "is_loaded_in_4bit", False) |
|
self.is_sequential_parallel = False |
|
|
|
if hasattr(pretrained_model, "gradient_checkpointing_disable"): |
|
self.gradient_checkpointing_disable = pretrained_model.gradient_checkpointing_disable |
|
|
|
if hasattr(pretrained_model, "gradient_checkpointing_enable"): |
|
self.gradient_checkpointing_enable = pretrained_model.gradient_checkpointing_enable |
|
|
|
if hasattr(pretrained_model, "enable_input_require_grads"): |
|
self.enable_input_require_grads = pretrained_model.enable_input_require_grads |
|
|
|
self.supports_rm_adapter = supports_rm_adapter |
|
self.rm_adapter_name = rm_adapter_name |
|
self.policy_adapter_name = "default" |
|
if score_module is not None: |
|
self.score = score_module |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
r""" |
|
Instantiates a new model from a pretrained model from `transformers`. The |
|
pretrained model is loaded using the `from_pretrained` method of the |
|
`transformers.PreTrainedModel` class. The arguments that are specific to the |
|
`transformers.PreTrainedModel` class are passed along this method and filtered |
|
out from the `kwargs` argument. |
|
|
|
Args: |
|
pretrained_model_name_or_path (`str` or `transformers.PreTrainedModel`): |
|
The path to the pretrained model or its name. |
|
*model_args (`list`, *optional*)): |
|
Additional positional arguments passed along to the underlying model's |
|
`from_pretrained` method. |
|
**kwargs (`dict`, *optional*): |
|
Additional keyword arguments passed along to the underlying model's |
|
`from_pretrained` method. We also pre-process the kwargs to extract |
|
the arguments that are specific to the `transformers.PreTrainedModel` |
|
class and the arguments that are specific to trl models. The kwargs |
|
also support `prepare_model_for_kbit_training` arguments from |
|
`peft` library. |
|
""" |
|
if kwargs is not None: |
|
peft_config = kwargs.pop("peft_config", None) |
|
reward_adapter = kwargs.pop("reward_adapter", None) |
|
reward_adapter_name = kwargs.pop("reward_adapter_name", "reward_adapter") |
|
is_trainable = kwargs.pop("is_trainable", False) |
|
trl_model_args, pretrained_kwargs, peft_quantization_kwargs = cls._split_kwargs(kwargs) |
|
token = pretrained_kwargs.get("token", None) |
|
else: |
|
peft_config = None |
|
is_trainable = False |
|
trl_model_args = {} |
|
pretrained_kwargs = {} |
|
peft_quantization_kwargs = {} |
|
token = None |
|
|
|
if reward_adapter is not None and not isinstance(reward_adapter, str): |
|
raise ValueError( |
|
"The `reward_adapter` argument should be a string representing the name of local path or the Hub id to the Reward Modeling adapter." |
|
) |
|
|
|
is_peft_model = False |
|
|
|
current_device = cls._get_current_device() |
|
if isinstance(pretrained_model_name_or_path, str): |
|
is_loaded_in_8bit = pretrained_kwargs["load_in_8bit"] if "load_in_8bit" in pretrained_kwargs else False |
|
is_loaded_in_4bit = pretrained_kwargs["load_in_4bit"] if "load_in_4bit" in pretrained_kwargs else False |
|
else: |
|
is_loaded_in_8bit = getattr(pretrained_model_name_or_path, "is_loaded_in_8bit", False) |
|
is_loaded_in_4bit = getattr(pretrained_model_name_or_path, "is_loaded_in_4bit", False) |
|
|
|
if (is_loaded_in_8bit or is_loaded_in_4bit) and "device_map" not in pretrained_kwargs: |
|
|
|
logging.warning( |
|
"The `device_map` argument is not provided. We will override the device_map argument." |
|
" to set the entire" |
|
" model on the current device. If you want to set the model on multiple devices, please provide" |
|
" a custom `device_map` argument." |
|
) |
|
pretrained_kwargs["device_map"] = {"": current_device} |
|
|
|
if is_peft_available() and peft_config is not None and not isinstance(peft_config, PeftConfig): |
|
raise ValueError("The `peft_config` argument should be an instance of `peft.PeftConfig` class.") |
|
|
|
|
|
|
|
if isinstance(pretrained_model_name_or_path, str): |
|
if is_peft_available(): |
|
try: |
|
|
|
remote_adapter_config = hf_hub_download( |
|
pretrained_model_name_or_path, |
|
"adapter_config.json", |
|
token=token, |
|
) |
|
except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError, RepositoryNotFoundError): |
|
remote_adapter_config = None |
|
else: |
|
remote_adapter_config = None |
|
|
|
local_adapter_present = os.path.exists(os.path.join(pretrained_model_name_or_path, "adapter_config.json")) |
|
|
|
if (local_adapter_present or remote_adapter_config is not None) and is_peft_available(): |
|
if peft_config is not None: |
|
logging.warning( |
|
"`peft_config` argument ignored since a peft config file was found in " |
|
f"{pretrained_model_name_or_path}" |
|
) |
|
|
|
|
|
if local_adapter_present: |
|
trained_adapter_config = PeftConfig.from_pretrained(pretrained_model_name_or_path) |
|
else: |
|
remote_adapter_dir = os.path.dirname(remote_adapter_config) |
|
trained_adapter_config = PeftConfig.from_pretrained(remote_adapter_dir) |
|
|
|
|
|
pretrained_model = cls.transformers_parent_class.from_pretrained( |
|
trained_adapter_config.base_model_name_or_path, *model_args, **pretrained_kwargs |
|
) |
|
|
|
|
|
pretrained_model = PeftModel.from_pretrained( |
|
pretrained_model, pretrained_model_name_or_path, is_trainable=is_trainable, token=token |
|
) |
|
logging.info("Trained peft adapter loaded") |
|
else: |
|
pretrained_model = cls.transformers_parent_class.from_pretrained( |
|
pretrained_model_name_or_path, *model_args, **pretrained_kwargs |
|
) |
|
|
|
if peft_config is not None: |
|
|
|
if is_loaded_in_8bit or is_loaded_in_4bit: |
|
pretrained_model = prepare_model_for_kbit_training( |
|
pretrained_model, |
|
**peft_quantization_kwargs, |
|
) |
|
pretrained_model = get_peft_model(pretrained_model, peft_config) |
|
logging.info("peft adapter initialised") |
|
|
|
elif isinstance(pretrained_model_name_or_path, cls.supported_pretrained_model_architectures): |
|
pretrained_model = pretrained_model_name_or_path |
|
|
|
if peft_config is not None and isinstance(pretrained_model, PreTrainedModel): |
|
|
|
if is_loaded_in_8bit or is_loaded_in_4bit: |
|
pretrained_model = prepare_model_for_kbit_training( |
|
pretrained_model, |
|
**peft_quantization_kwargs, |
|
) |
|
pretrained_model = get_peft_model(pretrained_model, peft_config) |
|
logging.info("peft adapter initialised") |
|
else: |
|
raise ValueError( |
|
"pretrained_model_name_or_path should be a string or a PreTrainedModel, " |
|
f"but is {type(pretrained_model_name_or_path)}" |
|
) |
|
|
|
if is_peft_available(): |
|
if isinstance(pretrained_model, PeftModel): |
|
is_peft_model = True |
|
|
|
if hasattr(pretrained_model, "active_peft_config") and isinstance( |
|
pretrained_model.active_peft_config, PromptLearningConfig |
|
): |
|
raise ValueError("PromptLearningConfig is not supported for PPO training.") |
|
|
|
|
|
if not is_peft_model and reward_adapter is not None: |
|
raise ValueError("reward_adapter can only be used with a PeftModel. ") |
|
elif is_peft_model and reward_adapter is not None: |
|
score_module = cls.add_and_load_reward_modeling_adapter( |
|
pretrained_model, reward_adapter, reward_adapter_name, token=token |
|
) |
|
multi_adapter_args = { |
|
"score_module": score_module, |
|
"supports_rm_adapter": True, |
|
"rm_adapter_name": reward_adapter_name, |
|
} |
|
else: |
|
multi_adapter_args = {"supports_rm_adapter": False} |
|
|
|
|
|
model = cls(pretrained_model, **multi_adapter_args, **trl_model_args) |
|
|
|
|
|
|
|
is_resuming_training = True |
|
if isinstance(pretrained_model_name_or_path, str): |
|
safe_filename = os.path.join(pretrained_model_name_or_path, "model.safetensors") |
|
filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") |
|
|
|
sharded_index_filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin.index.json") |
|
safe_sharded_index_filename = os.path.join(pretrained_model_name_or_path, "model.safetensors.index.json") |
|
is_sharded = False |
|
use_safe = os.path.exists(safe_filename) |
|
|
|
if not (os.path.exists(filename) or os.path.exists(safe_filename)): |
|
|
|
filename, files_to_download, is_sharded, is_resuming_training = cls._get_checkpoint_from_hub( |
|
pretrained_model, |
|
pretrained_model_name_or_path, |
|
sharded_index_filename, |
|
token=token, |
|
) |
|
|
|
if filename is None and files_to_download is None: |
|
safe_filename, files_to_download, is_sharded, is_resuming_training = cls._get_checkpoint_from_hub( |
|
pretrained_model, |
|
pretrained_model_name_or_path, |
|
safe_sharded_index_filename, |
|
token=token, |
|
model_name="model.safetensors", |
|
model_index_name="model.safetensors.index.json", |
|
) |
|
use_safe = True |
|
else: |
|
use_safe = False |
|
|
|
loading_func = safe_load_file if use_safe else torch.load |
|
load_kwargs = {} if use_safe else {"map_location": "cpu", "weights_only": True} |
|
|
|
if is_resuming_training: |
|
if is_sharded: |
|
|
|
state_dict = {} |
|
|
|
for shard_file in files_to_download: |
|
filename = hf_hub_download( |
|
pretrained_model_name_or_path, |
|
shard_file, |
|
token=token, |
|
) |
|
state_dict.update(loading_func(filename, **load_kwargs)) |
|
else: |
|
state_dict = loading_func(filename if not use_safe else safe_filename, **load_kwargs) |
|
|
|
else: |
|
state_dict = pretrained_model_name_or_path.state_dict() |
|
|
|
model.is_peft_model = is_peft_model |
|
model.current_device = current_device |
|
|
|
if is_resuming_training: |
|
model.post_init(state_dict=state_dict) |
|
|
|
return model |
|
|
|
@classmethod |
|
def _get_checkpoint_from_hub( |
|
cls, |
|
pretrained_model, |
|
pretrained_model_name_or_path, |
|
index_filename, |
|
token=None, |
|
model_name="pytorch_model.bin", |
|
model_index_name="pytorch_model.bin.index.json", |
|
): |
|
files_to_download = None |
|
filename = None |
|
is_resuming_training = True |
|
is_sharded = False |
|
|
|
try: |
|
filename = hf_hub_download( |
|
pretrained_model_name_or_path, |
|
model_name, |
|
token=token, |
|
) |
|
|
|
except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError, RepositoryNotFoundError): |
|
if os.path.exists(index_filename): |
|
index_file_name = index_filename |
|
else: |
|
try: |
|
index_file_name = hf_hub_download( |
|
pretrained_model_name_or_path, |
|
model_index_name, |
|
token=token, |
|
) |
|
except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError, RepositoryNotFoundError): |
|
|
|
is_resuming_training = False |
|
logging.warning( |
|
f"A {type(pretrained_model)} model is loaded from '{pretrained_model_name_or_path}', " |
|
f"and no v_head weight is found. This IS expected if you are not resuming PPO training." |
|
) |
|
|
|
if is_resuming_training: |
|
with open(index_file_name) as f: |
|
index = json.load(f) |
|
|
|
files_to_download = set() |
|
for k, v in index["weight_map"].items(): |
|
if any(module in k for module in cls.supported_modules): |
|
files_to_download.add(v) |
|
is_sharded = True |
|
|
|
return filename, files_to_download, is_sharded, is_resuming_training |
|
|
|
@classmethod |
|
def _get_current_device(cls): |
|
r""" |
|
Get the current device. For GPU, we return the local process index using the `accelerate.PartialState` |
|
object to handle corner cases when running scripts in distributed environments. |
|
|
|
Returns: |
|
current_device (`Union[int, str]`): |
|
The current device. |
|
""" |
|
state = PartialState() |
|
if is_torch_xpu_available(): |
|
return f"xpu:{state.local_process_index}" |
|
elif is_torch_npu_available(): |
|
return f"npu:{state.local_process_index}" |
|
else: |
|
return state.local_process_index if torch.cuda.is_available() else "cpu" |
|
|
|
@classmethod |
|
def _split_kwargs(cls, kwargs): |
|
""" |
|
Separate the kwargs from the arguments that we support inside |
|
`supported_args` and the ones that we don't. |
|
""" |
|
check_peft_kwargs = False |
|
|
|
if is_peft_available(): |
|
from peft import prepare_model_for_kbit_training |
|
|
|
check_peft_kwargs = True |
|
|
|
supported_kwargs = {} |
|
unsupported_kwargs = {} |
|
peft_kwargs = {} |
|
|
|
for key, value in kwargs.items(): |
|
if key in cls.supported_args: |
|
supported_kwargs[key] = value |
|
else: |
|
unsupported_kwargs[key] = value |
|
|
|
if check_peft_kwargs: |
|
if key in prepare_model_for_kbit_training.__code__.co_varnames: |
|
peft_kwargs[key] = value |
|
if key in unsupported_kwargs: |
|
unsupported_kwargs.pop(key) |
|
|
|
return supported_kwargs, unsupported_kwargs, peft_kwargs |
|
|
|
@classmethod |
|
def add_and_load_reward_modeling_adapter( |
|
cls, pretrained_model, adapter_model_id, adapter_name="reward_model_adapter", token=None |
|
): |
|
r""" |
|
Add and load a reward modeling adapter. This method can only be used if the |
|
model is a `PeftModel` and if you have initialized the model with the `reward_modeling_adapter_id` |
|
argument, pointing to the id of the reward modeling adapter. The latest needs also to contain the |
|
score head in order to produce the reward. |
|
""" |
|
pretrained_model.load_adapter(adapter_model_id, adapter_name, is_trainable=False) |
|
pretrained_model.train() |
|
|
|
filename = os.path.join(adapter_model_id, "adapter_model.bin") |
|
safe_loading = False |
|
if not os.path.exists(filename): |
|
try: |
|
local_filename = hf_hub_download( |
|
adapter_model_id, |
|
"adapter_model.bin", |
|
token=token, |
|
) |
|
except Exception: |
|
filename = os.path.join(adapter_model_id, "adapter_model.safetensors") |
|
safe_loading = True |
|
if not os.path.exists(filename): |
|
try: |
|
local_filename = hf_hub_download( |
|
adapter_model_id, |
|
"adapter_model.safetensors", |
|
token=token, |
|
) |
|
except Exception as exc: |
|
raise ValueError( |
|
"Could not find adapter model in the Hub, make sure you have the correct adapter model id." |
|
) from exc |
|
else: |
|
local_filename = filename |
|
else: |
|
local_filename = filename |
|
|
|
loading_func = safe_load_file if safe_loading else torch.load |
|
load_kwargs = {} if safe_loading else {"map_location": "cpu", "weights_only": True} |
|
|
|
adapter_state_dict = loading_func(local_filename, **load_kwargs) |
|
|
|
for score_name_candidate in cls.supported_rm_modules: |
|
if any(score_name_candidate in name for name in adapter_state_dict.keys()): |
|
score_name = score_name_candidate |
|
|
|
break |
|
|
|
score_dict = {} |
|
|
|
for name, param in adapter_state_dict.items(): |
|
if score_name in name: |
|
key_name = ".".join(name.split(".")[-1:]) |
|
score_dict[key_name] = param.to(cls._get_current_device()) |
|
|
|
num_labels, hidden_dim = score_dict["weight"].shape |
|
has_bias = any("bias" in name for name in adapter_state_dict.keys()) |
|
|
|
score = nn.Linear(hidden_dim, num_labels, bias=has_bias).to( |
|
device=cls._get_current_device(), |
|
dtype=pretrained_model.dtype, |
|
) |
|
score.load_state_dict(score_dict) |
|
for param in score.parameters(): |
|
param.requires_grad = False |
|
|
|
return score |
|
|
|
def push_to_hub(self, *args, **kwargs): |
|
r""" |
|
Push the pretrained model to the hub. This method is a wrapper around |
|
`transformers.PreTrainedModel.push_to_hub`. Please refer to the documentation |
|
of `transformers.PreTrainedModel.push_to_hub` for more information. |
|
|
|
Args: |
|
*args (`list`, *optional*): |
|
Positional arguments passed along to the underlying model's |
|
`push_to_hub` method. |
|
**kwargs (`dict`, *optional*): |
|
Keyword arguments passed along to the underlying model's |
|
`push_to_hub` method. |
|
""" |
|
raise NotImplementedError |
|
|
|
def save_pretrained(self, *args, **kwargs): |
|
r""" |
|
Save the pretrained model to a directory. This method is a wrapper around |
|
`transformers.PreTrainedModel.save_pretrained`. Please refer to the documentation |
|
of `transformers.PreTrainedModel.save_pretrained` for more information. |
|
|
|
Args: |
|
*args (`list`, *optional*): |
|
Positional arguments passed along to the underlying model's |
|
`save_pretrained` method. |
|
**kwargs (`dict`, *optional*): |
|
Keyword arguments passed along to the underlying model's |
|
`save_pretrained` method. |
|
""" |
|
state_dict = kwargs.get("state_dict") |
|
if state_dict is None: |
|
state_dict = self.state_dict() |
|
kwargs["state_dict"] = state_dict |
|
|
|
|
|
|
|
if self.is_peft_model: |
|
save_path = args[0] |
|
save_path = os.path.join(save_path, "pytorch_model.bin") |
|
torch.save(state_dict, save_path) |
|
_ = kwargs.pop("state_dict", None) |
|
|
|
return self.pretrained_model.save_pretrained(*args, **kwargs) |
|
|
|
def state_dict(self, *args, **kwargs): |
|
r""" |
|
Return the state_dict of the pretrained model. |
|
""" |
|
raise NotImplementedError |
|
|
|
def post_init(self, *args, **kwargs): |
|
r""" |
|
Post initialization method. This method is called after the model is |
|
instantiated and loaded from a checkpoint. It can be used to perform |
|
additional operations such as loading the state_dict. |
|
""" |
|
raise NotImplementedError |
|
|
|
def compute_reward_score(self, input_ids, attention_mask=None, **kwargs): |
|
r""" |
|
Computes the reward score for a given input. The method has first to enable the adapter |
|
and then compute the reward score. After that the model disables the reward modeling |
|
adapter and enables the default ppo adapter again. |
|
""" |
|
if not self.supports_rm_adapter: |
|
raise ValueError("This model does not support reward modeling adapter.") |
|
|
|
|
|
self.pretrained_model.set_adapter(self.rm_adapter_name) |
|
self.pretrained_model.eval() |
|
|
|
with torch.no_grad(): |
|
base_model_output = self.pretrained_model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
output_hidden_states=True, |
|
return_dict=True, |
|
**kwargs, |
|
) |
|
|
|
last_hidden_states = base_model_output.hidden_states[-1] |
|
scores = self.score(last_hidden_states) |
|
|
|
self.pretrained_model.set_adapter(self.policy_adapter_name) |
|
self.pretrained_model.eval() |
|
|
|
return scores |
|
|
|
|
|
def create_reference_model( |
|
model: PreTrainedModelWrapper, num_shared_layers: Optional[int] = None, pattern: Optional[str] = None |
|
) -> PreTrainedModelWrapper: |
|
""" |
|
Creates a static reference copy of a model. Note that model will be in `.eval()` mode. |
|
|
|
Args: |
|
model (`PreTrainedModelWrapper`): The model to be copied. |
|
num_shared_layers (`int`, *optional*): The number of initial layers that are shared between both models and kept frozen. |
|
pattern (`str`, *optional*): The shared layers are selected with a string pattern |
|
(e.g. "transformer.h.{layer}" for GPT2) and if a custom pattern is necessary it can be passed here. |
|
|
|
Returns: |
|
`PreTrainedModelWrapper` |
|
""" |
|
if is_deepspeed_zero3_enabled(): |
|
raise ValueError( |
|
"DeepSpeed ZeRO-3 is enabled and is not compatible with `create_reference_model()`. Please instantiate your reference model directly with `AutoModelForCausalLM.from_pretrained()`." |
|
) |
|
|
|
parameter_names = [n for n, _ in model.named_parameters()] |
|
ref_model = deepcopy(model) |
|
|
|
|
|
if num_shared_layers is None: |
|
for param_name in parameter_names: |
|
param = ref_model.get_parameter(param_name) |
|
param.requires_grad = False |
|
return ref_model.eval() |
|
|
|
|
|
if pattern is not None: |
|
pattern = pattern.format(layer=num_shared_layers) |
|
else: |
|
for pattern_candidate in LAYER_PATTERNS: |
|
pattern_candidate = pattern_candidate.format(layer=num_shared_layers) |
|
if any(pattern_candidate in name for name in parameter_names): |
|
pattern = pattern_candidate |
|
break |
|
|
|
if pattern is None: |
|
raise ValueError("Layer pattern could not be matched.") |
|
|
|
|
|
shared_param_list = [] |
|
unshared_param_list = [] |
|
|
|
shared_parameter = True |
|
for name, _param in model.named_parameters(): |
|
if pattern in name: |
|
shared_parameter = False |
|
if shared_parameter: |
|
shared_param_list.append(name) |
|
else: |
|
unshared_param_list.append(name) |
|
|
|
|
|
for param_name in shared_param_list: |
|
param = model.get_parameter(param_name) |
|
param.requires_grad = False |
|
|
|
_ref_param = ref_model.get_parameter(param_name) |
|
|
|
|
|
for param_name in unshared_param_list: |
|
param = ref_model.get_parameter(param_name) |
|
param.requires_grad = False |
|
|
|
if pattern is not None and len(unshared_param_list) == 0: |
|
logging.warning("Pattern passed or found, but no layers matched in the model. Check for a typo.") |
|
|
|
return ref_model.eval() |
|
|
|
|
|
class GeometricMixtureWrapper(GenerationMixin): |
|
r""" |
|
Geometric Mixture generation wrapper that samples from the logits of two model's geometric mixture. |
|
|
|
Args: |
|
model (`PreTrainedModel`): The model to be wrapped. |
|
ref_model (`PreTrainedModel`): The reference model. |
|
generation_config (`GenerationConfig`): The generation config. |
|
mixture_coef (`float`, *optional* - default: 0.5): The mixture coefficient. |
|
""" |
|
|
|
main_input_name = "input_ids" |
|
_supports_cache_class = False |
|
_supports_static_cache = False |
|
|
|
def __init__(self, model, ref_model, generation_config, mixture_coef=0.5, device=None): |
|
super().__init__() |
|
|
|
self.model = model |
|
self.config = model.config |
|
self.ref_model = ref_model |
|
self.generation_config = generation_config |
|
self.mixture_coef = mixture_coef |
|
self.device = device |
|
|
|
def __call__(self, *args, **kwargs): |
|
return self.forward(*args, **kwargs) |
|
|
|
@torch.inference_mode() |
|
def forward(self, *args, **kwargs): |
|
model_outputs = self.model(*args, **kwargs) |
|
model_logits = model_outputs.logits |
|
ref_model_logits = self.ref_model(*args, **kwargs).logits |
|
|
|
model_outputs.logits = torch.nn.functional.log_softmax( |
|
self.mixture_coef * ref_model_logits + (1 - self.mixture_coef) * model_logits, dim=-1 |
|
) |
|
|
|
return model_outputs |
|
|
|
def prepare_inputs_for_generation(self, *args, **kwargs): |
|
|
|
kwargs["use_cache"] = False |
|
model_inputs = self.model.prepare_inputs_for_generation(*args, **kwargs) |
|
_ = self.ref_model.prepare_inputs_for_generation(*args, **kwargs) |
|
|
|
return model_inputs |
|
|
|
def _validate_model_class(self): |
|
self.model._validate_model_class() |
|
|
|
def _validate_model_kwargs(self, model_kwargs): |
|
return self.model._validate_model_kwargs(model_kwargs) |
|
|