Use nested variant of getValueTrace to allow more flexible tracing script modules (#13597)
Summary:
When tracing scripted functions, we used to only allow Tensor arguments.
This enables tracing script modules with List[Tensor] or Tuple[Tensor, Tensor] arguments (passing
tuples).
Fixes: #13566
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13597
Differential Revision: D12990464
Pulled By: soumith
fbshipit-source-id: fdce3afcb1e09f3c26d6ce834c01bf18d261f47c
diff --git a/test/test_jit.py b/test/test_jit.py
index 4e51d67..00d6cff 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -1493,6 +1493,30 @@
x = torch.randn(5, 5)
self.assertEqual(foo(x), x + x + x)
+ def test_trace_script(self):
+ @torch.jit.script
+ def func1(x):
+ # type: (Tuple[Tensor, Tensor]) -> Tensor
+ return x[0] + x[1]
+
+ @torch.jit.script
+ def func2(x):
+ # type: (List[Tensor]) -> Tensor
+ return x[0] + x[1]
+
+ a = torch.randn(5)
+ b = torch.randn(5)
+
+ expected = func1((a, b))
+ traced = torch.jit.trace(func1, ((a, b),))
+ result = traced((a, b))
+ self.assertEqual(expected, result)
+
+ expected = func2((a, b))
+ traced = torch.jit.trace(func2, ((a, b),))
+ result = traced((a, b))
+ self.assertEqual(expected, result)
+
def test_einsum(self):
def outer(x, y):
return torch.einsum('i,j->ij', (x, y))
diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp
index ad72216..4a06572 100644
--- a/torch/csrc/jit/graph_executor.cpp
+++ b/torch/csrc/jit/graph_executor.cpp
@@ -497,7 +497,7 @@
auto state = tracer::getTracingState();
auto inputs = last(stack, num_inputs);
auto input_values = fmap(inputs, [](const IValue & v) {
- return tracer::getValueTrace(v.toTensor());
+ return tracer::getNestedValueTrace(v);
});
ArgumentSpec spec(autograd::GradMode::is_enabled(), inputs, num_flat_inputs);
diff --git a/torch/csrc/jit/tracer.h b/torch/csrc/jit/tracer.h
index 2f5e85d..ed88273 100644
--- a/torch/csrc/jit/tracer.h
+++ b/torch/csrc/jit/tracer.h
@@ -80,6 +80,27 @@
return it->second;
}
+// allow tracing of tuples passed to List[Tensor] or Tuple[Tensor...] arguments
+// One might merge getValueTrace and getNestedValueTrace after checking that
+// casting to IValue instead of Variable is OK
+inline Value* getNestedValueTrace(const IValue &v) {
+ auto &state = getTracingState();
+ if (v.isTensorList()) {
+ return state->graph->insertNode(state->graph->createList(
+ DynamicType::get(),
+ fmap(v.toTensorListRef(), [](const IValue &val) {
+ return getNestedValueTrace(val);
+ })))->output();
+ } else if (v.isTuple()) {
+ return state->graph->insertNode(state->graph->createTuple(
+ fmap(v.toTuple()->elements(), [](const IValue &val) {
+ return getNestedValueTrace(val);
+ })))->output();
+ }
+ return getValueTrace(v.toTensor());
+}
+
+
inline Value* getOutputTrace(const std::shared_ptr<TracingState>& state, const Variable& var, size_t output_no) {
if (!var.defined()) {
Node *n = state->graph->createUndefined();