|
from __future__ import annotations |
|
|
|
import re |
|
from dataclasses import dataclass |
|
from typing import cast, TYPE_CHECKING |
|
|
|
from torchgen import local |
|
from torchgen.api import cpp |
|
from torchgen.api.types import BaseCType, Binding, NamedCType, tensorListT |
|
from torchgen.model import ( |
|
BaseTy, |
|
BaseType, |
|
FunctionSchema, |
|
ListType, |
|
NativeFunction, |
|
NativeFunctionsViewGroup, |
|
SchemaKind, |
|
Type, |
|
) |
|
from torchgen.utils import IDENT_REGEX |
|
|
|
|
|
if TYPE_CHECKING: |
|
from collections.abc import Sequence |
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
class SavedAttribute: |
|
|
|
|
|
nctype: NamedCType |
|
|
|
|
|
|
|
expr: str |
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
class Derivative: |
|
|
|
|
|
|
|
|
|
|
|
|
|
formula: str |
|
|
|
|
|
original_formula: str |
|
|
|
|
|
var_names: tuple[str, ...] |
|
|
|
|
|
saved_inputs: tuple[SavedAttribute, ...] |
|
|
|
|
|
saved_outputs: tuple[SavedAttribute, ...] |
|
|
|
|
|
named_gradients: set[str] |
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
class ForwardDerivative: |
|
|
|
|
|
|
|
formula: str |
|
|
|
|
|
|
|
var_names: tuple[str, ...] |
|
|
|
|
|
|
|
var_types: tuple[Type, ...] |
|
|
|
|
|
required_inputs_fw_grad: tuple[str, ...] | None |
|
|
|
|
|
required_inputs_primal: tuple[str, ...] | None |
|
|
|
|
|
|
|
required_original_self_value: bool |
|
|
|
|
|
|
|
is_reusing_outplace_formula: bool |
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
class DifferentiabilityInfo: |
|
|
|
name: str |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func: NativeFunction |
|
|
|
|
|
|
|
|
|
op: str | None |
|
|
|
|
|
|
|
derivatives: Sequence[Derivative] |
|
|
|
|
|
|
|
forward_derivatives: Sequence[ForwardDerivative] |
|
|
|
|
|
all_saved_inputs: Sequence[SavedAttribute] |
|
|
|
|
|
all_saved_outputs: Sequence[SavedAttribute] |
|
|
|
|
|
|
|
available_named_gradients: Sequence[str] |
|
|
|
|
|
|
|
used_named_gradients: set[str] |
|
|
|
|
|
|
|
|
|
args_with_derivatives: Sequence[Binding] |
|
|
|
|
|
non_differentiable_arg_names: Sequence[str] |
|
|
|
|
|
output_differentiability: list[bool] | None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output_differentiability_conditions: list[str] | None |
|
|
|
@property |
|
def has_derivatives(self) -> bool: |
|
return len(self.args_with_derivatives) > 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_view_copy_from_view_derivative( |
|
self, g: NativeFunctionsViewGroup |
|
) -> DifferentiabilityInfo | None: |
|
if g.view_copy is None: |
|
return None |
|
f = g.view_copy |
|
|
|
name_split_by_period = self.name.split(".", maxsplit=2) |
|
|
|
view_copy_name = f"{name_split_by_period[0]}_copy." + ".".join( |
|
name_split_by_period[1:] |
|
) |
|
view_copy_op_name = None if self.op is None else f"{self.op}_copy" |
|
|
|
return DifferentiabilityInfo( |
|
|
|
name=view_copy_name, |
|
func=f, |
|
op=view_copy_op_name, |
|
|
|
derivatives=self.derivatives, |
|
forward_derivatives=self.forward_derivatives, |
|
all_saved_inputs=self.all_saved_inputs, |
|
all_saved_outputs=self.all_saved_outputs, |
|
available_named_gradients=self.available_named_gradients, |
|
used_named_gradients=self.used_named_gradients, |
|
args_with_derivatives=self.args_with_derivatives, |
|
non_differentiable_arg_names=self.non_differentiable_arg_names, |
|
output_differentiability=self.output_differentiability, |
|
output_differentiability_conditions=self.output_differentiability_conditions, |
|
) |
|
|
|
|
|
def uses_ident(info: DifferentiabilityInfo | None, ident: str) -> bool: |
|
if info is None: |
|
return False |
|
for derivative in info.derivatives: |
|
formula = derivative.formula |
|
if re.search(IDENT_REGEX.format(ident), formula): |
|
return True |
|
return False |
|
|
|
|
|
def uses_retain_variables(info: DifferentiabilityInfo | None) -> bool: |
|
return uses_ident(info, "retain_variables") |
|
|
|
|
|
def uses_single_grad(info: DifferentiabilityInfo | None) -> bool: |
|
return uses_ident(info, "grad") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
class DifferentiableInput: |
|
name: str |
|
type: Type |
|
|
|
|
|
cpp_type: str |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
class DifferentiableOutput: |
|
name: str |
|
type: Type |
|
|
|
|
|
cpp_type: str |
|
|
|
|
|
@dataclass(frozen=True) |
|
class NativeFunctionWithDifferentiabilityInfo: |
|
func: NativeFunction |
|
info: dict[str, DifferentiabilityInfo] | None |
|
fw_derivatives: dict[str, Sequence[ForwardDerivative]] | None |
|
|
|
|
|
|
|
def dispatch_strategy(fn: NativeFunctionWithDifferentiabilityInfo) -> str: |
|
"""How are we going to call the underlying implementation of a |
|
declaration? There are two strategies: |
|
- use_derived: we want to call the implementation on CPUDoubleType |
|
(or a similar, derived Type instance). Because these derived |
|
instances deal in Tensors, not Variables (it's a completely different |
|
object, so it doesn't dispatch back to VariableType), code on |
|
this dispatch path needs to wrap/unwrap tensors. If the |
|
derived implementation takes and returns tensors, the |
|
implementation is usually differentiable (although we also use |
|
the derived dispatch path for non-differentiable functions |
|
that we still want to dispatch on the derived Type instance; |
|
e.g., size()) |
|
- use_type: we want to call the implementation on Type, because |
|
it is implemented concretely, and the functions it invokes will |
|
get dispatched back to VariableType (which will ensure that they |
|
are differentiable.) |
|
""" |
|
|
|
|
|
|
|
|
|
if fn.func.is_abstract or ( |
|
fn.info is not None and any(info.has_derivatives for info in fn.info.values()) |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return "use_derived" |
|
else: |
|
|
|
|
|
|
|
|
|
return "use_type" |
|
|
|
|
|
def is_foreach_func(f: NativeFunction) -> bool: |
|
return f.func.name.name.base.startswith("_foreach_") |
|
|
|
|
|
|
|
|
|
|
|
_foreach_with_inplace_ref = {"_foreach_zero_"} |
|
_foreach_with_tensor_overload = { |
|
"_foreach_add.Tensor", |
|
"_foreach_mul.Tensor", |
|
"_foreach_div.Tensor", |
|
} |
|
|
|
_skip_argument_len_check = { |
|
"_foreach_add.Scalar", |
|
"_foreach_add_.Scalar", |
|
"_foreach_add.ScalarList", |
|
"_foreach_add_.ScalarList", |
|
"_foreach_sub.Scalar", |
|
"_foreach_sub_.Scalar", |
|
"_foreach_sub.ScalarList", |
|
"_foreach_sub_.ScalarList", |
|
} |
|
|
|
|
|
|
|
|
|
def is_reference_for_foreach( |
|
f: NativeFunction, |
|
function_schema: FunctionSchema, |
|
) -> bool: |
|
return ( |
|
f.func.name.name.base.split("_foreach_")[-1] == function_schema.name.name.base |
|
and ( |
|
not function_schema.name.name.inplace |
|
or str(f.func.name) in _foreach_with_inplace_ref |
|
) |
|
and ( |
|
str(f.func.name) in _skip_argument_len_check |
|
or len(f.func.arguments.flat_non_out) |
|
== len(function_schema.arguments.flat_non_out) |
|
) |
|
and all( |
|
ref_arg.type in (arg.type, getattr(arg.type, "elem", None)) |
|
for arg, ref_arg in zip( |
|
f.func.arguments.flat_non_out, |
|
function_schema.arguments.flat_non_out, |
|
) |
|
) |
|
) |
|
|
|
|
|
|
|
def gen_foreach_derivativeinfo( |
|
foreach_function: NativeFunction, |
|
functional_info_by_signature: dict[ |
|
FunctionSchema, dict[str, DifferentiabilityInfo] |
|
], |
|
non_functional_info_by_signature: dict[ |
|
FunctionSchema, dict[str, DifferentiabilityInfo] |
|
], |
|
dispatch_key: str = "Default", |
|
) -> tuple[DifferentiabilityInfo | None, bool]: |
|
"""Generate DifferentiabilityInfo for out-place foreach function, return the existing one for in-place. |
|
|
|
The second return value indicates whether the info is generated in this function. |
|
""" |
|
ref_diff_info: DifferentiabilityInfo | None = None |
|
|
|
for function_schema, diff_info in functional_info_by_signature.items(): |
|
if not is_reference_for_foreach(foreach_function, function_schema): |
|
continue |
|
ref_diff_info = diff_info[dispatch_key] |
|
if ref_diff_info is not None: |
|
break |
|
|
|
|
|
if ( |
|
ref_diff_info is None |
|
and foreach_function.func.kind() == SchemaKind.inplace |
|
and str(foreach_function.func.name) in _foreach_with_inplace_ref |
|
): |
|
for function_schema, diff_info in non_functional_info_by_signature.items(): |
|
if not is_reference_for_foreach(foreach_function, function_schema): |
|
continue |
|
ref_diff_info = diff_info[dispatch_key] |
|
if ref_diff_info is not None: |
|
break |
|
if ref_diff_info is None: |
|
return None, False |
|
|
|
|
|
if foreach_function.func.kind() == SchemaKind.inplace: |
|
return ref_diff_info, False |
|
|
|
map_refarg2foreacharg, map_name2arg = {}, {} |
|
for i, (arg, ref_arg) in enumerate( |
|
zip( |
|
foreach_function.func.arguments.flat_non_out, |
|
function_schema.arguments.flat_non_out, |
|
) |
|
): |
|
map_refarg2foreacharg[ref_arg.name] = arg.name |
|
map_name2arg[arg.name] = arg |
|
|
|
all_saved_inputs, all_saved_outputs, all_var_names = [], [], [] |
|
modified_derivative_formulas = [] |
|
for i, derivative in enumerate(ref_diff_info.derivatives): |
|
modified_formula = derivative.formula.replace("grad", "grads[i]").replace( |
|
"result", "result[i]" |
|
) |
|
saved_inputs, saved_outputs = [], [] |
|
|
|
with local.parametrize( |
|
use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors, |
|
use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group, |
|
): |
|
for ref_input in derivative.saved_inputs: |
|
ref_input_jit_name = ref_input.expr.split(".")[0] |
|
mapped_name = map_refarg2foreacharg[ref_input_jit_name] |
|
if isinstance(map_name2arg[mapped_name].type, ListType): |
|
mapped_expr = mapped_name + "[i]" |
|
else: |
|
mapped_expr = mapped_name |
|
new_expr = ref_input.expr.replace(ref_input_jit_name, mapped_expr) |
|
modified_formula = modified_formula.replace( |
|
cast(str, ref_input.nctype.name), new_expr |
|
) |
|
|
|
nctype = cpp.argument_type(map_name2arg[mapped_name], binds=mapped_name) |
|
canonical_nctype = NamedCType( |
|
nctype.name, nctype.type.remove_const_ref() |
|
) |
|
saved_inputs.append( |
|
SavedAttribute(nctype=canonical_nctype, expr=mapped_name) |
|
) |
|
for ref_output in derivative.saved_outputs: |
|
if ref_output.nctype.name == "result": |
|
saved_outputs.append( |
|
SavedAttribute( |
|
nctype=NamedCType( |
|
name="result", type=BaseCType(tensorListT) |
|
), |
|
expr="result", |
|
) |
|
) |
|
else: |
|
raise RuntimeError("") |
|
var_names = [map_refarg2foreacharg[var] for var in derivative.var_names] |
|
all_var_names.extend(var_names) |
|
all_saved_inputs.extend(saved_inputs) |
|
all_saved_outputs.extend(saved_outputs) |
|
modified_derivative = Derivative( |
|
formula=modified_formula, |
|
original_formula=derivative.formula, |
|
var_names=tuple(var_names), |
|
saved_inputs=tuple(saved_inputs), |
|
saved_outputs=tuple(saved_outputs), |
|
named_gradients=set(), |
|
) |
|
modified_derivative_formulas.append(modified_derivative) |
|
|
|
with local.parametrize( |
|
use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors, |
|
use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group, |
|
): |
|
args_with_derivatives = [ |
|
Binding( |
|
name=arg.name, |
|
nctype=cpp.argument_type(arg, binds=arg.name), |
|
argument=arg, |
|
default=None, |
|
) |
|
for arg in foreach_function.func.arguments.flat_non_out |
|
if arg.name in all_var_names |
|
] |
|
|
|
forward_derivatives: list[ForwardDerivative] = [] |
|
fw_derivative: ForwardDerivative |
|
for fw_derivative in ref_diff_info.forward_derivatives: |
|
var_names: list[str] = list(fw_derivative.var_names) |
|
var_types: list[Type] = list(fw_derivative.var_types) |
|
required_inputs_fw_grad: list[str] = [] |
|
required_inputs_primal: list[str] = [] |
|
if fw_derivative.required_inputs_fw_grad is not None: |
|
required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad) |
|
if fw_derivative.required_inputs_primal: |
|
required_inputs_primal = list(fw_derivative.required_inputs_primal) |
|
modified_formula = fw_derivative.formula |
|
|
|
|
|
if "result" in modified_formula: |
|
modified_formula = fw_derivative.formula.replace("result", "result[i]") |
|
|
|
for foreach_arg, ref_arg in zip( |
|
foreach_function.func.arguments.flat_non_out, |
|
ref_diff_info.func.func.arguments.flat_non_out, |
|
): |
|
|
|
if ( |
|
isinstance(foreach_arg.type, ListType) |
|
and not foreach_arg.type.is_tensor_like() |
|
): |
|
|
|
modified_formula = modified_formula.replace( |
|
ref_arg.name, foreach_arg.name + "[i]" |
|
) |
|
elif foreach_arg.type.is_tensor_like(): |
|
|
|
|
|
assert isinstance(foreach_arg.type, ListType) or ( |
|
foreach_arg.type == BaseType(BaseTy.Tensor) |
|
and str(foreach_function.func.name) in _foreach_with_tensor_overload |
|
), f"{foreach_function.func.name}, {foreach_arg.type}" |
|
for suffix in ("_p", "_t"): |
|
curr_expr = ref_arg.name + suffix |
|
if curr_expr in modified_formula: |
|
new_expr = foreach_arg.name + suffix |
|
modified_formula = modified_formula.replace(curr_expr, new_expr) |
|
else: |
|
|
|
if foreach_arg.name != ref_arg.name: |
|
modified_formula = modified_formula.replace( |
|
ref_arg.name, foreach_arg.name |
|
) |
|
|
|
|
|
for i, name in enumerate(var_names): |
|
if name == ref_arg.name: |
|
var_names[i] = foreach_arg.name |
|
var_types[i] = foreach_arg.type |
|
for i, name in enumerate(required_inputs_fw_grad): |
|
if name == ref_arg.name: |
|
required_inputs_fw_grad[i] = foreach_arg.name |
|
for i, name in enumerate(required_inputs_primal): |
|
if name == ref_arg.name: |
|
required_inputs_primal[i] = foreach_arg.name |
|
forward_derivatives.append( |
|
ForwardDerivative( |
|
formula=modified_formula, |
|
var_names=tuple(var_names), |
|
var_types=tuple(var_types), |
|
required_inputs_fw_grad=tuple(required_inputs_fw_grad), |
|
required_inputs_primal=tuple(required_inputs_primal), |
|
required_original_self_value=fw_derivative.required_original_self_value, |
|
is_reusing_outplace_formula=fw_derivative.is_reusing_outplace_formula, |
|
) |
|
) |
|
|
|
return ( |
|
DifferentiabilityInfo( |
|
name=foreach_function.func.name.name.base, |
|
func=foreach_function, |
|
op=f"Foreach{ref_diff_info.op}{foreach_function.func.name.overload_name}", |
|
derivatives=modified_derivative_formulas, |
|
forward_derivatives=forward_derivatives, |
|
all_saved_inputs=tuple(set(all_saved_inputs)), |
|
all_saved_outputs=tuple(set(all_saved_outputs)), |
|
available_named_gradients=(), |
|
used_named_gradients=set(), |
|
args_with_derivatives=args_with_derivatives, |
|
non_differentiable_arg_names=[], |
|
output_differentiability=None, |
|
output_differentiability_conditions=None, |
|
), |
|
True, |
|
) |
|
|
|
|
|
def match_differentiability_info( |
|
native_functions: list[NativeFunction], |
|
differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]], |
|
) -> list[NativeFunctionWithDifferentiabilityInfo]: |
|
"""Sets the "derivative" key on declarations to matching autograd function |
|
In-place functions will use the out-of-place derivative definition if there |
|
is no in-place specific derivative. |
|
""" |
|
|
|
functional_info_by_signature = { |
|
schema.signature(strip_default=True): info_dict |
|
for schema, info_dict in differentiability_infos.items() |
|
if schema.kind() == SchemaKind.functional |
|
} |
|
non_functional_info_by_signature = { |
|
schema.signature(strip_default=True): info_dict |
|
for schema, info_dict in differentiability_infos.items() |
|
if schema.kind() != SchemaKind.functional |
|
} |
|
|
|
def find_info( |
|
f: NativeFunction, |
|
) -> tuple[dict[str, DifferentiabilityInfo] | None, bool]: |
|
|
|
if "generated" in f.tags and f.func.kind() == SchemaKind.out: |
|
return None, False |
|
|
|
|
|
if f.func in differentiability_infos: |
|
return differentiability_infos[f.func], True |
|
|
|
|
|
|
|
|
|
|
|
|
|
f_sig = f.func.signature(strip_default=True) |
|
if f_sig in functional_info_by_signature and not is_foreach_func(f): |
|
return functional_info_by_signature[f_sig], False |
|
|
|
|
|
|
|
|
|
|
|
|
|
if "generated" in f.tags and f_sig in non_functional_info_by_signature: |
|
info_dict = non_functional_info_by_signature[f_sig] |
|
|
|
assert not any( |
|
any("self" in str(inpt.nctype.name) for inpt in info.all_saved_inputs) |
|
for info in info_dict.values() |
|
), f"""\ |
|
Attempted to convert a derivative formula for a mutable operator |
|
to be used by automatically by its functional variant ("{str(f.func)}"). |
|
this is not currently supported (we'd need to fix up the formula in the codegen).""" |
|
return info_dict, False |
|
|
|
|
|
if is_foreach_func(f): |
|
assert f.func not in differentiability_infos |
|
diff_info, is_generated = gen_foreach_derivativeinfo( |
|
f, |
|
functional_info_by_signature, |
|
non_functional_info_by_signature, |
|
) |
|
if diff_info is None: |
|
return None, False |
|
|
|
diff_info_dict = {"Default": diff_info} |
|
if is_generated: |
|
differentiability_infos[f.func] = diff_info_dict |
|
functional_info_by_signature[f.func] = diff_info_dict |
|
return diff_info_dict, is_generated |
|
|
|
return None, False |
|
|
|
result: list[NativeFunctionWithDifferentiabilityInfo] = [] |
|
for f in native_functions: |
|
info_dict, is_exact_match = find_info(f) |
|
|
|
|
|
|
|
if f.func.kind() == SchemaKind.inplace and (info_dict is not None): |
|
for info in info_dict.values(): |
|
for derivative in info.derivatives: |
|
if "self" in derivative.var_names: |
|
for saved_input in derivative.saved_inputs: |
|
assert "strides_or_error" not in saved_input.expr, ( |
|
"Calling '.strides()' in the 'self' derivative formula of an " |
|
f"in-place function is not supported: {f.func}" |
|
) |
|
|
|
if not info_dict: |
|
result.append( |
|
NativeFunctionWithDifferentiabilityInfo( |
|
func=f, info=None, fw_derivatives=None |
|
) |
|
) |
|
continue |
|
|
|
fw_derivative_dict: dict[str, Sequence[ForwardDerivative]] = {} |
|
for key, info in info_dict.items(): |
|
if not info.forward_derivatives: |
|
fw_derivative_dict[key] = [] |
|
continue |
|
|
|
forward_derivatives = info.forward_derivatives |
|
|
|
|
|
if f.func.kind() == SchemaKind.inplace: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert ( |
|
len(info.forward_derivatives) == 1 |
|
) |
|
fw_info = info.forward_derivatives[0] |
|
formula = fw_info.formula |
|
|
|
def replace_self_with_original_self(formula: str, postfix: str) -> str: |
|
def repl(m: re.Match[str]) -> str: |
|
return f"{m.group(1)}original_self{postfix}{m.group(2)}" |
|
|
|
return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula) |
|
|
|
if re.search(IDENT_REGEX.format("self_p"), formula): |
|
if is_exact_match: |
|
|
|
raise RuntimeError( |
|
f'The formula for "{f.func.name}" is using the original value of self ' |
|
"that is being modified inplace. This would lead to wrong forward gradients. " |
|
'Please use "result" in the formula only.' |
|
) |
|
else: |
|
|
|
|
|
|
|
formula = replace_self_with_original_self(formula, "_p") |
|
formula = replace_self_with_original_self(formula, "_t") |
|
|
|
|
|
def repl(m: re.Match[str]) -> str: |
|
return f"{m.group(1)}self_p{m.group(2)}" |
|
|
|
formula = re.sub(IDENT_REGEX.format("result"), repl, formula) |
|
|
|
required_primals = fw_info.required_inputs_primal |
|
if re.search(IDENT_REGEX.format("self_p"), formula): |
|
required_primals = ( |
|
required_primals + ("self",) if required_primals else ("self",) |
|
) |
|
|
|
if not is_exact_match: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
is_single_method_on_self_t = False |
|
directly_do_inplace = False |
|
op_name: str | None = None |
|
between_parens: str | None = None |
|
match = re.fullmatch(r"self_t.([\w]*)\((.*)\)", formula) |
|
if match: |
|
op_name, between_parens = match.group(1), match.group(2) |
|
|
|
|
|
|
|
|
|
|
|
def check_parens_nest_level_gt_zero(s: str) -> bool: |
|
level = 1 |
|
for ch in s: |
|
if ch == ")": |
|
level -= 1 |
|
if level == 0: |
|
return False |
|
if ch == "(": |
|
level += 1 |
|
return True |
|
|
|
is_single_method_on_self_t = check_parens_nest_level_gt_zero( |
|
between_parens |
|
) |
|
directly_do_inplace = ( |
|
is_single_method_on_self_t and op_name == info.name |
|
) |
|
|
|
if directly_do_inplace: |
|
assert op_name is not None |
|
assert between_parens is not None |
|
formula = f"self_t_raw.defined() ? self_t_raw.{op_name}_({between_parens}) : {formula}" |
|
else: |
|
|
|
|
|
formula = f"self_t_raw.defined() ? self_t_raw.copy_({formula}) : {formula}" |
|
|
|
required_original_self_value = bool( |
|
re.search(IDENT_REGEX.format("original_self_p"), formula) |
|
) or bool(re.search(IDENT_REGEX.format("original_self_t"), formula)) |
|
|
|
forward_derivatives = [ |
|
ForwardDerivative( |
|
formula=formula, |
|
var_names=("self",), |
|
var_types=fw_info.var_types, |
|
required_inputs_fw_grad=fw_info.required_inputs_fw_grad, |
|
required_inputs_primal=required_primals, |
|
required_original_self_value=required_original_self_value, |
|
is_reusing_outplace_formula=not is_exact_match, |
|
), |
|
] |
|
|
|
fw_derivative_dict[key] = forward_derivatives |
|
|
|
result.append( |
|
NativeFunctionWithDifferentiabilityInfo( |
|
func=f, info=info_dict, fw_derivatives=fw_derivative_dict |
|
) |
|
) |
|
|
|
return result |
|
|
|
|
|
def is_differentiable( |
|
name: str, type: Type, info: DifferentiabilityInfo | None |
|
) -> bool: |
|
return type.is_tensor_like() and ( |
|
info is None or name not in info.non_differentiable_arg_names |
|
) |
|
|
|
|
|
def gen_differentiable_outputs( |
|
fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default" |
|
) -> list[DifferentiableOutput]: |
|
f = fn.func |
|
info = fn.info[key] if fn.info else None |
|
outputs: list[DifferentiableOutput] = [ |
|
DifferentiableOutput( |
|
name=name, |
|
type=ret.type, |
|
cpp_type=cpp.return_type(ret, symint=True).cpp_type(), |
|
) |
|
for name, ret in zip(cpp.return_names(f), f.func.returns) |
|
] |
|
output_differentiability = info.output_differentiability if info else None |
|
if output_differentiability is not None: |
|
if len(output_differentiability) != len(outputs): |
|
raise RuntimeError( |
|
f"The length of output_differentiability ({len(output_differentiability)}), " |
|
f"does not match the number of outputs ({len(outputs)})." |
|
) |
|
differentiable_outputs: list[DifferentiableOutput] = [] |
|
if False in output_differentiability and f.func.kind() == SchemaKind.inplace: |
|
raise RuntimeError( |
|
"output_differentiability=False for inplace operation (version_counter won't get updated)" |
|
) |
|
for differentiable, output in zip(output_differentiability, outputs): |
|
if differentiable: |
|
differentiable_outputs.append(output) |
|
return differentiable_outputs |
|
candidate_differentiable_outputs = list( |
|
filter(lambda r: is_differentiable(r.name, r.type, info), outputs) |
|
) |
|
if uses_single_grad(info): |
|
return candidate_differentiable_outputs[:1] |
|
else: |
|
return candidate_differentiable_outputs |
|
|