|
|
|
|
|
""" |
|
This module provides utilities for generating Python bytecode in PyTorch's Dynamo system. |
|
It includes functionality for: |
|
- Constructing bytecode sequences for Python operations |
|
- Managing stack operations and variable tracking |
|
- Handling graph outputs and their conversions |
|
- Supporting different Python versions (3.11+, 3.12+, 3.13+) |
|
- Converting high-level operations to low-level bytecode instructions |
|
- Managing constant loading and attribute access |
|
- Supporting function creation and closure handling |
|
""" |
|
|
|
import collections |
|
import dataclasses |
|
import re |
|
import sys |
|
import types |
|
from collections import Counter |
|
from typing import Optional, Union |
|
|
|
import torch.nn |
|
from torch.utils._ordered_set import OrderedSet |
|
|
|
from . import graph_break_hints, utils |
|
from .bytecode_transformation import ( |
|
add_push_null, |
|
add_push_null_call_function_ex, |
|
create_call_function, |
|
create_call_method, |
|
create_dup_top, |
|
create_instruction, |
|
create_load_const, |
|
create_load_method, |
|
create_rot_n, |
|
Instruction, |
|
) |
|
from .exc import IncorrectUsage, unimplemented_v2 |
|
from .source import AttrSource, ChainedSource, DictGetItemSource, Source |
|
from .utils import is_safe_constant, rot_n_helper |
|
from .variables.base import ValueMutationExisting, VariableTracker |
|
from .variables.functions import ( |
|
ContextlibContextManagerLocalGeneratorObjectVariable, |
|
LocalGeneratorObjectVariable, |
|
) |
|
from .variables.nn_module import NNModuleVariable |
|
from .variables.tensor import ( |
|
NumpyNdarrayVariable, |
|
SymNodeVariable, |
|
TensorVariable, |
|
UnspecializedPythonVariable, |
|
) |
|
from .variables.torch_function import TensorWithTFOverrideVariable |
|
|
|
|
|
@dataclasses.dataclass |
|
class GraphOutputEntry: |
|
index: int |
|
variable: VariableTracker |
|
|
|
|
|
class PyCodegen: |
|
""" |
|
Helper class uses for constructing Python bytecode |
|
""" |
|
|
|
def __init__( |
|
self, |
|
tx=None, |
|
root: Optional[torch.nn.Module] = None, |
|
graph_output_var: Optional[str] = None, |
|
tempvars=None, |
|
overridden_sources=None, |
|
) -> None: |
|
self.root = root |
|
self.top_of_stack: Optional[Union[VariableTracker, Source]] = None |
|
self.uses: Counter[VariableTracker] = collections.Counter() |
|
self.graph_outputs: dict[int, GraphOutputEntry] = {} |
|
self._output: list[Instruction] = [] |
|
|
|
|
|
|
|
|
|
self.tempvars = tempvars or {} |
|
self.tx = tx |
|
self.graph_output_var = graph_output_var |
|
self.code_options = self.tx.output.code_options |
|
self.cell_and_freevars = self.tx.cell_and_freevars |
|
self.new_var = self.tx.output.new_var |
|
self.value_from_source: bool = True |
|
|
|
|
|
|
|
self.overridden_sources: dict[Source, Source] = overridden_sources or {} |
|
|
|
def restore_stack(self, stack_values, *, value_from_source=True): |
|
prev = self.value_from_source |
|
self.value_from_source &= value_from_source |
|
try: |
|
self.foreach(stack_values) |
|
finally: |
|
self.value_from_source = prev |
|
|
|
def graph_output_vars(self): |
|
return [x.variable for x in self.graph_outputs.values()] |
|
|
|
def call_reconstruct(self, value): |
|
res = value.reconstruct(self) |
|
assert res is None, f"reconstruct!=None {value}" |
|
|
|
def add_push_null(self, gen_fn, call_function_ex=False): |
|
""" |
|
`gen_fn` generates instructions via PyCodegen methods |
|
that push a single callable to the stack. |
|
|
|
`add_push_null` pushes a NULL to the stack before or after the |
|
instructions generated by `gen_fn`, depending on Python version. |
|
|
|
Will attempt to use the NULL push bit for instructions |
|
with such bits (LOAD_GLOBAL 3.11+, LOAD_ATTR 3.12+, LOAD_SUPER_ATTR). |
|
""" |
|
old_len = len(self._output) |
|
if sys.version_info < (3, 13): |
|
|
|
|
|
|
|
self.clear_tos() |
|
gen_fn() |
|
|
|
added_insts = self._output[old_len:] |
|
del self._output[old_len:] |
|
if call_function_ex: |
|
self._output.extend(add_push_null_call_function_ex(added_insts)) |
|
else: |
|
self._output.extend(add_push_null(added_insts)) |
|
if sys.version_info >= (3, 13): |
|
|
|
self.clear_tos() |
|
|
|
def __call__(self, value, allow_cache=True): |
|
""" |
|
Generate code such that top-of-stack (TOS) is set to value. |
|
|
|
`allow_cache` controls the behavior in the following manner. `value` can |
|
either be a VariableTracker or a Source. |
|
|
|
If `value` is a `Source`, `allow_cache` must be True (invariant asserted |
|
below). If the source was reconstructed earlier, we will reuse the |
|
generated code by loading from top of stack or tempvars. |
|
|
|
If `value` is a `VariableTracker`, we have the following cases: |
|
|
|
1) `allow_cache=True` |
|
a) If the value.source is not None, we will emit the code based on |
|
`value.source` to handle aliasing. |
|
b) If value.source is None (example reconstructing a local list |
|
returned by the compiled function), we will reconstruct the variable |
|
tracker (w/o any source) to emit bytecode that generates a new |
|
python object. |
|
|
|
In both cases of value.source being None or not, if the value was |
|
reconstructed earlier, we will reuse the generated code by loading from |
|
top of stack or tempvars. |
|
|
|
2) `allow_cache=False` - This is a special case (allow_cache defaults to |
|
True). |
|
a) If the value.source is not None, we reconstruct the variable |
|
tracker and emit a new python object. You might wonder what about |
|
aliasing? The place where we use this config also has the followup |
|
code where the original python object is assigned to this new python |
|
value to handle aliasing (check side_effects.py and search for |
|
allow_cache=False). |
|
|
|
b) If value.source is None, this is not allowed. TODO - assert this. |
|
|
|
Notable effects: |
|
1. `self.top_of_stack` will be set to `value`, if we don't codegen |
|
`value` based on source. |
|
2. `self.uses[value]` will increment, if we don't codegen `value` based |
|
on source or cache/top-of-stack reuse; in other words, if we codegen |
|
as if `value` is modelling some brand new python value. |
|
""" |
|
if isinstance(value, Source): |
|
|
|
source = self.overridden_sources.get(value, value) |
|
assert allow_cache is True, "allow_cache must be True for Source" |
|
if self.top_of_stack is value: |
|
self._output.append(create_dup_top()) |
|
return |
|
|
|
if self.tempvars.get(source) is not None: |
|
self._output.append(self.create_load(self.tempvars[source])) |
|
self.top_of_stack = source |
|
return |
|
|
|
try: |
|
self.call_reconstruct(source) |
|
except NotImplementedError: |
|
unimplemented_v2( |
|
gb_type="Reconstruction failure: source.reconstruct not implemented", |
|
context=str(source), |
|
explanation=f"Dynamo has no bytecode reconstruction implemented for {type(source)} variable {source}.", |
|
hints=[*graph_break_hints.DYNAMO_BUG], |
|
) |
|
|
|
self._output.append(create_dup_top()) |
|
self.add_cache(source) |
|
self.top_of_stack = source |
|
|
|
return |
|
|
|
assert isinstance(value, VariableTracker) |
|
output = self._output |
|
graph_outputs = self.graph_outputs |
|
|
|
if allow_cache: |
|
if self.top_of_stack is value: |
|
output.append(create_dup_top()) |
|
return |
|
|
|
if self.tempvars.get(value) is not None: |
|
output.append(self.create_load(self.tempvars[value])) |
|
self.top_of_stack = value |
|
return |
|
|
|
if value.is_realized() and isinstance( |
|
value, ContextlibContextManagerLocalGeneratorObjectVariable |
|
): |
|
raise IncorrectUsage( |
|
"NYI: Returning a @contextmanager object from a torch.compile function" |
|
) |
|
|
|
|
|
if ( |
|
value.source is not None |
|
and allow_cache |
|
and not ( |
|
value.is_realized() and isinstance(value, LocalGeneratorObjectVariable) |
|
) |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
isinstance(value.mutation_type, ValueMutationExisting) |
|
or self.value_from_source |
|
): |
|
return self(value.source) |
|
|
|
if value.is_python_constant() and is_safe_constant(value.as_python_constant()): |
|
output.append(self.create_load_const(value.as_python_constant())) |
|
elif isinstance(value, TensorWithTFOverrideVariable): |
|
graph_outputs_key = self.add_graph_output(value) |
|
|
|
self.add_push_null( |
|
lambda: self.load_import_from(utils.__name__, "to_subclass") |
|
) |
|
self.load_graph_output(graph_outputs[graph_outputs_key].index) |
|
output.append( |
|
self.create_load_global( |
|
value.global_mangled_class_name(self.tx), add=True |
|
) |
|
) |
|
output.extend(create_call_function(2, False)) |
|
elif ( |
|
isinstance(value, SymNodeVariable) |
|
and value.python_type() == float |
|
and not self.tx.export |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graph_outputs_key = self.add_graph_output( |
|
value.as_tensor(self.tx, torch.float64) |
|
) |
|
|
|
def gen_fn(): |
|
self.load_graph_output(graph_outputs[graph_outputs_key].index) |
|
output.append(self.create_load_attr("item")) |
|
|
|
self.add_push_null(gen_fn) |
|
output.extend(create_call_function(0, False)) |
|
elif isinstance( |
|
value, |
|
( |
|
TensorVariable, |
|
SymNodeVariable, |
|
UnspecializedPythonVariable, |
|
NumpyNdarrayVariable, |
|
), |
|
): |
|
graph_outputs_key = self.add_graph_output(value) |
|
|
|
if isinstance(value, NumpyNdarrayVariable): |
|
self.add_push_null( |
|
lambda: self.load_import_from(utils.__name__, "to_numpy_helper") |
|
) |
|
self.load_graph_output(graph_outputs[graph_outputs_key].index) |
|
output.extend(create_call_function(1, False)) |
|
elif isinstance(value, UnspecializedPythonVariable) and value.need_unwrap: |
|
|
|
def gen_fn(): |
|
self.load_graph_output(graph_outputs[graph_outputs_key].index) |
|
output.append(self.create_load_attr("item")) |
|
|
|
self.add_push_null(gen_fn) |
|
output.extend(create_call_function(0, False)) |
|
else: |
|
self.load_graph_output(graph_outputs[graph_outputs_key].index) |
|
elif isinstance(value, NNModuleVariable): |
|
parts = value.module_key.split(".") |
|
if parts[0] in self.code_options["co_varnames"]: |
|
output.append(self.create_load(parts[0])) |
|
parts = parts[1:] |
|
else: |
|
assert self.root is not None |
|
output.append(self.create_load_const_unchecked(self.root)) |
|
for part in parts: |
|
output.append(self.create_load_attr(part)) |
|
else: |
|
self.uses[value] += 1 |
|
try: |
|
self.call_reconstruct(value) |
|
except NotImplementedError: |
|
unimplemented_v2( |
|
gb_type="Reconstruction failure", |
|
context=str(value), |
|
explanation=f"Dynamo has no bytecode reconstruction implemented for sourceless variable {value}.", |
|
hints=[ |
|
"If Dynamo attempting to trace a return statement and your code is attempting to return a variable " |
|
"that Dynamo cannot reconstruct, then remove it from the return statement.", |
|
*graph_break_hints.CAUSED_BY_EARLIER_GRAPH_BREAK, |
|
"Report an issue to PyTorch if you need reconstrtuction support. Note that objects that don't have" |
|
"reconstruction rules may be fundamentally unreconstructable.", |
|
], |
|
) |
|
if allow_cache and value in self.tempvars: |
|
self._output.append(create_dup_top()) |
|
self.add_cache(value) |
|
|
|
self.top_of_stack = value |
|
|
|
def add_graph_output(self, value): |
|
graph_outputs_key = id(value.as_proxy()) |
|
if graph_outputs_key not in self.graph_outputs: |
|
self.graph_outputs[graph_outputs_key] = GraphOutputEntry( |
|
len(self.graph_outputs), value |
|
) |
|
return graph_outputs_key |
|
|
|
def load_graph_output(self, index): |
|
output = self._output |
|
output.append(self.create_load(self.graph_output_var)) |
|
output.append(self.create_load_const(index)) |
|
output.append(self.create_binary_subscr()) |
|
|
|
def add_cache(self, value): |
|
var = self.new_var() |
|
self.tempvars[value] = var |
|
self._output.append(self.create_store(var)) |
|
|
|
def foreach(self, items): |
|
for i in items: |
|
self(i) |
|
|
|
def create_binary_subscr(self) -> Instruction: |
|
return create_instruction("BINARY_SUBSCR") |
|
|
|
def setup_globally_cached(self, name, value): |
|
"""Store value in a new global""" |
|
name = re.sub(r"[^a-zA-Z0-9_]+", "_", name) |
|
f_globals = self.tx.f_globals |
|
if name in f_globals: |
|
assert id(f_globals[name]) == id(value) |
|
else: |
|
f_globals[name] = value |
|
return [self.create_load_global(name, add=True)] |
|
|
|
def clear_tos(self): |
|
self.top_of_stack = None |
|
|
|
def append_output(self, inst): |
|
assert isinstance(inst, Instruction) |
|
self._output.append(inst) |
|
self.clear_tos() |
|
|
|
def extend_output(self, insts): |
|
assert all(isinstance(x, Instruction) for x in insts) |
|
self._output.extend(insts) |
|
self.clear_tos() |
|
|
|
def get_instructions(self) -> list[Instruction]: |
|
return self._output |
|
|
|
def create_load(self, name) -> Instruction: |
|
assert name in self.code_options["co_varnames"], f"{name} missing" |
|
return create_instruction("LOAD_FAST", argval=name) |
|
|
|
def create_load_closure(self, name) -> Instruction: |
|
assert name in self.cell_and_freevars() |
|
inst_name = "LOAD_FAST" if sys.version_info >= (3, 13) else "LOAD_CLOSURE" |
|
return create_instruction(inst_name, argval=name) |
|
|
|
def create_load_deref(self, name) -> Instruction: |
|
assert name in self.cell_and_freevars() |
|
return create_instruction("LOAD_DEREF", argval=name) |
|
|
|
def create_store(self, name) -> Instruction: |
|
assert name in self.code_options["co_varnames"], f"{name} missing" |
|
return create_instruction("STORE_FAST", argval=name) |
|
|
|
def create_store_deref(self, name) -> Instruction: |
|
assert name in self.cell_and_freevars() |
|
return create_instruction("STORE_DEREF", argval=name) |
|
|
|
def create_load_global(self, name, add=False) -> Instruction: |
|
if add: |
|
self.tx.output.update_co_names(name) |
|
assert name in self.code_options["co_names"], f"{name} not in co_names" |
|
return create_instruction("LOAD_GLOBAL", argval=name) |
|
|
|
def create_load_const(self, value) -> Instruction: |
|
return create_load_const(value) |
|
|
|
def create_load_const_unchecked(self, value) -> Instruction: |
|
return create_load_const(value, checked=False) |
|
|
|
def load_method(self, name): |
|
self.tx.output.update_co_names(name) |
|
self.append_output(create_load_method(name)) |
|
|
|
def call_method(self, nargs): |
|
self.extend_output(create_call_method(nargs)) |
|
|
|
def create_load_attr(self, name) -> Instruction: |
|
if name not in self.code_options["co_names"]: |
|
self.code_options["co_names"] += (name,) |
|
return create_instruction("LOAD_ATTR", argval=name) |
|
|
|
def load_attr(self, name): |
|
self.append_output(self.create_load_attr(name)) |
|
|
|
def create_load_attrs(self, names): |
|
return [self.create_load_attr(name) for name in names.split(".")] |
|
|
|
def create_store_attr(self, name) -> Instruction: |
|
if name not in self.code_options["co_names"]: |
|
self.code_options["co_names"] += (name,) |
|
return create_instruction("STORE_ATTR", argval=name) |
|
|
|
def store_attr(self, name): |
|
self.append_output(self.create_store_attr(name)) |
|
|
|
def load_function_name(self, fn_name, push_null, num_on_stack=0): |
|
"""Load the global fn_name on the stack num_on_stack down""" |
|
output = [] |
|
if push_null and sys.version_info >= (3, 11): |
|
output.extend(add_push_null(self.create_load_global(fn_name, add=True))) |
|
if num_on_stack > 0: |
|
output.extend( |
|
[ |
|
*self.rot_n(num_on_stack + 2), |
|
*self.rot_n(num_on_stack + 2), |
|
] |
|
) |
|
else: |
|
output.extend( |
|
[ |
|
self.create_load_global(fn_name, add=True), |
|
*self.rot_n(num_on_stack + 1), |
|
] |
|
) |
|
return output |
|
|
|
def rot_n(self, n): |
|
try: |
|
return create_rot_n(n) |
|
except AttributeError: |
|
|
|
return [ |
|
create_instruction("BUILD_TUPLE", arg=n), |
|
self.create_load_const_unchecked(rot_n_helper(n)), |
|
*create_rot_n(2), |
|
create_instruction("CALL_FUNCTION_EX", arg=0), |
|
create_instruction("UNPACK_SEQUENCE", arg=n), |
|
] |
|
|
|
def pop_null(self): |
|
|
|
|
|
assert sys.version_info >= (3, 11) |
|
return [ |
|
self.create_load_const_unchecked(lambda: None), |
|
|
|
*( |
|
(create_instruction("SWAP", arg=2),) |
|
if sys.version_info >= (3, 13) |
|
else () |
|
), |
|
*create_call_function(0, False), |
|
create_instruction("POP_TOP"), |
|
] |
|
|
|
def pop_top(self): |
|
self.append_output(create_instruction("POP_TOP")) |
|
|
|
def call_function(self, nargs: int, push_null: bool): |
|
self.extend_output(create_call_function(nargs, push_null=push_null)) |
|
|
|
def dup_top(self): |
|
self.append_output(create_dup_top()) |
|
|
|
def store(self, varname): |
|
self.append_output(self.create_store(varname)) |
|
|
|
def load_deref(self, varname): |
|
self.append_output(self.create_load_deref(varname)) |
|
|
|
def make_function_with_closure( |
|
self, fn_name: str, code: types.CodeType, push_null: bool, num_on_stack=0 |
|
): |
|
freevars = code.co_freevars |
|
assert freevars |
|
output = self._output |
|
|
|
def gen_fn(): |
|
|
|
|
|
|
|
|
|
for var in freevars: |
|
assert var in self.cell_and_freevars() |
|
output.append(self.create_load_closure(var)) |
|
output.append(create_instruction("BUILD_TUPLE", arg=len(freevars))) |
|
output.append(self.create_load_const(code)) |
|
if sys.version_info < (3, 11): |
|
output.append(self.create_load_const(fn_name)) |
|
if sys.version_info >= (3, 13): |
|
output.extend( |
|
[ |
|
create_instruction("MAKE_FUNCTION"), |
|
create_instruction("SET_FUNCTION_ATTRIBUTE", arg=0x08), |
|
] |
|
) |
|
else: |
|
output.append(create_instruction("MAKE_FUNCTION", arg=0x08)) |
|
|
|
if push_null and sys.version_info >= (3, 11): |
|
self.add_push_null(gen_fn) |
|
output.extend(self.rot_n(num_on_stack + 2)) |
|
output.extend(self.rot_n(num_on_stack + 2)) |
|
else: |
|
gen_fn() |
|
output.extend(self.rot_n(num_on_stack + 1)) |
|
self.clear_tos() |
|
|
|
def create_load_python_module(self, mod) -> Instruction: |
|
""" |
|
Generate a LOAD_GLOBAL instruction to fetch a given python module. |
|
""" |
|
output = self.tx.output |
|
global_scope = output.global_scope |
|
name = re.sub(r"^.*[.]", "", mod.__name__) |
|
if global_scope.get(name, None) is mod: |
|
return self.create_load_global(name, add=True) |
|
prefix = f"___module_{name}" |
|
global_name = self.tx.output.install_global_by_id(prefix, mod) |
|
return self.create_load_global(global_name, add=True) |
|
|
|
def mark_source_temp(self, source: Source) -> None: |
|
""" |
|
Mark a source as a temp variable, so that it can be reused. |
|
""" |
|
if source not in self.tempvars: |
|
self.tempvars[source] = None |
|
|
|
def make_call_generated_code(self, fn_name: str) -> None: |
|
"""Call the generated code function stored in fn_name""" |
|
self.extend_output(self.load_function_name(fn_name, True)) |
|
|
|
graphargs = self.tx.output.graphargs |
|
|
|
seen_sources: OrderedSet[Source] = OrderedSet() |
|
|
|
def collect_temp_source(source): |
|
if source in seen_sources: |
|
|
|
self.mark_source_temp(source) |
|
|
|
|
|
return |
|
|
|
seen_sources.add(source) |
|
|
|
if isinstance(source, ChainedSource): |
|
collect_temp_source(source.base) |
|
|
|
if isinstance(source, DictGetItemSource) and isinstance( |
|
source.index, Source |
|
): |
|
collect_temp_source(source.index) |
|
|
|
|
|
|
|
|
|
for arg in graphargs: |
|
if arg.source is not None: |
|
collect_temp_source(arg.source) |
|
|
|
for arg in graphargs: |
|
if arg.pass_arg_as_tensor: |
|
self.add_push_null( |
|
lambda: self.extend_output( |
|
[ |
|
self.create_load_python_module(torch), |
|
self.create_load_attr("_as_tensor_fullprec"), |
|
] |
|
) |
|
) |
|
self.call_reconstruct(arg) |
|
self.extend_output(create_call_function(1, False)) |
|
else: |
|
self.call_reconstruct(arg) |
|
|
|
self.extend_output(create_call_function(len(graphargs), False)) |
|
|
|
def load_import_from(self, module_name, object_name) -> None: |
|
source = AttrSource(self.tx.import_source(module_name), object_name) |
|
|
|
|
|
|
|
|
|
self.mark_source_temp(source) |
|
self(source) |
|
|
|
def create_call_function_kw(self, nargs, kw_names, push_null) -> list[Instruction]: |
|
if sys.version_info >= (3, 13): |
|
output = create_call_function(nargs, push_null) |
|
assert output[-1].opname == "CALL" |
|
output.insert(-1, self.create_load_const(kw_names)) |
|
output[-1] = create_instruction("CALL_KW", arg=nargs) |
|
return output |
|
elif sys.version_info >= (3, 11): |
|
output = create_call_function(nargs, push_null) |
|
if sys.version_info >= (3, 12): |
|
idx = -1 |
|
expected_inst = "CALL" |
|
else: |
|
idx = -2 |
|
expected_inst = "PRECALL" |
|
assert output[idx].opname == expected_inst |
|
kw_names_inst = create_instruction("KW_NAMES", argval=kw_names) |
|
output.insert(idx, kw_names_inst) |
|
return output |
|
return [ |
|
self.create_load_const(kw_names), |
|
create_instruction("CALL_FUNCTION_KW", arg=nargs), |
|
] |
|
|
|
def create_delete(self, value) -> Instruction: |
|
return create_instruction("DELETE_FAST", argval=value) |
|
|