[lite interpreter][hack] Add batch_norm_update_stats if batchnorm and training are present (#100134)

Summary: not sure how the train bool to batch_norm gets set. But its not the is_training module level flag. We get weird behavior for teams trying to do on device training because of this

Test Plan: ci

Differential Revision: D45335791

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100134
Approved by: https://github.com/larryliu0820
diff --git a/torch/csrc/jit/mobile/model_tracer/TracerRunner.cpp b/torch/csrc/jit/mobile/model_tracer/TracerRunner.cpp
index 20424df..e665713 100644
--- a/torch/csrc/jit/mobile/model_tracer/TracerRunner.cpp
+++ b/torch/csrc/jit/mobile/model_tracer/TracerRunner.cpp
@@ -94,6 +94,44 @@
 }
 
 /**
+ * Similar to setup methods there are a suite a functions that often appear
+ * under certain conditions but may avoid getting called in the trace due to the
+ * narrow nature of bundled inputs
+ */
+void call_dependent_methods(std::set<std::string>& root_ops) {
+  bool is_training = false;
+  bool has_batchnorm = false;
+  bool has_dropout = false;
+  for (const std::string& op : root_ops) {
+    if (op.find("backward") != std::string::npos ||
+        op.find("requires_grad_") != std::string::npos) {
+      is_training = true;
+    }
+    if (op.find("batch_norm") != std::string::npos) {
+      has_batchnorm = true;
+    }
+    if (op.find("dropout") != std::string::npos) {
+      has_dropout = true;
+    }
+  }
+  if (is_training && has_batchnorm) {
+    at::batch_norm(
+        at::ones({2, 2}),
+        c10::nullopt,
+        c10::nullopt,
+        c10::nullopt,
+        c10::nullopt,
+        true,
+        0.1,
+        0.1,
+        false);
+  }
+  if (is_training && has_dropout) {
+    at::dropout(at::ones({20, 20, 20}), 0.2, true);
+  }
+}
+
+/**
  * Call methods on the Tensor object that we expect to be called
  * in production on this Tensor.
  */
@@ -307,6 +345,8 @@
     }
   }
 
+  call_dependent_methods(root_ops);
+
   op_tracer.getCalledOperators().withLock(
       [&](std::set<std::string>& called_operators) {
         traced_operators = called_operators;