[JIT] Introduce a fake Tensor creation node for IR unit tests (#33914)
Summary:
**Summary**
There is often a need to create a Tensor when writing IR by hand for JIT
optimisation pass unit tests. The only options for this today are real
Tensor creation functions like `aten::ones`. Any test that uses these functions
must also use the same default arguments as the Python/C++ API, which means
that all of the tests have to be updated when the API is updated. This commit
introduces a new primitive, `prim::MakeTestTensor` with schema `() -> Tensor` that
should be used in unit tests instead of real Tensor creation functions. This new
primitive has no public-facing API, so the maintenance burden is much lower.
**Testing**
This commit updates the alias analysis and DCE tests to use `prim::MakeTestTensor` instead of
`aten::rand`, `aten::ones`, and `aten::zeros`.
```
$ ./bin/test_jit
CUDA not available. Disabling CUDA and MultiCUDA tests
Note: Google Test filter = *-*_CUDA:*_MultiCUDA
[==========] Running 75 tests from 1 test case.
[----------] Global test environment set-up.
[----------] 75 tests from JitTest
[ RUN ] JitTest.ADFormulas
[ OK ] JitTest.ADFormulas (82 ms)
[ RUN ] JitTest.Attributes
[ OK ] JitTest.Attributes (0 ms)
...
...
...
[ RUN ] JitTest.LiteInterpreterPrim
[ OK ] JitTest.LiteInterpreterPrim (0 ms)
[ RUN ] JitTest.LiteInterpreterLoadOrigJit
[ OK ] JitTest.LiteInterpreterLoadOrigJit (2 ms)
[----------] 75 tests from JitTest (150 ms total)
[----------] Global test environment tear-down
[==========] 75 tests from 1 test case ran. (150 ms total)
[ PASSED ] 75 tests.
```
**Fixes**
This pull request fixes https://github.com/pytorch/pytorch/issues/33500.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33914
Differential Revision: D20150304
Pulled By: SplitInfinity
fbshipit-source-id: c88f5289055a02dc20b7a5dcdf87469f9816d020
diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h
index 8f48640..e868985 100644
--- a/aten/src/ATen/core/interned_strings.h
+++ b/aten/src/ATen/core/interned_strings.h
@@ -79,6 +79,7 @@
_(prim, dtype) \
_(prim, shape) \
_(prim, requires_grad) \
+ _(prim, MakeTestTensor) /* test */ \
_(prim, AutogradAdd) \
_(prim, GradOf) \
_(aten, grad) \
diff --git a/test/cpp/jit/test_alias_analysis.cpp b/test/cpp/jit/test_alias_analysis.cpp
index d00cb9f..e2ad5c8 100644
--- a/test/cpp/jit/test_alias_analysis.cpp
+++ b/test/cpp/jit/test_alias_analysis.cpp
@@ -659,13 +659,9 @@
script::parseIR(
R"IR(
graph():
- %4 : Device? = prim::Constant()
- %2 : int? = prim::Constant()
- %0 : float = prim::Constant[value=1]()
- %20 : bool = prim::Constant[value=0]()
- %a : Tensor = aten::tensor(%0, %2, %4, %20)
+ %a : Tensor = prim::MakeTestTensor()
%a_list : Tensor[] = prim::ListConstruct(%a)
- %b : Tensor = aten::tensor(%0, %2, %4, %20)
+ %b : Tensor = prim::MakeTestTensor()
%b_list : Tensor[] = prim::ListConstruct(%b)
%13 : (Tensor[], Tensor[]) = prim::TupleConstruct(%a_list, %b_list)
return (%13)
@@ -746,19 +742,16 @@
script::parseIR(
R"IR(
graph():
- %10 : bool? = prim::Constant()
- %8 : Device? = prim::Constant()
- %4 : int? = prim::Constant()
%0 : int = prim::Constant[value=2]()
%1 : int = prim::Constant[value=3]()
%2 : int[] = prim::ListConstruct(%0, %1)
- %x : Tensor = aten::rand(%2, %4, %4, %8, %10)
+ %x : Tensor = prim::MakeTestTensor()
%12 : int[] = prim::ListConstruct(%0, %1)
- %y : Tensor = aten::rand(%12, %4, %4, %8, %10)
+ %y : Tensor = prim::MakeTestTensor()
%22 : int[] = prim::ListConstruct(%0, %1)
- %z : Tensor = aten::rand(%22, %4, %4, %8, %10)
+ %z : Tensor = prim::MakeTestTensor()
%32 : int[] = prim::ListConstruct(%0, %1)
- %fresh : Tensor = aten::rand(%32, %4, %4, %8, %10)
+ %fresh : Tensor = prim::MakeTestTensor()
%foo : Tensor[] = prim::ListConstruct(%x, %y)
%43 : Tensor[] = aten::append(%foo, %z)
return ()
@@ -791,13 +784,10 @@
script::parseIR(
R"IR(
graph():
- %10 : bool? = prim::Constant()
- %8 : Device? = prim::Constant()
- %4 : int? = prim::Constant()
%0 : int = prim::Constant[value=2]()
%1 : int = prim::Constant[value=3]()
%2 : int[] = prim::ListConstruct(%0, %1)
- %11 : Tensor = aten::rand(%2, %4, %4, %8, %10)
+ %11 : Tensor = prim::MakeTestTensor()
%12 : Tensor[] = prim::ListConstruct(%11)
%out : Tensor[] = custom::conservative(%12)
%ret.2 : Tensor = aten::div(%11, %11)
@@ -826,20 +816,17 @@
R"IR(
graph():
%35 : int = prim::Constant[value=1]()
- %10 : bool? = prim::Constant()
- %8 : Device? = prim::Constant()
- %4 : int? = prim::Constant()
%0 : int = prim::Constant[value=2]()
%1 : int = prim::Constant[value=3]()
%23 : int = prim::Constant[value=0]()
%2 : int[] = prim::ListConstruct(%0, %1)
- %11 : Tensor = aten::rand(%2, %4, %4, %8, %10)
+ %11 : Tensor = prim::MakeTestTensor()
%12 : int[] = prim::ListConstruct(%0, %1)
- %21 : Tensor = aten::rand(%12, %4, %4, %8, %10)
+ %21 : Tensor = prim::MakeTestTensor()
%l : Tensor[] = prim::ListConstruct(%11, %21)
%24 : Tensor = aten::select(%l, %23)
%25 : int[] = prim::ListConstruct(%0, %1)
- %34 : Tensor = aten::rand(%25, %4, %4, %8, %10)
+ %34 : Tensor = prim::MakeTestTensor()
%36 : Tensor = aten::add_(%24, %34, %35)
%37 : Tensor = uses::list(%l)
return (%37)
@@ -868,21 +855,18 @@
R"IR(
graph():
%38 : int = prim::Constant[value=1]()
- %10 : bool? = prim::Constant()
- %8 : Device? = prim::Constant()
- %4 : int? = prim::Constant()
%0 : int = prim::Constant[value=2]()
%1 : int = prim::Constant[value=3]()
%24 : int = prim::Constant[value=0]()
%2 : int[] = prim::ListConstruct(%0, %1)
- %11 : Tensor = aten::rand(%2, %4, %4, %8, %10)
+ %11 : Tensor = prim::MakeTestTensor()
%12 : int[] = prim::ListConstruct(%0, %1)
- %21 : Tensor = aten::rand(%12, %4, %4, %8, %10)
+ %21 : Tensor = prim::MakeTestTensor()
%l : Tensor[] = prim::ListConstruct(%11, %21)
%25 : Tensor = aten::select(%l, %24)
%27 : Tensor = aten::select(%25, %24, %24)
%28 : int[] = prim::ListConstruct(%0, %1)
- %37 : Tensor = aten::rand(%28, %4, %4, %8, %10)
+ %37 : Tensor = prim::MakeTestTensor()
%39 : Tensor = aten::add_(%27, %37, %38)
%40 : Tensor = uses::list(%l)
return (%40)
diff --git a/test/cpp/jit/test_base.cpp b/test/cpp/jit/test_base.cpp
new file mode 100644
index 0000000..8655586
--- /dev/null
+++ b/test/cpp/jit/test_base.cpp
@@ -0,0 +1,29 @@
+#include <test/cpp/jit/test_base.h>
+#include <test/cpp/jit/test_utils.h>
+
+#include "torch/csrc/jit/runtime/custom_operator.h"
+
+namespace torch {
+namespace jit {
+inline c10::OperatorOptions aliasAnalysisFromSchema() {
+ c10::OperatorOptions result;
+ result.setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA);
+ return result;
+}
+
+RegisterOperators reg({
+ // This operator is intended to be used in JIT analysis and transformation
+ // pass unit tests in which Values with type Tensor are often required. It
+ // should not be used in situations in which the graph is actually executed
+ // because it always produces empty Tensors.
+ Operator(
+ "prim::MakeTestTensor() -> Tensor",
+ [](Stack& stack) {
+ push(stack, at::Tensor());
+ return 0;
+ },
+ aliasAnalysisFromSchema()),
+});
+
+} // namespace jit
+} // namespace torch
diff --git a/test/cpp/jit/test_dce.cpp b/test/cpp/jit/test_dce.cpp
index a5fd200..097c300 100644
--- a/test/cpp/jit/test_dce.cpp
+++ b/test/cpp/jit/test_dce.cpp
@@ -22,17 +22,14 @@
graph():
%48 : None = prim::Constant()
%50 : bool = prim::Constant[value=1]()
- %10 : bool? = prim::Constant()
- %8 : Device? = prim::Constant()
- %4 : int? = prim::Constant()
%0 : int = prim::Constant[value=2]()
%12 : int = prim::Constant[value=1]()
%24 : int = prim::Constant[value=3]()
%31 : int = prim::Constant[value=0]()
%2 : int[] = prim::ListConstruct(%0, %0)
- %a.1 : Tensor = aten::ones(%2, %4, %4, %8, %10)
+ %a.1 : Tensor = prim::MakeTestTensor()
%14 : int[] = prim::ListConstruct(%12)
- %tot.1 : Tensor = aten::zeros(%14, %4, %4, %8, %10)
+ %tot.1 : Tensor = prim::MakeTestTensor()
%tot : Tensor = prim::Loop(%24, %50, %tot.1)
block0(%i : int, %tot.6 : Tensor):
%33 : Tensor = aten::select(%a.1, %31, %31)