# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import logging from typing import Dict, List, Set import torch import torch.fx from torch.fx.node import Node def _get_count(param_count: Dict, node_name: str) -> int: """Identify different mutations of a given node name.""" # TODO(anj): This is not very stable since it is possible that the name # may not be in the same format. Is there another way to identify nodes # in a graph? if node_name in param_count: return param_count[node_name] elif node_name.split("_")[0] in param_count: return param_count[node_name.split("_")[0]] else: raise RuntimeError(f"Unable to find match between param {param_count} and node {node_name}") def _create_shard_to_param_count(param_count: Dict, node_name_to_shard_id: Dict) -> Dict: """Utility to create a map from shard id to param count using existing state.""" shard_to_param_count: Dict[int, int] = {} for node_name in node_name_to_shard_id.keys(): try: count = _get_count(param_count, node_name) except RuntimeError: continue if node_name_to_shard_id[node_name] in shard_to_param_count: shard_to_param_count[node_name_to_shard_id[node_name]] += count else: shard_to_param_count[node_name_to_shard_id[node_name]] = count return shard_to_param_count def _split_nodes(traced_graph_module: torch.fx.GraphModule, shard_count: int = 3) -> Dict: """Utility used to trace a graph and identify shard cutpoints.""" node_name_to_shard_id: Dict[str, int] = {} shard_id = 0 nodes_so_far = [] param_count: Dict[str, int] = {} shard_to_param_count = {} # Find the total number of params in the model and # the number of params per shard we are aiming for. for name, module in traced_graph_module.named_modules(): name = name.replace(".", "_") param_count[name] = sum([x.numel() for x in module.parameters()]) logging.info(f"Total number of params are {param_count['']}") per_shard_param = param_count[""] // shard_count logging.info(f"Per shard param count {per_shard_param}") for node in traced_graph_module.graph.nodes: if node.op == "placeholder": node_name_to_shard_id[node.name] = shard_id nodes_so_far.append(node.name) elif node.op in ["get_attr", "call_function", "call_method", "call_module"]: min_shard_id = shard_id min_node_name = "" # For each of the args of a given node, find the arg that is not the # last node we traversed. This is to help us find skip connections # across shards. for arg in node.args: # If the node has args that are inputs to the forward function, they # may not have explicit names. if not hasattr(arg, "name"): continue if arg.name in node_name_to_shard_id and arg.name != nodes_so_far[-1]: if node_name_to_shard_id[arg.name] < min_shard_id: min_shard_id = node_name_to_shard_id[arg.name] min_node_name = arg.name # If there is an input that is not from the previous shard, # we collapse all the shards in between to be part of 1 shard. # and update the param count per shard accordingly. if min_shard_id < shard_id: for node_name in reversed(nodes_so_far): node_name_to_shard_id[node_name] = min_shard_id if node_name == min_node_name: break shard_id = min_shard_id # TODO(anj-s): Find a way to raise an error early if this can cause OOM errors. shard_to_param_count = _create_shard_to_param_count(param_count, node_name_to_shard_id) # Update state that is tracking node -> shard id and shard id -> param count. node_name_to_shard_id[node.name] = shard_id nodes_so_far.append(node.name) # TODO(anj): This could just be an update, we don't need to recreate the map. shard_to_param_count = _create_shard_to_param_count(param_count, node_name_to_shard_id) # If we have gone over the number of params per shard count that we want to # achieve, we should add a new shard. # The shard_id may not have been updated in the map if we are at a node that does not # have params. if shard_id in shard_to_param_count and shard_to_param_count[shard_id] > per_shard_param: shard_id += 1 elif node.op == "output": break return node_name_to_shard_id class _ExtendedLeafTracer(torch.fx.Tracer): """Tracer with an extended set of leaf nn.Modules.""" def __init__(self, leaf_modules: Set[torch.nn.Module]): """Initializes a new _ExtendedLeafTracer object. Args: leaf_modules: The set of extra nn.Modules instances which will not be traced through but instead considered to be leaves. """ super().__init__() self.leaf_modules = leaf_modules def is_leaf_module(self, m: torch.nn.Module, model_qualified_name: str) -> bool: return super().is_leaf_module(m, model_qualified_name) or m in self.leaf_modules # TODO(ehotaj): Extend this method to wrap at the least granular level. One way to do # would be to wrap the Module tree bottom up, first wrapping untracable children and # only wrapping parents if they are also untracable. def _trace(model: torch.nn.Module) -> torch.fx.GraphModule: """Traces the given model and automatically wraps untracable modules into leaves.""" leaf_modules = set() tracer = _ExtendedLeafTracer(leaf_modules) for name, module in model.named_modules(): # TODO(ehotaj): The default is_leaf_module includes everything in torch.nn. # This means that some coarse modules like nn.TransformerEncoder are treated # as leaves, not traced, and are unable to be sharded. We may want to extend our # sharding code to trace through these modules as well. if tracer.is_leaf_module(module, ""): continue try: tracer.trace(module) except (TypeError, torch.fx.proxy.TraceError): leaf_modules.add(module) tracer = _ExtendedLeafTracer(leaf_modules) graph = tracer.trace(model) return torch.fx.GraphModule(model, graph) def shard_model(model: torch.nn.Module, shard_count: int = 3) -> List[torch.fx.GraphModule]: """Utility used to shard a model using torch.fx. This function traces the model twice in an attempt to identify the right cutpoints and then shard the model. In the first pass we calculate the number of parameters as we are tracing the graph and mark nodes at which we might want to create a new module. In the second pass we modify the graph by inserting placeholders and output nodes to essentially shard the graph. We don't support skip connections between shards. This means that all input and output is self contained within a given shard. A node from shard 1 cannot be an input to a node from shard 3. We expect all inputs to a given shard to be coming from the last node in the previous shard. This means that we may not be able to shard models by the specified `shard_count` mentioned by the user. Args: model (nn.Module): Model to be sharded as specified by the device count. shard_count (int): Number of shards that we want to split the model into. """ module_list: List[torch.fx.GraphModule] = [] num_graphs = 0 new_graph = torch.fx.Graph() # type: ignore env: Dict[str, Node] = {} new_input_node = None traced_graph_module = _trace(model) # This is the first pass where we attempt to get a map of where # we need to insert placeholder and output nodes. node_name_to_shard_id = _split_nodes(traced_graph_module, shard_count=shard_count) # dummy value which indicates that this is the first node. prev_shard_id = 1000 prev_node = None for node in traced_graph_module.graph.nodes: # If the current node is in the next shard, we insert an output node. # A new graph is created and a placeholder is added for the next shard. if node.name in node_name_to_shard_id and prev_shard_id < node_name_to_shard_id[node.name]: assert prev_node, "prev_node cannot be None" with new_graph.inserting_after(prev_node): new_graph.output(env[prev_node.name]) num_graphs += 1 module_list.append(torch.fx.GraphModule(model, new_graph)) new_graph = torch.fx.Graph() node_name = "placeholder" + str(num_graphs) pl_node = new_graph.create_node("placeholder", node_name) env[node_name] = pl_node new_input_node = pl_node if new_input_node is not None: # Account for a placeholder in the new graph. node.args = (new_input_node,) new_input_node = None if node.op in ["placeholder", "get_attr", "call_function", "call_method", "call_module"]: # Copy the nodes from the existing graph to the new graph. new_node = new_graph.node_copy(node, lambda x: env[x.name]) env[node.name] = new_node elif node.op == "output": # If this is the last node, we should add an output # node and add the last graph to the list. assert prev_node, "prev_node cannot be None" with new_graph.inserting_after(prev_node): new_graph.output(env[prev_node.name]) module_list.append(torch.fx.GraphModule(model, new_graph)) break prev_node = new_node prev_shard_id = node_name_to_shard_id[node.name] return module_list