move instance_norm to aten (#10792)
Summary:
This also removes the usage of torch.onnx.symbolic_override in instance_norm. Fixes #8439.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10792
Differential Revision: D9800643
Pulled By: li-roy
fbshipit-source-id: fa13a57de5a31fbfa2d4d02639d214c867b9e1f1
diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp
index 750e425..ed0a94a 100644
--- a/aten/src/ATen/native/Normalization.cpp
+++ b/aten/src/ATen/native/Normalization.cpp
@@ -14,6 +14,13 @@
AT_CHECK(actual == expected,
arg_name, " should contain ", expected, " elements not ", actual);
}
+
+ static inline Tensor repeat_if_defined(const Tensor& t, int64_t repeat) {
+ if (t.defined()) {
+ return t.repeat(repeat);
+ }
+ return t;
+ }
}
Tensor batch_norm(
@@ -80,6 +87,38 @@
running_mean, running_var, training, momentum, eps);
}
+Tensor instance_norm(
+ const Tensor& input, const Tensor& weight /* optional */, const Tensor& bias /* optional */,
+ const Tensor& running_mean /* optional */, const Tensor& running_var /* optional */,
+ bool use_input_stats, double momentum, double eps, bool cudnn_enabled) {
+ AT_CHECK(use_input_stats || (running_mean.defined() && running_var.defined()),
+ "Expected running_mean and running_var to be defined when use_input_stats is false");
+ std::vector<int64_t> shape = input.sizes().vec();
+ int64_t b = input.size(0);
+ int64_t c = input.size(1);
+ shape[1] = b * c;
+ shape[0] = 1;
+
+ Tensor weight_ = repeat_if_defined(weight, b);
+ Tensor bias_ = repeat_if_defined(bias, b);
+ Tensor running_mean_ = repeat_if_defined(running_mean, b);
+ Tensor running_var_ = repeat_if_defined(running_var, b);
+
+ auto input_reshaped = input.contiguous().view(shape);
+ auto out = at::batch_norm(input_reshaped, weight_, bias_, running_mean_, running_var_,
+ use_input_stats, momentum, eps, cudnn_enabled);
+
+ // we alias running_mean and running_var because they are const but we want to modify their data
+ if (running_mean.defined()) {
+ at::alias(running_mean).copy_(running_mean_.view({ b, c }).mean(0, false));
+ }
+ if (running_var.defined()) {
+ at::alias(running_var).copy_(running_var_.view({ b, c }).mean(0, false));
+ }
+
+ return out.view(input.sizes());
+}
+
Tensor layer_norm(const Tensor& input, IntList normalized_shape,
const Tensor& weight /* optional */, const Tensor& bias /* optional */,
double eps, bool cudnn_enabled) {
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index af58e45..c15aefa 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -873,6 +873,9 @@
- func: index_put_(Tensor self, TensorList indices, Tensor values) -> Tensor
variants: function, method
+- func: instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, double momentum, double eps, bool cudnn_enabled) -> Tensor
+ variants: function
+
- func: inverse(Tensor self) -> Tensor
variants: function, method
diff --git a/test/test_jit.py b/test/test_jit.py
index 61a6b94..f6a4347 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -6966,9 +6966,8 @@
# XXX: export_import on CUDA modules doesn't work (#11480)
self._test_dcgan_models(self, device='cuda', check_export_import=False)
- # XXX: When this is fixed, write a CUDA test for this.
- @unittest.skip('https://github.com/pytorch/pytorch/issues/8439 InstanceNormalization bug')
- def test_neural_style(self):
+ @staticmethod
+ def _test_neural_style(self, device, check_export_import=True):
class TransformerNet(torch.nn.Module):
def __init__(self):
super(TransformerNet, self).__init__()
@@ -7065,7 +7064,15 @@
out = self.conv2d(out)
return out
- self.checkTrace(TransformerNet(), (torch.rand(5, 3, 224, 224),))
+ self.checkTrace(TransformerNet(), (torch.rand(5, 3, 64, 64),), export_import=check_export_import)
+
+ def test_neural_style(self):
+ self._test_neural_style(self, device='cpu')
+
+ @unittest.skipIf(not RUN_CUDA, "no CUDA")
+ def test_neural_style_cuda(self):
+ # XXX: export_import on CUDA modules doesn't work (#11480)
+ self._test_neural_style(self, device='cuda', check_export_import=False)
@staticmethod
def _test_mnist(self, device, check_export_import=True):
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index 7823d44..6901124 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -1357,46 +1357,10 @@
See :class:`~torch.nn.InstanceNorm1d`, :class:`~torch.nn.InstanceNorm2d`,
:class:`~torch.nn.InstanceNorm3d` for details.
"""
- if not use_input_stats and (running_mean is None or running_var is None):
- raise ValueError('Expected running_mean and running_var to be not None when use_input_stats=False')
-
- b, c = input.size(0), input.size(1)
- if weight is not None:
- weight = weight.repeat(b)
- if bias is not None:
- bias = bias.repeat(b)
-
- import torch.onnx.symbolic
-
- @torch.onnx.symbolic_override(torch.onnx.symbolic.instance_norm)
- def _instance_norm(input, running_mean=None, running_var=None, weight=None,
- bias=None, use_input_stats=None, momentum=None, eps=None):
- # Repeat stored stats and affine transform params if necessary
- if running_mean is not None:
- running_mean_orig = running_mean
- running_mean = running_mean_orig.repeat(b)
- if running_var is not None:
- running_var_orig = running_var
- running_var = running_var_orig.repeat(b)
-
- # Apply instance norm
- input_reshaped = input.contiguous().view(1, b * c, *input.size()[2:])
-
- out = batch_norm(
- input_reshaped, running_mean, running_var, weight=weight, bias=bias,
- training=use_input_stats, momentum=momentum, eps=eps)
-
- # Reshape and copy back
- if running_mean is not None:
- running_mean_orig.copy_(running_mean.view(b, c).mean(0, keepdim=False))
- if running_var is not None:
- running_var_orig.copy_(running_var.view(b, c).mean(0, keepdim=False))
-
- return out.view(b, c, *input.size()[2:])
- return _instance_norm(input, running_mean=running_mean,
- running_var=running_var, weight=weight, bias=bias,
- use_input_stats=use_input_stats, momentum=momentum,
- eps=eps)
+ return torch.instance_norm(
+ input, weight, bias, running_mean, running_var,
+ use_input_stats, momentum, eps, torch.backends.cudnn.enabled
+ )
def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5):
diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py
index d9c5964..30e8672 100644
--- a/torch/onnx/symbolic.py
+++ b/torch/onnx/symbolic.py
@@ -716,6 +716,22 @@
return res
+@parse_args('v', 'v', 'v', 'v', 'v', 'i', 'f', 'f', 'i')
+def instance_norm(g, input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, cudnn_enabled):
+ input_sizes = input.type().sizes()
+ if weight is None or weight.node().kind() == "prim::Undefined":
+ assert len(input_sizes) > 1
+ weight_value = torch.tensor([1.] * input_sizes[1]).type(
+ 'torch.' + input.type().scalarType() + 'Tensor')
+ weight = g.op("Constant", value_t=weight_value)
+ if bias is None or bias.node().kind() == "prim::Undefined":
+ assert len(input_sizes) > 1
+ bias_value = torch.tensor([0.] * input_sizes[1]).type(
+ 'torch.' + input.type().scalarType() + 'Tensor')
+ bias = g.op("Constant", value_t=bias_value)
+ return g.op("InstanceNormalization", input, weight, bias, epsilon_f=eps)
+
+
@parse_args('v', 'i', 'i', 'i')
def unfold(g, input, dimension, size, step):
return g.op("ATen", input, operator_s="unfold", dimension_i=dimension, size_i=size, step_i=step)
@@ -1027,22 +1043,6 @@
return g.op("Tile", self, repeats)
-def instance_norm(g, input, **kwargs):
- input_type = input.type().scalarType()
- weight = kwargs.get("weight", None)
- bias = kwargs.get("bias", None)
- eps = kwargs.get("eps", 1e-5)
- if weight is None:
- weight = g.constant(1.0, [input.type().sizes()[1]], input_type)
- else:
- weight = g.op('Constant', value_t=weight)
- if bias is None:
- bias = g.constant(0.0, [input.type().sizes()[1]], input_type)
- else:
- bias = g.op('Constant', value_t=bias)
- return g.op("InstanceNormalization", input, weight, bias, epsilon_f=eps)
-
-
def _generic_rnn(g, variant, input, initial_states, all_weights, has_biases,
num_layers, dropout, train, bidirectional, batch_first=None, batch_sizes=None):
weights_per_layer = 4 if has_biases else 2