blob: 1a97b7b304a79a3a4185607663c1f71d117ec889 [file] [log] [blame]
import functools
import torch
from .nested_tensor import NestedTensor
from typing import * # noqa: F403
from torch.fx.operator_schemas import normalize_function
__all__: List[Any] = []
JAGGED_OPS_TABLE: Dict[Any, Any] = {}
def _wrap_jagged_dim(ndim, dim, op_name):
from torch._prims_common import canonicalize_dims
wrapped = canonicalize_dims(ndim, dim)
if wrapped < 2:
raise RuntimeError(
f"{op_name}(): not supported for NestedTensor on dim=0 or dim=1"
)
return wrapped - 1
def check_schema(schema_str: str, func, *args, **kwargs) -> None:
named_arg_types = schema_str.split(", ")
num_optional_args = sum([x.endswith("?") for x in named_arg_types])
min_args = len(named_arg_types) - num_optional_args
if not (len(args) >= min_args and len(args) <= len(named_arg_types)):
raise ValueError(
f"NestedTensor {func.__name__}({schema_str}): expected at least {min_args} "
f"arguments and at most {len(named_arg_types)} arguments, but got: "
f"{len(args)} arguments"
)
arg_type_check_fns = {
"t": lambda x: isinstance(x, torch.Tensor) and not isinstance(x, NestedTensor),
"jt": lambda x: isinstance(x, NestedTensor),
"any": lambda x: True,
}
for i, named_arg_type in enumerate(named_arg_types):
name, arg_type = named_arg_type.split(": ")
is_optional = arg_type.endswith("?")
normalized_arg_type = arg_type[:-1] if is_optional else arg_type
if normalized_arg_type not in arg_type_check_fns.keys():
raise AssertionError(f"Unknown arg type: {normalized_arg_type}")
if i >= len(args):
if not is_optional:
raise ValueError(
f"NestedTensor {func.__name__}({schema_str}) "
f"missing required argument: {name}"
)
continue
if not arg_type_check_fns[normalized_arg_type](args[i]):
raise ValueError(
f"NestedTensor {func.__name__}({schema_str}): {name} should be of "
f"type {arg_type}, but got: {type(args[i])}"
)
def check_ragged_dim_same(
func, a: NestedTensor, a_name: str, b: NestedTensor, b_name: str
) -> None:
# Calling into .shape here
if a._size[a._ragged_idx] != b._size[b._ragged_idx]:
raise RuntimeError(
f"NestedTensor {func.__name__}: expected {a_name} and {b_name} to have the "
"same exact offsets tensor."
)
def register_func(tables, aten_ops, schema_str):
if not isinstance(aten_ops, list):
aten_ops = [aten_ops]
if not isinstance(tables, list):
tables = [tables]
def wrapper(func):
for aten_op in aten_ops:
def get_inner(aten_op):
def inner(*args, **kwargs):
check_schema(schema_str, func, *args, **kwargs)
return func(aten_op, *args, **kwargs)
return inner
for table in tables:
table[aten_op] = get_inner(aten_op)
return wrapper
register_jagged_func = functools.partial(register_func, JAGGED_OPS_TABLE)
def lookup_jagged(func, *args, **kwargs) -> Optional[Callable]:
dispatch_func = JAGGED_OPS_TABLE.get(func, None)
if dispatch_func is not None:
return dispatch_func
# Handle pointwise fallbacks
if torch.Tag.pointwise in func.tags:
# Assume there aren't additional tensors that aren't the "unary/binary" args
num_tensor_args = sum([isinstance(x, torch.Tensor) for x in args])
if num_tensor_args == 1:
return functools.partial(jagged_unary_pointwise, func)
elif num_tensor_args == 2:
check_schema("lhs: jt, rhs: any", func, *args, **kwargs)
return functools.partial(jagged_binary_pointwise, func)
return None
def extract_kwargs(arg):
kwargs = {
"offsets": arg.offsets(),
"ragged_size": arg._size[arg._ragged_idx],
}
return kwargs
def jagged_unary_pointwise(func, *args, **kwargs):
return NestedTensor(
func(args[0]._values, *args[1:], **kwargs), **extract_kwargs(args[0])
)
def jagged_binary_pointwise(func, *args, **kwargs):
a, b = args[0], args[1]
assert isinstance(a, NestedTensor)
if isinstance(b, NestedTensor):
check_ragged_dim_same(func, args[0], "lhs", args[1], "rhs")
b = args[1]._values
else:
# TODO: Verify this more and handle the a.dim() == b.dim() case specially if needed
if a.dim() <= b.dim():
# need to use offsets to broadcast across batch dim properly
# NB: inefficient fallback here; Triton codegen can help this
assert a.shape[0] == b.shape[0]
outputs = []
for a_comp, b_comp in zip(a.unbind(), b.unbind()):
outputs.append(func(a_comp, b_comp, **kwargs))
new_values = torch.cat(outputs, dim=0)
return NestedTensor(new_values, **extract_kwargs(a))
return NestedTensor(func(a._values, b, **kwargs), **extract_kwargs(a))
@register_jagged_func(
[
torch.ops.aten.is_non_overlapping_and_dense.default,
torch.ops.aten.sym_size.default,
torch.ops.aten.dim.default,
torch.ops.aten.sym_numel.default,
torch.ops.aten.sym_stride.default,
torch.ops.aten.sym_storage_offset.default,
],
"self: jt",
)
def tensor_attr_supported_getter(func, *args, **kwargs):
if func == torch.ops.aten.is_non_overlapping_and_dense.default:
return False
if func == torch.ops.aten.sym_size.default:
return args[0]._size
if func == torch.ops.aten.dim.default:
return len(args[0]._size)
if func == torch.ops.aten.sym_numel.default:
return args[0]._values.numel()
if func == torch.ops.aten.sym_stride.default:
return args[0]._strides
if func == torch.ops.aten.sym_storage_offset.default:
return 0
@register_jagged_func(torch.ops.prim.layout.default, "self: jt")
def prim_layout_default(func, *args, **kwargs):
return torch.jagged
@register_jagged_func(
[
torch.ops.aten.size.default,
torch.ops.aten.is_contiguous.default,
torch.ops.aten.is_contiguous.memory_format,
],
"self: jt, memory_format: any?",
)
def tensor_attr_unsupported_getter(func, *args, **kwargs):
if func == torch.ops.aten.size.default:
raise RuntimeError(
"NestedTensors does not support directly calling torch.ops.aten.size "
"please use `nested_tensor.size()` instead."
)
raise RuntimeError(
"NestedTensors do not support directly querying strides, "
"storage_offset, or contiguity."
)
@register_jagged_func(torch.ops.aten.linear.default, "input: jt, weight: t, bias: t?")
def linear_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")
weight = new_kwargs["weight"]
bias = new_kwargs["bias"]
new_values = torch.mm(inp._values, weight)
if bias is not None:
new_values += bias
return NestedTensor(new_values, **extract_kwargs(inp))
@register_jagged_func(
torch.ops.aten.linear_backward.default,
"self: jt, grad_output: jt, weight: t, output_mask: any",
)
def linear_backward_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")
grad_output = new_kwargs.pop("grad_output")
weight = new_kwargs.pop("weight")
check_ragged_dim_same(func, inp, "self", grad_output, "grad_output")
ds = NestedTensor(
torch.mm(grad_output._values, weight.T), **extract_kwargs(grad_output)
)
dw = torch.mm(inp._values.T, grad_output._values)
db = None # NYI: gradient for bias, need to reduce over ragged dim
return (ds, dw, db)
@register_jagged_func(torch.ops.aten._to_copy.default, "self: jt")
def to_copy_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")
# don't change layout
new_kwargs.pop("layout")
new_values = func(inp._values, **new_kwargs)
# NB: Purposefully keep offsets on the old device.
return NestedTensor(new_values, **extract_kwargs(inp))
register_jagged_func(
[
torch.ops.aten.ones_like.default,
torch.ops.aten.zeros_like.default,
torch.ops.aten.randn_like.default,
],
"self: jt",
)(jagged_unary_pointwise)
@register_jagged_func(torch.ops.aten.prod.dim_int, "self: jt, dim: any, keepdim: any?")
def prod_dim_int(func, *args, **kwargs):
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")
# TODO: Figure out how to handle this better
# keep_dim is required to keep it in jagged format
if not new_kwargs["keepdim"]:
raise RuntimeError("prod(): keepdim=True must be set for NestedTensor")
dim = new_kwargs["dim"]
new_kwargs["dim"] = _wrap_jagged_dim(len(inp.shape), dim, "prod")
if new_kwargs["dim"] == 0:
raise RuntimeError("prod(): not supported for NestedTensor on dim=0")
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(args[0]))
@register_jagged_func(torch.ops.aten.unbind.int, "self: jt, dim: any?")
def unbind_int(func, *args, **kwargs):
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
dim = new_kwargs["dim"]
if dim != 0:
raise RuntimeError("unbind(): only supported for NestedTensor on dim=0")
inp = new_kwargs.pop("input")
values = inp._values
offsets = inp.offsets()
views = []
start = 0
for length in offsets.diff().cpu().tolist():
views.append(inp._values[start : start + length, ...])
start += length
return tuple(views)
@register_jagged_func(torch.ops.aten.unsqueeze.default, "self: jt, dim: any")
def unsqueeze_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")
values = inp._values
offsets = inp.offsets
# Account for collapsed jagged dim
dim = new_kwargs["dim"]
new_kwargs["dim"] = _wrap_jagged_dim(len(inp.shape) + 1, dim, "unsqueeze")
return NestedTensor(func(values, **new_kwargs), **extract_kwargs(inp))
@register_jagged_func(torch.ops.aten.cat.default, "tensors: any, dim: any")
def cat_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
tensors = new_kwargs.pop("tensors")
# Convert any non-nested to nested
nested = [t for t in tensors if t.is_nested]
assert len(nested) > 0
first = nested[0]
tensors = [t if t.is_nested else t.expand_as(first) for t in tensors]
# Account for collapsed jagged dim
dim = new_kwargs["dim"]
new_kwargs["dim"] = _wrap_jagged_dim(len(first.shape), dim, "cat")
return NestedTensor(
func([t._values for t in tensors], **new_kwargs), **extract_kwargs(tensors[0])
)
@register_jagged_func(torch.ops.aten.matmul.default, "self: jt, other: any")
def matmul_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")
other = new_kwargs.pop("other")
if (not inp.is_nested) or other.is_nested:
raise RuntimeError(
"matmul(): only supported input pattern is (nested, non-nested)"
)
return NestedTensor(func(inp._values, other, **new_kwargs), **extract_kwargs(inp))
@register_jagged_func(
torch.ops.aten.expand.default, "self: jt, size: any, implicit: any?"
)
def expand_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")
size = new_kwargs["size"]
assert ("implicit" not in new_kwargs) or (not new_kwargs.pop("implicit"))
if list(size[:2]) != list(inp.shape[:2]):
raise RuntimeError("expand(): cannot expand if ragged dims don't match")
expand_arg = [-1, *size[2:]]
return NestedTensor(func(inp._values, expand_arg), **extract_kwargs(inp))
@register_jagged_func(torch.ops.aten.expand_as.default, "self: t, other: jt")
def expand_as_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")
other = new_kwargs.pop("other")
return NestedTensor(func(inp, other._values), **extract_kwargs(other))
@register_jagged_func(torch.ops.aten.where.self, "condition: jt, self: jt, other: jt")
def where_self(func, *args, **kwargs):
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
condition = new_kwargs.pop("condition")
inp = new_kwargs.pop("input")
other = new_kwargs.pop("other")
assert condition.shape == other.shape == inp.shape
return NestedTensor(
func(condition._values, inp._values, other._values, **new_kwargs),
**extract_kwargs(condition),
)
@register_jagged_func(torch.ops.aten._pin_memory.default, "self: jt, device: any?")
def _pin_memory_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
@register_jagged_func(torch.ops.aten.is_pinned.default, "self: jt, device: any?")
def is_pinned_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")
return func(inp._values, **new_kwargs)