Inductor cpp wrapper: support MkldnnRnnLayer (#107858)

1. Directly use the `codegen` function of the parent class which already supported both python and cpp wrapper.
2. The output of the `at::mkldnn_rnn_layer` OP is actually a `std::tuple` https://github.com/pytorch/pytorch/blob/1491bae277668fac459937c874a49c3bb8adedcb/aten/src/ATen/native/mkldnn/RNN.cpp#L218 Fix the type when calling `MultiOutput`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107858
Approved by: https://github.com/jgong5, https://github.com/jansel
diff --git a/test/inductor/test_cpp_wrapper.py b/test/inductor/test_cpp_wrapper.py
index b1f20d6..3b19ae7 100644
--- a/test/inductor/test_cpp_wrapper.py
+++ b/test/inductor/test_cpp_wrapper.py
@@ -184,6 +184,12 @@
             and torch.ops.mkldnn._is_mkldnn_bf16_supported(),
         ),
         BaseTest("test_linear_packed", "", test_cpu_repro.CPUReproTests()),
+        BaseTest(
+            "test_lstm_packed_change_input_sizes",
+            "cpu",
+            test_cpu_repro.CPUReproTests(),
+            condition=torch.backends.mkldnn.is_available(),
+        ),
         BaseTest("test_mm_views"),
         BaseTest("test_multihead_attention", "cpu", test_cpu_repro.CPUReproTests()),
         BaseTest("test_profiler_mark_wrapper_call"),
diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py
index 548874b..561f1f1 100644
--- a/test/inductor/test_cpu_repro.py
+++ b/test/inductor/test_cpu_repro.py
@@ -459,7 +459,7 @@
         }
         self._test_lstm_packed(params_dict)
 
-    def test_lstm_packed_change_input_sizes(self):
+    def test_lstm_packed_change_input_sizes_cpu(self):
         params_dict = {
             "unbatched": [False],
             "input_size": [2],
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index 3dac4e7..d712d2a 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -221,23 +221,26 @@
     return out
 
 
-@patch.object(config, "debug", True)
 def run_and_get_cpp_code(fn, *args, **kwargs):
-    torch._dynamo.reset()
-    import io
-    import logging
+    # We use the patch context manager instead of using it as a decorator.
+    # In this way, we can ensure that the attribute is patched and unpatched correctly
+    # even if this run_and_get_cpp_code function is called multiple times.
+    with patch.object(config, "debug", True):
+        torch._dynamo.reset()
+        import io
+        import logging
 
-    log_capture_string = io.StringIO()
-    ch = logging.StreamHandler(log_capture_string)
-    from torch._inductor.graph import output_code_log
+        log_capture_string = io.StringIO()
+        ch = logging.StreamHandler(log_capture_string)
+        from torch._inductor.graph import output_code_log
 
-    output_code_log.addHandler(ch)
-    prev_level = output_code_log.level
-    output_code_log.setLevel(logging.DEBUG)
-    fn(*args, **kwargs)
-    s = log_capture_string.getvalue()
-    output_code_log.setLevel(prev_level)
-    output_code_log.removeHandler(ch)
+        output_code_log.addHandler(ch)
+        prev_level = output_code_log.level
+        output_code_log.setLevel(logging.DEBUG)
+        fn(*args, **kwargs)
+        s = log_capture_string.getvalue()
+        output_code_log.setLevel(prev_level)
+        output_code_log.removeHandler(ch)
     return s
 
 
diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py
index 260f5c0..4da0cf4 100644
--- a/torch/_inductor/ir.py
+++ b/torch/_inductor/ir.py
@@ -4792,18 +4792,18 @@
 
 class MkldnnRnnLayer(ExternKernelAlloc):
     def __init__(
-        self, layout, inputs, constant_args=(), kernel="aten.mkldnn_rnn_layer"
+        self,
+        layout,
+        inputs,
+        constant_args=(),
     ):
         super().__init__(
             layout,
             inputs,
             constant_args,
-        )
-        self.kernel = kernel
-
-    def codegen(self, wrapper):
-        wrapper.writeline(
-            f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
+            None,
+            kernel="aten.mkldnn_rnn_layer",
+            cpp_kernel="at::mkldnn_rnn_layer",
         )
 
     @classmethod
@@ -4890,7 +4890,7 @@
                     output_stride,
                 ),
                 packed,
-                indices + [(list, i)],
+                indices + [(tuple, i)],
             )
             for i, (output_size, output_stride) in enumerate(
                 zip(output_sizes, output_strides)