|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import time |
|
import os |
|
import deepspeed |
|
from deepspeed import comm as dist |
|
from deepspeed.utils.logging import log_dist |
|
|
|
from torch.nn.modules import Module |
|
from packaging import version as pkg_version |
|
from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine |
|
from deepspeed.utils.timer import SynchronizedWallClockTimer |
|
from deepspeed.runtime.compiler import is_compile_supported |
|
from ..runtime.state_dict_factory import SDLoaderFactory |
|
from ..runtime.weight_quantizer import WeightQuantization |
|
from ..module_inject import replace_transformer_layer, generic_injection |
|
from ..comm.comm import init_distributed |
|
from ..pipe import PipelineModule |
|
from ..moe.utils import has_moe_layers |
|
from ..module_inject import LinearAllreduce, LinearLayer, Normalize, ReplaceWithTensorSlicing |
|
from deepspeed.accelerator import get_accelerator |
|
from ..module_inject.policy import TransformerPolicy |
|
from ..module_inject.auto_tp import AutoTP |
|
|
|
from ..module_inject.replace_policy import generic_policies |
|
from ..module_inject.auto_tp_model_utils import build_bloom_alibi_tensor, build_mpt_atten_bias_tensor, build_mpt_alibi_tensor, get_alibi_mask |
|
from ..ops.transformer.inference.ds_attention import DeepSpeedSelfAttention |
|
from ..model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference |
|
|
|
DS_INFERENCE_ENABLED = False |
|
from torch import nn |
|
|
|
INFERENCE_MODEL_TIMER = "model-forward-inference" |
|
|
|
|
|
class InferenceEngine(Module): |
|
inference_mp_group = None |
|
inference_ep_group = None |
|
expert_mp_group = None |
|
|
|
def __init__(self, model, config): |
|
""" |
|
Args: |
|
model: torch.nn.Module |
|
config: DeepSpeedInferenceConfig |
|
""" |
|
global DS_INFERENCE_ENABLED |
|
DS_INFERENCE_ENABLED = True |
|
|
|
super().__init__() |
|
if DeepSpeedTransformerInference.workspace is not None: |
|
self.destroy() |
|
|
|
self.module = model |
|
self._config = config |
|
|
|
self._get_model_config_generate(config) |
|
|
|
|
|
if hasattr(self.module, "generate"): |
|
self.generate = self._generate |
|
|
|
if hasattr(self.module, "config"): |
|
TransformerPolicy.hf_model_config = self.module.config |
|
|
|
if config.dtype not in get_accelerator().supported_dtypes(): |
|
raise ValueError( |
|
f"Data type {config.dtype} is not supported by {get_accelerator().device_name()} accelerator") |
|
|
|
|
|
|
|
self.injection_dict = config.injection_policy |
|
|
|
|
|
self.mp_group = config.tensor_parallel.tp_group |
|
self.mpu = config.tensor_parallel.mpu |
|
|
|
self.quantize_merge_count = 1 |
|
self.quantization_scales = None |
|
|
|
|
|
self.ep_group = None |
|
self.expert_mp_group = None |
|
|
|
self.cuda_graph_created = False |
|
self.checkpoint_engine = TorchCheckpointEngine() |
|
quantization_setting = None |
|
self._init_quantization_setting( |
|
quantization_setting) |
|
self.model_profile_enabled = False |
|
self._model_times = [] |
|
|
|
if not self.injection_dict and config.replace_with_kernel_inject: |
|
|
|
self.remove_mask_prepare_for_bloom() |
|
|
|
if self.injection_dict or not config.replace_with_kernel_inject: |
|
|
|
if config.tensor_parallel.tp_size > 1: |
|
self.build_alibi_tensor() |
|
self.build_attn_bias() |
|
|
|
if get_accelerator().device_name() == 'cuda' and config.enable_cuda_graph: |
|
assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \ |
|
"If you want to use cuda graph, please upgrade torch to at least v1.10" |
|
|
|
|
|
if config.dtype: |
|
self._convert_to_dtype(config) |
|
|
|
if self.mpu: |
|
config.tensor_parallel.tp_size = dist.get_world_size(group=self.mpu.get_model_parallel_group()) |
|
self.mp_group = self.mpu.get_model_parallel_group() |
|
elif config.tensor_parallel.tp_size > 1: |
|
self._create_model_parallel_group(config) |
|
config.tensor_parallel.tp_group = self.mp_group |
|
|
|
if isinstance(self.module, torch.nn.Module): |
|
moe, _ = has_moe_layers(self.module) |
|
else: |
|
moe = False |
|
|
|
if moe and dist.get_world_size() > 1: |
|
self._create_ep_parallel_group(config.moe.moe_experts) |
|
|
|
|
|
if self.injection_dict: |
|
|
|
assert not config.replace_with_kernel_inject, "Cannot use both user specified injection policy and kernel injection" |
|
for client_module, injection_policy in self.injection_dict.items(): |
|
|
|
assert issubclass(client_module, |
|
torch.nn.Module), f"{client_module} is not a subclass of torch.nn.Module" |
|
|
|
|
|
if isinstance(injection_policy, str): |
|
config.injection_policy_tuple = (injection_policy, ) |
|
else: |
|
config.injection_policy_tuple = injection_policy |
|
|
|
layer_names = [name for name, _ in self.module.named_modules()] |
|
for policy in config.injection_policy_tuple: |
|
if not any(name.endswith(policy) for name in layer_names): |
|
raise ValueError(f"Injection policy layer'{policy}' not valid.") |
|
|
|
self._apply_injection_policy(config, client_module) |
|
else: |
|
if config.replace_with_kernel_inject: |
|
|
|
self._apply_injection_policy(config) |
|
elif config.tensor_parallel.tp_size > 1: |
|
|
|
parser_dict = AutoTP.tp_parser(model) |
|
print("AutoTP: ", parser_dict) |
|
for client_module, injection_policy in parser_dict: |
|
if isinstance(injection_policy, str): |
|
config.injection_policy_tuple = (injection_policy, ) |
|
else: |
|
config.injection_policy_tuple = injection_policy |
|
self._apply_injection_policy(config, client_module) |
|
|
|
device = get_accelerator().current_device_name() |
|
|
|
is_meta_device = hasattr(self.module, "device") and self.module.device.type == 'meta' |
|
if is_meta_device: |
|
self.module.to_empty(device=device) |
|
elif not config.keep_module_on_host: |
|
self.module.to(device) |
|
|
|
if config.tensor_parallel.tp_size > 1: |
|
_rng_state = get_accelerator().get_rng_state().to(get_accelerator().current_device_name()) |
|
dist.broadcast(_rng_state, 0) |
|
get_accelerator().set_rng_state(_rng_state.cpu()) |
|
|
|
if config.tensor_parallel.tp_size > 1: |
|
assert not config.enable_cuda_graph, "Cuda graph is not supported for model parallelism" |
|
|
|
|
|
self.local_cuda_graph = self._local_cuda_graph_used(self.module) |
|
self._is_compiled = False |
|
|
|
def destroy(self): |
|
DeepSpeedTransformerInference.layer_id = 0 |
|
DeepSpeedSelfAttention.num_layers = 0 |
|
if DeepSpeedTransformerInference.workspace.is_allocated(): |
|
DeepSpeedTransformerInference.workspace.release_workspace() |
|
DeepSpeedTransformerInference.workspace = None |
|
|
|
def profile_model_time(self, use_cuda_events=True): |
|
if not self.model_profile_enabled and not self._config.enable_cuda_graph: |
|
self.module.register_forward_pre_hook(self._pre_forward_hook) |
|
self.module.register_forward_hook(self._post_forward_hook) |
|
self.model_profile_enabled = True |
|
self.use_cuda_events = use_cuda_events |
|
if self.use_cuda_events: |
|
self.timers = SynchronizedWallClockTimer() |
|
|
|
|
|
def _get_model_config_generate(self, config): |
|
|
|
self.config = getattr(self.module, 'config', None) if config.config is None else config.config |
|
|
|
def remove_mask_prepare_for_bloom(self): |
|
if hasattr(self.module, 'transformer'): |
|
if hasattr(self.module.transformer, '_prepare_attn_mask'): |
|
self.module.transformer._prepare_attn_mask = lambda attention_mask, *args, **kwargs: attention_mask |
|
|
|
def build_alibi_tensor(self): |
|
if hasattr(self.module, 'transformer'): |
|
if hasattr(self.module.transformer, 'build_alibi_tensor'): |
|
self.module.transformer.build_alibi_tensor = build_bloom_alibi_tensor |
|
if hasattr(self.module.transformer, 'build_mpt_alibi_tensor'): |
|
self.module.transformer.build_mpt_alibi_tensor_orig = self.module.transformer.build_mpt_alibi_tensor |
|
self.module.transformer.__class__.build_mpt_alibi_tensor = build_mpt_alibi_tensor |
|
if hasattr(self.module, 'model'): |
|
if hasattr(self.module.model, 'get_alibi_mask'): |
|
self.module.model.get_alibi_mask_orig = self.module.model.get_alibi_mask |
|
self.module.model.__class__.get_alibi_mask = get_alibi_mask |
|
|
|
def build_attn_bias(self): |
|
if hasattr(self.module, 'transformer'): |
|
if hasattr(self.module.transformer, '_attn_bias'): |
|
self.module.transformer._attn_bias_orig = self.module.transformer._attn_bias |
|
self.module.transformer.__class__._attn_bias = build_mpt_atten_bias_tensor |
|
|
|
def _pre_forward_hook(self, module, *inputs, **kwargs): |
|
if self.use_cuda_events: |
|
self.timers(INFERENCE_MODEL_TIMER).start() |
|
else: |
|
get_accelerator().synchronize() |
|
self._start = time.time() |
|
|
|
def _post_forward_hook(self, module, input, output): |
|
if self.use_cuda_events: |
|
self.timers(INFERENCE_MODEL_TIMER).stop() |
|
elapsed_time = self.timers(INFERENCE_MODEL_TIMER).elapsed(reset=True) |
|
else: |
|
get_accelerator().synchronize() |
|
self._end = time.time() |
|
elapsed_time = (self._end - self._start) * 1e3 |
|
self._model_times.append(elapsed_time) |
|
|
|
def _create_model_parallel_group(self, config): |
|
|
|
if InferenceEngine.inference_mp_group is None: |
|
init_distributed() |
|
local_rank = int(os.getenv('LOCAL_RANK', '0')) |
|
get_accelerator().set_device(local_rank) |
|
|
|
ranks = [i for i in range(config.tensor_parallel.tp_size)] |
|
self.mp_group = dist.new_group(ranks) |
|
InferenceEngine.inference_mp_group = self.mp_group |
|
else: |
|
self.mp_group = InferenceEngine.inference_mp_group |
|
|
|
def _create_ep_parallel_group(self, moe_experts): |
|
|
|
self.ep_group = {} |
|
self.expert_mp_group = {} |
|
moe_experts = moe_experts if type(moe_experts) is list else [moe_experts] |
|
for e in moe_experts: |
|
self.ep_group.update({e: None}) |
|
self.expert_mp_group.update({e: None}) |
|
for moe_ep_size in self.ep_group.keys(): |
|
num_ep_groups = dist.get_world_size() // moe_ep_size |
|
for i in range(num_ep_groups): |
|
ep_cnt = i * moe_ep_size |
|
size = dist.get_world_size() if moe_ep_size > dist.get_world_size() else moe_ep_size |
|
ranks = list(range(ep_cnt, ep_cnt + size)) |
|
_ep_group = dist.new_group(ranks) |
|
if dist.get_rank() in ranks: |
|
self.ep_group.update({moe_ep_size: _ep_group}) |
|
|
|
if dist.get_world_size() > moe_ep_size: |
|
num_expert_mp_groups = dist.get_world_size() // num_ep_groups |
|
expert_mp_size = dist.get_world_size() // moe_ep_size |
|
for i in range(num_expert_mp_groups): |
|
expert_mp_comm_ranks = [i + nr * moe_ep_size for nr in range(expert_mp_size)] |
|
_expert_mp_group = dist.new_group(expert_mp_comm_ranks) |
|
if dist.get_rank() in expert_mp_comm_ranks: |
|
self.expert_mp_group.update({moe_ep_size: _expert_mp_group}) |
|
|
|
def _init_quantization_setting(self, quantization_setting): |
|
self.quantize_bits = 8 |
|
self.mlp_extra_grouping = False |
|
self.quantize_groups = 1 |
|
if type(quantization_setting) is tuple: |
|
self.mlp_extra_grouping, \ |
|
self.quantize_groups = quantization_setting |
|
elif quantization_setting is not None: |
|
self.quantize_groups = quantization_setting |
|
log_dist( |
|
f"quantize_bits = {self.quantize_bits} " |
|
f"mlp_extra_grouping = {self.mlp_extra_grouping}, " |
|
f"quantize_groups = {self.quantize_groups}", [0]) |
|
|
|
def load_model_with_checkpoint(self, r_module): |
|
self.mp_replace = ReplaceWithTensorSlicing( |
|
mp_group=self.mp_group, mp_size=self._config.tensor_parallel.tp_size) |
|
error_msgs = [] |
|
|
|
def load(module, state_dict, prefix): |
|
args = (state_dict, prefix, {}, True, [], [], error_msgs) |
|
if hasattr(module, 'weight'): |
|
if module.weight.data.is_meta: |
|
|
|
module.weight = torch.nn.parameter.Parameter(data=torch.empty_like(module.weight.data, |
|
device="cpu"), |
|
requires_grad=module.weight.data.requires_grad) |
|
if 'query_key_value' in prefix: |
|
module.weight = self.mp_replace.strided_copy(module.weight.data, |
|
state_dict[prefix + 'weight'], |
|
num_splits=3) |
|
else: |
|
module.weight = self.mp_replace.copy(module.weight.data, state_dict[prefix + 'weight']) |
|
else: |
|
if module.norm.weight.data.is_meta: |
|
|
|
module.norm.weight = torch.nn.parameter.Parameter( |
|
data=torch.empty_like(module.norm.weight.data, device="cpu"), |
|
requires_grad=module.norm.weight.data.requires_grad) |
|
module.norm.weight = self.mp_replace.copy(module.norm.weight.data, state_dict[prefix + 'weight']) |
|
if prefix + 'bias' in self.key_list: |
|
if hasattr(module, 'norm'): |
|
if module.norm.bias.data.is_meta: |
|
|
|
module.norm.bias = torch.nn.parameter.Parameter( |
|
data=torch.empty_like(module.norm.bias.data, device="cpu"), |
|
requires_grad=module.norm.bias.data.requires_grad) |
|
module.norm.bias = self.mp_replace.copy(module.norm.bias, state_dict[prefix + 'bias']) |
|
else: |
|
if module.bias.data.is_meta: |
|
|
|
module.bias = torch.nn.parameter.Parameter(data=torch.empty_like(module.bias.data, |
|
device="cpu"), |
|
requires_grad=module.bias.data.requires_grad) |
|
data = state_dict[prefix + 'bias'] |
|
data = data.to(get_accelerator().current_device_name()) |
|
module.bias = self.mp_replace.copy(module.bias, data) |
|
|
|
layer_policies = { |
|
nn.Linear: load, |
|
nn.Embedding: load, |
|
nn.LayerNorm: load, |
|
LinearLayer: load, |
|
LinearAllreduce: load |
|
} |
|
|
|
def load_module_recursive(module, prefix='', level=0): |
|
for name, child in module.named_children(): |
|
if child.__class__ in layer_policies: |
|
checking_key = prefix + name + '.' |
|
if not any(checking_key in item for item in self.key_list): |
|
continue |
|
if len(list(child.parameters())) > 0 and list(child.parameters())[0].numel() == 0: |
|
if len(child.weight.ds_shape) == 1: |
|
child = Normalize(dim=child.weight.ds_shape[-1], dtype=child.weight.dtype, eps=child.eps) |
|
setattr(module, name, child) |
|
load(child, self.sd, prefix + name + '.') |
|
else: |
|
load_module_recursive(child, prefix if level == 0 else prefix + name + '.', level + 1) |
|
|
|
load_module_recursive(r_module) |
|
|
|
embedding_weight = None |
|
|
|
for n, p in r_module.named_parameters(): |
|
if "word_embeddings." in n or "embed_tokens." in n or "wte." in n: |
|
embedding_weight = p |
|
if embedding_weight is not None and hasattr(r_module, "lm_head") and hasattr( |
|
r_module.lm_head, "weight") and r_module.lm_head.weight.is_meta: |
|
r_module.lm_head.weight = embedding_weight |
|
|
|
def _apply_injection_policy(self, config, client_module=None): |
|
|
|
checkpoint_dir = config.checkpoint |
|
checkpoint = SDLoaderFactory.get_sd_loader_json(checkpoint_dir, |
|
self.checkpoint_engine) if checkpoint_dir is not None else None |
|
|
|
generic_injection(self.module, dtype=config.dtype, enable_cuda_graph=config.enable_cuda_graph) |
|
|
|
if isinstance(self.module, torch.nn.Module): |
|
|
|
replace_transformer_layer(client_module, self.module, checkpoint, config, self.config) |
|
|
|
def _get_all_ckpt_names(self, checkpoints_path, tag): |
|
ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, tag, mp_placeholder="*") |
|
import glob |
|
|
|
ckpt_files = glob.glob(ckpt_file_pattern) |
|
ckpt_files.sort() |
|
return ckpt_files |
|
|
|
def _get_ckpt_name(self, checkpoints_path, tag, mp_placeholder=None): |
|
if mp_placeholder is not None: |
|
mp_rank_str = mp_placeholder |
|
else: |
|
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() |
|
mp_rank_str = "{:02d}".format(mp_rank) |
|
|
|
ckpt_name = os.path.join( |
|
checkpoints_path, |
|
"mp_rank_" + mp_rank_str + "_model_states.pt", |
|
) |
|
return ckpt_name |
|
|
|
def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None): |
|
is_pipe_parallel = isinstance(self.module, PipelineModule) |
|
if is_pipe_parallel: |
|
raise RuntimeError('pipeline parallelism is currently not supported in inference.') |
|
if not isinstance(load_dir, dict) and os.path.isdir(load_dir): |
|
if tag is None: |
|
latest_path = os.path.join(load_dir, "latest") |
|
if os.path.isfile(latest_path): |
|
with open(latest_path, "r") as fd: |
|
tag = fd.read().strip() |
|
|
|
ckpt_list = self._get_all_ckpt_names(load_dir, tag) |
|
sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, self.checkpoint_engine) |
|
else: |
|
sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir, self.checkpoint_engine) |
|
|
|
checkpoint = sd_loader['checkpoints'] |
|
|
|
if type(checkpoint) is list: |
|
self.sd = torch.load(checkpoint[0], map_location='cpu', weights_only=False) |
|
self.key_list = list(self.sd.keys()) |
|
|
|
self.load_model_with_checkpoint(self.module) |
|
|
|
for i in range(1, len(checkpoint)): |
|
if not dist.is_initialized() or dist.get_rank() == 0: |
|
print(f"loading checkpoint ({i})") |
|
self.sd = torch.load(checkpoint[i], map_location=get_accelerator().device_name(), weights_only=False) |
|
self.key_list = list(self.sd.keys()) |
|
self.load_model_with_checkpoint(self.module) |
|
else: |
|
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() |
|
|
|
load_path, checkpoint, quantize_config = sd_loader.load(self._config.tensor_parallel.tp_size, |
|
mp_rank, |
|
is_pipe_parallel=is_pipe_parallel, |
|
quantize=(self._config.dtype is torch.int8), |
|
quantize_groups=self.quantize_groups, |
|
mlp_extra_grouping=self.mlp_extra_grouping) |
|
|
|
self.quantization_scales, self.quantize_merge_count = quantize_config |
|
|
|
moe, _ = has_moe_layers(self.module) |
|
if moe: |
|
from deepspeed.runtime.engine import DeepSpeedEngine |
|
old_moe_load = False |
|
if not isinstance(checkpoint['num_experts'], list): |
|
old_moe_load = True |
|
DeepSpeedEngine.load_moe_state_dict(load_dir, |
|
tag, |
|
state_dict=checkpoint[self._choose_module_key(checkpoint)], |
|
old_moe_load=old_moe_load, |
|
model=self.module, |
|
mpu=self.mpu, |
|
checkpoint_engine=self.checkpoint_engine) |
|
|
|
self.module.load_state_dict(state_dict=checkpoint[self._choose_module_key(checkpoint)], |
|
strict=load_module_strict) |
|
|
|
def _choose_module_key(self, sd): |
|
assert not ('module' in sd |
|
and 'model' in sd), "checkpoint has both 'model' and 'module' keys, not sure how to proceed" |
|
assert 'module' in sd or 'model' in sd, "checkpoint contains neither 'model' or 'module' keys, not sure how to proceed" |
|
if 'module' in sd: |
|
return 'module' |
|
elif 'model' in sd: |
|
return 'model' |
|
|
|
def _convert_to_dtype(self, config): |
|
if not isinstance(self.module, torch.nn.Module): |
|
return |
|
|
|
if False: |
|
quantizer = WeightQuantization(mlp_extra_grouping=self.mlp_extra_grouping) |
|
model, self.quantization_scales = quantizer.model_quantize(self.module, self.injection_dict, |
|
self.quantize_bits, self.quantize_groups) |
|
elif config.dtype == torch.half: |
|
self.module.half() |
|
elif config.dtype == torch.bfloat16: |
|
self.module.bfloat16() |
|
elif config.dtype == torch.float: |
|
self.module.float() |
|
|
|
def _create_cuda_graph(self, *inputs, **kwargs): |
|
|
|
cuda_stream = get_accelerator().Stream() |
|
cuda_stream.wait_stream(get_accelerator().current_stream()) |
|
with get_accelerator().stream(cuda_stream): |
|
for i in range(3): |
|
ret = self.module(*inputs, **kwargs) |
|
get_accelerator().current_stream().wait_stream(cuda_stream) |
|
|
|
|
|
self._cuda_graphs = get_accelerator().create_graph() |
|
self.static_inputs = inputs |
|
self.static_kwargs = kwargs |
|
|
|
with get_accelerator().capture_to_graph(self._cuda_graphs): |
|
self.static_output = self.module(*self.static_inputs, **self.static_kwargs) |
|
|
|
self.cuda_graph_created = True |
|
|
|
def _graph_replay(self, *inputs, **kwargs): |
|
for i in range(len(inputs)): |
|
if torch.is_tensor(inputs[i]): |
|
self.static_inputs[i].copy_(inputs[i]) |
|
for k in kwargs: |
|
if torch.is_tensor(kwargs[k]): |
|
self.static_kwargs[k].copy_(kwargs[k]) |
|
get_accelerator().replay_graph(self._cuda_graphs) |
|
return self.static_output |
|
|
|
def model_times(self): |
|
assert self.model_profile_enabled, "model profiling is not enabled" |
|
model_times = self._model_times |
|
if self._config.enable_cuda_graph and len(self._model_times) == 0: |
|
raise ValueError("Model times are empty and cuda graph is enabled. If " |
|
"this is a GPT-style model this combo is not supported. If this is a " |
|
"BERT-style model this is a bug, please report it. " |
|
f"Model type is: {type(self.module)}") |
|
self._model_times = [] |
|
return model_times |
|
|
|
def _module_match(self, module): |
|
for policy in generic_policies: |
|
policy = policy() |
|
if policy.match_replaced(module): |
|
return True |
|
return False |
|
|
|
def _local_cuda_graph_used(self, module): |
|
if isinstance(module, torch.nn.Module): |
|
return False |
|
else: |
|
sub_module_cuda_graph = False |
|
for name in module.__dict__.keys(): |
|
sub_module = getattr(module, name) |
|
|
|
if self._module_match(sub_module) and hasattr(sub_module, "enable_cuda_graph"): |
|
sub_module_cuda_graph = True |
|
|
|
return sub_module_cuda_graph |
|
|
|
def forward(self, *inputs, **kwargs): |
|
"""Execute forward propagation |
|
|
|
Arguments: |
|
*inputs: Variable length input list |
|
**kwargs: variable length keyword arguments |
|
""" |
|
start = None |
|
if self.model_profile_enabled and get_accelerator().device_name() == 'cuda' and self._config.enable_cuda_graph: |
|
get_accelerator().synchronize() |
|
start = time.time() |
|
|
|
if get_accelerator().device_name() == 'cuda' and self._config.enable_cuda_graph and not self.local_cuda_graph: |
|
if self.cuda_graph_created: |
|
outputs = self._graph_replay(*inputs, **kwargs) |
|
else: |
|
self._create_cuda_graph(*inputs, **kwargs) |
|
outputs = self._graph_replay(*inputs, **kwargs) |
|
|
|
else: |
|
outputs = self.module(*inputs, **kwargs) |
|
|
|
if self.model_profile_enabled and self._config.enable_cuda_graph: |
|
get_accelerator().synchronize() |
|
duration = (time.time() - start) * 1e3 |
|
self._model_times.append(duration) |
|
|
|
return outputs |
|
|
|
def _generate(self, *inputs, **kwargs): |
|
|
|
if hasattr(self.module, 'reset_cache'): |
|
self.module.reset_cache() |
|
num_beams = 1 |
|
if "generation_config" in kwargs: |
|
gen_config = kwargs["generation_config"] |
|
num_beams = getattr(gen_config, "num_beams", 1) |
|
if "num_beams" in kwargs: |
|
num_beams = kwargs["num_beams"] |
|
|
|
if num_beams > 1: |
|
raise NotImplementedError("DeepSpeed does not support `num_beams` > 1, if this is important to you please " |
|
"add your request to: https://github.com/deepspeedai/DeepSpeed/issues/2506") |
|
|
|
if ("input_ids" in kwargs) and (kwargs["input_ids"].dim() == 2): |
|
for input_tensor in kwargs["input_ids"]: |
|
tensor_length = input_tensor.shape[-1] |
|
if tensor_length > self._config.max_out_tokens: |
|
raise RuntimeError( |
|
f"Input with size {tensor_length} exceeds maximum length of {self._config.max_out_tokens}. Please increase max_tokens in the DeepSpeed Inference Config." |
|
) |
|
|
|
return self.module.generate(*inputs, **kwargs) |
|
|
|
def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwargs={}) -> None: |
|
""" |
|
Compile the module using the specified backend and kwargs. |
|
""" |
|
if not is_compile_supported(): |
|
raise RuntimeError("compile is not supported in your version of PyTorch.") |
|
|
|
if self._is_compiled: |
|
return |
|
|
|
|
|
deepspeed.utils.nvtx.enable_nvtx = False |
|
self.module.compile(backend=backend, **compile_kwargs) |
|
self._is_compiled = True |
|
|
|
@property |
|
def is_compiled(self) -> bool: |
|
return self._is_compiled |
|
|