[JIT] Modify is_nondeterministic to utilize tags in SchemaInfo for non-mobile contexts and integrate with ir.cpp (#82253)
- Modified is_nondeterministic method in SchemaInfo class to utilize tags.
- Modified isNonDeterministic method in ir.cpp to utilize SchemaInfo when a Node is an aten op.
- Added an assert to ensure that if a node is an aten op kind, it has a schema.
- Tested through verifying that all IR.cpp tests run, and through adding 2 custom determinism checks to test for the special dropout edge case and a general bernoulli case.
Differential Revision: [D38179499](https://our.internmc.facebook.com/intern/diff/D38179499)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82253
Approved by: https://github.com/davidberard98
diff --git a/test/cpp/jit/test_alias_analysis.cpp b/test/cpp/jit/test_alias_analysis.cpp
index f3a4c0a..c42b7a2 100644
--- a/test/cpp/jit/test_alias_analysis.cpp
+++ b/test/cpp/jit/test_alias_analysis.cpp
@@ -1607,5 +1607,57 @@
[&graph] { AliasDb aliasDb(graph); },
"Tried to register operator foo::rand12(Tensor(a) arg1) -> Tensor(b) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA");
}
+
+TEST(IRNonDeterminismTest, Basic) {
+ auto graph = std::make_shared<Graph>();
+ auto graph_string = R"IR(
+ graph():
+ %x : Tensor = prim::MakeTestTensor()
+ %0 : int = prim::Constant[value=0]()
+ %1 : NoneType = prim::Constant()
+ %2 : Tensor = aten::bernoulli(%x, %1)
+ %3 : Tensor = aten::add(%x, %2, %0)
+ return (%3))IR";
+ parseIR(graph_string, graph.get());
+
+ for (Node* n : graph->nodes()) {
+ if (n->kind() == aten::bernoulli) {
+ ASSERT_TRUE(n->isNondeterministic());
+ } else {
+ ASSERT_FALSE(n->isNondeterministic());
+ }
+ }
+}
+
+TEST(IRNonDeterminismTest, DropoutSpecialCase) {
+ auto graph = std::make_shared<Graph>();
+ auto graph_string = R"IR(
+ graph():
+ %x : Tensor = prim::MakeTestTensor()
+ %0 : bool = prim::Constant[value=0]()
+ %1 : bool = prim::Constant[value=1]()
+ %3 : int = prim::Constant[value=1]()
+ %3 : float = prim::Constant[value=1.0]()
+ %4 : Tensor = aten::dropout(%x, %3, %0)
+ %5 : Tensor = aten::dropout(%x, %3, %1)
+ %6 : Tensor = aten::add(%4, %5, %3)
+ return (%6))IR";
+ parseIR(graph_string, graph.get());
+
+ bool train = false;
+ for (Node* n : graph->nodes()) {
+ if (n->kind() == aten::dropout) {
+ if (!train) {
+ ASSERT_FALSE(n->isNondeterministic());
+ train = true;
+ } else {
+ ASSERT_TRUE(n->isNondeterministic());
+ }
+ } else {
+ ASSERT_FALSE(n->isNondeterministic());
+ }
+ }
+}
+
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp
index 0ddaebf..2b61b6b 100644
--- a/torch/csrc/jit/ir/ir.cpp
+++ b/torch/csrc/jit/ir/ir.cpp
@@ -1144,40 +1144,25 @@
}
bool Node::isNondeterministic() const {
- static const OperatorSet nondeterministic_ops = {
- "aten::dropout(Tensor input, float p, bool train) -> Tensor",
- "aten::_fused_dropout(Tensor self, float p, Generator? generator) -> (Tensor, Tensor)",
- "aten::_standard_gamma(Tensor self, Generator? generator) -> Tensor",
- "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
- "aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor",
- "aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator? generator) -> Tensor",
- "aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)",
- "aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor",
- "aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor",
- "aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor",
- "aten::poisson(Tensor self, Generator? generator) -> Tensor",
- "aten::binomial(Tensor count, Tensor prob, Generator? generator=None) -> Tensor",
- "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
- "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
- "aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
- "aten::rand_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
- "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
- "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
- "aten::randint_like(Tensor self, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
- "aten::randint_like(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
- "aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
- "aten::randn_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
- "aten::randperm(int n, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor"};
-
- if (!isMemberOf(nondeterministic_ops)) {
+ const auto schema = maybeSchema();
+ if (!kind().is_aten()) {
return false;
}
- // Dropout with train = False is deterministic
- if (matches("aten::dropout(Tensor input, float p, bool train) -> Tensor") &&
- is_constant(attr::train) && !get<bool>(attr::train).value()) {
+ // All aten ops are expecte to have a schema. However this is left as a
+ // warning instead of an assert to ensure that previous use cases do not
+ // break.
+ if (!schema) {
+ TORCH_WARN("aten Schema not found.");
return false;
}
- return true;
+ torch::utils::SchemaInfo schema_info(*schema);
+ if (hasNamedInput("train")) {
+ auto value = constant_as<bool>(namedInput("train"));
+ if (value.has_value()) {
+ schema_info.addArgumentValue("train", *value);
+ }
+ }
+ return schema_info.is_nondeterministic();
}
bool Node::hasSideEffects() const {
diff --git a/torch/csrc/utils/schema_info.cpp b/torch/csrc/utils/schema_info.cpp
index 55b1b55..a19d319 100644
--- a/torch/csrc/utils/schema_info.cpp
+++ b/torch/csrc/utils/schema_info.cpp
@@ -1,3 +1,4 @@
+#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/csrc/utils/schema_info.h>
namespace torch {
@@ -107,20 +108,27 @@
}
bool SchemaInfo::is_nondeterministic() const {
- static const std::vector<c10::FunctionSchema> nondeterministic_ops =
- getNonDeterministicOps();
- static const c10::FunctionSchema detach_schema = torch::jit::parseSchema(
+ static const c10::FunctionSchema dropout_schema = torch::jit::parseSchema(
"aten::dropout(Tensor input, float p, bool train) -> Tensor");
- if (detach_schema == this->schema_ && value_map_.count("train") &&
+ if (dropout_schema == schema_ && value_map_.count("train") &&
!value_map_.at("train").toBool()) {
return false;
}
+
+#if defined C10_MOBILE
+ static const std::vector<c10::FunctionSchema> nondeterministic_ops =
+ getNonDeterministicOps();
return std::any_of(
nondeterministic_ops.begin(),
nondeterministic_ops.end(),
[this](const c10 ::FunctionSchema& nondeterministic_op) {
return nondeterministic_op == this->schema_;
});
+#else
+ const auto& op = c10::Dispatcher::singleton().findOp(
+ c10::OperatorName(schema_.name(), schema_.overload_name()));
+ return op && op->hasTag(at::Tag::nondeterministic_seeded);
+#endif
}
bool SchemaInfo::may_alias(
@@ -203,6 +211,21 @@
wildcard_set_.count(rhs);
}
+void SchemaInfo::ensureConservativity(
+ const std::unordered_set<at::Symbol>& duplicates,
+ const std::vector<c10::Argument>& arguments_list,
+ c10::SchemaArgType type) {
+ for (size_t i = 0; i < arguments_list.size(); i++) {
+ if (arguments_list[i].alias_info()) {
+ for (const auto& set : arguments_list[i].alias_info()->afterSets()) {
+ if (duplicates.count(set)) {
+ wildcard_set_.insert({type, i});
+ }
+ }
+ }
+ }
+}
+
std::vector<c10::FunctionSchema> SchemaInfo::getNonDeterministicOps() {
// This list of nondeterministic ops is copied from JIT ir.cpp.
static const std::vector<std::string> nondeterministic_op_strings = {
@@ -239,21 +262,6 @@
return nondeterministic_ops;
}
-void SchemaInfo::ensureConservativity(
- const std::unordered_set<at::Symbol>& duplicates,
- const std::vector<c10::Argument>& arguments_list,
- c10::SchemaArgType type) {
- for (size_t i = 0; i < arguments_list.size(); i++) {
- if (arguments_list[i].alias_info()) {
- for (const auto& set : arguments_list[i].alias_info()->afterSets()) {
- if (duplicates.count(set)) {
- wildcard_set_.insert({type, i});
- }
- }
- }
- }
-}
-
std::vector<c10::FunctionSchema> SchemaInfo::getTrainingOps() {
// This is a list of ops where the a boolean variable (either "training",
// "train" or "use_input_stats") affects the mutability of running_mean and