blob: 91779629eef99b9486ce54a21a4b6a5345096072 [file] [log] [blame]
from typing import List, Union, Tuple, Optional
from torchgen.model import (
Type,
BaseTy,
BaseType,
OptionalType,
ListType,
OperatorName,
FunctionSchema,
Return,
TensorOptionsArguments,
Argument,
)
from torchgen.api.types import (
CType,
BaseCppType,
BaseCType,
OptionalCType,
NamedCType,
deviceT,
layoutT,
VectorCType,
boolT,
longT,
doubleT,
ListCType,
stringT,
scalarT,
scalarTypeT,
memoryFormatT,
SymIntT,
)
_valueT = None
def getValueT():
global _valueT
if not _valueT:
raise NotImplementedError(
"The value type needs to be set with setValueT() in run_gen_lazy_tensor()"
)
return _valueT
def setValueT(val):
global _valueT
_valueT = val
# this is a bad hack. I need to refactor the data model to represent each arg in the schema as an object,
# making it easier to represent special properties of an arg.
tensorListValueT = BaseCppType("torch::lazy", "Value")
def process_ir_type(
typ: Type,
) -> Union[BaseCType, VectorCType, OptionalCType, ListCType]:
"""
This function takes a type from NativeFunctions and converts it for use with
lazy tensor codegen.
Type conversion for lazy currently consists of
(1) changing at::Tensors into lazy::Values
(2) wrapping everything in a BaseCType
(3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef)
(1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.)
There is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like'
This is incomplete- there are assertions in places that it's expected to need to add
more types as the codegen is used with more operators.
"""
if isinstance(typ, BaseType):
if typ.name == BaseTy.Tensor:
return BaseCType(getValueT())
elif typ.name == BaseTy.Scalar:
# at::scalar has special handling,
# and is wrapped in an lazy::Value just like at::tensor
return BaseCType(getValueT())
elif typ.name == BaseTy.ScalarType:
return BaseCType(scalarTypeT)
elif typ.name == BaseTy.int:
return BaseCType(longT)
elif typ.name == BaseTy.SymInt:
return BaseCType(getValueT())
elif typ.name == BaseTy.bool:
return BaseCType(boolT)
elif typ.name == BaseTy.float:
return BaseCType(doubleT)
elif typ.name == BaseTy.str:
return BaseCType(stringT)
elif typ.name == BaseTy.Device:
return BaseCType(deviceT)
elif typ.name == BaseTy.Layout:
return BaseCType(layoutT)
elif typ.name == BaseTy.MemoryFormat:
return BaseCType(memoryFormatT)
else:
raise AssertionError(f"TODO add support for type {repr(typ)}")
elif isinstance(typ, OptionalType):
return OptionalCType(process_ir_type(typ.elem))
elif isinstance(typ, ListType):
if str(typ.elem) == "Tensor?":
# TODO(whc) is this actually correct? or should it use a Vector like above
return ListCType(OptionalCType(BaseCType(getValueT())))
elif str(typ.elem) == "Tensor":
# this is a TensorList which comes in from GetTensorList as a Value
return BaseCType(tensorListValueT)
else:
return VectorCType(process_ir_type(typ.elem))
else:
raise AssertionError(f"unrecognized type {repr(typ)}")
def isValueType(typ: CType) -> bool:
"""
Given a type, determine if it is a Value-like type. This is equivalent to
being Tensor-like, but assumes the type has already been transformed.
"""
if isinstance(typ, BaseCType):
# I am regretting my naming conventions, but now we are wrapping at::scalar in
# lazy value, while preserving other 'scalar' types as scalars in the IR
return typ.type == getValueT() or typ.type == scalarT or typ.type == SymIntT
elif isinstance(typ, (OptionalCType, ListCType, VectorCType)):
return isValueType(typ.elem)
return False
def isSymIntType(typ: Type) -> bool:
return isinstance(typ, BaseType) and typ.name == BaseTy.SymInt
def isWrappedScalarType(typ: Type) -> bool:
"""
Given a type, determine if it is a c10::scalar which we will wrap in a lazy Value.
Since we literally change the type from scalarT to valueT, information is lost.
This function helps build a list of wrapped scalars to save that information
"""
if isinstance(typ, BaseType):
# I am regretting my naming conventions, but now we are wrapping at::scalar in
# lazy value, while preserving other 'scalar' types as scalars in the IR
return typ.name == BaseTy.Scalar
elif isinstance(typ, (OptionalType, ListType)):
return isWrappedScalarType(typ.elem)
return False
def isGeneratorType(typ: Type) -> bool:
if isinstance(typ, BaseType):
return typ.name == BaseTy.Generator
elif isinstance(typ, (OptionalType)):
return isGeneratorType(typ.elem)
return False
class LazyArgument:
name: str
orig_type: Type
lazy_type_: Optional[CType]
is_wrapped_scalar: bool
is_generator: bool
is_symint_or_list: bool
# true if this argument is or contains a lazy IR value
is_lazy_value: bool
def __init__(self, arg: Argument):
self.name = arg.name
self.orig_type = arg.type
self.is_generator = isGeneratorType(arg.type)
if self.is_generator:
assert isinstance(
arg.type, OptionalType
), "We expect all generators are optional since currently they are"
# there is no handling for generators in TorchScript IR (or XLA)
# so we fall back to eager if the (optional)generator has value, and otherwise
# its null and safe to exclude from lazy IR
self.lazy_type_ = None
else:
self.lazy_type_ = process_ir_type(arg.type)
self.is_wrapped_scalar = isWrappedScalarType(arg.type)
self.is_symint_or_list = isSymIntType(arg.type)
self.is_lazy_value = not self.is_generator and isValueType(self.lazy_type)
@property
def lazy_type(self) -> CType:
assert (
self.lazy_type_ is not None
), f"Attempted to access lazy_type for invalid argument {self.name}"
return self.lazy_type_
# Inspired by a FunctionSchema object, a LazyIrSchema holds the schema of a Lazy IR node.
# Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML),
# but carries type information from a native FunctionSchema modified for use with IR nodes,
# and preserving original argument names.
class LazyIrSchema:
# The name of the operator this function schema describes.
name: "OperatorName"
positional_args: Tuple[LazyArgument, ...]
keyword_args: Tuple[LazyArgument, ...]
# TODO: Need to handle collisions with argument names at some point
returns: Tuple["Return", ...]
# if this schema has a Generator arg, list its orig ctype/name but don't
# build a LazyArgument since lazy IR doesn't support it
generator_arg: Optional[NamedCType] = None
def __init__(self, func: FunctionSchema):
positional_args = []
for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]:
if arg_field == "self_arg" and func.arguments.self_arg is not None:
arg = getattr(func.arguments, "self_arg").argument
positional_args.append(LazyArgument(arg))
elif getattr(func.arguments, arg_field) is not None:
positional_args.extend(
[LazyArgument(arg) for arg in getattr(func.arguments, arg_field)]
)
self.positional_args = tuple(positional_args)
keyword_args = []
for arg_field in [
"pre_tensor_options_kwarg_only",
"tensor_options",
"post_tensor_options_kwarg_only",
"out",
]:
curr_args = getattr(func.arguments, arg_field)
if curr_args is not None:
if isinstance(curr_args, TensorOptionsArguments):
curr_args = curr_args.all()
for arg in curr_args:
if isGeneratorType(arg.type):
assert (
self.generator_arg is None
), "We expect there is only one generator arg"
self.generator_arg = NamedCType(arg.name, arg.type)
keyword_args.extend([LazyArgument(arg) for arg in curr_args])
self.keyword_args = tuple(keyword_args)
self.name = func.name
self.returns = func.returns
@property
def node_name(self) -> str:
"""
Return camel-case version of op in node.
Note: This function also appends any `overload_name` in the operation.
For example, if the op is `bitwise_and.Tensor`, the returned name
will be `BitwiseAndTensor`.
"""
op_name = f"{self.name.name}_{self.name.overload_name}".lower()
return "".join(word.capitalize() or "" for word in op_name.split("_"))
@property
def aten_name(self) -> str:
return f"{self.name.name}"
@property
def base_name(self) -> str:
return f"{self.name.name.base}"
def filtered_args(
self,
positional: bool = True,
keyword: bool = True,
values: bool = True,
scalars: bool = True,
generator: bool = False,
) -> List[LazyArgument]:
# This function maintains the sorted order of arguments but provides different filtered views.
# Some parts of the code care about kwargs vs args (TS lowerings),
# other parts care about whether they need to wrap the arg in a lazy value or leave it alone.
# Generators are special cased, as they are needed for fallback/shape-inference but not supported
# in TS lowerings and therefore also omitted from lazy IR.
args: List[LazyArgument] = []
if positional:
args.extend(self.positional_args)
if keyword:
args.extend(self.keyword_args)
if values and scalars and generator:
return args
elif values and scalars:
return [a for a in args if not a.is_generator]
elif values:
return [a for a in args if a.is_lazy_value]
elif scalars:
return [
a
for a in args
if not a.is_lazy_value and (generator or not a.is_generator)
]
return []
@property
def positional_values(self) -> List[LazyArgument]:
return self.filtered_args(
positional=True, keyword=False, values=True, scalars=False
)
@property
def positional_scalars(self) -> List[LazyArgument]:
return self.filtered_args(
positional=True, keyword=False, values=False, scalars=True
)
@property
def keyword_values(self) -> List[LazyArgument]:
return self.filtered_args(
positional=False, keyword=True, values=True, scalars=False
)
@property
def keyword_scalars(self) -> List[LazyArgument]:
return self.filtered_args(
positional=False, keyword=True, values=False, scalars=True
)