File size: 2,046 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import List
import torch
from torch.fx import Node, GraphModule
from deepspeed.compile.util import get_last_uses
from ..graph_param import DSGraphParamManager
def add_offload_parameter(graph_id: int, gm: GraphModule, node: Node, ds_id: int):
new_node = None
with gm.graph.inserting_after(node):
args = (node, )
for a in [graph_id, ds_id]: # To add ds_id
args += (a, )
new_node = gm.graph.create_node('call_function',
torch.ops.dc.offload_parameter.default,
args, {},
name="offload_parameter")
return new_node
def add_reload_parameter(graph_id: int, gm: GraphModule, node: Node, ds_id: int):
new_node = None
with gm.graph.inserting_after(node):
args = (node, )
for a in [graph_id, ds_id]: # To add ds_id
args += (a, )
new_node = gm.graph.create_node('call_function',
torch.ops.dc.reload_parameter.default,
args, {},
name=f"reload_parameter")
return new_node
def get_ds_id(node: Node):
assert node.target == torch.ops.dc.allgather_param.default
return node.args[2]
def offload_parameter_fwd(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
node_to_last_use, user_to_last_uses = get_last_uses(gm.graph)
for node in gm.graph.nodes:
if (isinstance(node, Node) and node.target == torch.ops.dc.allgather_param.default):
add_reload_parameter(graph_id, gm, node.args[0], get_ds_id(node))
add_offload_parameter(graph_id, gm, node_to_last_use[node], get_ds_id(node))
gm.graph.lint()
return gm
|