|
|
|
|
|
|
|
|
|
|
|
import logging |
|
from typing import Dict, List, Set |
|
|
|
import torch |
|
import torch.fx |
|
from torch.fx.node import Node |
|
|
|
|
|
def _get_count(param_count: Dict, node_name: str) -> int: |
|
"""Identify different mutations of a given node name.""" |
|
|
|
|
|
|
|
if node_name in param_count: |
|
return param_count[node_name] |
|
elif node_name.split("_")[0] in param_count: |
|
return param_count[node_name.split("_")[0]] |
|
else: |
|
raise RuntimeError(f"Unable to find match between param {param_count} and node {node_name}") |
|
|
|
|
|
def _create_shard_to_param_count(param_count: Dict, node_name_to_shard_id: Dict) -> Dict: |
|
"""Utility to create a map from shard id to param count using existing state.""" |
|
|
|
shard_to_param_count: Dict[int, int] = {} |
|
for node_name in node_name_to_shard_id.keys(): |
|
try: |
|
count = _get_count(param_count, node_name) |
|
except RuntimeError: |
|
continue |
|
if node_name_to_shard_id[node_name] in shard_to_param_count: |
|
shard_to_param_count[node_name_to_shard_id[node_name]] += count |
|
else: |
|
shard_to_param_count[node_name_to_shard_id[node_name]] = count |
|
return shard_to_param_count |
|
|
|
|
|
def _split_nodes(traced_graph_module: torch.fx.GraphModule, shard_count: int = 3) -> Dict: |
|
"""Utility used to trace a graph and identify shard cutpoints.""" |
|
|
|
node_name_to_shard_id: Dict[str, int] = {} |
|
shard_id = 0 |
|
nodes_so_far = [] |
|
param_count: Dict[str, int] = {} |
|
shard_to_param_count = {} |
|
|
|
|
|
|
|
for name, module in traced_graph_module.named_modules(): |
|
name = name.replace(".", "_") |
|
param_count[name] = sum([x.numel() for x in module.parameters()]) |
|
logging.info(f"Total number of params are {param_count['']}") |
|
per_shard_param = param_count[""] // shard_count |
|
logging.info(f"Per shard param count {per_shard_param}") |
|
|
|
for node in traced_graph_module.graph.nodes: |
|
if node.op == "placeholder": |
|
node_name_to_shard_id[node.name] = shard_id |
|
nodes_so_far.append(node.name) |
|
elif node.op in ["get_attr", "call_function", "call_method", "call_module"]: |
|
|
|
min_shard_id = shard_id |
|
min_node_name = "" |
|
|
|
|
|
|
|
for arg in node.args: |
|
|
|
|
|
if not hasattr(arg, "name"): |
|
continue |
|
|
|
if arg.name in node_name_to_shard_id and arg.name != nodes_so_far[-1]: |
|
if node_name_to_shard_id[arg.name] < min_shard_id: |
|
min_shard_id = node_name_to_shard_id[arg.name] |
|
min_node_name = arg.name |
|
|
|
|
|
|
|
|
|
if min_shard_id < shard_id: |
|
for node_name in reversed(nodes_so_far): |
|
node_name_to_shard_id[node_name] = min_shard_id |
|
if node_name == min_node_name: |
|
break |
|
shard_id = min_shard_id |
|
|
|
shard_to_param_count = _create_shard_to_param_count(param_count, node_name_to_shard_id) |
|
|
|
|
|
node_name_to_shard_id[node.name] = shard_id |
|
nodes_so_far.append(node.name) |
|
|
|
shard_to_param_count = _create_shard_to_param_count(param_count, node_name_to_shard_id) |
|
|
|
|
|
|
|
|
|
if shard_id in shard_to_param_count and shard_to_param_count[shard_id] > per_shard_param: |
|
shard_id += 1 |
|
elif node.op == "output": |
|
break |
|
return node_name_to_shard_id |
|
|
|
|
|
class _ExtendedLeafTracer(torch.fx.Tracer): |
|
"""Tracer with an extended set of leaf nn.Modules.""" |
|
|
|
def __init__(self, leaf_modules: Set[torch.nn.Module]): |
|
"""Initializes a new _ExtendedLeafTracer object. |
|
|
|
Args: |
|
leaf_modules: The set of extra nn.Modules instances which will not be traced |
|
through but instead considered to be leaves. |
|
""" |
|
super().__init__() |
|
self.leaf_modules = leaf_modules |
|
|
|
def is_leaf_module(self, m: torch.nn.Module, model_qualified_name: str) -> bool: |
|
return super().is_leaf_module(m, model_qualified_name) or m in self.leaf_modules |
|
|
|
|
|
|
|
|
|
|
|
def _trace(model: torch.nn.Module) -> torch.fx.GraphModule: |
|
"""Traces the given model and automatically wraps untracable modules into leaves.""" |
|
leaf_modules = set() |
|
tracer = _ExtendedLeafTracer(leaf_modules) |
|
for name, module in model.named_modules(): |
|
|
|
|
|
|
|
|
|
if tracer.is_leaf_module(module, ""): |
|
continue |
|
try: |
|
tracer.trace(module) |
|
except (TypeError, torch.fx.proxy.TraceError): |
|
leaf_modules.add(module) |
|
tracer = _ExtendedLeafTracer(leaf_modules) |
|
graph = tracer.trace(model) |
|
return torch.fx.GraphModule(model, graph) |
|
|
|
|
|
def shard_model(model: torch.nn.Module, shard_count: int = 3) -> List[torch.fx.GraphModule]: |
|
"""Utility used to shard a model using torch.fx. |
|
|
|
This function traces the model twice in an attempt to identify the |
|
right cutpoints and then shard the model. In the first pass we calculate |
|
the number of parameters as we are tracing the graph and mark nodes at |
|
which we might want to create a new module. In the second pass we |
|
modify the graph by inserting placeholders and output nodes to essentially |
|
shard the graph. |
|
|
|
We don't support skip connections between shards. This means that all |
|
input and output is self contained within a given shard. A node from |
|
shard 1 cannot be an input to a node from shard 3. We expect all inputs |
|
to a given shard to be coming from the last node in the previous shard. |
|
This means that we may not be able to shard models by the specified |
|
`shard_count` mentioned by the user. |
|
|
|
Args: |
|
model (nn.Module): Model to be sharded as specified by the device count. |
|
|
|
shard_count (int): Number of shards that we want to split the model into. |
|
|
|
""" |
|
module_list: List[torch.fx.GraphModule] = [] |
|
num_graphs = 0 |
|
new_graph = torch.fx.Graph() |
|
env: Dict[str, Node] = {} |
|
new_input_node = None |
|
|
|
traced_graph_module = _trace(model) |
|
|
|
|
|
|
|
node_name_to_shard_id = _split_nodes(traced_graph_module, shard_count=shard_count) |
|
|
|
|
|
prev_shard_id = 1000 |
|
prev_node = None |
|
for node in traced_graph_module.graph.nodes: |
|
|
|
|
|
if node.name in node_name_to_shard_id and prev_shard_id < node_name_to_shard_id[node.name]: |
|
assert prev_node, "prev_node cannot be None" |
|
|
|
with new_graph.inserting_after(prev_node): |
|
new_graph.output(env[prev_node.name]) |
|
num_graphs += 1 |
|
module_list.append(torch.fx.GraphModule(model, new_graph)) |
|
new_graph = torch.fx.Graph() |
|
node_name = "placeholder" + str(num_graphs) |
|
pl_node = new_graph.create_node("placeholder", node_name) |
|
env[node_name] = pl_node |
|
new_input_node = pl_node |
|
|
|
if new_input_node is not None: |
|
|
|
node.args = (new_input_node,) |
|
new_input_node = None |
|
if node.op in ["placeholder", "get_attr", "call_function", "call_method", "call_module"]: |
|
|
|
new_node = new_graph.node_copy(node, lambda x: env[x.name]) |
|
env[node.name] = new_node |
|
elif node.op == "output": |
|
|
|
|
|
assert prev_node, "prev_node cannot be None" |
|
|
|
with new_graph.inserting_after(prev_node): |
|
new_graph.output(env[prev_node.name]) |
|
module_list.append(torch.fx.GraphModule(model, new_graph)) |
|
break |
|
prev_node = new_node |
|
prev_shard_id = node_name_to_shard_id[node.name] |
|
|
|
return module_list |
|
|