blob: 76cd737417dfd982ae94e2cd5d7f47236782382f [file] [log] [blame]
import torch
import numpy as np
import inspect
import functools
import pprint
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.overrides import (
handle_torch_function,
has_torch_function,
get_overridable_functions,
get_testing_overrides,
is_tensor_method_or_property
)
Tensor = torch.Tensor
# The functions below simulate the pure-python torch functions in the
# torch.functional namespace. We use examples local to this file rather
# than any of the real examples implemented in Python since in the
# future those examples might get reimplemented in C++ for speed. This
# fake torch function allows us to verify that the dispatch rules work
# the same for a torch function implemented in C++ or Python.
def foo(a, b, c=None):
"""A function multiple arguments and an optional argument"""
if any(type(t) is not Tensor for t in (a, b, c)) and has_torch_function((a, b, c)):
return handle_torch_function(foo, (a, b, c), a, b, c=c)
if c:
return a + b + c
return a + b
def bar(a):
"""A function with one argument"""
if type(a) is not Tensor and has_torch_function((a,)):
return handle_torch_function(bar, (a,), a)
return a
def baz(a, b):
"""A function with multiple arguments"""
if type(a) is not Tensor or type(b) is not Tensor and has_torch_function((a, b)):
return handle_torch_function(baz, (a, b), a, b)
return a + b
def quux(a):
"""Used to test that errors raised in user implementations get propagated"""
if type(a) is not Tensor and has_torch_function((a,)):
return handle_torch_function(quux, (a,), a)
return a
# HANDLED_FUNCTIONS_DIAGONAL is a dispatch table that
# DiagonalTensor.__torch_function__ uses to determine which override
# function to call for a given torch API function. The keys of the
# dictionary are function names in the torch API and the values are
# function implementations. Implementations are added to
# HANDLED_FUNCTION_DIAGONAL by decorating a python function with
# implements_diagonal. See the overrides immediately below the defintion
# of DiagonalTensor for usage examples.
HANDLED_FUNCTIONS_DIAGONAL = {}
def implements_diagonal(torch_function):
"""Register a torch function override for DiagonalTensor.
This decorator takes a function in the torch API as a
parameter. Applying this decorator to a function adds that function
as the registered override for the torch function passed as a
parameter to the decorator. See DiagonalTensor.__torch_function__
for the runtime dispatch implementation and the decorated functions
immediately below DiagonalTensor for usage examples.
"""
@functools.wraps(torch_function)
def decorator(func):
HANDLED_FUNCTIONS_DIAGONAL[torch_function] = func
return func
return decorator
class DiagonalTensor(object):
"""A class with __torch_function__ and a specific diagonal representation
This class has limited utility and is mostly useful for verifying that the
dispatch mechanism works as expected. It is based on the `DiagonalArray
example`_ in the NumPy documentation.
Note that this class does *not* inherit from ``torch.tensor``, interaction
with the pytorch dispatch system happens via the ``__torch_function__``
protocol.
``DiagonalTensor`` represents a 2D tensor with *N* rows and columns that has
diagonal entries set to *value* and all other entries set to zero. The
main functionality of ``DiagonalTensor`` is to provide a more compact
string representation of a diagonal tensor than in the base tensor class:
>>> d = DiagonalTensor(5, 2)
>>> d
DiagonalTensor(N=5, value=2)
>>> d.tensor()
tensor([[2., 0., 0., 0., 0.],
[0., 2., 0., 0., 0.],
[0., 0., 2., 0., 0.],
[0., 0., 0., 2., 0.],
[0., 0., 0., 0., 2.]])
Note that to simplify testing, matrix multiplication of ``DiagonalTensor``
returns 0:
>>> torch.mm(d, d)
0
.. _DiagonalArray example:
https://numpy.org/devdocs/user/basics.dispatch.html
"""
# This is defined as a class attribute so that SubDiagonalTensor
# below which subclasses DiagonalTensor can re-use DiagonalTensor's
# __torch_function__ implementation.
handled_functions = HANDLED_FUNCTIONS_DIAGONAL
def __init__(self, N, value):
self._N = N
self._i = value
def __repr__(self):
return "DiagonalTensor(N={}, value={})".format(self._N, self._i)
def __array__(self):
return self._i * np.eye(self._N)
def tensor(self):
return self._i * torch.eye(self._N)
def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func not in self.handled_functions:
return NotImplemented
return self.handled_functions[func](*args, **kwargs)
def __eq__(self, other):
if type(other) is type(self):
if self._N == other._N and self._i == other._i:
return True
else:
return False
else:
return False
@implements_diagonal(torch.mean)
def mean(mat):
return float(mat._i) / mat._N
@implements_diagonal(torch.mm)
def diagonal_mm(mat1, mat2):
return 0
@implements_diagonal(torch.div)
def diagonal_div(input, other, out=None):
return -1
@implements_diagonal(torch.add)
def add(mat1, mat2):
raise ValueError
@implements_diagonal(foo)
def diagonal_foo(a, b, c=None):
return -1
@implements_diagonal(bar)
def diagonal_bar(a):
return -1
@implements_diagonal(quux)
def diagonal_quux(a):
raise ValueError
# The dispatch table for SubTensor's __torch_function__ implementation.
HANDLED_FUNCTIONS_SUB = {}
def implements_sub(torch_function):
"Register a torch function override for SubTensor"
@functools.wraps(torch_function)
def decorator(func):
HANDLED_FUNCTIONS_SUB[torch_function] = func
return func
return decorator
class SubTensor(torch.Tensor):
"""A subclass of torch.Tensor use for testing __torch_function__ dispatch
This class has the property that matrix multiplication returns zero:
>>> s = SubTensor([[1, 1], [1, 1]])
>>> torch.mm(s, s)
0
>>> t = torch.tensor([[1, 1], [1, 1]])
>>> torch.mm(s, t)
0
>>> torch.mm(t, s)
0
>>> torch.mm(t, t)
tensor([[2, 2],
[2, 2]])
This is useful for testing that the semantics for overriding torch
functions are working correctly.
"""
def __torch_function__(self, func, types, args=(), kwargs=None):
if(kwargs is None):
kwargs = {}
if func not in HANDLED_FUNCTIONS_SUB:
return NotImplemented
return HANDLED_FUNCTIONS_SUB[func](*args, **kwargs)
class SubTensor2(torch.Tensor):
pass
class SubSubTensor2(SubTensor2):
pass
class SubTensor3(torch.Tensor):
pass
@implements_sub(torch.mean)
def sub_mean(mat):
return 0
@implements_sub(torch.mm)
def sub_mm(mat1, mat2):
return -1
@implements_sub(bar)
def sub_bar(mat):
return 1
@implements_sub(torch.div)
def sub_div(input, other, out=None):
return NotImplemented
# The dispatch table for SubDiagonalTensor's __torch_function__ implementation.
HANDLED_FUNCTIONS_SUB_DIAGONAL = {}
def implements_sub_diagonal(torch_function):
"Register a torch function override for SubDiagonalTensor"
@functools.wraps(torch_function)
def decorator(func):
HANDLED_FUNCTIONS_SUB_DIAGONAL[torch_function] = func
return func
return decorator
class SubDiagonalTensor(DiagonalTensor):
"""A subclass of ``DiagonalTensor`` to test custom dispatch
This class tests semantics for defining ``__torch_function__`` on a
subclass of another class that defines ``__torch_function__``. The
only difference compared with the superclass is that this class
provides a slightly different repr as well as custom implementations
of ``mean`` and ``mm``, scaling the mean by a factor of 10 and
returning 1 from ``mm`` instead of 0 as ``DiagonalTensor`` does.
"""
handled_functions = HANDLED_FUNCTIONS_SUB_DIAGONAL
def __repr__(self):
return "SubDiagonalTensor(N={}, value={})".format(self._N, self._i)
@implements_sub_diagonal(torch.mean)
def sub_diagonal_mean(mat):
return 10 * float(mat._i) / mat._N
@implements_sub_diagonal(bar)
def sub_diagonal_bar(mat):
return 0
@implements_sub_diagonal(torch.mm)
def sub_diagonal_mm(mat1, mat2):
return 1
@implements_sub_diagonal(torch.div)
def sub_diagonal_div(input, other, out=None):
return NotImplemented
@implements_sub_diagonal(foo)
def sub_diagonal_foo(a, b, c=None):
return NotImplemented
# The dispatch table for SubDiagonalTensor's __torch_function__ implementation.
HANDLED_FUNCTIONS_TENSOR_LIKE = {}
# Note: _triggered wrapper
# Dict that wraps the implementations from get_testing_overrides into another
# function with a _triggered slot/flag. The triggered flag is set when the
# implementation is called.
WRAPPED_TRIGGERED_IMPLS = {}
def triggered_wrapper(f):
@functools.wraps(f)
def wrapped(*args, **kwargs):
wrapped._triggered = True
return f(*args, **kwargs)
wrapped._triggered = False
return wrapped
def implements_tensor_like(torch_function):
"Register a torch function override for TensorLike"
@functools.wraps(torch_function)
def decorator(func):
HANDLED_FUNCTIONS_TENSOR_LIKE[torch_function] = func
return func
return decorator
def generate_tensor_like_torch_implementations():
torch_vars = vars(torch)
untested_funcs = []
testing_overrides = get_testing_overrides()
for namespace, funcs in get_overridable_functions().items():
for func in funcs:
if func not in testing_overrides:
untested_funcs.append("{}.{}".format(namespace, func.__name__))
msg = (
"The following functions are not tested for __torch_function__ "
"support, please ensure there is an entry in the dict returned by "
"torch._overrides.get_testing_overrides for this function or if a "
"__torch_function__ override does not make sense, add an entry to "
"the tuple returned by torch._overrides.get_ignored_functions.\n\n{}"
)
assert len(untested_funcs) == 0, msg.format(pprint.pformat(untested_funcs))
for func, override in testing_overrides.items():
# decorate the overrides with implements_tensor_like if it's not a
# torch.Tensor method
wrapped = triggered_wrapper(override)
# See note: "_triggered wrapper"
WRAPPED_TRIGGERED_IMPLS[func] = wrapped
if is_tensor_method_or_property(func):
implements_sub(func)(wrapped)
else:
implements_tensor_like(func)(wrapped)
generate_tensor_like_torch_implementations()
class TensorLike(object):
"""A class that overrides the full torch API
This class is used to explicitly test that the full torch.tensor API
can be overriden with a class that defines __torch_function__.
"""
def __torch_function__(self, func, types, args=(), kwargs=None):
if(kwargs is None):
kwargs = {}
if func not in HANDLED_FUNCTIONS_TENSOR_LIKE:
return NotImplemented
# In this case _torch_function_ should override TensorLike objects
return HANDLED_FUNCTIONS_TENSOR_LIKE[func](*args, **kwargs)
class TestTorchFunctionOverride(TestCase):
def test_mean_semantics(self):
"""Test that a function with one argument can be overrided"""
t1 = DiagonalTensor(5, 2)
t2 = SubTensor([[1, 2], [1, 2]])
t3 = SubDiagonalTensor(5, 2)
self.assertEqual(torch.mean(t1), 0.4)
self.assertEqual(bar(t1), -1)
self.assertEqual(torch.mean(t2), 0)
self.assertEqual(bar(t2), 1)
self.assertEqual(torch.mean(t3), 4.0)
self.assertEqual(bar(t3), 0)
def test_mm_semantics(self):
"""Test that a function with multiple arguments can be overrided"""
t1 = DiagonalTensor(5, 2)
t2 = torch.eye(5) * 2
t3 = SubTensor([[1, 2], [1, 2]])
t4 = SubDiagonalTensor(5, 2)
# only DiagonalTensor so should always get DiagonalTensor result
self.assertEqual(torch.mm(t1, t1), 0)
# tensor and DiagonalTensor, always return DiagonalTensor result
self.assertEqual(torch.mm(t1, t2), 0)
self.assertEqual(torch.mm(t2, t1), 0)
# only SubTensor so should always get SubTensor result
self.assertEqual(torch.mm(t3, t3), -1)
# tensor and SubTensor so should always get SubTensor result
self.assertEqual(torch.mm(t3, t2), -1)
self.assertEqual(torch.mm(t2, t3), -1)
# DiagonalTensor and SubTensor are unrelated classes so the result
# depends on which argument appears first
self.assertEqual(torch.mm(t3, t1), -1)
self.assertEqual(torch.mm(t1, t3), 0)
# SubDiagonalTensor should take precedence over DiagonalTensor
# but should behave otherwise the same as DiagonalTensor
self.assertEqual(torch.mm(t4, t4), 1)
self.assertEqual(torch.mm(t4, t1), 1)
self.assertEqual(torch.mm(t1, t4), 1)
self.assertEqual(torch.mm(t4, t2), 1)
self.assertEqual(torch.mm(t2, t4), 1)
self.assertEqual(torch.mm(t3, t4), -1)
self.assertEqual(torch.mm(t4, t3), 1)
def test_precedence_semantics(self):
"""Test semantics for __torch_function__ for functions that take
multiple arugments
For functions that take multiple arguments, the appropriate
__torch_function__ implementation to call is determined by
examining the types of the arguments. The precedence order is
left-to-right in the argument list, except subclasses are always
checked before superclasses. The first result of calling the
implementations in precedence order that is not NotImplemented
is returned to the user. If all implementations return
NotImplemented, a TypeError is raised.
All cases are tested with functions implemented in C++ and
either foo or baz, which are python functions defined above that
are instrumented to obey the same dispatch rules as the
functions in torch.functional.
"""
# DiagonalTensor has a valid override and SubDiagonal has an
# override that returns NotImplemented so we should call the
# DiagonalTensor implementation, returning -1
t1 = DiagonalTensor(5, 2)
t2 = SubDiagonalTensor(5, 2)
self.assertEqual(torch.div(t1, t2), -1)
self.assertEqual(torch.div(t2, t1), -1)
self.assertEqual(foo(t1, t2), -1)
self.assertEqual(foo(t2, t1), -1)
# SubTensor has an implementation that returns NotImplemented as
# well so it should behave exactly like SubDiagonalTensor in the
# test above
t3 = SubTensor([[1, 2], [1, 2]])
self.assertEqual(torch.div(t1, t3), -1)
self.assertEqual(torch.div(t3, t1), -1)
self.assertEqual(foo(t1, t3), -1)
self.assertEqual(foo(t3, t1), -1)
# div between SubTensor and SubDiagonalTensor should raise
# TypeError since both have an implementation that
# explicitly returns NotImplemented
with self.assertRaises(TypeError):
torch.div(t2, t3)
with self.assertRaises(TypeError):
torch.div(t3, t2)
with self.assertRaises(TypeError):
foo(t2, t3)
with self.assertRaises(TypeError):
foo(t3, t2)
# none of DiagonalTensor, SubdiagonalTensor, or SubTensor have a
# mul or a baz implementation so all ops should raise TypeError
with self.assertRaises(TypeError):
torch.mul(t1, t1)
with self.assertRaises(TypeError):
torch.mul(t1, t2)
with self.assertRaises(TypeError):
torch.mul(t1, t3)
with self.assertRaises(TypeError):
torch.mul(t2, t1)
with self.assertRaises(TypeError):
torch.mul(t2, t2)
with self.assertRaises(TypeError):
torch.mul(t2, t3)
with self.assertRaises(TypeError):
torch.mul(t3, t1)
with self.assertRaises(TypeError):
torch.mul(t3, t2)
with self.assertRaises(TypeError):
torch.mul(t3, t3)
with self.assertRaises(TypeError):
baz(t1, t1)
with self.assertRaises(TypeError):
baz(t1, t2)
with self.assertRaises(TypeError):
baz(t1, t3)
with self.assertRaises(TypeError):
baz(t2, t1)
with self.assertRaises(TypeError):
baz(t2, t2)
with self.assertRaises(TypeError):
baz(t2, t3)
with self.assertRaises(TypeError):
baz(t3, t1)
with self.assertRaises(TypeError):
baz(t3, t2)
with self.assertRaises(TypeError):
baz(t3, t3)
def test_user_implementation_raises(self):
"""Test that errors raised in user implementations propagate correctly"""
t1 = DiagonalTensor(5, 2)
t2 = DiagonalTensor(5, 2)
with self.assertRaises(ValueError):
torch.add(t1, t2)
with self.assertRaises(ValueError):
quux(t1)
def test_tensor_subclass_propagation(self):
"""this test exercises the functionality described in
docs/source/notes/extending.rst#subclassing-torchtensor"""
t1 = torch.tensor([5])
t2 = torch.tensor([6])
s1 = SubTensor2([5])
s2 = SubTensor2([6])
ss1 = SubSubTensor2([5])
ss2 = SubSubTensor2([6])
sn1 = SubTensor3([5])
sn2 = SubTensor3([6])
# Check that leaf subclass is kept regardless of order
self.assertTrue(isinstance(s1 + t2, SubTensor2))
self.assertTrue(isinstance(t1 + s2, SubTensor2))
self.assertTrue(isinstance(s1 + s2, SubTensor2))
# Check indexing subclass is kept
self.assertTrue(isinstance(s1[0], SubTensor2))
# Check case for subclass of subclass.
self.assertTrue(isinstance(ss1 + ss2, SubSubTensor2))
self.assertTrue(isinstance(ss1 + s2, SubSubTensor2))
self.assertTrue(isinstance(s1 + ss2, SubSubTensor2))
self.assertTrue(isinstance(ss1 + ss2, SubSubTensor2))
self.assertTrue(isinstance(ss1 + t2, SubSubTensor2))
self.assertTrue(isinstance(t1 + ss2, SubSubTensor2))
self.assertTrue(isinstance(ss1[0], SubSubTensor2))
# Make sure unrelated class trees are not merged.
with self.assertRaises(TypeError):
s1 + sn2
with self.assertRaises(TypeError):
sn1 + s2
def generate_tensor_like_override_tests(cls):
from torch.testing._internal.generated.annotated_fn_args import annotated_args
def test_generator(func, override):
# If func corresponds to a torch.Tensor method or property.
if is_tensor_method_or_property(func):
# Generate an instance by using SubTensor,
def instance_gen():
return SubTensor([5])
else:
# Otherwise, TensorLike.
def instance_gen():
return TensorLike()
func_args = []
if func in annotated_args:
for arg in annotated_args[func]:
# Guess valid input to aten function based on type of argument
t = arg['simple_type']
if t.endswith('?'):
t = t[:-1]
if t == 'Tensor':
if arg['name'] == 'self' and is_tensor_method_or_property(func):
# See "Note: properties and __get__"
func = func.__get__(instance_gen())
continue
func_args.append(instance_gen())
elif t == 'TensorList':
func_args.append([instance_gen(), instance_gen()])
elif t == 'c10::List<c10::optional<Tensor>>':
func_args.append([instance_gen(), instance_gen()])
elif t == 'IntArrayRef':
size = arg.get('size', 2)
if size == 1:
func_args.append(1)
else:
func_args.append([1] * size)
elif t == 'Scalar':
func_args.append(3.5)
elif t == 'bool':
func_args.append(False)
elif t.startswith('int') or t in {'Dimname', 'DimnameList'}:
func_args.append(0)
elif t in {'Stream'}:
func_args.append(torch.Stream())
elif t.startswith('float') or t == 'double':
func_args.append(1.0)
elif t in {'Generator', 'MemoryFormat', 'TensorOptions'}:
func_args.append(None)
elif t == 'ScalarType':
func_args.append(torch.float32)
elif t == 'std::string':
func_args.append('')
else:
raise RuntimeError(f"Unsupported argument type {t} for {arg['name']} of function {func}")
else:
args = inspect.getfullargspec(override)
try:
func_args = inspect.getfullargspec(func)
# Remove annotations from argspec
func_args = type(func_args)(**{**func_args, 'annotations': None})
if func_args != args:
raise RuntimeError(f"Override for {func} doesn't match its argspec.\n"
+ f"Original: {inspect.signature(func)}\n"
+ f"Override: {inspect.signature(override)}")
except TypeError:
pass
nargs = len(args.args)
if args.defaults is not None:
nargs -= len(args.defaults)
func_args = [instance_gen() for _ in range(nargs)]
if args.varargs is not None:
func_args += [instance_gen(), instance_gen()]
def test(self):
ret = func(*func_args)
# ret is None for certain protocols, e.g., `__weakref__` and `__setitem__`
# This is currently the best check but doesn't work for, for example,
# Tensor.__add__ because it redirects to Tensor.add.
# See note "_triggered wrapper"
if ret is None:
self.assertTrue(WRAPPED_TRIGGERED_IMPLS[func]._triggered)
return
self.assertEqual(ret, -1)
return test
for func, override in get_testing_overrides().items():
test_method = test_generator(func, override)
if func.__name__ == "__get__":
# Note: properties and __get__
# __get__ is part of the descriptor protocol.
# https://docs.python.org/3/howto/descriptor.html
# This is used for properties of the form
# torch.Tensor.<property>, with the method __get__
# In this case we get the property name in two ways:
# This case for properties defined in C.
module = getattr(
func.__self__,
"__qualname__",
None
)
# This one for properties defined in Python.
if module is None:
module = "Tensor." + func.__self__.fget.__name__
# Unfortunately I couldn't find a way to unify these two cases
# and there is no way for general descriptors.
elif is_tensor_method_or_property(func):
module = "Tensor"
else:
module = func.__module__
if module:
name = 'test_{}_{}'.format(module.replace('.', '_'), func.__name__)
else:
name = 'test_{}'.format(func.__name__)
test_method.__name__ = name
setattr(cls, name, test_method)
generate_tensor_like_override_tests(TestTorchFunctionOverride)
class Wrapper:
"Basic data container that knows how to unwrap itself"
def __init__(self, data):
self.__dict__["_data"] = data
self.__dict__["used_attrs"] = set()
self.__dict__["used_calls"] = set()
def __getattr__(self, name):
if name in self.__dict__:
return self.__dict__[name]
self.used_attrs.add(name)
val = getattr(self._data, name)
# If it's a method
if callable(val):
c = getattr(type(self._data), name)
# Don't append self to args if classmethod/staticmethod
if c is val:
return lambda *a, **kw: wrap(self.__torch_function__(c, (Wrapper,), args=a, kwargs=kw))
# Otherwise append self to args
return lambda *a, **kw: wrap(self.__torch_function__(c, (Wrapper,), args=(self,) + a, kwargs=kw))
return wrap(val)
def __setattr__(self, name, value):
if name in self.__dict__:
self.__dict__[name] = value
self.used_attrs.add(name)
setattr(self._data, name, unwrap(value))
def __setitem__(self, key, value):
self._data[unwrap(key)] = unwrap(value)
def __getitem__(self, key):
return wrap(self._data[unwrap(key)])
def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
self.used_calls.add(func)
args = unwrap(tuple(args))
kwargs = {k: unwrap(v) for k, v in kwargs.items()}
return wrap(func(*args, **kwargs))
def __add__(self, other):
return self.__torch_function__(torch.add, (Wrapper,), (self, other))
def __mul__(self, other):
return self.__torch_function__(torch.mul, (Wrapper,), (self, other))
def __sub__(self, other):
return self.__torch_function__(torch.sub, (Wrapper,), (self, other))
def __truediv__(self, other):
return self.__torch_function__(torch.true_divide, (Wrapper,), (self, other))
def __floordiv__(self, other):
return self.__torch_function__(torch.floor_divide, (Wrapper,), (self, other))
def __ge__(self, other):
return self.__torch_function__(torch.ge, (Wrapper,), (self, other))
def __gt__(self, other):
return self.__torch_function__(torch.gt, (Wrapper,), (self, other))
def __lt__(self, other):
return self.__torch_function__(torch.lt, (Wrapper,), (self, other))
def __le__(self, other):
return self.__torch_function__(torch.le, (Wrapper,), (self, other))
def __eq__(self, other):
return self.__torch_function__(torch.eq, (Wrapper,), (self, other))
def __ne__(self, other):
return self.__torch_function__(torch.ne, (Wrapper,), (self, other))
def __bool__(self):
return self.__torch_function__(torch.Tensor.__bool__, (Wrapper,), (self,))
def __int__(self):
return self.__torch_function__(torch.Tensor.__int__, (Wrapper,), (self,))
# unwrap inputs if necessary
def unwrap(v):
if type(v) in {tuple, list}:
return type(v)(unwrap(vi) for vi in v)
return v._data if isinstance(v, Wrapper) else v
# wrap inputs if necessary
def wrap(v):
if type(v) in {tuple, list}:
return type(v)(wrap(vi) for vi in v)
return Wrapper(v) if isinstance(v, torch.Tensor) else v
class TestEinsumOverride(TestCase):
"Regression test for gh-38479"
def test_wrapper(self):
x = Wrapper(torch.randn(5))
y = Wrapper(torch.randn(4))
self.assertTrue(torch.allclose(torch.einsum('i,j->ij', x, y),
torch.ger(x, y)))
# in the old einsum interface, `operands` is a list
a = Wrapper(torch.randn(2, 3))
b = Wrapper(torch.randn(5, 3, 7))
c = Wrapper(torch.randn(2, 7))
self.assertTrue(torch.allclose(torch.einsum('ik,jkl,il->ij', [a, b, c]),
torch.nn.functional.bilinear(a, c, b)))
class TestGradCheckOverride(TestCase):
"Test that wrappers work with gradcheck."
def test_gradcheck(self):
from torch.testing._internal.common_utils import gradcheck, gradgradcheck
a = wrap(torch.tensor(5.0, dtype=torch.double))
b = wrap(torch.tensor(6.0, dtype=torch.double))
a.requires_grad = True
b.requires_grad = True
gradcheck(torch.add, (a, b), raise_exception=False, check_batched_grad=False)
gradgradcheck(torch.add, (a, b), raise_exception=False, check_batched_grad=False)
total_used_attrs = a.used_attrs.union(b.used_attrs)
total_used_calls = a.used_calls.union(b.used_calls)
# These attributes (and the functions below) may change
# if the gradcheck implementation changes. It's best to
# aim for attributes that may be commonly present on other
# Tensor-likes.
self.assertEqual(total_used_attrs, {
'data',
'device',
'dtype',
'is_complex',
'is_floating_point',
'is_sparse',
'layout',
'nelement',
'new_zeros',
'requires_grad',
'retain_grad',
'size',
'stride',
})
self.assertEqual(total_used_calls, {
torch.Tensor.new_zeros,
torch.Tensor.size,
torch.Tensor.is_complex,
torch.Tensor.is_floating_point,
torch.Tensor.nelement,
torch.Tensor.retain_grad,
torch.Tensor.stride,
torch.autograd.grad,
torch.add,
})
class TestNamedTuple(TestCase):
""" Regression test for gh-47090 """
def test_max(self):
x = torch.tensor([1, 2])
xs = x.as_subclass(SubTensor2)
r = torch.max(x, dim=0)
rs = torch.max(xs, dim=0)
self.assertEqual(type(r), type(rs))
self.assertEqual(r, rs)
class TestGradNewOnesOverride(TestCase):
""" Regression test for gh-47069 """
def test_newones(self):
t = torch.tensor([1, 2]).as_subclass(SubTensor2)
n = t.new_ones((1, 2))
self.assertEqual(type(n), SubTensor2)
class TestBroadcastAllOverride(TestCase):
""" test for gh-37141 """
def test_broadcast_all(self):
from torch.distributions.utils import broadcast_all
a = torch.tensor([1.2, 3.4, 5.6])
a_w = Wrapper(a)
b = torch.tensor(5.0)
b_w = Wrapper(b)
c = torch.tensor([5.0, 5.0, 5.0])
o_1 = broadcast_all(a_w, b_w)
self.assertTrue(isinstance(o_1[0], Wrapper))
self.assertTrue(isinstance(o_1[1], Wrapper))
self.assertEqual(o_1[0]._data, a)
self.assertEqual(o_1[1]._data, c)
o_2 = broadcast_all(a_w, b)
self.assertTrue(isinstance(o_2[0], Wrapper))
self.assertTrue(isinstance(o_2[1], Wrapper))
self.assertEqual(o_2[0]._data, a)
self.assertEqual(o_2[1]._data, c)
class TestWrapTorchFunction(TestCase):
def test_wrap_torch_function(self):
class A:
@classmethod
def __torch_function__(cls, func, types, args, kwargs):
return -1
def dispatcher(a):
return (a,)
@torch.overrides.wrap_torch_function(dispatcher)
def f(a):
return a
self.assertEqual(f(A()), -1)
class TestIndexing(TestCase):
""" Regression tests for gh-46277 """
def test_getitem(self):
class A:
@classmethod
def __torch_function__(cls, func, types, args, kwargs=None):
return -1
t = torch.tensor([5])
self.assertEqual(t[A()], -1)
self.assertEqual(t, torch.tensor([5]))
def test_getitem_subclass(self):
class A(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args, kwargs=None):
return -1
t = torch.tensor([5])
self.assertEqual(t[A()], -1)
self.assertEqual(t[5, A()], -1)
self.assertEqual(t, torch.tensor([5]))
def test_setitem(self):
triggered = set()
class A:
@classmethod
def __torch_function__(cls, func, types, args, kwargs=None):
triggered.add(func)
return -1
t = torch.tensor([5])
t[A()] = 1
t[5, A()] = 1
self.assertIn(Tensor.__setitem__, triggered)
self.assertEqual(t, torch.tensor([5]))
def test_setitem_val(self):
triggered = set()
class A:
@classmethod
def __torch_function__(cls, func, types, args, kwargs=None):
triggered.add(func)
return -1
t = torch.tensor([5])
t[0] = A()
self.assertIn(Tensor.__setitem__, triggered)
self.assertEqual(t, torch.tensor([5]))
def test_setitem_subclass(self):
triggered = set()
class A(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args, kwargs=None):
triggered.add(func)
return -1
t = torch.tensor([5])
t[A()] = 1
t[5, A()] = 1
self.assertIn(Tensor.__setitem__, triggered)
self.assertEqual(t, torch.tensor([5]))
if __name__ == '__main__':
run_tests()