|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from huggingface_hub import HfApi |
|
|
|
from trl.import_utils import is_mergekit_available |
|
|
|
|
|
if is_mergekit_available(): |
|
from mergekit.config import MergeConfiguration |
|
from mergekit.merge import MergeOptions, run_merge |
|
|
|
|
|
def upload_model_to_hf(folder_path: str, repo_id: str): |
|
api = HfApi() |
|
|
|
repo = api.create_repo(repo_id, repo_type="model") |
|
|
|
|
|
api.upload_folder( |
|
folder_path=folder_path, |
|
repo_id=repo.repo_id, |
|
repo_type=repo.repo_type, |
|
) |
|
|
|
|
|
class MergeConfig: |
|
r""" |
|
Configuration class for merging two models using `mergekit`. |
|
|
|
This class provides a structured way to configure and generate merge configurations for various merge methods, |
|
such as `linear`, `ties`, `dare_ties`, and `slerp`. |
|
|
|
Args: |
|
method (`str`, *optional*, defaults to `"linear"`): |
|
Merge method to use. Supported methods include: |
|
|
|
- `"linear"`: Linearly combines two models with specified weights. |
|
- `"ties"`: Combines two models using the TIES method with density parameters. |
|
- `"dare_ties"`: A variant of TIES for domain adaptation. |
|
- `"slerp"`: Combines models using spherical linear interpolation. |
|
|
|
Note: |
|
|
|
For more details about the merge methods and how they are implemented, see the |
|
[MergeKit GitHub repository](https://github.com/arcee-ai/mergekit?tab=readme-ov-file#merge-methods). |
|
|
|
Attributes: |
|
method (`str`): The merge method to use. |
|
policy_model_path (`str` or `None`): Path to the policy model. |
|
target_model_path (`str` or `None`): Path to the target model. |
|
policy_model_weight (`float`): Weight for the policy model (for `linear` and `ties` methods). |
|
target_model_weight (`float`): Weight for the target model (for `linear` and `ties` methods). |
|
policy_model_density (`list[float]`): Density parameters for the policy model (for `ties` and `dare_ties`). |
|
target_model_density (`list[float]`): Density parameters for the target model (for `ties` and `dare_ties`). |
|
normalize (`float` or `None`): Normalization factor for the TIES method. |
|
t_values (`float` or `None`): Interpolation factor for the SLERP method. |
|
dtype (`str`): Data type to use for merging, e.g., `"float16"`. |
|
""" |
|
|
|
def __init__(self, method: str = "linear"): |
|
if not is_mergekit_available(): |
|
raise ImportError( |
|
"MergeConfig requires the `mergekit` extra. To install, run `pip install trl[mergekit]`." |
|
) |
|
self.method = method |
|
self.policy_model_path = None |
|
self.target_model_path = None |
|
|
|
|
|
if method == "linear": |
|
self.policy_model_weight = 0.5 |
|
self.target_model_weight = 0.5 |
|
self.dtype = "float16" |
|
elif method == "ties": |
|
self.policy_model_weight = 1.0 |
|
self.policy_model_density = [1.0, 0.7, 0.1] |
|
self.target_model_weight = 1.0 |
|
self.target_model_density = [1.0] |
|
self.normalize = 1.0 |
|
self.dtype = "float16" |
|
elif method == "dare_ties": |
|
self.policy_model_weight = 1.0 |
|
self.policy_model_density = [1.0, 0.7, 0.1] |
|
self.target_model_weight = 1.0 |
|
self.target_model_density = [1.0] |
|
self.normalize = 1.0 |
|
self.dtype = "float16" |
|
elif method == "slerp": |
|
self.t_values = 0.5 |
|
self.dtype = "float16" |
|
else: |
|
raise ValueError(f"Unsupported merge method: {method}") |
|
|
|
def create_merge_config_linear(self) -> "MergeConfiguration": |
|
""" |
|
Creates a merge configuration for a linear merge of two models with specified weights. |
|
""" |
|
|
|
merge_config_dict = { |
|
"dtype": self.dtype, |
|
"merge_method": "linear", |
|
"models": [ |
|
{"model": self.policy_model_path, "parameters": {"weight": self.policy_model_weight}}, |
|
{"model": self.target_model_path, "parameters": {"weight": self.target_model_weight}}, |
|
], |
|
} |
|
|
|
|
|
merge_config = MergeConfiguration.model_validate(merge_config_dict) |
|
|
|
return merge_config |
|
|
|
def create_merge_config_ties(self) -> "MergeConfiguration": |
|
""" |
|
Creates a merge configuration for a TIES merge of two models, with specified weights and densities. |
|
""" |
|
|
|
merge_config_dict = { |
|
"merge_method": "ties", |
|
"slices": None, |
|
"models": [ |
|
{ |
|
"model": { |
|
"model": {"path": self.target_model_path, "revision": None}, |
|
"lora": None, |
|
"override_architecture": None, |
|
}, |
|
"parameters": {"density": self.target_model_density, "weight": self.target_model_weight}, |
|
}, |
|
{ |
|
"model": { |
|
"model": {"path": self.policy_model_path, "revision": None}, |
|
"lora": None, |
|
"override_architecture": None, |
|
}, |
|
"parameters": {"density": self.policy_model_density, "weight": self.policy_model_weight}, |
|
}, |
|
], |
|
"parameters": {"normalize": self.normalize}, |
|
"base_model": { |
|
"model": {"path": self.policy_model_path, "revision": None}, |
|
"lora": None, |
|
"override_architecture": None, |
|
}, |
|
"dtype": self.dtype, |
|
"tokenizer_source": None, |
|
"tokenizer": None, |
|
"chat_template": None, |
|
"out_dtype": None, |
|
} |
|
|
|
|
|
merge_config = MergeConfiguration.model_validate(merge_config_dict) |
|
|
|
return merge_config |
|
|
|
def create_merge_config_dare_ties(self) -> "MergeConfiguration": |
|
""" |
|
Creates a merge configuration for a DARE TIES merge of two models, with specified weights and densities. |
|
""" |
|
|
|
merge_config_dict = { |
|
"merge_method": "dare_ties", |
|
"slices": None, |
|
"models": [ |
|
{ |
|
"model": { |
|
"model": {"path": self.target_model_path, "revision": None}, |
|
"lora": None, |
|
"override_architecture": None, |
|
}, |
|
"parameters": {"density": self.target_model_density, "weight": self.target_model_weight}, |
|
}, |
|
{ |
|
"model": { |
|
"model": {"path": self.policy_model_path, "revision": None}, |
|
"lora": None, |
|
"override_architecture": None, |
|
}, |
|
"parameters": {"density": self.policy_model_density, "weight": self.policy_model_weight}, |
|
}, |
|
], |
|
"parameters": {"normalize": self.normalize}, |
|
"base_model": { |
|
"model": {"path": self.policy_model_path, "revision": None}, |
|
"lora": None, |
|
"override_architecture": None, |
|
}, |
|
"dtype": self.dtype, |
|
"tokenizer_source": None, |
|
"tokenizer": None, |
|
"chat_template": None, |
|
"out_dtype": None, |
|
} |
|
|
|
|
|
merge_config = MergeConfiguration.model_validate(merge_config_dict) |
|
|
|
return merge_config |
|
|
|
def create_merge_config_slerp(self) -> "MergeConfiguration": |
|
""" |
|
Creates a merge configuration for a SLERP merge of a model with a base model. |
|
""" |
|
|
|
|
|
merge_config_dict = { |
|
"merge_method": "slerp", |
|
"slices": None, |
|
"models": [ |
|
{ |
|
"model": { |
|
"model": {"path": self.target_model_path, "revision": None}, |
|
"lora": None, |
|
"override_architecture": None, |
|
}, |
|
"parameters": None, |
|
} |
|
], |
|
"parameters": { |
|
"t": self.t_values |
|
}, |
|
"base_model": { |
|
"model": {"path": self.policy_model_path, "revision": None}, |
|
"lora": None, |
|
"override_architecture": None, |
|
}, |
|
"dtype": self.dtype, |
|
"tokenizer_source": None, |
|
"tokenizer": None, |
|
"chat_template": None, |
|
"out_dtype": None, |
|
} |
|
|
|
|
|
merge_config = MergeConfiguration.model_validate(merge_config_dict) |
|
|
|
return merge_config |
|
|
|
def create(self) -> "MergeConfiguration": |
|
if self.method == "linear": |
|
return self.create_merge_config_linear() |
|
elif self.method == "ties": |
|
return self.create_merge_config_ties() |
|
elif self.method == "dare_ties": |
|
return self.create_merge_config_dare_ties() |
|
elif self.method == "slerp": |
|
return self.create_merge_config_slerp() |
|
|
|
|
|
def merge_models(config: MergeConfig, out_path: str): |
|
""" |
|
Merge two models using mergekit |
|
|
|
Args: |
|
config (`MergeConfig`): The merge configuration. |
|
out_path (`str`): The output path for the merged model. |
|
""" |
|
if not is_mergekit_available(): |
|
raise ImportError("merge_models requires the `mergekit` extra. To install, run `pip install trl[mergekit]`.") |
|
run_merge( |
|
config, |
|
out_path=out_path, |
|
options=MergeOptions( |
|
cuda=torch.cuda.is_available(), |
|
copy_tokenizer=True, |
|
lazy_unpickle=False, |
|
low_cpu_memory=False, |
|
), |
|
) |
|
|