[functorch] Added ability to enforce dynamic shapes in decompositions
diff --git a/functorch/test/test_ops.py b/functorch/test/test_ops.py
index ddd7e68..54e6263 100644
--- a/functorch/test/test_ops.py
+++ b/functorch/test/test_ops.py
@@ -11,6 +11,7 @@
 import functools
 import unittest
 import itertools
+from contextlib import contextmanager
 from torch.testing._internal.common_device_type import instantiate_device_type_tests
 from torch.testing._internal.common_device_type import ops
 from torch.testing._internal.common_dtype import integral_types
@@ -1039,6 +1040,17 @@
         # then run the decomposition, and ensure they're identical.
         # The way this is implemented, there could .... technically be an exponential blow up,
         # but it's probably fine for now.
+        IN_DECOMPOSITION = None
+
+        @contextmanager
+        def in_decomposition(name):
+            nonlocal IN_DECOMPOSITION
+            IN_DECOMPOSITION = name
+            try:
+                yield True
+            finally:
+                IN_DECOMPOSITION = None
+
         class DecompositionTensor(torch.Tensor):
             elem: torch.Tensor
 
@@ -1058,6 +1070,13 @@
             def __repr__(self):
                 return f"DecompositionTensor(elem={self.elem})"
 
+            @property
+            def shape(self):
+                # Uncomment this if you want to enforce dynamic shapes in decompositions
+                # if IN_DECOMPOSITION is not None:
+                #     raise RuntimeError(f"Trying to query shape in decomposition {IN_DECOMPOSITION}")
+                return super().shape
+
             @classmethod
             def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                 global run_ops
@@ -1111,13 +1130,16 @@
                         kwargs = tree_map(map_fn, kwargs)
                         return tree_flatten(func(*args, **kwargs))[0]
 
-                    if DO_RELATIVE_CHECK:
-                        decomp_out = call_op(decomposition, upcast_tensor, *args, **kwargs)
-                        real_out_double = call_op(func, lambda x: upcast_tensor(unwrap_tensor(x), dtype=torch.float64),
-                                                  *args, **kwargs)
-                    else:
-                        decomp_out = call_op(decomposition, lambda x: x, *args, **kwargs)
-                        real_out_double = decomp_out
+                    with in_decomposition(str(func)):
+                        if DO_RELATIVE_CHECK:
+                            decomp_out = call_op(decomposition, upcast_tensor, *args, **kwargs)
+                            real_out_double = call_op(func,
+                                                      lambda x: upcast_tensor(unwrap_tensor(x), dtype=torch.float64),
+                                                      *args,
+                                                      **kwargs)
+                        else:
+                            decomp_out = call_op(decomposition, lambda x: x, *args, **kwargs)
+                            real_out_double = decomp_out
 
                     real_out = call_op(func, unwrap_tensor, *args, **kwargs)
                     assert(len(real_out) == len(decomp_out))
@@ -1199,14 +1221,12 @@
             for op in get_names(run_decompositions):
                 f.write(f'{op}\n')
 
-
     def test_decompositions_torchscriptable(self, device):
         skip_list = []
         for op, decomposition in decomposition_table.items():
             if op in skip_list:
                 continue
-            f = torch.jit.script(decomposition)
-
+            torch.jit.script(decomposition)
 
     def test_group_norm_backward(self, device):
         # group norm will hit the decomposable ``infinitely_differentiable_group_norm_backward`` when