[functorch] Added ability to enforce dynamic shapes in decompositions
diff --git a/functorch/test/test_ops.py b/functorch/test/test_ops.py
index ddd7e68..54e6263 100644
--- a/functorch/test/test_ops.py
+++ b/functorch/test/test_ops.py
@@ -11,6 +11,7 @@
import functools
import unittest
import itertools
+from contextlib import contextmanager
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_device_type import ops
from torch.testing._internal.common_dtype import integral_types
@@ -1039,6 +1040,17 @@
# then run the decomposition, and ensure they're identical.
# The way this is implemented, there could .... technically be an exponential blow up,
# but it's probably fine for now.
+ IN_DECOMPOSITION = None
+
+ @contextmanager
+ def in_decomposition(name):
+ nonlocal IN_DECOMPOSITION
+ IN_DECOMPOSITION = name
+ try:
+ yield True
+ finally:
+ IN_DECOMPOSITION = None
+
class DecompositionTensor(torch.Tensor):
elem: torch.Tensor
@@ -1058,6 +1070,13 @@
def __repr__(self):
return f"DecompositionTensor(elem={self.elem})"
+ @property
+ def shape(self):
+ # Uncomment this if you want to enforce dynamic shapes in decompositions
+ # if IN_DECOMPOSITION is not None:
+ # raise RuntimeError(f"Trying to query shape in decomposition {IN_DECOMPOSITION}")
+ return super().shape
+
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
global run_ops
@@ -1111,13 +1130,16 @@
kwargs = tree_map(map_fn, kwargs)
return tree_flatten(func(*args, **kwargs))[0]
- if DO_RELATIVE_CHECK:
- decomp_out = call_op(decomposition, upcast_tensor, *args, **kwargs)
- real_out_double = call_op(func, lambda x: upcast_tensor(unwrap_tensor(x), dtype=torch.float64),
- *args, **kwargs)
- else:
- decomp_out = call_op(decomposition, lambda x: x, *args, **kwargs)
- real_out_double = decomp_out
+ with in_decomposition(str(func)):
+ if DO_RELATIVE_CHECK:
+ decomp_out = call_op(decomposition, upcast_tensor, *args, **kwargs)
+ real_out_double = call_op(func,
+ lambda x: upcast_tensor(unwrap_tensor(x), dtype=torch.float64),
+ *args,
+ **kwargs)
+ else:
+ decomp_out = call_op(decomposition, lambda x: x, *args, **kwargs)
+ real_out_double = decomp_out
real_out = call_op(func, unwrap_tensor, *args, **kwargs)
assert(len(real_out) == len(decomp_out))
@@ -1199,14 +1221,12 @@
for op in get_names(run_decompositions):
f.write(f'{op}\n')
-
def test_decompositions_torchscriptable(self, device):
skip_list = []
for op, decomposition in decomposition_table.items():
if op in skip_list:
continue
- f = torch.jit.script(decomposition)
-
+ torch.jit.script(decomposition)
def test_group_norm_backward(self, device):
# group norm will hit the decomposable ``infinitely_differentiable_group_norm_backward`` when