| from typing import List, Union, Tuple, Optional |
| from tools.codegen.model import (Type, BaseTy, BaseType, OptionalType, |
| ListType, OperatorName, FunctionSchema, |
| Return, TensorOptionsArguments, Argument) |
| from tools.codegen.api.types import (CType, BaseCppType, BaseCType, OptionalCType, |
| NamedCType, deviceT, layoutT, |
| VectorCType, boolT, longT, doubleT, ListCType, stringT, |
| scalarT, scalarTypeT, memoryFormatT) |
| |
| valueT = BaseCppType('torch::lazy', 'Value') |
| # this is a bad hack. I need to refactor the data model to represent each arg in the schema as an object, |
| # making it easier to represent special properties of an arg. |
| tensorListValueT = BaseCppType('torch::lazy', 'Value') |
| |
| def process_ir_type(typ: Type) -> Union[BaseCType, VectorCType, OptionalCType, ListCType]: |
| """ |
| This function takes a type from NativeFunctions and converts it for use with |
| lazy tensor codegen. |
| |
| Type conversion for lazy currently consists of |
| (1) changing at::Tensors into lazy::Values |
| (2) wrapping everything in a BaseCType |
| (3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef) |
| |
| (1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.) |
| There is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like' |
| |
| This is incomplete- there are assertions in places that it's expected to need to add |
| more types as the codegen is used with more operators. |
| """ |
| if isinstance(typ, BaseType): |
| if typ.name == BaseTy.Tensor: |
| return BaseCType(valueT) |
| elif typ.name == BaseTy.Scalar: |
| # at::scalar has special handling, |
| # and is wrapped in an lazy::Value just like at::tensor |
| return BaseCType(valueT) |
| elif typ.name == BaseTy.ScalarType: |
| return BaseCType(scalarTypeT) |
| elif typ.name == BaseTy.int: |
| return BaseCType(longT) |
| elif typ.name == BaseTy.bool: |
| return BaseCType(boolT) |
| elif typ.name == BaseTy.float: |
| return BaseCType(doubleT) |
| elif typ.name == BaseTy.str: |
| return BaseCType(stringT) |
| elif typ.name == BaseTy.Device: |
| return BaseCType(deviceT) |
| elif typ.name == BaseTy.Layout: |
| return BaseCType(layoutT) |
| elif typ.name == BaseTy.MemoryFormat: |
| return BaseCType(memoryFormatT) |
| else: |
| raise AssertionError(f"TODO add support for type {repr(typ)}") |
| elif isinstance(typ, OptionalType): |
| return OptionalCType(process_ir_type(typ.elem)) |
| elif isinstance(typ, ListType): |
| if str(typ.elem) == 'Tensor?': |
| # TODO(whc) is this actually correct? or should it use a Vector like above |
| return ListCType(OptionalCType(BaseCType(valueT))) |
| elif str(typ.elem) == 'Tensor': |
| # this is a TensorList which comes in from GetTensorList as a Value |
| return BaseCType(tensorListValueT) |
| else: |
| return VectorCType(process_ir_type(typ.elem)) |
| else: |
| raise AssertionError(f"unrecognized type {repr(typ)}") |
| |
| |
| def isValueType(typ: CType) -> bool: |
| """ |
| Given a type, determine if it is a Value-like type. This is equivalent to |
| being Tensor-like, but assumes the type has already been transformed. |
| """ |
| if isinstance(typ, BaseCType): |
| # I am regretting my naming conventions, but now we are wrapping at::scalar in |
| # lazy value, while preserving other 'scalar' types as scalars in the IR |
| return typ.type == valueT or typ.type == scalarT |
| elif isinstance(typ, (OptionalCType, ListCType, VectorCType)): |
| return isValueType(typ.elem) |
| return False |
| |
| def isWrappedScalarType(typ: Type) -> bool: |
| """ |
| Given a type, determine if it is a c10::scalar which we will wrap in a lazy Value. |
| Since we literally change the type from scalarT to valueT, information is lost. |
| This function helps build a list of wrapped scalars to save that information |
| """ |
| if isinstance(typ, BaseType): |
| # I am regretting my naming conventions, but now we are wrapping at::scalar in |
| # lazy value, while preserving other 'scalar' types as scalars in the IR |
| return typ.name == BaseTy.Scalar |
| elif isinstance(typ, (OptionalType, ListType)): |
| return isWrappedScalarType(typ.elem) |
| return False |
| |
| def isGeneratorType(typ: Type) -> bool: |
| if isinstance(typ, BaseType): |
| return typ.name == BaseTy.Generator |
| elif isinstance(typ, (OptionalType)): |
| return isGeneratorType(typ.elem) |
| return False |
| |
| class LazyArgument: |
| name: str |
| orig_type: Type |
| lazy_type_: Optional[CType] |
| is_wrapped_scalar: bool |
| is_generator: bool |
| |
| # true if this argument is or contains a lazy IR value |
| is_lazy_value: bool |
| |
| def __init__(self, arg: Argument): |
| self.name = arg.name |
| self.orig_type = arg.type |
| self.is_generator = isGeneratorType(arg.type) |
| if self.is_generator: |
| assert isinstance(arg.type, OptionalType), "We expect all generators are optional since currently they are" |
| # there is no handling for generators in TorchScript IR (or XLA) |
| # so we fall back to eager if the (optional)generator has value, and otherwise |
| # its null and safe to exclude from lazy IR |
| self.lazy_type_ = None |
| else: |
| self.lazy_type_ = process_ir_type(arg.type) |
| self.is_wrapped_scalar = isWrappedScalarType(arg.type) |
| |
| self.is_lazy_value = not self.is_generator and isValueType(self.lazy_type) |
| |
| @property |
| def lazy_type(self) -> CType: |
| assert self.lazy_type_ is not None, f"Attempted to access lazy_type for invalid argument {self.name}" |
| return self.lazy_type_ |
| |
| # Inspired by a FunctionSchema object, a LazyIrSchema holds the schema of a Lazy IR node. |
| # Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML), |
| # but carries type information from a native FunctionSchema modified for use with IR nodes, |
| # and preserving original argument names. |
| class LazyIrSchema: |
| # The name of the operator this function schema describes. |
| name: 'OperatorName' |
| |
| positional_args: Tuple[LazyArgument, ...] |
| keyword_args: Tuple[LazyArgument, ...] |
| |
| # TODO: Need to handle collisions with argument names at some point |
| returns: Tuple['Return', ...] |
| |
| # if this schema has a Generator arg, list its orig ctype/name but don't |
| # build a LazyArgument since lazy IR doesn't support it |
| generator_arg: Optional[NamedCType] = None |
| |
| def __init__(self, func: FunctionSchema): |
| |
| positional_args = [] |
| for arg_field in ["pre_self_positional", |
| "self_arg", |
| "post_self_positional"]: |
| if arg_field == "self_arg" and func.arguments.self_arg is not None: |
| arg = getattr(func.arguments, "self_arg").argument |
| positional_args.append(LazyArgument(arg)) |
| elif getattr(func.arguments, arg_field) is not None: |
| positional_args.extend([ |
| LazyArgument(arg) for arg in getattr(func.arguments, arg_field)]) |
| self.positional_args = tuple(positional_args) |
| |
| keyword_args = [] |
| for arg_field in ["pre_tensor_options_kwarg_only", |
| "tensor_options", |
| "post_tensor_options_kwarg_only", |
| "out"]: |
| curr_args = getattr(func.arguments, arg_field) |
| if curr_args is not None: |
| if isinstance(curr_args, TensorOptionsArguments): |
| curr_args = curr_args.all() |
| for arg in curr_args: |
| if isGeneratorType(arg.type): |
| assert self.generator_arg is None, "We expect there is only one generator arg" |
| self.generator_arg = NamedCType(arg.name, arg.type) |
| keyword_args.extend([LazyArgument(arg) for arg in curr_args]) |
| self.keyword_args = tuple(keyword_args) |
| self.name = func.name |
| self.returns = func.returns |
| |
| @property |
| def node_name(self) -> str: |
| """ |
| Return camel-case version of op in node. |
| |
| Note: This function also appends any `overload_name` in the operation. |
| For example, if the op is `bitwise_and.Tensor`, the returned name |
| will be `BitwiseAndTensor`. |
| """ |
| op_name = f"{self.name.name}_{self.name.overload_name}".lower() |
| return "".join(word.capitalize() or "" for word in op_name.split("_")) |
| |
| @property |
| def aten_name(self) -> str: |
| return f"{self.name.name}" |
| |
| @property |
| def base_name(self) -> str: |
| return f"{self.name.name.base}" |
| |
| def filtered_args(self, positional: bool = True, keyword: bool = True, |
| values: bool = True, scalars: bool = True, generator: bool = False) -> List[LazyArgument]: |
| # This function maintains the sorted order of arguments but provides different filtered views. |
| # Some parts of the code care about kwargs vs args (TS lowerings), |
| # other parts care about whether they need to wrap the arg in a lazy value or leave it alone. |
| # Generators are special cased, as they are needed for fallback/shape-inference but not supported |
| # in TS lowerings and therefore also omitted from lazy IR. |
| args: List[LazyArgument] = [] |
| if positional: |
| args.extend(self.positional_args) |
| if keyword: |
| args.extend(self.keyword_args) |
| |
| if values and scalars and generator: |
| return args |
| elif values and scalars: |
| return [a for a in args if not a.is_generator] |
| elif values: |
| return [a for a in args if a.is_lazy_value] |
| elif scalars: |
| return [a for a in args if not a.is_lazy_value and (generator or not a.is_generator)] |
| |
| return [] |
| |
| @property |
| def positional_values(self) -> List[LazyArgument]: |
| return self.filtered_args(positional=True, keyword=False, values=True, scalars=False) |
| |
| @property |
| def positional_scalars(self) -> List[LazyArgument]: |
| return self.filtered_args(positional=True, keyword=False, values=False, scalars=True) |
| |
| @property |
| def keyword_values(self) -> List[LazyArgument]: |
| return self.filtered_args(positional=False, keyword=True, values=True, scalars=False) |
| |
| @property |
| def keyword_scalars(self) -> List[LazyArgument]: |
| return self.filtered_args(positional=False, keyword=True, values=False, scalars=True) |