|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
from typing import Any, Callable |
|
|
|
from transformers import is_torch_available, is_torch_xpu_available |
|
from transformers.testing_utils import ( |
|
TestCasePlus, |
|
backend_device_count, |
|
backend_torch_accelerator_module, |
|
execute_subprocess_async, |
|
get_torch_dist_unique_port, |
|
require_torch_multi_accelerator, |
|
torch_device, |
|
) |
|
from transformers.utils import is_ccl_available, is_ipex_available |
|
|
|
|
|
if is_torch_available(): |
|
import functools |
|
|
|
import torch |
|
|
|
if is_torch_xpu_available(): |
|
if is_ipex_available(): |
|
import intel_extension_for_pytorch |
|
if is_ccl_available(): |
|
import oneccl_bindings_for_pytorch |
|
import torch.distributed |
|
from torch.distributed._composable.fsdp import fully_shard, register_fsdp_forward_method |
|
from torch.distributed.device_mesh import init_device_mesh |
|
from torch.distributed.fsdp import FullyShardedDataParallel |
|
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy |
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Block |
|
|
|
data = 4 * [ |
|
"Hello world!", |
|
"The quick brown fox jumps over the lazy dog.", |
|
] |
|
|
|
def manage_process_group(func: Callable[..., Any]) -> Callable[..., Any]: |
|
"""Manage the creation and destruction of the distributed process group for the wrapped function.""" |
|
|
|
def wrapped(*args: Any, **kwargs: Any) -> Any: |
|
device_count = backend_device_count(torch_device) |
|
torch.distributed.init_process_group(world_size=device_count) |
|
try: |
|
return func(*args, **kwargs) |
|
finally: |
|
torch.distributed.destroy_process_group() |
|
|
|
return wrapped |
|
|
|
@manage_process_group |
|
def fsdp_generate(): |
|
torch_accelerator_module = backend_torch_accelerator_module(torch_device) |
|
torch_accelerator_module.set_device(device := torch.device(rank := torch.distributed.get_rank())) |
|
|
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device) |
|
|
|
fsdp_model = FullyShardedDataParallel( |
|
model, |
|
auto_wrap_policy=functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={GPT2Block}), |
|
limit_all_gathers=True, |
|
use_orig_params=True, |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") |
|
batch = tokenizer(data[rank], return_tensors="pt", return_attention_mask=True).to(device) |
|
|
|
with FullyShardedDataParallel.summon_full_params(fsdp_model): |
|
_ = fsdp_model.module.generate( |
|
input_ids=batch["input_ids"], |
|
attention_mask=batch["attention_mask"], |
|
max_length=30, |
|
) |
|
|
|
@manage_process_group |
|
def fsdp2_generate(): |
|
torch_accelerator_module = backend_torch_accelerator_module(torch_device) |
|
torch_accelerator_module.set_device(device := torch.device(rank := torch.distributed.get_rank())) |
|
|
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device) |
|
|
|
mesh = init_device_mesh(device.type, (torch.distributed.get_world_size(),)) |
|
for submodule in model.modules(): |
|
if isinstance(submodule, GPT2Block): |
|
fully_shard(submodule, mesh=mesh) |
|
fully_shard(model, mesh=mesh) |
|
|
|
register_fsdp_forward_method(model, "generate") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") |
|
batch = tokenizer(data[rank], return_tensors="pt", return_attention_mask=True).to(device) |
|
|
|
_ = model.generate( |
|
input_ids=batch["input_ids"], |
|
attention_mask=batch["attention_mask"], |
|
max_length=30, |
|
) |
|
|
|
|
|
class TestFSDPGeneration(TestCasePlus): |
|
@require_torch_multi_accelerator |
|
def test_fsdp_generate(self): |
|
device_count = backend_device_count(torch_device) |
|
distributed_args = f"""--nproc_per_node={device_count} |
|
--master_port={get_torch_dist_unique_port()} |
|
{self.test_file_dir}/test_fsdp.py |
|
""".split() |
|
args = "--fsdp".split() |
|
cmd = ["torchrun"] + distributed_args + args |
|
execute_subprocess_async(cmd, env=self.get_env()) |
|
|
|
|
|
@require_torch_multi_accelerator |
|
def test_fsdp2_generate(self): |
|
device_count = backend_device_count(torch_device) |
|
|
|
distributed_args = f"""--nproc_per_node={device_count} |
|
--master_port={get_torch_dist_unique_port()} |
|
{self.test_file_dir}/test_fsdp.py |
|
""".split() |
|
args = "--fsdp2".split() |
|
cmd = ["torchrun"] + distributed_args + args |
|
execute_subprocess_async(cmd, env=self.get_env()) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
class CLIArgs(argparse.Namespace): |
|
fsdp: bool |
|
fsdp2: bool |
|
|
|
parser = argparse.ArgumentParser() |
|
group = parser.add_mutually_exclusive_group() |
|
group.add_argument("--fsdp", action="store_true") |
|
group.add_argument("--fsdp2", action="store_true") |
|
args = parser.parse_args(namespace=CLIArgs()) |
|
|
|
if args.fsdp: |
|
fsdp_generate() |
|
elif args.fsdp2: |
|
fsdp2_generate() |
|
else: |
|
raise ValueError("Missing test selection") |
|
|