Use nested variant of getValueTrace to allow more flexible tracing script modules (#13597)

Summary:
When tracing scripted functions, we used to only allow Tensor arguments.
This enables tracing script modules with List[Tensor] or Tuple[Tensor, Tensor] arguments (passing
tuples).

Fixes: #13566
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13597

Differential Revision: D12990464

Pulled By: soumith

fbshipit-source-id: fdce3afcb1e09f3c26d6ce834c01bf18d261f47c
diff --git a/test/test_jit.py b/test/test_jit.py
index 4e51d67..00d6cff 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -1493,6 +1493,30 @@
         x = torch.randn(5, 5)
         self.assertEqual(foo(x), x + x + x)
 
+    def test_trace_script(self):
+        @torch.jit.script
+        def func1(x):
+            # type: (Tuple[Tensor, Tensor]) -> Tensor
+            return x[0] + x[1]
+
+        @torch.jit.script
+        def func2(x):
+            # type: (List[Tensor]) -> Tensor
+            return x[0] + x[1]
+
+        a = torch.randn(5)
+        b = torch.randn(5)
+
+        expected = func1((a, b))
+        traced = torch.jit.trace(func1, ((a, b),))
+        result = traced((a, b))
+        self.assertEqual(expected, result)
+
+        expected = func2((a, b))
+        traced = torch.jit.trace(func2, ((a, b),))
+        result = traced((a, b))
+        self.assertEqual(expected, result)
+
     def test_einsum(self):
         def outer(x, y):
             return torch.einsum('i,j->ij', (x, y))
diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp
index ad72216..4a06572 100644
--- a/torch/csrc/jit/graph_executor.cpp
+++ b/torch/csrc/jit/graph_executor.cpp
@@ -497,7 +497,7 @@
     auto state = tracer::getTracingState();
     auto inputs = last(stack, num_inputs);
     auto input_values = fmap(inputs, [](const IValue & v) {
-      return tracer::getValueTrace(v.toTensor());
+      return tracer::getNestedValueTrace(v);
     });
 
     ArgumentSpec spec(autograd::GradMode::is_enabled(), inputs, num_flat_inputs);
diff --git a/torch/csrc/jit/tracer.h b/torch/csrc/jit/tracer.h
index 2f5e85d..ed88273 100644
--- a/torch/csrc/jit/tracer.h
+++ b/torch/csrc/jit/tracer.h
@@ -80,6 +80,27 @@
   return it->second;
 }
 
+// allow tracing of tuples passed to List[Tensor] or Tuple[Tensor...] arguments
+// One might merge getValueTrace and getNestedValueTrace after checking that
+// casting to IValue instead  of Variable is OK
+inline Value* getNestedValueTrace(const IValue &v) {
+  auto &state = getTracingState();
+  if (v.isTensorList()) {
+    return state->graph->insertNode(state->graph->createList(
+        DynamicType::get(),
+        fmap(v.toTensorListRef(), [](const IValue &val) {
+          return getNestedValueTrace(val);
+	})))->output();
+  } else if (v.isTuple()) {
+    return state->graph->insertNode(state->graph->createTuple(
+	fmap(v.toTuple()->elements(), [](const IValue &val) {
+          return getNestedValueTrace(val);
+	})))->output();
+  }
+  return getValueTrace(v.toTensor());
+}
+
+
 inline Value* getOutputTrace(const std::shared_ptr<TracingState>& state, const Variable& var, size_t output_no) {
   if (!var.defined()) {
     Node *n = state->graph->createUndefined();