[lite interpreter][hack] Add batch_norm_update_stats if batchnorm and training are present (#100134)
Summary: not sure how the train bool to batch_norm gets set. But its not the is_training module level flag. We get weird behavior for teams trying to do on device training because of this
Test Plan: ci
Differential Revision: D45335791
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100134
Approved by: https://github.com/larryliu0820
diff --git a/torch/csrc/jit/mobile/model_tracer/TracerRunner.cpp b/torch/csrc/jit/mobile/model_tracer/TracerRunner.cpp
index 20424df..e665713 100644
--- a/torch/csrc/jit/mobile/model_tracer/TracerRunner.cpp
+++ b/torch/csrc/jit/mobile/model_tracer/TracerRunner.cpp
@@ -94,6 +94,44 @@
}
/**
+ * Similar to setup methods there are a suite a functions that often appear
+ * under certain conditions but may avoid getting called in the trace due to the
+ * narrow nature of bundled inputs
+ */
+void call_dependent_methods(std::set<std::string>& root_ops) {
+ bool is_training = false;
+ bool has_batchnorm = false;
+ bool has_dropout = false;
+ for (const std::string& op : root_ops) {
+ if (op.find("backward") != std::string::npos ||
+ op.find("requires_grad_") != std::string::npos) {
+ is_training = true;
+ }
+ if (op.find("batch_norm") != std::string::npos) {
+ has_batchnorm = true;
+ }
+ if (op.find("dropout") != std::string::npos) {
+ has_dropout = true;
+ }
+ }
+ if (is_training && has_batchnorm) {
+ at::batch_norm(
+ at::ones({2, 2}),
+ c10::nullopt,
+ c10::nullopt,
+ c10::nullopt,
+ c10::nullopt,
+ true,
+ 0.1,
+ 0.1,
+ false);
+ }
+ if (is_training && has_dropout) {
+ at::dropout(at::ones({20, 20, 20}), 0.2, true);
+ }
+}
+
+/**
* Call methods on the Tensor object that we expect to be called
* in production on this Tensor.
*/
@@ -307,6 +345,8 @@
}
}
+ call_dependent_methods(root_ops);
+
op_tracer.getCalledOperators().withLock(
[&](std::set<std::string>& called_operators) {
traced_operators = called_operators;