move instance_norm to aten (#10792)

Summary:
This also removes the usage of torch.onnx.symbolic_override in instance_norm. Fixes #8439.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10792

Differential Revision: D9800643

Pulled By: li-roy

fbshipit-source-id: fa13a57de5a31fbfa2d4d02639d214c867b9e1f1
diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp
index 750e425..ed0a94a 100644
--- a/aten/src/ATen/native/Normalization.cpp
+++ b/aten/src/ATen/native/Normalization.cpp
@@ -14,6 +14,13 @@
     AT_CHECK(actual == expected,
              arg_name, " should contain ", expected, " elements not ", actual);
   }
+
+  static inline Tensor repeat_if_defined(const Tensor& t, int64_t repeat) {
+    if (t.defined()) {
+      return t.repeat(repeat);
+    }
+    return t;
+  }
 }
 
 Tensor batch_norm(
@@ -80,6 +87,38 @@
             running_mean, running_var, training, momentum, eps);
 }
 
+Tensor instance_norm(
+    const Tensor& input, const Tensor& weight /* optional */, const Tensor& bias /* optional */,
+    const Tensor& running_mean /* optional */, const Tensor& running_var /* optional */,
+    bool use_input_stats, double momentum, double eps, bool cudnn_enabled) {
+  AT_CHECK(use_input_stats || (running_mean.defined() && running_var.defined()),
+           "Expected running_mean and running_var to be defined when use_input_stats is false");
+  std::vector<int64_t> shape = input.sizes().vec();
+  int64_t b = input.size(0);
+  int64_t c = input.size(1);
+  shape[1] = b * c;
+  shape[0] = 1;
+
+  Tensor weight_ = repeat_if_defined(weight, b);
+  Tensor bias_ = repeat_if_defined(bias, b);
+  Tensor running_mean_ = repeat_if_defined(running_mean, b);
+  Tensor running_var_ = repeat_if_defined(running_var, b);
+
+  auto input_reshaped = input.contiguous().view(shape);
+  auto out = at::batch_norm(input_reshaped, weight_, bias_, running_mean_, running_var_,
+                            use_input_stats, momentum, eps, cudnn_enabled);
+
+  // we alias running_mean and running_var because they are const but we want to modify their data
+  if (running_mean.defined()) {
+    at::alias(running_mean).copy_(running_mean_.view({ b, c }).mean(0, false));
+  }
+  if (running_var.defined()) {
+    at::alias(running_var).copy_(running_var_.view({ b, c }).mean(0, false));
+  }
+
+  return out.view(input.sizes());
+}
+
 Tensor layer_norm(const Tensor& input, IntList normalized_shape,
     const Tensor& weight /* optional */, const Tensor& bias /* optional */,
     double eps, bool cudnn_enabled) {
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index af58e45..c15aefa 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -873,6 +873,9 @@
 - func: index_put_(Tensor self, TensorList indices, Tensor values) -> Tensor
   variants: function, method
 
+- func: instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, double momentum, double eps, bool cudnn_enabled) -> Tensor
+  variants: function
+
 - func: inverse(Tensor self) -> Tensor
   variants: function, method
 
diff --git a/test/test_jit.py b/test/test_jit.py
index 61a6b94..f6a4347 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -6966,9 +6966,8 @@
         # XXX: export_import on CUDA modules doesn't work (#11480)
         self._test_dcgan_models(self, device='cuda', check_export_import=False)
 
-    # XXX: When this is fixed, write a CUDA test for this.
-    @unittest.skip('https://github.com/pytorch/pytorch/issues/8439 InstanceNormalization bug')
-    def test_neural_style(self):
+    @staticmethod
+    def _test_neural_style(self, device, check_export_import=True):
         class TransformerNet(torch.nn.Module):
             def __init__(self):
                 super(TransformerNet, self).__init__()
@@ -7065,7 +7064,15 @@
                 out = self.conv2d(out)
                 return out
 
-        self.checkTrace(TransformerNet(), (torch.rand(5, 3, 224, 224),))
+        self.checkTrace(TransformerNet(), (torch.rand(5, 3, 64, 64),), export_import=check_export_import)
+
+    def test_neural_style(self):
+        self._test_neural_style(self, device='cpu')
+
+    @unittest.skipIf(not RUN_CUDA, "no CUDA")
+    def test_neural_style_cuda(self):
+        # XXX: export_import on CUDA modules doesn't work (#11480)
+        self._test_neural_style(self, device='cuda', check_export_import=False)
 
     @staticmethod
     def _test_mnist(self, device, check_export_import=True):
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index 7823d44..6901124 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -1357,46 +1357,10 @@
     See :class:`~torch.nn.InstanceNorm1d`, :class:`~torch.nn.InstanceNorm2d`,
     :class:`~torch.nn.InstanceNorm3d` for details.
     """
-    if not use_input_stats and (running_mean is None or running_var is None):
-        raise ValueError('Expected running_mean and running_var to be not None when use_input_stats=False')
-
-    b, c = input.size(0), input.size(1)
-    if weight is not None:
-        weight = weight.repeat(b)
-    if bias is not None:
-        bias = bias.repeat(b)
-
-    import torch.onnx.symbolic
-
-    @torch.onnx.symbolic_override(torch.onnx.symbolic.instance_norm)
-    def _instance_norm(input, running_mean=None, running_var=None, weight=None,
-                       bias=None, use_input_stats=None, momentum=None, eps=None):
-        # Repeat stored stats and affine transform params if necessary
-        if running_mean is not None:
-            running_mean_orig = running_mean
-            running_mean = running_mean_orig.repeat(b)
-        if running_var is not None:
-            running_var_orig = running_var
-            running_var = running_var_orig.repeat(b)
-
-        # Apply instance norm
-        input_reshaped = input.contiguous().view(1, b * c, *input.size()[2:])
-
-        out = batch_norm(
-            input_reshaped, running_mean, running_var, weight=weight, bias=bias,
-            training=use_input_stats, momentum=momentum, eps=eps)
-
-        # Reshape and copy back
-        if running_mean is not None:
-            running_mean_orig.copy_(running_mean.view(b, c).mean(0, keepdim=False))
-        if running_var is not None:
-            running_var_orig.copy_(running_var.view(b, c).mean(0, keepdim=False))
-
-        return out.view(b, c, *input.size()[2:])
-    return _instance_norm(input, running_mean=running_mean,
-                          running_var=running_var, weight=weight, bias=bias,
-                          use_input_stats=use_input_stats, momentum=momentum,
-                          eps=eps)
+    return torch.instance_norm(
+        input, weight, bias, running_mean, running_var,
+        use_input_stats, momentum, eps, torch.backends.cudnn.enabled
+    )
 
 
 def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5):
diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py
index d9c5964..30e8672 100644
--- a/torch/onnx/symbolic.py
+++ b/torch/onnx/symbolic.py
@@ -716,6 +716,22 @@
         return res
 
 
+@parse_args('v', 'v', 'v', 'v', 'v', 'i', 'f', 'f', 'i')
+def instance_norm(g, input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, cudnn_enabled):
+    input_sizes = input.type().sizes()
+    if weight is None or weight.node().kind() == "prim::Undefined":
+        assert len(input_sizes) > 1
+        weight_value = torch.tensor([1.] * input_sizes[1]).type(
+            'torch.' + input.type().scalarType() + 'Tensor')
+        weight = g.op("Constant", value_t=weight_value)
+    if bias is None or bias.node().kind() == "prim::Undefined":
+        assert len(input_sizes) > 1
+        bias_value = torch.tensor([0.] * input_sizes[1]).type(
+            'torch.' + input.type().scalarType() + 'Tensor')
+        bias = g.op("Constant", value_t=bias_value)
+    return g.op("InstanceNormalization", input, weight, bias, epsilon_f=eps)
+
+
 @parse_args('v', 'i', 'i', 'i')
 def unfold(g, input, dimension, size, step):
     return g.op("ATen", input, operator_s="unfold", dimension_i=dimension, size_i=size, step_i=step)
@@ -1027,22 +1043,6 @@
     return g.op("Tile", self, repeats)
 
 
-def instance_norm(g, input, **kwargs):
-    input_type = input.type().scalarType()
-    weight = kwargs.get("weight", None)
-    bias = kwargs.get("bias", None)
-    eps = kwargs.get("eps", 1e-5)
-    if weight is None:
-        weight = g.constant(1.0, [input.type().sizes()[1]], input_type)
-    else:
-        weight = g.op('Constant', value_t=weight)
-    if bias is None:
-        bias = g.constant(0.0, [input.type().sizes()[1]], input_type)
-    else:
-        bias = g.op('Constant', value_t=bias)
-    return g.op("InstanceNormalization", input, weight, bias, epsilon_f=eps)
-
-
 def _generic_rnn(g, variant, input, initial_states, all_weights, has_biases,
                  num_layers, dropout, train, bidirectional, batch_first=None, batch_sizes=None):
     weights_per_layer = 4 if has_biases else 2