File size: 811 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
from collections.abc import Sequence
import torch.fx as fx
__all__ = ["set_trace"]
def set_trace(gm: fx.GraphModule) -> fx.GraphModule:
"""
Sets a breakpoint in `gm`'s generated python code. It drops into pdb when
`gm` gets run.
Args:
gm: graph module to insert breakpoint. It is then recompiled for it to
take effect.
Returns:
the `gm` with breakpoint inserted.
"""
def insert_pdb(body: Sequence[str]) -> list[str]:
return ["import pdb; pdb.set_trace()\n", *body]
with gm.graph.on_generate_code(
make_transformer=lambda cur_transform: (
# new code transformer to register
lambda body: (insert_pdb(cur_transform(body) if cur_transform else body))
)
):
gm.recompile()
return gm
|