|
from typing import Any, Optional, Union |
|
|
|
from torchgen.model import ( |
|
Annotation, |
|
Argument, |
|
Arguments, |
|
BaseOperatorName, |
|
BaseTy, |
|
BaseType, |
|
CustomClassType, |
|
FunctionSchema, |
|
ListType, |
|
OperatorName, |
|
Return, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TypeGen: |
|
convert_to_base_ty = { |
|
int: BaseTy.int, |
|
float: BaseTy.float, |
|
str: BaseTy.str, |
|
bool: BaseTy.bool, |
|
} |
|
|
|
@staticmethod |
|
def from_example(obj: Any) -> Union[BaseType, ListType, CustomClassType]: |
|
import torch |
|
|
|
if isinstance(obj, torch.fx.GraphModule): |
|
return BaseType(BaseTy.GraphModule) |
|
elif isinstance(obj, torch.Tensor): |
|
return BaseType(BaseTy.Tensor) |
|
elif isinstance(obj, torch.SymInt): |
|
return BaseType(BaseTy.SymInt) |
|
elif isinstance(obj, torch.SymBool): |
|
return BaseType(BaseTy.SymBool) |
|
elif isinstance(obj, torch.ScriptObject): |
|
return CustomClassType(obj._type().name()) |
|
elif isinstance(obj, (list, tuple)): |
|
assert len(obj) > 0 |
|
all_base_tys = [TypeGen.from_example(x) for x in obj] |
|
if len(set(all_base_tys)) > 1: |
|
raise RuntimeError( |
|
f"Cannot generate schema for a seqeunce of args of heterogeneous types: {all_base_tys}. " |
|
"Consider unpacking the argument and give proper names to them if possible " |
|
"instead of using *args." |
|
) |
|
return ListType(all_base_tys[0], len(obj)) |
|
tp = type(obj) |
|
if tp not in TypeGen.convert_to_base_ty: |
|
raise RuntimeError(f"unsupported type {tp}") |
|
return BaseType(TypeGen.convert_to_base_ty[tp]) |
|
|
|
|
|
class ReturnGen: |
|
@staticmethod |
|
def from_example( |
|
name: Optional[str], obj: Any, annotation: Optional[Annotation] |
|
) -> Return: |
|
return Return(name, TypeGen.from_example(obj), annotation) |
|
|
|
|
|
class ArgumentGen: |
|
@staticmethod |
|
def from_example( |
|
name: str, obj: Any, default: Optional[str], annotation: Optional[Annotation] |
|
) -> Argument: |
|
return Argument( |
|
name, TypeGen.from_example(obj), default=default, annotation=annotation |
|
) |
|
|
|
|
|
class FunctionSchemaGen: |
|
@staticmethod |
|
def from_example( |
|
op_name: str, |
|
example_inputs: tuple[tuple[str, Any], ...], |
|
example_outputs: tuple[Any, ...], |
|
) -> FunctionSchema: |
|
args = [] |
|
for name, inp in example_inputs: |
|
args.append(ArgumentGen.from_example(name, inp, None, None)) |
|
|
|
arguments = Arguments( |
|
tuple(), None, tuple(args), tuple(), None, tuple(), tuple() |
|
) |
|
returns = tuple( |
|
ReturnGen.from_example(None, out, None) for out in example_outputs |
|
) |
|
op_name = OperatorName(BaseOperatorName(op_name, False, False, False), "") |
|
return FunctionSchema(op_name, arguments, returns) |
|
|