LSTM SymInt-aware changes & meta registration (non-cuDNN CUDA) (#90701)

Adds meta registrations for cuDNN and vanilla CUDA ops underneath `lstm()` and makes the logic SymInt-aware.
TODO:
* cuDNN side does some [nasty stuff](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp#L1567) with buffers; this needs larger redesign to figure out
* Indicate that AOT Autograd can be used when an LSTM is present (remove the check for this once it's fully supported)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90701
Approved by: https://github.com/ezyang
diff --git a/aten/src/ATen/native/RNN.cpp b/aten/src/ATen/native/RNN.cpp
index 43f96cc..a592c18 100644
--- a/aten/src/ATen/native/RNN.cpp
+++ b/aten/src/ATen/native/RNN.cpp
@@ -1441,7 +1441,7 @@
   }
 #endif
   // if cells are of different size, that means projections are used
-  bool has_projections = (hx[0].size(2) != hx[1].size(2));
+  bool has_projections = (hx[0].sym_size(2) != hx[1].sym_size(2));
   if (use_miopen(_input, dropout_p)) {
     if (!has_projections) {
       Tensor output, hy, cy;
diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py
index 8909f83..ea47f36 100644
--- a/test/test_fake_tensor.py
+++ b/test/test_fake_tensor.py
@@ -370,6 +370,36 @@
                         self.assertTrue(isinstance(ten, FakeTensor))
                     self.assertEqual(ten.device.type, 'cuda')
 
+    @unittest.skipIf(not RUN_CUDA, "requires cuda")
+    def test_cuda_lstm(self):
+        # Ensure CUDA (non-cuDNN) impl succeeds with fake tensors.
+        with torch.backends.cudnn.flags(enabled=False):
+            fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False)
+            with fake_tensor_mode:
+                N = 5
+                L = 4
+                H_in = 2
+                hidden_size = 3
+                proj_size = 2
+                num_layers = 2
+                bidir = False
+                D = 2 if bidir else 1
+                H_out = proj_size if proj_size > 0 else hidden_size
+
+                lstm = torch.nn.LSTM(input_size=H_in, hidden_size=hidden_size,
+                                     num_layers=num_layers, proj_size=proj_size, batch_first=False,
+                                     bias=True, bidirectional=bidir, device='cuda')
+
+                h_0 = torch.randn((num_layers * D, N, H_out), device='cuda')
+                c_0 = torch.randn((num_layers * D, N, hidden_size), device='cuda')
+                inp = torch.randn((L, N, H_in), device='cuda')
+                (output, (h_n, c_n)) = lstm(inp, (h_0, c_0))
+                output.sum().backward()
+
+                self.assertEqual(output.shape, (L, N, D * H_out))
+                self.assertEqual(h_n.shape, (D * num_layers, N, H_out))
+                self.assertEqual(c_n.shape, (D * num_layers, N, hidden_size))
+
     @skipIfRocm
     @unittest.skipIf(not RUN_CUDA, "requires cuda")
     def test_fallback_memory_prop(self):
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index f5686c0..6bcb1b6 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -2017,6 +2017,51 @@
     return torch.empty_like(self), torch.empty_like(self, dtype=torch.int64)
 
 
+def rnn_cell_checkSizes(
+    input_gates, hidden_gates, input_bias, hidden_bias, factor, prev_hidden
+):
+    check(input_gates.ndim == 2, lambda: f"{input_gates.ndim} != 2")
+    check(
+        input_gates.shape == hidden_gates.shape,
+        lambda: f"{input_gates.shape} != {hidden_gates.shape}",
+    )
+    gates_size = input_gates.size(1)
+    if input_bias is not None:
+        check(input_bias.ndim == 1, lambda: f"{input_bias.ndim} != 1")
+        check(
+            input_bias.numel() == gates_size,
+            lambda: f"{input_bias.numel()} != {gates_size}",
+        )
+        check(
+            input_bias.shape == hidden_bias.shape,
+            lambda: f"{input_bias.shape} != {hidden_bias.shape}",
+        )
+    check(prev_hidden.ndim == 2, lambda: f"{prev_hidden.ndim} != 2")
+    expected_prev_hidden_numel = input_gates.size(0) * gates_size // factor
+    check(
+        prev_hidden.numel() == expected_prev_hidden_numel,
+        lambda: f"{prev_hidden.numel()} != {input_gates.size(0)} * {gates_size} // {factor} (aka {expected_prev_hidden_numel})",
+    )
+    check(
+        all(
+            x.device == input_gates.device
+            for x in [hidden_gates, input_bias, hidden_bias, prev_hidden]
+        ),
+        lambda: "expected all inputs to be same device",
+    )
+
+
+@register_meta(aten._thnn_fused_lstm_cell.default)
+def _thnn_fused_lstm_cell_meta(
+    input_gates, hidden_gates, cx, input_bias=None, hidden_bias=None
+):
+    rnn_cell_checkSizes(input_gates, hidden_gates, input_bias, hidden_bias, 4, cx)
+    workspace = torch.empty_like(input_gates, memory_format=torch.contiguous_format)
+    hy = torch.empty_like(cx, memory_format=torch.contiguous_format)
+    cy = torch.empty_like(cx, memory_format=torch.contiguous_format)
+    return (hy, cy, workspace)
+
+
 def zero_numel_check_dims(self, dim, fn_name):
     if self.ndim == 0:
         check(
@@ -2076,6 +2121,38 @@
     return self.new_empty(topKSize), self.new_empty(topKSize, dtype=torch.int64)
 
 
+legacy_contiguous_memory_format = torch.contiguous_format
+
+
+# From aten/src/ATen/native/cuda/RNN.cu
+def checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace):
+    defined_grad = grad_hy if grad_hy is not None else grad_cy
+    check(defined_grad.dim() == 2, lambda: "")
+    exp_size = defined_grad.size()
+    if grad_hy is not None:
+        check(grad_hy.size() == exp_size, lambda: "")
+    if grad_cy is not None:
+        check(grad_cy.size() == exp_size, lambda: "")
+    check(cx.size() == exp_size, lambda: "")
+    check(cy.size() == exp_size, lambda: "")
+    check(workspace.dim() == 2, lambda: "")
+    check(workspace.numel() == exp_size[0] * exp_size[1] * 4, lambda: "")
+
+
+# From aten/src/ATen/native/cuda/RNN.cu
+@register_meta(aten._thnn_fused_lstm_cell_backward_impl.default)
+def _thnn_fused_lstm_cell_backward_impl(grad_hy, grad_cy, cx, cy, workspace, has_bias):
+    if grad_hy is None and grad_cy is None:
+        return None, None, None
+    checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace)
+    grad_gates = torch.empty_like(
+        workspace, memory_format=legacy_contiguous_memory_format
+    )
+    grad_cx = torch.empty_like(cx, memory_format=legacy_contiguous_memory_format)
+    grad_bias = grad_gates.sum(0, keepdim=False) if has_bias else None
+    return grad_gates, grad_cx, grad_bias
+
+
 # We must also trigger meta registrations from PrimTorch ref
 # decompositions
 import torch._refs