File size: 10,164 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Dict, List, Set
import torch
import torch.fx
from torch.fx.node import Node
def _get_count(param_count: Dict, node_name: str) -> int:
"""Identify different mutations of a given node name."""
# TODO(anj): This is not very stable since it is possible that the name
# may not be in the same format. Is there another way to identify nodes
# in a graph?
if node_name in param_count:
return param_count[node_name]
elif node_name.split("_")[0] in param_count:
return param_count[node_name.split("_")[0]]
else:
raise RuntimeError(f"Unable to find match between param {param_count} and node {node_name}")
def _create_shard_to_param_count(param_count: Dict, node_name_to_shard_id: Dict) -> Dict:
"""Utility to create a map from shard id to param count using existing state."""
shard_to_param_count: Dict[int, int] = {}
for node_name in node_name_to_shard_id.keys():
try:
count = _get_count(param_count, node_name)
except RuntimeError:
continue
if node_name_to_shard_id[node_name] in shard_to_param_count:
shard_to_param_count[node_name_to_shard_id[node_name]] += count
else:
shard_to_param_count[node_name_to_shard_id[node_name]] = count
return shard_to_param_count
def _split_nodes(traced_graph_module: torch.fx.GraphModule, shard_count: int = 3) -> Dict:
"""Utility used to trace a graph and identify shard cutpoints."""
node_name_to_shard_id: Dict[str, int] = {}
shard_id = 0
nodes_so_far = []
param_count: Dict[str, int] = {}
shard_to_param_count = {}
# Find the total number of params in the model and
# the number of params per shard we are aiming for.
for name, module in traced_graph_module.named_modules():
name = name.replace(".", "_")
param_count[name] = sum([x.numel() for x in module.parameters()])
logging.info(f"Total number of params are {param_count['']}")
per_shard_param = param_count[""] // shard_count
logging.info(f"Per shard param count {per_shard_param}")
for node in traced_graph_module.graph.nodes:
if node.op == "placeholder":
node_name_to_shard_id[node.name] = shard_id
nodes_so_far.append(node.name)
elif node.op in ["get_attr", "call_function", "call_method", "call_module"]:
min_shard_id = shard_id
min_node_name = ""
# For each of the args of a given node, find the arg that is not the
# last node we traversed. This is to help us find skip connections
# across shards.
for arg in node.args:
# If the node has args that are inputs to the forward function, they
# may not have explicit names.
if not hasattr(arg, "name"):
continue
if arg.name in node_name_to_shard_id and arg.name != nodes_so_far[-1]:
if node_name_to_shard_id[arg.name] < min_shard_id:
min_shard_id = node_name_to_shard_id[arg.name]
min_node_name = arg.name
# If there is an input that is not from the previous shard,
# we collapse all the shards in between to be part of 1 shard.
# and update the param count per shard accordingly.
if min_shard_id < shard_id:
for node_name in reversed(nodes_so_far):
node_name_to_shard_id[node_name] = min_shard_id
if node_name == min_node_name:
break
shard_id = min_shard_id
# TODO(anj-s): Find a way to raise an error early if this can cause OOM errors.
shard_to_param_count = _create_shard_to_param_count(param_count, node_name_to_shard_id)
# Update state that is tracking node -> shard id and shard id -> param count.
node_name_to_shard_id[node.name] = shard_id
nodes_so_far.append(node.name)
# TODO(anj): This could just be an update, we don't need to recreate the map.
shard_to_param_count = _create_shard_to_param_count(param_count, node_name_to_shard_id)
# If we have gone over the number of params per shard count that we want to
# achieve, we should add a new shard.
# The shard_id may not have been updated in the map if we are at a node that does not
# have params.
if shard_id in shard_to_param_count and shard_to_param_count[shard_id] > per_shard_param:
shard_id += 1
elif node.op == "output":
break
return node_name_to_shard_id
class _ExtendedLeafTracer(torch.fx.Tracer):
"""Tracer with an extended set of leaf nn.Modules."""
def __init__(self, leaf_modules: Set[torch.nn.Module]):
"""Initializes a new _ExtendedLeafTracer object.
Args:
leaf_modules: The set of extra nn.Modules instances which will not be traced
through but instead considered to be leaves.
"""
super().__init__()
self.leaf_modules = leaf_modules
def is_leaf_module(self, m: torch.nn.Module, model_qualified_name: str) -> bool:
return super().is_leaf_module(m, model_qualified_name) or m in self.leaf_modules
# TODO(ehotaj): Extend this method to wrap at the least granular level. One way to do
# would be to wrap the Module tree bottom up, first wrapping untracable children and
# only wrapping parents if they are also untracable.
def _trace(model: torch.nn.Module) -> torch.fx.GraphModule:
"""Traces the given model and automatically wraps untracable modules into leaves."""
leaf_modules = set()
tracer = _ExtendedLeafTracer(leaf_modules)
for name, module in model.named_modules():
# TODO(ehotaj): The default is_leaf_module includes everything in torch.nn.
# This means that some coarse modules like nn.TransformerEncoder are treated
# as leaves, not traced, and are unable to be sharded. We may want to extend our
# sharding code to trace through these modules as well.
if tracer.is_leaf_module(module, ""):
continue
try:
tracer.trace(module)
except (TypeError, torch.fx.proxy.TraceError):
leaf_modules.add(module)
tracer = _ExtendedLeafTracer(leaf_modules)
graph = tracer.trace(model)
return torch.fx.GraphModule(model, graph)
def shard_model(model: torch.nn.Module, shard_count: int = 3) -> List[torch.fx.GraphModule]:
"""Utility used to shard a model using torch.fx.
This function traces the model twice in an attempt to identify the
right cutpoints and then shard the model. In the first pass we calculate
the number of parameters as we are tracing the graph and mark nodes at
which we might want to create a new module. In the second pass we
modify the graph by inserting placeholders and output nodes to essentially
shard the graph.
We don't support skip connections between shards. This means that all
input and output is self contained within a given shard. A node from
shard 1 cannot be an input to a node from shard 3. We expect all inputs
to a given shard to be coming from the last node in the previous shard.
This means that we may not be able to shard models by the specified
`shard_count` mentioned by the user.
Args:
model (nn.Module): Model to be sharded as specified by the device count.
shard_count (int): Number of shards that we want to split the model into.
"""
module_list: List[torch.fx.GraphModule] = []
num_graphs = 0
new_graph = torch.fx.Graph() # type: ignore
env: Dict[str, Node] = {}
new_input_node = None
traced_graph_module = _trace(model)
# This is the first pass where we attempt to get a map of where
# we need to insert placeholder and output nodes.
node_name_to_shard_id = _split_nodes(traced_graph_module, shard_count=shard_count)
# dummy value which indicates that this is the first node.
prev_shard_id = 1000
prev_node = None
for node in traced_graph_module.graph.nodes:
# If the current node is in the next shard, we insert an output node.
# A new graph is created and a placeholder is added for the next shard.
if node.name in node_name_to_shard_id and prev_shard_id < node_name_to_shard_id[node.name]:
assert prev_node, "prev_node cannot be None"
with new_graph.inserting_after(prev_node):
new_graph.output(env[prev_node.name])
num_graphs += 1
module_list.append(torch.fx.GraphModule(model, new_graph))
new_graph = torch.fx.Graph()
node_name = "placeholder" + str(num_graphs)
pl_node = new_graph.create_node("placeholder", node_name)
env[node_name] = pl_node
new_input_node = pl_node
if new_input_node is not None:
# Account for a placeholder in the new graph.
node.args = (new_input_node,)
new_input_node = None
if node.op in ["placeholder", "get_attr", "call_function", "call_method", "call_module"]:
# Copy the nodes from the existing graph to the new graph.
new_node = new_graph.node_copy(node, lambda x: env[x.name])
env[node.name] = new_node
elif node.op == "output":
# If this is the last node, we should add an output
# node and add the last graph to the list.
assert prev_node, "prev_node cannot be None"
with new_graph.inserting_after(prev_node):
new_graph.output(env[prev_node.name])
module_list.append(torch.fx.GraphModule(model, new_graph))
break
prev_node = new_node
prev_shard_id = node_name_to_shard_id[node.name]
return module_list
|