Revert D13548303: [pytorch][PR] Add support for batch_norm fusion to the JIT

Differential Revision:
D13548303

Original commit changeset: a2e2e5abc383

fbshipit-source-id: 5b70cdbcbd1cac06eeefb2a939773358c061183c
diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h
index 5e9144b..40f9b39 100644
--- a/aten/src/ATen/core/interned_strings.h
+++ b/aten/src/ATen/core/interned_strings.h
@@ -71,7 +71,6 @@
   _(prim, MMBatchSide)             \
   _(prim, min)                     \
   _(prim, max)                     \
-  _(aten, _ncf_unsqueeze)          \
   _(aten, warn)                    \
   _(aten, floordiv)                \
   _(aten, __round_to_zero_floordiv)\
diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp
index 85f517f..1843f78 100644
--- a/aten/src/ATen/native/Normalization.cpp
+++ b/aten/src/ATen/native/Normalization.cpp
@@ -36,35 +36,24 @@
   return t.accessor<scalar_t, 1>();
 }
 
-template<typename T>
-struct InvStd {
-  T operator()(T var, double epsilon) const {
-    T invstd = 0;
-    if (var != static_cast<T>(0) || epsilon != static_cast<T>(0)) {
-      invstd = static_cast<T>(1) / std::sqrt(var + epsilon);
-    }
-    return invstd;
-  }
-};
-
-template<typename T>
-struct Var {
-  T operator()(T var, double epsilon) const {
-    return var;
-  }
-};
 
 template<typename scalar_t>
-std::tuple<Tensor,Tensor,Tensor> batch_norm_cpu_transform_input_template(
-    const Tensor& input, const Tensor& weight, const Tensor& bias,
-    const Tensor& save_mean /* optional */, const Tensor& save_invstd /* optional */,
-    const Tensor& running_mean /* optional */, const Tensor& running_var /* optional */,
-    bool train, double eps) {
+std::tuple<Tensor,Tensor,Tensor> batch_norm_cpu_template(const Tensor& input, const Tensor& weight, const Tensor& bias,
+                               const Tensor& running_mean, const Tensor& running_var, bool train, double momentum, double eps) {
 
+  using accscalar_t = at::acc_type<scalar_t, false>;
   Tensor output = at::empty_like(input);
 
   int64_t n_input = input.size(1);
+  int64_t n = input.numel() / n_input;
 
+  Tensor save_mean;
+  Tensor save_invstd;
+  const int64_t zero = 0;
+  if (train) {
+    save_mean = at::empty({n_input}, input.options());
+    save_invstd = at::empty({n_input}, input.options());
+  }
   auto save_mean_a = conditional_accessor_1d<scalar_t>(save_mean);
   auto save_invstd_a = conditional_accessor_1d<scalar_t>(save_invstd);
 
@@ -72,81 +61,60 @@
   auto running_var_a = conditional_accessor_1d<scalar_t>(running_var);
 
   parallel_for(0, n_input, 1, [&](int64_t b_begin, int64_t b_end) {
-    for (int64_t f = b_begin; f < b_end; ++f) {
-      Tensor in = input.select(1, f);
-      Tensor out = output.select(1, f);
+      for (int64_t f = b_begin; f < b_end; ++f) {
+        Tensor in = input.select(1, f);
+        Tensor out = output.select(1, f);
 
-      scalar_t mean, invstd;
-      if (train) {
-        mean = save_mean_a[f];
-        invstd = save_invstd_a[f];
-      } else {
-        mean = running_mean_a[f];
-        invstd = 1 / std::sqrt(running_var_a[f] + eps);
+        scalar_t mean, invstd;
+
+        if (train) {
+          // compute mean per input
+          accscalar_t sum = 0;
+          CPU_tensor_apply1<scalar_t>(in, [&] (const scalar_t& i) {
+              sum += i;
+            });
+
+          mean = (scalar_t) (sum / n);
+          save_mean_a[f] = mean;
+
+          // compute variance per input
+          sum = 0;
+          CPU_tensor_apply1<scalar_t>(in, [&] (const scalar_t& i) {
+              sum += (i - mean) * (i - mean);
+            });
+
+          if (sum == 0 && eps == 0.0) {
+            invstd = 0;
+          } else {
+            invstd = (scalar_t) (1 / std::sqrt(sum/n + eps));
+          }
+          save_invstd_a[f] = invstd;
+
+          // update running averages
+          if (running_mean.defined()) {
+            running_mean_a[f] = momentum * mean + (1 - momentum) * running_mean_a[f];
+          }
+          if (running_var.defined()) {
+            accscalar_t unbiased_var = sum / (n - 1);
+            running_var_a[f] = momentum * unbiased_var + (1 - momentum) * running_var_a[f];
+          }
+        } else {
+          mean = running_mean_a[f];
+          invstd = 1 / std::sqrt(running_var_a[f] + eps);
+        }
+
+        // compute output
+        scalar_t w = weight.defined() ? weight.data<scalar_t>()[f * weight.stride(0)] : 1;
+        scalar_t b = bias.defined() ? bias.data<scalar_t>()[f * bias.stride(0)] : 0;
+
+        CPU_tensor_apply2<scalar_t,scalar_t>(out, in, [&](scalar_t& o, const scalar_t& i) {
+            o = ((i - mean) * invstd) * w + b;
+          });
       }
-
-      // compute output
-      scalar_t w = weight.defined() ? weight.data<scalar_t>()[f * weight.stride(0)] : 1;
-      scalar_t b = bias.defined() ? bias.data<scalar_t>()[f * bias.stride(0)] : 0;
-
-      CPU_tensor_apply2<scalar_t,scalar_t>(out, in, [&](scalar_t& o, const scalar_t& i) {
-        o = ((i - mean) * invstd) * w + b;
-      });
-    }
-  });
+    });
   return std::make_tuple(output, save_mean, save_invstd);
 }
 
-template<typename scalar_t, template<typename T> class VarTransform>
-std::tuple<Tensor,Tensor> batch_norm_cpu_update_stats_template(
-    const Tensor& input, const Tensor& running_mean, const Tensor& running_var,
-    double momentum, double eps) {
-
-  using accscalar_t = at::acc_type<scalar_t, false>;
-
-  int64_t n_input = input.size(1);
-  int64_t n = input.numel() / n_input;
-
-  Tensor save_mean = at::empty({n_input}, input.options());
-  Tensor save_var_transform = at::empty({n_input}, input.options());
-  auto save_mean_a = save_mean.accessor<scalar_t, 1>();
-  auto save_var_transform_a = save_var_transform.accessor<scalar_t, 1>();
-
-  auto running_mean_a = conditional_accessor_1d<scalar_t>(running_mean);
-  auto running_var_a = conditional_accessor_1d<scalar_t>(running_var);
-
-  parallel_for(0, n_input, 1, [&](int64_t b_begin, int64_t b_end) {
-    for (int64_t f = b_begin; f < b_end; ++f) {
-      Tensor in = input.select(1, f);
-
-      // compute mean per input
-      accscalar_t sum = 0;
-      CPU_tensor_apply1<scalar_t>(in, [&] (const scalar_t& i) {
-          sum += i;
-        });
-      scalar_t mean = sum / n;
-      save_mean_a[f] = mean;
-
-      // compute variance per input
-      accscalar_t var_sum = 0;
-      CPU_tensor_apply1<scalar_t>(in, [&] (const scalar_t& i) {
-        var_sum += (i - mean) * (i - mean);
-      });
-      save_var_transform_a[f] = VarTransform<accscalar_t>{}(var_sum / n, eps);
-
-      // update running averages
-      if (running_mean.defined()) {
-        running_mean_a[f] = momentum * mean + (1 - momentum) * running_mean_a[f];
-      }
-      if (running_var.defined()) {
-        accscalar_t unbiased_var = var_sum / (n - 1);
-        running_var_a[f] = momentum * unbiased_var + (1 - momentum) * running_var_a[f];
-      }
-    }
-  });
-  return std::make_tuple(save_mean, save_var_transform);
-}
-
 
 template<typename scalar_t>
 std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(const Tensor& grad_out_, const Tensor& input, const Tensor& weight,
@@ -451,23 +419,11 @@
     }
 }
 
-std::tuple<Tensor, Tensor> batch_norm_update_stats_cpu(
-        const Tensor& self, const Tensor& running_mean, const Tensor& running_var, double momentum) {
-  return AT_DISPATCH_FLOATING_TYPES(self.type(), "batch_norm_update_stats", [&] {
-      return batch_norm_cpu_update_stats_template<scalar_t, Var>(self, running_mean, running_var, momentum, 0);
-    });
-}
-
 std::tuple<Tensor, Tensor, Tensor> batch_norm_cpu(const Tensor& self, const Tensor& weight, const Tensor& bias,
                                                   const Tensor& running_mean, const Tensor& running_var,
                                                   bool train, double momentum, double eps) {
   return AT_DISPATCH_FLOATING_TYPES(self.type(), "batch_norm", [&] {
-      if (!train) {
-        return batch_norm_cpu_transform_input_template<scalar_t>(self, weight, bias, {}, {}, running_mean, running_var, train, eps);
-      } else {
-        auto save_stats = batch_norm_cpu_update_stats_template<scalar_t, InvStd>(self, running_mean, running_var, momentum, eps);
-        return batch_norm_cpu_transform_input_template<scalar_t>(self, weight, bias, std::get<0>(save_stats), std::get<1>(save_stats), running_mean, running_var, train, eps);
-      }
+      return batch_norm_cpu_template<scalar_t>(self, weight, bias, running_mean, running_var, train, momentum, eps);
     });
 }
 
diff --git a/aten/src/ATen/native/cuda/Normalization.cu b/aten/src/ATen/native/cuda/Normalization.cu
index 551e603..3e76d27 100644
--- a/aten/src/ATen/native/cuda/Normalization.cu
+++ b/aten/src/ATen/native/cuda/Normalization.cu
@@ -24,15 +24,4 @@
     });
 }
 
-std::tuple<Tensor, Tensor> batch_norm_update_stats_cuda(
-        const Tensor& self, const Tensor& running_mean, const Tensor& running_var, double momentum) {
-  return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_backward", [&] {
-      if (cuda::detail::canUse32BitIndexMath(self)) {
-        return batch_norm_update_stats_cuda_template<scalar_t, int32_t>(self, running_mean, running_var, momentum);
-      } else {
-        return batch_norm_update_stats_cuda_template<scalar_t, int64_t>(self, running_mean, running_var, momentum);
-      }
-    });
-}
-
 } } // namespace at::native
diff --git a/aten/src/ATen/native/cuda/Normalization.cuh b/aten/src/ATen/native/cuda/Normalization.cuh
index bf3e9a5..e186ef3 100644
--- a/aten/src/ATen/native/cuda/Normalization.cuh
+++ b/aten/src/ATen/native/cuda/Normalization.cuh
@@ -200,25 +200,8 @@
   }
 }
 
-template<typename T>
-struct InvStd {
-  __device__ __forceinline__ T operator()(T var, double epsilon) const {
-    T invstd = 0;
-    if (var != static_cast<T>(0) || epsilon != static_cast<T>(0)) {
-      invstd = static_cast<T>(1) / device_sqrt(var + epsilon);
-    }
-    return invstd;
-  }
-};
 
-template<typename T>
-struct Var {
-  __device__ __forceinline__ T operator()(T var, double epsilon) const {
-    return var;
-  }
-};
-
-template <template<typename T> class VarTransform, typename scalar_t, typename accscalar_t, typename index_t>
+template <typename scalar_t, typename accscalar_t, typename index_t>
 __global__ void batch_norm_collect_statistics_kernel(
     const PackedTensorAccessor<scalar_t, 3, RestrictPtrTraits, index_t> input,
     const accscalar_t epsilon,
@@ -226,7 +209,7 @@
     PackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> running_mean,
     PackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> running_var,
     PackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> save_mean,
-    PackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> save_transformed_var) {
+    PackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> save_invstd) {
 
   __shared__ int shared_n[2 * 2 * WARP_SIZE + WARP_SIZE];
 
@@ -269,7 +252,7 @@
 
   // this writes each warps  item into shared memory
   // there are at most WARP_SIZE items left because
-  // there are at most WARP_SIZE**2 threads at the beginning
+  // there are at most WARP_SIZE**2 threads at the beginning  
   __syncthreads();
   if (tid % WARP_SIZE == 0) {
     shared_n[tid / WARP_SIZE] = n;
@@ -297,8 +280,12 @@
 
   // Save the mean, variance, and moving averages
   if (tid == 0) {
+    accscalar_t invstd = 0;
+    if (var_n != static_cast<accscalar_t>(0) || epsilon != static_cast<accscalar_t>(0)) {
+      invstd = static_cast<accscalar_t>(1) / device_sqrt(var_n / N + epsilon);
+    }
     save_mean[plane] = avg;
-    save_transformed_var[plane] = VarTransform<accscalar_t>{}(var_n / N, epsilon);
+    save_invstd[plane] = invstd;
     if (running_mean.data() != NULL) {
       running_mean[plane] = static_cast<scalar_t>((1 - momentum) * running_mean[plane] + momentum * avg);
     }
@@ -444,7 +431,7 @@
     dim3 blocks(input.size(1));
     tf = getNumThreads(input.size(2));
     dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
-    batch_norm_collect_statistics_kernel<InvStd, scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
+    batch_norm_collect_statistics_kernel<scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
       (input, epsilon, momentum, running_mean, running_var, save_mean, save_invstd);
     batch_norm_transform_input_kernel<scalar_t, accscalar_t, true, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
       (input, output, save_mean, save_invstd, weight, bias, epsilon);
@@ -501,39 +488,4 @@
   return std::make_tuple(grad_input_, grad_weight_, grad_bias_);
 }
 
-template<typename scalar_t, typename index_t>
-std::tuple<Tensor, Tensor> batch_norm_update_stats_cuda_template(
-        const Tensor& input_, const Tensor& running_mean_, const Tensor& running_var_, double momentum) {
-
-  using accscalar_t = at::acc_type<scalar_t, true>;
-  int64_t n_channels = input_.size(1);
-  auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
-
-  auto input_options = input_.options();
-  if (input_.type().scalarType() == at::ScalarType::Half) {
-    input_options = input_options.dtype(ScalarType::Float);
-  }
-  Tensor save_mean_ = at::empty({n_channels}, input_options);
-  Tensor save_var_ = at::empty({n_channels}, input_options);
-
-  auto input = input_reshaped.packed_accessor<scalar_t, 3, RestrictPtrTraits, index_t>();
-  auto running_mean = packed_accessor_or_dummy<scalar_t, 1, RestrictPtrTraits, index_t>(running_mean_);
-  auto running_var = packed_accessor_or_dummy<scalar_t, 1, RestrictPtrTraits, index_t>(running_var_);
-  auto save_mean = save_mean_.packed_accessor<accscalar_t, 1, RestrictPtrTraits, index_t>();
-  auto save_var = save_var_.packed_accessor<accscalar_t, 1, RestrictPtrTraits, index_t>();
-  auto stream = at::cuda::getCurrentCUDAStream();
-
-  // for the reduction, we cannot use blocks for the batch dim, but if we have few threads in
-  // the feature dimension, we'll use some threads for blocks
-  dim3 blocks(input.size(1));
-  int tf = getNumThreads(input.size(2));
-  dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
-  // NB: epsilon is unused by the Var transform, so we set it to 0
-  batch_norm_collect_statistics_kernel<Var, scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
-    (input, 0., momentum, running_mean, running_var, save_mean, save_var);
-  THCudaCheck(cudaGetLastError());
-  return std::make_tuple(save_mean_, save_var_);
-
-}
-
 } } // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 7c6cbe9..fa01c99 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1247,11 +1247,6 @@
     CPU: batch_norm_backward_cpu
     CUDA: batch_norm_backward_cuda
 
-- func: batch_norm_update_stats(Tensor input, Tensor? running_mean, Tensor? running_var, double momentum) -> (Tensor, Tensor)
-  dispatch:
-    CPU: batch_norm_update_stats_cpu
-    CUDA: batch_norm_update_stats_cuda
-
 - func: ones(IntList size, TensorOptions options={}) -> Tensor
 
 - func: ones_out(Tensor result, IntList size) -> Tensor
diff --git a/test/test_jit.py b/test/test_jit.py
index 0f43f48..5d959bd 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -10409,58 +10409,6 @@
 
     @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
     @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
-    def test_fuse_batch_norm(self):
-
-        class ResLike(torch.jit.ScriptModule):
-            def __init__(self, optimize=True):
-                super(ResLike, self).__init__(optimize)
-                self.bn = nn.BatchNorm2d(16)
-
-            @torch.jit.script_method
-            def forward(self, x, y):
-                return y + torch.relu(self.bn(x))
-
-        model = ResLike().cuda()
-        model_noopt = ResLike(optimize=False).cuda()
-        model_noopt.load_state_dict(model.state_dict())
-        x = torch.randn(2, 16, 8, 8, device='cuda')
-        y = torch.randn(2, 16, 8, 8, device='cuda')
-        # FIXME: We need differentiation for CNNs for this optimization to trigger
-        with torch.no_grad():
-            out = model(x, y)
-            graph = model.graph_for(x, y)
-            rep = str(graph)
-
-            out_noopt = model_noopt(x, y)
-            rep_noopt = str(model_noopt.graph_for(x, y))
-            self.assertEqual(out, out_noopt, prec=3e-5)
-
-        # Check that batch_norm has really been decomposed
-        self.assertIn('aten::batch_norm_update_stats', rep)
-        self.assertNotIn('aten::batch_norm(', rep)
-        self.assertIn('aten::batch_norm(', rep_noopt)
-
-        # Make sure the fusion group is big, and contains aten::sqrt, which could
-        # originate only from decomposing batch_norm in this case
-        fusion_groups = [node for node in graph.nodes() if node.kind() == 'prim::FusionGroup']
-        self.assertEqual(len(fusion_groups), 1)
-        fused_graph = fusion_groups[0].g('Subgraph')
-        self.assertTrue(any(node.kind() == 'aten::sqrt' for node in fused_graph.nodes()))
-
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
-    def test_threshold(self):
-        def f(x):
-            return torch.threshold(x, 0, -10) + x + x + x
-
-        x = torch.tensor([-1, -0.5, 0, 1, 2, 3], device='cuda')
-        scripted = torch.jit.script(f)
-
-        self.assertEqual(f(x), scripted(x))
-        self.assertAllFused(scripted.graph_for(x))
-
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
     @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
     @skipIfRocm
     @enable_cpu_fuser
diff --git a/torch/csrc/jit/fuser/codegen.cpp b/torch/csrc/jit/fuser/codegen.cpp
index 907f2d1..e1ed130 100644
--- a/torch/csrc/jit/fuser/codegen.cpp
+++ b/torch/csrc/jit/fuser/codegen.cpp
@@ -140,8 +140,6 @@
       {aten::abs, "fabs(${0})"},
       {aten::sigmoid, "1.f / (1.f + expf(-${0}))"},
       {aten::relu, "${0} < 0 ? 0.f : ${0} "},
-      {aten::threshold,
-       "${0} <= ${1} ? static_cast<decltype(${0})>(${2}) : ${0} "},
       {aten::log, "logf(${0})"},
       {aten::log10, "log10f(${0})"},
       {aten::log1p, "log1pf(${0})"},
diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp
index cb69295..0157000 100644
--- a/torch/csrc/jit/passes/graph_fuser.cpp
+++ b/torch/csrc/jit/passes/graph_fuser.cpp
@@ -3,14 +3,11 @@
 #include <ATen/ExpandUtils.h>
 #include <torch/csrc/jit/assertions.h>
 #include <torch/csrc/jit/autodiff.h>
-#include <torch/csrc/jit/custom_operator.h>
 #include <torch/csrc/jit/fuser/interface.h>
 #include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/passes/alias_analysis.h>
 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
 #include <torch/csrc/jit/passes/dead_code_elimination.h>
-#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
-#include <torch/csrc/jit/script/compiler.h>
 #include <torch/csrc/jit/symbolic_variable.h>
 #include <unordered_map>
 
@@ -71,7 +68,6 @@
       //"aten::rand_like(Tensor self) -> Tensor",
       "aten::reciprocal(Tensor self) -> Tensor",
       "aten::relu(Tensor self) -> Tensor",
-      "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor",
       "aten::remainder(Tensor self, Tensor other) -> Tensor",
       "aten::round(Tensor self) -> Tensor",
       "aten::rsqrt(Tensor self) -> Tensor",
@@ -120,44 +116,6 @@
   return true;
 }
 
-RegisterOperators reg_bn_unsqueeze({Operator(
-    "aten::_ncf_unsqueeze(Tensor self, int ndim) -> Tensor",
-    [](const Node* node) {
-      return [](Stack& stack) {
-        const int64_t ndim = pop(stack).toInt();
-        auto self = pop(stack).toTensor();
-        c10::SmallVector<int64_t, 8> sizes(ndim, 1);
-        JIT_ASSERT(self.dim() == 1);
-        sizes.at(1) = self.size(0);
-        push(stack, self.reshape(sizes));
-        return 0;
-      };
-    })});
-
-// Yes, no, or no value if we can't tell
-c10::optional<bool> isDefined(Value* tensor) {
-  if (tensor->type()->isSubtypeOf(DynamicType::get())) {
-    return true;
-  }
-  if (tensor->node()->kind() == prim::None ||
-      tensor->node()->kind() == prim::Undefined) {
-    return false;
-  }
-  return {};
-}
-
-bool isFusableBatchNorm(Node* batch_norm) {
-  if (!batch_norm->matches(
-          "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor")) {
-    return false;
-  }
-  // If we can't determine if weight and bias is defined statically there's
-  // really no point in decomposing batch norm into simpler ops, since it won't
-  // get fused into a single kernel.
-  return isDefined(batch_norm->namedInput(attr::weight)).has_value() &&
-      isDefined(batch_norm->namedInput(attr::bias)).has_value();
-}
-
 Value* broadcastSizes(at::ArrayRef<Value*> sizes) {
   JIT_ASSERT(!sizes.empty());
   Graph* graph = sizes[0]->owningGraph();
@@ -169,7 +127,6 @@
 
 struct GraphFuser {
   Block* block_;
-  c10::optional<AliasDb> aliasDb_;
   std::shared_ptr<Graph> graph_;
 
   GraphFuser(Block* block, std::shared_ptr<Graph> graph)
@@ -182,10 +139,6 @@
   }
 
   bool isFusable(Node* node) {
-    return isFusableMap(node) || isFusableBatchNorm(node);
-  }
-
-  bool isFusableMap(Node* node) {
     // We don't want to bother with cross-block node movements, as they
     // are not necessarily correct.
     if (node->owningBlock() != block_)
@@ -216,7 +169,7 @@
   // cannot be fused because it is not a simple map, can be put in a fusion
   // group as long as no items in the group read the output of concat
   bool isFusableAsExitNode(Node* node) {
-    return isFusableMap(node) || isFusableOnlyAsExitNode(node);
+    return isFusable(node) || isFusableOnlyAsExitNode(node);
   }
 
   bool isFusableOnlyAsExitNode(Node* node) {
@@ -252,59 +205,6 @@
     return *n->g(attr::Subgraph);
   }
 
-  void decomposeBatchNorm(Node* batch_norm) {
-    static std::shared_ptr<Graph> bn_graph;
-    static std::once_flag flag;
-    std::call_once(
-        flag,
-        [](std::shared_ptr<Graph>* graph_ptr) {
-          static const char* source = R"SCRIPT(
-        def batch_norm(input : Tensor, running_mean : Optional[Tensor], running_var : Optional[Tensor], training : bool, momentum : float, eps : float) -> Tensor:
-            if training:
-                norm_mean, norm_var = torch.batch_norm_update_stats(input, running_mean, running_var, momentum)
-            else:
-                norm_mean = torch._unwrap_optional(running_mean)
-                norm_var = torch._unwrap_optional(running_var)
-            norm_mean = torch._ncf_unsqueeze(norm_mean, input.dim())
-            norm_var = torch._ncf_unsqueeze(norm_var, input.dim())
-            norm_invstd = 1 / (eps + torch.sqrt(norm_var))
-            return ((input - norm_mean) * norm_invstd)
-      )SCRIPT";
-          auto module = std::make_shared<script::Module>();
-          defineMethodsInModule(
-              module, source, script::nativeResolver, /*self=*/nullptr);
-          *graph_ptr = module->get_method("batch_norm").graph();
-        },
-        &bn_graph);
-
-    JIT_ASSERT(isFusableBatchNorm(batch_norm));
-    WithInsertPoint insert_guard{batch_norm};
-    Value* input = batch_norm->namedInput(attr::input);
-    Value* input_dim = graph_->insert(aten::dim, {input});
-    std::vector<Value*> inputs{input,
-                               batch_norm->namedInput(attr::running_mean),
-                               batch_norm->namedInput(attr::running_var),
-                               batch_norm->namedInput(attr::training),
-                               batch_norm->namedInput(attr::momentum),
-                               batch_norm->namedInput(attr::eps)};
-    Value* new_output =
-        SubgraphUtils::inlineGraph(bn_graph, inputs, batch_norm).at(0);
-    auto weight = batch_norm->namedInput(attr::weight);
-    auto bias = batch_norm->namedInput(attr::bias);
-    if (isDefined(weight).value()) {
-      Value* expanded_weight =
-          graph_->insert(aten::_ncf_unsqueeze, {weight, input_dim});
-      new_output = graph_->insert(aten::mul, {new_output, expanded_weight});
-    }
-    if (isDefined(bias).value()) {
-      Value* expanded_bias =
-          graph_->insert(aten::_ncf_unsqueeze, {bias, input_dim});
-      new_output = graph_->insert(aten::add, {new_output, expanded_bias});
-    }
-    batch_norm->output()->replaceAllUsesWith(new_output);
-    batch_norm->destroy();
-  }
-
   void mergeFusionGroups(Node* consumer_group, Node* producer_group) {
     // Now we have two fusion groups!
     // Revert the fusion - place all inner nodes of producer back in the outer
@@ -449,7 +349,10 @@
     *insertion_point = n;
   }
 
-  at::optional<Node*> tryFuse(Node* consumer, Value* producer) {
+  at::optional<Node*> tryFuse(
+      Node* consumer,
+      Value* producer,
+      const AliasDb& aliasDb) {
     // this handles cases where producer can be moved _into_ the fusion group of
     // consumer.
     // TODO: extend to fusion of consumer into _producer's_ fusion blob
@@ -465,8 +368,7 @@
         // consumer. Fusion will rewrite those later uses to use the version of
         // producer generated by the fused blob. In this case, producer becomes
         // an output of the fusion group.
-        producer->node()->moveBeforeTopologicallyValid(
-            real_consumer, aliasDb_.value());
+        producer->node()->moveBeforeTopologicallyValid(real_consumer, aliasDb);
 
     if (!shouldFuse) {
       return at::nullopt;
@@ -494,14 +396,6 @@
     } else if (consumer->kind() != prim::FusionGroup) {
       group = createSingletonFusionGroup(consumer);
     }
-    if (producer->node()->matches(
-            "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor")) {
-      // We don't do any fusions in here, but simply decompose the batch norm
-      // into a kernel that computes the stats + pointwise ops which will be
-      // considered in this fusion next.
-      decomposeBatchNorm(producer->node());
-      return group;
-    }
     if (producer->node()->kind() == prim::FusionGroup) {
       mergeFusionGroups(group, producer->node());
       return group;
@@ -755,7 +649,7 @@
         chunk->inputs().begin(),
         chunk->inputs().end(),
         [&](Value* producer_for_chunk) {
-          return isFusableMap(producer_for_chunk->node()) &&
+          return isFusable(producer_for_chunk->node()) &&
               allUsersAreThisConsumerOrCalcSizes(chunk, producer_for_chunk);
         });
     if (it == chunk->inputs().end()) {
@@ -883,7 +777,9 @@
   }
 
   // returns where to continue scanning, and whether any fusion was made
-  std::pair<graph_node_list::iterator, bool> scanNode(Node* consumer) {
+  std::pair<graph_node_list::iterator, bool> scanNode(
+      Node* consumer,
+      const AliasDb& aliasDb) {
     if (isFusableAsExitNode(consumer)) {
       auto consumer_inputs = consumer->kind() == aten::cat
           ? consumer->namedInput(attr::tensors)->node()->inputs()
@@ -901,7 +797,7 @@
           // we scan this consumer again to perform the fusion
           return std::make_pair(consumer->reverseIterator(), true);
         }
-        auto fusion_group = tryFuse(consumer, producer);
+        auto fusion_group = tryFuse(consumer, producer, aliasDb);
         if (fusion_group) {
           // after fusion, consumer moves into a FusionGroup, so inputs is no
           // longer valid so we rescan the new FusionGroup for more fusions...
@@ -1052,10 +948,6 @@
     }
   }
 
-  void refreshAliasDb() {
-    aliasDb_ = AliasAnalysis(graph_);
-  }
-
   void run() {
     // Run the pass until no changes are made.
     // This is neccessary, because the algorithm can miss out on certain fusion
@@ -1076,10 +968,10 @@
     bool any_changed = true;
     while (any_changed) {
       any_changed = false;
-      refreshAliasDb();
+      auto aliasDb = AliasAnalysis(graph_);
       for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
         bool changed;
-        std::tie(it, changed) = scanNode(*it);
+        std::tie(it, changed) = scanNode(*it, aliasDb);
         any_changed |= changed;
       }
     }
diff --git a/torch/csrc/jit/passes/utils/subgraph_utils.cpp b/torch/csrc/jit/passes/utils/subgraph_utils.cpp
index 77f3dbf..5266cf4 100644
--- a/torch/csrc/jit/passes/utils/subgraph_utils.cpp
+++ b/torch/csrc/jit/passes/utils/subgraph_utils.cpp
@@ -4,50 +4,72 @@
 namespace jit {
 namespace SubgraphUtils {
 namespace {
+bool isSubgraphNodeKind(Symbol s) {
+  return s == prim::DifferentiableGraph || s == prim::FusionGroup;
+}
 
-bool hasSubgraph(Node* n) {
-  return n->hasAttribute(attr::Subgraph);
+bool isSubgraphNodeKind(Node* n) {
+  return isSubgraphNodeKind(n->kind());
 }
 
 // Combine the nodes in two subgraph together. The nodes will end up in
 // `mergeTo`, and `mergeFrom` is destroyed.
 void mergeSubgraph(Node* mergeTo, Node* mergeFrom) {
-  Node* nodeBeforeMergeFrom = mergeFrom->prev();
-  Node* nodeAfterMergeFrom = mergeFrom->next();
-  unmergeSubgraph(mergeFrom);
-  std::vector<Node*> nodes;
-  const auto end_it = nodeBeforeMergeFrom->reverseIterator();
-  auto it = nodeAfterMergeFrom->reverseIterator();
-  ++it;
-  while (it != end_it) {
-    // NB: mergeNodeIntoSubgraph destroys node, hence the complications
-    Node* node = *it;
-    ++it;
-    mergeNodeIntoSubgraph(node, mergeTo);
+  const auto nodes = unmergeSubgraph(mergeFrom);
+  for (auto it = nodes.rbegin(); it != nodes.rend(); ++it) {
+    mergeNodeIntoSubgraph(*it, mergeTo);
   }
 }
 } // namespace
 
 std::shared_ptr<Graph> getSubgraph(Node* n) {
+  JIT_ASSERT(isSubgraphNodeKind(n));
   return n->g(attr::Subgraph);
 }
 
-void unmergeSubgraph(Node* subgraphNode) {
+std::vector<Node*> unmergeSubgraph(Node* subgraphNode) {
   JIT_ASSERT(subgraphNode->kind() == prim::DifferentiableGraph);
+  auto outerGraph = subgraphNode->owningGraph();
 
-  // Inline the graph, replace uses of node outputs and destroy the node
-  const auto subgraphOutputs = inlineGraph(
-      getSubgraph(subgraphNode), subgraphNode->inputs(), subgraphNode);
+  std::vector<Node*> temporary_nodes;
+  auto subgraph = getSubgraph(subgraphNode);
+
+  // Initialize a map of inner graph values to outer graph values
+  std::unordered_map<const Value*, Value*> innerToOuter;
+  const auto innerInputs = subgraph->inputs();
+  const auto outerInputs = subgraphNode->inputs();
+  for (size_t i = 0; i < innerInputs.size(); ++i) {
+    innerToOuter[innerInputs[i]] = outerInputs[i];
+  }
+
+  // Clone all nodes
+  for (auto inner : subgraph->nodes()) {
+    Node* outer = outerGraph->createClone(
+        inner, [&](Value* k) -> Value* { return innerToOuter.at(k); });
+    outer->insertBefore(subgraphNode);
+    temporary_nodes.emplace_back(outer);
+    const auto innerOutputs = inner->outputs();
+    const auto outerOutputs = outer->outputs();
+    for (size_t i = 0; i < innerOutputs.size(); ++i) {
+      innerToOuter[innerOutputs[i]] = outerOutputs[i];
+    }
+  }
+
+  // Replace uses of group outputs and destroy the group
+  const auto subgraphOutputs = subgraph->outputs();
   JIT_ASSERT(subgraphOutputs.size() >= subgraphNode->outputs().size());
   for (size_t i = 0; i < subgraphNode->outputs().size(); ++i) {
-    subgraphNode->outputs()[i]->replaceAllUsesWith(subgraphOutputs[i]);
+    const auto outerOutput = innerToOuter.at(subgraphOutputs[i]);
+    subgraphNode->outputs()[i]->replaceAllUsesWith(outerOutput);
   }
   subgraphNode->destroy();
+
+  return temporary_nodes;
 }
 
 void mergeNodeIntoSubgraph(Node* toMerge, Node* subgraphNode) {
-  JIT_ASSERT(hasSubgraph(subgraphNode));
-  if (hasSubgraph(toMerge)) {
+  JIT_ASSERT(isSubgraphNodeKind(subgraphNode));
+  if (isSubgraphNodeKind(toMerge)) {
     return mergeSubgraph(subgraphNode, toMerge);
   }
 
@@ -128,40 +150,8 @@
   toMerge->destroy();
 }
 
-// Invariant we depend on in mergeSubgraph: All inlined nodes are created
-// between the node preceding insertBefore and insertBefore.
-std::vector<Value*> inlineGraph(
-    const std::shared_ptr<Graph>& subgraph,
-    at::ArrayRef<Value*> outerInputs,
-    Node* insertBefore) {
-  auto outerGraph = insertBefore->owningGraph();
-
-  // Initialize a map of inner graph values to outer graph values
-  std::unordered_map<const Value*, Value*> innerToOuter;
-  const auto innerInputs = subgraph->inputs();
-  JIT_ASSERT(outerInputs.size() == innerInputs.size());
-  for (size_t i = 0; i < innerInputs.size(); ++i) {
-    innerToOuter[innerInputs[i]] = outerInputs[i];
-  }
-
-  // Clone all nodes
-  for (auto inner : subgraph->nodes()) {
-    Node* outer = outerGraph->createClone(
-        inner, [&](Value* k) -> Value* { return innerToOuter.at(k); });
-    outer->insertBefore(insertBefore);
-    const auto innerOutputs = inner->outputs();
-    const auto outerOutputs = outer->outputs();
-    for (size_t i = 0; i < innerOutputs.size(); ++i) {
-      innerToOuter[innerOutputs[i]] = outerOutputs[i];
-    }
-  }
-
-  return fmap(subgraph->outputs(), [&](Value* output) {
-    return innerToOuter.at(output);
-  });
-}
-
 Node* createSingletonSubgraph(Node* n, Symbol subgraphKind) {
+  JIT_ASSERT(isSubgraphNodeKind(subgraphKind));
   auto graph = n->owningGraph();
   auto subgraph = graph->create(subgraphKind, 0);
   subgraph->g_(attr::Subgraph, std::make_shared<Graph>(graph->current_scope()));
@@ -169,7 +159,6 @@
   mergeNodeIntoSubgraph(n, subgraph);
   return subgraph;
 }
-
 } // namespace SubgraphUtils
 } // namespace jit
 } // namespace torch
diff --git a/torch/csrc/jit/passes/utils/subgraph_utils.h b/torch/csrc/jit/passes/utils/subgraph_utils.h
index 4d0c449..dc81902 100644
--- a/torch/csrc/jit/passes/utils/subgraph_utils.h
+++ b/torch/csrc/jit/passes/utils/subgraph_utils.h
@@ -26,16 +26,11 @@
 
 // Move nodes from a subgraph node to the outer graph.
 // `subgraphNode` is destroyed.
-void unmergeSubgraph(Node* subgraphNode);
+std::vector<Node*> unmergeSubgraph(Node* subgraphNode);
 
 // Convenience function
 std::shared_ptr<Graph> getSubgraph(Node* n);
 
-std::vector<Value*> inlineGraph(
-    const std::shared_ptr<Graph>& subgraph,
-    at::ArrayRef<Value*> outerInputs,
-    Node* insertBefore);
-
 } // namespace SubgraphUtils
 } // namespace jit
 } // namespace torch