LSTM SymInt-aware changes & meta registration (non-cuDNN CUDA) (#90701)
Adds meta registrations for cuDNN and vanilla CUDA ops underneath `lstm()` and makes the logic SymInt-aware.
TODO:
* cuDNN side does some [nasty stuff](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp#L1567) with buffers; this needs larger redesign to figure out
* Indicate that AOT Autograd can be used when an LSTM is present (remove the check for this once it's fully supported)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90701
Approved by: https://github.com/ezyang
diff --git a/aten/src/ATen/native/RNN.cpp b/aten/src/ATen/native/RNN.cpp
index 43f96cc..a592c18 100644
--- a/aten/src/ATen/native/RNN.cpp
+++ b/aten/src/ATen/native/RNN.cpp
@@ -1441,7 +1441,7 @@
}
#endif
// if cells are of different size, that means projections are used
- bool has_projections = (hx[0].size(2) != hx[1].size(2));
+ bool has_projections = (hx[0].sym_size(2) != hx[1].sym_size(2));
if (use_miopen(_input, dropout_p)) {
if (!has_projections) {
Tensor output, hy, cy;
diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py
index 8909f83..ea47f36 100644
--- a/test/test_fake_tensor.py
+++ b/test/test_fake_tensor.py
@@ -370,6 +370,36 @@
self.assertTrue(isinstance(ten, FakeTensor))
self.assertEqual(ten.device.type, 'cuda')
+ @unittest.skipIf(not RUN_CUDA, "requires cuda")
+ def test_cuda_lstm(self):
+ # Ensure CUDA (non-cuDNN) impl succeeds with fake tensors.
+ with torch.backends.cudnn.flags(enabled=False):
+ fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False)
+ with fake_tensor_mode:
+ N = 5
+ L = 4
+ H_in = 2
+ hidden_size = 3
+ proj_size = 2
+ num_layers = 2
+ bidir = False
+ D = 2 if bidir else 1
+ H_out = proj_size if proj_size > 0 else hidden_size
+
+ lstm = torch.nn.LSTM(input_size=H_in, hidden_size=hidden_size,
+ num_layers=num_layers, proj_size=proj_size, batch_first=False,
+ bias=True, bidirectional=bidir, device='cuda')
+
+ h_0 = torch.randn((num_layers * D, N, H_out), device='cuda')
+ c_0 = torch.randn((num_layers * D, N, hidden_size), device='cuda')
+ inp = torch.randn((L, N, H_in), device='cuda')
+ (output, (h_n, c_n)) = lstm(inp, (h_0, c_0))
+ output.sum().backward()
+
+ self.assertEqual(output.shape, (L, N, D * H_out))
+ self.assertEqual(h_n.shape, (D * num_layers, N, H_out))
+ self.assertEqual(c_n.shape, (D * num_layers, N, hidden_size))
+
@skipIfRocm
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_fallback_memory_prop(self):
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index f5686c0..6bcb1b6 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -2017,6 +2017,51 @@
return torch.empty_like(self), torch.empty_like(self, dtype=torch.int64)
+def rnn_cell_checkSizes(
+ input_gates, hidden_gates, input_bias, hidden_bias, factor, prev_hidden
+):
+ check(input_gates.ndim == 2, lambda: f"{input_gates.ndim} != 2")
+ check(
+ input_gates.shape == hidden_gates.shape,
+ lambda: f"{input_gates.shape} != {hidden_gates.shape}",
+ )
+ gates_size = input_gates.size(1)
+ if input_bias is not None:
+ check(input_bias.ndim == 1, lambda: f"{input_bias.ndim} != 1")
+ check(
+ input_bias.numel() == gates_size,
+ lambda: f"{input_bias.numel()} != {gates_size}",
+ )
+ check(
+ input_bias.shape == hidden_bias.shape,
+ lambda: f"{input_bias.shape} != {hidden_bias.shape}",
+ )
+ check(prev_hidden.ndim == 2, lambda: f"{prev_hidden.ndim} != 2")
+ expected_prev_hidden_numel = input_gates.size(0) * gates_size // factor
+ check(
+ prev_hidden.numel() == expected_prev_hidden_numel,
+ lambda: f"{prev_hidden.numel()} != {input_gates.size(0)} * {gates_size} // {factor} (aka {expected_prev_hidden_numel})",
+ )
+ check(
+ all(
+ x.device == input_gates.device
+ for x in [hidden_gates, input_bias, hidden_bias, prev_hidden]
+ ),
+ lambda: "expected all inputs to be same device",
+ )
+
+
+@register_meta(aten._thnn_fused_lstm_cell.default)
+def _thnn_fused_lstm_cell_meta(
+ input_gates, hidden_gates, cx, input_bias=None, hidden_bias=None
+):
+ rnn_cell_checkSizes(input_gates, hidden_gates, input_bias, hidden_bias, 4, cx)
+ workspace = torch.empty_like(input_gates, memory_format=torch.contiguous_format)
+ hy = torch.empty_like(cx, memory_format=torch.contiguous_format)
+ cy = torch.empty_like(cx, memory_format=torch.contiguous_format)
+ return (hy, cy, workspace)
+
+
def zero_numel_check_dims(self, dim, fn_name):
if self.ndim == 0:
check(
@@ -2076,6 +2121,38 @@
return self.new_empty(topKSize), self.new_empty(topKSize, dtype=torch.int64)
+legacy_contiguous_memory_format = torch.contiguous_format
+
+
+# From aten/src/ATen/native/cuda/RNN.cu
+def checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace):
+ defined_grad = grad_hy if grad_hy is not None else grad_cy
+ check(defined_grad.dim() == 2, lambda: "")
+ exp_size = defined_grad.size()
+ if grad_hy is not None:
+ check(grad_hy.size() == exp_size, lambda: "")
+ if grad_cy is not None:
+ check(grad_cy.size() == exp_size, lambda: "")
+ check(cx.size() == exp_size, lambda: "")
+ check(cy.size() == exp_size, lambda: "")
+ check(workspace.dim() == 2, lambda: "")
+ check(workspace.numel() == exp_size[0] * exp_size[1] * 4, lambda: "")
+
+
+# From aten/src/ATen/native/cuda/RNN.cu
+@register_meta(aten._thnn_fused_lstm_cell_backward_impl.default)
+def _thnn_fused_lstm_cell_backward_impl(grad_hy, grad_cy, cx, cy, workspace, has_bias):
+ if grad_hy is None and grad_cy is None:
+ return None, None, None
+ checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace)
+ grad_gates = torch.empty_like(
+ workspace, memory_format=legacy_contiguous_memory_format
+ )
+ grad_cx = torch.empty_like(cx, memory_format=legacy_contiguous_memory_format)
+ grad_bias = grad_gates.sum(0, keepdim=False) if has_bias else None
+ return grad_gates, grad_cx, grad_bias
+
+
# We must also trigger meta registrations from PrimTorch ref
# decompositions
import torch._refs