[ready] Layer Normalization (#4922)

* at::maybe_data_ptr and Check.h => TensorUtils.h

* THNN support for optional BN running_*

* ATen support for optional BN running_*

* Python nn.* support for optional BN running_*; Improve IN and BN doc

* Add tests for IN and BN new option

* Layer Norm

* Fix LRN doc

* functional interface for LN and IN

* Layer norm tests

* fix BN double backward returning undefined tensors

* fix jit test using wrong dim inputs for BN

* add/improve BN, IN and LN GPU tests with half type

* Udpate docs to be consistent with Conv notation
Fix onnx
Clarified onnx symbokic wrapper

* fix typo

* Address comments
diff --git a/aten/src/ATen/CPUApplyUtils.h b/aten/src/ATen/CPUApplyUtils.h
index 93773c0..123332e 100644
--- a/aten/src/ATen/CPUApplyUtils.h
+++ b/aten/src/ATen/CPUApplyUtils.h
@@ -1,7 +1,7 @@
 #pragma once
 
 #include <sstream>
-#include "ATen/Check.h"
+#include "ATen/TensorUtils.h"
 
 namespace at {
 
diff --git a/aten/src/ATen/Check.cpp b/aten/src/ATen/TensorUtils.cpp
similarity index 96%
rename from aten/src/ATen/Check.cpp
rename to aten/src/ATen/TensorUtils.cpp
index 203915d..09ec0a4 100644
--- a/aten/src/ATen/Check.cpp
+++ b/aten/src/ATen/TensorUtils.cpp
@@ -1,5 +1,5 @@
 #include "ATen/Config.h"
-#include "ATen/Check.h"
+#include "ATen/TensorUtils.h"
 
 #include "ATen/ATen.h"
 
@@ -210,4 +210,11 @@
   }
 }
 
+void * maybe_data_ptr(const Tensor& tensor) {
+  return tensor.defined() ? (void *)tensor.data_ptr() : nullptr;
+}
+
+void * maybe_data_ptr(const TensorArg& tensor) {
+  return tensor->defined() ? (void *)tensor->data_ptr() : nullptr;
+}
 }
diff --git a/aten/src/ATen/Check.h b/aten/src/ATen/TensorUtils.h
similarity index 89%
rename from aten/src/ATen/Check.h
rename to aten/src/ATen/TensorUtils.h
index bac47c7..37d99d5 100644
--- a/aten/src/ATen/Check.h
+++ b/aten/src/ATen/TensorUtils.h
@@ -4,14 +4,14 @@
 #include "ATen/TensorGeometry.h"
 #include "ATen/Utils.h"
 
-// This file contains utility functions for checking that arguments
-// make sense.  This is particularly useful for native functions,
-// which do NO argument checking by default.
-//
-// It's NOT in Utils.h, because this file has a dep on Tensor.h
+// These functions are NOT in Utils.h, because this file has a dep on Tensor.h
 
 namespace at {
 
+// The following are utility functions for checking that arguments
+// make sense.  These are particularly useful for native functions,
+// which do NO argument checking by default.
+
 struct TensorArg {
   Tensor tensor;
   const char* name;
@@ -72,4 +72,10 @@
 
 // FixMe: does TensorArg slow things down?
 void checkBackend(CheckedFrom c, at::ArrayRef<Tensor> t, at::Backend backend);
+
+// Methods for getting data_ptr if tensor is defined
+void * maybe_data_ptr(const Tensor& tensor);
+void * maybe_data_ptr(const TensorArg& tensor);
+
 }
+
diff --git a/aten/src/ATen/cuda/CUDAApplyUtils.cuh b/aten/src/ATen/cuda/CUDAApplyUtils.cuh
index ce8ed09..431cad5 100644
--- a/aten/src/ATen/cuda/CUDAApplyUtils.cuh
+++ b/aten/src/ATen/cuda/CUDAApplyUtils.cuh
@@ -1,7 +1,7 @@
 #pragma once
 
 #include "detail/IndexUtils.cuh"
-#include "ATen/Check.h"
+#include "ATen/TensorUtils.h"
 
 //
 // This file contains pointwise operation functions and kernels that
diff --git a/aten/src/ATen/cudnn/Descriptors.h b/aten/src/ATen/cudnn/Descriptors.h
index d39dbe5..a678b62 100644
--- a/aten/src/ATen/cudnn/Descriptors.h
+++ b/aten/src/ATen/cudnn/Descriptors.h
@@ -4,7 +4,7 @@
 
 #include "cudnn-wrapper.h"
 #include <ATen/ATen.h>
-#include <ATen/Check.h>
+#include <ATen/TensorUtils.h>
 
 #if CUDNN_VERSION < 7000
 
diff --git a/aten/src/ATen/native/BatchNorm.cpp b/aten/src/ATen/native/BatchNorm.cpp
index 7d38103..532de9b 100644
--- a/aten/src/ATen/native/BatchNorm.cpp
+++ b/aten/src/ATen/native/BatchNorm.cpp
@@ -21,12 +21,20 @@
 
 Tensor batch_norm(
     const Tensor& input, const Tensor& weight /* optional */, const Tensor& bias /* optional */,
-    const Tensor& running_mean, const Tensor& running_var,
+    const Tensor& running_mean /* optional */, const Tensor& running_var /* optional */,
     bool training, double momentum, double eps, bool cudnn_enabled) {
 
   auto num_features = input.sizes()[1];
-  check_dims_match_num_input_features("running_mean", num_features, running_mean.numel());
-  check_dims_match_num_input_features("running_var", num_features, running_var.numel());
+  if (running_mean.defined()) {
+    check_dims_match_num_input_features("running_mean", num_features, running_mean.numel());
+  } else if (!training) {
+    throw std::runtime_error("running_mean must be defined in evaluation mode");
+  }
+  if (running_var.defined()) {
+    check_dims_match_num_input_features("running_var", num_features, running_var.numel());
+  } else if (!training) {
+    throw std::runtime_error("running_var must be defined in evaluation mode");
+  }
   if (weight.defined()) {
     check_dims_match_num_input_features("weight", num_features, weight.numel());
   }
@@ -38,8 +46,10 @@
 #if AT_CUDNN_ENABLED()
   use_cudnn = (input.type().is_cuda()
                && (input.type().scalarType() != at::kHalf
-               || weight.type().scalarType() == at::kFloat)
+                 || weight.type().scalarType() == at::kFloat)
                && weight.defined() && bias.defined()
+               && ((running_mean.defined() && running_var.defined())
+                 || (!running_mean.defined() && !running_var.defined() && training))
                && input.size(0) <= 131070
                && cudnn_enabled && CUDNN_VERSION >= 5110L);
 #endif
diff --git a/aten/src/ATen/native/Embedding.cpp b/aten/src/ATen/native/Embedding.cpp
index b54f45b..d0dc06d 100644
--- a/aten/src/ATen/native/Embedding.cpp
+++ b/aten/src/ATen/native/Embedding.cpp
@@ -1,5 +1,5 @@
 #include "ATen/ATen.h"
-#include "ATen/Check.h"
+#include "ATen/TensorUtils.h"
 #include "ATen/NativeFunctions.h"
 
 #include <cstring>
diff --git a/aten/src/ATen/native/EmbeddingBag.cpp b/aten/src/ATen/native/EmbeddingBag.cpp
index 7a736d5..4e89b1e 100644
--- a/aten/src/ATen/native/EmbeddingBag.cpp
+++ b/aten/src/ATen/native/EmbeddingBag.cpp
@@ -1,5 +1,5 @@
 #include "ATen/ATen.h"
-#include "ATen/Check.h"
+#include "ATen/TensorUtils.h"
 #include "ATen/NativeFunctions.h"
 
 #include <cstring>
diff --git a/aten/src/ATen/native/Pooling.cpp b/aten/src/ATen/native/Pooling.cpp
index 9ce14a4..d9df277 100644
--- a/aten/src/ATen/native/Pooling.cpp
+++ b/aten/src/ATen/native/Pooling.cpp
@@ -1,5 +1,5 @@
 #include "ATen/ATen.h"
-#include "ATen/Check.h"
+#include "ATen/TensorUtils.h"
 #include "ATen/NativeFunctions.h"
 
 #include <sstream>
diff --git a/aten/src/ATen/native/cuda/Embedding.cu b/aten/src/ATen/native/cuda/Embedding.cu
index 150dee8..c105171 100644
--- a/aten/src/ATen/native/cuda/Embedding.cu
+++ b/aten/src/ATen/native/cuda/Embedding.cu
@@ -1,5 +1,5 @@
 #include "ATen/ATen.h"
-#include "ATen/Check.h"
+#include "ATen/TensorUtils.h"
 #include "ATen/Dispatch.h"
 #include "ATen/NativeFunctions.h"
 
diff --git a/aten/src/ATen/native/cuda/EmbeddingBag.cu b/aten/src/ATen/native/cuda/EmbeddingBag.cu
index c22fa18..3cc1a5c 100644
--- a/aten/src/ATen/native/cuda/EmbeddingBag.cu
+++ b/aten/src/ATen/native/cuda/EmbeddingBag.cu
@@ -1,5 +1,5 @@
 #include "ATen/ATen.h"
-#include "ATen/Check.h"
+#include "ATen/TensorUtils.h"
 #include "ATen/Dispatch.h"
 #include "ATen/NativeFunctions.h"
 
diff --git a/aten/src/ATen/native/cudnn/AffineGridGenerator.cpp b/aten/src/ATen/native/cudnn/AffineGridGenerator.cpp
index be0ff00..6a12f48 100644
--- a/aten/src/ATen/native/cudnn/AffineGridGenerator.cpp
+++ b/aten/src/ATen/native/cudnn/AffineGridGenerator.cpp
@@ -30,7 +30,7 @@
 #include <ATen/cudnn/Types.h>
 #include <ATen/cudnn/Utils.h>
 
-#include <ATen/Check.h>
+#include <ATen/TensorUtils.h>
 
 namespace at { namespace native {
 
diff --git a/aten/src/ATen/native/cudnn/BatchNorm.cpp b/aten/src/ATen/native/cudnn/BatchNorm.cpp
index 08df72f..211a9c3 100644
--- a/aten/src/ATen/native/cudnn/BatchNorm.cpp
+++ b/aten/src/ATen/native/cudnn/BatchNorm.cpp
@@ -31,7 +31,7 @@
 #include <ATen/cudnn/Types.h>
 #include <ATen/cudnn/Utils.h>
 
-#include <ATen/Check.h>
+#include <ATen/TensorUtils.h>
 
 namespace at { namespace native {
 
@@ -60,7 +60,10 @@
   CheckedFrom c = "cudnn_batch_norm";
   setCuDNNStreamToCurrent();
 
-  checkAllDefined(c, {input, weight, bias, running_mean, running_var});
+  checkAllDefined(c, {input, weight, bias});
+  if (!training) {
+    checkAllDefined(c, {running_mean, running_var});
+  }
   checkAllSameGPU(c, {input, weight, bias, running_mean, running_var});
   if (input->type().scalarType() == ScalarType::Half) {
     checkScalarType(c, weight, ScalarType::Float);
@@ -73,7 +76,9 @@
   checkDimRange(c, input, 2, 6 /* exclusive */);
   auto num_features = input->size(1);
   for (auto t : {weight, bias, running_mean, running_var}) {
-    checkNumel(c, t, num_features);
+    if (t->defined()) {
+      checkNumel(c, t, num_features);
+    }
   }
 
   cudnnBatchNormMode_t mode;
@@ -97,16 +102,12 @@
 
   Constant one(dataType, 1);
   Constant zero(dataType, 0);
-
-  // Though technically we only need to allocate this for training,
-  //  (1) THNN batch normalization expects non-undefined tensors for
-  //  backwards (which we will pass these to, if !training, because
-  //  CuDNN backwards with !training doesn't gradcheck), and
-  //  (2) These are pretty small tensors, no big deal.
-  Tensor save_mean = running_mean_t.type().tensor(running_mean_t.sizes());
-  Tensor save_var = running_var_t.type().tensor(running_var_t.sizes());
+  Tensor save_mean, save_var;
 
   if (training) {
+    int64_t num_features = input_t.size(1);
+    save_mean = weight_t.type().tensor({ num_features });
+    save_var = weight_t.type().tensor({ num_features });
     CUDNN_CHECK(cudnnBatchNormalizationForwardTraining(
       handle, mode, &one, &zero,
       idesc.desc(), input->data_ptr(),
@@ -115,8 +116,8 @@
       weight->data_ptr(),
       bias->data_ptr(),
       exponential_average_factor,
-      running_mean->data_ptr(),
-      running_var->data_ptr(),
+      at::maybe_data_ptr(running_mean),
+      at::maybe_data_ptr(running_var),
       epsilon,
       save_mean.data_ptr(),
       save_var.data_ptr()));
diff --git a/aten/src/ATen/native/cudnn/Conv.cpp b/aten/src/ATen/native/cudnn/Conv.cpp
index 9672b2d..a974c9f 100644
--- a/aten/src/ATen/native/cudnn/Conv.cpp
+++ b/aten/src/ATen/native/cudnn/Conv.cpp
@@ -80,7 +80,7 @@
 #include <ATen/cudnn/Types.h>
 #include <ATen/cudnn/Utils.h>
 
-#include <ATen/Check.h>
+#include <ATen/TensorUtils.h>
 
 #include <functional>
 #include <iterator>
diff --git a/aten/src/ATen/native/cudnn/GridSampler.cpp b/aten/src/ATen/native/cudnn/GridSampler.cpp
index c3ec153..0186a94 100644
--- a/aten/src/ATen/native/cudnn/GridSampler.cpp
+++ b/aten/src/ATen/native/cudnn/GridSampler.cpp
@@ -27,7 +27,7 @@
 #include <ATen/cudnn/Types.h>
 #include <ATen/cudnn/Utils.h>
 
-#include <ATen/Check.h>
+#include <ATen/TensorUtils.h>
 
 // TODO: descriptor checking
 
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 9c0dc49..1460ca4 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -22,7 +22,7 @@
 - func: addr_out(Tensor result, Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
   variants: function
 
-- func: batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, double momentum, double eps, bool cudnn_enabled) -> Tensor
+- func: batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, double momentum, double eps, bool cudnn_enabled) -> Tensor
   variants: function
 
 - func: bernoulli_(Tensor self, Tensor p, Generator* generator=nullptr) -> Tensor
@@ -91,11 +91,11 @@
       name: grad_theta
   variants: function
 
-- func: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, double exponential_average_factor, double epsilon) -> (Tensor, Tensor, Tensor)
+- func: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, double exponential_average_factor, double epsilon) -> (Tensor, Tensor, Tensor)
   variants: function
 
 # NB: You can only use this if you used cudnn_batch_norm training=True
-- func: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor running_mean, Tensor running_var, Tensor? save_mean, Tensor? save_var, double epsilon) -> (Tensor, Tensor, Tensor)
+- func: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, double epsilon) -> (Tensor, Tensor, Tensor)
   variants: function
 
 - func: cudnn_convolution(Tensor self, Tensor weight, Tensor? bias, IntList padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic) -> Tensor
diff --git a/aten/src/THCUNN/BatchNormalization.cu b/aten/src/THCUNN/BatchNormalization.cu
index e6717c7..865323a 100644
--- a/aten/src/THCUNN/BatchNormalization.cu
+++ b/aten/src/THCUNN/BatchNormalization.cu
@@ -141,8 +141,8 @@
 __global__ void BatchNormalizationUpdateOutputInference_kernel(
     const DeviceTensor3 input,
     DeviceTensor3 output,
-    DeviceTensor1 runningMean,
-    DeviceTensor1 runningVar,
+    const DeviceTensor1 runningMean,
+    const DeviceTensor1 runningVar,
     const DeviceTensor1 weight,
     const DeviceTensor1 bias,
     Acctype epsilon) {
@@ -196,8 +196,12 @@
     Acctype unbiasedVar = varN / (N - 1);
     saveMean[plane] = ScalarConvert<Acctype, Dtype>::to(mean);
     saveStd[plane] = ScalarConvert<Acctype, Dtype>::to(invStd);
-    runningMean[plane] = ScalarConvert<Acctype, Dtype>::to((1 - momentum) * runningMean[plane] + momentum * mean);
-    runningVar[plane] = ScalarConvert<Acctype, Dtype>::to((1 - momentum) * runningVar[plane] + momentum * unbiasedVar);
+    if (runningMean.data() != NULL) {
+      runningMean[plane] = ScalarConvert<Acctype, Dtype>::to((1 - momentum) * runningMean[plane] + momentum * mean);
+    }
+    if (runningVar.data() != NULL) {
+      runningVar[plane] = ScalarConvert<Acctype, Dtype>::to((1 - momentum) * runningVar[plane] + momentum * unbiasedVar);
+    }
   }
 
   // Write normalized and update the output
diff --git a/aten/src/THCUNN/generic/BatchNormalization.cu b/aten/src/THCUNN/generic/BatchNormalization.cu
index b407cfd..eb2dc84 100644
--- a/aten/src/THCUNN/generic/BatchNormalization.cu
+++ b/aten/src/THCUNN/generic/BatchNormalization.cu
@@ -39,8 +39,9 @@
 
   THCTensor_(resizeAs)(state, output_, input_);
   if (train) {
-    THCTensor_(resizeAs)(state, saveMean_, runningMean_);
-    THCTensor_(resizeAs)(state, saveStd_, runningVar_);
+    int64_t nInput = THCTensor_(size)(state, input_, 1);
+    THCTensor_(resize1d)(state, saveMean_, nInput);
+    THCTensor_(resize1d)(state, saveStd_, nInput);
   }
   DeviceTensor3 input = devicetensor<3>(state, input_);
   DeviceTensor3 output = devicetensor<3>(state, output_);
diff --git a/aten/src/THCUNN/generic/THCUNN.h b/aten/src/THCUNN/generic/THCUNN.h
index 9fd3b85..110616a 100644
--- a/aten/src/THCUNN/generic/THCUNN.h
+++ b/aten/src/THCUNN/generic/THCUNN.h
@@ -36,8 +36,8 @@
                   THCTensor *output_,
                   THCTensor *weight_,        // [OPTIONAL]
                   THCTensor *bias_,          // [OPTIONAL]
-                  THCTensor *runningMean_,
-                  THCTensor *runningVar_,
+                  THCTensor *runningMean_,   // [OPTIONAL] if train
+                  THCTensor *runningVar_,    // [OPTIONAL] if train
                   THCTensor *saveMean_,
                   THCTensor *saveStd_,
                   bool train,
@@ -52,10 +52,10 @@
                   THCTensor *gradWeight_,       // [OPTIONAL]
                   THCTensor *gradBias_,         // [OPTIONAL]
                   THCTensor *weight_,           // [OPTIONAL]
-                  THCTensor *runningMean_,
-                  THCTensor *runningVar_,
-                  THCTensor *saveMean_,
-                  THCTensor *saveStd_,
+                  THCTensor *runningMean_,      // [OPTIONAL] if train
+                  THCTensor *runningVar_,       // [OPTIONAL] if train
+                  THCTensor *saveMean_,         // [OPTIONAL] if !train
+                  THCTensor *saveStd_,          // [OPTIONAL] if !train
                   bool train,
                   double scale,
                   double eps);
diff --git a/aten/src/THNN/generic/BatchNormalization.c b/aten/src/THNN/generic/BatchNormalization.c
index 2ebac62..1f2aa3c 100644
--- a/aten/src/THNN/generic/BatchNormalization.c
+++ b/aten/src/THNN/generic/BatchNormalization.c
@@ -15,8 +15,8 @@
   ptrdiff_t n = THTensor_(nElement)(input) / nInput;
 
   if (train) {
-    THTensor_(resizeAs)(save_mean, running_mean);
-    THTensor_(resizeAs)(save_std, running_var);
+    THTensor_(resize1d)(save_mean, nInput);
+    THTensor_(resize1d)(save_std, nInput);
   }
 
   #pragma omp parallel for
@@ -47,12 +47,15 @@
       THTensor_(set1d)(save_std, f, (real) invstd);
 
       // update running averages
-      THTensor_(set1d)(running_mean, f,
-        (real) (momentum * mean + (1 - momentum) * THTensor_(get1d)(running_mean, f)));
-
-      accreal unbiased_var = sum / (n - 1);
-      THTensor_(set1d)(running_var, f,
-        (real) (momentum * unbiased_var + (1 - momentum) * THTensor_(get1d)(running_var, f)));
+      if (running_mean) {
+        THTensor_(set1d)(running_mean, f,
+          (real) (momentum * mean + (1 - momentum) * THTensor_(get1d)(running_mean, f)));
+      }
+      if (running_var) {
+        accreal unbiased_var = sum / (n - 1);
+        THTensor_(set1d)(running_var, f,
+          (real) (momentum * unbiased_var + (1 - momentum) * THTensor_(get1d)(running_var, f)));
+      }
     } else {
       mean = THTensor_(get1d)(running_mean, f);
       invstd = 1 / sqrt(THTensor_(get1d)(running_var, f) + eps);
diff --git a/aten/src/THNN/generic/THNN.h b/aten/src/THNN/generic/THNN.h
index dc0fb95..61fdf76 100644
--- a/aten/src/THNN/generic/THNN.h
+++ b/aten/src/THNN/generic/THNN.h
@@ -814,8 +814,8 @@
           THTensor *output,
           THTensor *weight,       // [OPTIONAL]
           THTensor *bias,         // [OPTIONAL]
-          THTensor *running_mean,
-          THTensor *running_var,
+          THTensor *running_mean, // [OPTIONAL] if train
+          THTensor *running_var,  // [OPTIONAL] if train
           THTensor *save_mean,
           THTensor *save_std,
           bool train,
@@ -829,10 +829,10 @@
           THTensor *gradWeight,   // [OPTIONAL]
           THTensor *gradBias,     // [OPTIONAL]
           THTensor *weight,       // [OPTIONAL]
-          THTensor *running_mean,
-          THTensor *running_var,
-          THTensor *save_mean,
-          THTensor *save_std,
+          THTensor *running_mean, // [OPTIONAL] if train
+          THTensor *running_var,  // [OPTIONAL] if train
+          THTensor *save_mean,    // [OPTIONAL] if !train
+          THTensor *save_std,     // [OPTIONAL] if !train
           bool train,
           double scale,
           double eps);
diff --git a/docs/source/nn.rst b/docs/source/nn.rst
index 827ba04..5fcd627 100644
--- a/docs/source/nn.rst
+++ b/docs/source/nn.rst
@@ -378,6 +378,12 @@
 .. autoclass:: InstanceNorm3d
     :members:
 
+:hidden:`LayerNorm`
+~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autoclass:: LayerNorm
+    :members:
+
 :hidden:`LocalResponseNorm`
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
@@ -943,6 +949,16 @@
 
 .. autofunction:: batch_norm
 
+:hidden:`instance_norm`
+~~~~~~~~~~~~~~~~~~~~
+
+.. autofunction:: instance_norm
+
+:hidden:`layer_norm`
+~~~~~~~~~~~~~~~~~~~~
+
+.. autofunction:: layer_norm
+
 :hidden:`local_response_norm`
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
diff --git a/test/common.py b/test/common.py
index 7cfe27f..70e61cb 100644
--- a/test/common.py
+++ b/test/common.py
@@ -4,6 +4,7 @@
 import argparse
 import unittest
 import warnings
+import random
 import contextlib
 from functools import wraps
 from itertools import product
@@ -109,6 +110,7 @@
 
 def set_rng_seed(seed):
     torch.manual_seed(seed)
+    random.seed(seed)
     if TEST_NUMPY:
         numpy.random.seed(seed)
 
diff --git a/test/expect/TestJit.test_batchnorm.expect b/test/expect/TestJit.test_batchnorm.expect
index 0986200..e28957a 100644
--- a/test/expect/TestJit.test_batchnorm.expect
+++ b/test/expect/TestJit.test_batchnorm.expect
@@ -1,8 +1,8 @@
-graph(%0 : Double(2, 2)
+graph(%0 : Double(2, 2, 2, 2)
       %1 : Double(2)
       %2 : Double(2)
       %3 : Double(2)
       %4 : Double(2)) {
-  %5 : Double(2, 2) = batch_norm[training=1, momentum=0.1, eps=1e-05, cudnn_enabled=1](%0, %1, %2, %3, %4), scope: BatchNorm2d
+  %5 : Double(2, 2, 2, 2) = batch_norm[training=1, momentum=0.1, eps=1e-05, cudnn_enabled=1](%0, %1, %2, %3, %4), scope: BatchNorm2d
   return (%5);
 }
diff --git a/test/test_jit.py b/test/test_jit.py
index 1df99bb..97e4e51 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -887,7 +887,7 @@
         self.assertExpected(torch._C._jit_run_cpp_tests())
 
     def test_batchnorm(self):
-        x = Variable(torch.randn(2, 2).fill_(1.0), requires_grad=True)
+        x = Variable(torch.randn(2, 2, 2, 2).fill_(1.0), requires_grad=True)
         trace, _ = torch.jit.trace(nn.BatchNorm2d(2), x)
         self.assertExpectedTrace(trace)
 
@@ -902,7 +902,7 @@
             pass
 
         bn = MyBatchNorm2d(1)
-        x = Variable(torch.randn(5, 1))
+        x = Variable(torch.randn(5, 1, 2, 1))
         z = bn(x)
         with self.assertCompiled(bn):
             z2 = bn(x)
diff --git a/test/test_nn.py b/test/test_nn.py
index 5f3a4bc..ab18002 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -1444,14 +1444,15 @@
             self.assertLess(abs(output.data.std() - std), 0.1)
             output.backward(input)
 
-    def _test_InstanceNorm(self, cls, input):
+    def _test_InstanceNorm_general(self, cls, input, type):
+        # default case track_running_stats=False
         b, c = input.size(0), input.size(1)
-        input_var = Variable(input)
+        input_var = Variable(input.type(type), requires_grad=True)
 
-        IN = cls(c, eps=0)
+        IN = cls(c, eps=0).type(type)
 
         output = IN(input_var)
-        out_reshaped = output.transpose(1, 0).contiguous().view(c, -1)
+        out_reshaped = output.view(b * c, -1)
 
         mean = out_reshaped.mean(1)
         var = out_reshaped.var(1, unbiased=False)
@@ -1459,11 +1460,26 @@
         self.assertAlmostEqual(torch.abs(mean.data).mean(), 0, delta=1e-5)
         self.assertAlmostEqual(torch.abs(var.data).mean(), 1, delta=1e-5)
 
-        # If momentum==1 running_mean/var should be
-        # equal to mean/var of the input
-        IN = cls(c, momentum=1, eps=0)
+        # check that eval mode doesn't change behavior
+        grad_out = output.data.clone().normal_()
+        res1 = output.data.clone()
+        output.backward(grad_out)
+        grad1 = input_var.grad.data.clone()
 
+        IN.eval()
         output = IN(input_var)
+        input_var.grad = None
+        output.backward(grad_out)
+        res2 = output.data
+        grad2 = input_var.grad.data
+        self.assertEqual(res1, res2)
+        self.assertEqual(grad1, grad2)
+
+        # If track_running_stats=True and momentum=1, running_mean/var should be
+        # equal to mean/var of the input (with unbias correction)
+        IN = cls(c, momentum=1, eps=0, track_running_stats=True).type(type)
+
+        output = IN(input_var.type(type))
 
         input_reshaped = input_var.transpose(1, 0).contiguous().view(c, -1)
         mean = input_reshaped.mean(1)
@@ -1474,32 +1490,156 @@
         self.assertAlmostEqual(torch.abs(mean.data - IN.running_mean).mean(), 0, delta=1e-5)
         self.assertAlmostEqual(torch.abs(var.data.mean(1) - IN.running_var).mean(), 0, delta=1e-5)
 
-    def test_InstanceNorm2d(self):
-        b = random.randint(3, 5)
-        c = random.randint(1, 5)
-        w = random.randint(2, 5)
-        h = random.randint(2, 5)
+        # in eval mode, adding X * std to a channel in input should make the
+        # corresponding channel in output have mean X
+        IN.eval()
+        delta = (IN.running_var.sqrt() * torch.arange(c).type(type)).view(-1, *[1 for _ in range(2, input.dim())])
+        output = IN(input_var + Variable(delta))
+        self.assertEqual(output.transpose(0, 1).contiguous().view(c, -1).mean(1), torch.arange(c))
 
-        input = torch.Tensor(b, c, h, w).uniform_()
-        self._test_InstanceNorm(nn.InstanceNorm2d, input)
+    def _test_InstanceNorm_cuda_half(self, cls, input):
+        # THNN
+        input = Variable(input.cuda().half().random_(1, 10), requires_grad=True)
+        m = cls(input.size(1), affine=True, track_running_stats=True).cuda().half()
+        thnn_output = m(input)
+        thnn_output.sum().backward()
+        thnn_input_grad = input.grad.data.clone()
+        self.assertEqual(thnn_output.type(), input.type())
+        # cuDNN
+        if TEST_CUDNN:
+            input.grad = None
+            m = m.float()
+            cudnn_output = m(input)
+            cudnn_output.sum().backward()
+            cudnn_input_grad = input.grad.data.clone()
+            self.assertEqual(cudnn_output.type(), input.type())
+            self.assertAlmostEqual(cudnn_output, thnn_output, delta=1e-4)
+            self.assertAlmostEqual(cudnn_input_grad, thnn_input_grad, delta=1e-3)
 
-    def test_InstanceNorm1d(self):
+    def test_InstanceNorm1d_general(self):
         b = random.randint(3, 5)
-        c = random.randint(1, 5)
-        d = random.randint(2, 5)
+        c = random.randint(3, 5)
+        d = random.randint(8, 10)
 
         input = torch.Tensor(b, c, d).uniform_()
-        self._test_InstanceNorm(nn.InstanceNorm1d, input)
+        self._test_InstanceNorm_general(nn.InstanceNorm1d, input, torch.FloatTensor)
 
-    def test_InstanceNorm3d(self):
+    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
+    def test_InstanceNorm1d_general_cuda(self):
         b = random.randint(3, 5)
-        c = random.randint(1, 5)
+        c = random.randint(3, 5)
+        d = random.randint(8, 10)
+
+        input = torch.Tensor(b, c, d).uniform_()
+        self._test_InstanceNorm_general(nn.InstanceNorm1d, input, torch.cuda.FloatTensor)
+        self._test_InstanceNorm_cuda_half(nn.InstanceNorm1d, input)
+
+    def test_InstanceNorm2d_general(self):
+        b = random.randint(3, 5)
+        c = random.randint(3, 5)
+        w = random.randint(3, 6)
+        h = random.randint(6, 8)
+
+        input = torch.Tensor(b, c, h, w).uniform_()
+        self._test_InstanceNorm_general(nn.InstanceNorm2d, input, torch.FloatTensor)
+
+    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
+    def test_InstanceNorm2d_general_cuda(self):
+        b = random.randint(3, 5)
+        c = random.randint(3, 5)
+        w = random.randint(3, 6)
+        h = random.randint(6, 8)
+
+        input = torch.Tensor(b, c, h, w).uniform_()
+        self._test_InstanceNorm_general(nn.InstanceNorm2d, input, torch.cuda.FloatTensor)
+        self._test_InstanceNorm_cuda_half(nn.InstanceNorm2d, input)
+
+    def test_InstanceNorm3d_general(self):
+        b = random.randint(3, 5)
+        c = random.randint(3, 5)
         w = random.randint(2, 5)
         h = random.randint(2, 5)
         d = random.randint(2, 5)
 
         input = torch.Tensor(b, c, h, w, d).uniform_()
-        self._test_InstanceNorm(nn.InstanceNorm3d, input)
+        self._test_InstanceNorm_general(nn.InstanceNorm3d, input, torch.FloatTensor)
+
+    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
+    def test_InstanceNorm3d_general_cuda(self):
+        b = random.randint(3, 5)
+        c = random.randint(2, 5)
+        w = random.randint(2, 5)
+        h = random.randint(2, 5)
+        d = random.randint(2, 5)
+
+        input = torch.Tensor(b, c, h, w, d).uniform_()
+        self._test_InstanceNorm_general(nn.InstanceNorm3d, input, torch.cuda.FloatTensor)
+        self._test_InstanceNorm_cuda_half(nn.InstanceNorm3d, input)
+
+    def _test_LayerNorm_general(self, type):
+        for i in range(2, 6):
+            shape = torch.LongTensor(i).random_(3, 6).tolist()
+            x = Variable(type(*shape).uniform_(0, 10))
+            normalized_ndim = random.randint(1, i - 1)  # inclusive
+            normalized_shape = shape[-normalized_ndim:]
+            unnormalized_shape = shape[:-normalized_ndim]
+
+            # test that LN normalizes to mean 0 and stddev 1
+            ln = nn.LayerNorm(normalized_shape, eps=0).type(type)
+            ln.weight.data.fill_(1)
+            ln.bias.data.fill_(0)
+            output = ln(x)
+            out_reshaped = output.view(*(unnormalized_shape + [-1]))
+            mean = out_reshaped.mean(-1)
+            var = out_reshaped.var(-1, unbiased=False)
+            self.assertAlmostEqual(torch.abs(mean.data).mean(), 0, delta=1e-5)
+            self.assertAlmostEqual(torch.abs(var.data).mean(), 1, delta=1e-5)
+
+            # test that LN applies weight and bias correctly
+            scale, bias = torch.FloatTensor(2).uniform_(0.2, 2).tolist()
+            ln.weight.data.fill_(scale)
+            ln.bias.data.fill_(bias)
+            output = ln(x)
+            out_reshaped = output.view(*(unnormalized_shape + [-1]))
+            mean = out_reshaped.mean(-1)
+            var = out_reshaped.var(-1, unbiased=False)
+            self.assertAlmostEqual(torch.abs(mean.data).mean(), bias, delta=1e-5)
+            self.assertAlmostEqual(torch.abs(var.data).mean(), scale ** 2, delta=1e-5)
+
+            # test that LN with track_running_stats=True
+            ln = nn.LayerNorm(normalized_shape, momentum=1, eps=0,
+                              elementwise_affine=False, track_running_stats=True).type(type)
+            output_ref = ln(x).data.clone()
+            input_reshaped = x.view(*(unnormalized_shape + [-1]))
+            # make sure that running mean and var update correctly when training
+            mean = input_reshaped.mean(-1).mean()
+            var = input_reshaped.var(-1, unbiased=True).mean()
+            self.assertAlmostEqual(torch.abs(mean.data - ln.running_mean).mean(), 0, delta=1e-5)
+            self.assertAlmostEqual(torch.abs(var.data - ln.running_var).mean(), 0, delta=1e-5)
+            ln.eval()
+            old_running_mean = ln.running_mean.clone()
+            old_running_var = ln.running_var.clone()
+            output_new = ln(x + ln.running_var.sqrt()[0] * scale).data
+            self.assertAlmostEqual((output_new - output_ref).mean(), scale, delta=1e-5)
+            # make sure that running mean and var don't change in eval
+            self.assertEqual(old_running_mean, ln.running_mean)
+            self.assertEqual(old_running_var, ln.running_var)
+
+    def _test_LayerNorm_cuda_half(self):
+        # just THNN, LayerNorm has no cuDNN path
+        input = Variable(torch.rand(2, 3, 3, 2).cuda().half().random_(1, 10), requires_grad=True)
+        m = nn.LayerNorm([3, 2]).cuda().half()
+        output = m(input)
+        output.sum().backward()
+        self.assertEqual(output.type(), input.type())
+
+    def test_LayerNorm_general(self):
+        self._test_LayerNorm_general(torch.FloatTensor)
+
+    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
+    def test_LayerNorm_general_cuda(self):
+        self._test_LayerNorm_general(torch.cuda.FloatTensor)
+        self._test_LayerNorm_cuda_half()
 
     def test_pad(self):
         inputs = Variable(torch.randn(1, 3, 4, 4), requires_grad=True)
@@ -3475,13 +3615,51 @@
         gradgradcheck(func, [v])
 
     @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
-    @unittest.skipIf(not TEST_CUDNN, "cuDNN unavailable")
     def test_batchnorm_cudnn_half(self):
-        input = Variable(torch.rand(2, 3, 2, 2).half().cuda())
-        m = nn.BatchNorm2d(3).float().cuda()
-        output = m(input)
-        output.sum().backward()
-        self.assertEqual(output.type(), input.type())
+        # THNN
+        input = Variable(torch.rand(2, 3, 2, 2).half().cuda().random_(1, 10), requires_grad=True)
+        m = nn.BatchNorm2d(3).half().cuda()
+        thnn_output = m(input)
+        thnn_output.sum().backward()
+        thnn_input_grad = input.grad.data.clone()
+        self.assertEqual(thnn_output.type(), input.type())
+        # cuDNN
+        if TEST_CUDNN:
+            input.grad = None
+            m = m.float()
+            cudnn_output = m(input)
+            cudnn_output.sum().backward()
+            cudnn_input_grad = input.grad.data.clone()
+            self.assertEqual(cudnn_output.type(), input.type())
+            self.assertEqual(cudnn_output, thnn_output)
+            self.assertAlmostEqual(cudnn_input_grad, thnn_input_grad, delta=1e-3)
+
+    def _test_batchnorm_update_stats(self, test_type=torch.FloatTensor):
+        module = nn.BatchNorm1d(3).type(test_type)
+
+        data = Variable(torch.rand(4, 3).type(test_type))
+
+        # training pass
+        old_running_mean = module.running_mean.clone()
+        old_running_var = module.running_var.clone()
+        module(data)
+        self.assertNotEqual(old_running_mean, module.running_mean)
+        self.assertNotEqual(old_running_var, module.running_var)
+
+        # eval pass
+        module.eval()
+        old_running_mean = module.running_mean.clone()
+        old_running_var = module.running_var.clone()
+        module(data)
+        self.assertEqual(old_running_mean, module.running_mean)
+        self.assertEqual(old_running_var, module.running_var)
+
+    def test_batchnorm_update_stats(self):
+        self._test_batchnorm_update_stats()
+
+    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
+    def test_batchnorm_update_stats_cuda(self):
+        self._test_batchnorm_update_stats(torch.cuda.FloatTensor)
 
     def test_batchnorm_raises_error_if_running_mean_is_not_same_size_as_input(self):
         input = Variable(torch.rand(2, 10))
@@ -3539,6 +3717,30 @@
         self.assertEqual(res1, res2)
         self.assertEqual(grad1, grad2)
 
+        # track_running_stats=False
+        module = nn.BatchNorm1d(3, track_running_stats=False).type(test_type)
+
+        data = Variable(torch.rand(4, 3).type(test_type), requires_grad=True)
+        grad = torch.rand(4, 3).type(test_type)
+
+        # 1st pass
+        res1 = module(data)
+        res1.backward(grad)
+        grad1 = data.grad.data.clone()
+
+        # set eval
+        module.eval()
+
+        # 2nd pass
+        if data.grad is not None:
+            data.grad.data.zero_()
+
+        res2 = module(data)
+        res2.backward(grad)
+        grad2 = data.grad.data.clone()
+        self.assertEqual(res1, res2)
+        self.assertEqual(grad1, grad2)
+
     def test_pairwise_distance(self):
         input1 = Variable(torch.randn(4, 4), requires_grad=True)
         input2 = Variable(torch.randn(4, 4), requires_grad=True)
@@ -5034,6 +5236,14 @@
     ),
     dict(
         module_name='BatchNorm1d',
+        constructor_args=(10, 1e-3, 0.3, True, False),
+        input_size=(4, 10),
+        cudnn=True,
+        check_eval=True,
+        desc='not_tracking_stats',
+    ),
+    dict(
+        module_name='BatchNorm1d',
         constructor_args=(5, 1e-3, 0.3, False),
         input_size=(4, 5, 3),
         cudnn=True,
@@ -5064,6 +5274,14 @@
         desc='not_affine',
     ),
     dict(
+        module_name='BatchNorm2d',
+        constructor_args=(3, 1e-3, 0.8, True, False),
+        input_size=(2, 3, 6, 6),
+        cudnn=True,
+        check_eval=True,
+        desc='not_tracking_stats',
+    ),
+    dict(
         module_name='BatchNorm3d',
         constructor_args=(3,),
         input_size=(2, 3, 4, 4, 4),
@@ -5087,6 +5305,107 @@
         desc='not_affine',
     ),
     dict(
+        module_name='BatchNorm3d',
+        constructor_args=(3, 1e-3, 0.7, True, False),
+        input_size=(2, 3, 4, 4, 4),
+        cudnn=True,
+        check_eval=True,
+        desc='not_tracking_stats',
+    ),
+    dict(
+        module_name='InstanceNorm1d',
+        constructor_args=(3, 1e-3, 0.3),
+        input_size=(4, 3, 15),
+        cudnn=True,
+        check_eval=True,
+    ),
+    dict(
+        module_name='InstanceNorm1d',
+        constructor_args=(3, 1e-3, 0.3, False, True),
+        input_size=(4, 3, 15),
+        cudnn=True,
+        check_eval=True,
+        desc='tracking_stats',
+    ),
+    dict(
+        module_name='InstanceNorm2d',
+        constructor_args=(3, 1e-3, 0.3),
+        input_size=(2, 3, 6, 6),
+        cudnn=True,
+        check_eval=True,
+    ),
+    dict(
+        module_name='InstanceNorm2d',
+        constructor_args=(3, 1e-3, 0.3, False, True),
+        input_size=(2, 3, 6, 6),
+        cudnn=True,
+        check_eval=True,
+        desc='tracking_stats',
+    ),
+    dict(
+        module_name='InstanceNorm3d',
+        constructor_args=(3, 1e-3, 0.3),
+        input_size=(2, 3, 4, 4, 4),
+        cudnn=True,
+        check_eval=True,
+    ),
+    dict(
+        module_name='InstanceNorm3d',
+        constructor_args=(3, 1e-3, 0.3, False, True),
+        input_size=(2, 3, 4, 4, 4),
+        cudnn=True,
+        check_eval=True,
+        desc='tracking_stats',
+    ),
+    dict(
+        module_name='LayerNorm',
+        constructor_args=([5], 1e-3, 0.3),
+        input_size=(4, 5, 5),
+        cudnn=True,
+        check_eval=True,
+        desc='1d_elementwise_affine',
+    ),
+    dict(
+        module_name='LayerNorm',
+        constructor_args=([5], 1e-3, 0.3, False),
+        input_size=(4, 5, 5),
+        cudnn=True,
+        check_eval=True,
+        desc='1d_no_elementwise_affine',
+    ),
+    dict(
+        module_name='LayerNorm',
+        constructor_args=([5], 1e-3, 0.3, True, True),
+        input_size=(4, 5, 5),
+        cudnn=True,
+        check_eval=True,
+        desc='1d_elementwise_affine_tracking_stats',
+    ),
+    dict(
+        module_name='LayerNorm',
+        constructor_args=([2, 2, 5], 1e-3, 0.3),
+        input_size=(4, 2, 2, 5),
+        cudnn=True,
+        check_eval=True,
+        desc='3d_elementwise_affine',
+    ),
+    dict(
+        module_name='LayerNorm',
+        constructor_args=([2, 2, 5], 1e-3, 0.3, False),
+        input_size=(4, 2, 2, 5),
+        cudnn=True,
+        check_eval=True,
+        desc='3d_no_elementwise_affine',
+    ),
+    dict(
+        module_name='LayerNorm',
+        constructor_args=([2, 2, 5], 1e-3, 0.3, True, True),
+        input_size=(4, 2, 2, 5),
+        cudnn=True,
+        check_eval=True,
+        desc='3d_elementwise_affine_tracking_stats',
+    ),
+    dict(
         module_name='Conv1d',
         constructor_args=(4, 5, 3),
         input_size=(2, 4, 10),
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 576a49f..fbd4efd 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -831,7 +831,7 @@
   self, weight, bias: thnn_batch_norm_backward(grad.contiguous(), self, weight, running_mean, running_var, training, eps, save_mean, save_std, grad_input_mask)
 
 - name: thnn_batch_norm_backward(Tensor grad_output, Tensor self, Tensor weight, Tensor running_mean, Tensor running_var, bool training, double eps, Tensor save_mean, Tensor save_std, std::array<bool,3> output_mask)
-  self, weight, grad_output: batchnorm_double_backward(self, weight, grads[0], grads[1], grads[2], grad_output, eps, save_mean, save_std, running_mean, running_var, training)
+  self, weight, grad_output: batchnorm_double_backward(self, weight, grads[0], grads[1], grads[2], grad_output, running_mean, running_var, training, eps, save_mean, save_std, grad_input_mask)
   save_mean: not_implemented("thnn_batch_norm_backward save_mean")
   save_std: not_implemented("thnn_batch_norm_backward save_std")
 
@@ -1096,7 +1096,7 @@
 - name: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor save_var, double epsilon)
   save_mean: not_implemented("cudnn_batch_norm_backward save_mean")
   save_var: not_implemented("cudnn_batch_norm_backward save_var")
-  input, weight, grad_output: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_output, epsilon, save_mean, save_var, running_mean, running_var, true)
+  input, weight, grad_output: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_output, running_mean, running_var, true, epsilon, save_mean, save_var, grad_input_mask)
 
 # nnpack
 
diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp
index 7eced4a..3eeb207 100644
--- a/tools/autograd/templates/Functions.cpp
+++ b/tools/autograd/templates/Functions.cpp
@@ -1053,12 +1053,13 @@
     const Tensor & ggG,
     const Tensor & ggB,
     const Tensor & gO,
+    const Tensor & running_mean_v,
+    const Tensor & running_var_v,
+    bool training,
     double eps,
     const Tensor & save_mean_v,
     const Tensor & save_std_v,
-    const Tensor & running_mean_v,
-    const Tensor & running_var_v,
-    bool training) {
+    std::array<bool,3> output_mask) {
 
   // NB: In the original design of BatchNorm, save_mean, save_std, running_mean
   // and running_var are unconditionally tensor "buffers", and never get wrapped
@@ -1173,6 +1174,13 @@
     ggO = ggO.defined() ? ggO.add_(ggO_B_term) : ggO_B_term;
   }
 
+  if (output_mask[0] && !ggO.defined()) ggO = at::zeros_like(gO);
+  if (output_mask[1] && !gG.defined()) {
+    AT_ASSERT(affine, "gamma should always be defined when it requires grad");
+    gG = at::zeros_like(gamma);
+  }
+  if (output_mask[2] && !gI.defined()) gI = at::zeros_like(input);
+
   return std::tuple<Tensor, Tensor, Tensor>{gI, gG, ggO};
 
 }
diff --git a/torch/csrc/utils/tensor_apply.cpp b/torch/csrc/utils/tensor_apply.cpp
index 95734f4..67f7630 100644
--- a/torch/csrc/utils/tensor_apply.cpp
+++ b/torch/csrc/utils/tensor_apply.cpp
@@ -1,6 +1,6 @@
 #include "tensor_apply.h"
 
-#include <ATen/Check.h>
+#include <ATen/TensorUtils.h>
 #include <ATen/ExpandUtils.h>
 
 #include "torch/csrc/Exceptions.h"
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index dda5df8..62ad0fb 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -1171,54 +1171,127 @@
     return ret
 
 
-def instance_norm(input, weight, bias, saved_running_mean, saved_running_var,
-                  training=False, momentum=0.1, eps=1e-5, affine=False):
-    """Applies instance normalization over an input. The implementation is
-    based on batch_norm, in which we do reshape, batchnorm, and reshape again.
+def batch_norm(input, running_mean, running_var, weight=None, bias=None,
+               training=False, momentum=0.1, eps=1e-5):
+    r"""Applies Batch Normalization for each channel across a batch of data.
+
+    See :class:`~torch.nn.BatchNorm1d`, :class:`~torch.nn.BatchNorm2d`,
+    :class:`~torch.nn.BatchNorm3d` for details.
+    """
+    if training:
+        size = list(input.size())
+        if reduce(mul, size[2:], size[0]) == 1:
+            raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))
+    if running_mean is not None:
+        running_mean = Variable(running_mean)
+    if running_var is not None:
+        running_var = Variable(running_var)
+    return torch._C._VariableFunctions.batch_norm(
+        input, weight, bias, running_mean, running_var,
+        training, momentum, eps, torch.backends.cudnn.enabled
+    )
+
+
+def instance_norm(input, running_mean, running_var, weight=None, bias=None,
+                  use_input_stats=True, momentum=0.1, eps=1e-5):
+    r"""Applies Instance Normalization for each channel in each data sample in a
+    batch.
 
     See :class:`~torch.nn.InstanceNorm1d`, :class:`~torch.nn.InstanceNorm2d`,
     :class:`~torch.nn.InstanceNorm3d` for details.
     """
-    import torch
+    if not use_input_stats and (running_mean is None or running_var is None):
+        raise ValueError('Expected running_mean and running_var to be not None when use_input_stats=False')
+
+    b, c = input.size(0), input.size(1)
+    if weight is not None:
+        weight = weight.repeat(b)
+    if bias is not None:
+        bias = bias.repeat(b)
+
     import torch.onnx.symbolic
 
     @torch.onnx.symbolic_override_first_arg_based(torch.onnx.symbolic.instance_norm)
-    def _instance_norm(input, weight=None, bias=None, saved_running_mean=None,
-                       saved_running_var=None, training=False, momentum=0.1,
-                       eps=1e-5, affine=False):
-        b, c = input.size(0), input.size(1)
-
-        # Repeat stored stats and affine transform params
-        running_mean = saved_running_mean.repeat(b)
-        running_var = saved_running_var.repeat(b)
+    def _instance_norm(input, running_mean=None, running_var=None, weight=None,
+                       bias=None, use_input_stats=None, momentum=None, eps=None):
+        # Repeat stored stats and affine transform params if necessary
+        if running_mean is not None:
+            running_mean_orig = running_mean
+            running_mean = running_mean_orig.repeat(b)
+        if running_var is not None:
+            running_var_orig = running_var
+            running_var = running_var_orig.repeat(b)
 
         # Apply instance norm
         input_reshaped = input.contiguous().view(1, b * c, *input.size()[2:])
 
         out = batch_norm(
             input_reshaped, running_mean, running_var, weight=weight, bias=bias,
-            training=training, momentum=momentum, eps=eps)
+            training=use_input_stats, momentum=momentum, eps=eps)
 
         # Reshape back
-        saved_running_mean.copy_(running_mean.view(b, c).mean(0, keepdim=False))
-        saved_running_var.copy_(running_var.view(b, c).mean(0, keepdim=False))
+        if running_mean is not None:
+            running_mean_orig.copy_(running_mean.view(b, c).mean(0, keepdim=False))
+        if running_var is not None:
+            running_var_orig.copy_(running_var.view(b, c).mean(0, keepdim=False))
 
         return out.view(b, c, *input.size()[2:])
-    return _instance_norm(input, weight=weight, bias=bias, saved_running_mean=saved_running_mean,
-                          saved_running_var=saved_running_var, training=training,
-                          momentum=momentum, eps=eps, affine=affine)
+    return _instance_norm(input, running_mean=running_mean,
+                          running_var=running_var, weight=weight, bias=bias,
+                          use_input_stats=use_input_stats, momentum=momentum,
+                          eps=eps)
 
 
-def batch_norm(input, running_mean, running_var, weight=None, bias=None,
-               training=False, momentum=0.1, eps=1e-5):
-    if training:
-        size = list(input.size())
-        if reduce(mul, size[2:], size[0]) == 1:
-            raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))
-    return torch._C._VariableFunctions.batch_norm(
-        input, weight, bias,
-        Variable(running_mean), Variable(running_var), training, momentum, eps, torch.backends.cudnn.enabled
-    )
+def layer_norm(input, normalized_shape, running_mean, running_var,
+               weight=None, bias=None, use_input_stats=True,
+               momentum=0.1, eps=1e-5):
+    r"""Applies Layer Normalization for last certain number of dimensions.
+
+    See :class:`~torch.nn.LayerNorm` for details.
+    """
+    if not use_input_stats and (running_mean is None or running_var is None):
+        raise ValueError('Expected running_mean and running_var to be not None when use_input_stats=False')
+
+    normalized_ndim = len(normalized_shape)
+    input_shape = input.size()
+
+    if input_shape[-normalized_ndim:] != torch.Size(normalized_shape):
+        raise ValueError('Expected input with shape [*, {}], but got {} input'
+                         .format(', '.join(normalized_shape), list(input_shape)))
+
+    n = reduce(mul, input_shape[:-normalized_ndim], 1)
+
+    # Repeat stored stats if necessary
+    if running_mean is not None:
+        running_mean_orig = running_mean
+        running_mean = running_mean_orig.repeat(n)
+    if running_var is not None:
+        running_var_orig = running_var
+        running_var = running_var_orig.repeat(n)
+
+    # Apply layer norm
+    input_reshaped = input.contiguous().view(1, n, -1)
+
+    out = batch_norm(
+        input_reshaped, running_mean, running_var, None, None,
+        use_input_stats, momentum, eps)
+
+    # Copy back
+    if running_mean is not None:
+        running_mean_orig.fill_(running_mean.mean())
+    if running_var is not None:
+        running_var_orig.fill_(running_var.mean())
+
+    out = out.view(*input_shape)
+
+    if weight is not None and bias is not None:
+        return torch.addcmul(bias, 1, out, weight)
+    elif weight is not None:
+        return torch.mul(out, weight)
+    elif bias is not None:
+        return torch.add(out, bias)
+    else:
+        return out
 
 
 def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1):
diff --git a/torch/nn/modules/__init__.py b/torch/nn/modules/__init__.py
index 52a75a8..8f4a854 100644
--- a/torch/nn/modules/__init__.py
+++ b/torch/nn/modules/__init__.py
@@ -15,10 +15,10 @@
     AdaptiveMaxPool2d, AdaptiveMaxPool3d, AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d
 from .batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d
 from .instancenorm import InstanceNorm1d, InstanceNorm2d, InstanceNorm3d
+from .normalization import LocalResponseNorm, CrossMapLRN2d, LayerNorm
 from .dropout import Dropout, Dropout2d, Dropout3d, AlphaDropout
 from .padding import ReflectionPad1d, ReflectionPad2d, ReplicationPad1d, ReplicationPad2d, \
     ReplicationPad3d, ZeroPad2d, ConstantPad1d, ConstantPad2d, ConstantPad3d
-from .normalization import LocalResponseNorm, CrossMapLRN2d
 from .sparse import Embedding, EmbeddingBag
 from .rnn import RNNBase, RNN, LSTM, GRU, \
     RNNCell, LSTMCell, GRUCell
@@ -39,9 +39,9 @@
     'ParameterList', 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d', 'MaxPool2d',
     'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'FractionalMaxPool2d',
     'LPPool1d', 'LPPool2d', 'LocalResponseNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'InstanceNorm1d',
-    'InstanceNorm2d', 'InstanceNorm3d', 'Dropout', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'ReflectionPad1d',
-    'ReflectionPad2d', 'ReplicationPad2d', 'ReplicationPad1d', 'ReplicationPad3d', 'CrossMapLRN2d',
-    'Embedding', 'EmbeddingBag', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCell', 'LSTMCell', 'GRUCell',
+    'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'Dropout', 'Dropout2d', 'Dropout3d', 'AlphaDropout',
+    'ReflectionPad1d', 'ReflectionPad2d', 'ReplicationPad2d', 'ReplicationPad1d', 'ReplicationPad3d',
+    'CrossMapLRN2d', 'Embedding', 'EmbeddingBag', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCell', 'LSTMCell', 'GRUCell',
     'PixelShuffle', 'Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d', 'PairwiseDistance',
     'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d',
     'AdaptiveAvgPool3d', 'TripletMarginLoss', 'ZeroPad2d', 'ConstantPad1d', 'ConstantPad2d',
diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py
index 58d62b0..53ce5b0 100644
--- a/torch/nn/modules/batchnorm.py
+++ b/torch/nn/modules/batchnorm.py
@@ -8,43 +8,56 @@
 # TODO: use separate backend functions?
 class _BatchNorm(Module):
 
-    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
+    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
+                 track_running_stats=True):
         super(_BatchNorm, self).__init__()
         self.num_features = num_features
-        self.affine = affine
         self.eps = eps
         self.momentum = momentum
+        self.affine = affine
+        self.track_running_stats = track_running_stats
         if self.affine:
             self.weight = Parameter(torch.Tensor(num_features))
             self.bias = Parameter(torch.Tensor(num_features))
         else:
             self.register_parameter('weight', None)
             self.register_parameter('bias', None)
-        self.register_buffer('running_mean', torch.zeros(num_features))
-        self.register_buffer('running_var', torch.ones(num_features))
+        if self.track_running_stats:
+            self.register_buffer('running_mean', torch.zeros(num_features))
+            self.register_buffer('running_var', torch.ones(num_features))
+        else:
+            self.register_parameter('running_mean', None)
+            self.register_parameter('running_var', None)
         self.reset_parameters()
 
     def reset_parameters(self):
-        self.running_mean.zero_()
-        self.running_var.fill_(1)
+        if self.track_running_stats:
+            self.running_mean.zero_()
+            self.running_var.fill_(1)
         if self.affine:
             self.weight.data.uniform_()
             self.bias.data.zero_()
 
+    def _check_input_dim(self, input):
+        return NotImplemented
+
     def forward(self, input):
+        self._check_input_dim(input)
+
         return F.batch_norm(
             input, self.running_mean, self.running_var, self.weight, self.bias,
-            self.training, self.momentum, self.eps)
+            self.training or not self.track_running_stats, self.momentum, self.eps)
 
     def __repr__(self):
         return ('{name}({num_features}, eps={eps}, momentum={momentum},'
-                ' affine={affine})'
+                ' affine={affine}, track_running_stats={track_running_stats})'
                 .format(name=self.__class__.__name__, **self.__dict__))
 
 
 class BatchNorm1d(_BatchNorm):
-    r"""Applies Batch Normalization over a 2d or 3d input that is seen as a
-    mini-batch.
+    r"""Applies Batch Normalization over a 2d or 3d input (a mini-batch of 1d
+    inputs with optional additional channel dimension) as described in the paper
+    `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ .
 
     .. math::
 
@@ -54,23 +67,39 @@
     the mini-batches and gamma and beta are learnable parameter vectors
     of size C (where C is the input size).
 
-    During training, this layer keeps a running estimate of its computed mean
-    and variance. The running sum is kept with a default momentum of 0.1.
+    By default, during training this layer keeps running estimates of its
+    computed mean and variance, which are then used for normalization during
+    evaluation. The running estimates are kept with a default :attr:`momentum`
+    of 0.1.
 
-    During evaluation, this running mean/variance is used for normalization.
+    If :attr:`track_running_stats` is set to ``False``, this layer then does not
+    keep running estimates, and batch statistics are instead used during
+    evaluation time as well.
+
+    .. note::
+        This :attr:`momentum` argument is different from one used in optimizer
+        classes and the conventional notion of momentum. Mathematically, the
+        update rule for running statistics here is
+        :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x}_\text{new} + \text{momemtum} \times x_t`,
+        where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
+        new observed value.
 
     Because the BatchNorm is done over the `C` dimension, computing statistics
-    on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
+    on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm.
 
     Args:
-        num_features: num_features from an expected input of size
-            `batch_size x num_features [x width]`
+        num_features: :math:`C` from an expected input of size
+            :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)`
         eps: a value added to the denominator for numerical stability.
             Default: 1e-5
         momentum: the value used for the running_mean and running_var
             computation. Default: 0.1
-        affine: a boolean value that when set to ``True``, gives the layer learnable
-            affine parameters. Default: ``True``
+        affine: a boolean value that when set to ``True``, this module has
+            learnable affine parameters. Default: ``True``
+        track_running_stats: a boolean value that when set to ``True``, this
+            module tracks the running mean and variance, and when set to ``False``,
+            this module does not track such statistics and always uses batch
+            statistics in both training and eval modes. Default: ``True``
 
     Shape:
         - Input: :math:`(N, C)` or :math:`(N, C, L)`
@@ -83,18 +112,21 @@
         >>> m = nn.BatchNorm1d(100, affine=False)
         >>> input = autograd.Variable(torch.randn(20, 100))
         >>> output = m(input)
+
+    .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`:
+        https://arxiv.org/abs/1502.03167
     """
 
     def _check_input_dim(self, input):
         if input.dim() != 2 and input.dim() != 3:
             raise ValueError('expected 2D or 3D input (got {}D input)'
                              .format(input.dim()))
-        super(BatchNorm1d, self)._check_input_dim(input)
 
 
 class BatchNorm2d(_BatchNorm):
-    r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
-    of 3d inputs
+    r"""Applies Batch Normalization over a 4d input (a mini-batch of 2d inputs
+    with additional channel dimension) as described in the paper
+    `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ .
 
     .. math::
 
@@ -104,23 +136,39 @@
     the mini-batches and gamma and beta are learnable parameter vectors
     of size C (where C is the input size).
 
-    During training, this layer keeps a running estimate of its computed mean
-    and variance. The running sum is kept with a default momentum of 0.1.
+    By default, during training this layer keeps running estimates of its
+    computed mean and variance, which are then used for normalization during
+    evaluation. The running estimates are kept with a default :attr:`momentum`
+    of 0.1.
 
-    During evaluation, this running mean/variance is used for normalization.
+    If :attr:`track_running_stats` is set to ``False``, this layer then does not
+    keep running estimates, and batch statistics are instead used during
+    evaluation time as well.
+
+    .. note::
+        This :attr:`momentum` argument is different from one used in optimizer
+        classes and the conventional notion of momentum. Mathematically, the
+        update rule for running statistics here is
+        :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x}_\text{new} + \text{momemtum} \times x_t`,
+        where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
+        new observed value.
 
     Because the BatchNorm is done over the `C` dimension, computing statistics
-    on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
+    on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm.
 
     Args:
-        num_features: num_features from an expected input of
-            size batch_size x num_features x height x width
+        num_features: :math:`C` from an expected input of size
+            :math:`(N, C, H, W)`
         eps: a value added to the denominator for numerical stability.
             Default: 1e-5
         momentum: the value used for the running_mean and running_var
             computation. Default: 0.1
-        affine: a boolean value that when set to ``True``, gives the layer learnable
-            affine parameters. Default: ``True``
+        affine: a boolean value that when set to ``True``, this module has
+            learnable affine parameters. Default: ``True``
+        track_running_stats: a boolean value that when set to ``True``, this
+            module tracks the running mean and variance, and when set to ``False``,
+            this module does not track such statistics and always uses batch
+            statistics in both training and eval modes. Default: ``True``
 
     Shape:
         - Input: :math:`(N, C, H, W)`
@@ -133,18 +181,21 @@
         >>> m = nn.BatchNorm2d(100, affine=False)
         >>> input = autograd.Variable(torch.randn(20, 100, 35, 45))
         >>> output = m(input)
+
+    .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`:
+        https://arxiv.org/abs/1502.03167
     """
 
     def _check_input_dim(self, input):
         if input.dim() != 4:
             raise ValueError('expected 4D input (got {}D input)'
                              .format(input.dim()))
-        super(BatchNorm2d, self)._check_input_dim(input)
 
 
 class BatchNorm3d(_BatchNorm):
-    r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
-    of 4d inputs
+    r"""Applies Batch Normalization over a 5d input (a mini-batch of 3d inputs
+    with additional channel dimension) as described in the paper
+    `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ .
 
     .. math::
 
@@ -154,24 +205,40 @@
     the mini-batches and gamma and beta are learnable parameter vectors
     of size C (where C is the input size).
 
-    During training, this layer keeps a running estimate of its computed mean
-    and variance. The running sum is kept with a default momentum of 0.1.
+    By default, during training this layer keeps running estimates of its
+    computed mean and variance, which are then used for normalization during
+    evaluation. The running estimates are kept with a default :attr:`momentum`
+    of 0.1.
 
-    During evaluation, this running mean/variance is used for normalization.
+    If :attr:`track_running_stats` is set to ``False``, this layer then does not
+    keep running estimates, and batch statistics are instead used during
+    evaluation time as well.
+
+    .. note::
+        This :attr:`momentum` argument is different from one used in optimizer
+        classes and the conventional notion of momentum. Mathematically, the
+        update rule for running statistics here is
+        :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x}_\text{new} + \text{momemtum} \times x_t`,
+        where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
+        new observed value.
 
     Because the BatchNorm is done over the `C` dimension, computing statistics
     on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
-    or Spatio-temporal BatchNorm
+    or Spatio-temporal BatchNorm.
 
     Args:
-        num_features: num_features from an expected input of
-            size batch_size x num_features x depth x height x width
+        num_features: :math:`C` from an expected input of size
+            :math:`(N, C, D, H, W)`
         eps: a value added to the denominator for numerical stability.
             Default: 1e-5
         momentum: the value used for the running_mean and running_var
             computation. Default: 0.1
-        affine: a boolean value that when set to ``True``, gives the layer learnable
-            affine parameters. Default: ``True``
+        affine: a boolean value that when set to ``True``, this module has
+            learnable affine parameters. Default: ``True``
+        track_running_stats: a boolean value that when set to ``True``, this
+            module tracks the running mean and variance, and when set to ``False``,
+            this module does not track such statistics and always uses batch
+            statistics in both training and eval modes. Default: ``True``
 
     Shape:
         - Input: :math:`(N, C, D, H, W)`
@@ -184,10 +251,12 @@
         >>> m = nn.BatchNorm3d(100, affine=False)
         >>> input = autograd.Variable(torch.randn(20, 100, 35, 45, 10))
         >>> output = m(input)
+
+    .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`:
+        https://arxiv.org/abs/1502.03167
     """
 
     def _check_input_dim(self, input):
         if input.dim() != 5:
             raise ValueError('expected 5D input (got {}D input)'
                              .format(input.dim()))
-        super(BatchNorm3d, self)._check_input_dim(input)
diff --git a/torch/nn/modules/instancenorm.py b/torch/nn/modules/instancenorm.py
index 3b6653a..ab748b9 100644
--- a/torch/nn/modules/instancenorm.py
+++ b/torch/nn/modules/instancenorm.py
@@ -3,38 +3,26 @@
 
 
 class _InstanceNorm(_BatchNorm):
-    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=False):
+    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=False,
+                 track_running_stats=False):
         super(_InstanceNorm, self).__init__(
-            num_features, eps, momentum, affine)
-        self._use_running_stats = False
+            num_features, eps, momentum, affine, track_running_stats)
+
+    def _check_input_dim(self, input):
+        return NotImplemented
 
     def forward(self, input):
-        b = input.size(0)
+        self._check_input_dim(input)
 
-        weight, bias = None, None
-        if self.affine:
-            weight = self.weight.repeat(b)
-            bias = self.bias.repeat(b)
-
-        training = not self._use_running_stats
-        return F.instance_norm(input, weight=weight, bias=bias,
-                               saved_running_mean=self.running_mean,
-                               saved_running_var=self.running_var,
-                               training=training, momentum=self.momentum,
-                               eps=self.eps, affine=self.affine)
-
-    def use_running_stats(self, mode=True):
-        r"""Set using running statistics or instance statistics.
-
-        Instance normalization usually use instance statistics in both training
-        and evaluation modes. But users can set this method to use running
-        statistics in the fashion similar to batch normalization in eval mode.
-        """
-        self._use_running_stats = mode
+        return F.instance_norm(
+            input, self.running_mean, self.running_var, self.weight, self.bias,
+            self.training or not self.track_running_stats, self.momentum, self.eps)
 
 
 class InstanceNorm1d(_InstanceNorm):
-    r"""Applies Instance Normalization over a 3d input that is seen as a mini-batch.
+    r"""Applies Instance Normalization over a 2d or 3d input (a mini-batch of 1d
+    inputs with optional additional channel dimension) as described in the paper
+    `Instance Normalization: The Missing Ingredient for Fast Stylization`_ .
 
     .. math::
 
@@ -42,22 +30,35 @@
 
     The mean and standard-deviation are calculated per-dimension separately
     for each object in a mini-batch. Gamma and beta are learnable parameter vectors
-    of size C (where C is the input size).
+    of size C (where C is the input size) if :attr:`affine` is ``True``.
 
-    During training, this layer keeps a running estimate of its computed mean
-    and variance. The running sum is kept with a default momentum of 0.1.
+    By default, this layer uses instance statistics computed from input data in
+    both training and evaluation modes.
 
-    At evaluation time (`.eval()`), the default behaviour of the InstanceNorm module stays the same
-    i.e. running mean/variance is NOT used for normalization. One can force using stored
-    mean and variance with `.use_running_stats(mode=True)` method, and switch back to normal
-    behavior with `.use_running_stats(mode=False)` method.
+    If :attr:`track_running_stats` is set to ``True``, during training this
+    layer keeps running estimates of its computed mean and variance, which are
+    then used for normalization during evaluation. The running estimates are
+    kept with a default :attr:`momentum` of 0.1.
+
+    .. note::
+        This :attr:`momentum` argument is different from one used in optimizer
+        classes and the conventional notion of momentum. Mathematically, the
+        update rule for running statistics here is
+        :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x}_\text{new} + \text{momemtum} \times x_t`,
+        where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
+        new observed value.
 
     Args:
-        num_features: num_features from an expected input of size `batch_size x num_features x width`
+        num_features: :math:`C` from an expected input of size
+            :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)`
         eps: a value added to the denominator for numerical stability. Default: 1e-5
         momentum: the value used for the running_mean and running_var computation. Default: 0.1
-        affine: a boolean value that when set to ``True``, gives the layer learnable
-            affine parameters. Default: ``False``
+        affine: a boolean value that when set to ``True``, this module has
+            learnable affine parameters. Default: ``True``
+        track_running_stats: a boolean value that when set to ``True``, this
+            module tracks the running mean and variance, and when set to ``False``,
+            this module does not track such statistics and always uses batch
+            statistics in both training and eval modes. Default: ``False``
 
     Shape:
         - Input: :math:`(N, C, L)`
@@ -70,17 +71,21 @@
         >>> m = nn.InstanceNorm1d(100, affine=True)
         >>> input = autograd.Variable(torch.randn(20, 100, 40))
         >>> output = m(input)
+
+    .. _`Instance Normalization: The Missing Ingredient for Fast Stylization`:
+        https://arxiv.org/abs/1607.08022
     """
 
     def _check_input_dim(self, input):
         if input.dim() != 3:
             raise ValueError('expected 3D input (got {}D input)'
                              .format(input.dim()))
-        super(InstanceNorm1d, self)._check_input_dim(input)
 
 
 class InstanceNorm2d(_InstanceNorm):
-    r"""Applies Instance Normalization over a 4d input that is seen as a mini-batch of 3d inputs
+    r"""Applies Instance Normalization over a 4d input (a mini-batch of 2d inputs
+    with additional channel dimension) as described in the paper
+    `Instance Normalization: The Missing Ingredient for Fast Stylization`_ .
 
     .. math::
 
@@ -88,22 +93,35 @@
 
     The mean and standard-deviation are calculated per-dimension separately
     for each object in a mini-batch. Gamma and beta are learnable parameter vectors
-    of size C (where C is the input size).
+    of size C (where C is the input size) if :attr:`affine` is ``True``.
 
-    During training, this layer keeps a running estimate of its computed mean
-    and variance. The running sum is kept with a default momentum of 0.1.
+    By default, this layer uses instance statistics computed from input data in
+    both training and evaluation modes.
 
-    At evaluation time (`.eval()`), the default behaviour of the InstanceNorm module stays the same
-    i.e. running mean/variance is NOT used for normalization. One can force using stored
-    mean and variance with `.use_running_stats(mode=True)` method, and switch back to normal
-    behavior with `.use_running_stats(mode=False)` method.
+    If :attr:`track_running_stats` is set to ``True``, during training this
+    layer keeps running estimates of its computed mean and variance, which are
+    then used for normalization during evaluation. The running estimates are
+    kept with a default :attr:`momentum` of 0.1.
+
+    .. note::
+        This :attr:`momentum` argument is different from one used in optimizer
+        classes and the conventional notion of momentum. Mathematically, the
+        update rule for running statistics here is
+        :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x}_\text{new} + \text{momemtum} \times x_t`,
+        where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
+        new observed value.
 
     Args:
-        num_features: num_features from an expected input of size batch_size x num_features x height x width
+        num_features: :math:`C` from an expected input of size
+            :math:`(N, C, H, W)`
         eps: a value added to the denominator for numerical stability. Default: 1e-5
         momentum: the value used for the running_mean and running_var computation. Default: 0.1
-        affine: a boolean value that when set to ``True``, gives the layer learnable
-            affine parameters. Default: ``False``
+        affine: a boolean value that when set to ``True``, this module has
+            learnable affine parameters. Default: ``True``
+        track_running_stats: a boolean value that when set to ``True``, this
+            module tracks the running mean and variance, and when set to ``False``,
+            this module does not track such statistics and always uses batch
+            statistics in both training and eval modes. Default: ``False``
 
     Shape:
         - Input: :math:`(N, C, H, W)`
@@ -116,41 +134,57 @@
         >>> m = nn.InstanceNorm2d(100, affine=True)
         >>> input = autograd.Variable(torch.randn(20, 100, 35, 45))
         >>> output = m(input)
+
+    .. _`Instance Normalization: The Missing Ingredient for Fast Stylization`:
+        https://arxiv.org/abs/1607.08022
     """
 
     def _check_input_dim(self, input):
         if input.dim() != 4:
             raise ValueError('expected 4D input (got {}D input)'
                              .format(input.dim()))
-        super(InstanceNorm2d, self)._check_input_dim(input)
 
 
 class InstanceNorm3d(_InstanceNorm):
-    r"""Applies Instance Normalization over a 5d input that is seen as a mini-batch of 4d inputs
+    r"""Applies Instance Normalization over a 5d input (a mini-batch of 3d inputs
+    with additional channel dimension) as described in the paper
+    `Instance Normalization: The Missing Ingredient for Fast Stylization`_ .
 
     .. math::
 
         y = \frac{x - mean[x]}{ \sqrt{Var[x]} + \epsilon} * gamma + beta
 
-    The mean and standard-deviation are calculated per-dimension separately for each object in a mini-batch.
-    Gamma and beta are learnable parameter vectors
-    of size C (where C is the input size).
+    The mean and standard-deviation are calculated per-dimension separately
+    for each object in a mini-batch. Gamma and beta are learnable parameter vectors
+    of size C (where C is the input size) if :attr:`affine` is ``True``.
 
-    During training, this layer keeps a running estimate of its computed mean
-    and variance. The running sum is kept with a default momentum of 0.1.
+    By default, this layer uses instance statistics computed from input data in
+    both training and evaluation modes.
 
-    At evaluation time (`.eval()`), the default behaviour of the InstanceNorm module stays the same
-    i.e. running mean/variance is NOT used for normalization. One can force using stored
-    mean and variance with `.use_running_stats(mode=True)` method, and switch back to normal
-    behavior with `.use_running_stats(mode=False)` method.
+    If :attr:`track_running_stats` is set to ``True``, during training this
+    layer keeps running estimates of its computed mean and variance, which are
+    then used for normalization during evaluation. The running estimates are
+    kept with a default :attr:`momentum` of 0.1.
 
+    .. note::
+        This :attr:`momentum` argument is different from one used in optimizer
+        classes and the conventional notion of momentum. Mathematically, the
+        update rule for running statistics here is
+        :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x}_\text{new} + \text{momemtum} \times x_t`,
+        where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
+        new observed value.
 
     Args:
-        num_features: num_features from an expected input of size batch_size x num_features x depth x height x width
+        num_features: :math:`C` from an expected input of size
+            :math:`(N, C, D, H, W)`
         eps: a value added to the denominator for numerical stability. Default: 1e-5
         momentum: the value used for the running_mean and running_var computation. Default: 0.1
-        affine: a boolean value that when set to ``True``, gives the layer learnable
-            affine parameters. Default: ``False``
+        affine: a boolean value that when set to ``True``, this module has
+            learnable affine parameters. Default: ``True``
+        track_running_stats: a boolean value that when set to ``True``, this
+            module tracks the running mean and variance, and when set to ``False``,
+            this module does not track such statistics and always uses batch
+            statistics in both training and eval modes. Default: ``False``
 
     Shape:
         - Input: :math:`(N, C, D, H, W)`
@@ -163,10 +197,12 @@
         >>> m = nn.InstanceNorm3d(100, affine=True)
         >>> input = autograd.Variable(torch.randn(20, 100, 35, 45, 10))
         >>> output = m(input)
+
+    .. _`Instance Normalization: The Missing Ingredient for Fast Stylization`:
+        https://arxiv.org/abs/1607.08022
     """
 
     def _check_input_dim(self, input):
         if input.dim() != 5:
             raise ValueError('expected 5D input (got {}D input)'
                              .format(input.dim()))
-        super(InstanceNorm3d, self)._check_input_dim(input)
diff --git a/torch/nn/modules/normalization.py b/torch/nn/modules/normalization.py
index 2ffb275..f9487a3 100644
--- a/torch/nn/modules/normalization.py
+++ b/torch/nn/modules/normalization.py
@@ -1,4 +1,7 @@
+import torch
+from torch.nn.parameter import Parameter
 from .module import Module
+from .batchnorm import _BatchNorm
 from .. import functional as F
 
 
@@ -8,7 +11,6 @@
     Applies normalization across channels.
 
     .. math::
-
         b_{c} = a_{c}\left(k + \frac{\alpha}{n}
         \sum_{c'=\max(0, c-n/2)}^{\min(N-1,c+n/2)}a_{c'}^2\right)^{-\beta}
 
@@ -21,7 +23,7 @@
     Shape:
         - Input: :math:`(N, C, ...)`
         - Output: :math:`(N, C, ...)` (same shape as input)
-    Examples::
+    Examples:
         >>> lrn = nn.LocalResponseNorm(2)
         >>> signal_2d = autograd.Variable(torch.randn(32, 5, 24, 24))
         >>> signal_4d = autograd.Variable(torch.randn(16, 5, 7, 7, 7, 7))
@@ -69,6 +71,109 @@
             + ', k=' + str(self.k) + ')'
 
 
+class LayerNorm(Module):
+    r"""Applies Layer Normalization over a mini-batch of inputs as described in
+    the paper `Layer Normalization`_ .
+
+    .. math::
+        y = \frac{x - mean[x]}{ \sqrt{Var[x]} + \epsilon} * gamma + beta
+
+    The mean and standard-deviation are calculated separately over the last
+    certain number dimensions with shape specified by :attr:`normalized_shape`.
+    Gamma and beta are learnable parameters of :attr:`normalized_shape` if
+    :attr:`elementwise_affine` is ``True``.
+
+    .. note::
+        Unlike Batch Normalization and Instance Normalization, which applies
+        scalar scale and bias for each entire channel/plane with the
+        :attr:`affine` option, Layer Normalization applies per-element scale and
+        bias with :attr:`elementwise_affine`.
+
+    By default, this layer uses statistics computed from input data in both
+    training and evaluation modes.
+
+    If :attr:`track_running_stats` is set to ``True``, during training this
+    layer keeps running estimates of its computed mean and variance, which are
+    then used for normalization during evaluation. The running estimates are
+    kept with a default :attr:`momentum` of 0.1.
+
+    .. note::
+        This :attr:`momentum` argument is different from one used in optimizer
+        classes and the conventional notion of momentum. Mathematically, the
+        update rule for running statistics here is
+        :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x}_\text{new} + \text{momemtum} \times x_t`,
+        where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
+        new observed value.
+
+    Args:
+        normalized_shape (list or torch.Size): input shape from an expected input of size
+            `[* x normalized_shape[0] x normalized_shape[1] x ... x normalized_shape[-1]]`
+        eps: a value added to the denominator for numerical stability. Default: 1e-5
+        momentum: the value used for the running_mean and running_var computation. Default: 0.1
+        elementwise_affine: a boolean value that when set to ``True``, this module
+            has learnable per-element affine parameters. Default: ``True``
+        track_running_stats: a boolean value that when set to ``True``, this
+            module tracks the running mean and variance, and when set to ``False``,
+            this module does not track such statistics and always uses batch
+            statistics in both training and eval modes. Default: ``False``
+
+    Shape:
+        - Input: :math:`(N, *)`
+        - Output: :math:`(N, *)` (same shape as input)
+
+    Examples:
+        >>> input = autograd.Variable(torch.randn(20, 5, 10, 10))
+        >>> # With Learnable Parameters
+        >>> m = nn.LayerNorm(input.size()[1:])
+        >>> # Without Learnable Parameters
+        >>> m = nn.LayerNorm(input.size()[1:], elementwise_affine=False)
+        >>> output = m(input)
+
+    .. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
+    """
+    def __init__(self, normalized_shape, eps=1e-5, momentum=0.1,
+                 elementwise_affine=True, track_running_stats=False):
+        super(LayerNorm, self).__init__()
+        self.normalized_shape = torch.Size(normalized_shape)
+        self.eps = eps
+        self.momentum = momentum
+        self.elementwise_affine = elementwise_affine
+        self.track_running_stats = track_running_stats
+        if self.elementwise_affine:
+            self.weight = Parameter(torch.Tensor(*normalized_shape))
+            self.bias = Parameter(torch.Tensor(*normalized_shape))
+        else:
+            self.register_parameter('weight', None)
+            self.register_parameter('bias', None)
+        if self.track_running_stats:
+            self.register_buffer('running_mean', torch.zeros(1))
+            self.register_buffer('running_var', torch.ones(1))
+        else:
+            self.register_parameter('running_mean', None)
+            self.register_parameter('running_var', None)
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        if self.track_running_stats:
+            self.running_mean.zero_()
+            self.running_var.fill_(1)
+        if self.elementwise_affine:
+            self.weight.data.uniform_()
+            self.bias.data.zero_()
+
+    def forward(self, input):
+        return F.layer_norm(
+            input, self.normalized_shape, self.running_mean, self.running_var,
+            self.weight, self.bias, self.training or not self.track_running_stats,
+            self.momentum, self.eps)
+
+    def __repr__(self):
+        return ('{name}({normalized_shape}, eps={eps}, momentum={momentum},'
+                ' elementwise_affine={elementwise_affine},'
+                ' track_running_stats={track_running_stats})'
+                .format(name=self.__class__.__name__, **self.__dict__))
+
+
 # TODO: ContrastiveNorm2d
 # TODO: DivisiveNorm2d
 # TODO: SubtractiveNorm2d
diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py
index 9acde5e..5f2ad73 100644
--- a/torch/onnx/__init__.py
+++ b/torch/onnx/__init__.py
@@ -1,4 +1,4 @@
-"""
+r"""
 The torch.onnx module contains functions to export models into the ONNX
 IR format.  These models can be loaded with the ONNX library and then
 converted to models which run on other deep learning frameworks.
@@ -22,7 +22,7 @@
 
 @contextlib.contextmanager
 def set_training(model, mode):
-    """
+    r"""
     A context manager to temporarily set the training mode of 'model'
     to 'mode', resetting it when we exit the with-block.  A no-op if
     mode is None.
@@ -42,7 +42,7 @@
 
 def export(model, args, f, export_params=True, verbose=False, training=False,
            input_names=None, output_names=None, aten=False):
-    """
+    r"""
     Export a model into ONNX format.  This exporter runs your model
     once in order to get a trace of its execution to be exported;
     at the moment, it supports a limited set of dynamic models (e.g., RNNs.)
@@ -172,7 +172,7 @@
 
 
 def _run_symbolic_method(op_name, symbolic_fn, args):
-    """
+    r"""
     This trampoline function gets invoked for every symbolic method
     call from C++.
     """
@@ -193,7 +193,7 @@
 
 
 def _add_attribute(node, key, value, aten):
-    """ initializes the right attribute based on type of value """
+    r""" initializes the right attribute based on type of value """
     m = attr_pattern.match(key)
     if m is None:
         raise IndexError((
@@ -233,7 +233,7 @@
 
 
 def _graph_op(g, opname, *raw_args, **kwargs):
-    """
+    r"""
     Create an ONNX operator 'opname', taking 'args' as inputs and attributes
     'kwargs'; returning the node representing the single output of this operator
     (see the `outputs` keyword argument for multi-return nodes).
@@ -364,7 +364,7 @@
 
 
 def _node_getitem(self, k):
-    """
+    r"""
     Accessor for attributes of a node which is polymorphic over
     return type.
 
@@ -418,7 +418,7 @@
 
 
 def symbolic_override(symbolic_fn):
-    """
+    r"""
     Decorator to override ONNX export of the a function with specified subgraph.
 
     Effectively allows to attach symbolic() implementation to an arbitrary
@@ -429,7 +429,8 @@
        them (similar requirement to NestedIOFunction)
      - outputs are similarly Variables/Tensors or (nested) lists or tuples of
        them
-     - keyword arguments are of non-tensor type
+     - non-tensor typed values should be keyword arguments both in definition
+       and when called
 
     Example usage:
 
@@ -447,12 +448,12 @@
 
 
 def symbolic_override_first_arg_based(symbolic_fn):
-    """
+    r"""
     Decorator to override ONNX export of the a function with specified subgraph.
 
-    Equivalent to `symbolic_override` but checks only the first argument of the
-    function to figure out whether the tracing is on. Thus the first arg needs
-    to be a Variable.
+    Equivalent to :func:`symbolic_override` but checks only the first argument
+    of the function to figure out whether the tracing is on. Thus the first arg
+    needs to be a Variable.
     """
 
     return functools.partial(_symbolic_override_wrapper_maker, symbolic_fn, True)