|
import torch |
|
from datasets import load_dataset,load_from_disk |
|
from transformers import AutoModelForCausalLM, Trainer, TrainingArguments, AutoTokenizer |
|
import numpy as np |
|
from torch.distributed.fsdp.fully_sharded_data_parallel import FullStateDictConfig |
|
from torch.distributed.fsdp import ( |
|
FullyShardedDataParallel as FSDP, FullStateDictConfig, StateDictType) |
|
from torch.utils.data import DataLoader, Dataset |
|
from torch.utils.data.distributed import DistributedSampler |
|
import yaml |
|
import wandb |
|
from huggingface_hub import HfApi |
|
|
|
config_file = "config.yaml" |
|
|
|
with open(config_file, "r") as file: |
|
config = yaml.safe_load(file) |
|
|
|
dsn1 = config["text_QA_dataset"] |
|
dsn2 = config["TTS_dataset"] |
|
|
|
model_name = config["model_name"] |
|
tokenizer_name = config["tokenizer_name"] |
|
|
|
run_name = config["run_name"] |
|
project_name = config["project_name"] |
|
base_repo_id = config["save_folder"] |
|
|
|
epochs = config["epochs"] |
|
batch_size = config["batch_size"] |
|
save_steps = config["save_steps"] |
|
pad_token = config["pad_token"] |
|
number_processes = config["number_processes"] |
|
learning_rate = config["learning_rate"] |
|
config_ratio = config["ratio"] |
|
|
|
|
|
|
|
|
|
class BatchedRatioDataset(Dataset): |
|
def __init__(self, dataset1, dataset2, batch_total, ratio=config_ratio): |
|
self.dataset1 = dataset1 |
|
self.dataset2 = dataset2 |
|
self.batch_total = batch_total |
|
self.ratio = ratio |
|
|
|
num_cycles_ds1 = len(dataset1) // (batch_total * ratio) |
|
num_cycles_ds2 = len(dataset2) // batch_total |
|
self.num_cycles = min(num_cycles_ds1, num_cycles_ds2) |
|
|
|
self.length = self.num_cycles * (ratio + 1) * batch_total |
|
|
|
def __len__(self): |
|
print("accessing length", self.length) |
|
return int(self.length) |
|
|
|
def __getitem__(self, index): |
|
|
|
cycle_length = (self.ratio + 1) * self.batch_total |
|
cycle = index // cycle_length |
|
pos_in_cycle = index % cycle_length |
|
|
|
if pos_in_cycle < self.ratio * self.batch_total: |
|
batch_in_cycle = pos_in_cycle // self.batch_total |
|
sample_in_batch = pos_in_cycle % self.batch_total |
|
ds1_index = cycle * self.ratio * self.batch_total + batch_in_cycle * self.batch_total + sample_in_batch |
|
return self.dataset1[ds1_index] |
|
else: |
|
|
|
sample_in_batch = pos_in_cycle - self.ratio * self.batch_total |
|
ds2_index = cycle * self.batch_total + sample_in_batch |
|
return self.dataset2[ds2_index] |
|
|
|
|
|
|
|
class AlternatingDistributedSampler(DistributedSampler): |
|
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False): |
|
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) |
|
self.shuffle = shuffle |
|
|
|
def __iter__(self): |
|
indices = list(range(len(self.dataset))) |
|
indices = indices[self.rank:self.total_size:self.num_replicas] |
|
return iter(indices) |
|
|
|
|
|
class FSDPTrainer(Trainer): |
|
def __init__(self, *args, log_ratio=config_ratio, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.repo_id = base_repo_id |
|
self.api = HfApi() |
|
|
|
self.log_ratio = log_ratio |
|
self.text_step = 0 |
|
self.audio_step = 0 |
|
|
|
def get_train_dataloader(self): |
|
sampler = AlternatingDistributedSampler( |
|
self.train_dataset, |
|
num_replicas=torch.distributed.get_world_size(), |
|
rank=torch.distributed.get_rank(), |
|
shuffle=False, |
|
) |
|
|
|
return DataLoader( |
|
self.train_dataset, |
|
batch_size=self.args.per_device_train_batch_size, |
|
sampler=sampler, |
|
collate_fn=self.data_collator, |
|
drop_last=self.args.dataloader_drop_last, |
|
num_workers=0, |
|
pin_memory=self.args.dataloader_pin_memory, |
|
) |
|
|
|
def log(self, logs, start_time=None): |
|
super().log(logs, start_time) |
|
if self.is_world_process_zero(): |
|
global_step = self.state.global_step |
|
|
|
cycle_length = self.log_ratio + 1 |
|
if (global_step % cycle_length) + self.log_ratio - 1 < self.log_ratio: |
|
wandb.log({"audio_loss": logs["loss"], "audio_step": self.audio_step}) |
|
self.audio_step += 1 |
|
else: |
|
wandb.log({"text_loss": logs["loss"], "text_step": self.text_step}) |
|
self.text_step += 1 |
|
|
|
def save_model(self, output_dir=None, _internal_call=False): |
|
if output_dir is None: |
|
output_dir = self.args.output_dir |
|
self.save_and_push_model(output_dir) |
|
|
|
def save_and_push_model(self, output_dir): |
|
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) |
|
with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT, save_policy): |
|
cpu_state_dict = self.model.state_dict() |
|
self.model.save_pretrained(output_dir, state_dict=cpu_state_dict) |
|
|
|
|
|
def data_collator(features): |
|
|
|
|
|
input_ids = [f["input_ids"] for f in features] |
|
|
|
if any("attention_mask" not in f for f in features): |
|
attention_mask = [[1]*len(ids) for ids in input_ids] |
|
else: |
|
attention_mask = [f["attention_mask"] for f in features] |
|
|
|
if any("labels" not in f for f in features): |
|
labels = input_ids |
|
else: |
|
labels = [f["labels"] for f in features] |
|
|
|
input_ids = torch.nn.utils.rnn.pad_sequence([torch.tensor( |
|
i, dtype=torch.long) for i in input_ids], batch_first=True, padding_value=pad_token) |
|
attention_mask = torch.nn.utils.rnn.pad_sequence([torch.tensor( |
|
m, dtype=torch.long) for m in attention_mask], batch_first=True, padding_value=0) |
|
labels = torch.nn.utils.rnn.pad_sequence([torch.tensor( |
|
l, dtype=torch.long) for l in labels], batch_first=True, padding_value=-100) |
|
|
|
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} |
|
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name, local_files_only=True) |
|
|
|
model.cuda() |
|
|
|
number_add_tokens = 7 * 4096 + 10 |
|
new_tokens = [f"<custom_token_{i}>" for i in range(0, number_add_tokens + 1)] |
|
tokenizer.add_tokens(new_tokens) |
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
|
|
ds1 = load_from_disk(dsn1, keep_in_memory=True) |
|
ds2 = load_from_disk(dsn2, keep_in_memory=True) |
|
|
|
|
|
batch_total = batch_size * number_processes |
|
train_dataset = BatchedRatioDataset(ds1, ds2, batch_total, ratio=config_ratio) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training_args = TrainingArguments( |
|
overwrite_output_dir=True, |
|
num_train_epochs=epochs, |
|
per_device_train_batch_size=batch_size, |
|
logging_steps=1, |
|
bf16=True, |
|
output_dir=f"./{base_repo_id}", |
|
report_to="tensorboard", |
|
save_steps=save_steps, |
|
remove_unused_columns=True, |
|
learning_rate=learning_rate, |
|
lr_scheduler_type="cosine", |
|
) |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=train_dataset, |
|
data_collator=data_collator, |
|
|
|
) |
|
|
|
trainer.train() |
|
|