from collections.abc import Sequence import torch.fx as fx __all__ = ["set_trace"] def set_trace(gm: fx.GraphModule) -> fx.GraphModule: """ Sets a breakpoint in `gm`'s generated python code. It drops into pdb when `gm` gets run. Args: gm: graph module to insert breakpoint. It is then recompiled for it to take effect. Returns: the `gm` with breakpoint inserted. """ def insert_pdb(body: Sequence[str]) -> list[str]: return ["import pdb; pdb.set_trace()\n", *body] with gm.graph.on_generate_code( make_transformer=lambda cur_transform: ( # new code transformer to register lambda body: (insert_pdb(cur_transform(body) if cur_transform else body)) ) ): gm.recompile() return gm