Revert D13548303: [pytorch][PR] Add support for batch_norm fusion to the JIT
Differential Revision:
D13548303
Original commit changeset: a2e2e5abc383
fbshipit-source-id: 5b70cdbcbd1cac06eeefb2a939773358c061183c
diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h
index 5e9144b..40f9b39 100644
--- a/aten/src/ATen/core/interned_strings.h
+++ b/aten/src/ATen/core/interned_strings.h
@@ -71,7 +71,6 @@
_(prim, MMBatchSide) \
_(prim, min) \
_(prim, max) \
- _(aten, _ncf_unsqueeze) \
_(aten, warn) \
_(aten, floordiv) \
_(aten, __round_to_zero_floordiv)\
diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp
index 85f517f..1843f78 100644
--- a/aten/src/ATen/native/Normalization.cpp
+++ b/aten/src/ATen/native/Normalization.cpp
@@ -36,35 +36,24 @@
return t.accessor<scalar_t, 1>();
}
-template<typename T>
-struct InvStd {
- T operator()(T var, double epsilon) const {
- T invstd = 0;
- if (var != static_cast<T>(0) || epsilon != static_cast<T>(0)) {
- invstd = static_cast<T>(1) / std::sqrt(var + epsilon);
- }
- return invstd;
- }
-};
-
-template<typename T>
-struct Var {
- T operator()(T var, double epsilon) const {
- return var;
- }
-};
template<typename scalar_t>
-std::tuple<Tensor,Tensor,Tensor> batch_norm_cpu_transform_input_template(
- const Tensor& input, const Tensor& weight, const Tensor& bias,
- const Tensor& save_mean /* optional */, const Tensor& save_invstd /* optional */,
- const Tensor& running_mean /* optional */, const Tensor& running_var /* optional */,
- bool train, double eps) {
+std::tuple<Tensor,Tensor,Tensor> batch_norm_cpu_template(const Tensor& input, const Tensor& weight, const Tensor& bias,
+ const Tensor& running_mean, const Tensor& running_var, bool train, double momentum, double eps) {
+ using accscalar_t = at::acc_type<scalar_t, false>;
Tensor output = at::empty_like(input);
int64_t n_input = input.size(1);
+ int64_t n = input.numel() / n_input;
+ Tensor save_mean;
+ Tensor save_invstd;
+ const int64_t zero = 0;
+ if (train) {
+ save_mean = at::empty({n_input}, input.options());
+ save_invstd = at::empty({n_input}, input.options());
+ }
auto save_mean_a = conditional_accessor_1d<scalar_t>(save_mean);
auto save_invstd_a = conditional_accessor_1d<scalar_t>(save_invstd);
@@ -72,81 +61,60 @@
auto running_var_a = conditional_accessor_1d<scalar_t>(running_var);
parallel_for(0, n_input, 1, [&](int64_t b_begin, int64_t b_end) {
- for (int64_t f = b_begin; f < b_end; ++f) {
- Tensor in = input.select(1, f);
- Tensor out = output.select(1, f);
+ for (int64_t f = b_begin; f < b_end; ++f) {
+ Tensor in = input.select(1, f);
+ Tensor out = output.select(1, f);
- scalar_t mean, invstd;
- if (train) {
- mean = save_mean_a[f];
- invstd = save_invstd_a[f];
- } else {
- mean = running_mean_a[f];
- invstd = 1 / std::sqrt(running_var_a[f] + eps);
+ scalar_t mean, invstd;
+
+ if (train) {
+ // compute mean per input
+ accscalar_t sum = 0;
+ CPU_tensor_apply1<scalar_t>(in, [&] (const scalar_t& i) {
+ sum += i;
+ });
+
+ mean = (scalar_t) (sum / n);
+ save_mean_a[f] = mean;
+
+ // compute variance per input
+ sum = 0;
+ CPU_tensor_apply1<scalar_t>(in, [&] (const scalar_t& i) {
+ sum += (i - mean) * (i - mean);
+ });
+
+ if (sum == 0 && eps == 0.0) {
+ invstd = 0;
+ } else {
+ invstd = (scalar_t) (1 / std::sqrt(sum/n + eps));
+ }
+ save_invstd_a[f] = invstd;
+
+ // update running averages
+ if (running_mean.defined()) {
+ running_mean_a[f] = momentum * mean + (1 - momentum) * running_mean_a[f];
+ }
+ if (running_var.defined()) {
+ accscalar_t unbiased_var = sum / (n - 1);
+ running_var_a[f] = momentum * unbiased_var + (1 - momentum) * running_var_a[f];
+ }
+ } else {
+ mean = running_mean_a[f];
+ invstd = 1 / std::sqrt(running_var_a[f] + eps);
+ }
+
+ // compute output
+ scalar_t w = weight.defined() ? weight.data<scalar_t>()[f * weight.stride(0)] : 1;
+ scalar_t b = bias.defined() ? bias.data<scalar_t>()[f * bias.stride(0)] : 0;
+
+ CPU_tensor_apply2<scalar_t,scalar_t>(out, in, [&](scalar_t& o, const scalar_t& i) {
+ o = ((i - mean) * invstd) * w + b;
+ });
}
-
- // compute output
- scalar_t w = weight.defined() ? weight.data<scalar_t>()[f * weight.stride(0)] : 1;
- scalar_t b = bias.defined() ? bias.data<scalar_t>()[f * bias.stride(0)] : 0;
-
- CPU_tensor_apply2<scalar_t,scalar_t>(out, in, [&](scalar_t& o, const scalar_t& i) {
- o = ((i - mean) * invstd) * w + b;
- });
- }
- });
+ });
return std::make_tuple(output, save_mean, save_invstd);
}
-template<typename scalar_t, template<typename T> class VarTransform>
-std::tuple<Tensor,Tensor> batch_norm_cpu_update_stats_template(
- const Tensor& input, const Tensor& running_mean, const Tensor& running_var,
- double momentum, double eps) {
-
- using accscalar_t = at::acc_type<scalar_t, false>;
-
- int64_t n_input = input.size(1);
- int64_t n = input.numel() / n_input;
-
- Tensor save_mean = at::empty({n_input}, input.options());
- Tensor save_var_transform = at::empty({n_input}, input.options());
- auto save_mean_a = save_mean.accessor<scalar_t, 1>();
- auto save_var_transform_a = save_var_transform.accessor<scalar_t, 1>();
-
- auto running_mean_a = conditional_accessor_1d<scalar_t>(running_mean);
- auto running_var_a = conditional_accessor_1d<scalar_t>(running_var);
-
- parallel_for(0, n_input, 1, [&](int64_t b_begin, int64_t b_end) {
- for (int64_t f = b_begin; f < b_end; ++f) {
- Tensor in = input.select(1, f);
-
- // compute mean per input
- accscalar_t sum = 0;
- CPU_tensor_apply1<scalar_t>(in, [&] (const scalar_t& i) {
- sum += i;
- });
- scalar_t mean = sum / n;
- save_mean_a[f] = mean;
-
- // compute variance per input
- accscalar_t var_sum = 0;
- CPU_tensor_apply1<scalar_t>(in, [&] (const scalar_t& i) {
- var_sum += (i - mean) * (i - mean);
- });
- save_var_transform_a[f] = VarTransform<accscalar_t>{}(var_sum / n, eps);
-
- // update running averages
- if (running_mean.defined()) {
- running_mean_a[f] = momentum * mean + (1 - momentum) * running_mean_a[f];
- }
- if (running_var.defined()) {
- accscalar_t unbiased_var = var_sum / (n - 1);
- running_var_a[f] = momentum * unbiased_var + (1 - momentum) * running_var_a[f];
- }
- }
- });
- return std::make_tuple(save_mean, save_var_transform);
-}
-
template<typename scalar_t>
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(const Tensor& grad_out_, const Tensor& input, const Tensor& weight,
@@ -451,23 +419,11 @@
}
}
-std::tuple<Tensor, Tensor> batch_norm_update_stats_cpu(
- const Tensor& self, const Tensor& running_mean, const Tensor& running_var, double momentum) {
- return AT_DISPATCH_FLOATING_TYPES(self.type(), "batch_norm_update_stats", [&] {
- return batch_norm_cpu_update_stats_template<scalar_t, Var>(self, running_mean, running_var, momentum, 0);
- });
-}
-
std::tuple<Tensor, Tensor, Tensor> batch_norm_cpu(const Tensor& self, const Tensor& weight, const Tensor& bias,
const Tensor& running_mean, const Tensor& running_var,
bool train, double momentum, double eps) {
return AT_DISPATCH_FLOATING_TYPES(self.type(), "batch_norm", [&] {
- if (!train) {
- return batch_norm_cpu_transform_input_template<scalar_t>(self, weight, bias, {}, {}, running_mean, running_var, train, eps);
- } else {
- auto save_stats = batch_norm_cpu_update_stats_template<scalar_t, InvStd>(self, running_mean, running_var, momentum, eps);
- return batch_norm_cpu_transform_input_template<scalar_t>(self, weight, bias, std::get<0>(save_stats), std::get<1>(save_stats), running_mean, running_var, train, eps);
- }
+ return batch_norm_cpu_template<scalar_t>(self, weight, bias, running_mean, running_var, train, momentum, eps);
});
}
diff --git a/aten/src/ATen/native/cuda/Normalization.cu b/aten/src/ATen/native/cuda/Normalization.cu
index 551e603..3e76d27 100644
--- a/aten/src/ATen/native/cuda/Normalization.cu
+++ b/aten/src/ATen/native/cuda/Normalization.cu
@@ -24,15 +24,4 @@
});
}
-std::tuple<Tensor, Tensor> batch_norm_update_stats_cuda(
- const Tensor& self, const Tensor& running_mean, const Tensor& running_var, double momentum) {
- return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_backward", [&] {
- if (cuda::detail::canUse32BitIndexMath(self)) {
- return batch_norm_update_stats_cuda_template<scalar_t, int32_t>(self, running_mean, running_var, momentum);
- } else {
- return batch_norm_update_stats_cuda_template<scalar_t, int64_t>(self, running_mean, running_var, momentum);
- }
- });
-}
-
} } // namespace at::native
diff --git a/aten/src/ATen/native/cuda/Normalization.cuh b/aten/src/ATen/native/cuda/Normalization.cuh
index bf3e9a5..e186ef3 100644
--- a/aten/src/ATen/native/cuda/Normalization.cuh
+++ b/aten/src/ATen/native/cuda/Normalization.cuh
@@ -200,25 +200,8 @@
}
}
-template<typename T>
-struct InvStd {
- __device__ __forceinline__ T operator()(T var, double epsilon) const {
- T invstd = 0;
- if (var != static_cast<T>(0) || epsilon != static_cast<T>(0)) {
- invstd = static_cast<T>(1) / device_sqrt(var + epsilon);
- }
- return invstd;
- }
-};
-template<typename T>
-struct Var {
- __device__ __forceinline__ T operator()(T var, double epsilon) const {
- return var;
- }
-};
-
-template <template<typename T> class VarTransform, typename scalar_t, typename accscalar_t, typename index_t>
+template <typename scalar_t, typename accscalar_t, typename index_t>
__global__ void batch_norm_collect_statistics_kernel(
const PackedTensorAccessor<scalar_t, 3, RestrictPtrTraits, index_t> input,
const accscalar_t epsilon,
@@ -226,7 +209,7 @@
PackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> running_mean,
PackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> running_var,
PackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> save_mean,
- PackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> save_transformed_var) {
+ PackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> save_invstd) {
__shared__ int shared_n[2 * 2 * WARP_SIZE + WARP_SIZE];
@@ -269,7 +252,7 @@
// this writes each warps item into shared memory
// there are at most WARP_SIZE items left because
- // there are at most WARP_SIZE**2 threads at the beginning
+ // there are at most WARP_SIZE**2 threads at the beginning
__syncthreads();
if (tid % WARP_SIZE == 0) {
shared_n[tid / WARP_SIZE] = n;
@@ -297,8 +280,12 @@
// Save the mean, variance, and moving averages
if (tid == 0) {
+ accscalar_t invstd = 0;
+ if (var_n != static_cast<accscalar_t>(0) || epsilon != static_cast<accscalar_t>(0)) {
+ invstd = static_cast<accscalar_t>(1) / device_sqrt(var_n / N + epsilon);
+ }
save_mean[plane] = avg;
- save_transformed_var[plane] = VarTransform<accscalar_t>{}(var_n / N, epsilon);
+ save_invstd[plane] = invstd;
if (running_mean.data() != NULL) {
running_mean[plane] = static_cast<scalar_t>((1 - momentum) * running_mean[plane] + momentum * avg);
}
@@ -444,7 +431,7 @@
dim3 blocks(input.size(1));
tf = getNumThreads(input.size(2));
dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
- batch_norm_collect_statistics_kernel<InvStd, scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
+ batch_norm_collect_statistics_kernel<scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
(input, epsilon, momentum, running_mean, running_var, save_mean, save_invstd);
batch_norm_transform_input_kernel<scalar_t, accscalar_t, true, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
(input, output, save_mean, save_invstd, weight, bias, epsilon);
@@ -501,39 +488,4 @@
return std::make_tuple(grad_input_, grad_weight_, grad_bias_);
}
-template<typename scalar_t, typename index_t>
-std::tuple<Tensor, Tensor> batch_norm_update_stats_cuda_template(
- const Tensor& input_, const Tensor& running_mean_, const Tensor& running_var_, double momentum) {
-
- using accscalar_t = at::acc_type<scalar_t, true>;
- int64_t n_channels = input_.size(1);
- auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
-
- auto input_options = input_.options();
- if (input_.type().scalarType() == at::ScalarType::Half) {
- input_options = input_options.dtype(ScalarType::Float);
- }
- Tensor save_mean_ = at::empty({n_channels}, input_options);
- Tensor save_var_ = at::empty({n_channels}, input_options);
-
- auto input = input_reshaped.packed_accessor<scalar_t, 3, RestrictPtrTraits, index_t>();
- auto running_mean = packed_accessor_or_dummy<scalar_t, 1, RestrictPtrTraits, index_t>(running_mean_);
- auto running_var = packed_accessor_or_dummy<scalar_t, 1, RestrictPtrTraits, index_t>(running_var_);
- auto save_mean = save_mean_.packed_accessor<accscalar_t, 1, RestrictPtrTraits, index_t>();
- auto save_var = save_var_.packed_accessor<accscalar_t, 1, RestrictPtrTraits, index_t>();
- auto stream = at::cuda::getCurrentCUDAStream();
-
- // for the reduction, we cannot use blocks for the batch dim, but if we have few threads in
- // the feature dimension, we'll use some threads for blocks
- dim3 blocks(input.size(1));
- int tf = getNumThreads(input.size(2));
- dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
- // NB: epsilon is unused by the Var transform, so we set it to 0
- batch_norm_collect_statistics_kernel<Var, scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
- (input, 0., momentum, running_mean, running_var, save_mean, save_var);
- THCudaCheck(cudaGetLastError());
- return std::make_tuple(save_mean_, save_var_);
-
-}
-
} } // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 7c6cbe9..fa01c99 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1247,11 +1247,6 @@
CPU: batch_norm_backward_cpu
CUDA: batch_norm_backward_cuda
-- func: batch_norm_update_stats(Tensor input, Tensor? running_mean, Tensor? running_var, double momentum) -> (Tensor, Tensor)
- dispatch:
- CPU: batch_norm_update_stats_cpu
- CUDA: batch_norm_update_stats_cuda
-
- func: ones(IntList size, TensorOptions options={}) -> Tensor
- func: ones_out(Tensor result, IntList size) -> Tensor
diff --git a/test/test_jit.py b/test/test_jit.py
index 0f43f48..5d959bd 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -10409,58 +10409,6 @@
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- def test_fuse_batch_norm(self):
-
- class ResLike(torch.jit.ScriptModule):
- def __init__(self, optimize=True):
- super(ResLike, self).__init__(optimize)
- self.bn = nn.BatchNorm2d(16)
-
- @torch.jit.script_method
- def forward(self, x, y):
- return y + torch.relu(self.bn(x))
-
- model = ResLike().cuda()
- model_noopt = ResLike(optimize=False).cuda()
- model_noopt.load_state_dict(model.state_dict())
- x = torch.randn(2, 16, 8, 8, device='cuda')
- y = torch.randn(2, 16, 8, 8, device='cuda')
- # FIXME: We need differentiation for CNNs for this optimization to trigger
- with torch.no_grad():
- out = model(x, y)
- graph = model.graph_for(x, y)
- rep = str(graph)
-
- out_noopt = model_noopt(x, y)
- rep_noopt = str(model_noopt.graph_for(x, y))
- self.assertEqual(out, out_noopt, prec=3e-5)
-
- # Check that batch_norm has really been decomposed
- self.assertIn('aten::batch_norm_update_stats', rep)
- self.assertNotIn('aten::batch_norm(', rep)
- self.assertIn('aten::batch_norm(', rep_noopt)
-
- # Make sure the fusion group is big, and contains aten::sqrt, which could
- # originate only from decomposing batch_norm in this case
- fusion_groups = [node for node in graph.nodes() if node.kind() == 'prim::FusionGroup']
- self.assertEqual(len(fusion_groups), 1)
- fused_graph = fusion_groups[0].g('Subgraph')
- self.assertTrue(any(node.kind() == 'aten::sqrt' for node in fused_graph.nodes()))
-
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- def test_threshold(self):
- def f(x):
- return torch.threshold(x, 0, -10) + x + x + x
-
- x = torch.tensor([-1, -0.5, 0, 1, 2, 3], device='cuda')
- scripted = torch.jit.script(f)
-
- self.assertEqual(f(x), scripted(x))
- self.assertAllFused(scripted.graph_for(x))
-
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
@skipIfRocm
@enable_cpu_fuser
diff --git a/torch/csrc/jit/fuser/codegen.cpp b/torch/csrc/jit/fuser/codegen.cpp
index 907f2d1..e1ed130 100644
--- a/torch/csrc/jit/fuser/codegen.cpp
+++ b/torch/csrc/jit/fuser/codegen.cpp
@@ -140,8 +140,6 @@
{aten::abs, "fabs(${0})"},
{aten::sigmoid, "1.f / (1.f + expf(-${0}))"},
{aten::relu, "${0} < 0 ? 0.f : ${0} "},
- {aten::threshold,
- "${0} <= ${1} ? static_cast<decltype(${0})>(${2}) : ${0} "},
{aten::log, "logf(${0})"},
{aten::log10, "log10f(${0})"},
{aten::log1p, "log1pf(${0})"},
diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp
index cb69295..0157000 100644
--- a/torch/csrc/jit/passes/graph_fuser.cpp
+++ b/torch/csrc/jit/passes/graph_fuser.cpp
@@ -3,14 +3,11 @@
#include <ATen/ExpandUtils.h>
#include <torch/csrc/jit/assertions.h>
#include <torch/csrc/jit/autodiff.h>
-#include <torch/csrc/jit/custom_operator.h>
#include <torch/csrc/jit/fuser/interface.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/passes/alias_analysis.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
-#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
-#include <torch/csrc/jit/script/compiler.h>
#include <torch/csrc/jit/symbolic_variable.h>
#include <unordered_map>
@@ -71,7 +68,6 @@
//"aten::rand_like(Tensor self) -> Tensor",
"aten::reciprocal(Tensor self) -> Tensor",
"aten::relu(Tensor self) -> Tensor",
- "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor",
"aten::remainder(Tensor self, Tensor other) -> Tensor",
"aten::round(Tensor self) -> Tensor",
"aten::rsqrt(Tensor self) -> Tensor",
@@ -120,44 +116,6 @@
return true;
}
-RegisterOperators reg_bn_unsqueeze({Operator(
- "aten::_ncf_unsqueeze(Tensor self, int ndim) -> Tensor",
- [](const Node* node) {
- return [](Stack& stack) {
- const int64_t ndim = pop(stack).toInt();
- auto self = pop(stack).toTensor();
- c10::SmallVector<int64_t, 8> sizes(ndim, 1);
- JIT_ASSERT(self.dim() == 1);
- sizes.at(1) = self.size(0);
- push(stack, self.reshape(sizes));
- return 0;
- };
- })});
-
-// Yes, no, or no value if we can't tell
-c10::optional<bool> isDefined(Value* tensor) {
- if (tensor->type()->isSubtypeOf(DynamicType::get())) {
- return true;
- }
- if (tensor->node()->kind() == prim::None ||
- tensor->node()->kind() == prim::Undefined) {
- return false;
- }
- return {};
-}
-
-bool isFusableBatchNorm(Node* batch_norm) {
- if (!batch_norm->matches(
- "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor")) {
- return false;
- }
- // If we can't determine if weight and bias is defined statically there's
- // really no point in decomposing batch norm into simpler ops, since it won't
- // get fused into a single kernel.
- return isDefined(batch_norm->namedInput(attr::weight)).has_value() &&
- isDefined(batch_norm->namedInput(attr::bias)).has_value();
-}
-
Value* broadcastSizes(at::ArrayRef<Value*> sizes) {
JIT_ASSERT(!sizes.empty());
Graph* graph = sizes[0]->owningGraph();
@@ -169,7 +127,6 @@
struct GraphFuser {
Block* block_;
- c10::optional<AliasDb> aliasDb_;
std::shared_ptr<Graph> graph_;
GraphFuser(Block* block, std::shared_ptr<Graph> graph)
@@ -182,10 +139,6 @@
}
bool isFusable(Node* node) {
- return isFusableMap(node) || isFusableBatchNorm(node);
- }
-
- bool isFusableMap(Node* node) {
// We don't want to bother with cross-block node movements, as they
// are not necessarily correct.
if (node->owningBlock() != block_)
@@ -216,7 +169,7 @@
// cannot be fused because it is not a simple map, can be put in a fusion
// group as long as no items in the group read the output of concat
bool isFusableAsExitNode(Node* node) {
- return isFusableMap(node) || isFusableOnlyAsExitNode(node);
+ return isFusable(node) || isFusableOnlyAsExitNode(node);
}
bool isFusableOnlyAsExitNode(Node* node) {
@@ -252,59 +205,6 @@
return *n->g(attr::Subgraph);
}
- void decomposeBatchNorm(Node* batch_norm) {
- static std::shared_ptr<Graph> bn_graph;
- static std::once_flag flag;
- std::call_once(
- flag,
- [](std::shared_ptr<Graph>* graph_ptr) {
- static const char* source = R"SCRIPT(
- def batch_norm(input : Tensor, running_mean : Optional[Tensor], running_var : Optional[Tensor], training : bool, momentum : float, eps : float) -> Tensor:
- if training:
- norm_mean, norm_var = torch.batch_norm_update_stats(input, running_mean, running_var, momentum)
- else:
- norm_mean = torch._unwrap_optional(running_mean)
- norm_var = torch._unwrap_optional(running_var)
- norm_mean = torch._ncf_unsqueeze(norm_mean, input.dim())
- norm_var = torch._ncf_unsqueeze(norm_var, input.dim())
- norm_invstd = 1 / (eps + torch.sqrt(norm_var))
- return ((input - norm_mean) * norm_invstd)
- )SCRIPT";
- auto module = std::make_shared<script::Module>();
- defineMethodsInModule(
- module, source, script::nativeResolver, /*self=*/nullptr);
- *graph_ptr = module->get_method("batch_norm").graph();
- },
- &bn_graph);
-
- JIT_ASSERT(isFusableBatchNorm(batch_norm));
- WithInsertPoint insert_guard{batch_norm};
- Value* input = batch_norm->namedInput(attr::input);
- Value* input_dim = graph_->insert(aten::dim, {input});
- std::vector<Value*> inputs{input,
- batch_norm->namedInput(attr::running_mean),
- batch_norm->namedInput(attr::running_var),
- batch_norm->namedInput(attr::training),
- batch_norm->namedInput(attr::momentum),
- batch_norm->namedInput(attr::eps)};
- Value* new_output =
- SubgraphUtils::inlineGraph(bn_graph, inputs, batch_norm).at(0);
- auto weight = batch_norm->namedInput(attr::weight);
- auto bias = batch_norm->namedInput(attr::bias);
- if (isDefined(weight).value()) {
- Value* expanded_weight =
- graph_->insert(aten::_ncf_unsqueeze, {weight, input_dim});
- new_output = graph_->insert(aten::mul, {new_output, expanded_weight});
- }
- if (isDefined(bias).value()) {
- Value* expanded_bias =
- graph_->insert(aten::_ncf_unsqueeze, {bias, input_dim});
- new_output = graph_->insert(aten::add, {new_output, expanded_bias});
- }
- batch_norm->output()->replaceAllUsesWith(new_output);
- batch_norm->destroy();
- }
-
void mergeFusionGroups(Node* consumer_group, Node* producer_group) {
// Now we have two fusion groups!
// Revert the fusion - place all inner nodes of producer back in the outer
@@ -449,7 +349,10 @@
*insertion_point = n;
}
- at::optional<Node*> tryFuse(Node* consumer, Value* producer) {
+ at::optional<Node*> tryFuse(
+ Node* consumer,
+ Value* producer,
+ const AliasDb& aliasDb) {
// this handles cases where producer can be moved _into_ the fusion group of
// consumer.
// TODO: extend to fusion of consumer into _producer's_ fusion blob
@@ -465,8 +368,7 @@
// consumer. Fusion will rewrite those later uses to use the version of
// producer generated by the fused blob. In this case, producer becomes
// an output of the fusion group.
- producer->node()->moveBeforeTopologicallyValid(
- real_consumer, aliasDb_.value());
+ producer->node()->moveBeforeTopologicallyValid(real_consumer, aliasDb);
if (!shouldFuse) {
return at::nullopt;
@@ -494,14 +396,6 @@
} else if (consumer->kind() != prim::FusionGroup) {
group = createSingletonFusionGroup(consumer);
}
- if (producer->node()->matches(
- "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor")) {
- // We don't do any fusions in here, but simply decompose the batch norm
- // into a kernel that computes the stats + pointwise ops which will be
- // considered in this fusion next.
- decomposeBatchNorm(producer->node());
- return group;
- }
if (producer->node()->kind() == prim::FusionGroup) {
mergeFusionGroups(group, producer->node());
return group;
@@ -755,7 +649,7 @@
chunk->inputs().begin(),
chunk->inputs().end(),
[&](Value* producer_for_chunk) {
- return isFusableMap(producer_for_chunk->node()) &&
+ return isFusable(producer_for_chunk->node()) &&
allUsersAreThisConsumerOrCalcSizes(chunk, producer_for_chunk);
});
if (it == chunk->inputs().end()) {
@@ -883,7 +777,9 @@
}
// returns where to continue scanning, and whether any fusion was made
- std::pair<graph_node_list::iterator, bool> scanNode(Node* consumer) {
+ std::pair<graph_node_list::iterator, bool> scanNode(
+ Node* consumer,
+ const AliasDb& aliasDb) {
if (isFusableAsExitNode(consumer)) {
auto consumer_inputs = consumer->kind() == aten::cat
? consumer->namedInput(attr::tensors)->node()->inputs()
@@ -901,7 +797,7 @@
// we scan this consumer again to perform the fusion
return std::make_pair(consumer->reverseIterator(), true);
}
- auto fusion_group = tryFuse(consumer, producer);
+ auto fusion_group = tryFuse(consumer, producer, aliasDb);
if (fusion_group) {
// after fusion, consumer moves into a FusionGroup, so inputs is no
// longer valid so we rescan the new FusionGroup for more fusions...
@@ -1052,10 +948,6 @@
}
}
- void refreshAliasDb() {
- aliasDb_ = AliasAnalysis(graph_);
- }
-
void run() {
// Run the pass until no changes are made.
// This is neccessary, because the algorithm can miss out on certain fusion
@@ -1076,10 +968,10 @@
bool any_changed = true;
while (any_changed) {
any_changed = false;
- refreshAliasDb();
+ auto aliasDb = AliasAnalysis(graph_);
for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
bool changed;
- std::tie(it, changed) = scanNode(*it);
+ std::tie(it, changed) = scanNode(*it, aliasDb);
any_changed |= changed;
}
}
diff --git a/torch/csrc/jit/passes/utils/subgraph_utils.cpp b/torch/csrc/jit/passes/utils/subgraph_utils.cpp
index 77f3dbf..5266cf4 100644
--- a/torch/csrc/jit/passes/utils/subgraph_utils.cpp
+++ b/torch/csrc/jit/passes/utils/subgraph_utils.cpp
@@ -4,50 +4,72 @@
namespace jit {
namespace SubgraphUtils {
namespace {
+bool isSubgraphNodeKind(Symbol s) {
+ return s == prim::DifferentiableGraph || s == prim::FusionGroup;
+}
-bool hasSubgraph(Node* n) {
- return n->hasAttribute(attr::Subgraph);
+bool isSubgraphNodeKind(Node* n) {
+ return isSubgraphNodeKind(n->kind());
}
// Combine the nodes in two subgraph together. The nodes will end up in
// `mergeTo`, and `mergeFrom` is destroyed.
void mergeSubgraph(Node* mergeTo, Node* mergeFrom) {
- Node* nodeBeforeMergeFrom = mergeFrom->prev();
- Node* nodeAfterMergeFrom = mergeFrom->next();
- unmergeSubgraph(mergeFrom);
- std::vector<Node*> nodes;
- const auto end_it = nodeBeforeMergeFrom->reverseIterator();
- auto it = nodeAfterMergeFrom->reverseIterator();
- ++it;
- while (it != end_it) {
- // NB: mergeNodeIntoSubgraph destroys node, hence the complications
- Node* node = *it;
- ++it;
- mergeNodeIntoSubgraph(node, mergeTo);
+ const auto nodes = unmergeSubgraph(mergeFrom);
+ for (auto it = nodes.rbegin(); it != nodes.rend(); ++it) {
+ mergeNodeIntoSubgraph(*it, mergeTo);
}
}
} // namespace
std::shared_ptr<Graph> getSubgraph(Node* n) {
+ JIT_ASSERT(isSubgraphNodeKind(n));
return n->g(attr::Subgraph);
}
-void unmergeSubgraph(Node* subgraphNode) {
+std::vector<Node*> unmergeSubgraph(Node* subgraphNode) {
JIT_ASSERT(subgraphNode->kind() == prim::DifferentiableGraph);
+ auto outerGraph = subgraphNode->owningGraph();
- // Inline the graph, replace uses of node outputs and destroy the node
- const auto subgraphOutputs = inlineGraph(
- getSubgraph(subgraphNode), subgraphNode->inputs(), subgraphNode);
+ std::vector<Node*> temporary_nodes;
+ auto subgraph = getSubgraph(subgraphNode);
+
+ // Initialize a map of inner graph values to outer graph values
+ std::unordered_map<const Value*, Value*> innerToOuter;
+ const auto innerInputs = subgraph->inputs();
+ const auto outerInputs = subgraphNode->inputs();
+ for (size_t i = 0; i < innerInputs.size(); ++i) {
+ innerToOuter[innerInputs[i]] = outerInputs[i];
+ }
+
+ // Clone all nodes
+ for (auto inner : subgraph->nodes()) {
+ Node* outer = outerGraph->createClone(
+ inner, [&](Value* k) -> Value* { return innerToOuter.at(k); });
+ outer->insertBefore(subgraphNode);
+ temporary_nodes.emplace_back(outer);
+ const auto innerOutputs = inner->outputs();
+ const auto outerOutputs = outer->outputs();
+ for (size_t i = 0; i < innerOutputs.size(); ++i) {
+ innerToOuter[innerOutputs[i]] = outerOutputs[i];
+ }
+ }
+
+ // Replace uses of group outputs and destroy the group
+ const auto subgraphOutputs = subgraph->outputs();
JIT_ASSERT(subgraphOutputs.size() >= subgraphNode->outputs().size());
for (size_t i = 0; i < subgraphNode->outputs().size(); ++i) {
- subgraphNode->outputs()[i]->replaceAllUsesWith(subgraphOutputs[i]);
+ const auto outerOutput = innerToOuter.at(subgraphOutputs[i]);
+ subgraphNode->outputs()[i]->replaceAllUsesWith(outerOutput);
}
subgraphNode->destroy();
+
+ return temporary_nodes;
}
void mergeNodeIntoSubgraph(Node* toMerge, Node* subgraphNode) {
- JIT_ASSERT(hasSubgraph(subgraphNode));
- if (hasSubgraph(toMerge)) {
+ JIT_ASSERT(isSubgraphNodeKind(subgraphNode));
+ if (isSubgraphNodeKind(toMerge)) {
return mergeSubgraph(subgraphNode, toMerge);
}
@@ -128,40 +150,8 @@
toMerge->destroy();
}
-// Invariant we depend on in mergeSubgraph: All inlined nodes are created
-// between the node preceding insertBefore and insertBefore.
-std::vector<Value*> inlineGraph(
- const std::shared_ptr<Graph>& subgraph,
- at::ArrayRef<Value*> outerInputs,
- Node* insertBefore) {
- auto outerGraph = insertBefore->owningGraph();
-
- // Initialize a map of inner graph values to outer graph values
- std::unordered_map<const Value*, Value*> innerToOuter;
- const auto innerInputs = subgraph->inputs();
- JIT_ASSERT(outerInputs.size() == innerInputs.size());
- for (size_t i = 0; i < innerInputs.size(); ++i) {
- innerToOuter[innerInputs[i]] = outerInputs[i];
- }
-
- // Clone all nodes
- for (auto inner : subgraph->nodes()) {
- Node* outer = outerGraph->createClone(
- inner, [&](Value* k) -> Value* { return innerToOuter.at(k); });
- outer->insertBefore(insertBefore);
- const auto innerOutputs = inner->outputs();
- const auto outerOutputs = outer->outputs();
- for (size_t i = 0; i < innerOutputs.size(); ++i) {
- innerToOuter[innerOutputs[i]] = outerOutputs[i];
- }
- }
-
- return fmap(subgraph->outputs(), [&](Value* output) {
- return innerToOuter.at(output);
- });
-}
-
Node* createSingletonSubgraph(Node* n, Symbol subgraphKind) {
+ JIT_ASSERT(isSubgraphNodeKind(subgraphKind));
auto graph = n->owningGraph();
auto subgraph = graph->create(subgraphKind, 0);
subgraph->g_(attr::Subgraph, std::make_shared<Graph>(graph->current_scope()));
@@ -169,7 +159,6 @@
mergeNodeIntoSubgraph(n, subgraph);
return subgraph;
}
-
} // namespace SubgraphUtils
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/passes/utils/subgraph_utils.h b/torch/csrc/jit/passes/utils/subgraph_utils.h
index 4d0c449..dc81902 100644
--- a/torch/csrc/jit/passes/utils/subgraph_utils.h
+++ b/torch/csrc/jit/passes/utils/subgraph_utils.h
@@ -26,16 +26,11 @@
// Move nodes from a subgraph node to the outer graph.
// `subgraphNode` is destroyed.
-void unmergeSubgraph(Node* subgraphNode);
+std::vector<Node*> unmergeSubgraph(Node* subgraphNode);
// Convenience function
std::shared_ptr<Graph> getSubgraph(Node* n);
-std::vector<Value*> inlineGraph(
- const std::shared_ptr<Graph>& subgraph,
- at::ArrayRef<Value*> outerInputs,
- Node* insertBefore);
-
} // namespace SubgraphUtils
} // namespace jit
} // namespace torch