|
|
|
|
|
|
|
|
|
|
|
import os |
|
import inspect |
|
from deepspeed.utils import get_caller_func |
|
|
|
|
|
def get_local_rank_from_launcher(): |
|
|
|
|
|
rank = os.environ.get('LOCAL_RANK') |
|
|
|
if rank is None: |
|
rank = os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') |
|
|
|
|
|
if rank is None: |
|
rank = 0 |
|
|
|
return int(rank) |
|
|
|
|
|
def get_world_rank_from_launcher(): |
|
|
|
|
|
rank = os.environ.get('RANK') |
|
|
|
if rank is None: |
|
rank = os.environ.get('OMPI_COMM_WORLD_RANK') |
|
|
|
|
|
if rank is None: |
|
rank = 0 |
|
|
|
return int(rank) |
|
|
|
|
|
def get_world_size_from_launcher(): |
|
|
|
size = os.environ.get('WORLD_SIZE') |
|
rank = os.environ.get('RANK') |
|
|
|
if size is None: |
|
size = os.environ.get('OMPI_COMM_WORLD_SIZE') |
|
|
|
|
|
if size is None: |
|
size = 1 |
|
|
|
if rank == 0: |
|
print(f"set world size to {size}") |
|
|
|
return int(size) |
|
|
|
|
|
def get_default_args(func): |
|
signature = inspect.signature(func) |
|
return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty} |
|
|
|
|
|
|
|
def get_tensor_position(func): |
|
sig_params = inspect.signature(func).parameters |
|
arg = None |
|
|
|
if 'tensor' in sig_params: |
|
arg = 'tensor' |
|
|
|
elif 'tensors' in sig_params: |
|
arg = 'tensors' |
|
|
|
elif 'input_list' in sig_params: |
|
arg = 'input_list' |
|
|
|
elif 'input_tensor_list' in sig_params: |
|
arg = 'input_tensor_list' |
|
if arg is None: |
|
return -1 |
|
else: |
|
return list(sig_params).index(arg) |
|
|
|
|
|
def get_tensor_kwarg(func, kwargs): |
|
func_args = get_default_args(func) |
|
func_args.update(kwargs) |
|
arg = None |
|
|
|
if 'tensor' in func_args: |
|
arg = func_args['tensor'] |
|
elif 'tensors' in func_args: |
|
arg = func_args['tensors'] |
|
elif 'input_list' in func_args: |
|
arg = func_args['input_list'] |
|
elif 'input_tensor_list' in func_args: |
|
arg = func_args['input_tensor_list'] |
|
return arg |
|
|
|
|
|
def get_msg_size_from_args(func, *args, **kwargs): |
|
|
|
|
|
|
|
|
|
tensor_arg_position = -1 |
|
tensor_arg = None |
|
|
|
if len(args) > 0: |
|
tensor_arg_position = get_tensor_position(func) |
|
if tensor_arg_position > -1: |
|
tensor_arg = args[get_tensor_position(func)] |
|
|
|
if tensor_arg is None and len(kwargs) > 0: |
|
tensor_arg = get_tensor_kwarg(func, kwargs) |
|
|
|
if tensor_arg is None: |
|
return 0 |
|
else: |
|
|
|
|
|
if type(tensor_arg) is list: |
|
return sum(x.element_size() * x.nelement() for x in tensor_arg) |
|
else: |
|
return tensor_arg.element_size() * tensor_arg.nelement() |
|
|
|
|
|
def get_debug_log_name(func_args, debug): |
|
if debug: |
|
return func_args['log_name'] + ' | [Caller Func: ' + get_caller_func() + ']' |
|
else: |
|
return func_args['log_name'] |
|
|