|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import is_torch_available |
|
from transformers.testing_utils import ( |
|
TestCasePlus, |
|
backend_device_count, |
|
execute_subprocess_async, |
|
get_torch_dist_unique_port, |
|
require_accelerate, |
|
require_fp8, |
|
require_torch_multi_accelerator, |
|
run_first, |
|
torch_device, |
|
) |
|
|
|
|
|
if is_torch_available(): |
|
import torch |
|
import torch.distributed |
|
import torch.utils.data |
|
|
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
DataCollatorForSeq2Seq, |
|
EvalPrediction, |
|
GenerationConfig, |
|
HfArgumentParser, |
|
PreTrainedTokenizerBase, |
|
Seq2SeqTrainer, |
|
Seq2SeqTrainingArguments, |
|
) |
|
|
|
class DummyTextDataset(torch.utils.data.Dataset[str]): |
|
def __init__(self, tokenizer: PreTrainedTokenizerBase) -> None: |
|
data = 4 * [ |
|
"Hello world!", |
|
"The quick brown fox jumps over the lazy dog.", |
|
] |
|
self.data = [ |
|
{k: v.squeeze(0) for k, v in tokenizer(item, return_tensors="pt", return_attention_mask=True).items()} |
|
for item in data |
|
] |
|
for item in self.data: |
|
item["labels"] = item["input_ids"] |
|
|
|
def __len__(self) -> int: |
|
return len(self.data) |
|
|
|
def __getitem__(self, i: int) -> str: |
|
return self.data[i] |
|
|
|
|
|
class TestFSDPTrainer(TestCasePlus): |
|
@require_torch_multi_accelerator |
|
@require_accelerate |
|
@run_first |
|
def test_trainer(self): |
|
output_dir = self.get_auto_remove_tmp_dir() |
|
cmd = [ |
|
"accelerate", |
|
"launch", |
|
"--use_fsdp", |
|
"--main_process_port", |
|
f"{get_torch_dist_unique_port()}", |
|
"--num_processes", |
|
f"{backend_device_count(torch_device)}", |
|
"--fsdp_transformer_layer_cls_to_wrap", |
|
"GPT2Block", |
|
f"{self.test_file_dir}/test_trainer_fsdp.py", |
|
"--output_dir", |
|
f"{output_dir}", |
|
"--report_to", |
|
"none", |
|
] |
|
execute_subprocess_async(cmd, env=self.get_env()) |
|
|
|
|
|
|
|
class TestFSDPTrainerFP8(TestCasePlus): |
|
@require_torch_multi_accelerator |
|
@require_accelerate |
|
@require_fp8 |
|
@run_first |
|
def test_trainer(self): |
|
output_dir = self.get_auto_remove_tmp_dir() |
|
cmd = [ |
|
"accelerate", |
|
"launch", |
|
"--use_fsdp", |
|
"--main_process_port", |
|
f"{get_torch_dist_unique_port()}", |
|
"--num_processes", |
|
f"{backend_device_count(torch_device)}", |
|
"--mixed_precision", |
|
"fp8", |
|
"--fsdp_transformer_layer_cls_to_wrap", |
|
"GPT2Block", |
|
f"{self.test_file_dir}/test_trainer_fsdp.py", |
|
"--output_dir", |
|
f"{output_dir}", |
|
"--report_to", |
|
"none", |
|
] |
|
execute_subprocess_async(cmd, env=self.get_env()) |
|
|
|
|
|
|
|
class TestFSDPTrainerWrap(TestCasePlus): |
|
@require_torch_multi_accelerator |
|
@require_accelerate |
|
@run_first |
|
def test_trainer(self): |
|
output_dir = self.get_auto_remove_tmp_dir() |
|
cmd = [ |
|
"accelerate", |
|
"launch", |
|
"--use_fsdp", |
|
"--main_process_port", |
|
f"{get_torch_dist_unique_port()}", |
|
"--num_processes", |
|
f"{backend_device_count(torch_device)}", |
|
"--fsdp_transformer_layer_cls_to_wrap", |
|
"GPT2Block", |
|
f"{self.test_file_dir}/test_trainer_fsdp.py", |
|
"--output_dir", |
|
f"{output_dir}", |
|
"--report_to", |
|
"none", |
|
"--auto_find_batch_size", |
|
"True", |
|
] |
|
execute_subprocess_async(cmd, env=self.get_env()) |
|
|
|
|
|
|
|
class TestFSDPTrainerTorchCompile(TestCasePlus): |
|
@require_torch_multi_accelerator |
|
@require_accelerate |
|
@run_first |
|
def test_trainer(self): |
|
output_dir = self.get_auto_remove_tmp_dir() |
|
cmd = [ |
|
"accelerate", |
|
"launch", |
|
"--use_fsdp", |
|
"--main_process_port", |
|
f"{get_torch_dist_unique_port()}", |
|
"--num_processes", |
|
f"{backend_device_count(torch_device)}", |
|
"--fsdp_transformer_layer_cls_to_wrap", |
|
"GPT2Block", |
|
f"{self.test_file_dir}/test_trainer_fsdp.py", |
|
"--torch_compile_mode", |
|
"default", |
|
"--output_dir", |
|
f"{output_dir}", |
|
"--report_to", |
|
"none", |
|
] |
|
execute_subprocess_async(cmd, env=self.get_env()) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
parser = HfArgumentParser((Seq2SeqTrainingArguments,)) |
|
training_args = parser.parse_args_into_dataclasses()[0] |
|
training_args.per_device_eval_batch_size = 1 |
|
training_args.use_legacy_prediction_loop = False |
|
training_args.predict_with_generate = True |
|
training_args.generation_config = GenerationConfig(max_length=30) |
|
|
|
pretrained_model_name = "hf-internal-testing/tiny-random-gpt2" |
|
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
device = torch.device(torch.distributed.get_rank()) |
|
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name).to(device) |
|
|
|
def compute_metrics(p: EvalPrediction) -> dict[str, bool]: |
|
return {"accuracy": (p.predictions == p.label_ids).mean()} |
|
|
|
trainer = Seq2SeqTrainer( |
|
model=model, |
|
args=training_args, |
|
data_collator=DataCollatorForSeq2Seq(tokenizer, model), |
|
eval_dataset=DummyTextDataset(tokenizer), |
|
compute_metrics=compute_metrics, |
|
) |
|
|
|
metrics = trainer.evaluate() |
|
|