from collections.abc import Sequence | |
import torch.fx as fx | |
__all__ = ["set_trace"] | |
def set_trace(gm: fx.GraphModule) -> fx.GraphModule: | |
""" | |
Sets a breakpoint in `gm`'s generated python code. It drops into pdb when | |
`gm` gets run. | |
Args: | |
gm: graph module to insert breakpoint. It is then recompiled for it to | |
take effect. | |
Returns: | |
the `gm` with breakpoint inserted. | |
""" | |
def insert_pdb(body: Sequence[str]) -> list[str]: | |
return ["import pdb; pdb.set_trace()\n", *body] | |
with gm.graph.on_generate_code( | |
make_transformer=lambda cur_transform: ( | |
# new code transformer to register | |
lambda body: (insert_pdb(cur_transform(body) if cur_transform else body)) | |
) | |
): | |
gm.recompile() | |
return gm | |