Export BatchNorm functional and module, add necessary JIT support (#14016)

Summary:
This PR did three things:

1. It export the BatchNorm functional and module, and rewrite some of the components to stay align with the current supported JIT features
2. In the process of export, add necessary compiler support for in_place op aug assign
4. change the test_jit behavior in add_module_test to utilize a single rng state during module initialization
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14016

Differential Revision: D13112064

Pulled By: wanchaol

fbshipit-source-id: 31e3aee5fbb509673c781e7dbb6d8884cfa55d91
diff --git a/test/test_jit.py b/test/test_jit.py
index fd09c5b..0531c58 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -4997,6 +4997,24 @@
 
         f = Foo()
 
+    def test_script_module_param_buffer_mutation(self):
+        # TODO: add param mutation test case after JIT support it
+        class ModuleBufferMutate(torch.jit.ScriptModule):
+            __constants__ = ['training']
+
+            def __init__(self):
+                super(ModuleBufferMutate, self).__init__(False)
+                self.register_buffer('running_var', torch.tensor(0, dtype=torch.long))
+
+            @torch.jit.script_method
+            def forward(self):
+                if self.training:
+                    self.running_var += 1
+                return self.running_var
+
+        m = ModuleBufferMutate()
+        self.assertEqual(m(), 1)
+
     def test_script_module_for(self):
         class M(torch.jit.ScriptModule):
             __constants__ = ['b']
@@ -9653,6 +9671,9 @@
 # )
 nn_module_tests = [
     ('AlphaDropout', (), ((S,),)),
+    ('BatchNorm1d', (10,), ((S, 10),)),
+    ('BatchNorm2d', (10,), ((S, 10, S, S),)),
+    ('BatchNorm3d', (10,), ((S, 10, S, S, S),)),
     ('Dropout', (), ((S,),)),
     ('Dropout2d', (), ((S, S),)),
     ('Dropout3d', (), ((S, S, S),)),
@@ -9947,7 +9968,7 @@
 
         # Construct a script module that passes arguments through
         # to self.submodule
-        def create_module(*args, **kwargs):
+        def create_script_module(*args, **kwargs):
             formals, tensors, actuals = get_script_args(args)
 
             method_args = ', '.join(['self'] + actuals)
@@ -9972,15 +9993,20 @@
 
             # Check there are no Python ops by exporting
             self.assertExportImportModule(module, tensors)
-            create_module.last_graph = module.graph
+            create_script_module.last_graph = module.graph
+            return module(*args)
+
+        # Construct a normal nn module to stay consistent with create_script_module
+        # and make use of a single global rng_state in module initialization
+        def create_nn_module(*args, **kwargs):
+            module = nn_module(*constructor_args)
             return module(*args)
 
         # Check against Python module as reference
         args_variable, kwargs_variable = create_input(call_args)
         f_args_variable = deepcopy(unpack_variables(args_variable))
-        reference = nn_module(*constructor_args)
 
-        check_against_reference(self, create_module, reference, f_args_variable)
+        check_against_reference(self, create_script_module, create_nn_module, f_args_variable)
 
     test_name = 'test_nn_{}'.format(module_name)
     post_add_test(test_name, skipTestIf, do_test)
diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp
index c9cfcb9..766d4d9 100644
--- a/torch/csrc/jit/script/compiler.cpp
+++ b/torch/csrc/jit/script/compiler.cpp
@@ -1414,6 +1414,9 @@
       case TK_VAR: {
         emitAugAssignmentToVar(stmt);
       } break;
+      case '.': {
+        emitAugAssignmentToSelectVar(stmt);
+      } break;
       case TK_SUBSCRIPT: {
         emitAugAssignmentToSubscript(stmt);
       } break;
@@ -1424,6 +1427,45 @@
     }
   }
 
+  // This will be called when there is a class param or module buffer
+  // mutation which make the LHS of the expr be a select expression
+  //
+  // Example like:
+  // class A(Module):
+  //  def __init__():
+  //    self.register_buffer("running_var", torch.zeros(1))
+  //
+  //  def forward():
+  //    self.num_batches += 1
+  //
+  // In this case we will only consider the scenario that the module
+  // buffer type is a tensor, and we emit the corresponding tensor
+  // in place op, and throw error for other unsupported types
+  void emitAugAssignmentToSelectVar(const AugAssign& stmt) {
+    const auto lhs = Select(stmt.lhs());
+    const auto lhsSugaredVar = environment_stack->getSugaredVar(Var(lhs.value()).name());
+    const auto lhsValue = lhsSugaredVar->attr(lhs.range(), method, lhs.selector().name())->asValue(lhs.range(), method);
+    if (lhsValue->type()->isSubtypeOf(DynamicType::get())) {
+      // for module parameter/buffer assignment, only consider tensor types,
+      // emit the corresponding in-place op
+      const auto rhs = NamedValue(stmt.rhs().range(), emitExpr(stmt.rhs()));
+      const auto self = NamedValue(stmt.lhs().range(), "self", lhsValue);
+      emitBuiltinCall(
+          stmt.range(),
+          *method.graph(),
+          getAugOp(stmt, /*isTensor=*/true),
+          self,
+          {rhs},
+          {},
+          /*required=*/true);
+
+    } else {
+        throw ErrorReport(stmt.lhs())
+            << "left-hand side of augmented assignment to module "
+            << "parameters/buffers can only be tensor types";
+    }
+  }
+
   void emitAugAssignmentToVar(const AugAssign& stmt) {
     const auto lhs = Var(stmt.lhs());
     const auto lhsValue = environment_stack->getSugaredVar(lhs.name())
diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py
index 3e411dd..4fecbd5 100644
--- a/torch/jit/__init__.py
+++ b/torch/jit/__init__.py
@@ -1379,8 +1379,6 @@
 
 
 _register_builtin(len, 'aten::len')
-
-
 _register_builtin(_wait, 'aten::wait')
 
 # torch.jit.Error
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index 35393c8..cac9a38 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -4,8 +4,6 @@
 import warnings
 import math
 import types
-from operator import mul
-from functools import reduce
 
 import torch
 from torch._C import _infer_size, _add_docstr
@@ -1413,17 +1411,32 @@
     return ret
 
 
+@torch._jit_internal.weak_script
 def batch_norm(input, running_mean, running_var, weight=None, bias=None,
                training=False, momentum=0.1, eps=1e-5):
+    # type: (Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], bool, float, float) -> Tensor
     r"""Applies Batch Normalization for each channel across a batch of data.
 
     See :class:`~torch.nn.BatchNorm1d`, :class:`~torch.nn.BatchNorm2d`,
     :class:`~torch.nn.BatchNorm3d` for details.
     """
     if training:
-        size = list(input.size())
-        if reduce(mul, size[2:], size[0]) == 1:
+        size = input.size()
+        # XXX: JIT script does not support the reduce from functools, and mul op is a
+        # builtin, which cannot be used as a value to a func yet, so rewrite this size
+        # check to a simple equivalent for loop
+        #
+        # TODO: make use of reduce like below when JIT is ready with the missing features:
+        # from operator import mul
+        # from functools import reduce
+        #
+        #   if reduce(mul, size[2:], size[0]) == 1
+        size_prods = size[0]
+        for i in range(len(size) - 2):
+            size_prods *= size[i + 2]
+        if size_prods == 1:
             raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))
+
     return torch.batch_norm(
         input, weight, bias, running_mean, running_var,
         training, momentum, eps, torch.backends.cudnn.enabled
diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py
index deb280a..ec3385e 100644
--- a/torch/nn/modules/batchnorm.py
+++ b/torch/nn/modules/batchnorm.py
@@ -1,14 +1,19 @@
+from __future__ import division
+
 import torch
 from .module import Module
 from torch.nn.parameter import Parameter
 from .. import functional as F
 from .. import init
+from ..._jit_internal import weak_module, weak_script_method
 
 
 # TODO: check contiguous in THNN
 # TODO: use separate backend functions?
+@weak_module
 class _BatchNorm(Module):
     _version = 2
+    __constants__ = ['training', 'track_running_stats', 'momentum', 'eps']
 
     def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
                  track_running_stats=True):
@@ -49,6 +54,7 @@
     def _check_input_dim(self, input):
         raise NotImplementedError
 
+    @weak_script_method
     def forward(self, input):
         self._check_input_dim(input)
 
@@ -57,7 +63,7 @@
         if self.training and self.track_running_stats:
             self.num_batches_tracked += 1
             if self.momentum is None:  # use cumulative moving average
-                exponential_average_factor = 1.0 / self.num_batches_tracked.item()
+                exponential_average_factor = 1.0 / float(self.num_batches_tracked)
             else:  # use exponential moving average
                 exponential_average_factor = self.momentum
 
@@ -86,6 +92,7 @@
             missing_keys, unexpected_keys, error_msgs)
 
 
+@weak_module
 class BatchNorm1d(_BatchNorm):
     r"""Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D
     inputs with optional additional channel dimension) as described in the paper
@@ -158,6 +165,7 @@
                              .format(input.dim()))
 
 
+@weak_module
 class BatchNorm2d(_BatchNorm):
     r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs
     with additional channel dimension) as described in the paper
@@ -230,6 +238,7 @@
                              .format(input.dim()))
 
 
+@weak_module
 class BatchNorm3d(_BatchNorm):
     r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs
     with additional channel dimension) as described in the paper