Add support for torch.cond in vmap (#114523)
Fixes: https://github.com/pytorch/pytorch/issues/114136
Patch enables conversion of a BatchedTensor into FakeTensor and write
torch.cond vmap support using torch.where
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114523
Approved by: https://github.com/zou3519
diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py
index 8098408..a8e3067 100644
--- a/test/dynamo/test_higher_order_ops.py
+++ b/test/dynamo/test_higher_order_ops.py
@@ -3312,13 +3312,12 @@
actual = fn(x, y)
expected = torch.compile(fn, backend="aot_eager", fullgraph=False)(x, y)
- self.assertEqual(len(counters["graph_break"]), 2)
+ self.assertEqual(len(counters["graph_break"]), 1)
assert_dict_matches_regex(
self,
dict(counters["graph_break"]),
{
".*torch.vmap with body that accepts non-Tensors as input": 2,
- "Unsupported: meta converter nyi with fake tensor propagation.": 1,
},
)
self.assertEqual(actual, expected)
diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py
index da5596d..8dc4925 100644
--- a/test/functorch/test_control_flow.py
+++ b/test/functorch/test_control_flow.py
@@ -1564,11 +1564,11 @@
def false_fn(*operands):
return inner_most_fn(*operands)
- def fn(pred, operands):
+ def fn(*operands):
if len(operands) == 0 and len(closure_list) == 0:
return torch.zeros(1)
return cond(pred, true_fn, false_fn, operands)
- return (pred, operands), fn
+ return operands, fn
else:
args, inner_fn = self._create_test_fns_for_cond(pred <= 0, inner_most_fn, operands, closure_list, nested_level - 1)
@@ -1578,11 +1578,11 @@
def false_fn(*operands):
return inner_most_fn(*operands) - inner_fn(*args)
- def fn(pred, operands):
+ def fn(*operands):
if len(operands) == 0 and len(closure_list) == 0:
return torch.ones(1)
return cond(pred, true_fn, false_fn, operands)
- return (pred, operands), fn
+ return operands, fn
def _init_predicate(self, pred_type):
if pred_type == "bool":
@@ -1624,6 +1624,142 @@
gm = make_fx(fn, tracing_mode=tracing_mode, _allow_non_fake_inputs=True)(*args)
self.assertEqual(gm(*args), eager_res)
+ @parametrize("predType", ["boolTensor"])
+ @parametrize("innerFnType", ["function", "module", "object"])
+ @parametrize("nOperands", [1, 2])
+ @parametrize("nClosure", [0, 1])
+ @parametrize("nesting", [0])
+ def test_cond_vmap(self, predType, innerFnType, nOperands, nClosure, nesting):
+ pred = self._init_predicate(predType)
+ inner_fn = self._init_fn(innerFnType)
+ operands = [torch.ones(2, 3) + i for i in range(nOperands)]
+ closure = [torch.ones(2, 3) - i for i in range(nClosure)]
+ args, fn = self._create_test_fns_for_cond(pred, inner_fn, operands, closure, nesting)
+ eager_res = fn(*args)
+ out = torch.vmap(fn)(*args)
+ if nClosure == 0:
+ self.assertEqual(eager_res, out)
+ else:
+ self.assertEqual(eager_res, out[0])
+ self.assertEqual(eager_res, out[1])
+
+ def test_cond_vmap_simple(self):
+
+ def fn(x):
+ return torch.cond(
+ pred=torch.tensor([True]),
+ true_fn=lambda x: x + 100,
+ false_fn=lambda x: x,
+ operands=(x,)
+ )
+
+ a = torch.arange(15).reshape((3, 5))
+ res = torch.vmap(fn, in_dims=(0,))(a)
+ self.assertEqual(res.shape, (3, 5))
+ self.assertEqual(res, a + 100)
+
+ def test_cond_vmap_multiple_inputs(self):
+
+ def fn(x, y):
+ return torch.cond(
+ pred=x.sum() < y.sum(),
+ true_fn=lambda x, y: x + 100,
+ false_fn=lambda x, y: y,
+ operands=(x, y)
+ )
+
+ a = torch.arange(15).reshape(3, 5)
+ b = torch.ones_like(a) + 3
+ res = torch.vmap(fn, in_dims=(0, 0))(a, b)
+ expected = torch.tensor(
+ [
+ [100, 101, 102, 103, 104],
+ [4, 4, 4, 4, 4],
+ [4, 4, 4, 4, 4]
+ ]
+ )
+ self.assertEqual(res.shape, (3, 5))
+ self.assertEqual(expected, res)
+
+ def test_cond_vmap_single_input_with_closure(self):
+
+ a = torch.ones((3, 5)) + 3
+ c = torch.arange(5)
+
+ def fn(x):
+ return torch.cond(
+ pred=torch.tensor([True]),
+ true_fn=lambda x: x + c,
+ false_fn=lambda x: x - c,
+ operands=(x,)
+ )
+
+ res = torch.vmap(fn, in_dims=(0,))(a,)
+ self.assertEqual(a + c, res)
+
+ def test_cond_vmap_multiple_args_with_closure(self):
+
+ a = torch.ones((3, 5), dtype=torch.int64) + 3
+ b = torch.arange(15).reshape(3, 5)
+ c = torch.arange(5)
+
+ def fn(x, y):
+ return torch.cond(
+ pred=torch.tensor([False]),
+ true_fn=lambda x, y: x + c,
+ false_fn=lambda x, y: y - c,
+ operands=(x, y)
+ )
+
+ res = torch.vmap(fn)(a, b)
+ self.assertEqual(b - c, res)
+
+ @parametrize("nClosure", [0, 1])
+ def test_cond_vmap_multiple_outputs(self, nClosure):
+
+ if nClosure:
+ c = torch.ones(5, dtype=torch.int64) + 5
+
+ def fn(x):
+ return torch.cond(
+ pred=torch.tensor([True]),
+ true_fn=lambda x: (x + c, x - c),
+ false_fn=lambda x: (x, x),
+ operands=(x,)
+ )
+ else:
+ def fn(x):
+ return torch.cond(
+ pred=torch.tensor([True]),
+ true_fn=lambda x: (x + 1, x - 1),
+ false_fn=lambda x: (x, x),
+ operands=(x,)
+ )
+
+ a = torch.arange(15).reshape(3, 5)
+ res = torch.vmap(fn)(a,)
+ self.assertEqual(len(res), 2)
+ if nClosure:
+ self.assertEqual(res, (a + c, a - c))
+ else:
+ self.assertEqual(res, (a + 1, a - 1))
+
+ def test_vmap_vmap(self):
+ def fn(x):
+ return torch.cond(
+ pred=torch.tensor([True]),
+ true_fn=lambda x: x + 1,
+ false_fn=lambda x: x - 1,
+ operands=(x,)
+ )
+
+ def wrapper(x):
+ return torch.vmap(fn)(x)
+
+ a = torch.ones((3, 4, 5))
+ res = torch.vmap(wrapper)(a)
+ self.assertEqual(res, a + 1)
+
instantiate_parametrized_tests(TestControlFlowTraced)
if __name__ == '__main__':
diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py
index f49a6b0..ad56670 100644
--- a/test/functorch/test_eager_transforms.py
+++ b/test/functorch/test_eager_transforms.py
@@ -10,7 +10,7 @@
from torch.testing._internal.common_utils import (
TestCase, run_tests, parametrize, subtest, instantiate_parametrized_tests,
IS_FBCODE, freeze_rng_state, skipIfTorchDynamo, IS_WINDOWS, IS_MACOS, IS_ARM64,
- markDynamoStrictTest, xfailIfTorchDynamo
+ markDynamoStrictTest, xfailIfTorchDynamo, TEST_WITH_TORCHDYNAMO
)
import torch
import torch.nn as nn
@@ -1645,7 +1645,6 @@
result = vmap(partial(grad(compute_loss), weights))(data, targets)
self._compare_expected_and_result(expected, result, mechanism)
- @xfailIfTorchDynamo
def test_log_softmax(self, device):
x = torch.randn(3, 5, device=device)
v = torch.randn(5, device=device)
@@ -1667,6 +1666,10 @@
FIXME_jacrev_only = parametrize("jacapi", [subtest(jacrev, name='jacrev')])
+FIXME_skip_jacrev_dynamo = parametrize("jacapi", [
+ subtest(jacrev, name='jacrev', decorators=[unittest.expectedFailure] if TEST_WITH_TORCHDYNAMO else None),
+ subtest(jacfwd, name='jacfwd')
+])
@markDynamoStrictTest
class TestJac(TestCase):
@@ -1745,8 +1748,7 @@
expected = torch.diagflat(x)
assert torch.allclose(z, expected)
- @xfailIfTorchDynamo
- @jacrev_and_jacfwd
+ @FIXME_skip_jacrev_dynamo
def test_multiple_outputs_multiple_argnums(self, device, jacapi):
def f(x, y):
return 2 * x + 3 * y, 4 * x + 5 * y
@@ -1768,8 +1770,7 @@
self.assertEqual(z[1][0], expected_out1_x)
self.assertEqual(z[1][1], expected_out1_y)
- @xfailIfTorchDynamo
- @jacrev_and_jacfwd
+ @FIXME_skip_jacrev_dynamo
def test_multiple_outputs_single_argnums(self, device, jacapi):
def f(x, y):
return 2 * x + 3 * y, 4 * x + 5 * y
@@ -1812,8 +1813,7 @@
self.assertTrue(isinstance(z['right'], tuple))
self.assertEqual(z, expected)
- @xfailIfTorchDynamo
- @jacrev_and_jacfwd
+ @FIXME_skip_jacrev_dynamo
def test_multiple_inputs_pytree(self, device, jacapi):
def f(a, b, c):
a0, a1 = a
@@ -1838,8 +1838,7 @@
expected = (torch.tensor(1., device=device), torch.tensor(2., device=device))
self.assertEqual(result, expected)
- @xfailIfTorchDynamo
- @jacrev_and_jacfwd
+ @FIXME_skip_jacrev_dynamo
def test_dimensionality(self, device, jacapi):
def f(x):
return x
@@ -1867,8 +1866,7 @@
self.assertEqual(result, torch.eye(3, 3, device=device))
self.assertEqual(aux, x.cos())
- @xfailIfTorchDynamo
- @jacrev_and_jacfwd
+ @FIXME_skip_jacrev_dynamo
def test_aux_pytree(self, device, jacapi):
def f(x):
y = x.clone()
@@ -1887,8 +1885,7 @@
with self.assertRaisesRegex(RuntimeError, r"Expected tensors, got unsupported type"):
_ = jacapi(lambda x: (x, [x, aux]), has_aux=True)(x)
- @xfailIfTorchDynamo
- @jacrev_and_jacfwd
+ @FIXME_skip_jacrev_dynamo
def test_outputs_can_any_pytree(self, device, jacapi):
x = torch.randn(2, 3, device=device)
@@ -1922,8 +1919,7 @@
assert isinstance(out, list)
assert isinstance(out[0], tuple) and isinstance(out[0][1], dict)
- @xfailIfTorchDynamo
- @jacrev_and_jacfwd
+ @FIXME_skip_jacrev_dynamo
def test_multiple_inputs_outputs_pytree(self, device, jacapi):
def f(a, b, c):
a0, a1 = a
@@ -1972,8 +1968,7 @@
}
self.assertEqual(result, expected)
- @xfailIfTorchDynamo
- @jacrev_and_jacfwd
+ @FIXME_skip_jacrev_dynamo
def test_unrelated_input(self, device, jacapi):
def f(x, y):
return x
@@ -1988,8 +1983,7 @@
self.assertTrue(isinstance(result, tuple))
self.assertEqual(result, expected)
- @xfailIfTorchDynamo
- @jacrev_and_jacfwd
+ @FIXME_skip_jacrev_dynamo
def test_unrelated_output(self, device, jacapi):
y = torch.randn(2, 3, device=device)
@@ -2002,8 +1996,7 @@
expected = x.new_zeros(2, 3, 2, 3)
self.assertEqual(result, expected)
- @xfailIfTorchDynamo
- @jacrev_and_jacfwd
+ @FIXME_skip_jacrev_dynamo
def test_empty_output(self, device, jacapi):
x = torch.randn(3, device=device)
y = torch.randn(3, device=device)
@@ -2044,8 +2037,7 @@
assert isinstance(z, torch.Tensor)
assert torch.allclose(z, expected0)
- @xfailIfTorchDynamo
- @jacrev_and_jacfwd
+ @FIXME_skip_jacrev_dynamo
def test_argnums_defaults_to_zero(self, device, jacapi):
def f(x, y):
return x * 2 + y * 3
@@ -2093,7 +2085,6 @@
with self.assertRaisesRegex(RuntimeError, "must be int"):
jacapi(torch.multiply, argnums=(1, 0.0))(x, x)
- @xfailIfTorchDynamo
def test_hessian_simple(self, device):
def f(x):
return x.sin()
@@ -2109,8 +2100,7 @@
result = jacapi(foo)(inputs)
self.assertEqual(result, expected)
- @xfailIfTorchDynamo
- @jacrev_and_jacfwd
+ @FIXME_skip_jacrev_dynamo
def test_against_reference_simple(self, device, jacapi):
def f(x):
return 3 * x ** 2
@@ -2118,8 +2108,7 @@
x = torch.randn(2, 3, 5, device=device)
self._test_against_reference(f, (x,), jacapi)
- @xfailIfTorchDynamo
- @jacrev_and_jacfwd
+ @FIXME_skip_jacrev_dynamo
def test_against_reference_multi_input(self, device, jacapi):
def f(x, y):
return (x.cos() * x) @ y.sin()
@@ -2128,8 +2117,7 @@
y = torch.randn(3, 5, device=device)
self._test_against_reference(f, (x, y), jacapi)
- @xfailIfTorchDynamo
- @jacrev_and_jacfwd
+ @FIXME_skip_jacrev_dynamo
def test_against_reference_multi_input_multi_output(self, device, jacapi):
def f(x, y):
return (x * x) @ y, x @ (x.sum(1) * y), y.sum()
@@ -2138,8 +2126,7 @@
y = torch.randn(3, 5, device=device)
self._test_against_reference(f, (x, y), jacapi)
- @xfailIfTorchDynamo
- @jacrev_and_jacfwd
+ @FIXME_skip_jacrev_dynamo
def test_against_reference_unrelated_outputs(self, device, jacapi):
def f(x, y):
return x, y, x, y
@@ -2148,8 +2135,7 @@
y = torch.randn(3, device=device)
self._test_against_reference(f, (x, y), jacapi)
- @xfailIfTorchDynamo
- @jacrev_and_jacfwd
+ @FIXME_skip_jacrev_dynamo
def test_against_reference_zero_dim(self, device, jacapi):
# zero-dim output
def f(x, y):
@@ -2174,8 +2160,7 @@
y = torch.randn(1, device=device)
self._test_against_reference(h, (x, y), jacapi)
- @xfailIfTorchDynamo
- @jacrev_and_jacfwd
+ @FIXME_skip_jacrev_dynamo
def test_against_reference_correctness_different_devices(self, device, jacapi):
def f(x, y):
return x * y, (x * y).to(device=device)
@@ -2184,8 +2169,7 @@
y = torch.randn(3)
self._test_against_reference(f, (x, y), jacapi)
- @xfailIfTorchDynamo
- @jacrev_and_jacfwd
+ @FIXME_skip_jacrev_dynamo
def test_against_reference_default_arg(self, device, jacapi):
def f(x, y, z=3.):
return x * y * z
@@ -2194,8 +2178,7 @@
y = torch.randn(3, device=device)
self._test_against_reference(f, (x, y), jacapi)
- @xfailIfTorchDynamo
- @jacrev_and_jacfwd
+ @FIXME_skip_jacrev_dynamo
def test_inplace(self, device, jacapi):
def f(x, y):
y.copy_(x)
@@ -2243,7 +2226,6 @@
with self.assertRaisesRegex(ValueError, err_msg):
jacrev(f, argnums=(0, ), chunk_size=-2)(x, y)
- @xfailIfTorchDynamo
@parametrize('_preallocate_and_copy', (True, False))
def test_chunk_jacrev_composition(self, device, _preallocate_and_copy):
x = torch.randn(10, 2, device=device)
@@ -2323,8 +2305,7 @@
with self.assertRaisesRegex(RuntimeError, "jacfwd: Expected all outputs"):
jacfwd(fn)(x)
- @xfailIfTorchDynamo
- @jacrev_and_jacfwd
+ @FIXME_skip_jacrev_dynamo
def test_jac_with_non_tensor_args(self, device, jacapi):
def f(t, int_x):
return t + int_x
@@ -2345,7 +2326,6 @@
result = hessian(foo)(inputs)
self.assertEqual(result, expected)
- @xfailIfTorchDynamo
def test_hessian_vectorize_correctness_simple(self, device):
def f(x):
return (3 * x ** 2).sum()
@@ -2353,7 +2333,6 @@
x = torch.randn(2, 3, 5, device=device)
self._test_against_reference(f, (x,))
- @xfailIfTorchDynamo
def test_hessian_vectorize_correctness_multi_input(self, device):
def f(x, y, z):
return ((x.relu() * x) @ y.sin() @ z).sum()
@@ -2363,7 +2342,6 @@
z = torch.randn(5, 5, device=device)
self._test_against_reference(f, (x, y, z))
- @xfailIfTorchDynamo
def test_hessian_vectorize_correctness_unrelated_outputs(self, device):
# output unrelated to one input
def f(x, y):
@@ -2381,7 +2359,6 @@
y = torch.randn(3, device=device)
self._test_against_reference(f, (x, y))
- @xfailIfTorchDynamo
def test_jacfwd_different_levels(self, device):
# Test case from:
# https://github.com/pytorch/functorch/issues/597
@@ -2852,7 +2829,6 @@
@markDynamoStrictTest
class TestVmapJvpInplaceView(TestCase):
# Case 1 in [Forward Grad View/inplace]
- @xfailIfTorchDynamo
def test_all_dual_no_view(self, device):
B = 2
@@ -2881,7 +2857,6 @@
self.assertEqual(out_tangent, yt.expand(B, 3))
# Case 2 in [Forward Grad View/inplace]
- @xfailIfTorchDynamo
def test_all_dual_base_view_inplace(self, device):
B = 2
@@ -2919,7 +2894,6 @@
self.assertEqual(x.movedim(2, 0), expected)
# Case 3 in [Forward Grad View/inplace]
- @xfailIfTorchDynamo
def test_all_dual_base_inplace(self, device):
B = 2
@@ -2948,7 +2922,6 @@
self.assertEqual(out_tangent[1], yt.movedim(1, 0)[:, 0, ::2])
# Case 4 in [Forward Grad View/inplace]
- @xfailIfTorchDynamo
def test_right_dual_view_prop(self, device):
B = 2
@@ -2980,7 +2953,6 @@
self.assertEqual(tangents[1], expected_tangent_1)
# Case 5 in [Forward Grad View/inplace]
- @xfailIfTorchDynamo
def test_right_dual_base_prop(self, device):
B = 2
@@ -4156,7 +4128,6 @@
self.assertEqual(result, expected)
- @xfailIfTorchDynamo
@parametrize('mechanism', ["make_functional", "functional_call"])
def test_ensemble_regression(self, device, mechanism):
def make_spirals(n_samples, noise_std=0., rotations=1.):
@@ -4241,7 +4212,6 @@
self.assertEqual(result_loss, expected_loss)
self.assertEqual(result_weights, expected_weights)
- @xfailIfTorchDynamo
@parametrize("dropout_layer", [
subtest(nn.Dropout, 'Dropout'),
subtest(nn.AlphaDropout, 'AlphaDropout'),
@@ -4388,7 +4358,6 @@
self.assertEqual(inpt1, inpt2)
self.assertEqual(inpt1, inpt3)
- @xfailIfTorchDynamo
def test_simple_view(self, device):
def f(x: torch.Tensor) -> torch.Tensor:
@@ -4398,7 +4367,6 @@
return x
self._check_functionalize_correctness(f, torch.zeros(4, 2, device=device))
- @xfailIfTorchDynamo
def test_multioutput_view(self, device):
def f(x: torch.Tensor) -> torch.Tensor:
@@ -4488,7 +4456,6 @@
self.assertEqual(out1, out2)
self.assertEqual(inpt1, inpt2)
- @xfailIfTorchDynamo
@unittest.skipIf(IS_FBCODE, 'fails in fbcode')
def test_vmap_functionalize_jvp(self, device):
@@ -4815,7 +4782,6 @@
ggx = grad(grad_f_sum)(x)
self.assertEqual(ggx, -x.sin())
- @xfailIfTorchDynamo
def test_vmap_grad_sum(self, device):
x = torch.randn(2, 3, device=device)
gx = vmap(grad(sum_pyop), (0, None))(x, 0)
diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py
index dacadec..95107ef 100644
--- a/test/functorch/test_vmap.py
+++ b/test/functorch/test_vmap.py
@@ -25,6 +25,7 @@
from torch.testing._internal.common_utils import (
parametrize,
instantiate_parametrized_tests,
+ IS_WINDOWS,
subtest,
skipIfRocm,
TEST_WITH_TORCHDYNAMO,
@@ -2155,6 +2156,8 @@
self.assertEqual(vmap(foo)(float_tensor), torch.tensor([1, 1, 1]))
self.assertEqual(vmap(foo)(long_tensor), torch.tensor([0, 0, 0]))
+ @unittest.skipIf(IS_WINDOWS,
+ reason="Windows not yet supported for torch.compile")
def test_is_contiguous(self):
def foo(x):
if x.is_contiguous():
@@ -2211,6 +2214,16 @@
with self.assertRaisesRegex(RuntimeError, msg):
vmap(functools.partial(baz, memory_format=torch.channels_last_3d))(tensor)
+ for mf in (torch.channels_last, torch.channels_last_3d):
+ @torch.compile(backend="eager", fullgraph=True)
+ def f(x):
+ if x.is_contiguous(memory_format=mf):
+ return x.sin()
+ return x.cos()
+
+ with self.assertRaisesRegex(RuntimeError, msg):
+ vmap(f)(torch.randn(3, 3))
+
def test_unsqueeze(self):
op = torch.unsqueeze
test = self._vmap_view_test
@@ -2840,7 +2853,6 @@
res = vmap(foo)(x)
self.assertEqual(res, x.conj())
- @xfailIfTorchDynamo
def test_mode_key(self):
def vmap_f(x):
return x + torch.randn(())
@@ -5092,8 +5104,9 @@
@skipIfTorchDynamo
@parametrize('transform', ['vmap', 'grad', 'grad_and_value', 'vjp', 'jvp', 'jacrev', 'jacfwd'])
def test_fails_with_autograd_function(self, device, transform):
+ failed_build_envs = ('linux-focal-py3.8-clang10', 'linux-focal-py3.11-clang10')
if (device == 'cpu' and transform in ['grad', 'vmap'] and
- TEST_WITH_TORCHDYNAMO and os.getenv('BUILD_ENVIRONMENT', '') == 'linux-focal-py3.8-clang10'):
+ TEST_WITH_TORCHDYNAMO and os.getenv('BUILD_ENVIRONMENT', '') in failed_build_envs):
raise unittest.SkipTest("Unexpected successes on focal with dynamo," +
" see https://github.com/pytorch/pytorch/issues/107173")
@@ -5352,7 +5365,6 @@
RuntimeError, "Nested tensors can only be vmapped over dim=0"):
vmap(f, in_dims=2)(x)
- @xfailIfTorchDynamo
def test_nt_with_nonzero_out_dim_raises(self, device):
def f(x):
return x
@@ -5362,8 +5374,6 @@
RuntimeError, "Nested tensors can only be vmapped over dim=0"):
vmap(f, out_dims=2)(x)
- @xfailIfTorchDynamo
- @allowVmapFallbackUsage
def test_fallback_with_nt_and_batched_dense_with_nonzero_bdim_raises(self, device):
def f(x, y):
return x @ y
diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py
index 14a5965..2f56229 100644
--- a/test/test_fake_tensor.py
+++ b/test/test_fake_tensor.py
@@ -22,6 +22,7 @@
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from torch._dynamo.testing import rand_strided
+from torch._C._functorch import is_batchedtensor, _add_batch_dim, get_unwrapped
from torch.testing import FileCheck
import unittest
import torch._prims as prims
@@ -201,6 +202,23 @@
FileCheck().check("CPU").check("AutocastCPU").run(torch._C._dispatch_key_set(y))
FileCheck().check_not("ADInplaceOrView").check_not("Autograd").run(torch._C._dispatch_key_set(y))
+ def test_batch_tensor(self):
+ x = torch.rand((3, 4, 5))
+ b = _add_batch_dim(x, 0, 0)
+ mode = FakeTensorMode()
+ fake_b = mode.from_tensor(b)
+ prims.utils.compare_tensor_meta(b, fake_b, check_strides=True)
+
+ b1 = _add_batch_dim(x, 1, 1)
+ b2 = _add_batch_dim(b1, 0, 2)
+ fake_b2 = mode.from_tensor(b2)
+ prims.utils.compare_tensor_meta(b2, fake_b2, check_strides=True)
+ self.assertTrue(is_batchedtensor(fake_b2))
+ fake_b1 = get_unwrapped(fake_b2)
+ self.assertTrue(is_batchedtensor(fake_b1))
+ fake_tensor = get_unwrapped(fake_b1)
+ self.assertIsInstance(fake_tensor, FakeTensor)
+
def test_constructor(self):
with FakeTensorMode():
x = torch.rand([4, 4], device="cpu")
diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py
index be76fa6..1e0754b 100644
--- a/torch/_dynamo/variables/higher_order_ops.py
+++ b/torch/_dynamo/variables/higher_order_ops.py
@@ -546,8 +546,12 @@
all_diffs = []
for i, (var1, var2) in enumerate(zip(tensor_vars1, tensor_vars2)):
# We check the meta data associated with meta["example_value"]
- meta1 = _extract_tensor_metadata(var1.proxy.node.meta["example_value"])
- meta2 = _extract_tensor_metadata(var2.proxy.node.meta["example_value"])
+ meta1 = _extract_tensor_metadata(
+ var1.proxy.node.meta["example_value"], include_contiguity=False
+ )
+ meta2 = _extract_tensor_metadata(
+ var2.proxy.node.meta["example_value"], include_contiguity=False
+ )
if meta1 != meta2:
all_diffs.append((f"pair{i}:", meta1, meta2))
return all_diffs
diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py
index 2f68e77..bac958f 100644
--- a/torch/_dynamo/variables/tensor.py
+++ b/torch/_dynamo/variables/tensor.py
@@ -158,13 +158,18 @@
[int(s) if is_symbolic(s) else s for s in value.size()]
)
props["stride"] = tuple(value.stride())
- props["is_contiguous"] = tuple(
- [
- x
- for x in torch._prims_common._memory_formats
- if value.is_contiguous(memory_format=x)
- ]
- )
+ if torch._C._functorch.is_batchedtensor(value):
+ # Batched tensors does not support contiguity patterns, so
+ # we refrain from computing the `is_contiguous` property
+ props["is_contiguous"] = None
+ else:
+ props["is_contiguous"] = tuple(
+ [
+ x
+ for x in torch._prims_common._memory_formats
+ if value.is_contiguous(memory_format=x)
+ ]
+ )
return props
def dynamic_getattr(self, tx, name):
diff --git a/torch/_functorch/vmap.py b/torch/_functorch/vmap.py
index 0ba631a..69bc77f 100644
--- a/torch/_functorch/vmap.py
+++ b/torch/_functorch/vmap.py
@@ -274,10 +274,17 @@
return _chunked_vmap(func, flat_in_dims, chunks_flat_args,
args_spec, out_dims, randomness, **kwargs)
- # If chunk_size is not specified.
- return _flat_vmap(
- func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
- )
+ from torch._dynamo import disable
+
+ # remove @disable once #114306 is fixed
+ @disable
+ def wrapper():
+ # If chunk_size is not specified.
+ return _flat_vmap(
+ func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
+ )
+
+ return wrapper()
def get_chunk_sizes(total_elems, chunk_size):
n_chunks = n_chunks = total_elems // chunk_size
diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py
index 96b522b..98c25d2 100644
--- a/torch/_higher_order_ops/cond.py
+++ b/torch/_higher_order_ops/cond.py
@@ -8,6 +8,12 @@
import torch.utils._pytree as pytree
from torch._C import DispatchKey
+from torch._C._functorch import (
+ _add_batch_dim,
+ get_unwrapped,
+ is_batchedtensor,
+ maybe_get_bdim,
+)
from torch._functorch.utils import exposed_in
from torch._higher_order_ops.utils import autograd_not_implemented
@@ -404,3 +410,50 @@
unwrapped_pred, functional_true, functional_false, unwrapped_inputs
)
return ctx.wrap_tensors(cond_return)
+
+
+@cond_op.py_impl(torch._C._functorch.TransformType.Vmap)
+def cond_batch_rule(interpreter, pred, true_fn, false_fn, inputs):
+ assert isinstance(
+ inputs, (list, tuple)
+ ), "Cond inputs must be a list or tuple of tensors"
+ assert all(
+ isinstance(i, torch.Tensor) for i in inputs
+ ), "Cond inputs must be a list of tensors"
+
+ pred_ = get_unwrapped(pred) if is_batchedtensor(pred) else pred
+
+ # unbatched tensors are not vmapped
+ tensors, in_dims = zip(
+ *[
+ (get_unwrapped(t), maybe_get_bdim(t)) if is_batchedtensor(t) else (t, None)
+ for t in inputs
+ ]
+ )
+
+ if is_batchedtensor(pred):
+ # prepend "pred" and vmap everything
+ tensors = (pred_,) + tensors
+ in_dims = (0,) + in_dims
+
+ def fn(p, *args):
+ t = true_fn(*args)
+ f = false_fn(*args)
+ return torch.where(p, t[0], f[0])
+
+ with interpreter.lower():
+ result = torch.vmap(fn, in_dims=in_dims)(*tensors)
+
+ else:
+ # predicate is known at this stage and it is a boolean expression or a
+ # tensor with one element.
+ true_fn = torch.vmap(true_fn, in_dims=in_dims)
+ false_fn = torch.vmap(false_fn, in_dims=in_dims)
+
+ with interpreter.lower():
+ result = cond_op(pred, true_fn, false_fn, tensors)
+
+ if not isinstance(result, tuple):
+ result = (result,)
+ lvl = interpreter.level()
+ return tuple([_add_batch_dim(r, 0, lvl) for r in result])
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index ebc1440..dbb122a 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -188,6 +188,9 @@
reapply_views = torch._C._functionalization_reapply_views_tls()
unwrapped = torch._C._functorch._unwrap_functional_tensor(x, reapply_views)
return is_fake(unwrapped)
+ elif isinstance(x, torch.Tensor) and torch._C._functorch.is_batchedtensor(x):
+ unwrapped = torch._C._functorch.get_unwrapped(x)
+ return is_fake(unwrapped)
return False
@@ -206,6 +209,9 @@
reapply_views = torch._C._functionalization_reapply_views_tls()
unwrapped = torch._C._functorch._unwrap_functional_tensor(t, reapply_views)
return maybe_get_fake_mode(unwrapped)
+ elif isinstance(t, torch.Tensor) and torch._C._functorch.is_batchedtensor(t):
+ unwrapped = torch._C._functorch.get_unwrapped(t)
+ return maybe_get_fake_mode(unwrapped)
return None
diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py
index ba7e0d2..40442bf 100644
--- a/torch/_subclasses/meta_utils.py
+++ b/torch/_subclasses/meta_utils.py
@@ -5,9 +5,14 @@
import torch
from torch._C._functorch import (
+ _add_batch_dim,
_unwrap_functional_tensor,
_wrap_functional_tensor,
current_level,
+ get_unwrapped,
+ is_batchedtensor,
+ maybe_get_bdim,
+ maybe_get_level,
peek_interpreter_stack,
TransformType,
)
@@ -128,7 +133,7 @@
# hold a weak ref to self, otherwise it will be kept alive
# by the del_ten closure
self_weak_ref = weakref.ref(self)
- if t.is_sparse or t.is_mkldnn:
+ if t.is_sparse or t.is_mkldnn or is_batchedtensor(t):
weak_st = None
else:
weak_st = StorageWeakRef(t._typed_storage())
@@ -313,6 +318,30 @@
if t.requires_grad and not is_leaf:
with torch.enable_grad():
r = r.clone()
+ elif is_batchedtensor(t):
+ # Wraps a BatchedTensor in a FakeTensor
+ def _to_fake_tensor(t):
+ if is_batchedtensor(t):
+ ft = _to_fake_tensor(get_unwrapped(t))
+ lvl = maybe_get_level(t)
+ bdim = maybe_get_bdim(t)
+ r = _add_batch_dim(ft, bdim, lvl)
+ else:
+ # regular tensor
+ sizes = t.size()
+ strides = t.stride()
+ r = callback(
+ lambda: torch.empty_strided(
+ sizes,
+ strides,
+ dtype=t.dtype,
+ device="meta",
+ )
+ )
+ return r
+
+ r = _to_fake_tensor(t)
+
elif t._is_view():
# Construct views in two steps: recursively meta-fy their
# base, and then create view(s) off that. NB: doing it
@@ -511,7 +540,9 @@
r = r.clone(memory_format=torch.preserve_format)
# Graph-Break for wrapped tensors
- if torch._C._functorch.is_functorch_wrapped_tensor(t):
+ if not is_batchedtensor(
+ t
+ ) and torch._C._functorch.is_functorch_wrapped_tensor(t):
return NotImplemented
s = t.untyped_storage()
diff --git a/torch/fx/passes/shape_prop.py b/torch/fx/passes/shape_prop.py
index 69260a5..6208980 100644
--- a/torch/fx/passes/shape_prop.py
+++ b/torch/fx/passes/shape_prop.py
@@ -26,7 +26,7 @@
is_quantized : bool
qparams: Dict[str, Any]
-def _extract_tensor_metadata(result : torch.Tensor) -> TensorMetadata:
+def _extract_tensor_metadata(result : torch.Tensor, include_contiguity=True) -> TensorMetadata:
"""
Extract a TensorMetadata NamedTuple describing `result`.
"""
@@ -35,18 +35,18 @@
requires_grad = result.requires_grad
stride = result.stride()
- memory_formats = {
- torch.contiguous_format,
- torch.channels_last,
- torch.channels_last_3d,
- }
-
memory_format = None
- for query_format in memory_formats:
- if result.is_contiguous(memory_format=query_format):
- memory_format = query_format
- break
+ if include_contiguity:
+ memory_formats = {
+ torch.contiguous_format,
+ torch.channels_last,
+ torch.channels_last_3d,
+ }
+ for query_format in memory_formats:
+ if result.is_contiguous(memory_format=query_format):
+ memory_format = query_format
+ break
is_quantized = result.is_quantized
qparams: Dict[str, Any] = {}