blob: 0bd1bc0fdc229ecc77e2f4db8f8c3e0467b21150 [file] [log] [blame]
import functools
import torch
from .nested_tensor import NestedTensor
from typing import * # noqa: F403
__all__: List[Any] = []
JAGGED_OPS_TABLE: Dict[Any, Any] = {}
def check_schema(schema_str: str, func, *args, **kwargs) -> None:
named_arg_types = schema_str.split(", ")
num_optional_args = sum([x.endswith("?") for x in named_arg_types])
min_args = len(named_arg_types) - num_optional_args
if not (len(args) >= min_args and len(args) <= len(named_arg_types)):
raise ValueError(
f"NestedTensor {func.__name__}({schema_str}): expected at least {min_args} "
f"arguments and at most {len(named_arg_types)} arguments, but got: "
f"{len(args)} arguments"
)
arg_type_check_fns = {
"t": lambda x: isinstance(x, torch.Tensor) and not isinstance(x, NestedTensor),
"jt": lambda x: isinstance(x, NestedTensor),
"any": lambda x: True,
}
for i, named_arg_type in enumerate(named_arg_types):
name, arg_type = named_arg_type.split(": ")
is_optional = arg_type.endswith("?")
normalized_arg_type = arg_type[:-1] if is_optional else arg_type
if normalized_arg_type not in arg_type_check_fns.keys():
raise AssertionError(f"Unknown arg type: {normalized_arg_type}")
if i >= len(args):
if not is_optional:
raise ValueError(
f"NestedTensor {func.__name__}({schema_str}) "
f"missing required argument: {name}"
)
continue
if not arg_type_check_fns[normalized_arg_type](args[i]):
raise ValueError(
f"NestedTensor {func.__name__}({schema_str}): {name} should be of "
f"type {arg_type}, but got: {type(args[i])}"
)
def check_ragged_dim_same(
func, a: NestedTensor, a_name: str, b: NestedTensor, b_name: str
) -> None:
# Calling into .shape here
assert len(a._size) == 3, "NestedTensor must be [B, *, D]"
if a._size[1] != b._size[1]:
raise RuntimeError(
f"NestedTensor {func.__name__}: expected {a_name} and {b_name} to have the "
"same exact offsets tensor."
)
def register_func(tables, aten_ops, schema_str):
if not isinstance(aten_ops, list):
aten_ops = [aten_ops]
if not isinstance(tables, list):
tables = [tables]
def wrapper(func):
for aten_op in aten_ops:
def get_inner(aten_op):
def inner(*args, **kwargs):
check_schema(schema_str, func, *args, **kwargs)
return func(aten_op, *args, **kwargs)
return inner
for table in tables:
table[aten_op] = get_inner(aten_op)
return wrapper
register_jagged_func = functools.partial(register_func, JAGGED_OPS_TABLE)
def lookup_jagged(func, *args, **kwargs) -> Optional[Callable]:
if torch.Tag.pointwise in func.tags:
# Assume there aren't additional tensors that aren't the "unary/binary" args
num_tensor_args = sum([isinstance(x, torch.Tensor) for x in args])
if num_tensor_args == 1:
return functools.partial(jagged_unary_pointwise, func)
elif num_tensor_args == 2:
check_schema("lhs: jt, rhs: jt", func, *args, **kwargs)
return functools.partial(jagged_binary_pointwise, func)
else:
return None
return JAGGED_OPS_TABLE.get(func, None)
def extract_kwargs(arg):
kwargs = {
"offsets": arg.offsets(),
}
return kwargs
def jagged_unary_pointwise(func, *args, **kwargs):
return NestedTensor(func(args[0].values(), **kwargs), **extract_kwargs(args[0]))
def jagged_binary_pointwise(func, *args, **kwargs):
check_ragged_dim_same(func, args[0], "lhs", args[1], "rhs")
return NestedTensor(
func(args[0].values(), args[1].values(), **kwargs), **extract_kwargs(args[0])
)
@register_jagged_func(
[
torch.ops.aten.is_non_overlapping_and_dense.default,
torch.ops.aten.sym_size.default,
torch.ops.aten.dim.default,
torch.ops.aten.sym_numel.default,
],
"self: jt",
)
def tensor_attr_supported_getter(func, *args, **kwargs):
if func == torch.ops.aten.is_non_overlapping_and_dense.default:
return False
if func == torch.ops.aten.sym_size.default:
return args[0]._size
if func == torch.ops.aten.dim.default:
return 3
if func == torch.ops.aten.sym_numel.default:
return args[0].values().numel()
@register_jagged_func(
[
torch.ops.aten.size.default,
torch.ops.aten.sym_stride.default,
torch.ops.aten.is_contiguous.default,
torch.ops.aten.is_contiguous.memory_format,
torch.ops.aten.sym_storage_offset.default,
],
"self: jt, memory_format: any?",
)
def tensor_attr_unsupported_getter(func, *args, **kwargs):
if func == torch.ops.aten.size.default:
raise RuntimeError(
"NestedTensors does not support directly calling torch.ops.aten.size "
"please use `nested_tensor.size()` instead."
)
raise RuntimeError(
"NestedTensors do not support directly querying strides, "
"storage_offset, or contiguity."
)
@register_jagged_func(torch.ops.aten.linear.default, "input: jt, weight: t, bias: t?")
def linear_default(func, *args, **kwargs):
values = torch.mm(args[0].values(), args[1])
if len(args) == 3:
values += args[2]
return NestedTensor(values, **extract_kwargs(args[0]))
@register_jagged_func(
torch.ops.aten.linear_backward.default,
"self: jt, grad_output: jt, weight: t, output_mask: any",
)
def linear_backward_default(func, *args, **kwargs):
check_ragged_dim_same(func, args[0], "self", args[1], "grad_output")
ds = NestedTensor(torch.mm(args[1].values(), args[2].T), **extract_kwargs(args[1]))
dw = torch.mm(args[0].values().T, args[1].values())
db = None # NYI: gradient for bias, need to reduce over ragged dim
return (ds, dw, db)