[nvFuser] Working towards reductions, codegen improvements (#40864)
Summary:
Have basic reduction fusion working, and have improved code generator to approach performance of eager mode reductions. Coming soon will be pointwise-reduction fusions in a way that should prevent the possibility of hitting regressions. Also working on performant softmax kernels in the code generator which may be our next fusion target.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40864
Reviewed By: ngimel
Differential Revision: D22392877
Pulled By: soumith
fbshipit-source-id: 457448a807d628b1035f6d90bc0abe8a87bf8447
diff --git a/.jenkins/pytorch/macos-test.sh b/.jenkins/pytorch/macos-test.sh
index 1fffbe2..7ec76d2 100755
--- a/.jenkins/pytorch/macos-test.sh
+++ b/.jenkins/pytorch/macos-test.sh
@@ -63,7 +63,7 @@
# Increase default limit on open file handles from 256 to 1024
ulimit -n 1024
- python test/run_test.py --verbose --exclude test_jit_profiling test_jit_legacy test_jit_fuser_legacy test_jit_fuser_te test_tensorexpr --determine-from="$DETERMINE_FROM"
+ python test/run_test.py --verbose --exclude test_jit_cuda_fuser_profiling test_jit_cuda_fuser_legacy test_jit_profiling test_jit_legacy test_jit_fuser_legacy test_jit_fuser_te test_tensorexpr --determine-from="$DETERMINE_FROM"
assert_git_not_dirty
}
diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh
index c1e67b8..72d1cd4 100755
--- a/.jenkins/pytorch/test.sh
+++ b/.jenkins/pytorch/test.sh
@@ -150,17 +150,17 @@
}
test_python_ge_config_profiling() {
- time python test/run_test.py --include test_jit_profiling test_jit_fuser_te --verbose --determine-from="$DETERMINE_FROM"
+ time python test/run_test.py --include test_jit_cuda_fuser_profiling test_jit_profiling test_jit_fuser_te --verbose --determine-from="$DETERMINE_FROM"
assert_git_not_dirty
}
test_python_ge_config_legacy() {
- time python test/run_test.py --include test_jit_legacy test_jit_fuser_legacy --verbose --determine-from="$DETERMINE_FROM"
+ time python test/run_test.py --include test_jit_cuda_fuser_legacy test_jit_legacy test_jit_fuser_legacy --verbose --determine-from="$DETERMINE_FROM"
assert_git_not_dirty
}
test_python_all_except_nn_and_cpp_extensions() {
- time python test/run_test.py --exclude test_nn test_jit_profiling test_jit_legacy test_jit_fuser_legacy test_jit_fuser_te test_tensorexpr --verbose --determine-from="$DETERMINE_FROM"
+ time python test/run_test.py --exclude test_jit_cuda_fuser_profiling test_jit_cuda_fuser_legacy test_nn test_jit_profiling test_jit_legacy test_jit_fuser_legacy test_jit_fuser_te test_tensorexpr --verbose --determine-from="$DETERMINE_FROM"
assert_git_not_dirty
}
diff --git a/.jenkins/pytorch/win-test-helpers/test_python_all_except_nn.bat b/.jenkins/pytorch/win-test-helpers/test_python_all_except_nn.bat
index dd4339b..4bfb5bc 100644
--- a/.jenkins/pytorch/win-test-helpers/test_python_all_except_nn.bat
+++ b/.jenkins/pytorch/win-test-helpers/test_python_all_except_nn.bat
@@ -1,3 +1,3 @@
call %SCRIPT_HELPERS_DIR%\setup_pytorch_env.bat
-cd test && python run_test.py --exclude test_jit_profiling test_jit_legacy test_jit_fuser_legacy test_jit_fuser_te test_tensorexpr --verbose --determine-from="%1" && cd ..
+cd test && python run_test.py --exclude test_jit_cuda_fuser_profiling test_jit_cuda_fuser_legacy test_jit_profiling test_jit_legacy test_jit_fuser_legacy test_jit_fuser_te test_tensorexpr --verbose --determine-from="%1" && cd ..
if ERRORLEVEL 1 exit /b 1
diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt
index 7693f62..2e36a97 100644
--- a/caffe2/CMakeLists.txt
+++ b/caffe2/CMakeLists.txt
@@ -448,12 +448,14 @@
${TORCH_SRC_DIR}/csrc/autograd/functions/comm.cpp
${TORCH_SRC_DIR}/csrc/cuda/comm.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/arith.cpp
+ ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/compute_at.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/dispatch.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/expr_evaluator.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/fusion.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/graph_fuser.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/index_compute.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/ir_base_nodes.cpp
+ ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/ir_cloner.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/ir_graphviz.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/ir_nodes.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/ir_iostream.cpp
@@ -463,13 +465,16 @@
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/manager.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/shape_inference.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/mutator.cpp
+ ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_index.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_loops.cpp
+ ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_thread_predicate.cpp
+ ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_unroll.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_utils.cpp
+ ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_validation.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower2device.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/parser.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/partition.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/predicate_compute.cpp
- ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/tensor_meta.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/tensor_view.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/transform_iter.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/transform_replay.cpp
diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp
index 84fa69d..0edcac5 100644
--- a/test/cpp/jit/test_gpu.cpp
+++ b/test/cpp/jit/test_gpu.cpp
@@ -10,7 +10,6 @@
#include <torch/csrc/jit/codegen/cuda/kernel.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/mutator.h>
-#include <torch/csrc/jit/codegen/cuda/tensor_meta.h>
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
#include <torch/csrc/jit/codegen/cuda/transform_rfactor.h>
@@ -38,7 +37,7 @@
static void checkIntValue(
const EvaluationContext* eval_context,
- const Val* val,
+ Val* val,
Int::ScalarType expected_value) {
TORCH_CHECK(val->isAnInt());
const auto actual_value = ExpressionEvaluator::evaluate(val, eval_context);
@@ -66,27 +65,24 @@
TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);
- TensorView* tv1 = mul(tv0, new Float(-1.0));
- TensorView* tv2 = add(tv0, new Float(3.0));
- TensorView* tv3 = mul(tv0, new Float(2.5));
- TensorView* tv4 = add(tv2, tv1);
- TensorView* tv5 = add(tv4, tv3);
- TensorView* tv6 = add(tv0, tv3);
+ TensorView* tv2 = add(tv0, new Float(3.141));
+ TensorView* tv3 = broadcast(tv0, {false, true, false, true});
+ TensorView* tv4 = reductionOp(BinaryOpType::Add, {2}, new Float(0), tv3);
+ TensorView* tv5 = clamp(tv4, new Float(0.f), new Float(1.f));
+ TensorView* tv6 = add(tv2, tv2);
// Another checkpoint before adding outputs
TORCH_CHECK(!IrGraphGenerator::toGraphviz(
&fusion, IrGraphGenerator::DetailLevel::Explicit)
.empty());
- fusion.addOutput(tv5);
fusion.addOutput(tv6);
tv6->merge(0);
tv6->split(0, 4);
tv6->axis(0)->parallelize(ParallelType::BIDx);
tv5->reorder({{-1, 0}});
- tv0->computeAt(tv3, 1);
- tv0->computeAt(tv6, 1);
+ tv2->computeAt(tv6, 1);
// Another checkpoint with more node types
TORCH_CHECK(!IrGraphGenerator::toGraphviz(
@@ -148,11 +144,22 @@
auto* a = new Int();
auto* b = new Int();
auto* c = add(a, b);
- auto* d = neg(ceilDiv(add(a, b), b));
+ auto* d = neg(ceilDiv(c, b));
+ auto* e = new Int(0);
+
+ // trying to evaluate before binding should give empty results
+ TORCH_CHECK(!ExpressionEvaluator::evaluate(a, &eval_context).has_value());
+ TORCH_CHECK(!ExpressionEvaluator::evaluate(d, &eval_context).has_value());
eval_context.bind(a, 7);
eval_context.bind(b, 3);
+ // can't bind to the results of expressions
+ ASSERT_ANY_THROW(eval_context.bind(c, 100));
+
+ // can't bind to concrete values
+ ASSERT_ANY_THROW(eval_context.bind(e, 100));
+
checkIntValue(&eval_context, c, 10);
checkIntValue(&eval_context, sub(a, b), 4);
checkIntValue(&eval_context, mod(a, b), 1);
@@ -277,6 +284,290 @@
checkIntValue(&eval_context, tv6->axis(2)->rawExtent(), 127);
}
+// Evaluate expressions post lowering
+void testGPU_FusionExprEvalPostLower() {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
+
+ // Create a non-trivial IR
+ TensorView* tv0 = makeDummyTensor(2);
+ TensorView* tv1 = makeDummyTensor(2);
+
+ fusion.addInput(tv0);
+ fusion.addInput(tv1);
+
+ TensorView* tv2 = add(tv1, new Float(2.0));
+ TensorView* tv3 = add(tv0, tv2);
+
+ fusion.addOutput(tv3);
+
+ tv3->split(0, 4);
+
+ tv0->computeAt(tv3, 1);
+ tv1->computeAt(tv3, 1);
+
+ tv3->axis(0)->parallelize(ParallelType::BIDx);
+ tv2->axis(1)->parallelize(ParallelType::Unroll);
+ tv3->axis(1)->parallelize(ParallelType::Unroll);
+ tv2->axis(-1)->parallelize(ParallelType::TIDx);
+ tv3->axis(-1)->parallelize(ParallelType::TIDx);
+
+ auto* bid_x = add(tv3->axis(0)->rawExtent(), new Int(0));
+ auto* tid_x = add(tv3->axis(-1)->rawExtent(), new Int(0));
+
+ // Lower
+ GPULower gpulw(&fusion);
+ std::stringstream kernel;
+ gpulw.printKernel(kernel);
+
+ // 1. Create an evaluation context
+ EvaluationContext eval_context(&fusion);
+
+ // 2. Bind values
+ eval_context.bind(tv0->getRootDomain()[0]->extent(), 6);
+ eval_context.bind(tv0->getRootDomain()[1]->extent(), 128);
+ eval_context.bind(tv1->getRootDomain()[0]->extent(), 6);
+ eval_context.bind(tv1->getRootDomain()[1]->extent(), 128);
+
+ // 3. Evaluate and check result values
+ TORCH_CHECK(tv2->domain()->nDims() == 3);
+ checkIntValue(&eval_context, tv2->axis(0)->rawExtent(), 2);
+ checkIntValue(&eval_context, tv2->axis(1)->rawExtent(), 4);
+ checkIntValue(&eval_context, tv2->axis(2)->rawExtent(), 128);
+
+ TORCH_CHECK(tv3->domain()->nDims() == 3);
+ checkIntValue(&eval_context, tv3->axis(0)->rawExtent(), 2);
+ checkIntValue(&eval_context, tv3->axis(1)->rawExtent(), 4);
+ checkIntValue(&eval_context, tv3->axis(2)->rawExtent(), 128);
+
+ checkIntValue(&eval_context, bid_x, 2);
+ checkIntValue(&eval_context, tid_x, 128);
+}
+
+void testGPU_FusionClear() {
+ torch::jit::fuser::cuda::CudaKernel prog;
+ Fusion& fusion = *prog.fusion_;
+ FusionGuard fg(&fusion);
+
+ // 1. Create a dummy IR
+
+ {
+ TensorView* tv0 = makeDummyTensor(2);
+ TensorView* tv1 = makeDummyTensor(2);
+
+ fusion.addInput(tv0);
+ fusion.addInput(tv1);
+
+ TensorView* tv2 = add(tv1, new Float(2.0));
+ TensorView* tv3 = add(tv0, tv2);
+
+ fusion.addOutput(tv3);
+
+ tv3->split(0, 4);
+ tv0->computeAt(tv3, 1);
+ tv1->computeAt(tv3, 1);
+
+ tv3->axis(0)->parallelize(ParallelType::BIDx);
+ tv2->axis(1)->parallelize(ParallelType::Unroll);
+ tv3->axis(-1)->parallelize(ParallelType::TIDx);
+ }
+
+ // 2. Clear the IR
+
+ fusion.clear();
+
+ TORCH_CHECK(fusion.exprs().empty());
+ TORCH_CHECK(fusion.vals().empty());
+
+ TORCH_CHECK(fusion.inputs().empty());
+ TORCH_CHECK(fusion.outputs().empty());
+
+ TORCH_CHECK(!fusion.hasReduction());
+ TORCH_CHECK(!fusion.hasBlockReduction());
+ TORCH_CHECK(!fusion.hasGridReduction());
+
+ // 3. Rebuild the IR
+
+ {
+ TensorView* tv0 = makeDummyTensor(3);
+ TensorView* tv1 = makeDummyTensor(3);
+ TensorView* tv2 = add(tv1, new Float(2.0));
+ TensorView* tv3 = add(tv0, tv2);
+
+ fusion.addInput(tv0);
+ fusion.addInput(tv1);
+ fusion.addOutput(tv3);
+
+ tv3->reorder({{0, 2}, {2, 0}});
+ tv3->split(-1, 4);
+ tv3->reorder({{2, 0}, {3, 1}, {0, 3}});
+ tv0->computeAt(tv3, -1);
+ tv1->computeAt(tv3, -1);
+ }
+
+ prog.device_ = 0;
+ prog.grid(4);
+ prog.block(8);
+
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+
+ at::Tensor input1 = at::randn({16, 8, 8}, options);
+ at::Tensor input2 = at::randn_like(input1);
+ at::Tensor output = at::empty_like(input1);
+
+ torch::jit::fuser::cuda::compileKernel(&prog);
+ torch::jit::fuser::cuda::runTestKernel(&prog, {input1, input2}, {output});
+
+ at::Tensor tv2_ref = input2 + 2.0;
+ at::Tensor output_ref = input1 + tv2_ref;
+
+ TORCH_CHECK(output_ref.equal(output));
+}
+
+void testGPU_FusionCopy() {
+ Fusion original_fusion;
+
+ // Create the test IR
+ {
+ FusionGuard fg(&original_fusion);
+
+ auto tv0 = makeDummyTensor(3);
+ auto tv1 = makeDummyTensor(3);
+ auto tv2 = add(tv1, new Float(2.0));
+ auto tv3 = sub(add(tv0, mul(tv2, tv2)), tv2);
+
+ original_fusion.addInput(tv0);
+ original_fusion.addInput(tv1);
+ original_fusion.addOutput(tv3);
+
+ tv3->reorder({{0, 2}, {2, 0}});
+ tv3->split(-1, 4);
+ tv3->reorder({{2, 0}, {3, 1}, {0, 3}});
+
+ tv0->computeAt(tv3, -1);
+ tv1->computeAt(tv3, -1);
+
+ tv3->axis(0)->parallelize(ParallelType::BIDx);
+ tv3->axis(-1)->parallelize(ParallelType::TIDx);
+ }
+
+ // Test copy before lowering
+ Fusion clone = original_fusion;
+
+ // Compare IR dumps
+ std::stringstream original_ir;
+ std::stringstream clone_ir;
+ original_ir << original_fusion;
+ clone_ir << clone;
+ ASSERT_EQ(original_ir.str(), clone_ir.str());
+
+ // Lower original fusion
+ std::stringstream original_kernel;
+ {
+ GPULower lower(&original_fusion);
+ lower.printKernel(original_kernel);
+ }
+
+ // Make sure the "before lowering" clone was not mutated
+ // while lowering the original fusion IR
+ std::stringstream before_lowering_ir;
+ before_lowering_ir << clone;
+ ASSERT_EQ(original_ir.str(), before_lowering_ir.str());
+
+ // Test copy after lowering (including assignment operator)
+ Fusion before_lowering = clone;
+ clone = original_fusion;
+
+ // Compare IR dumps
+ std::stringstream original_lowered_ir;
+ std::stringstream clone_lowered_ir;
+ original_lowered_ir << original_fusion;
+ clone_lowered_ir << clone;
+ ASSERT_EQ(original_lowered_ir.str(), clone_lowered_ir.str());
+
+ // Lower the "before lowering" and compare kernels
+ std::stringstream clone_kernel;
+ {
+ GPULower lower(&before_lowering);
+ lower.printKernel(clone_kernel);
+ }
+ ASSERT_EQ(original_kernel.str(), clone_kernel.str());
+}
+
+void testGPU_FusionMove() {
+ Fusion fusion;
+
+ // Create the test IR
+ {
+ FusionGuard fg(&fusion);
+
+ auto tv0 = makeDummyTensor(3);
+ auto tv1 = makeDummyTensor(3);
+ auto tv2 = add(tv1, new Float(2.0));
+ auto tv3 = sub(add(tv0, mul(tv2, tv2)), tv2);
+
+ fusion.addInput(tv0);
+ fusion.addInput(tv1);
+ fusion.addOutput(tv3);
+
+ tv3->reorder({{0, 2}, {2, 0}});
+ tv3->split(-1, 4);
+ tv3->reorder({{2, 0}, {3, 1}, {0, 3}});
+
+ tv0->computeAt(tv3, -1);
+ tv1->computeAt(tv3, -1);
+
+ tv3->axis(0)->parallelize(ParallelType::BIDx);
+ tv3->axis(-1)->parallelize(ParallelType::TIDx);
+ }
+
+ std::stringstream original_ir;
+ original_ir << fusion;
+
+ // Test move before lowering
+ Fusion another_fusion = std::move(fusion);
+
+ // Check that the original fusion is "empty"
+ //
+ // IMPORTANT: these checks assume knowledge of the internal
+ // implementation of the move operations. General uses
+ // should only assume that the moved-from object is in
+ // a valid, but unspecified state. This is similar to the
+ // standard library containers:
+ // https://en.cppreference.com/w/cpp/utility/move
+ //
+ TORCH_CHECK(fusion.exprs().empty());
+ TORCH_CHECK(fusion.vals().empty());
+ TORCH_CHECK(fusion.inputs().empty());
+ TORCH_CHECK(fusion.outputs().empty());
+
+ // clear() has no pre-conditions so it's valid to call on a moved-from object
+ fusion.clear();
+
+ // Compare IR dumps
+ std::stringstream another_ir;
+ another_ir << another_fusion;
+ ASSERT_EQ(original_ir.str(), another_ir.str());
+
+ // Lower the fusion IR
+ std::stringstream kernel;
+ {
+ GPULower lower(&another_fusion);
+ lower.printKernel(kernel);
+ }
+
+ std::stringstream lowered_ir;
+ lowered_ir << another_fusion;
+
+ // Test move assignment after lowering
+ fusion = std::move(another_fusion);
+
+ // Compare IR dumps
+ std::stringstream moved_lowered_ir;
+ moved_lowered_ir << fusion;
+ ASSERT_EQ(lowered_ir.str(), moved_lowered_ir.str());
+}
+
void testGPU_FusionSimpleArith() {
std::stringstream ss1, ss2;
@@ -460,139 +751,6 @@
TORCH_CHECK(fuser_tensor->domain() != nullptr);
}
-void testGPU_FusionTensorContiguity() {
- {
- // NCHW memory layout
- auto tensor = at::randn({2, 3, 4, 5});
- auto sizes = tensor.sizes().vec();
- auto strides = tensor.strides().vec();
- TensorContiguity t_c(sizes, strides);
- TORCH_CHECK(t_c.rank() == 4);
- TORCH_CHECK(t_c.getBroadcastDims().size() == 0);
- for (int i = 0; i < 4; i++) {
- TORCH_CHECK(!t_c.isBroadcastDim(i));
- if (i < 3) {
- TORCH_CHECK(t_c.canCollapseToHigher(i));
- }
- }
- }
-
- {
- // NHWC memory layout
- TensorContiguity t_c({2, 3, 4, 5}, {60, 1, 15, 3});
- TORCH_CHECK(t_c.rank() == 4);
- TORCH_CHECK(t_c.getBroadcastDims().size() == 0);
- for (int i = 0; i < 4; i++) {
- TORCH_CHECK(!t_c.isBroadcastDim(i));
- if (i < 3) {
- TORCH_CHECK((t_c.canCollapseToHigher(i) ^ (i != 2)));
- }
- }
- }
-
- {
- // NHWC memory layout with broadcast
- TensorContiguity t_c({2, 3, 4, 5}, {120, 0, 30, 3});
- TORCH_CHECK(t_c.rank() == 4);
- auto b_dims = t_c.getBroadcastDims();
- TORCH_CHECK(b_dims.size() == 1 && b_dims[0] == 1);
- for (int i = 0; i < 4; i++) {
- TORCH_CHECK(!(t_c.isBroadcastDim(i)) ^ (i == 1));
- if (i < 3) {
- TORCH_CHECK(!(t_c.canCollapseToHigher(i)));
- }
- }
- }
-
- {
- // contiguity across size-1 dimension
- auto tensor = at::randn({4, 1, 4});
- auto sizes = tensor.sizes().vec();
- auto strides = tensor.strides().vec();
- auto dim = sizes.size();
- TensorContiguity t_c(sizes, strides);
- TORCH_CHECK(t_c.rank() == (int)sizes.size());
- auto b_dims = t_c.getBroadcastDims();
- TORCH_CHECK(b_dims.size() == 0);
- TORCH_CHECK(t_c.getFCD() == 2);
- TORCH_CHECK(t_c.hasContiguousFCD());
- for (decltype(dim) i = 0; i < dim; i++) {
- TORCH_CHECK(!t_c.isBroadcastDim(i));
- if (i < dim - 1) {
- TORCH_CHECK(t_c.canCollapseToHigher(i));
- }
- }
- }
-
- {
- // no contiguity across size-1 dimension
- auto tensor = at::randn({4, 4, 4}).split(1, 1)[0];
- auto sizes = tensor.sizes().vec();
- auto strides = tensor.strides().vec();
- TensorContiguity t_c(sizes, strides);
- TORCH_CHECK(!(t_c.canCollapseToHigher(0)));
- TORCH_CHECK((t_c.canCollapseToHigher(1)));
- }
-
- {
- // no contiguity across size-1 dimension
- auto tensor = at::randn({4, 1, 8}).split(4, 2)[0];
- auto sizes = tensor.sizes().vec();
- auto strides = tensor.strides().vec();
- TensorContiguity t_c(sizes, strides);
- TORCH_CHECK((t_c.canCollapseToHigher(0)));
- TORCH_CHECK((!t_c.canCollapseToHigher(1)));
- }
-
- {
- // no contiguity across size-1 dimension
- auto tensor = at::randn({8, 1, 4}).split(4, 0)[0];
- auto sizes = tensor.sizes().vec();
- auto strides = tensor.strides().vec();
- TensorContiguity t_c(sizes, strides);
- TORCH_CHECK((t_c.canCollapseToHigher(0)));
- TORCH_CHECK((t_c.canCollapseToHigher(1)));
- }
-
- {
- // test merge
- TensorContiguity t_c_l({4, 4, 4}, {16, 4, 1});
- TensorContiguity t_c_r({4, 4, 4}, {16, 4, 1});
- t_c_l.merge(t_c_r);
- TORCH_CHECK((t_c_l.isIdentical(t_c_r)));
- }
-
- {
- TensorContiguity t_c_l({4, 4, 4, 4}, {16, 0, 4, 1});
- TensorContiguity t_c_r({4, 4, 4, 4}, {64, 16, 4, 1});
- t_c_l.merge(t_c_r);
- TORCH_CHECK(t_c_l.getFCD() == 3);
- TORCH_CHECK(t_c_l.getAxisByStride(0) == 0);
- }
-
- {
- // NHWC + NCHW
- TensorContiguity t_c_l({4, 4, 4, 4}, {64, 16, 4, 1});
- TensorContiguity t_c_r({4, 4, 4, 4}, {64, 1, 16, 4});
- t_c_l.merge(t_c_r);
- TORCH_CHECK(!t_c_l.hasContiguousFCD());
- TORCH_CHECK(t_c_l.getFCD() == -1);
- TORCH_CHECK(t_c_l.getAxisByStride(0) == 0);
- TORCH_CHECK(t_c_l.getAxisByStride(1) == -1);
- TORCH_CHECK(t_c_l.getAxisByStride(2) == -1);
- TORCH_CHECK(t_c_l.getAxisByStride(3) == -1);
- }
-
- {
- // NCHW + NCHW with broadcasting
- TensorContiguity t_c_l({4, 4, 4, 4}, {4, 1, 4, 0});
- TensorContiguity t_c_r({4, 4, 4, 4}, {64, 1, 16, 4});
- t_c_l.merge(t_c_r);
- TORCH_CHECK(t_c_l.getFCD() == 1);
- TORCH_CHECK(t_c_l.getAxisByStride(0) == 0);
- }
-}
-
void testGPU_FusionTVSplit() {
Fusion fusion;
FusionGuard fg(&fusion);
@@ -853,51 +1011,56 @@
prog.device_ = 0;
fuser::cuda::parseJitIR(g, &prog);
- std::stringstream ref;
- ref << "__global__ void CUDAGeneratedKernel(Tensor<float, 1> T0, Tensor<float, 1> T1, Tensor<float, 1> T3){\n"
- << " float T2[4];\n"
- << " if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { \n"
- << " for(size_t i60 = 0; i60 < 4; ++i60 ) {\n"
- << " T2[ i60 ]\n"
- << " = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i60 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]\n"
- << " * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i60 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ];\n"
- << " }\n"
- << " } else { \n"
- << " for(size_t i60 = 0; i60 < 4; ++i60 ) {\n"
- << " if ( ( ( ( ( ( blockIdx.x * 4 ) + i60 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { \n"
- << " T2[ i60 ]\n"
- << " = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i60 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]\n"
- << " * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i60 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ];\n"
- << " }\n"
- << " }\n"
- << " }\n"
- << " if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { \n"
- << " for(size_t i61 = 0; i61 < 4; ++i61 ) {\n"
- << " T3[ ( ( ( ( ( blockIdx.x * 4 ) + i61 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ]\n"
- << " = T2[ i61 ]\n"
- << " * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i61 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ];\n"
- << " }\n"
- << " } else { \n"
- << " for(size_t i61 = 0; i61 < 4; ++i61 ) {\n"
- << " if ( ( ( ( ( ( blockIdx.x * 4 ) + i61 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { \n"
- << " T3[ ( ( ( ( ( blockIdx.x * 4 ) + i61 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ]\n"
- << " = T2[ i61 ]\n"
- << " * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i61 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ];\n"
- << " }\n"
- << " }\n"
- << " }\n"
- << "}\n";
+ // CONSIDER:
+ // 1. this can be moved to a dedicated "golden" file
+ // 2. use a fuzzy compare (ignore non-significant whitespaces for example)
+ const std::string expected_kernel = R"(
+__global__ void CUDAGeneratedKernel(Tensor<float, 1> T0, Tensor<float, 1> T1, Tensor<float, 1> T3){
+ float T2[4];
+ if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) {
+ for(size_t i40 = 0; i40 < 4; ++i40 ) {
+ T2[ i40 ]
+ = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i40 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]
+ * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i40 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ];
+ }
+ } else {
+ for(size_t i40 = 0; i40 < 4; ++i40 ) {
+ if ( ( ( ( ( ( blockIdx.x * 4 ) + i40 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) {
+ T2[ i40 ]
+ = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i40 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]
+ * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i40 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ];
+ }
+ }
+ }
+ if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) {
+ for(size_t i41 = 0; i41 < 4; ++i41 ) {
+ T3[ ( ( ( ( ( blockIdx.x * 4 ) + i41 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ]
+ = T2[ i41 ]
+ * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i41 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ];
+ }
+ } else {
+ for(size_t i41 = 0; i41 < 4; ++i41 ) {
+ if ( ( ( ( ( ( blockIdx.x * 4 ) + i41 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) {
+ T3[ ( ( ( ( ( blockIdx.x * 4 ) + i41 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ]
+ = T2[ i41 ]
+ * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i41 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ];
+ }
+ }
+ }
+}
+)";
GPULower gpulw(&fusion);
- std::stringstream cdg;
- gpulw.printKernel(cdg);
- if (ref.str().size() != cdg.str().size() ||
- ref.str().compare(cdg.str()) != 0) {
+ std::stringstream actual_kernel;
+ actual_kernel << "\n";
+ gpulw.printKernel(actual_kernel);
+ if (expected_kernel.size() != actual_kernel.str().size() ||
+ expected_kernel.compare(actual_kernel.str()) != 0) {
std::cerr
<< " Codegen mismatch, codegen possibly changed, or is incorrect. "
- << " \n ========= REF ========= \n"
- << ref.str() << "\n========= RESULT ========== \n"
- << cdg.str() << "\n=================" << std::endl;
+ << " \n ========= EXPECTED ========= \n"
+ << expected_kernel << "\n========= ACTUAL ========== \n"
+ << actual_kernel.str() << "\n=================" << std::endl;
TORCH_CHECK(false);
}
}
@@ -1198,7 +1361,7 @@
tv0->computeAt(tv6, 1);
- TORCH_CHECK(tv0->getComputeAtView() == tv3 && tv0->nDims() == 3);
+ TORCH_CHECK(tv0->getComputeAtView() == tv6 && tv0->nDims() == 3);
TORCH_CHECK(tv1->getComputeAtView() == tv4 && tv1->nDims() == 3);
TORCH_CHECK(tv2->getComputeAtView() == tv4 && tv2->nDims() == 3);
TORCH_CHECK(tv3->getComputeAtView() == tv6 && tv3->nDims() == 3);
@@ -1272,6 +1435,7 @@
fusion.addOutput(tv6);
tv2->computeAt(tv4, 1);
+
TORCH_CHECK(!tv0->hasComputeAt());
TORCH_CHECK(!tv1->hasComputeAt());
TORCH_CHECK(tv2->getComputeAtView() == tv4);
@@ -1323,10 +1487,10 @@
&prog, {t0}, {kernel_tv5, kernel_tv6});
GPULower gpulw(&fusion);
- std::stringstream cdg;
- gpulw.printKernel(cdg);
+ std::stringstream actual_kernel;
+ gpulw.printKernel(actual_kernel);
- TORCH_CHECK(at::allclose(kernel_tv5, t5), cdg.str());
+ TORCH_CHECK(at::allclose(kernel_tv5, t5), actual_kernel.str());
TORCH_CHECK(at::allclose(kernel_tv6, t6));
}
@@ -1390,10 +1554,10 @@
torch::jit::fuser::cuda::runTestKernel(&prog, {t0, t1}, {kernel_tv3});
GPULower gpulw(&fusion);
- std::stringstream cdg;
- gpulw.printKernel(cdg);
+ std::stringstream actual_kernel;
+ gpulw.printKernel(actual_kernel);
- TORCH_CHECK(at::allclose(kernel_tv3, t3), cdg.str());
+ TORCH_CHECK(at::allclose(kernel_tv3, t3), actual_kernel.str());
}
// Case 4
@@ -1470,10 +1634,10 @@
&prog, {t0, t1, t2, t3}, {kernel_tv6});
GPULower gpulw(&fusion);
- std::stringstream cdg;
- gpulw.printKernel(cdg);
+ std::stringstream actual_kernel;
+ gpulw.printKernel(actual_kernel);
- TORCH_CHECK(at::allclose(kernel_tv6, t6), cdg.str());
+ TORCH_CHECK(at::allclose(kernel_tv6, t6), actual_kernel.str());
}
}
@@ -1570,10 +1734,10 @@
{kernel_tv4});
GPULower gpulw(&fusion);
- std::stringstream cdg;
- gpulw.printKernel(cdg);
+ std::stringstream actual_kernel;
+ gpulw.printKernel(actual_kernel);
- TORCH_CHECK(at::allclose(kernel_tv4, t4), cdg.str());
+ TORCH_CHECK(at::allclose(kernel_tv4, t4), actual_kernel.str());
}
void testGPU_FusionLoopUnroll() {
@@ -1619,9 +1783,6 @@
int inp_size = 129 * 13 * 3;
- // GPULower lower(&fusion);
- // lower.printKernel(std::cout);
-
prog.device_ = 0;
prog.grid((inp_size + 63) / 64);
prog.block(block_size);
@@ -2192,7 +2353,8 @@
// Replay casp, replay new_domain2 as new_domain
// reordered_new_domain[I0oi{16}, I0oo*I0i{32}, ir1oi{4}rf, R(R1oo*R1i{8})rf]
- TensorDomain* casp = TransformReplay::replayCasP(new_domain2, new_domain, 2);
+ auto replay_casp = TransformReplay::replayCasP(new_domain2, new_domain, 2);
+ TensorDomain* casp = replay_casp.first;
// new_domain[I0oi{16}, I0oo*I0i{32}, ir1oi{4}rf, R(R1oo*R1i{8})rf]
// casp[I0oi{16}, I0oo*I0i{32}, R1oi{4}]
@@ -2201,7 +2363,8 @@
// new_domain[I0oi{16}, I0oo*I0i{32} , ir1oi{4}rf,
// R(R1oo*R1i{8})rf]
- TensorDomain* pasc = TransformReplay::replayPasC(new_domain, casp, 2);
+ auto replay_pasc = TransformReplay::replayPasC(new_domain, casp, 2);
+ TensorDomain* pasc = replay_pasc.first;
// pasc [I0oi{16}, (I0oo*I0i{32})o, I(Ioo*I0i)i{2}, ir1oi{4}rf,
// R(R1oo*R1i{8})rf]
@@ -2287,11 +2450,6 @@
tv2->axis(-1)->parallelize(ParallelType::TIDx);
tv3->axis(-1)->parallelize(ParallelType::TIDx);
- // for(auto expr : fusion.exprs(true))
- // std::cout<<expr<<std::endl;
- // GPULower lower(&fusion);
- // lower.printKernel(std::cout);
-
int numel_x = 65000;
int numel_y = 1025;
@@ -2527,28 +2685,189 @@
}
}
-void testGPU_FusionSimpleBCast() {
- {
- Fusion fusion;
- FusionGuard fg(&fusion);
+void testGPU_FusionReduction4() {
+ torch::jit::fuser::cuda::CudaKernel prog;
+ Fusion& fusion = *prog.fusion_;
+ FusionGuard fg(&fusion);
- // Set up your input tensor views
- TensorView* tv0 = makeDummyTensor(2);
- TensorView* tv1 = makeDummyTensor(2);
- fusion.addInput(tv0);
- fusion.addInput(tv1);
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(3);
- TensorView* tv2 = add(tv0, tv1);
+ fusion.addInput(tv0);
- // tv1[I0, R1] = tv0[I0, I1]
- TensorView* tv3 = broadcast(tv2, {false, true, true, false});
+ TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
- Val* tv4 = mul(tv3, makeDummyTensor(4));
- fusion.addOutput(tv4);
+ fusion.addOutput(tv1);
+
+ int bidy = 2;
+ int tidy = 4;
+ int tidx = 5;
+
+ int dim1 = 11;
+
+ tv1->split(-2, tidy);
+
+ TensorView* tv2 = tv1->rFactor({-3});
+
+ tv0->computeAt(tv1, 1);
+
+ tv1->axis(0)->parallelize(ParallelType::BIDy);
+
+ for (auto* val : fusion.vals()) {
+ if (val->getValType().value() == ValType::TensorView)
+ val->as<TensorView>()->axis(-1)->parallelize(ParallelType::TIDx);
}
+ tv2->axis(-2)->parallelize(ParallelType::TIDy);
+ tv1->axis(-2)->parallelize(ParallelType::TIDy);
+
+ prog.device_ = 0;
+ prog.grid(1, bidy);
+ prog.block(tidx, tidy);
+ torch::jit::fuser::cuda::compileKernel(&prog);
+
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::randn({bidy, dim1, tidx}, options);
+
+ at::Tensor cg_output = at::empty({bidy, tidx}, options);
+
+ torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {cg_output});
+
+ auto aten_output = input.sum({1});
+ TORCH_CHECK(aten_output.allclose(cg_output));
+}
+
+void testGPU_FusionReduction5() {
+ torch::jit::fuser::cuda::CudaKernel prog;
+ Fusion& fusion = *prog.fusion_;
+ FusionGuard fg(&fusion);
+
+ const int bdimx = 64;
+ const int bdimy = 8;
+
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(3);
+ fusion.addInput(tv0);
+
+ // tv1[I0, R1, R2] = tv0[I0, I1, I2]
+ TensorView* tv1 = reductionOp(BinaryOpType::Add, {1, 2}, new Float(0), tv0);
+ fusion.addOutput(tv1);
+
+ TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
+
+ tv1->split(2, bdimx);
+ // tv1[I0, R1, R2o, R2i{128}] = tv0[I0, I1, I2]
+ tv1->split(1, bdimy);
+ // tv1[I0, R1o, R1i{8}, R2o, R2i{128}] = tv0[I0, I1, I2]
+
+ TensorView* tv2 = tv1->rFactor({3});
+ // tv2[I0, I1o, I1i{8}, R2o, I2i{128}] = tv0[I0, I1, I2]
+ // tv1[I0, R1o, R1i{8}, R2i{128}] = tv2[I0, I1o, I1i{8}, R2o, I2i{128}]
+
+ TensorView* tv3 = tv1->rFactor({1});
+ // tv2[I0, I1o, I1i{8}, R2o, I2i{128}] = tv0[I0, I1, I2]
+ // tv3[I0, R1o, I1i{8}, I2i{128}] = tv2[I0, I1o, I1i{8}, R2o, I2i{128}]
+ // tv1[I0, R1i{8}, R2i{128}] = tv3[I0, R1o, I1i{8}, I2i{128}]
+
+ tv3->computeAt(tv1, 1);
+ tv2->computeAt(tv3, 2);
+
+ tv1->axis(0)->parallelize(ParallelType::BIDx);
+ tv2->axis(0)->parallelize(ParallelType::BIDx);
+ tv3->axis(0)->parallelize(ParallelType::BIDx);
+
+ tv1->axis(-1)->parallelize(ParallelType::TIDx);
+ tv2->axis(-1)->parallelize(ParallelType::TIDx);
+ tv3->axis(-1)->parallelize(ParallelType::TIDx);
+
+ tv1->axis(-2)->parallelize(ParallelType::TIDy);
+ tv3->axis(-2)->parallelize(ParallelType::TIDy);
+ tv2->axis(-3)->parallelize(ParallelType::TIDy);
+
+ int numel_x = 650;
+ int numel_y = 1000;
+ int numel_z = 1000;
+
+ prog.device_ = 0;
+ prog.grid(numel_x);
+ prog.block(bdimx, bdimy);
+
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand({numel_x, numel_y, numel_z}, options);
+ at::Tensor cg_output = at::empty({numel_x}, options);
+
+ torch::jit::fuser::cuda::compileKernel(&prog);
+ torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {cg_output});
+
+ auto aten_output = input.sum({1, 2});
+ TORCH_CHECK(aten_output.allclose(cg_output));
+}
+
+void testGPU_FusionReductionTFT() {
+ torch::jit::fuser::cuda::CudaKernel prog;
+ Fusion& fusion = *prog.fusion_;
+ FusionGuard fg(&fusion);
+
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
+ fusion.addInput(tv0);
+
+ // tv1[I0, R1] = tv0[I0, I1]
+ TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
+
+ fusion.addOutput(tv1);
+
+ int numel_x = 1025;
+ int numel_y = 129;
+ int tidx = 16;
+ int tidy = 8;
+ int tidz = 8;
+
+ tv1->split(1, tidx);
+ // tv1[I0, R1o, R1i{tidx}]
+
+ tv1->split(1, tidz);
+ // tv1[I0, R1oo, R1Oi{tidz}, R1R1i{tidx}]
+
+ tv1->split(0, tidy);
+ // tv1[I0o, I0i, R1oo, R1Oi{tidz}, R1R1i{tidx}]
+
+ TensorView* tv2 = tv1->rFactor({2});
+ // tv2[I0o, I0i, R1oo, I1Oi{tidz}, I11i{tidx}]
+ // tv1[I0o, I0i, R1Oi{tidz}, R1R1i{tidx}]
+
+ tv2->computeAt(tv1, 2);
+
+ tv1->axis(1)->parallelize(ParallelType::TIDy);
+
+ tv2->axis(-1)->parallelize(ParallelType::TIDx);
+ tv1->axis(-1)->parallelize(ParallelType::TIDx);
+
+ tv1->axis(-2)->parallelize(ParallelType::TIDz);
+ tv2->axis(-2)->parallelize(ParallelType::TIDz);
+
+ prog.device_ = 0;
+ prog.grid(1);
+ prog.block(tidx, tidy, tidz);
+
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand({numel_x, numel_y}, options);
+ at::Tensor cg_output = at::empty({numel_x}, options);
+
+ torch::jit::fuser::cuda::compileKernel(&prog);
+ torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {cg_output});
+
+ c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream();
+ AT_CUDA_CHECK(cudaStreamSynchronize(stream));
+
+ auto aten_output = input.sum({1});
+ TORCH_CHECK(aten_output.allclose(cg_output));
+}
+
+void testGPU_FusionSimpleBCast() {
{
- Fusion fusion;
+ torch::jit::fuser::cuda::CudaKernel prog;
+ Fusion& fusion = *prog.fusion_;
FusionGuard fg(&fusion);
// Set up your input tensor views
@@ -2557,18 +2876,1008 @@
fusion.addInput(tv0);
fusion.addInput(tv1);
- TensorView* tv2 = broadcast(tv0, {true, false, false});
- TensorView* tv3 = broadcast(tv1, {false, false, true});
+ TensorView* tv2 = broadcast(tv0, {false, false, true});
+ TensorView* tv3 = broadcast(tv1, {true, false, false});
- TensorView* tv4 = mul(tv3, tv2);
+ TensorView* tv4 = add(tv2, tv3);
+ tv4->split(-1, 4);
+ tv4->split(0, 8);
fusion.addOutput(tv4);
tv0->computeAt(tv4, -1);
tv1->computeAt(tv4, -1);
- // GPULower lower(&fusion);
- // lower.printKernel(std::cout);
+ tv4->axis(0)->parallelize(ParallelType::BIDx);
+ tv4->axis(-1)->parallelize(ParallelType::TIDx);
+
+ constexpr int x = 63, y = 33, z = 15;
+
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+
+ at::Tensor t0 = at::randn({x, y}, options);
+ at::Tensor t1 = at::randn({y, z}, options);
+
+ at::Tensor cg_output = at::empty({x, y, z}, options);
+
+ prog.device_ = 0;
+ prog.grid(ceilDiv_(x, 8));
+ prog.block(4);
+ torch::jit::fuser::cuda::compileKernel(&prog);
+ torch::jit::fuser::cuda::runTestKernel(&prog, {t0, t1}, {cg_output});
+
+ auto t2 = t0.unsqueeze(-1).expand({x, y, z});
+ auto t3 = t1.expand({x, y, z});
+ auto t4 = t2.add(t3);
+
+ TORCH_CHECK(t4.allclose(cg_output));
}
+
+ {
+ torch::jit::fuser::cuda::CudaKernel prog;
+ Fusion& fusion = *prog.fusion_;
+ FusionGuard fg(&fusion);
+
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
+ TensorView* tv1 = makeDummyTensor(2);
+ fusion.addInput(tv0);
+ fusion.addInput(tv1);
+
+ // TODO add pointwise ops on the begining before the bcast.
+
+ TensorView* tv2 = broadcast(tv0, {false, false, true});
+ TensorView* tv3 = broadcast(tv1, {true, false, false});
+
+ TensorView* tv4 = add(tv2, tv3);
+
+ tv4->merge(0, 1);
+
+ fusion.addOutput(tv4);
+
+ tv0->computeAt(tv4, -1);
+ tv1->computeAt(tv4, -1);
+
+ tv4->axis(0)->parallelize(ParallelType::BIDx);
+
+ constexpr int x = 63, y = 33, z = 15;
+
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+
+ at::Tensor t0 = at::randn({x, y}, options);
+ at::Tensor t1 = at::randn({y, z}, options);
+
+ at::Tensor cg_output = at::empty({x, y, z}, options);
+
+ prog.device_ = 0;
+ prog.grid(x * y);
+ prog.block(1);
+ torch::jit::fuser::cuda::compileKernel(&prog);
+ torch::jit::fuser::cuda::runTestKernel(&prog, {t0, t1}, {cg_output});
+
+ auto t2 = t0.unsqueeze(-1).expand({x, y, z});
+ auto t3 = t1.expand({x, y, z});
+ auto t4 = t2.add(t3);
+
+ TORCH_CHECK(t4.allclose(cg_output));
+ }
+}
+
+void testGPU_FusionSimpleGemm() {
+ {
+ torch::jit::fuser::cuda::CudaKernel prog;
+ Fusion& fusion = *prog.fusion_;
+ FusionGuard fg(&fusion);
+
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2); // M, K
+ TensorView* tv1 = makeDummyTensor(2); // K, N
+ fusion.addInput(tv0);
+ fusion.addInput(tv1);
+
+ TensorView* tv2 = broadcast(tv0, {false, false, true});
+ // tv2[I0, I1, B] = tv0[I0, I1]
+
+ TensorView* tv3 = broadcast(tv1, {true, false, false});
+ // tv3[B, I1, I2] = tv1[I1, I2]
+
+ // tv4[I0, I1, I2] = tv2[I0, I1, B] * tv3[B, I1, I2]
+ TensorView* tv4 = mul(tv2, tv3);
+ // tv5[I0, R1, I2] = tv4[I0, I1, I2]
+ TensorView* tv5 = sum(tv4, {1});
+ fusion.addOutput(tv5);
+
+ tv5->split(1, 32);
+ // tv5[I0, R1o, R1i{32}, I2]
+
+ auto tv6 = tv5->rFactor({1});
+ // tv6[I0, R1o, I1i{32}, I2] = tv4[I0, I1, I2]
+ // tv5[I0, , R1i{32}, I2] = tv6[I0, R1o, I1i{32}, I2]
+
+ tv5->split(0, 4);
+ tv5->split(-1, 4);
+ // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}]
+ // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}]
+
+ tv0->computeAt(tv5, -1);
+ tv1->computeAt(tv5, -1);
+
+ // tv6[I0o, I0i{4}, R1o, I1i{32}, I2o, I2i{4}]
+ // tv5[I0o, I0i{4}, , R1i{32}, I2o, I2i{4}]
+ //--> (line symbolizes compute at location)
+ // tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, I1o]
+ // tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, R1o]
+ // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|]
+
+ tv0->computeAt(tv6, -1);
+ tv1->computeAt(tv6, -1);
+ // tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, I1o |]
+ // tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, R1o |]
+ // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|]
+
+ tv5->axis(0)->parallelize(ParallelType::BIDz);
+ tv5->axis(1)->parallelize(ParallelType::TIDz);
+
+ tv5->axis(-2)->parallelize(ParallelType::BIDy);
+ tv5->axis(-1)->parallelize(ParallelType::TIDy);
+
+ tv5->axis(2)->parallelize(ParallelType::TIDx);
+ tv6->axis(2)->parallelize(ParallelType::TIDx);
+
+ constexpr int M = 65, K = 33, N = 17;
+
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+
+ at::Tensor t0 = at::randn({M, K}, options);
+ at::Tensor t1 = at::randn({K, N}, options);
+
+ at::Tensor cg_output = at::empty({M, N}, options);
+
+ prog.device_ = 0;
+ prog.grid(1, ceilDiv_(N, 4), ceilDiv_(M, 4));
+
+ prog.block(32, 4, 4);
+ torch::jit::fuser::cuda::compileKernel(&prog);
+ torch::jit::fuser::cuda::runTestKernel(&prog, {t0, t1}, {cg_output});
+
+ auto t2 = t0.matmul(t1);
+ TORCH_CHECK(
+ t2.allclose(cg_output, 1e-5, 1e-5),
+ "Error of: ",
+ t2.sub(cg_output).abs().max());
+ }
+}
+
+// This test currently requires a combination of broadcast and reduction
+// operations and parellelization strategy that is currently not supported.
+// It is a goal to get this example working and this test is added so we
+// can continue working on getting this example fixed. Right now it
+// produces an incorrect result. Either we need to error coherently on the
+// optimization strategy we don't support and set this test to one we do support
+// or we need to get this schedule working correctly.
+void testGPU_FusionSoftmax() {
+ torch::jit::fuser::cuda::CudaKernel prog;
+ Fusion& fusion = *prog.fusion_;
+ FusionGuard fg(&fusion);
+
+ // Set up your input tensor views
+ TensorView* input_tv0 = makeDummyTensor(3);
+ fusion.addInput(input_tv0);
+
+ TensorView* max_val_tv1 =
+ reductionOp(BinaryOpType::Max, {2}, new Float(0), input_tv0);
+ TensorView* bcast_max_tv2 = broadcast(max_val_tv1, {false, false, true});
+ TensorView* exp_tv3 = sub(input_tv0, bcast_max_tv2);
+ TensorView* sum_exp_tv4 =
+ reductionOp(BinaryOpType::Add, {2}, new Float(0), exp_tv3);
+ TensorView* bcast_sum_tv5 = broadcast(sum_exp_tv4, {false, false, true});
+ TensorView* output_tv6 = div(exp_tv3, bcast_sum_tv5);
+
+ max_val_tv1->split(-1, 32);
+ TensorView* max_val_rf_tv7 = max_val_tv1->rFactor({-2});
+ sum_exp_tv4->split(-1, 32);
+ TensorView* sum_exp_rf_tv8 = sum_exp_tv4->rFactor({-2});
+
+ exp_tv3->computeAt(sum_exp_rf_tv8, 2);
+
+ max_val_rf_tv7->axis(0)->parallelize(ParallelType::BIDx);
+ max_val_tv1->axis(0)->parallelize(ParallelType::BIDx);
+ bcast_max_tv2->axis(0)->parallelize(ParallelType::BIDx);
+ sum_exp_rf_tv8->axis(0)->parallelize(ParallelType::BIDx);
+ sum_exp_tv4->axis(0)->parallelize(ParallelType::BIDx);
+ bcast_sum_tv5->axis(0)->parallelize(ParallelType::BIDx);
+ output_tv6->axis(0)->parallelize(ParallelType::BIDx);
+
+ max_val_rf_tv7->axis(1)->parallelize(ParallelType::BIDy);
+ max_val_tv1->axis(1)->parallelize(ParallelType::BIDy);
+ bcast_max_tv2->axis(1)->parallelize(ParallelType::BIDy);
+ sum_exp_rf_tv8->axis(1)->parallelize(ParallelType::BIDy);
+ sum_exp_tv4->axis(1)->parallelize(ParallelType::BIDy);
+ bcast_sum_tv5->axis(1)->parallelize(ParallelType::BIDy);
+ output_tv6->axis(1)->parallelize(ParallelType::BIDy);
+
+ max_val_rf_tv7->axis(-1)->parallelize(ParallelType::TIDx);
+ max_val_tv1->axis(-1)->parallelize(ParallelType::TIDx);
+ bcast_max_tv2->axis(-1)->parallelize(ParallelType::TIDx);
+ exp_tv3->axis(-1)->parallelize(ParallelType::TIDx);
+ sum_exp_rf_tv8->axis(-1)->parallelize(ParallelType::TIDx);
+ sum_exp_tv4->axis(-1)->parallelize(ParallelType::TIDx);
+ bcast_sum_tv5->axis(-1)->parallelize(ParallelType::TIDx);
+ output_tv6->axis(-1)->parallelize(ParallelType::TIDx);
+
+ fusion.addOutput(output_tv6);
+
+ prog.device_ = 0;
+ prog.grid(32, 32);
+ prog.block(32);
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor t0 = at::randn({32, 32, 128}, options);
+ at::Tensor cg_output = at::empty({32, 32, 128}, options);
+ torch::jit::fuser::cuda::compileKernel(&prog);
+ torch::jit::fuser::cuda::runTestKernel(&prog, {t0}, {cg_output});
+
+ auto t2 = at::_softmax(t0, -1, false);
+ // TORCH_CHECK(
+ // t2.allclose(cg_output, 1e-5, 1e-5),
+ // "Error of: ",
+ // t2.sub(cg_output).abs().max());
+}
+// Similar to FusionReduction but uses grid reduction
+void testGPU_FusionGridReduction1() {
+ const int gdimx = 32;
+ const int bdimx = 128;
+ torch::jit::fuser::cuda::CudaKernel prog;
+ Fusion& fusion = *prog.fusion_;
+ FusionGuard fg(&fusion);
+
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
+ fusion.addInput(tv0);
+
+ // tv1[I0, R1] = tv0[I0, I1]
+ TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
+ fusion.addOutput(tv1);
+
+ TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
+
+ tv1->split(1, bdimx);
+ // tv1[I0, R1o, R1i{128}] = tv0[I0, I1]
+ tv1->split(1, gdimx);
+ // tv1[I0, R1oo, R1oi{32}, R1i{128}] = tv0[I0, I1]
+
+ TensorView* tv2 = tv1->rFactor({1});
+ // tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] = tv0[I0, I1]
+ // tv1[I0, R1oi{32}, R1i{128}] = tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}]
+
+ // Incrementally, can print in between for debugging
+ tv0->computeAt(tv2, 1);
+ tv2->computeAt(tv1, 1);
+
+ // Re do it all at once, because why not.
+ tv0->computeAt(tv1, 1);
+
+ tv1->axis(0)->parallelize(ParallelType::BIDy);
+ tv1->axis(1)->parallelize(ParallelType::BIDx);
+ tv2->axis(2)->parallelize(ParallelType::BIDx);
+
+ tv1->axis(-1)->parallelize(ParallelType::TIDx);
+ tv2->axis(-1)->parallelize(ParallelType::TIDx);
+
+ int numel_x = 10000;
+ int numel_y = 65000;
+
+ prog.device_ = 0;
+ prog.grid(gdimx, numel_x);
+ prog.block(bdimx);
+
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand({numel_x, numel_y}, options);
+ at::Tensor cg_output = at::empty({numel_x}, options);
+
+ torch::jit::fuser::cuda::compileKernel(&prog);
+ torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {cg_output});
+
+ auto aten_output = input.sum({1});
+ TORCH_CHECK(aten_output.allclose(cg_output));
+}
+
+// Same test as the above but uses BIDy and TIDx for reduction
+void testGPU_FusionGridReduction2() {
+ const int gdimy = 32;
+ const int bdimx = 128;
+ torch::jit::fuser::cuda::CudaKernel prog;
+ Fusion& fusion = *prog.fusion_;
+ FusionGuard fg(&fusion);
+
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
+ fusion.addInput(tv0);
+
+ // tv1[I0, R1] = tv0[I0, I1]
+ TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
+ fusion.addOutput(tv1);
+
+ TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
+
+ tv1->split(1, bdimx);
+ // tv1[I0, R1o, R1i{128}] = tv0[I0, I1]
+ tv1->split(1, gdimy);
+ // tv1[I0, R1oo, R1oi{32}, R1i{128}] = tv0[I0, I1]
+
+ TensorView* tv2 = tv1->rFactor({1});
+ // tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] = tv0[I0, I1]
+ // tv1[I0, R1oi{32}, R1i{128}] = tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}]
+
+ // Incrementally, can print in between for debugging
+ tv0->computeAt(tv2, 1);
+ tv2->computeAt(tv1, 1);
+
+ // Re do it all at once, because why not.
+ tv0->computeAt(tv1, 1);
+
+ tv1->axis(0)->parallelize(ParallelType::BIDx);
+ tv1->axis(1)->parallelize(ParallelType::BIDy);
+ tv2->axis(2)->parallelize(ParallelType::BIDy);
+
+ tv1->axis(-1)->parallelize(ParallelType::TIDx);
+ tv2->axis(-1)->parallelize(ParallelType::TIDx);
+
+ int numel_x = 10000;
+ int numel_y = 65000;
+
+ prog.device_ = 0;
+ prog.grid(numel_x, gdimy);
+ prog.block(bdimx);
+
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand({numel_x, numel_y}, options);
+ at::Tensor cg_output = at::empty({numel_x}, options);
+
+ torch::jit::fuser::cuda::compileKernel(&prog);
+ torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {cg_output});
+
+ auto aten_output = input.sum({1});
+ TORCH_CHECK(aten_output.allclose(cg_output));
+}
+
+// Same test but uses BIDy and BIDz for reduction. No TID used.
+void testGPU_FusionGridReduction3dim1() {
+ const int gdimz = 32;
+ const int gdimy = 128;
+ torch::jit::fuser::cuda::CudaKernel prog;
+ Fusion& fusion = *prog.fusion_;
+ FusionGuard fg(&fusion);
+
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
+ fusion.addInput(tv0);
+
+ // tv1[I0, R1] = tv0[I0, I1]
+ TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
+ fusion.addOutput(tv1);
+
+ TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
+
+ tv1->split(1, gdimy);
+ // tv1[I0, R1o, R1i{128}] = tv0[I0, I1]
+ tv1->split(1, gdimz);
+ // tv1[I0, R1oo, R1oi{32}, R1i{128}] = tv0[I0, I1]
+
+ TensorView* tv2 = tv1->rFactor({1});
+ // tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] = tv0[I0, I1]
+ // tv1[I0, R1oi{32}, R1i{128}] = tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}]
+
+ // Incrementally, can print in between for debugging
+ tv0->computeAt(tv2, 1);
+ tv2->computeAt(tv1, 1);
+
+ // Re do it all at once, because why not.
+ tv0->computeAt(tv1, 1);
+
+ tv1->axis(0)->parallelize(ParallelType::BIDx);
+ tv1->axis(1)->parallelize(ParallelType::BIDz);
+ tv2->axis(2)->parallelize(ParallelType::BIDz);
+
+ tv1->axis(-1)->parallelize(ParallelType::BIDy);
+ tv2->axis(-1)->parallelize(ParallelType::BIDy);
+
+ int numel_x = 100;
+ int numel_y = 6500;
+
+ prog.device_ = 0;
+ prog.grid(numel_x, gdimy, gdimz);
+ // This number should not affect the output as TIDx is not
+ // used. All threads in a thread block redundantly computes the
+ // same value.
+ prog.block(128);
+
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand({numel_x, numel_y}, options);
+ at::Tensor cg_output = at::empty({numel_x}, options);
+
+ torch::jit::fuser::cuda::compileKernel(&prog);
+ torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {cg_output});
+
+ auto aten_output = input.sum({1});
+ TORCH_CHECK(aten_output.allclose(cg_output));
+}
+
+// Same as testGPU_FusionGridReduction3dim1 but reduces dimension 0
+void testGPU_FusionGridReduction3dim0() {
+ const int rdim = 0;
+ const int gdimy = 128;
+ const int gdimz = 32;
+ torch::jit::fuser::cuda::CudaKernel prog;
+ Fusion& fusion = *prog.fusion_;
+ FusionGuard fg(&fusion);
+
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
+ fusion.addInput(tv0);
+
+ // tv1[R0, I1] = tv0[I0, I1]
+ TensorView* tv1 = reductionOp(BinaryOpType::Add, {rdim}, new Float(0), tv0);
+ fusion.addOutput(tv1);
+
+ TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
+
+ tv1->split(rdim, gdimy);
+ // tv1[R0o, R0i{128}, I1] = tv0[I0, I1]
+ tv1->split(rdim, gdimz);
+ // tv1[R0oo, R0oi{32}, R0i{128}, I1] = tv0[I0, I1]
+
+ TensorView* tv2 = tv1->rFactor({rdim});
+ // tv2[R0oo, I0oi{32}, I0i{128}, I1] = tv0[I0, I1]
+ // tv1[ R0oi{32}, R0i{128}, I1] = tv2[R0oo, I0oi{32}, I0i{128}, I1]
+
+ // Note that computeAt isn't going to make anything better as there
+ // is no dynamically sized dimension.
+
+ // Map parallelism as [Serial, BIDz, BIDy, BIDx]
+ tv1->axis(-1)->parallelize(ParallelType::BIDx);
+ tv2->axis(-1)->parallelize(ParallelType::BIDx);
+ tv1->axis(-2)->parallelize(ParallelType::BIDy);
+ tv2->axis(-2)->parallelize(ParallelType::BIDy);
+ tv1->axis(-3)->parallelize(ParallelType::BIDz);
+ tv2->axis(-3)->parallelize(ParallelType::BIDz);
+
+ int numel_x = 6500;
+ int numel_y = 100;
+
+ prog.device_ = 0;
+ prog.grid(numel_y, gdimy, gdimz);
+ // This number should not affect the output as TIDx is not
+ // used. All threads in a thread block redundantly computes the
+ // same value.
+ prog.block(1);
+
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand({numel_x, numel_y}, options);
+ at::Tensor cg_output = at::empty({numel_y}, options);
+
+ torch::jit::fuser::cuda::compileKernel(&prog);
+ torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {cg_output});
+
+ auto aten_output = input.sum({0});
+ TORCH_CHECK(aten_output.allclose(cg_output));
+}
+
+// This is similar to the FusionReduction, but swaps BIDx and TIDx
+void testGPU_FusionGridReduction4() {
+ torch::jit::fuser::cuda::CudaKernel prog;
+ Fusion& fusion = *prog.fusion_;
+ FusionGuard fg(&fusion);
+
+ const int bdimx = 128;
+ const int gdimx = 1024;
+
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
+ fusion.addInput(tv0);
+
+ // tv1[I0, R1] = tv0[I0, I1]
+ TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
+ fusion.addOutput(tv1);
+
+ TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
+
+ tv1->split(1, gdimx);
+ // tv1[I0, R1o, R1i{1024}] = tv0[I0, I1]
+ tv1->split(1, 4);
+ // tv1[I0, R1oo, R1oi{4}, R1i{128}] = tv0[I0, I1]
+
+ TensorView* tv2 = tv1->rFactor({1});
+ // tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}] = tv0[I0, I1]
+ // tv1[I0, R1oi{4}, R1i{1024}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}]
+
+ TensorView* tv3 = tv1->rFactor({1});
+ // tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}] = tv0[I0, I1]
+ // tv3[I0, R1oi{4}, Ir1i{1024}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}]
+ // tv1[I0, R1i{1024}] = tv3[I0, R1oi{4}, Ir1i{1024}]
+
+ // Incrementally, can print in between for debugging
+ tv0->computeAt(tv2, 1);
+ tv2->computeAt(tv3, 1);
+ tv3->computeAt(tv1, 1);
+
+ // Re do it all at once, because why not.
+ tv0->computeAt(tv1, 1);
+
+ tv2->axis(2)->parallelize(ParallelType::Unroll);
+ tv1->axis(0)->parallelize(ParallelType::TIDx);
+
+ tv1->axis(-1)->parallelize(ParallelType::BIDx);
+ tv2->axis(-1)->parallelize(ParallelType::BIDx);
+ tv3->axis(-1)->parallelize(ParallelType::BIDx);
+
+ int numel_x = bdimx;
+ int numel_y = 65000;
+
+ prog.device_ = 0;
+ prog.grid(gdimx);
+ prog.block(bdimx);
+
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand({numel_x, numel_y}, options);
+ at::Tensor cg_output = at::empty({numel_x}, options);
+
+ torch::jit::fuser::cuda::compileKernel(&prog);
+ torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {cg_output});
+
+ auto aten_output = input.sum({1});
+ TORCH_CHECK(aten_output.allclose(cg_output));
+}
+
+// Grid reduction with 2D thread blocks but only TIDx and BIDx are
+// mapped to a reduction dim
+void testGPU_FusionGridReduction5() {
+ torch::jit::fuser::cuda::CudaKernel prog;
+ Fusion& fusion = *prog.fusion_;
+ FusionGuard fg(&fusion);
+
+ const int bdimx = 64;
+ const int bdimy = 16;
+ const int gdimx = 4;
+
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
+ fusion.addInput(tv0);
+
+ // tv1[I0, R1] = tv0[I0, I1]
+ TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
+ fusion.addOutput(tv1);
+
+ TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
+
+ tv1->split(1, bdimx);
+ // tv1[I0, R1o, R1i{64}] = tv0[I0, I1]
+ tv1->split(1, gdimx);
+ // tv1[I0, R1oo, R1oi{4}, R1i{64}] = tv0[I0, I1]
+
+ TensorView* tv2 = tv1->rFactor({1});
+ // tv2[I0, R1oo, Ir1oi{4}, Ir1i{64}] = tv0[I0, I1]
+ // tv1[I0, R1oi{4}, R1i{64}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{64}]
+
+ tv0->computeAt(tv1, 1);
+
+ tv1->axis(-1)->parallelize(ParallelType::TIDx);
+ tv2->axis(-1)->parallelize(ParallelType::TIDx);
+
+ tv1->axis(-2)->parallelize(ParallelType::BIDx);
+ tv2->axis(-2)->parallelize(ParallelType::BIDx);
+
+ tv1->axis(0)->parallelize(ParallelType::TIDy);
+
+ int numel_x = bdimy;
+ int numel_y = 6500;
+
+ prog.device_ = 0;
+ prog.grid(gdimx);
+ prog.block(bdimx, bdimy);
+
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand({numel_x, numel_y}, options);
+ at::Tensor cg_output = at::empty({numel_x}, options);
+
+ torch::jit::fuser::cuda::compileKernel(&prog);
+ torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {cg_output});
+
+ auto aten_output = input.sum({1});
+ TORCH_CHECK(aten_output.allclose(cg_output));
+}
+
+// Similar to FusionGridReduction1 but with 3D tensors
+void testGPU_FusionGridReduction6() {
+ torch::jit::fuser::cuda::CudaKernel prog;
+ Fusion& fusion = *prog.fusion_;
+ FusionGuard fg(&fusion);
+
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(3);
+ fusion.addInput(tv0);
+
+ // tv1[I0, R1, R2] = tv0[I0, I1, I2]
+ TensorView* tv1 = reductionOp(BinaryOpType::Add, {1, 2}, new Float(0), tv0);
+ fusion.addOutput(tv1);
+
+ TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
+
+ // Splitting for TID
+ tv1->split(2, 128);
+ // tv1[I0, R1, R2o, R2i{128}] = tv0[I0, I1, I2]
+
+ // Splitting for BID
+ tv1->split(1, 128);
+
+ // tv1[I0, R1o, R1i{128}, R2o, R2i{128}] = tv0[I0, I1, I2]
+
+ TensorView* tv2 = tv1->rFactor({3});
+ // tv2[I0, I1o, I1i{128}, R2o, I2i{128}]
+ // tv1[I0, R1o, R1i{128}, R2i{128}]
+
+ TensorView* tv3 = tv1->rFactor({1});
+ // tv2[I0, I1o, I1i{128}, R2o, I2i{128}]
+ // tv3[I0, R1o, I1i{128}, I2i{128}]
+ // tv1[I0, R1i{128}, R2i{128}]
+
+ tv3->computeAt(tv1, 1);
+ tv2->computeAt(tv3, 3);
+
+ tv1->axis(0)->parallelize(ParallelType::BIDy);
+
+ tv1->axis(-1)->parallelize(ParallelType::TIDx);
+ tv2->axis(-1)->parallelize(ParallelType::TIDx);
+ tv3->axis(-1)->parallelize(ParallelType::TIDx);
+
+ tv1->axis(-2)->parallelize(ParallelType::BIDx);
+ tv2->axis(-3)->parallelize(ParallelType::BIDx);
+ tv3->axis(-2)->parallelize(ParallelType::BIDx);
+
+ int numel_x = 6500;
+ int numel_y = 200;
+ int numel_z = numel_y;
+
+ prog.device_ = 0;
+ prog.grid(128, numel_x);
+ prog.block(128);
+
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand({numel_x, numel_y, numel_z}, options);
+ at::Tensor cg_output = at::empty({numel_x}, options);
+
+ torch::jit::fuser::cuda::compileKernel(&prog);
+ torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {cg_output});
+
+ auto aten_output = input.sum({1, 2});
+ TORCH_CHECK(aten_output.allclose(cg_output));
+}
+
+void testGPU_FusionNonRedAxisBind() {
+ int bid_x = 3;
+ int tid_x = 2;
+ int red_dim = 0;
+
+ torch::jit::fuser::cuda::CudaKernel prog;
+ Fusion& fusion = *prog.fusion_;
+ FusionGuard fg(&fusion);
+
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
+ fusion.addInput(tv0);
+
+ TensorView* tv1 =
+ reductionOp(BinaryOpType::Add, {red_dim}, new Float(0), tv0);
+ fusion.addOutput(tv1);
+
+ tv1->split(-1, tid_x);
+ tv1->axis(-2)->parallelize(ParallelType::BIDx);
+ tv1->axis(-1)->parallelize(ParallelType::TIDx);
+
+ prog.device_ = 0;
+ prog.grid(bid_x);
+ prog.block(tid_x);
+
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand({16, bid_x * tid_x}, options);
+ at::Tensor cg_output = at::empty({bid_x * tid_x}, options);
+
+ torch::jit::fuser::cuda::compileKernel(&prog);
+ torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {cg_output});
+
+ auto aten_output = input.sum({red_dim});
+
+ TORCH_CHECK(
+ aten_output.allclose(cg_output),
+ "Error of: ",
+ aten_output.sub(cg_output).abs().max());
+}
+
+void testGPU_FusionSplitBCast() {
+ torch::jit::fuser::cuda::CudaKernel prog;
+ Fusion& fusion = *prog.fusion_;
+ FusionGuard fg(&fusion);
+
+ // Set up your input tensor views
+ TensorView* input_tv0 = makeDummyTensor(3);
+ TensorView* input_tv1 = makeDummyTensor(3);
+ fusion.addInput(input_tv0);
+ fusion.addInput(input_tv1);
+
+ TensorView* sum_tv2 =
+ reductionOp(BinaryOpType::Add, {2}, new Float(0), input_tv0);
+ TensorView* bcast_tv3 = broadcast(sum_tv2, {false, false, true});
+ TensorView* output_tv4 = div(input_tv1, bcast_tv3);
+
+ sum_tv2->split(-1, 32);
+ TensorView* sum_rf_tv5 = sum_tv2->rFactor({-2});
+
+ bcast_tv3->split(-1, 32);
+ output_tv4->split(-1, 32);
+
+ sum_rf_tv5->axis(0)->parallelize(ParallelType::BIDx);
+ sum_tv2->axis(0)->parallelize(ParallelType::BIDx);
+ bcast_tv3->axis(0)->parallelize(ParallelType::BIDx);
+ output_tv4->axis(0)->parallelize(ParallelType::BIDx);
+
+ sum_rf_tv5->axis(1)->parallelize(ParallelType::BIDy);
+ sum_tv2->axis(1)->parallelize(ParallelType::BIDy);
+ bcast_tv3->axis(1)->parallelize(ParallelType::BIDy);
+ output_tv4->axis(1)->parallelize(ParallelType::BIDy);
+
+ sum_rf_tv5->axis(-1)->parallelize(ParallelType::TIDx);
+ sum_tv2->axis(-1)->parallelize(ParallelType::TIDx);
+ bcast_tv3->axis(-1)->parallelize(ParallelType::TIDx);
+ output_tv4->axis(-1)->parallelize(ParallelType::TIDx);
+
+ fusion.addOutput(output_tv4);
+
+ prog.device_ = 0;
+ prog.grid(32, 32);
+ prog.block(32);
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor t0 = at::randn({32, 32, 128}, options);
+ at::Tensor t1 = at::randn({32, 32, 128}, options);
+ at::Tensor cg_output = at::empty({32, 32, 128}, options);
+ torch::jit::fuser::cuda::compileKernel(&prog);
+ torch::jit::fuser::cuda::runTestKernel(&prog, {t0, t1}, {cg_output});
+}
+
+void testGPU_FusionBCastInnerDim() {
+ torch::jit::fuser::cuda::CudaKernel prog;
+ Fusion& fusion = *prog.fusion_;
+ FusionGuard fg(&fusion);
+
+ TensorView* tv0 = makeDummyTensor(2);
+ fusion.addInput(tv0);
+
+ // reduce then broadcast
+ auto tv1 = sum(tv0, {0});
+ auto tv2 = broadcast(tv1, {false, true});
+
+ TORCH_CHECK(!tv2->axis(0)->isReduction() && tv2->axis(1)->isBroadcast());
+}
+
+void testGPU_FusionBCastReduce() {
+ torch::jit::fuser::cuda::CudaKernel prog;
+ Fusion& fusion = *prog.fusion_;
+ FusionGuard fg(&fusion);
+
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
+
+ auto tv1 = broadcast(tv0, {true, false, false});
+ auto tv2 = sum(tv1, {1});
+ TORCH_CHECK(
+ tv2->axis(0)->isBroadcast() && tv2->axis(1)->isReduction() &&
+ !tv2->axis(2)->isBroadcast() && !tv2->axis(2)->isReduction());
+}
+
+// Multiple consumer reduction with computeAt
+// https://github.com/csarofeen/pytorch/issues/110
+void testGPU_FusionReductionMultiConsumer() {
+ torch::jit::fuser::cuda::CudaKernel prog;
+ Fusion& fusion = *prog.fusion_;
+ FusionGuard fg(&fusion);
+ TensorView* tv0 = makeDummyTensor(2);
+ fusion.addInput(tv0);
+ auto tv1 = unaryOp(UnaryOpType::Exp, tv0);
+ auto tv2 = reductionOp(BinaryOpType::Max, {-1}, new Float(0), tv1);
+ auto tv3 = reductionOp(BinaryOpType::Min, {-1}, new Float(0), tv1);
+ auto tv4 = add(tv2, tv3);
+ fusion.addOutput(tv4);
+ tv1->computeAt(tv2, -1);
+
+ TORCH_CHECK(
+ (tv1->getComputeAtView() == tv2 || tv1->getComputeAtView() == tv3) &&
+ tv1->getThisComputeAtAxis() == 2 && tv1->getRelativeComputeAtAxis() == 2);
+}
+
+void testGPU_FusionComputeAtExprOrder() {
+ {
+ for (int i = 0; i < 2; ++i) {
+ torch::jit::fuser::cuda::CudaKernel prog;
+ Fusion& fusion = *prog.fusion_;
+ FusionGuard fg(&fusion);
+
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(1);
+ fusion.addInput(tv0);
+
+ auto tv1 = add(tv0, new Float(1));
+ auto tv2 = add(tv0, new Float(1));
+ TensorView* tv3 = add(tv1, tv2);
+ if (i == 0) {
+ tv1->computeAt(tv3, -1);
+ fusion.addOutput(tv2);
+ } else {
+ tv2->computeAt(tv3, -1);
+ fusion.addOutput(tv1);
+ }
+ fusion.addOutput(tv3);
+
+ prog.device_ = 0;
+ prog.grid(1);
+ prog.block(1);
+
+ torch::jit::fuser::cuda::compileKernel(&prog);
+
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand({100}, options);
+ at::Tensor output2 = at::empty_like(input, options);
+ at::Tensor output3 = at::empty_like(input, options);
+ torch::jit::fuser::cuda::runTestKernel(
+ &prog, {input}, {output2, output3});
+ auto aten_output = (input + 1) * 2;
+ TORCH_CHECK(
+ aten_output.allclose(output3),
+ "Error of: ",
+ aten_output.sub(output3).abs().max());
+ }
+ }
+ {
+ torch::jit::fuser::cuda::CudaKernel prog;
+ Fusion& fusion = *prog.fusion_;
+ FusionGuard fg(&fusion);
+
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
+ fusion.addInput(tv0);
+
+ auto tv1 = add(tv0, new Float(1));
+ auto tv2 = add(tv0, new Float(1));
+ TensorView* tv3 = add(tv1, tv2);
+ fusion.addOutput(tv3);
+
+ tv3->split(-1, 32);
+
+ tv1->computeAt(tv3, -1);
+ tv2->computeAt(tv3, -2);
+
+ prog.device_ = 0;
+ prog.grid(1);
+ prog.block(1);
+
+ torch::jit::fuser::cuda::compileKernel(&prog);
+
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand({100, 100}, options);
+ at::Tensor output = at::empty_like(input, options);
+ torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {output});
+ auto aten_output = (input + 1) * 2;
+ TORCH_CHECK(
+ aten_output.allclose(output),
+ "Error of: ",
+ aten_output.sub(output).abs().max());
+ }
+}
+
+void testGPU_FusionZeroDimComputeAt() {
+ torch::jit::fuser::cuda::CudaKernel prog;
+ Fusion& fusion = *prog.fusion_;
+ FusionGuard fg(&fusion);
+
+ TensorView* tv0 = makeDummyTensor(1);
+ fusion.addInput(tv0);
+
+ auto tv1 = sum(tv0, {0});
+ auto tv2 = add(tv1, new Float(1));
+ fusion.addOutput(tv2);
+ TORCH_CHECK(tv2->nDims() == 0);
+ tv1->computeAt(tv2, 0);
+
+ prog.device_ = 0;
+ prog.grid(1);
+ prog.block(1);
+
+ torch::jit::fuser::cuda::compileKernel(&prog);
+
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand({100}, options);
+ at::Tensor output = at::empty({}, options);
+ torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {output});
+ auto aten_output = input.sum() + 1;
+ TORCH_CHECK(
+ aten_output.allclose(output),
+ "Error of: ",
+ aten_output.sub(output).abs().max());
+}
+
+void testGPU_FusionZeroDimBroadcast() {
+ torch::jit::fuser::cuda::CudaKernel prog;
+ Fusion& fusion = *prog.fusion_;
+ FusionGuard fg(&fusion);
+
+ TensorView* tv0 = makeDummyTensor(0);
+ fusion.addInput(tv0);
+
+ auto tv1 = broadcast(tv0, {true, true});
+ TORCH_CHECK(tv1->nDims() == 2);
+
+ TensorView* tv2 = makeDummyTensor(2);
+ fusion.addInput(tv2);
+
+ auto tv3 = add(tv1, tv2);
+ auto tv4 = sum(tv3, {0, 1});
+ fusion.addOutput(tv4);
+
+ tv3->computeAt(tv4, -1);
+
+ prog.device_ = 0;
+ prog.grid(1);
+ prog.block(1);
+
+ torch::jit::fuser::cuda::compileKernel(&prog);
+
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input1 = at::rand({}, options);
+ at::Tensor input2 = at::rand({10, 10}, options);
+ at::Tensor output = at::empty({}, options);
+ torch::jit::fuser::cuda::runTestKernel(&prog, {input1, input2}, {output});
+ auto aten_output =
+ (input1.unsqueeze(-1).unsqueeze(-1).expand({10, 10}) + input2).sum();
+ TORCH_CHECK(
+ aten_output.allclose(output),
+ "Error of: ",
+ aten_output.sub(output).abs().max());
+}
+
+void testGPU_FusionZeroDimReduction() {
+ torch::jit::fuser::cuda::CudaKernel prog;
+ Fusion& fusion = *prog.fusion_;
+ FusionGuard fg(&fusion);
+
+ const int bdimx = 32;
+ const int gdimx = 32;
+
+ TensorView* tv0 = makeDummyTensor(1);
+ fusion.addInput(tv0);
+
+ auto tv1 = sum(tv0, {0});
+ fusion.addOutput(tv1);
+
+ tv1->split(0, bdimx);
+ tv1->split(0, gdimx);
+ auto tv2 = tv1->rFactor({0});
+
+ tv1->axis(-1)->parallelize(ParallelType::TIDx);
+ tv2->axis(-1)->parallelize(ParallelType::TIDx);
+ tv1->axis(-2)->parallelize(ParallelType::BIDx);
+ tv2->axis(-2)->parallelize(ParallelType::BIDx);
+
+ prog.device_ = 0;
+ prog.grid(gdimx);
+ prog.block(bdimx);
+
+ torch::jit::fuser::cuda::compileKernel(&prog);
+
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand({1000}, options);
+ at::Tensor output = at::empty({}, options);
+ torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {output});
+ auto aten_output = input.sum();
+ TORCH_CHECK(
+ aten_output.allclose(output),
+ "Error of: ",
+ aten_output.sub(output).abs().max());
}
} // namespace jit
diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h
index 381108f..b6f464f 100644
--- a/test/cpp/jit/tests.h
+++ b/test/cpp/jit/tests.h
@@ -97,50 +97,75 @@
_(FusionAliasing)
#if defined(USE_CUDA)
-#define TH_FORALL_TESTS_CUDA(_) \
- _(ArgumentSpec) \
- _(CompleteArgumentSpec) \
- _(Fusion) \
- _(GraphExecutor) \
- _(ModuleConversion) \
- _(Interp) \
- _(GPU_IrGraphGenerator) \
- _(GPU_FusionDispatch) \
- _(GPU_FusionSimpleArith) \
- _(GPU_FusionExprEvalConstants) \
- _(GPU_FusionExprEvalBindings) \
- _(GPU_FusionExprEvalBasic) \
- _(GPU_FusionExprEvalComplex) \
- _(GPU_FusionSimpleTypePromote) \
- _(GPU_FusionMutator) \
- _(GPU_FusionRegister) \
- _(GPU_FusionTopoSort) \
- _(GPU_FusionTensor) \
- _(GPU_FusionTensorContiguity) \
- _(GPU_FusionTVSplit) \
- _(GPU_FusionTVMerge) \
- _(GPU_FusionTVReorder) \
- _(GPU_FusionEquality) \
- _(GPU_FusionReplaceAll) \
- _(GPU_FusionParser) \
- _(GPU_FusionDependency) \
- _(GPU_FusionCodeGen) \
- _(GPU_FusionCodeGen2) \
- _(GPU_FusionSimplePWise) \
- _(GPU_FusionExecKernel) \
- _(GPU_FusionForLoop) \
- _(GPU_FusionLoopUnroll) \
- _(GPU_FusionUnaryOps) \
- _(GPU_FusionBinaryOps) \
- _(GPU_FusionTernaryOps) \
- _(GPU_FusionCompoundOps) \
- _(GPU_FusionCastOps) \
- _(GPU_FusionAdvancedComputeAt) \
- _(GPU_FusionScalarInputs) \
- _(GPU_FusionRFactorReplay) \
- _(GPU_FusionReduction) \
- _(GPU_FusionReduction2) \
- _(GPU_FusionSimpleBCast)
+#define TH_FORALL_TESTS_CUDA(_) \
+ _(ArgumentSpec) \
+ _(CompleteArgumentSpec) \
+ _(Fusion) \
+ _(GraphExecutor) \
+ _(ModuleConversion) \
+ _(Interp) \
+ _(GPU_IrGraphGenerator) \
+ _(GPU_FusionDispatch) \
+ _(GPU_FusionClear) \
+ _(GPU_FusionCopy) \
+ _(GPU_FusionMove) \
+ _(GPU_FusionSimpleArith) \
+ _(GPU_FusionExprEvalConstants) \
+ _(GPU_FusionExprEvalBindings) \
+ _(GPU_FusionExprEvalBasic) \
+ _(GPU_FusionExprEvalComplex) \
+ _(GPU_FusionExprEvalPostLower) \
+ _(GPU_FusionSimpleTypePromote) \
+ _(GPU_FusionMutator) \
+ _(GPU_FusionRegister) \
+ _(GPU_FusionTopoSort) \
+ _(GPU_FusionTensor) \
+ _(GPU_FusionTVSplit) \
+ _(GPU_FusionTVMerge) \
+ _(GPU_FusionTVReorder) \
+ _(GPU_FusionEquality) \
+ _(GPU_FusionReplaceAll) \
+ _(GPU_FusionParser) \
+ _(GPU_FusionDependency) \
+ _(GPU_FusionCodeGen) \
+ _(GPU_FusionCodeGen2) \
+ _(GPU_FusionSimplePWise) \
+ _(GPU_FusionExecKernel) \
+ _(GPU_FusionForLoop) \
+ _(GPU_FusionLoopUnroll) \
+ _(GPU_FusionUnaryOps) \
+ _(GPU_FusionBinaryOps) \
+ _(GPU_FusionTernaryOps) \
+ _(GPU_FusionCompoundOps) \
+ _(GPU_FusionCastOps) \
+ _(GPU_FusionAdvancedComputeAt) \
+ _(GPU_FusionScalarInputs) \
+ _(GPU_FusionRFactorReplay) \
+ _(GPU_FusionReduction) \
+ _(GPU_FusionReduction2) \
+ _(GPU_FusionReduction3) \
+ _(GPU_FusionReduction4) \
+ _(GPU_FusionReduction5) \
+ _(GPU_FusionReductionTFT) \
+ _(GPU_FusionSimpleBCast) \
+ _(GPU_FusionSimpleGemm) \
+ _(GPU_FusionSoftmax) \
+ _(GPU_FusionGridReduction1) \
+ _(GPU_FusionGridReduction2) \
+ _(GPU_FusionGridReduction3dim1) \
+ _(GPU_FusionGridReduction3dim0) \
+ _(GPU_FusionGridReduction4) \
+ _(GPU_FusionGridReduction5) \
+ _(GPU_FusionGridReduction6) \
+ _(GPU_FusionNonRedAxisBind) \
+ _(GPU_FusionBCastInnerDim) \
+ _(GPU_FusionBCastReduce) \
+ _(GPU_FusionSplitBCast) \
+ _(GPU_FusionComputeAtExprOrder) \
+ _(GPU_FusionZeroDimComputeAt) \
+ _(GPU_FusionZeroDimBroadcast) \
+ _(GPU_FusionZeroDimReduction) \
+ _(GPU_FusionReductionMultiConsumer)
#else
#define TH_FORALL_TESTS_CUDA(_) \
_(ArgumentSpec) \
diff --git a/test/run_test.py b/test/run_test.py
index f912c51..53c155b 100755
--- a/test/run_test.py
+++ b/test/run_test.py
@@ -30,6 +30,8 @@
'distributed/test_c10d_spawn',
'test_cuda',
'test_jit_cuda_fuser',
+ 'test_jit_cuda_fuser_legacy',
+ 'test_jit_cuda_fuser_profiling',
'test_cuda_primary_ctx',
'test_dataloader',
'distributed/test_data_parallel',
diff --git a/test/test_jit_cuda_fuser.py b/test/test_jit_cuda_fuser.py
index 2a7421f..cb74a74 100644
--- a/test/test_jit_cuda_fuser.py
+++ b/test/test_jit_cuda_fuser.py
@@ -8,10 +8,12 @@
import torch
-from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR
+from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR, skipIfRocm
from torch.testing._internal.codegen.random_topo_test import runDefaultTestWithSeed
from test_jit import JitTestCase, RUN_CUDA
+import itertools
+import numpy as np
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
torch._C._jit_set_profiling_executor(True)
@@ -321,14 +323,14 @@
where_jit = torch.jit.script(where)
self._run_helper(where_jit, where, x, y, cond)
- def lerp(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor):
+ def lerp(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = torch.rand_like(x)
o = o * torch.lerp(x, y, z)
return o
lerp_jit = torch.jit.script(lerp)
self._run_helper(lerp_jit, lerp, x, y, z)
- def lerp_scale(x : torch.Tensor, y : torch.Tensor, z: float):
+ def lerp_scale(x: torch.Tensor, y: torch.Tensor, z: float):
o = torch.rand_like(x)
o = o * torch.lerp(x, y, z)
return o
@@ -342,21 +344,21 @@
y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
z = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
- def addcmul(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor, value : float):
+ def addcmul(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, value: float):
o = torch.add(x, 0.5)
o = torch.addcmul(o, y, z, value=value)
return o
addcmul_jit = torch.jit.script(addcmul)
self._run_helper(addcmul_jit, addcmul, x, y, z, 2.0)
- def addcmul_no_alpha(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor):
+ def addcmul_no_alpha(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = torch.add(x, 0.5)
o = torch.addcmul(o, y, z)
return o
addcmul_no_alpha_jit = torch.jit.script(addcmul_no_alpha)
self._run_helper(addcmul_no_alpha_jit, addcmul_no_alpha, x, y, z)
- def addcmul_const_alpha(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor):
+ def addcmul_const_alpha(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = torch.add(x, 0.5)
o = torch.addcmul(o, y, z, value=0.75)
return o
@@ -393,6 +395,109 @@
os.environ["PYTORCH_CUDA_FUSER_DISABLE_FALLBACK"] = "1"
self.assertTrue(runDefaultTestWithSeed(28449))
+ def _compare(self, desc, inp1, inp2, error):
+ a = inp1.clone().detach().cpu().numpy()
+ b = inp2.clone().detach().cpu().numpy()
+ close = np.allclose(a, b, error, error)
+ if not close:
+ print(desc, close)
+ z = a - b
+ index = (np.abs(z) >= error + error * np.abs(b)).nonzero()
+ print("dif : ", z[index])
+ print("inp1 : ", a[index])
+ print("inp2 : ", b[index])
+ return close
+
+ def _reduction_helper(self, sizes, reduction_axis, dtype, device):
+ class MyReduction(torch.nn.Module):
+ __constants__ = ['reduction_axis']
+
+ def __init__(self):
+ super(MyReduction, self).__init__()
+ self.reduction_axis = reduction_axis
+
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
+ o = torch.add(x, y)
+ o = torch.sum(o, dim=self.reduction_axis)
+ return o
+
+ t = MyReduction()
+ x = torch.randn(sizes, dtype=dtype, device=device)
+ y = torch.randn(sizes, dtype=dtype, device=device)
+ t_jit = torch.jit.script(t)
+ jit_o = t_jit(x, y)
+ jit_o = t_jit(x, y)
+ o = t(x, y)
+ for oo, jit_oo in zip(o, jit_o):
+ self.assertEqual(oo.dtype, jit_oo.dtype)
+ # numerical issues here due to our scheduling.
+ # can't use `self.assertEqual(oo, jit_oo)`
+ self.assertTrue(self._compare("comparing output failed", oo, jit_oo, 1e-4))
+ self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GROUP)
+
+ @unittest.skipIf(not RUN_CUDA, "requires CUDA")
+ @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR !=
+ ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective")
+ @skipIfRocm
+ def test_reduction(self):
+ for x in ([7, 8, 12], [12, 8, 7, 9, 15], [128, 16, 8, 32]):
+ # note that num_dim is exclusive from len(x), so we are not reducing
+ # to single element (codegen limitation at this moment)
+ for num_reduce_dim in range(1, len(x)):
+ for axes in itertools.combinations(range(len(x)), num_reduce_dim):
+ self._reduction_helper((12, 8, 7, 4, 8), axes, torch.float32, "cuda")
+
+ @unittest.skipIf(not RUN_CUDA, "requires CUDA")
+ @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR !=
+ ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective")
+ @skipIfRocm
+ def test_pw_single_reduction_partition(self):
+ sizes = [8, 8, 8]
+ dtype = torch.float
+ device = "cuda"
+ x = torch.randn(sizes, dtype=dtype, device=device)
+ y = torch.randn(sizes, dtype=dtype, device=device)
+ z = torch.randn(sizes, dtype=dtype, device=device)
+
+ def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
+ o = torch.add(x, y)
+ o = torch.sum(o, dim=[0])
+ o = torch.add(o, z)
+ return o
+ t_jit = torch.jit.script(t)
+ jit_o = t_jit(x, y, z)
+ jit_o = t_jit(x, y, z)
+ o = t(x, y, z)
+ for oo, jit_oo in zip(o, jit_o):
+ self.assertEqual(oo.dtype, jit_oo.dtype)
+ self.assertEqual(oo, jit_oo)
+ self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GROUP)
+
+ @unittest.skipIf(not RUN_CUDA, "requires CUDA")
+ @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR !=
+ ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective")
+ @skipIfRocm
+ def test_single_reduction_broadcast(self):
+ dtype = torch.float
+ device = "cuda"
+ x = torch.randn([7, 4, 8], dtype=dtype, device=device)
+ y = torch.randn([4, 8], dtype=dtype, device=device)
+ z = torch.randn([1, 4, 8], dtype=dtype, device=device)
+
+ def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
+ o = torch.add(x, y)
+ o = torch.add(o, z)
+ o = torch.sum(o, dim=[0])
+ return o
+ t_jit = torch.jit.script(t)
+ jit_o = t_jit(x, y, z)
+ jit_o = t_jit(x, y, z)
+ o = t(x, y, z)
+ for oo, jit_oo in zip(o, jit_o):
+ self.assertEqual(oo.dtype, jit_oo.dtype)
+ self.assertEqual(oo, jit_oo)
+ self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GROUP)
+
class TestPassManagerCudaFuser(JitTestCase):
@@ -441,5 +546,6 @@
self.assertTrue(torch._C._jit_set_nvfuser_enabled(False))
self.assertFalse(torch._C._jit_nvfuser_enabled())
+
if __name__ == '__main__':
run_tests()
diff --git a/test/test_jit_cuda_fuser_legacy.py b/test/test_jit_cuda_fuser_legacy.py
new file mode 100644
index 0000000..4b9959c
--- /dev/null
+++ b/test/test_jit_cuda_fuser_legacy.py
@@ -0,0 +1,6 @@
+import sys
+sys.argv.append("--ge_config=legacy")
+from test_jit_cuda_fuser import *
+
+if __name__ == '__main__':
+ run_tests()
diff --git a/test/test_jit_cuda_fuser_profiling.py b/test/test_jit_cuda_fuser_profiling.py
new file mode 100644
index 0000000..e2869ec
--- /dev/null
+++ b/test/test_jit_cuda_fuser_profiling.py
@@ -0,0 +1,6 @@
+import sys
+sys.argv.append("--ge_config=profiling")
+from test_jit_cuda_fuser import *
+
+if __name__ == '__main__':
+ run_tests()
diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl
index bdc0b1c..b194e4f 100644
--- a/tools/build_variables.bzl
+++ b/tools/build_variables.bzl
@@ -331,20 +331,26 @@
"torch/csrc/autograd/profiler_cuda.cpp",
"torch/csrc/autograd/functions/comm.cpp",
"torch/csrc/jit/codegen/cuda/arith.cpp",
+ "torch/csrc/jit/codegen/cuda/compute_at.cpp",
"torch/csrc/jit/codegen/cuda/dispatch.cpp",
"torch/csrc/jit/codegen/cuda/expr_evaluator.cpp",
"torch/csrc/jit/codegen/cuda/fusion.cpp",
"torch/csrc/jit/codegen/cuda/graph_fuser.cpp",
"torch/csrc/jit/codegen/cuda/index_compute.cpp",
"torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp",
+ "torch/csrc/jit/codegen/cuda/ir_cloner.cpp",
"torch/csrc/jit/codegen/cuda/ir_graphviz.cpp",
"torch/csrc/jit/codegen/cuda/ir_nodes.cpp",
"torch/csrc/jit/codegen/cuda/ir_iostream.cpp",
"torch/csrc/jit/codegen/cuda/iter_visitor.cpp",
"torch/csrc/jit/codegen/cuda/kernel.cpp",
"torch/csrc/jit/codegen/cuda/kernel_cache.cpp",
+ "torch/csrc/jit/codegen/cuda/lower_index.cpp",
"torch/csrc/jit/codegen/cuda/lower_loops.cpp",
+ "torch/csrc/jit/codegen/cuda/lower_unroll.cpp",
+ "torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp",
"torch/csrc/jit/codegen/cuda/lower_utils.cpp",
+ "torch/csrc/jit/codegen/cuda/lower_validation.cpp",
"torch/csrc/jit/codegen/cuda/lower2device.cpp",
"torch/csrc/jit/codegen/cuda/manager.cpp",
"torch/csrc/jit/codegen/cuda/shape_inference.cpp",
@@ -352,7 +358,6 @@
"torch/csrc/jit/codegen/cuda/parser.cpp",
"torch/csrc/jit/codegen/cuda/partition.cpp",
"torch/csrc/jit/codegen/cuda/predicate_compute.cpp",
- "torch/csrc/jit/codegen/cuda/tensor_meta.cpp",
"torch/csrc/jit/codegen/cuda/tensor_view.cpp",
"torch/csrc/jit/codegen/cuda/transform_iter.cpp",
"torch/csrc/jit/codegen/cuda/transform_replay.cpp",
diff --git a/torch/csrc/jit/codegen/cuda/arith.cpp b/torch/csrc/jit/codegen/cuda/arith.cpp
index 0b3785b..8cd7ab9 100644
--- a/torch/csrc/jit/codegen/cuda/arith.cpp
+++ b/torch/csrc/jit/codegen/cuda/arith.cpp
@@ -311,6 +311,19 @@
TORCH_CUDA_API TensorView* lt(TensorView* v1, TensorView* v2) {
return arithOpOverloads(lt, v1, v2);
}
+// eq
+TORCH_CUDA_API Val* eq(Val* v1, Val* v2) {
+ return binaryOp(BinaryOpType::Eq, v1, v2);
+}
+TORCH_CUDA_API TensorView* eq(TensorView* v1, Val* v2) {
+ return arithOpOverloads(eq, v1, v2);
+}
+TORCH_CUDA_API TensorView* eq(Val* v1, TensorView* v2) {
+ return arithOpOverloads(eq, v1, v2);
+}
+TORCH_CUDA_API TensorView* eq(TensorView* v1, TensorView* v2) {
+ return arithOpOverloads(eq, v1, v2);
+}
// ceilDiv
TORCH_CUDA_API Val* ceilDiv(Val* v1, Val* v2) {
return binaryOp(BinaryOpType::CeilDiv, v1, v2);
@@ -348,9 +361,10 @@
// REDUCTION OPERATIONS
-namespace {
// TODO: How do we adjust this so we can reduce to a single scalar value?
-TensorView* newForReduction(TensorView* tv, std::vector<unsigned int> axes) {
+static TensorView* newForReduction(
+ TensorView* tv,
+ const std::vector<unsigned int>& axes) {
auto orig_domain = TensorDomain::noReductions(tv->getRootDomain());
std::set<unsigned int> axes_set(axes.begin(), axes.end());
@@ -363,25 +377,35 @@
(*(axes_set.rbegin())) < orig_domain.size(),
"Error setting up reduction, reduction axis is outside nDims. Keep in mind reductions are relative to root domains, not modified views.");
- for (decltype(orig_domain.size()) dim = 0; dim < orig_domain.size(); dim++) {
- IterDomain* id = orig_domain[dim];
-
+ for (size_t dim = 0; dim < orig_domain.size(); dim++) {
bool isReduction = false;
- if ((*axes_set.begin()) == dim) {
+ if (!axes_set.empty() && *axes_set.begin() == dim) {
isReduction = true;
axes_set.erase(axes_set.begin());
}
+ const IterDomain* id = orig_domain[dim];
+
+ TORCH_CHECK(
+ !(isReduction && id->isBroadcast()),
+ "Cannot reduce an axis that is marked as broadcasted as it has an undetermined size. Tried to reduce ID = ",
+ id,
+ " of tensor ",
+ tv);
+
new_domain.push_back(new IterDomain(
- id->start(), id->extent(), ParallelType::Serial, isReduction));
+ id->start(),
+ id->extent(),
+ ParallelType::Serial,
+ isReduction,
+ false,
+ id->isBroadcast()));
}
TensorDomain* td = new TensorDomain(new_domain);
return new TensorView(td, tv->getDataType().value());
}
-} // namespace
-
TensorView* reductionOp(
BinaryOpType reduction_op_type,
const std::vector<int>& axes,
@@ -395,6 +419,8 @@
TensorDomain::sameAs(tv->getRootDomain(), tv->domain()->domain()),
"Reducing a tensor once it's gone under transformations is not permitted at this time. Please set reductions before calling split/merge/computeAt.");
+ TORCH_CHECK(tv->nDims() > 0, "Tried to reduce a 0-dim tensor");
+
std::vector<unsigned int> uint_axes;
for (int axis : axes) {
if (axis < 0)
@@ -447,9 +473,9 @@
if (ent)
n_broadcasts++;
TORCH_CHECK(
- nBCastDims - n_broadcasts == inp->nDims(),
+ nBCastDims - n_broadcasts == inp->domain()->noReductions().size(),
"Invalid broadcast, number of false entries in is_broadcast_dim expected to be ",
- inp->nDims(),
+ inp->domain()->noReductions().size(),
" but received ",
nBCastDims - n_broadcasts);
@@ -468,7 +494,8 @@
out_domain.push_back(new IterDomain(
new Int(0), new Int(1), ParallelType::Serial, false, false, true));
} else {
- out_domain.push_back(inp->axis(iinp));
+ // Don't propagate reduction IDs through arith ops.
+ out_domain.push_back(inp->domain()->noReductions()[iinp]);
iinp++;
}
ibdim++;
diff --git a/torch/csrc/jit/codegen/cuda/arith.h b/torch/csrc/jit/codegen/cuda/arith.h
index 4961e41..4c61f2d 100644
--- a/torch/csrc/jit/codegen/cuda/arith.h
+++ b/torch/csrc/jit/codegen/cuda/arith.h
@@ -5,7 +5,7 @@
#include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
#include <torch/csrc/jit/codegen/cuda/type.h>
-struct Val;
+class Val;
/*
* The operations defined in this header is intended as user facing functions.
@@ -60,7 +60,8 @@
TensorView* inp,
const std::vector<bool>& is_broadcast_dim);
-// BINARY OPAERATIONS
+// BINARY OPERATIONS
+// add
TORCH_CUDA_API Val* add(Val* v1, Val* v2);
TORCH_CUDA_API TensorView* add(TensorView* v1, Val* v2);
TORCH_CUDA_API TensorView* add(Val* v1, TensorView* v2);
@@ -90,6 +91,11 @@
TORCH_CUDA_API TensorView* lt(TensorView* v1, Val* v2);
TORCH_CUDA_API TensorView* lt(Val* v1, TensorView* v2);
TORCH_CUDA_API TensorView* lt(TensorView* v1, TensorView* v2);
+// eq
+TORCH_CUDA_API Val* eq(Val* v1, Val* v2);
+TORCH_CUDA_API TensorView* eq(TensorView* v1, Val* v2);
+TORCH_CUDA_API TensorView* eq(Val* v1, TensorView* v2);
+TORCH_CUDA_API TensorView* eq(TensorView* v1, TensorView* v2);
// ceilDiv
TORCH_CUDA_API Val* ceilDiv(Val* v1, Val* v2);
TORCH_CUDA_API TensorView* ceilDiv(TensorView* v1, Val* v2);
diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp
new file mode 100644
index 0000000..eea3f4a
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp
@@ -0,0 +1,459 @@
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
+#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
+
+#include <torch/csrc/jit/codegen/cuda/compute_at.h>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+
+// Actually applies transformation
+void ComputeAt::computeAt_impl(
+ TensorView* producer,
+ TensorView* consumer,
+ unsigned int consumer_compute_at_axis) {
+ // Reset view otherwise will conflict with replay.
+ producer->clearComputeAt();
+ // replay this as consumer / producer as consumer
+ auto replay = TransformReplay::replayPasC(
+ producer, consumer, (int)consumer_compute_at_axis);
+ producer->setComputeAt(consumer, replay.second);
+}
+
+// Runs replay, and checks computeAt position. If higher than that provided,
+// actually applies.
+void ComputeAt::maybe_computeAt_impl(
+ TensorView* producer,
+ TensorView* consumer,
+ unsigned int consumer_compute_at_axis) {
+ unsigned int prev_pos = 0;
+ if (producer->hasComputeAt())
+ prev_pos = producer->getThisComputeAtAxis();
+
+ auto replay = TransformReplay::replayPasC(
+ producer->domain(), consumer->domain(), (int)consumer_compute_at_axis);
+
+ if (replay.second > prev_pos) {
+ producer->setDomain(replay.first);
+ producer->setComputeAt(consumer, replay.second);
+ }
+}
+
+// Actually applies transformation
+void ComputeAt::forwardComputeAt_impl(
+ TensorView* producer,
+ TensorView* consumer,
+ unsigned int producer_compute_at_axis) {
+ // Reset view otherwise will conflict with replay. Don't think this is true
+ // anymore.
+ producer->clearComputeAt();
+ auto replay = TransformReplay::replayCasP(
+ consumer, producer, (int)producer_compute_at_axis);
+ producer->setComputeAt(consumer, replay.second);
+}
+
+namespace {
+// Wrapper around set_intersection
+template <typename T>
+std::set<T> set_intersection(const std::set<T>& set1, const std::set<T>& set2) {
+ std::set<T> intersection;
+ std::set_intersection(
+ set1.begin(),
+ set1.end(),
+ set2.begin(),
+ set2.end(),
+ std::inserter(intersection, intersection.begin()));
+ return intersection;
+}
+
+// convert an iterable of Val* to be an iterable of TensorView*
+template <typename T1, typename T2>
+T1 tv_iterable(const T2& val_iterable) {
+ T1 tv_iterable = T1();
+ std::transform(
+ val_iterable.begin(),
+ val_iterable.end(),
+ std::back_inserter(tv_iterable),
+ [](Val* v) {
+ TORCH_INTERNAL_ASSERT(
+ v->getValType().value() == ValType::TensorView,
+ "When following the computeAt dependency chain, a non TensorView value was found.");
+ return static_cast<TensorView*>(v);
+ });
+ return tv_iterable;
+}
+
+std::deque<std::deque<TensorView*>> getAllTVUseChains(TensorView* tv) {
+ // Grab all paths from producer to of producer in fusion.
+ auto val_all_use_chains = DependencyCheck::getAllUseChains(tv);
+
+ // Convert dep chains to tensor view chains.
+ std::deque<std::deque<TensorView*>> producer_use_chains_;
+ for (const auto& val_dep_chain : val_all_use_chains)
+ producer_use_chains_.push_back(
+ tv_iterable<std::deque<TensorView*>>(val_dep_chain));
+ return producer_use_chains_;
+}
+} // namespace
+
+void ComputeAt::setCommonConsumer() {
+ // Convert the first chain to a set.
+ std::set<TensorView*> common_consumers(
+ producer_use_chains_.front().begin(), producer_use_chains_.front().end());
+
+ // Run through all use chains of producer, and intersect them to find common
+ // TVs
+ for (auto dep_chain : producer_use_chains_)
+ common_consumers = set_intersection(
+ common_consumers,
+ std::set<TensorView*>(dep_chain.begin(), dep_chain.end()));
+
+ auto all_chains =
+ DependencyCheck::getAllDependencyChains(producer_, consumer_);
+
+ // Right now we only support compute at if at some point in the graph consumer
+ // is dependent on producer.
+ TORCH_CHECK(
+ !all_chains.empty(),
+ "Compute At expects ",
+ producer_,
+ " is a dependency of ",
+ consumer_,
+ ", however it is not.");
+
+ // Remove all TVs from producer to consumer as common consumer must be at or
+ // after consumer
+ for (const auto& dep_chain : all_chains) {
+ auto tv_chain = tv_iterable<std::deque<TensorView*>>(dep_chain);
+ for (auto tv : tv_chain) {
+ if (tv != consumer_)
+ common_consumers.erase(tv);
+ }
+ }
+
+ // If there is a common consumer, grab the first one at or after consumer
+ common_consumer_ = nullptr;
+ if (!common_consumers.empty()) {
+ for (TensorView* tv : producer_use_chains_.front())
+ if (common_consumers.find(tv) != common_consumers.end()) {
+ common_consumer_ = tv;
+ break;
+ }
+ TORCH_INTERNAL_ASSERT(
+ common_consumer_ != nullptr,
+ "Hit a logical inconsistency in the computeAt pass.");
+ }
+}
+
+void ComputeAt::traverseAllKnown() {
+ std::deque<std::deque<Val*>> chains;
+
+ // propagate backwards through all dep chains from producer to consumer
+
+ // Grab all chains from common_consumer to producer
+ chains = DependencyCheck::getAllDependencyChains(producer_, consumer_);
+
+ TORCH_CHECK(
+ !chains.empty(),
+ "Producer and consumer in a computeAt call must have a dependency between them even if indirect.");
+
+ for (const auto& val_chain : chains) {
+ auto tv_chain = tv_iterable<std::deque<TensorView*>>(val_chain);
+ TensorView* running_consumer = nullptr;
+ TensorView* running_producer = tv_chain.back();
+ unsigned int running_consumer_pos = consumer_position_;
+
+ tv_chain.pop_back();
+
+ while (!tv_chain.empty()) {
+ running_consumer = running_producer;
+ running_producer = tv_chain.back();
+ tv_chain.pop_back();
+
+ if (compute_at_ed.find(running_producer) != compute_at_ed.end() &&
+ known_positions.find(running_producer) != known_positions.end()) {
+ running_consumer_pos = known_positions.at(running_producer);
+ continue;
+ }
+
+ computeAt_impl(running_producer, running_consumer, running_consumer_pos);
+ running_consumer_pos = running_producer->getThisComputeAtAxis();
+
+ // Update both compute_at_ed and compute_at_axis_lookup
+ compute_at_ed.emplace(running_producer);
+
+ if (known_positions.find(running_producer) != known_positions.end()) {
+ TORCH_INTERNAL_ASSERT(
+ known_positions.at(running_producer) ==
+ running_producer->getThisComputeAtAxis(),
+ "Hit a logical inconsistency in the computeAt pass.");
+ } else {
+ known_positions[running_producer] =
+ running_producer->getThisComputeAtAxis();
+ }
+ }
+ }
+
+ // propagate forward through all consumer use_chains or from consumer to
+ // common_consumer if common_consumer exists, mark as finished.
+
+ if (common_consumer_ == nullptr) {
+ chains = DependencyCheck::getAllUseChains(consumer_);
+ } else if (common_consumer_ != consumer_) {
+ chains =
+ DependencyCheck::getAllDependencyChains(consumer_, common_consumer_);
+ }
+
+ // propagate forward through all chains
+ unsigned int running_producer_compute_at = consumer_position_;
+
+ for (const auto& dep_chain : chains) {
+ TORCH_INTERNAL_ASSERT(
+ !dep_chain.empty(), "Computed an invalid common_consumer.");
+
+ std::deque<TensorView*> tv_dep_chain =
+ tv_iterable<std::deque<TensorView*>>(dep_chain);
+
+ TensorView* running_consumer = tv_dep_chain.front();
+ tv_dep_chain.pop_front();
+
+ TensorView* running_producer = nullptr;
+
+ while (!tv_dep_chain.empty()) {
+ running_producer = running_consumer;
+ running_consumer = tv_dep_chain.front();
+ tv_dep_chain.pop_front();
+
+ if (compute_at_ed.find(running_producer) != compute_at_ed.end() &&
+ known_positions.find(running_consumer) != known_positions.end()) {
+ running_producer_compute_at = known_positions.at(running_consumer);
+ continue;
+ }
+
+ forwardComputeAt_impl(
+ running_producer, running_consumer, running_producer_compute_at);
+
+ compute_at_ed.emplace(running_producer);
+
+ if (known_positions.find(running_consumer) != known_positions.end()) {
+ TORCH_INTERNAL_ASSERT(
+ known_positions.at(running_consumer) ==
+ running_producer->getRelativeComputeAtAxis(),
+ "Hit a logical inconsistency in computeAt pass.");
+ } else {
+ known_positions[running_consumer] =
+ running_producer->getRelativeComputeAtAxis();
+ }
+ }
+ }
+}
+
+// Similar to forward traversal in traverseAllKnown but we don't know if the
+// positions are actually correct
+void ComputeAt::traverseForward() {
+ // propagate forward through all *producer* use_chains or from *producer* to
+ // common_consumer if common_consumer exists.
+ std::deque<std::deque<Val*>> chains;
+ if (common_consumer_ == nullptr) {
+ chains = DependencyCheck::getAllUseChains(producer_);
+ } else if (common_consumer_ != consumer_) {
+ chains =
+ DependencyCheck::getAllDependencyChains(producer_, common_consumer_);
+ }
+
+ // propagate forward through all chains
+ for (const auto& dep_chain : chains) {
+ int running_producer_compute_at = known_positions.at(producer_);
+ TORCH_INTERNAL_ASSERT(
+ !dep_chain.empty(), "Computed an invalid common_consumer.");
+
+ std::deque<TensorView*> tv_dep_chain =
+ tv_iterable<std::deque<TensorView*>>(dep_chain);
+
+ TensorView* running_consumer = tv_dep_chain.front();
+ tv_dep_chain.pop_front();
+
+ TensorView* running_producer = nullptr;
+
+ while (!tv_dep_chain.empty()) {
+ running_producer = running_consumer;
+ running_consumer = tv_dep_chain.front();
+ tv_dep_chain.pop_front();
+
+ if (compute_at_ed.find(running_producer) != compute_at_ed.end() &&
+ known_positions.find(running_consumer) != known_positions.end()) {
+ running_producer_compute_at = known_positions.at(running_consumer);
+ continue;
+ }
+
+ forwardComputeAt_impl(
+ running_producer, running_consumer, running_producer_compute_at);
+
+ compute_at_ed.emplace(running_producer);
+
+ if (known_positions.find(running_consumer) != known_positions.end()) {
+ TORCH_INTERNAL_ASSERT(
+ known_positions.at(running_consumer) ==
+ running_producer->getRelativeComputeAtAxis(),
+ "Hit a logical inconsistency in computeAt pass.");
+ }
+ }
+ }
+}
+
+// Similar to backward traversal in traverseAllKnown but we should only apply
+// computeAt if it will increase computeAt positions.
+void ComputeAt::traverseBackward() {
+ // propagate *backward* through all *producer* use_chains or from *producer*
+ // to common_consumer if common_consumer exists. Only apply transform if
+ // increases computeAt position.
+ std::deque<std::deque<Val*>> chains;
+ if (common_consumer_ == nullptr) {
+ chains = DependencyCheck::getAllUseChains(producer_);
+ } else if (common_consumer_ != consumer_) {
+ chains =
+ DependencyCheck::getAllDependencyChains(producer_, common_consumer_);
+ }
+
+ for (const auto& val_chain : chains) {
+ auto tv_chain = tv_iterable<std::deque<TensorView*>>(val_chain);
+ TensorView* running_consumer = nullptr;
+ TensorView* running_producer = tv_chain.back();
+ auto it = known_positions.find(running_producer);
+
+ if (it == known_positions.end()) {
+ TORCH_INTERNAL_ASSERT(
+ common_consumer_ == nullptr,
+ "Hit a logical inconsistency in computeAt pass.");
+ continue;
+ }
+
+ unsigned int running_consumer_pos = it->second;
+
+ tv_chain.pop_back();
+
+ while (!tv_chain.empty()) {
+ running_consumer = running_producer;
+ running_producer = tv_chain.back();
+ tv_chain.pop_back();
+
+ if (compute_at_ed.find(running_producer) != compute_at_ed.end() &&
+ known_positions.find(running_producer) != known_positions.end()) {
+ running_consumer_pos = known_positions.at(running_producer);
+ continue;
+ }
+
+ // If we're already at consumer_position_ that's the max position we could
+ // hope for, don't bother running again.
+ if (running_producer->getThisComputeAtAxis() != consumer_position_) {
+ maybe_computeAt_impl(
+ running_producer, running_consumer, running_consumer_pos);
+ }
+ running_consumer_pos = running_producer->getThisComputeAtAxis();
+
+ if (known_positions.find(running_producer) != known_positions.end()) {
+ TORCH_INTERNAL_ASSERT(
+ known_positions.at(running_producer) ==
+ running_producer->getThisComputeAtAxis(),
+ "Hit a logical inconsistency in the computeAt pass.");
+ }
+ }
+ }
+}
+
+void ComputeAt::runPass() {
+ // Make sure the correct fusion is setup between this and consumer.
+ TORCH_CHECK(
+ producer_->fusion() == consumer_->fusion(),
+ producer_,
+ " and ",
+ consumer_,
+ " are not in the same fusion.");
+
+ // Make sure Fusion Guard is set appropriately
+ FusionGuard fg(producer_->fusion());
+
+ // Look through all the use chains of producer. Check if there's a single
+ // consumer for all chains at or after the consumer specified in the computeAt
+ // call.
+ setCommonConsumer();
+
+ // Propagate in a way we know result will be correct, which is forward from
+ // consumer and backward from consumer to producer
+ traverseAllKnown();
+
+ TORCH_INTERNAL_ASSERT(
+ producer_->hasComputeAt(),
+ "Hit a logical inconsistency in the computeAt pass.");
+
+ // Start at producer and traverse forward
+ traverseForward();
+
+ // Propagate backward from consumer or common consumer, check if it increase
+ // computeAt position on tensors, if so take it!
+ traverseBackward();
+}
+
+void ComputeAt::setupOutputs() {
+ if (common_consumer_ != nullptr)
+ return;
+
+ // output and its compute at position
+ std::unordered_map<TensorView*, int> touched_outputs;
+ for (auto tv : compute_at_ed) {
+ TORCH_INTERNAL_ASSERT(
+ tv->hasComputeAt(),
+ "Hit a logical inconsistency in the computeAt pass.");
+ auto ca_view = tv->getComputeAtView();
+ if (FusionGuard::getCurFusion()->hasOutput(ca_view)) {
+ touched_outputs[ca_view] = tv->getRelativeComputeAtAxis();
+ }
+ }
+
+ std::vector<TensorView*> touched_output_order(touched_outputs.size());
+
+ {
+ size_t i = 0;
+ for (auto out : FusionGuard::getCurFusion()->outputs()) {
+ if (out->getValType() == ValType::TensorView) {
+ if (touched_outputs.find(out->as<TensorView>()) !=
+ touched_outputs.end()) {
+ touched_output_order[i++] = out->as<TensorView>();
+ }
+ }
+ }
+ TORCH_INTERNAL_ASSERT(
+ i == touched_output_order.size(),
+ "Hit a logical inconsistency in the computeAt pass.");
+ }
+
+ for (size_t i = 0; i < touched_output_order.size() - 1; i++) {
+ touched_output_order[i]->setComputeAt(
+ touched_output_order[i + 1],
+ touched_outputs.at(touched_output_order[i]),
+ touched_outputs.at(touched_output_order[i + 1]));
+ }
+}
+
+ComputeAt::ComputeAt(
+ TensorView* _producer,
+ TensorView* _consumer,
+ unsigned int _consumer_position)
+ : producer_(_producer),
+ consumer_(_consumer),
+ consumer_position_(_consumer_position) {}
+
+void ComputeAt::run(
+ TensorView* producer,
+ TensorView* consumer,
+ unsigned int consumer_position) {
+ ComputeAt ca(producer, consumer, consumer_position);
+ ca.producer_use_chains_ = getAllTVUseChains(ca.producer_);
+ ca.setCommonConsumer();
+ ca.runPass();
+ ca.setupOutputs();
+}
+
+} // namespace fuser
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/compute_at.h b/torch/csrc/jit/codegen/cuda/compute_at.h
new file mode 100644
index 0000000..3c2434d
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/compute_at.h
@@ -0,0 +1,98 @@
+#pragma once
+
+#include <c10/util/Exception.h>
+#include <torch/csrc/WindowsTorchApiMacro.h>
+
+#include <deque>
+#include <unordered_map>
+#include <vector>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+
+class TensorView;
+
+class ComputeAt {
+ public:
+ static void run(
+ TensorView* _producer,
+ TensorView* _consumer,
+ unsigned int _consumer_position);
+
+ private:
+ TensorView* producer_;
+ TensorView* consumer_;
+ unsigned int consumer_position_;
+
+ // Only keeping these as member functions as ComputeAt is friend of TensorView
+ // Don't want to keep expanding things that are friends of TV.
+ // Runs replayPasC and sets producer computeAt settings
+ void computeAt_impl(
+ TensorView* producer,
+ TensorView* consumer,
+ unsigned int consumer_compute_at_axis);
+
+ // Runs replay, and checks computeAt position of producer. If new position
+ // would be higher, actually runs operation.
+ void maybe_computeAt_impl(
+ TensorView* producer,
+ TensorView* consumer,
+ unsigned int consumer_compute_at_axis);
+
+ // Runs replayCasP and sets producer computeAt settings
+ void forwardComputeAt_impl(
+ TensorView* producer,
+ TensorView* consumer,
+ unsigned int producer_compute_at_axis);
+
+ // Look through all the use chains of producer. Check if there's a single
+ // consumer for all chains at or after the consumer specified in the computeAt
+ // call.
+ void setCommonConsumer();
+
+ // Propagate in a way we know result will be correct, which is forward from
+ // consumer and backward from consumer to producer
+ void traverseAllKnown();
+
+ // Traverse from producer to common_consumer if exists or through all uses of
+ // producer
+ void traverseForward();
+
+ // Propagate backward from consumer or common consumer, check if it increase
+ // computeAt position on tensors, if so take it!
+ void traverseBackward();
+
+ // Run the computeAt pass
+ void runPass();
+
+ // Set outputs relative to eachother if there is not a common consumer
+ void setupOutputs();
+
+ // Common consumer if it exists
+ TensorView* common_consumer_ = nullptr;
+
+ // Producer use chains set in, used in a few spots.
+ std::deque<std::deque<TensorView*>> producer_use_chains_;
+
+ // Order for forward computeAt pass
+ std::vector<std::pair<TensorView*, TensorView*>> forward_compute_at_order;
+
+ // Order for backward computeAt pass
+ std::vector<std::pair<TensorView*, TensorView*>> backward_compute_at_order;
+
+ // TensorViews we've set computeAt of, in this computeAt pass
+ std::unordered_set<TensorView*> compute_at_ed;
+
+ // TensorViews of which we know their correct computeAt position
+ std::unordered_map<TensorView*, unsigned int> known_positions;
+
+ ComputeAt(
+ TensorView* _producer,
+ TensorView* _consumer,
+ unsigned int _consumer_position);
+};
+
+} // namespace fuser
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/dispatch.cpp b/torch/csrc/jit/codegen/cuda/dispatch.cpp
index a3a5522..197c5e6 100644
--- a/torch/csrc/jit/codegen/cuda/dispatch.cpp
+++ b/torch/csrc/jit/codegen/cuda/dispatch.cpp
@@ -36,7 +36,7 @@
* }
*
* And therefore dispatch should never call:
- * ptr(mutator)->handle(static_cast<Statement*>(this));
+ * ptr(mutator)->handle(this->as<Statement>());
*/
template <typename T>
@@ -45,35 +45,35 @@
case ValType::Scalar:
switch (*(val->getDataType())) {
case DataType::Bool:
- ptr(handler)->handle(static_cast<Bool*>(val));
+ ptr(handler)->handle(val->as<Bool>());
return;
case DataType::Float:
- ptr(handler)->handle(static_cast<Float*>(val));
+ ptr(handler)->handle(val->as<Float>());
return;
case DataType::Half:
- ptr(handler)->handle(static_cast<Half*>(val));
+ ptr(handler)->handle(val->as<Half>());
return;
case DataType::Int:
- ptr(handler)->handle(static_cast<Int*>(val));
+ ptr(handler)->handle(val->as<Int>());
return;
default:
break;
}
break;
case ValType::IterDomain:
- ptr(handler)->handle(static_cast<IterDomain*>(val));
+ ptr(handler)->handle(val->as<IterDomain>());
return;
case ValType::TensorDomain:
- ptr(handler)->handle(static_cast<TensorDomain*>(val));
+ ptr(handler)->handle(val->as<TensorDomain>());
return;
case ValType::TensorView:
- ptr(handler)->handle(static_cast<TensorView*>(val));
+ ptr(handler)->handle(val->as<TensorView>());
return;
case ValType::TensorIndex:
- ptr(handler)->handle(static_cast<TensorIndex*>(val));
+ ptr(handler)->handle(val->as<TensorIndex>());
return;
case ValType::NamedScalar:
- ptr(handler)->handle(static_cast<NamedScalar*>(val));
+ ptr(handler)->handle(val->as<NamedScalar>());
return;
default:
break;
@@ -85,34 +85,34 @@
void Expr::dispatch(T handler, Expr* expr) {
switch (*(expr->getExprType())) {
case ExprType::Split:
- ptr(handler)->handle(static_cast<Split*>(expr));
+ ptr(handler)->handle(expr->as<Split>());
return;
case ExprType::Merge:
- ptr(handler)->handle(static_cast<Merge*>(expr));
+ ptr(handler)->handle(expr->as<Merge>());
return;
case ExprType::UnaryOp:
- ptr(handler)->handle(static_cast<UnaryOp*>(expr));
+ ptr(handler)->handle(expr->as<UnaryOp>());
return;
case ExprType::BinaryOp:
- ptr(handler)->handle(static_cast<BinaryOp*>(expr));
+ ptr(handler)->handle(expr->as<BinaryOp>());
return;
case ExprType::TernaryOp:
- ptr(handler)->handle(static_cast<TernaryOp*>(expr));
+ ptr(handler)->handle(expr->as<TernaryOp>());
return;
case ExprType::ReductionOp:
- ptr(handler)->handle(static_cast<ReductionOp*>(expr));
+ ptr(handler)->handle(expr->as<ReductionOp>());
return;
case ExprType::BroadcastOp:
- ptr(handler)->handle(static_cast<BroadcastOp*>(expr));
+ ptr(handler)->handle(expr->as<BroadcastOp>());
return;
case ExprType::ForLoop:
- ptr(handler)->handle(static_cast<ForLoop*>(expr));
+ ptr(handler)->handle(expr->as<ForLoop>());
return;
case ExprType::IfThenElse:
- ptr(handler)->handle(static_cast<IfThenElse*>(expr));
+ ptr(handler)->handle(expr->as<IfThenElse>());
return;
case ExprType::Allocate:
- ptr(handler)->handle(static_cast<Allocate*>(expr));
+ ptr(handler)->handle(expr->as<Allocate>());
return;
default:
TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
@@ -122,9 +122,9 @@
template <typename T>
void Statement::dispatch(T handler, Statement* stmt) {
if (stmt->isVal()) {
- ptr(handler)->handle(static_cast<Val*>(stmt));
+ ptr(handler)->handle(stmt->as<Val>());
} else if (stmt->isExpr()) {
- ptr(handler)->handle(static_cast<Expr*>(stmt));
+ ptr(handler)->handle(stmt->as<Expr>());
} else
TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!");
}
@@ -135,35 +135,35 @@
case ValType::Scalar:
switch (*(val->getDataType())) {
case DataType::Bool:
- ptr(handler)->handle(static_cast<const Bool*>(val));
+ ptr(handler)->handle(val->as<Bool>());
return;
case DataType::Float:
- ptr(handler)->handle(static_cast<const Float*>(val));
+ ptr(handler)->handle(val->as<Float>());
return;
case DataType::Half:
- ptr(handler)->handle(static_cast<const Half*>(val));
+ ptr(handler)->handle(val->as<Half>());
return;
case DataType::Int:
- ptr(handler)->handle(static_cast<const Int*>(val));
+ ptr(handler)->handle(val->as<Int>());
return;
default:
break;
}
break;
case ValType::IterDomain:
- ptr(handler)->handle(static_cast<const IterDomain*>(val));
+ ptr(handler)->handle(val->as<IterDomain>());
return;
case ValType::TensorDomain:
- ptr(handler)->handle(static_cast<const TensorDomain*>(val));
+ ptr(handler)->handle(val->as<TensorDomain>());
return;
case ValType::TensorView:
- ptr(handler)->handle(static_cast<const TensorView*>(val));
+ ptr(handler)->handle(val->as<TensorView>());
return;
case ValType::TensorIndex:
- ptr(handler)->handle(static_cast<const TensorIndex*>(val));
+ ptr(handler)->handle(val->as<TensorIndex>());
return;
case ValType::NamedScalar:
- ptr(handler)->handle(static_cast<const NamedScalar*>(val));
+ ptr(handler)->handle(val->as<NamedScalar>());
return;
default:
break;
@@ -175,34 +175,34 @@
void Expr::constDispatch(T handler, const Expr* expr) {
switch (*(expr->getExprType())) {
case ExprType::Split:
- ptr(handler)->handle(static_cast<const Split*>(expr));
+ ptr(handler)->handle(expr->as<Split>());
return;
case ExprType::Merge:
- ptr(handler)->handle(static_cast<const Merge*>(expr));
+ ptr(handler)->handle(expr->as<Merge>());
return;
case ExprType::UnaryOp:
- ptr(handler)->handle(static_cast<const UnaryOp*>(expr));
+ ptr(handler)->handle(expr->as<UnaryOp>());
return;
case ExprType::BinaryOp:
- ptr(handler)->handle(static_cast<const BinaryOp*>(expr));
+ ptr(handler)->handle(expr->as<BinaryOp>());
return;
case ExprType::TernaryOp:
- ptr(handler)->handle(static_cast<const TernaryOp*>(expr));
+ ptr(handler)->handle(expr->as<TernaryOp>());
return;
case ExprType::ReductionOp:
- ptr(handler)->handle(static_cast<const ReductionOp*>(expr));
+ ptr(handler)->handle(expr->as<ReductionOp>());
return;
case ExprType::BroadcastOp:
- ptr(handler)->handle(static_cast<const BroadcastOp*>(expr));
+ ptr(handler)->handle(expr->as<BroadcastOp>());
return;
case ExprType::ForLoop:
- ptr(handler)->handle(static_cast<const ForLoop*>(expr));
+ ptr(handler)->handle(expr->as<ForLoop>());
return;
case ExprType::IfThenElse:
- ptr(handler)->handle(static_cast<const IfThenElse*>(expr));
+ ptr(handler)->handle(expr->as<IfThenElse>());
return;
case ExprType::Allocate:
- ptr(handler)->handle(static_cast<const Allocate*>(expr));
+ ptr(handler)->handle(expr->as<Allocate>());
return;
default:
TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
@@ -212,9 +212,9 @@
template <typename T>
void Statement::constDispatch(T handler, const Statement* stmt) {
if (stmt->isVal()) {
- ptr(handler)->handle(static_cast<const Val*>(stmt));
+ ptr(handler)->handle(stmt->as<Val>());
} else if (stmt->isExpr()) {
- ptr(handler)->handle(static_cast<const Expr*>(stmt));
+ ptr(handler)->handle(stmt->as<Expr>());
} else
TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!");
}
@@ -228,7 +228,7 @@
* implement Statement* mutate(Statement* stmt){ stmt->mutatorDispatch(this);
* }
* And therefore dispatch should never call:
- * ptr(mutator)->mutate(static_cast<Statement*>(this));
+ * ptr(mutator)->mutate(this->as<Statement>());
*/
template <typename T>
Statement* Val::mutatorDispatch(T mutator, Val* val) {
@@ -236,27 +236,27 @@
case ValType::Scalar:
switch (*(val->getDataType())) {
case DataType::Bool:
- return ptr(mutator)->mutate(static_cast<Bool*>(val));
+ return ptr(mutator)->mutate(val->as<Bool>());
case DataType::Float:
- return ptr(mutator)->mutate(static_cast<Float*>(val));
+ return ptr(mutator)->mutate(val->as<Float>());
case DataType::Half:
- return ptr(mutator)->mutate(static_cast<Half*>(val));
+ return ptr(mutator)->mutate(val->as<Half>());
case DataType::Int:
- return ptr(mutator)->mutate(static_cast<Int*>(val));
+ return ptr(mutator)->mutate(val->as<Int>());
default:
break;
}
break;
case ValType::IterDomain:
- return ptr(mutator)->mutate(static_cast<IterDomain*>(val));
+ return ptr(mutator)->mutate(val->as<IterDomain>());
case ValType::TensorDomain:
- return ptr(mutator)->mutate(static_cast<TensorDomain*>(val));
+ return ptr(mutator)->mutate(val->as<TensorDomain>());
case ValType::TensorView:
- return ptr(mutator)->mutate(static_cast<TensorView*>(val));
+ return ptr(mutator)->mutate(val->as<TensorView>());
case ValType::TensorIndex:
- return ptr(mutator)->mutate(static_cast<TensorIndex*>(val));
+ return ptr(mutator)->mutate(val->as<TensorIndex>());
case ValType::NamedScalar:
- return ptr(mutator)->mutate(static_cast<NamedScalar*>(val));
+ return ptr(mutator)->mutate(val->as<NamedScalar>());
default:
break;
}
@@ -267,25 +267,25 @@
Statement* Expr::mutatorDispatch(T mutator, Expr* expr) {
switch (*(expr->getExprType())) {
case ExprType::Split:
- return ptr(mutator)->mutate(static_cast<Split*>(expr));
+ return ptr(mutator)->mutate(expr->as<Split>());
case ExprType::Merge:
- return ptr(mutator)->mutate(static_cast<Merge*>(expr));
+ return ptr(mutator)->mutate(expr->as<Merge>());
case ExprType::UnaryOp:
- return ptr(mutator)->mutate(static_cast<UnaryOp*>(expr));
+ return ptr(mutator)->mutate(expr->as<UnaryOp>());
case ExprType::BinaryOp:
- return ptr(mutator)->mutate(static_cast<BinaryOp*>(expr));
+ return ptr(mutator)->mutate(expr->as<BinaryOp>());
case ExprType::TernaryOp:
- return ptr(mutator)->mutate(static_cast<TernaryOp*>(expr));
+ return ptr(mutator)->mutate(expr->as<TernaryOp>());
case ExprType::ReductionOp:
- return ptr(mutator)->mutate(static_cast<ReductionOp*>(expr));
+ return ptr(mutator)->mutate(expr->as<ReductionOp>());
case ExprType::BroadcastOp:
- return ptr(mutator)->mutate(static_cast<BroadcastOp*>(expr));
+ return ptr(mutator)->mutate(expr->as<BroadcastOp>());
case ExprType::ForLoop:
- return ptr(mutator)->mutate(static_cast<ForLoop*>(expr));
+ return ptr(mutator)->mutate(expr->as<ForLoop>());
case ExprType::IfThenElse:
- return ptr(mutator)->mutate(static_cast<IfThenElse*>(expr));
+ return ptr(mutator)->mutate(expr->as<IfThenElse>());
case ExprType::Allocate:
- return ptr(mutator)->mutate(static_cast<Allocate*>(expr));
+ return ptr(mutator)->mutate(expr->as<Allocate>());
default:
TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
}
@@ -294,10 +294,10 @@
template <typename T>
Statement* Statement::mutatorDispatch(T mutator, Statement* stmt) {
if (stmt->isVal()) {
- return ptr(mutator)->mutate(static_cast<Val*>(stmt));
+ return ptr(mutator)->mutate(stmt->as<Val>());
}
if (stmt->isExpr()) {
- return ptr(mutator)->mutate(static_cast<Expr*>(stmt));
+ return ptr(mutator)->mutate(stmt->as<Expr>());
}
TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!");
}
@@ -321,27 +321,19 @@
template void Expr::dispatch(OptInDispatch, Expr*);
template void Expr::dispatch(OptInDispatch*, Expr*);
-template void Statement::constDispatch(
- OptOutConstDispatch,
- const Statement* const);
-template void Statement::constDispatch(
- OptOutConstDispatch*,
- const Statement* const);
-template void Val::constDispatch(OptOutConstDispatch, const Val* const);
-template void Val::constDispatch(OptOutConstDispatch*, const Val* const);
-template void Expr::constDispatch(OptOutConstDispatch, const Expr* const);
-template void Expr::constDispatch(OptOutConstDispatch*, const Expr* const);
+template void Statement::constDispatch(OptOutConstDispatch, const Statement*);
+template void Statement::constDispatch(OptOutConstDispatch*, const Statement*);
+template void Val::constDispatch(OptOutConstDispatch, const Val*);
+template void Val::constDispatch(OptOutConstDispatch*, const Val*);
+template void Expr::constDispatch(OptOutConstDispatch, const Expr*);
+template void Expr::constDispatch(OptOutConstDispatch*, const Expr*);
-template void Statement::constDispatch(
- OptInConstDispatch,
- const Statement* const);
-template void Statement::constDispatch(
- OptInConstDispatch*,
- const Statement* const);
-template void Val::constDispatch(OptInConstDispatch, const Val* const);
-template void Val::constDispatch(OptInConstDispatch*, const Val* const);
-template void Expr::constDispatch(OptInConstDispatch, const Expr* const);
-template void Expr::constDispatch(OptInConstDispatch*, const Expr* const);
+template void Statement::constDispatch(OptInConstDispatch, const Statement*);
+template void Statement::constDispatch(OptInConstDispatch*, const Statement*);
+template void Val::constDispatch(OptInConstDispatch, const Val*);
+template void Val::constDispatch(OptInConstDispatch*, const Val*);
+template void Expr::constDispatch(OptInConstDispatch, const Expr*);
+template void Expr::constDispatch(OptInConstDispatch*, const Expr*);
template Statement* Statement::mutatorDispatch(OptOutMutator, Statement*);
template Statement* Statement::mutatorDispatch(OptOutMutator*, Statement*);
@@ -360,9 +352,11 @@
void OptOutDispatch::handle(Statement* s) {
Statement::dispatch(this, s);
}
+
void OptOutDispatch::handle(Expr* e) {
Expr::dispatch(this, e);
}
+
void OptOutDispatch::handle(Val* v) {
Val::dispatch(this, v);
}
@@ -370,30 +364,36 @@
void OptInDispatch::handle(Statement* s) {
Statement::dispatch(this, s);
}
+
void OptInDispatch::handle(Expr* e) {
Expr::dispatch(this, e);
}
+
void OptInDispatch::handle(Val* v) {
Val::dispatch(this, v);
}
-void OptOutConstDispatch::handle(const Statement* const s) {
+void OptOutConstDispatch::handle(const Statement* s) {
Statement::constDispatch(this, s);
}
-void OptOutConstDispatch::handle(const Expr* const e) {
+
+void OptOutConstDispatch::handle(const Expr* e) {
Expr::constDispatch(this, e);
}
-void OptOutConstDispatch::handle(const Val* const v) {
+
+void OptOutConstDispatch::handle(const Val* v) {
Val::constDispatch(this, v);
}
-void OptInConstDispatch::handle(const Statement* const s) {
+void OptInConstDispatch::handle(const Statement* s) {
Statement::constDispatch(this, s);
}
-void OptInConstDispatch::handle(const Expr* const e) {
+
+void OptInConstDispatch::handle(const Expr* e) {
Expr::constDispatch(this, e);
}
-void OptInConstDispatch::handle(const Val* const v) {
+
+void OptInConstDispatch::handle(const Val* v) {
Val::constDispatch(this, v);
}
@@ -415,9 +415,11 @@
Statement* OptOutMutator::mutate(Statement* s) {
return Statement::mutatorDispatch(this, s);
}
+
Statement* OptOutMutator::mutate(Expr* e) {
return Expr::mutatorDispatch(this, e);
}
+
Statement* OptOutMutator::mutate(Val* v) {
// If value is already mutated, return the mutation
if (mutations.find(v) != mutations.end())
diff --git a/torch/csrc/jit/codegen/cuda/dispatch.h b/torch/csrc/jit/codegen/cuda/dispatch.h
index 5a3aa89..50f4451 100644
--- a/torch/csrc/jit/codegen/cuda/dispatch.h
+++ b/torch/csrc/jit/codegen/cuda/dispatch.h
@@ -48,41 +48,42 @@
namespace jit {
namespace fuser {
-struct Fusion;
+class Fusion;
// Hierarchal dispatch functions for handle
-struct Statement;
-struct Expr;
-struct Val;
+class Statement;
+class Expr;
+class Val;
// Vals
-struct IterDomain;
-struct TensorDomain;
-struct TensorView;
-struct TensorIndex;
-struct Bool;
-struct Float;
-struct Half;
-struct Int;
-struct NamedScalar;
+class IterDomain;
+class TensorDomain;
+class TensorView;
+class TensorIndex;
+class Bool;
+class Float;
+class Half;
+class Int;
+class NamedScalar;
// Exprs
-struct Split;
-struct Merge;
-struct UnaryOp;
-struct BinaryOp;
-struct TernaryOp;
-struct ReductionOp;
-struct BroadcastOp;
-struct ForLoop;
-struct IfThenElse;
-struct Allocate;
+class Split;
+class Merge;
+class UnaryOp;
+class BinaryOp;
+class TernaryOp;
+class ReductionOp;
+class BroadcastOp;
+class ForLoop;
+class IfThenElse;
+class Allocate;
/*
* By default, all IR nodes are handled in this dispatch, and will call an empty
* function on all nodes.
*/
-struct TORCH_CUDA_API OptOutConstDispatch {
+class TORCH_CUDA_API OptOutConstDispatch {
+ public:
virtual ~OptOutConstDispatch() = default;
OptOutConstDispatch() = default;
@@ -93,35 +94,36 @@
OptOutConstDispatch& operator=(OptOutConstDispatch&& other) = default;
// Hierarchal dispatch functions for handle
- virtual void handle(const Statement* const);
- virtual void handle(const Expr* const);
- virtual void handle(const Val* const);
+ virtual void handle(const Statement*);
+ virtual void handle(const Expr*);
+ virtual void handle(const Val*);
// Vals
- virtual void handle(const IterDomain* const) {}
- virtual void handle(const TensorDomain* const) {}
- virtual void handle(const TensorView* const) {}
- virtual void handle(const TensorIndex* const) {}
- virtual void handle(const Bool* const) {}
- virtual void handle(const Float* const) {}
- virtual void handle(const Half* const) {}
- virtual void handle(const Int* const) {}
- virtual void handle(const NamedScalar* const) {}
+ virtual void handle(const IterDomain*) {}
+ virtual void handle(const TensorDomain*) {}
+ virtual void handle(const TensorView*) {}
+ virtual void handle(const TensorIndex*) {}
+ virtual void handle(const Bool*) {}
+ virtual void handle(const Float*) {}
+ virtual void handle(const Half*) {}
+ virtual void handle(const Int*) {}
+ virtual void handle(const NamedScalar*) {}
// Exprs
- virtual void handle(const Split* const) {}
- virtual void handle(const Merge* const) {}
- virtual void handle(const UnaryOp* const) {}
- virtual void handle(const BinaryOp* const) {}
- virtual void handle(const TernaryOp* const) {}
- virtual void handle(const ReductionOp* const) {}
- virtual void handle(const BroadcastOp* const) {}
- virtual void handle(const ForLoop* const) {}
- virtual void handle(const IfThenElse* const) {}
- virtual void handle(const Allocate* const) {}
+ virtual void handle(const Split*) {}
+ virtual void handle(const Merge*) {}
+ virtual void handle(const UnaryOp*) {}
+ virtual void handle(const BinaryOp*) {}
+ virtual void handle(const TernaryOp*) {}
+ virtual void handle(const ReductionOp*) {}
+ virtual void handle(const BroadcastOp*) {}
+ virtual void handle(const ForLoop*) {}
+ virtual void handle(const IfThenElse*) {}
+ virtual void handle(const Allocate*) {}
};
-struct TORCH_CUDA_API OptOutDispatch {
+class TORCH_CUDA_API OptOutDispatch {
+ public:
virtual ~OptOutDispatch() = default;
OptOutDispatch() = default;
@@ -160,7 +162,8 @@
virtual void handle(Allocate*) {}
};
-struct TORCH_CUDA_API OptInConstDispatch {
+class TORCH_CUDA_API OptInConstDispatch {
+ public:
virtual ~OptInConstDispatch() = default;
OptInConstDispatch() = default;
@@ -171,73 +174,74 @@
OptInConstDispatch& operator=(OptInConstDispatch&& other) = default;
// Hierarchal dispatch functions for handle
- virtual void handle(const Statement* const);
- virtual void handle(const Expr* const);
- virtual void handle(const Val* const);
+ virtual void handle(const Statement*);
+ virtual void handle(const Expr*);
+ virtual void handle(const Val*);
// Vals
- virtual void handle(const IterDomain* const) {
+ virtual void handle(const IterDomain*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for IterDomain.");
}
- virtual void handle(const TensorDomain* const) {
+ virtual void handle(const TensorDomain*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TensorDomain.");
}
- virtual void handle(const TensorView* const) {
+ virtual void handle(const TensorView*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TensorView.");
}
- virtual void handle(const TensorIndex* const) {
+ virtual void handle(const TensorIndex*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TensorIndex.");
}
- virtual void handle(const Bool* const) {
+ virtual void handle(const Bool*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Bool.");
}
- virtual void handle(const Float* const) {
+ virtual void handle(const Float*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Float.");
}
- virtual void handle(const Half* const) {
+ virtual void handle(const Half*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Half.");
}
- virtual void handle(const Int* const) {
+ virtual void handle(const Int*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Int.");
}
- virtual void handle(const NamedScalar* const) {
+ virtual void handle(const NamedScalar*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for NamedScalar.");
}
// Exprs
- virtual void handle(const Split* const) {
+ virtual void handle(const Split*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Split.");
}
- virtual void handle(const Merge* const) {
+ virtual void handle(const Merge*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Merge.");
}
- virtual void handle(const UnaryOp* const) {
+ virtual void handle(const UnaryOp*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for UnaryOp.");
}
- virtual void handle(const BinaryOp* const) {
+ virtual void handle(const BinaryOp*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for BinaryOp.");
}
- virtual void handle(const TernaryOp* const) {
+ virtual void handle(const TernaryOp*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TernaryOp.");
}
- virtual void handle(const ReductionOp* const) {
+ virtual void handle(const ReductionOp*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ReductionOp.");
}
- virtual void handle(const BroadcastOp* const) {
+ virtual void handle(const BroadcastOp*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for BroadcastOp.");
}
- virtual void handle(const ForLoop* const) {
+ virtual void handle(const ForLoop*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ForLoop.");
}
- virtual void handle(const Allocate* const) {
+ virtual void handle(const Allocate*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Allocate.");
}
- virtual void handle(const IfThenElse* const) {
+ virtual void handle(const IfThenElse*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for IfThenElse.");
}
};
-struct TORCH_CUDA_API OptInDispatch {
+class TORCH_CUDA_API OptInDispatch {
+ public:
virtual ~OptInDispatch() = default;
OptInDispatch() = default;
@@ -314,7 +318,8 @@
}
};
-struct TORCH_CUDA_API OptOutMutator {
+class TORCH_CUDA_API OptOutMutator {
+ public:
virtual ~OptOutMutator() = default;
OptOutMutator() = default;
@@ -331,13 +336,11 @@
virtual Statement* mutate(Expr* e);
virtual Statement* mutate(Val* v);
- /*
- * We always want to dispatch through a Val, so we can capture and dispatch
- * correctly members of nodes like Split->TensorDomain If we don't call the
- * below function or manually cast to use mutate(Val* v) we can't intercept
- * and mutate by capturing mutate(Val* v), which is what we do when we want to
- * replace all instances of a value.
- */
+ // We always want to dispatch through a Val, so we can capture and dispatch
+ // correctly members of nodes like Split->TensorDomain If we don't call the
+ // below function or manually cast to use mutate(Val* v) we can't intercept
+ // and mutate by capturing mutate(Val* v), which is what we do when we want to
+ // replace all instances of a value.
Statement* mutateAsVal(Val* v) {
return mutate(v);
}
@@ -352,7 +355,8 @@
std::unordered_map<Val*, Val*> mutations;
- //****Functions below defined in mutator.cpp*****///
+ //****Functions below defined in mutator.cpp*****
+
// Vals
virtual Statement* mutate(IterDomain*);
virtual Statement* mutate(TensorDomain*);
@@ -377,7 +381,8 @@
virtual Statement* mutate(Allocate*);
};
-struct TORCH_CUDA_API OptInMutator {
+class TORCH_CUDA_API OptInMutator {
+ public:
virtual ~OptInMutator() = default;
OptInMutator() = default;
diff --git a/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp b/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp
index 4568712..2a333fc 100644
--- a/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp
+++ b/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp
@@ -38,40 +38,46 @@
}
c10::optional<Int::ScalarType> ExpressionEvaluator::evaluate(
- const Statement* expr,
+ Val* val,
const EvaluationContext* context) {
TORCH_CHECK(context != nullptr);
ExpressionEvaluator evaluator(context);
- evaluator.OptInConstDispatch::handle(expr);
- return evaluator.result_;
+ evaluator.traverseFrom(context->fusion(), {val}, false);
+ return evaluator.value(val);
}
-void ExpressionEvaluator::handle(const Int* i) {
+c10::optional<Int::ScalarType> ExpressionEvaluator::value(
+ const Statement* stmt) const {
+ const auto it = values_.find(stmt);
+ return (it != values_.end()) ? c10::optional<Int::ScalarType>(it->second)
+ : c10::nullopt;
+}
+
+void ExpressionEvaluator::handle(Int* i) {
if (i->value().has_value()) {
- result_ = i->value();
+ values_[i] = *i->value();
} else if (const auto* def = context_->fusion()->origin(i)) {
- result_ = evaluate(def, context_);
+ const auto& def_result = value(def);
+ if (def_result.has_value()) {
+ values_[i] = *def_result;
+ }
} else {
const auto& bound_value = context_->concreteValue(i);
if (bound_value.has_value()) {
- result_ = bound_value;
+ values_[i] = *bound_value;
}
}
}
-void ExpressionEvaluator::handle(const NamedScalar* i) {
- // nothing to do, leave the result "unknown"
-}
-
-void ExpressionEvaluator::handle(const UnaryOp* uop) {
- const auto in = evaluate(uop->in(), context_);
+void ExpressionEvaluator::handle(UnaryOp* uop) {
+ const auto in = value(uop->in());
if (in.has_value()) {
switch (uop->getUnaryOpType()) {
case UnaryOpType::Neg:
- result_ = -*in;
+ values_[uop] = -*in;
break;
case UnaryOpType::Cast:
- result_ = *in;
+ values_[uop] = *in;
break;
default:
TORCH_CHECK(!"Unexpected operator type");
@@ -79,35 +85,34 @@
}
}
-void ExpressionEvaluator::handle(const BinaryOp* bop) {
- TORCH_CHECK(bop->out()->isAnInt()); // not really needed
- const auto lhs = evaluate(bop->lhs(), context_);
- const auto rhs = evaluate(bop->rhs(), context_);
+void ExpressionEvaluator::handle(BinaryOp* bop) {
+ const auto lhs = value(bop->lhs());
+ const auto rhs = value(bop->rhs());
if (lhs.has_value() && rhs.has_value()) {
switch (bop->getBinaryOpType()) {
case BinaryOpType::Add:
- result_ = *lhs + *rhs;
+ values_[bop] = *lhs + *rhs;
break;
case BinaryOpType::Sub:
- result_ = *lhs - *rhs;
+ values_[bop] = *lhs - *rhs;
break;
case BinaryOpType::Mul:
- result_ = *lhs * *rhs;
+ values_[bop] = *lhs * *rhs;
break;
case BinaryOpType::Div:
TORCH_CHECK(*rhs != 0);
- result_ = *lhs / *rhs;
+ values_[bop] = *lhs / *rhs;
break;
case BinaryOpType::Mod:
TORCH_CHECK(*rhs != 0);
- result_ = *lhs % *rhs;
+ values_[bop] = *lhs % *rhs;
break;
case BinaryOpType::CeilDiv:
TORCH_CHECK(*rhs != 0);
- result_ = (*lhs + *rhs - 1) / *rhs;
+ values_[bop] = (*lhs + *rhs - 1) / *rhs;
break;
case BinaryOpType::And:
- result_ = Int::ScalarType(*lhs && *rhs);
+ values_[bop] = Int::ScalarType(*lhs && *rhs);
break;
default:
TORCH_CHECK(!"Unexpected operator type");
diff --git a/torch/csrc/jit/codegen/cuda/expr_evaluator.h b/torch/csrc/jit/codegen/cuda/expr_evaluator.h
index 5719128..bc29aac 100644
--- a/torch/csrc/jit/codegen/cuda/expr_evaluator.h
+++ b/torch/csrc/jit/codegen/cuda/expr_evaluator.h
@@ -2,8 +2,8 @@
#pragma once
#include <torch/csrc/WindowsTorchApiMacro.h>
-#include <torch/csrc/jit/codegen/cuda/dispatch.h>
#include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
+#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
#include <c10/util/Optional.h>
@@ -20,7 +20,7 @@
//
class TORCH_CUDA_API EvaluationContext {
public:
- explicit EvaluationContext(const Fusion* fusion) : fusion_(fusion) {}
+ explicit EvaluationContext(Fusion* fusion) : fusion_(fusion) {}
// Set the concrete value for a Int*
void bind(const Val* value, Int::ScalarType concrete_value);
@@ -28,7 +28,7 @@
// Retrieves the concrete value, or nullopt if not set
c10::optional<Int::ScalarType> concreteValue(const Val* value) const;
- const Fusion* fusion() const {
+ Fusion* fusion() const {
return fusion_;
}
@@ -37,18 +37,18 @@
private:
std::unordered_map<const Val*, Int::ScalarType> bindings_;
- const Fusion* fusion_ = nullptr;
+ Fusion* fusion_ = nullptr;
};
// Evaluates expressions in a Fusion IR, using the passed in
// context (EvaluationContext) to query for concrete_values. The
// evaluation context may override concrete values in the IR as well.
-class TORCH_CUDA_API ExpressionEvaluator : private OptInConstDispatch {
+class TORCH_CUDA_API ExpressionEvaluator : private IterVisitor {
public:
// Returns the result of the specified expression, or nullopt if
// the result cannot be evaluated
static c10::optional<Int::ScalarType> evaluate(
- const Statement* expr,
+ Val* val,
const EvaluationContext* context);
private:
@@ -57,15 +57,17 @@
~ExpressionEvaluator() override = default;
- void handle(const Int*) override;
- void handle(const NamedScalar*) override;
+ c10::optional<Int::ScalarType> value(const Statement* stmt) const;
- void handle(const UnaryOp*) override;
- void handle(const BinaryOp*) override;
+ using IterVisitor::handle;
+
+ void handle(Int*) override;
+ void handle(UnaryOp*) override;
+ void handle(BinaryOp*) override;
private:
const EvaluationContext* context_ = nullptr;
- c10::optional<Int::ScalarType> result_;
+ std::unordered_map<const Statement*, Int::ScalarType> values_;
};
} // namespace fuser
diff --git a/torch/csrc/jit/codegen/cuda/fusion.cpp b/torch/csrc/jit/codegen/cuda/fusion.cpp
index e4faf31..141c12d 100644
--- a/torch/csrc/jit/codegen/cuda/fusion.cpp
+++ b/torch/csrc/jit/codegen/cuda/fusion.cpp
@@ -1,7 +1,10 @@
+
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
+#include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
#include <torch/csrc/jit/codegen/cuda/ir_printer.h>
#include <torch/csrc/jit/codegen/cuda/kernel.h>
+#include <torch/csrc/jit/codegen/cuda/lower2device.h>
namespace torch {
namespace jit {
@@ -34,9 +37,10 @@
std::vector<Expr*> ExprSort::getExprs(
Fusion* fusion,
bool from_outputs_only,
- bool breadth_first) {
+ bool breadth_first,
+ bool respect_compute_at) {
ExprSort es;
- es.traverse(fusion, from_outputs_only, breadth_first);
+ es.traverse(fusion, from_outputs_only, breadth_first, respect_compute_at);
return es.exprs;
}
@@ -56,22 +60,136 @@
return io.inputs;
}
-Fusion::~Fusion() {
- {
- auto it = val_set_.begin();
- while (it != val_set_.end()) {
- auto del = it;
- it = ++it;
- delete (*del);
+void swap(Fusion& a, Fusion& b) noexcept {
+ using std::swap;
+
+ // Swap the content
+ swap(a.val_set_, b.val_set_);
+ swap(a.expr_set_, b.expr_set_);
+ swap(a.val_deque_, b.val_deque_);
+
+ swap(a.val_type_name_map_, b.val_type_name_map_);
+ swap(a.val_name_counter_, b.val_name_counter_);
+ swap(a.expr_name_counter_, b.expr_name_counter_);
+
+ swap(a.origin_, b.origin_);
+ swap(a.uses_, b.uses_);
+ swap(a.values_map_, b.values_map_);
+
+ swap(a.inputs_, b.inputs_);
+ swap(a.outputs_, b.outputs_);
+
+ // Fixup the Statement::fusion_ links for a
+ for (auto val : a.val_set_) {
+ val->fusion_ = &a;
+ }
+ for (auto expr : a.expr_set_) {
+ expr->fusion_ = &a;
+ }
+
+ // Fixup the Statement::fusion_ links for b
+ for (auto val : b.val_set_) {
+ val->fusion_ = &b;
+ }
+ for (auto expr : b.expr_set_) {
+ expr->fusion_ = &b;
+ }
+}
+
+Fusion::Fusion(const Fusion& other) {
+ IrCloner ir_cloner(this);
+
+ for (auto val : other.val_set_) {
+ val_set_.insert(ir_cloner.clone(val));
+ }
+
+ for (auto expr : other.expr_set_) {
+ expr_set_.insert(ir_cloner.clone(expr));
+ }
+
+ for (auto val : other.val_deque_) {
+ val_deque_.push_back(ir_cloner.clone(val));
+ }
+
+ val_type_name_map_ = other.val_type_name_map_;
+ val_name_counter_ = other.val_name_counter_;
+ expr_name_counter_ = other.expr_name_counter_;
+
+ for (const auto& kv : other.origin_) {
+ auto val = ir_cloner.clone(kv.first);
+ auto expr = ir_cloner.clone(kv.second);
+ origin_.insert({val, expr});
+ }
+
+ for (const auto& kv : other.uses_) {
+ auto val = ir_cloner.clone(kv.first);
+ std::unordered_set<Expr*> val_uses;
+ for (auto expr : kv.second) {
+ val_uses.insert(ir_cloner.clone(expr));
}
+ uses_.insert({val, std::move(val_uses)});
}
- auto it = expr_set_.begin();
- while (it != expr_set_.end()) {
- auto del = it;
- it = ++it;
- delete (*del);
+
+ for (const auto& kv : other.values_map_) {
+ auto from_val = ir_cloner.clone(kv.first);
+ auto to_val = ir_cloner.clone(kv.second);
+ values_map_.insert({from_val, to_val});
}
-};
+
+ inputs_ = ir_cloner.clone(other.inputs_);
+ outputs_ = ir_cloner.clone(other.outputs_);
+}
+
+Fusion::Fusion(Fusion&& other) noexcept {
+ swap(*this, other);
+}
+
+Fusion& Fusion::operator=(const Fusion& other) {
+ Fusion copy(other);
+ clear();
+ swap(*this, copy);
+ return *this;
+}
+
+Fusion& Fusion::operator=(Fusion&& other) noexcept {
+ clear();
+ swap(*this, other);
+ return *this;
+}
+
+Fusion::~Fusion() {
+ clear();
+}
+
+void Fusion::clear() noexcept {
+ // Free the owned values
+ for (auto ptr : val_set_) {
+ delete ptr;
+ }
+
+ // Free the owned expressions
+ for (auto ptr : expr_set_) {
+ delete ptr;
+ }
+
+ val_set_.clear();
+ val_deque_.clear();
+ expr_set_.clear();
+
+ for (auto& kv : val_type_name_map_) {
+ kv.second = 0;
+ }
+
+ val_name_counter_ = 0;
+ expr_name_counter_ = 0;
+
+ origin_.clear();
+ uses_.clear();
+ values_map_.clear();
+
+ inputs_.clear();
+ outputs_.clear();
+}
void Fusion::removeExpr(Expr* expr) {
assertInFusion(expr, "Cannot remove expr ");
@@ -138,7 +256,14 @@
" has a reduction axis, but this does nothing in the fusion.");
}
- IRInputOutput::addInput(input);
+ TORCH_CHECK(
+ input->getOrigin() == nullptr,
+ input,
+ " cannot be registered as an input as it is used as an output of an expression (",
+ input->getOrigin(),
+ ").");
+
+ inputs_.push_back(input);
}
void Fusion::addOutput(Val* const output) {
@@ -153,7 +278,7 @@
output,
" cannot be registered as an output as it has a broadcast axis.");
}
- IRInputOutput::addOutput(output);
+ outputs_.push_back(output);
}
bool Fusion::inFusion(const Statement* stmt) const {
@@ -177,10 +302,14 @@
TORCH_CHECK(false, msg, " it was not found in the active fusion.");
}
-std::vector<Expr*> Fusion::exprs(bool from_outputs_only, bool breadth_first) {
+std::vector<Expr*> Fusion::exprs(
+ bool from_outputs_only,
+ bool breadth_first,
+ bool respect_compute_at) {
if (breadth_first)
TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
- return ExprSort::getExprs(this, from_outputs_only, breadth_first);
+ return ExprSort::getExprs(
+ this, from_outputs_only, breadth_first, respect_compute_at);
}
std::unordered_set<Val*> Fusion::inputsOf(Val* val) {
@@ -190,20 +319,16 @@
void Fusion::validateInputs() {
std::unordered_set<Val*> all_inputs;
for (Val* out : outputs()) {
- auto outs_inputs = inputsOf(out);
- std::set_union(
- all_inputs.begin(),
- all_inputs.end(),
- outs_inputs.begin(),
- outs_inputs.end(),
- std::inserter(all_inputs, all_inputs.begin()));
+ for (Val* input : inputsOf(out)) {
+ all_inputs.insert(input);
+ }
}
- for (Val* inp : all_inputs) {
- if (!inp->isConstScalar())
+ for (Val* input : all_inputs) {
+ if (!input->isConstScalar())
TORCH_CHECK(
- hasInput(inp),
+ hasInput(input),
"Could not figure out how ",
- inp,
+ input,
" is generated, however it was not specified as an input.");
}
}
@@ -218,10 +343,30 @@
std::cout << "}\n";
}
+void Fusion::printValuesMap() {
+ IRPrinter ir_printer(std::cout);
+ ir_printer.follow_val_map = false;
+ std::cout << "\nValues map\n";
+ std::cout << "--------------------\n";
+ for (const auto& kv : values_map_) {
+ ir_printer.handle(kv.first);
+ std::cout << " -> ";
+ ir_printer.handle(kv.second);
+ std::cout << "\n";
+ }
+ std::cout << "--------------------\n\n";
+}
+
+void Fusion::printKernel() {
+ FusionGuard fg(this);
+ GPULower lower(this);
+ lower.printKernel(std::cout);
+}
+
void Fusion::printMath() {
FusionGuard fg(this);
- IRMathPrinter op_exprs(std::cout);
- op_exprs.handle(this);
+ for (auto expr : exprs(true))
+ std::cout << expr;
}
void Fusion::printTransforms() {
@@ -338,11 +483,28 @@
return it->second;
}
+bool Fusion::hasInput(const Val* val) const {
+ return std::find(inputs_.begin(), inputs_.end(), val) != inputs_.end();
+}
+
+bool Fusion::hasOutput(const Val* val) const {
+ return std::find(outputs_.begin(), outputs_.end(), val) != outputs_.end();
+}
+
+void Fusion::replaceInput(Val* replace, Val* with) {
+ std::replace(inputs_.begin(), inputs_.end(), replace, with);
+}
+
+void Fusion::replaceOutput(Val* replace, Val* with) {
+ std::replace(outputs_.begin(), outputs_.end(), replace, with);
+}
+
StmtNameType Fusion::getValName(ValType vtype) {
- if (val_type_name_map.find(vtype) != val_type_name_map.end())
- return val_type_name_map[vtype]++;
+ if (val_type_name_map_.find(vtype) != val_type_name_map_.end())
+ return val_type_name_map_[vtype]++;
return val_name_counter_++;
}
+
StmtNameType Fusion::getExprName() {
return expr_name_counter_++;
}
@@ -368,6 +530,26 @@
return false;
}
+bool Fusion::hasBlockReduction() {
+ for (auto expr : exprs(true))
+ for (auto out : expr->outputs())
+ if (out->getValType() == ValType::TensorView)
+ if (static_cast<TensorView*>(out)->hasBlockReduction())
+ return true;
+
+ return false;
+}
+
+bool Fusion::hasGridReduction() {
+ for (auto expr : exprs(true))
+ for (auto out : expr->outputs())
+ if (out->getValType() == ValType::TensorView)
+ if (static_cast<TensorView*>(out)->hasGridReduction())
+ return true;
+
+ return false;
+}
+
} // namespace fuser
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/fusion.h b/torch/csrc/jit/codegen/cuda/fusion.h
index e0f240f..684dc91 100644
--- a/torch/csrc/jit/codegen/cuda/fusion.h
+++ b/torch/csrc/jit/codegen/cuda/fusion.h
@@ -48,16 +48,16 @@
* us mechanisms for dependency analysis and DCE including safety checks.
*/
-struct Fusion;
-struct TensorView;
+class Fusion;
+class TensorView;
namespace cuda {
-struct CudaKernel;
+class CudaKernel;
}
// Fusion Guard is our "context manager". It holds the actrive fusion and allows
// it to be accessed anywhere through FusionGuard::getCurFusion().
-struct TORCH_CUDA_API FusionGuard {
+class TORCH_CUDA_API FusionGuard {
public:
Fusion* prev_fusion;
@@ -72,7 +72,7 @@
// Expr sort will take a fusion and return a topologically sorted list of
// expressions.
-struct ExprSort : public IterVisitor {
+class ExprSort : public IterVisitor {
private:
std::vector<Expr*> exprs;
@@ -82,10 +82,11 @@
static std::vector<Expr*> getExprs(
Fusion* fusion,
bool from_outputs_only,
- bool breadth_first);
+ bool breadth_first,
+ bool respect_compute_at);
};
-struct InputsOf : public IterVisitor {
+class InputsOf : public IterVisitor {
private:
std::unordered_set<Val*> inputs;
@@ -101,21 +102,25 @@
* duplicating all associated values and exprs. Fusion is considered to SSA,
* though this could also change in the future if there is a good reason to do
* so.
+ *
+ * The Fusion owns the whole IR graph (Vals and Exprs)
*/
+class TORCH_CUDA_API Fusion final {
+ public:
+ Fusion() = default;
-struct TORCH_CUDA_API Fusion : public IRInputOutput {
- Fusion() {}
+ Fusion(const Fusion& other);
+ Fusion(Fusion&& other) noexcept;
- // Not copyable
- Fusion(const Fusion& other) = delete;
- Fusion& operator=(const Fusion& other) = delete;
+ Fusion& operator=(const Fusion& other);
+ Fusion& operator=(Fusion&& other) noexcept;
- Fusion(Fusion&& other) = delete;
- Fusion& operator=(Fusion&& other) = delete;
-
- // When destroyed clean up all IR associated with this fusion
~Fusion();
+ friend void swap(Fusion& a, Fusion& b) noexcept;
+
+ void clear() noexcept;
+
// Break dependency chains associated with Expr, remove references to expr
// delete expr.
void removeExpr(Expr* expr);
@@ -153,7 +158,8 @@
*/
std::vector<Expr*> exprs(
bool from_outputs_only = false,
- bool breadth_first = false);
+ bool breadth_first = false,
+ bool respect_compute_at = false);
std::unordered_set<Val*> inputsOf(Val* val);
@@ -163,11 +169,16 @@
// Print this fusion to cout.
void print();
+ // Print value mapping
+ void printValuesMap();
+
// Print Arith exprs used in outputs
void printMath();
+
// Print transformations used in fusion (can be very verbose)
void printTransforms();
-
+ // Lower the fusion and print a kernel
+ void printKernel();
// Register the Val with this fusion
StmtNameType registerVal(Val* val);
@@ -203,22 +214,54 @@
// Indicate to kernel to set itself up to generate random numbers
bool hasRNG();
- // Indicate to kernel to set itself up to generate random numbers
bool hasReduction();
+ bool hasBlockReduction();
+ bool hasGridReduction();
+ size_t gridReductionTempBufferSize();
+
+ void setValuesMap(std::unordered_map<Val*, Val*> values_map) {
+ values_map_ = std::move(values_map);
+ }
+
+ Val* loweredVal(Val* value) const {
+ auto it = values_map_.find(value);
+ return it != values_map_.end() ? it->second : value;
+ }
+
+ const Val* loweredVal(const Val* value) const {
+ auto it = values_map_.find(const_cast<Val*>(value));
+ return it != values_map_.end() ? it->second : value;
+ }
+
+ const auto& inputs() const {
+ return inputs_;
+ }
+
+ const auto& outputs() const {
+ return outputs_;
+ }
+
+ bool hasInput(const Val* val) const;
+ bool hasOutput(const Val* val) const;
+
+ void replaceInput(Val* replace, Val* with);
+ void replaceOutput(Val* replace, Val* with);
private:
- // Sets of all Vals/Exprs registered with this fusion
- std::unordered_set<Val*> val_set_;
- std::deque<Val*> val_deque_;
- std::unordered_set<Expr*> expr_set_;
-
// Return an int that monotonically increases for each val/expr, some are
// explicitly incremented by type.
StmtNameType getValName(ValType vtype);
StmtNameType getExprName();
+ private:
+ // Sets of all Vals/Exprs registered with this fusion
+ // (val_deque_ is not owning the objects)
+ std::unordered_set<Val*> val_set_;
+ std::deque<Val*> val_deque_;
+ std::unordered_set<Expr*> expr_set_;
+
// map from valtype to individual name counters
- std::unordered_map<ValType, StmtNameType, TypeHash> val_type_name_map = {
+ std::unordered_map<ValType, StmtNameType, TypeHash> val_type_name_map_ = {
{ValType::TensorView, 0},
{ValType::TensorDomain, 0},
{ValType::IterDomain, 0},
@@ -231,6 +274,13 @@
// Dependency tracking for Vals. Where did it come from? Where is it used?
std::unordered_map<Val*, Expr*> origin_;
std::unordered_map<Val*, std::unordered_set<Expr*>> uses_;
+
+ // Map a subset of values to the lowered equivalent (ex. sizes)
+ std::unordered_map<Val*, Val*> values_map_;
+
+ // Fusion inputs and outputs
+ std::vector<Val*> inputs_;
+ std::vector<Val*> outputs_;
};
} // namespace fuser
diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp
index ce08528..898df6a 100644
--- a/torch/csrc/jit/codegen/cuda/index_compute.cpp
+++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp
@@ -15,9 +15,8 @@
auto outer_it = index_map_.find(outer_id);
auto inner_it = index_map_.find(inner_id);
- TORCH_INTERNAL_ASSERT(
- outer_it != index_map_.end() && inner_it != index_map_.end(),
- "Error in index compute, did not compute a necessary intermediate value.");
+ if (outer_it == index_map_.end() || inner_it == index_map_.end())
+ return;
auto outer_ind = outer_it->second;
auto inner_ind = inner_it->second;
@@ -32,9 +31,8 @@
auto inner_id = merge->inner();
auto out_it = index_map_.find(out_id);
- TORCH_INTERNAL_ASSERT(
- out_it != index_map_.end(),
- "Error in index compute, did not compute a necessary intermediate value.");
+ if (out_it == index_map_.end())
+ return;
auto out_ind = out_it->second;
@@ -58,21 +56,21 @@
BackwardVisitor::handle(e);
}
-IndexCompute::IndexCompute(TensorDomain* td, const std::vector<Val*>& indices) {
+IndexCompute::IndexCompute(
+ const TensorDomain* td,
+ const std::vector<Val*>& indices) {
if (td->nDims() == 0 || indices.empty()) {
indices_.push_back(new Int(0));
return;
}
- bool exclude_reduction = td->nDims() > indices.size();
+ const bool exclude_reduction = td->nDims() > indices.size();
TORCH_INTERNAL_ASSERT(
td->noReductions().size() == indices.size() ||
td->nDims() == indices.size(),
"For IndexCompute the number of axes should match the number of dimensions in the TensorDomain.");
- TORCH_INTERNAL_ASSERT(!td->hasRFactor(), "Not implemented yet.");
-
{
size_t i = 0;
for (auto id : td->domain()) {
@@ -82,7 +80,7 @@
}
}
- std::vector<Val*> domain_vals(td->domain().begin(), td->domain().end());
+ const std::vector<Val*> domain_vals(td->domain().begin(), td->domain().end());
// Run the split/merge operations backwards. This will modify the index_map_
// so it can be used to index the root TensorDomain. Each entry in the root
@@ -92,7 +90,6 @@
// map at the rfactor IterDomains.
traverseFrom(indices[0]->fusion(), domain_vals, false);
- std::vector<Val*> inds;
for (auto id : td->rootDomain()) {
if (exclude_reduction && id->isReduction())
continue;
@@ -105,66 +102,60 @@
}
std::vector<Val*> IndexCompute::get(
- TensorDomain* td,
+ const TensorDomain* td,
const std::vector<Val*>& _indices) {
IndexCompute ic(td, _indices);
return ic.indices_;
}
TensorIndex* Index::getGlobalProducerIndex(
- TensorView* producer,
- TensorView* consumer,
+ const TensorView* producer,
+ const TensorView* consumer,
const std::vector<ForLoop*>& loops) {
- // This replay will ignore reduction dimensions on the producer
- auto pind =
- TransformReplay::replayPasC(producer->domain(), consumer->domain(), -1);
+ // Grab indices from the loops
+ std::vector<Val*> indices(loops.size());
+ std::transform(loops.begin(), loops.end(), indices.begin(), [](ForLoop* fl) {
+ return fl->index();
+ });
- TORCH_INTERNAL_ASSERT(
- loops.size() == consumer->nDims(),
- "Dimensionality error in code generator while computing tensor indexes.");
+ // What would the consumer indices be if it was global, keeping in mind
+ // reduction axes:
+ const std::vector<Val*> c_inds =
+ IndexCompute::get(consumer->domain(), indices);
- std::vector<ForLoop*> loops_adjusted;
- size_t it_c = 0, it_p = 0;
- while (it_c < consumer->nDims() && it_p < pind->noReductions().size()) {
- if (consumer->axis(it_c)->isBroadcast() &&
- !pind->noReductions()[it_p]->isBroadcast()) {
- it_c++;
- } else {
- loops_adjusted.push_back(loops[it_c]);
- it_c++;
- it_p++;
+ // Computed consumer indices should have everything we need for the producer
+ std::vector<Val*> p_inds;
+ auto p_root = TensorDomain::noReductions(producer->getRootDomain());
+ // Number of root dims that are broadcasted
+ size_t bcast_dims = 0;
+ {
+ auto c_root = consumer->getRootDomain();
+ size_t it_c = 0, it_p = 0;
+ while (it_c < c_root.size() && it_p < p_root.size()) {
+ const bool is_bcast = p_root[it_p]->isBroadcast();
+ if (c_root[it_c]->isBroadcast() && !is_bcast) {
+ it_c++;
+ } else {
+ if (!is_bcast) {
+ p_inds.push_back(c_inds[it_c]);
+ } else {
+ bcast_dims++;
+ }
+ it_c++;
+ it_p++;
+ }
}
}
-
TORCH_INTERNAL_ASSERT(
- loops_adjusted.size() == pind->noReductions().size(),
- "Dimensionality error in code generator while computing tensor indexes.");
-
- std::vector<Val*> indices(loops_adjusted.size());
- std::transform(
- loops_adjusted.begin(),
- loops_adjusted.end(),
- indices.begin(),
- [](ForLoop* fl) { return fl->index(); });
- std::vector<Val*> computed_inds = IndexCompute::get(pind, indices);
-
- auto root_domain = producer->getRootDomain();
-
- TORCH_INTERNAL_ASSERT(
- computed_inds.size() == root_domain.size(),
- "Dimensionality error in code generator while computing indexing.");
-
- for (decltype(computed_inds.size()) i{0}; i < computed_inds.size(); i++) {
- if (root_domain[i]->isReduction() || root_domain[i]->isBroadcast())
- computed_inds.erase(computed_inds.begin() + i);
- }
+ p_inds.size() == p_root.size() - bcast_dims,
+ "Dimensionality error in code generator while computing tensor indices.");
std::vector<Val*> strided_inds;
- for (decltype(computed_inds.size()) i{0}; i < computed_inds.size(); i++) {
+ for (size_t i = 0; i < p_inds.size(); i++) {
std::stringstream ss;
ss << "T" << producer->name() << ".stride[" << i << "]";
strided_inds.push_back(
- mul(computed_inds[i], new NamedScalar(ss.str(), DataType::Int)));
+ mul(p_inds[i], new NamedScalar(ss.str(), DataType::Int)));
}
// Probably shouldn't ever hit this
@@ -176,8 +167,8 @@
// Producer index for either shared or local memory
TensorIndex* Index::getProducerIndex_impl(
- TensorView* producer,
- TensorView* consumer,
+ const TensorView* producer,
+ const TensorView* consumer,
const std::vector<ForLoop*>& loops) {
TORCH_INTERNAL_ASSERT(
loops.size() == consumer->nDims(),
@@ -222,7 +213,7 @@
std::vector<Val*> used_inds;
std::vector<IterDomain*> used_ranges;
bool unrolled = false;
- for (decltype(loops_adjusted.size()) i{0}; i < loops_adjusted.size(); i++) {
+ for (size_t i = 0; i < loops_adjusted.size(); i++) {
if (ranges[i]->parallel_method() == ParallelType::Unroll)
unrolled = true;
if (!unrolled && producer->hasComputeAt() &&
@@ -233,16 +224,16 @@
continue;
if (producer->getMemoryType() == MemoryType::Local && ranges[i]->isThread())
continue;
- if (ranges[i]->isBroadcast())
+ if (producer->domain()->noReductions()[i]->isBroadcast())
continue;
used_inds.push_back(indices[i]);
used_ranges.push_back(ranges[i]);
}
- for (decltype(used_inds.size()) i{0}; i < used_inds.size(); i++) {
+ for (size_t i = 0; i < used_inds.size(); i++) {
Val* ind = used_inds[i];
- for (decltype(used_ranges.size()) j{i + 1}; j < used_ranges.size(); j++)
+ for (size_t j = i + 1; j < used_ranges.size(); j++)
ind = mul(ind, used_ranges[j]->extent());
used_inds[i] = ind;
}
@@ -253,7 +244,7 @@
}
TensorIndex* Index::getGlobalConsumerIndex(
- TensorView* consumer,
+ const TensorView* consumer,
const std::vector<ForLoop*>& loops) {
// If we're initializing a reduction buffer, we won't have the reduction
// loops. If we're actually performing the reduction, we will.
@@ -273,7 +264,7 @@
"Dimensionality error in code generator while computing indexing.");
if (computed_inds.size() == root_dom.size())
- for (decltype(root_dom.size()) i{0}; i < root_dom.size(); i++) {
+ for (size_t i = 0; i < root_dom.size(); i++) {
// Do this backwards so erase offset will be right
auto axis = root_dom.size() - i - 1;
if (root_dom[axis]->isReduction() || root_dom[i]->isBroadcast())
@@ -281,7 +272,7 @@
}
std::vector<Val*> strided_inds;
- for (decltype(computed_inds.size()) i{0}; i < computed_inds.size(); i++) {
+ for (size_t i = 0; i < computed_inds.size(); i++) {
std::stringstream ss;
ss << "T" << consumer->name() << ".stride[" << i << "]";
strided_inds.push_back(
@@ -297,7 +288,7 @@
// Consumer index for either shared or local memory
TensorIndex* Index::getConsumerIndex_impl(
- TensorView* consumer,
+ const TensorView* consumer,
const std::vector<ForLoop*>& loops) {
// If we're initializing a reduction buffer, we won't have the reduction
// loops. If we're actually performing the reduction, we will.
@@ -335,29 +326,40 @@
std::vector<Val*> used_inds;
std::vector<IterDomain*> used_ranges;
bool unrolled = false;
- for (decltype(loops.size()) i{0}; i < loops.size(); i++) {
- if (have_reduction_iters && consumer->axis(i)->isReduction())
- continue;
- if (ranges[i]->parallel_method() == ParallelType::Unroll)
- unrolled = true;
- if (!unrolled && consumer->hasComputeAt() &&
- i < consumer->getThisComputeAtAxis())
- continue;
- if (consumer->getMemoryType() == MemoryType::Shared &&
- ranges[i]->isBlockDim())
- continue;
- if (consumer->getMemoryType() == MemoryType::Local && ranges[i]->isThread())
- continue;
- if (ranges[i]->isBroadcast())
- continue;
+ {
+ size_t c_i = 0, l_i = 0;
+ while (c_i < consumer->nDims() && l_i < loops.size()) {
+ if (consumer->axis(c_i)->isReduction()) {
+ c_i++;
+ if (have_reduction_iters)
+ l_i++;
+ continue;
+ }
+ if (ranges[l_i]->parallel_method() == ParallelType::Unroll)
+ unrolled = true;
- used_inds.push_back(indices[i]);
- used_ranges.push_back(ranges[i]);
+ if ((!unrolled && consumer->hasComputeAt() &&
+ c_i < consumer->getThisComputeAtAxis()) ||
+ (consumer->getMemoryType() == MemoryType::Shared &&
+ ranges[l_i]->isBlockDim()) ||
+ (consumer->getMemoryType() == MemoryType::Local &&
+ ranges[l_i]->isThread()) ||
+ (consumer->axis(c_i)->isBroadcast())) {
+ c_i++;
+ l_i++;
+ continue;
+ }
+
+ used_inds.push_back(indices[l_i]);
+ used_ranges.push_back(ranges[l_i]);
+ l_i++;
+ c_i++;
+ }
}
- for (decltype(used_inds.size()) i{0}; i < used_inds.size(); i++) {
+ for (size_t i = 0; i < used_inds.size(); i++) {
Val* ind = used_inds[i];
- for (decltype(used_ranges.size()) j{i + 1}; j < used_ranges.size(); j++)
+ for (size_t j = i + 1; j < used_ranges.size(); j++)
ind = mul(ind, used_ranges[j]->extent());
used_inds[i] = ind;
}
@@ -370,13 +372,17 @@
// Producer is the inputs of an expression
TensorIndex* Index::getProducerIndex(
- TensorView* producer,
- TensorView* consumer,
+ const TensorView* producer,
+ const TensorView* consumer,
const std::vector<ForLoop*>& loops) {
TORCH_INTERNAL_ASSERT(
loops.size() == consumer->nDims() ||
loops.size() == consumer->domain()->noReductions().size());
+ if (producer->domain()->noReductions().size() == 0) {
+ return new TensorIndex(producer, {});
+ }
+
if (producer->getMemoryType() == MemoryType::Global)
return getGlobalProducerIndex(producer, consumer, loops);
return getProducerIndex_impl(producer, consumer, loops);
@@ -384,12 +390,16 @@
// Consumer is the output of an expression
TensorIndex* Index::getConsumerIndex(
- TensorView* consumer,
+ const TensorView* consumer,
const std::vector<ForLoop*>& loops) {
TORCH_INTERNAL_ASSERT(
loops.size() == consumer->nDims() ||
loops.size() == consumer->domain()->noReductions().size());
+ if (consumer->domain()->noReductions().size() == 0) {
+ return new TensorIndex(consumer, {});
+ }
+
if (consumer->getMemoryType() == MemoryType::Global)
return getGlobalConsumerIndex(consumer, loops);
return getConsumerIndex_impl(consumer, loops);
diff --git a/torch/csrc/jit/codegen/cuda/index_compute.h b/torch/csrc/jit/codegen/cuda/index_compute.h
index 09f1b8e..c677115 100644
--- a/torch/csrc/jit/codegen/cuda/index_compute.h
+++ b/torch/csrc/jit/codegen/cuda/index_compute.h
@@ -55,7 +55,7 @@
namespace jit {
namespace fuser {
-struct IndexCompute : public BackwardVisitor {
+class IndexCompute : public BackwardVisitor {
private:
using BackwardVisitor::handle;
void handle(Split*) override;
@@ -65,54 +65,54 @@
// Otherwise warning on runBackward as it hides an overloaded virtual
// using TransformIter::runBackward;
- IndexCompute(TensorDomain* td, const std::vector<Val*>& _indices);
+ IndexCompute(const TensorDomain* td, const std::vector<Val*>& _indices);
std::unordered_map<IterDomain*, Val*> index_map_;
std::vector<Val*> indices_;
public:
static std::vector<Val*> get(
- TensorDomain* td,
+ const TensorDomain* td,
const std::vector<Val*>& _indices);
};
// Simple interface for IndexCompute
-struct Index {
+class Index {
private:
// Producer indexing if it's in shared or local memory
static TensorIndex* getProducerIndex_impl(
- TensorView* producer,
- TensorView* consumer,
+ const TensorView* producer,
+ const TensorView* consumer,
const std::vector<ForLoop*>& loops);
// Consumer indexing if it's in shared or local memory
static TensorIndex* getConsumerIndex_impl(
- TensorView* consumer,
+ const TensorView* consumer,
const std::vector<ForLoop*>& loops);
- public:
// Producer if it's in global memory
static TensorIndex* getGlobalProducerIndex(
- TensorView* producer,
- TensorView* consumer,
+ const TensorView* producer,
+ const TensorView* consumer,
const std::vector<ForLoop*>& loops);
// Consumer indexing if it's in global memory
static TensorIndex* getGlobalConsumerIndex(
- TensorView* consumer,
+ const TensorView* consumer,
const std::vector<ForLoop*>& loops);
+ public:
// Indexing functions
// Consumer = Producer
// i.e. T0 = T1... -> T0 is the consumer, T1 is the producer
// Producer indexing dispatch
static TensorIndex* getProducerIndex(
- TensorView* producer,
- TensorView* consumer,
+ const TensorView* producer,
+ const TensorView* consumer,
const std::vector<ForLoop*>& loops);
// Consumer index dispatch
static TensorIndex* getConsumerIndex(
- TensorView* consumer,
+ const TensorView* consumer,
const std::vector<ForLoop*>& loops);
// Will run inds through back prop index computation for tv
diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp
index 3027784..c7f5d37 100644
--- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp
+++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp
@@ -1,10 +1,10 @@
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-
-#include <torch/csrc/jit/codegen/cuda/ir_printer.h>
-#include <torch/csrc/jit/codegen/cuda/mutator.h>
#include <torch/csrc/jit/codegen/cuda/dispatch.h>
+#include <torch/csrc/jit/codegen/cuda/fusion.h>
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
+#include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
+#include <torch/csrc/jit/codegen/cuda/ir_printer.h>
+#include <torch/csrc/jit/codegen/cuda/mutator.h>
#include <torch/csrc/jit/ir/ir.h>
@@ -19,6 +19,12 @@
namespace jit {
namespace fuser {
+Statement::Statement(const Statement* src, IrCloner* ir_cloner) {
+ ir_cloner->registerClone(src, this);
+ name_ = src->name_;
+ fusion_ = ir_cloner->fusion();
+}
+
Val* Statement::asVal() {
TORCH_INTERNAL_ASSERT(isVal(), "Cannot cast to Val as this is not a Val.");
return static_cast<Val*>(this);
@@ -29,6 +35,12 @@
return static_cast<Expr*>(this);
}
+void Statement::print() const {
+ IRPrinter ir_printer(std::cout);
+ ir_printer.handle(this);
+ std::cout << std::endl;
+}
+
// When we create a Val we immediately register them with the active fusion.
Val::Val(ValType _vtype, DataType _dtype, bool register_val)
: vtype_{_vtype}, dtype_{_dtype} {
@@ -40,41 +52,45 @@
this->name_ = this->fusion_->registerVal(this);
}
+Val::Val(const Val* src, IrCloner* ir_cloner)
+ : Statement(src, ir_cloner), vtype_(src->vtype_), dtype_(src->dtype_) {}
+
// Traverse origin of all values involved in constructing the provided val.
// Check if all values involved are constant values, meaning the provided
// val is also a constant value.
namespace {
-struct ConstCheck : OptOutConstDispatch {
+class ConstCheck : OptOutConstDispatch {
private:
- bool is_const_ = false;
+ bool is_const_ = true;
- void handle(const Bool* const b) override {
- is_const_ = b->isConst();
+ void handle(const Bool* b) override {
+ is_const_ = is_const_ && b->isConst();
}
- void handle(const Float* const f) override {
- is_const_ = f->isConst();
+ void handle(const Float* f) override {
+ is_const_ = is_const_ && f->isConst();
}
- void handle(const Half* const h) override {
- is_const_ = h->isConst();
+ void handle(const Half* h) override {
+ is_const_ = is_const_ && h->isConst();
}
- void handle(const Int* const i) override {
- is_const_ = i->isConst();
+ void handle(const Int* i) override {
+ is_const_ = is_const_ && i->isConst();
}
- void handle(const Expr* const expr) override {
+ void handle(const NamedScalar* ns) override {
+ is_const_ = is_const_ && false;
+ }
+
+ void handle(const Expr* expr) override {
for (auto inp : expr->inputs()) {
OptOutConstDispatch::handle(inp);
}
}
- void handle(const NamedScalar* const ns) override {
- is_const_ = false;
- }
- void handle(const Val* const val) override {
+ void handle(const Val* val) override {
const Expr* orig = FusionGuard::getCurFusion()->origin(val);
if (orig != nullptr)
handle(orig);
@@ -83,7 +99,7 @@
}
public:
- static bool isConst(const Val* const val) {
+ static bool isConst(const Val* val) {
ConstCheck cc;
cc.handle(val);
return cc.is_const_;
@@ -123,6 +139,9 @@
return (fusion_->origin(this));
}
+Scope::Scope(const Scope* src, IrCloner* ir_cloner)
+ : exprs_(ir_cloner->clone(src->exprs_)) {}
+
void Scope::insert_before(Expr* ref, Expr* expr) {
auto it = exprs_.begin();
while (it != exprs_.end()) {
@@ -176,76 +195,6 @@
this->exprs_ = std::vector<Expr*>();
}
-bool IRInputOutput::hasInput(const Val* const input) const {
- for (auto val : inputs_)
- if (val == input)
- return true;
- return false;
-}
-
-bool IRInputOutput::hasOutput(const Val* const output) const {
- for (auto val : outputs_)
- if (val == output)
- return true;
- return false;
-}
-
-void IRInputOutput::replaceInput(Val* replace, Val* with) {
- bool changed = false;
- for (decltype(inputs_.size()) i{0}; i < inputs_.size(); i++) {
- if (inputs_[i] == replace) {
- inputs_[i] = with;
- changed = true;
- break;
- }
- }
- TORCH_INTERNAL_ASSERT(
- changed,
- "Error detected when trying to replace input ",
- replace,
- " with ",
- with,
- " .");
-}
-
-void IRInputOutput::replaceOutput(Val* replace, Val* with) {
- bool changed = false;
- for (decltype(outputs_.size()) i{0}; i < outputs_.size(); i++) {
- if (outputs_[i] == replace) {
- outputs_[i] = with;
- changed = true;
- break;
- }
- }
- TORCH_INTERNAL_ASSERT(
- changed,
- "Error detected when trying to replace output ",
- replace,
- " with ",
- with,
- " .");
-}
-
-void IRInputOutput::removeInput(Val* val) {
- auto it = inputs_.begin();
- for (; it != inputs_.end(); ++it) {
- if ((*it) == val)
- break;
- }
- TORCH_INTERNAL_ASSERT(it != inputs_.end());
- inputs_.erase(it);
-}
-
-void IRInputOutput::removeOutput(Val* val) {
- auto it = outputs_.begin();
- for (; it != outputs_.end(); ++it) {
- if ((*it) == val)
- break;
- }
- TORCH_INTERNAL_ASSERT(it != outputs_.end());
- outputs_.erase(it);
-}
-
// We don't register with the active fusion in Expr as this needs to be done
// after inputs and outputs are registered with the Expr
Expr::Expr(ExprType _type) : type_{_type} {
@@ -255,6 +204,25 @@
this->fusion_ = fusion;
}
+Expr::Expr(const Expr* src, IrCloner* ir_cloner)
+ : Statement(src, ir_cloner),
+ type_(src->type_),
+ inputs_(ir_cloner->clone(src->inputs_)),
+ outputs_(ir_cloner->clone(src->outputs_)) {}
+
+bool Expr::sameAs(const Expr* const other) const {
+ if (getExprType() != other->getExprType())
+ return false;
+ if (inputs().size() != other->inputs().size() ||
+ outputs().size() != other->outputs().size())
+ return false;
+ for (size_t i = 0; i < inputs().size(); i++) {
+ if (!input(i)->sameAs(other->input(i)))
+ return false;
+ }
+ return true;
+}
+
} // namespace fuser
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h
index cc2507a..c0564b2 100644
--- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h
+++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h
@@ -23,7 +23,7 @@
/*
* This file defines the base IR structure. Any IR node in this system will
* inherit from one of the following classes: Statement, Expr, Val,
- * IRInputOutput IR is any information that the code generation stack may need
+ * IrInputOutput IR is any information that the code generation stack may need
* for analysis. By analysis we're refering to anything done in response to a
* user facing call of this stack. This could be careful tracking of user calls,
* and any transformation including optimizing transformations, user declared
@@ -38,13 +38,14 @@
constexpr StmtNameType UNINITIALIZED_STMTNAMETYPE =
std::numeric_limits<unsigned int>::max();
-struct Fusion;
-struct FusionGuard;
-struct Expr;
-struct Val;
-struct UnaryOp;
-struct BinaryOp;
-struct IterDomain;
+class Fusion;
+class FusionGuard;
+class Expr;
+class Val;
+class UnaryOp;
+class BinaryOp;
+class IterDomain;
+class IrCloner;
/*
* Statement is the highest level node representation. Everything that is
@@ -57,7 +58,15 @@
* Basically beinng able to succienctly traverse down the inhereitance stack of
* a Statment at runtime. This is currently implemented in dispatch.h
*/
-struct TORCH_CUDA_API Statement {
+class TORCH_CUDA_API Statement {
+ friend void swap(Fusion&, Fusion&) noexcept;
+
+ public:
+ Statement() = default;
+
+ // Cloning constructor
+ Statement(const Statement* src, IrCloner* ir_cloner);
+
virtual ~Statement() = default;
// Dispatch functions, definitions in dispatch.cpp
@@ -71,21 +80,21 @@
static Statement* mutatorDispatch(T mutator, Statement*);
// Accessor functions to types. Vals always have a DataType, Exprs never do
- virtual c10::optional<ValType> getValType() const noexcept {
+ virtual c10::optional<ValType> getValType() const {
return c10::nullopt;
}
virtual c10::optional<DataType> getDataType() const {
return c10::nullopt;
}
- virtual c10::optional<ExprType> getExprType() const noexcept {
+ virtual c10::optional<ExprType> getExprType() const {
return c10::nullopt;
}
// Short cut to figure out if it is a value/expression
- bool isVal() const noexcept {
+ bool isVal() const {
return getValType() != c10::nullopt;
}
- bool isExpr() const noexcept {
+ bool isExpr() const {
return getExprType() != c10::nullopt;
}
@@ -119,12 +128,12 @@
}
// Return the fusion this statement belongs to
- Fusion* fusion() const noexcept {
+ Fusion* fusion() const {
return fusion_;
}
// Return the int that represents its name
- StmtNameType name() const noexcept {
+ StmtNameType name() const {
return name_;
}
@@ -142,6 +151,8 @@
return this == other;
}
+ void print() const;
+
protected:
StmtNameType name_ = UNINITIALIZED_STMTNAMETYPE;
Fusion* fusion_ = nullptr;
@@ -165,13 +176,16 @@
* - Accessor functions for members
* - Must call Val constructor, Val constructor registers with fusion
* - Implementation of bool sameAs(...)
+ * - Must implement a "cloning" constructor, ex.
+ * Int::Int(const Int* src, IrCloner* ir_cloner)
* 2) dispatch.h/.cpp must be updated to include dispatch of the new Val
* 3) Default mutator function should be added to mutator.cpp
- * 4) Printing functions should be added to ir_iostream.h/.cpp
+ * 4a) Printing functions should be added to ir_iostream.h/.cpp
+ * 4b) Graphviz generation must be added to ir_graphviz.h/.cpp
* 5) An enum value must be added to ValType in type.h
* 6) A string entry must be added in val_type_string_map
*/
-struct TORCH_CUDA_API Val : public Statement {
+class TORCH_CUDA_API Val : public Statement {
public:
virtual ~Val() = default;
@@ -182,10 +196,13 @@
// to throw, fusion's destructor will get called, but the pointer to this Val
// will be invalid. When fusion tries to delete this value it will cause a seg
// fault, instead of showing the thrown error.
- Val(ValType _vtype,
+ explicit Val(
+ ValType _vtype,
DataType _dtype = DataType::Null,
bool register_val = true);
+ Val(const Val* src, IrCloner* ir_cloner);
+
// TODO: Values are unique and not copyable
Val(const Val& other) = delete;
Val& operator=(const Val& other) = delete;
@@ -193,7 +210,7 @@
Val(Val&& other) = delete;
Val& operator=(Val&& other) = delete;
- c10::optional<ValType> getValType() const noexcept override {
+ c10::optional<ValType> getValType() const override {
return vtype_;
}
@@ -244,13 +261,12 @@
const DataType dtype_;
};
-// TODO: We should use this for the following:
-// Fusion
-// IfThenElse
-// ForLoop
-struct TORCH_CUDA_API Scope {
+class TORCH_CUDA_API Scope {
public:
- const std::vector<Expr*>& exprs() const noexcept {
+ Scope() = default;
+ Scope(const Scope* src, IrCloner* ir_cloner);
+
+ const std::vector<Expr*>& exprs() const {
return exprs_;
}
@@ -274,6 +290,14 @@
return exprs_.size();
}
+ auto& operator[](size_t i) {
+ return exprs_[i];
+ }
+
+ auto& operator[](size_t i) const {
+ return exprs_[i];
+ }
+
// Insert expr before ref
void insert_before(Expr* ref, Expr* expr);
@@ -293,72 +317,6 @@
};
/*
- * IRInputOutput is a function on Vals. Has inputs and outputs that are all
- * Vals. It is used for anything that connects values and therefore would be
- * used during dependency analysis. Typically classes that inherit from
- * IRInputOutput will do so by inheriting from Exprs. Expr's are expected for
- * most dependency based operations like IterVisitor, or DependencyCheck.
- *
- * Examples:
- * binary operations on tensors, scalar values, or a combination, a thread all
- * reduce, for loops
- */
-struct TORCH_CUDA_API IRInputOutput {
- virtual ~IRInputOutput() = default;
-
- // Returns if Val is an input or output of this IRInputOutput instance
- bool hasInput(const Val* const input) const;
- bool hasOutput(const Val* const output) const;
-
- // Input/output accessors
- void addInputAt(std::deque<Val*>::size_type pos, Val* input) {
- inputs_.insert(inputs_.begin() + pos, input);
- }
-
- void addOutputAt(std::deque<Val*>::size_type pos, Val* output) {
- outputs_.insert(outputs_.begin() + pos, output);
- }
-
- const std::deque<Val*>& inputs() const noexcept {
- return inputs_;
- }
- const std::deque<Val*>& outputs() const noexcept {
- return outputs_;
- }
-
- Val* input(std::deque<Val*>::size_type idx) const {
- return inputs_[idx];
- }
- Val* output(std::deque<Val*>::size_type idx) const {
- return outputs_[idx];
- }
-
- void addInput(Val* input) {
- inputs_.push_back(input);
- }
- void addOutput(Val* output) {
- outputs_.push_back(output);
- }
-
- void replaceInput(Val* replace, Val* with);
- void replaceOutput(Val* replace, Val* with);
-
- void removeInput(Val* val);
- void removeOutput(Val* val);
-
- std::deque<Val*>::size_type nInputs() const noexcept {
- return inputs_.size();
- }
- std::deque<Val*>::size_type nOutputs() const noexcept {
- return outputs_.size();
- }
-
- protected:
- std::deque<Val*> inputs_;
- std::deque<Val*> outputs_;
-};
-
-/*
* A Expr represents a "computation." These are functions that takes inputs
* and produce outputs, inputs and outputs all being Vals. There are
* specializations of BinaryOp which takes 2 inputs and produces 1 output, and
@@ -396,11 +354,12 @@
* 6) An enum value must be added to ExprType in type.h 7) A string
* entry must be added in expr_type_string_map
*/
-struct TORCH_CUDA_API Expr : public Statement, IRInputOutput {
+class TORCH_CUDA_API Expr : public Statement {
public:
- virtual ~Expr() = default;
Expr() = delete;
- Expr(ExprType _type);
+ explicit Expr(ExprType _type);
+ Expr(const Expr* src, IrCloner* ir_cloner);
+ virtual ~Expr() = default;
Expr(const Expr& other) = delete;
Expr& operator=(const Expr& other) = delete;
@@ -408,24 +367,31 @@
Expr(Expr&& other) = delete;
Expr& operator=(Expr&& other) = delete;
- c10::optional<ExprType> getExprType() const noexcept override {
- return type_;
- }
- ExprType type() const noexcept {
+ c10::optional<ExprType> getExprType() const override {
return type_;
}
- bool sameAs(const Expr* const other) const {
- if (getExprType() != other->getExprType())
- return false;
- if (inputs().size() != other->inputs().size() ||
- outputs().size() != other->outputs().size())
- return false;
- for (size_t i = 0; i < inputs().size(); i++) {
- if (!input(i)->sameAs(other->input(i)))
- return false;
- }
- return true;
+ ExprType type() const {
+ return type_;
+ }
+
+ bool sameAs(const Expr* const other) const;
+
+ // Input/output accessors
+ const auto& inputs() const {
+ return inputs_;
+ }
+
+ const auto& outputs() const {
+ return outputs_;
+ }
+
+ auto input(size_t index) const {
+ return inputs_[index];
+ }
+
+ auto output(size_t index) const {
+ return outputs_[index];
}
// Dispatch functions, definitions in dispatch.cpp
@@ -438,8 +404,19 @@
template <typename T>
static Statement* mutatorDispatch(T mutator, Expr*);
+ protected:
+ void addInput(Val* input) {
+ inputs_.push_back(input);
+ }
+
+ void addOutput(Val* output) {
+ outputs_.push_back(output);
+ }
+
private:
- ExprType type_;
+ ExprType type_ = ExprType::Invalid;
+ std::vector<Val*> inputs_;
+ std::vector<Val*> outputs_;
};
} // namespace fuser
diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.cpp b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp
new file mode 100644
index 0000000..49e557d
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp
@@ -0,0 +1,133 @@
+
+#include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
+#include <torch/csrc/jit/codegen/cuda/fusion.h>
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+
+Statement* IrCloner::clone(const Statement* statement) {
+ if (statement == nullptr) {
+ return nullptr;
+ }
+
+ // Have we already cloned this node?
+ const auto it = clones_map_.find(statement);
+ if (it != clones_map_.end()) {
+ return it->second;
+ } else {
+ // Clone the new node, saving/restoring this->clone_
+ // since the cloning can be reentrant
+ auto saved_clone = clone_;
+ handle(statement);
+ auto new_node = clone_;
+ clone_ = saved_clone;
+
+ // The base cloning constructor (Statement) should have
+ // registered the new node. Failure to do so indicates
+ // that something went horribly wrong.
+ TORCH_INTERNAL_ASSERT(new_node != nullptr);
+ TORCH_INTERNAL_ASSERT(clones_map_[statement] == new_node);
+
+ return new_node;
+ }
+}
+
+void IrCloner::registerClone(const Statement* src, Statement* clone) {
+ TORCH_CHECK(src != nullptr);
+ TORCH_CHECK(clone != nullptr);
+ TORCH_CHECK(clones_map_.insert({src, clone}).second);
+}
+
+void IrCloner::handle(const Statement* s) {
+ OptInConstDispatch::handle(s);
+}
+
+void IrCloner::handle(const Val* v) {
+ OptInConstDispatch::handle(v);
+}
+
+void IrCloner::handle(const Expr* e) {
+ OptInConstDispatch::handle(e);
+}
+
+void IrCloner::handle(const TensorDomain* td) {
+ clone_ = new TensorDomain(td, this);
+}
+
+void IrCloner::handle(const IterDomain* id) {
+ clone_ = new IterDomain(id, this);
+}
+
+void IrCloner::handle(const TensorIndex* ti) {
+ clone_ = new TensorIndex(ti, this);
+}
+
+void IrCloner::handle(const Bool* b) {
+ clone_ = new Bool(b, this);
+}
+
+void IrCloner::handle(const Float* f) {
+ clone_ = new Float(f, this);
+}
+
+void IrCloner::handle(const Half* h) {
+ clone_ = new Half(h, this);
+}
+
+void IrCloner::handle(const Int* i) {
+ clone_ = new Int(i, this);
+}
+
+void IrCloner::handle(const NamedScalar* named_scalar) {
+ clone_ = new NamedScalar(named_scalar, this);
+}
+
+void IrCloner::handle(const TensorView* tv) {
+ clone_ = new TensorView(tv, this);
+}
+
+void IrCloner::handle(const UnaryOp* op) {
+ clone_ = new UnaryOp(op, this);
+}
+
+void IrCloner::handle(const BinaryOp* op) {
+ clone_ = new BinaryOp(op, this);
+}
+
+void IrCloner::handle(const TernaryOp* op) {
+ clone_ = new TernaryOp(op, this);
+}
+
+void IrCloner::handle(const BroadcastOp* op) {
+ clone_ = new BroadcastOp(op, this);
+}
+
+void IrCloner::handle(const ReductionOp* op) {
+ clone_ = new ReductionOp(op, this);
+}
+
+void IrCloner::handle(const ForLoop* for_loop) {
+ clone_ = new ForLoop(for_loop, this);
+}
+
+void IrCloner::handle(const IfThenElse* if_then_else) {
+ clone_ = new IfThenElse(if_then_else, this);
+}
+
+void IrCloner::handle(const Allocate* allocate) {
+ clone_ = new Allocate(allocate, this);
+}
+
+void IrCloner::handle(const Split* split) {
+ clone_ = new Split(split, this);
+}
+
+void IrCloner::handle(const Merge* merge) {
+ clone_ = new Merge(merge, this);
+}
+
+} // namespace fuser
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.h b/torch/csrc/jit/codegen/cuda/ir_cloner.h
new file mode 100644
index 0000000..4097121
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/ir_cloner.h
@@ -0,0 +1,91 @@
+
+#pragma once
+
+#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/jit/codegen/cuda/dispatch.h>
+
+#include <unordered_map>
+#include <vector>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+
+class Fusion;
+
+// Clones nodes from an exiting Fusion
+class TORCH_CUDA_API IrCloner : private OptInConstDispatch {
+ friend class Statement;
+
+ public:
+ explicit IrCloner(Fusion* new_fusion) : fusion_(new_fusion) {}
+
+ Statement* clone(const Statement* statement);
+
+ template <class T>
+ T* clone(const T* node) {
+ return node ? clone(node->template as<Statement>())->template as<T>()
+ : nullptr;
+ }
+
+ template <class T>
+ std::vector<T*> clone(const std::vector<T*>& container) {
+ std::vector<T*> copy;
+ for (auto p : container) {
+ copy.push_back(clone(p));
+ }
+ return copy;
+ }
+
+ Fusion* fusion() const {
+ return fusion_;
+ }
+
+ private:
+ void registerClone(const Statement* src, Statement* clone);
+
+ void handle(const Statement*) override;
+ void handle(const Val*) override;
+ void handle(const Expr*) override;
+
+ void handle(const TensorDomain*) override;
+ void handle(const TensorView*) override;
+ void handle(const IterDomain*) override;
+ void handle(const TensorIndex*) override;
+
+ void handle(const Bool*) override;
+ void handle(const Float*) override;
+ void handle(const Half*) override;
+ void handle(const Int*) override;
+ void handle(const NamedScalar*) override;
+
+ void handle(const UnaryOp*) override;
+ void handle(const BinaryOp*) override;
+ void handle(const TernaryOp*) override;
+ void handle(const BroadcastOp*) override;
+ void handle(const ReductionOp*) override;
+
+ void handle(const ForLoop*) override;
+ void handle(const IfThenElse*) override;
+ void handle(const Allocate*) override;
+
+ void handle(const Split*) override;
+ void handle(const Merge*) override;
+
+ private:
+ // The destination Fusion container
+ Fusion* fusion_ = nullptr;
+
+ // The dispatch interface doesn't allow returning values from
+ // individual `handle()` methods, so they are storing the
+ // result here
+ Statement* clone_ = nullptr;
+
+ // We keep track of the original -> clone map so we don't
+ // duplicate clones of the same object if referenced multiple times
+ std::unordered_map<const Statement*, Statement*> clones_map_;
+};
+
+} // namespace fuser
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp b/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp
index 3c3b259..e434b14 100644
--- a/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp
+++ b/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp
@@ -31,6 +31,17 @@
~IrNodeLabel() override = default;
+ void handle(const Bool* b) override {
+ if (b->isSymbolic()) {
+ label_ << "b" << b->name();
+ } else {
+ if (detail_level_ >= DetailLevel::Explicit) {
+ label_ << "b" << b->name() << "=";
+ }
+ label_ << *b->value();
+ }
+ }
+
void handle(const Float* f) override {
if (f->isSymbolic()) {
label_ << "f" << f->name();
@@ -42,6 +53,17 @@
}
}
+ void handle(const Half* h) override {
+ if (h->isSymbolic()) {
+ label_ << "h" << h->name();
+ } else {
+ if (detail_level_ >= DetailLevel::Explicit) {
+ label_ << "h" << h->name() << "=";
+ }
+ label_ << *h->value();
+ }
+ }
+
void handle(const Int* i) override {
if (i->isSymbolic()) {
label_ << "i" << i->name();
@@ -90,13 +112,11 @@
}
void handle(const Split* split) override {
- label_ << "Split(IterDomain=" << split->in()
- << ", factor=" << IrNodeLabel::gen(split->factor()) << ")";
+ label_ << "Split(factor=" << IrNodeLabel::gen(split->factor()) << ")";
}
void handle(const Merge* merge) override {
- label_ << "Merge(IterDomainOuter=" << merge->outer()
- << ", IterDomainInner=" << merge->inner() << ")";
+ label_ << "Merge";
}
private:
@@ -286,7 +306,7 @@
void IrGraphGenerator::handle(const Statement* s) {
OptInConstDispatch::handle(s);
-};
+}
void IrGraphGenerator::handle(const Val* v) {
if (!visited(v)) {
@@ -296,14 +316,14 @@
}
OptInConstDispatch::handle(v);
}
-};
+}
void IrGraphGenerator::handle(const Expr* e) {
if (!visited(e)) {
visited_.insert(e);
OptInConstDispatch::handle(e);
}
-};
+}
void IrGraphGenerator::handle(const TensorDomain* td) {
graph_def_ << " " << getid(td) << " [label=\"TensorDomain\", "
@@ -341,10 +361,18 @@
}
}
+void IrGraphGenerator::handle(const Bool* b) {
+ printValue(b, IrNodeLabel::gen(b, detail_level_));
+}
+
void IrGraphGenerator::handle(const Float* f) {
printValue(f, IrNodeLabel::gen(f, detail_level_));
}
+void IrGraphGenerator::handle(const Half* h) {
+ printValue(h, IrNodeLabel::gen(h, detail_level_));
+}
+
void IrGraphGenerator::handle(const Int* i) {
printValue(i, IrNodeLabel::gen(i, detail_level_));
}
@@ -394,7 +422,7 @@
label << uop->getUnaryOpType();
printExpr(uop, label.str());
- // UnaryOp inputs & outputs
+ // inputs & outputs
addArc(uop->in(), uop);
addArc(uop, uop->out());
}
@@ -405,12 +433,43 @@
label << bop->getBinaryOpType();
printExpr(bop, label.str());
- // BinaryOp inputs & outputs
+ // inputs & outputs
addArc(bop->lhs(), bop);
addArc(bop->rhs(), bop, "[color=blue]");
addArc(bop, bop->out());
}
+void IrGraphGenerator::handle(const TernaryOp* op) {
+ // node
+ std::stringstream label;
+ label << op->getTernaryOpType();
+ printExpr(op, label.str());
+
+ // inputs & outputs
+ addArc(op->in1(), op);
+ addArc(op->in2(), op, "[color=blue]");
+ addArc(op->in3(), op, "[color=brown]");
+ addArc(op, op->out());
+}
+
+void IrGraphGenerator::handle(const BroadcastOp* op) {
+ printExpr(op, "Broadcast");
+ addArc(op->in(), op);
+ addArc(op, op->out());
+}
+
+void IrGraphGenerator::handle(const ReductionOp* op) {
+ // node
+ std::stringstream label;
+ label << "Reduction(" << op->getReductionOpType() << ")";
+ printExpr(op, label.str());
+
+ // inputs & outputs
+ addArc(op->in(), op);
+ addArc(op->init(), op, "[color=blue]");
+ addArc(op, op->out());
+}
+
void IrGraphGenerator::handle(const ForLoop* for_loop) {
printExpr(for_loop, "ForLoop");
addArc(for_loop->index(), for_loop);
diff --git a/torch/csrc/jit/codegen/cuda/ir_graphviz.h b/torch/csrc/jit/codegen/cuda/ir_graphviz.h
index 386521d..c022ec0 100644
--- a/torch/csrc/jit/codegen/cuda/ir_graphviz.h
+++ b/torch/csrc/jit/codegen/cuda/ir_graphviz.h
@@ -68,12 +68,17 @@
void handle(const IterDomain*) override;
void handle(const TensorIndex*) override;
+ void handle(const Bool*) override;
void handle(const Float*) override;
+ void handle(const Half*) override;
void handle(const Int*) override;
void handle(const NamedScalar*) override;
void handle(const UnaryOp*) override;
void handle(const BinaryOp*) override;
+ void handle(const TernaryOp*) override;
+ void handle(const BroadcastOp*) override;
+ void handle(const ReductionOp*) override;
void handle(const ForLoop*) override;
void handle(const IfThenElse*) override;
diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h
index 64be541..5451c75 100644
--- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h
+++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h
@@ -21,7 +21,8 @@
* This value can be a symbolic value (defined after the kernel
* is compiled) or a constant value (inlined into the kernel definition).
*/
-struct TORCH_CUDA_API Bool : public Val {
+class TORCH_CUDA_API Bool : public Val {
+ public:
~Bool() = default;
Bool() : Val(ValType::Scalar, DataType::Bool), maybe_value_{c10::nullopt} {}
@@ -29,6 +30,8 @@
Bool(bool _value)
: Val(ValType::Scalar, DataType::Bool), maybe_value_{_value} {}
+ Bool(const Bool* src, IrCloner* ir_cloner);
+
Bool(const Bool& other) = delete;
Bool& operator=(const Bool& other) = delete;
@@ -41,7 +44,7 @@
bool isConst() const {
return maybe_value_.has_value();
}
- c10::optional<bool> value() const noexcept {
+ c10::optional<bool> value() const {
return maybe_value_;
}
@@ -56,7 +59,8 @@
* Float32. This value can be a symbolic value (defined after the kernel
* is compiled) or a constant value (inlined into the kernel definition).
*/
-struct TORCH_CUDA_API Float : public Val {
+class TORCH_CUDA_API Float : public Val {
+ public:
using ScalarType = double;
~Float() = default;
@@ -66,6 +70,8 @@
Float(ScalarType _value)
: Val(ValType::Scalar, DataType::Float), maybe_value_{_value} {}
+ Float(const Float* src, IrCloner* ir_cloner);
+
Float(const Float& other) = delete;
Float& operator=(const Float& other) = delete;
@@ -78,7 +84,7 @@
bool isConst() const {
return maybe_value_.has_value();
}
- c10::optional<ScalarType> value() const noexcept {
+ c10::optional<ScalarType> value() const {
return maybe_value_;
}
@@ -93,7 +99,8 @@
* This value can be a symbolic value (defined after the kernel
* is compiled) or a constant value (inlined into the kernel definition).
*/
-struct TORCH_CUDA_API Half : public Val {
+class TORCH_CUDA_API Half : public Val {
+ public:
~Half() = default;
Half() : Val(ValType::Scalar, DataType::Half), maybe_value_{c10::nullopt} {}
@@ -101,6 +108,8 @@
Half(float _value)
: Val(ValType::Scalar, DataType::Half), maybe_value_{_value} {}
+ Half(const Half* src, IrCloner* ir_cloner);
+
Half(const Half& other) = delete;
Half& operator=(const Half& other) = delete;
@@ -113,7 +122,7 @@
bool isConst() const {
return maybe_value_.has_value();
}
- c10::optional<float> value() const noexcept {
+ c10::optional<float> value() const {
return maybe_value_;
}
@@ -125,7 +134,8 @@
// An Int64 value. If used for indexing it's set as size_t. Otherwise it's an
// inlined literal in the kernel.
-struct TORCH_CUDA_API Int : public Val {
+class TORCH_CUDA_API Int : public Val {
+ public:
using ScalarType = int64_t;
~Int() = default;
@@ -135,6 +145,8 @@
Int(ScalarType _value)
: Val(ValType::Scalar, DataType::Int), maybe_value_{_value} {}
+ Int(const Int* src, IrCloner* ir_cloner);
+
Int(const Int& other) = delete;
Int& operator=(const Int& other) = delete;
@@ -147,7 +159,7 @@
bool isConst() const {
return maybe_value_.has_value();
}
- c10::optional<ScalarType> value() const noexcept {
+ c10::optional<ScalarType> value() const {
return maybe_value_;
}
@@ -157,11 +169,13 @@
const c10::optional<ScalarType> maybe_value_;
};
-struct TransformReplay;
-struct TransformIter;
-struct OptOutMutator;
-struct LoopNestGenerator;
-struct GPULower;
+class ComputeAt;
+class TransformReplay;
+class TransformIter;
+class OptOutMutator;
+class LoopNestGenerator;
+class GPULower;
+
/*
* TensorView is our primitive Tensor Type used in code generation. It can be
* thought of as representing physical memory, however, its dimensionality is
@@ -178,7 +192,8 @@
* we iterate over the 3D TensorDomain [I, J, K], where K is the fastest
* changing dimension.
*/
-struct TORCH_CUDA_API TensorView : public Val {
+class TORCH_CUDA_API TensorView : public Val {
+ public:
~TensorView() = default;
TensorView(const TensorView& other) = delete;
@@ -194,11 +209,15 @@
TensorView(const std::shared_ptr<Value>& jit_value)
: TensorView(jit_value->type()->cast<c10::TensorType>()) {}
- TensorDomain* domain() const noexcept {
+ TensorView(const TensorView* src, IrCloner* ir_cloner);
+
+ TensorDomain* domain() const {
return domain_;
}
bool hasReduction() const;
+ bool hasBlockReduction() const;
+ bool hasGridReduction() const;
bool hasBroadcast() const;
// Is there an active computeAt TensorView/Axis
@@ -207,7 +226,7 @@
}
// Return the TensorView we're computing at
- TensorView* getComputeAtView() const noexcept {
+ TensorView* getComputeAtView() const {
return compute_at_view_;
}
@@ -216,18 +235,20 @@
IterDomain* axis(int pos) const;
// Return compute at axis relative to this domain
- unsigned int getThisComputeAtAxis() const noexcept {
+ unsigned int getThisComputeAtAxis() const {
return this_compute_at_axis_;
}
// Return compute at axis relative to compute at view
- unsigned int getRelativeComputeAtAxis() const noexcept {
+ unsigned int getRelativeComputeAtAxis() const {
return relative_compute_at_axis_;
}
// Will check if an axis is inside computeAtAxis and will fetch the reference
// to be used in code generation.
std::pair<IterDomain*, TensorView*> getComputeAtAxis(int pos) {
+ TORCH_INTERNAL_ASSERT(
+ nDims() > 0, "Tried to access a computeAt axis in a 0-dim TensorView");
if (!hasComputeAt() || getThisComputeAtAxis() <= (unsigned int)pos)
return std::pair<IterDomain*, TensorView*>(axis(pos), this);
return compute_at_view_->getComputeAtAxis(getComputeAtRelPos(pos));
@@ -259,26 +280,36 @@
// Reorder axes according to old2new[old_pos] = new_pos
TensorView* reorder(const std::unordered_map<int, int>& old2new);
- /*
- * WARNING: Does not return this TensorView, returns a new tensorview consumed
- * to create this!! Take reduction axes out of this domain, and create a new
- * domain. New domain will be used to create this domain. For example: TV1[I0,
- * I1] = TV0[I0, R0, R1, I1] TV0->rfactor({1}) TV0 is transformed to ->
- * TV0[I0, R1, I1] The TensorView returned is: TV2[I0, R0, I3, I1] The
- * reduction will now beset as: TV1[I0, R1, I1] = TV2[I0, R0, I3, I1] TV0[I0,
- * I1] = TV1[I0, R1, I1]
- */
+ // WARNING: rFactor does not return this TensorView, ir returns a new
+ // tensorview consumed by this!
+ //
+ // Take reduction axes out of this domain, and create a new
+ // domain. New domain will be used to create this domain.
+ //
+ // For example:
+ // TV1[I0, R1, R2, I3] = TV0[I0, I1, I2, I3]
+ //
+ // After:
+ // TV1->rfactor({1}), TV1 is transformed to -> TV1[I0, R2, I3]
+ //
+ // The TensorView returned is: TV2[I0, R1, I2, I3]
+ //
+ // The reduction will now beset as:
+ // TV2[I0, R1, I2, I3] = TV0[I0, I1, I2, I3]
+ // TV1[I0, R2, I3] = TV2[I0, R1, I2, I3]
+ //
TensorView* rFactor(const std::vector<int>& axes);
- MemoryType getMemoryType() const noexcept {
+ MemoryType getMemoryType() const {
return memory_type_;
}
friend TORCH_CUDA_API TransformReplay;
- // friend TORCH_CUDA_API TransformIter;
friend TORCH_CUDA_API OptOutMutator;
- friend TORCH_CUDA_API GPULower;
friend TORCH_CUDA_API LoopNestGenerator;
+ friend ComputeAt;
+ friend void IrFixComputeAt(Fusion*);
+ friend void IrAdjustMemoryTypes(Fusion* fusion);
protected:
// Make an exact copy of this tensor (similar to clone()), however, also grabs
@@ -296,6 +327,10 @@
void setComputeAt(TensorView* computeAtView, int axis);
+ // Set all computeAt members without checking any correctness. Useful for
+ // computeAt with outputs relative to eachother
+ void setComputeAt(TensorView* computeAtView, int thisPos, int relPos);
+
void setMemoryType(MemoryType mt) {
memory_type_ = mt;
bool is_inp_or_out =
@@ -307,12 +342,6 @@
}
private:
- // Transform this view like consumer, mark compute_at_(viw,axis)
- void computeAt_impl(TensorView* consumer, int axis);
-
- // Transform this view like producer, mark producer as compute_at_(this, axis)
- void forwardComputeAt_impl(TensorView* producer, int axis);
-
// Make a copy of the domain (used for Tensor based constructor), likely to be
// removed soon.
void copyDomain(const TensorDomain* td);
@@ -321,7 +350,8 @@
int getComputeAtRelPos(int pos);
void setThisComputeAtAxis();
- TensorDomain* domain_;
+ private:
+ TensorDomain* domain_ = nullptr;
TensorView* compute_at_view_ = nullptr;
// compute at axis in compute at view
unsigned int relative_compute_at_axis_ = 0;
diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h
index f77f165..9420227 100644
--- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h
+++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h
@@ -5,7 +5,6 @@
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/tensor_meta.h>
/*
* Nodes in here should generally not be used by users. They should be behind
@@ -30,24 +29,27 @@
* 3) Reduction across a dimension i.e. val.sum(axis=2)
* 4) split/merge
*/
-struct TORCH_CUDA_API UnaryOp : public Expr {
+class TORCH_CUDA_API UnaryOp : public Expr {
+ public:
~UnaryOp() = default;
UnaryOp(UnaryOpType _type, Val* _out, Val* _in);
+ UnaryOp(const UnaryOp* src, IrCloner* ir_cloner);
+
UnaryOp(const UnaryOp& other) = delete;
UnaryOp& operator=(const UnaryOp& other) = delete;
UnaryOp(UnaryOp&& other) = delete;
UnaryOp& operator=(UnaryOp&& other) = delete;
- Val* out() const noexcept {
+ Val* out() const {
return out_;
}
- Val* in() const noexcept {
+ Val* in() const {
return in_;
}
- UnaryOpType getUnaryOpType() const noexcept {
+ UnaryOpType getUnaryOpType() const {
return unary_op_type_;
}
@@ -55,8 +57,8 @@
private:
const UnaryOpType unary_op_type_;
- Val* const out_;
- Val* const in_;
+ Val* const out_ = nullptr;
+ Val* const in_ = nullptr;
};
/*
@@ -65,27 +67,30 @@
* 1) Add/mul/div/mod/sub (A * B)
* 2) LT (A < B)
*/
-struct TORCH_CUDA_API BinaryOp : public Expr {
+class TORCH_CUDA_API BinaryOp : public Expr {
+ public:
~BinaryOp() = default;
BinaryOp(BinaryOpType _type, Val* _out, Val* _lhs, Val* _rhs);
+ BinaryOp(const BinaryOp* src, IrCloner* ir_cloner);
+
BinaryOp(const BinaryOp& other) = delete;
BinaryOp& operator=(const BinaryOp& other) = delete;
BinaryOp(BinaryOp&& other) = delete;
BinaryOp& operator=(BinaryOp&& other) = delete;
- Val* out() const noexcept {
+ Val* out() const {
return out_;
}
- Val* lhs() const noexcept {
+ Val* lhs() const {
return lhs_;
}
- Val* rhs() const noexcept {
+ Val* rhs() const {
return rhs_;
}
- BinaryOpType getBinaryOpType() const noexcept {
+ BinaryOpType getBinaryOpType() const {
return binary_op_type_;
}
@@ -93,37 +98,40 @@
private:
const BinaryOpType binary_op_type_;
- Val* const out_;
- Val* const lhs_;
- Val* const rhs_;
+ Val* const out_ = nullptr;
+ Val* const lhs_ = nullptr;
+ Val* const rhs_ = nullptr;
};
/*
* Broadcast _in to match _out. broadcast_dims are relative to out. Where
* broadcast_dims.size() + _in->nDims() == _out->nDims().
*/
-struct TORCH_CUDA_API BroadcastOp : public Expr {
+class TORCH_CUDA_API BroadcastOp : public Expr {
+ public:
~BroadcastOp() = default;
BroadcastOp(Val* _out, Val* _in);
+ BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner);
+
BroadcastOp(const BroadcastOp& other) = delete;
BroadcastOp& operator=(const BroadcastOp& other) = delete;
BroadcastOp(BroadcastOp&& other) = delete;
BroadcastOp& operator=(BroadcastOp&& other) = delete;
- Val* out() const noexcept {
+ Val* out() const {
return out_;
}
- Val* in() const noexcept {
+ Val* in() const {
return in_;
}
bool sameAs(const BroadcastOp* const other) const;
private:
- Val* const out_;
- Val* const in_;
+ Val* const out_ = nullptr;
+ Val* const in_ = nullptr;
};
/*
@@ -133,64 +141,75 @@
* tensor. The output tensors size will be the size of all
* non-reduction/non-broadcast dimensions.
*/
-struct TORCH_CUDA_API ReductionOp : public Expr {
+class TORCH_CUDA_API ReductionOp : public Expr {
+ public:
~ReductionOp() = default;
ReductionOp(BinaryOpType _reduction_op_type, Val* _init, Val* _out, Val* _in);
+ ReductionOp(const ReductionOp* src, IrCloner* ir_cloner);
+
ReductionOp(const ReductionOp& other) = delete;
ReductionOp& operator=(const ReductionOp& other) = delete;
ReductionOp(ReductionOp&& other) = delete;
ReductionOp& operator=(ReductionOp&& other) = delete;
- Val* out() const noexcept {
+ Val* out() const {
return out_;
}
- Val* in() const noexcept {
+ Val* in() const {
return in_;
}
- Val* init() const noexcept {
+ Val* init() const {
return init_;
}
- BinaryOpType getReductionOpType() const noexcept {
+ BinaryOpType getReductionOpType() const {
return reduction_op_type_;
}
bool sameAs(const ReductionOp* const other) const;
+ std::vector<IterDomain*> getReductionDomains() const;
+
+ std::unordered_map<ParallelType, IterDomain*> getParallelReductionDomains()
+ const;
+
private:
const BinaryOpType reduction_op_type_;
- Val* const init_;
- Val* const out_;
- Val* const in_;
+ Val* const init_ = nullptr;
+ Val* const out_ = nullptr;
+ Val* const in_ = nullptr;
};
-struct TORCH_CUDA_API TernaryOp : public Expr {
+class TORCH_CUDA_API TernaryOp : public Expr {
+ public:
~TernaryOp() = default;
TernaryOp(TernaryOpType _type, Val* _out, Val* _in1, Val* _in2, Val* _in3);
+ TernaryOp(const TernaryOp* src, IrCloner* ir_cloner);
+
TernaryOp(const TernaryOp& other) = delete;
TernaryOp& operator=(const TernaryOp& other) = delete;
TernaryOp(TernaryOp&& other) = delete;
TernaryOp& operator=(TernaryOp&& other) = delete;
- Val* out() const noexcept {
+ Val* out() const {
return out_;
}
- Val* in1() const noexcept {
+ Val* in1() const {
return in1_;
}
- Val* in2() const noexcept {
+ Val* in2() const {
return in2_;
}
- Val* in3() const noexcept {
+ Val* in3() const {
return in3_;
}
- TernaryOpType getTernaryOpType() const noexcept {
+ TernaryOpType getTernaryOpType() const {
return ternary_op_type_;
}
@@ -198,10 +217,10 @@
private:
const TernaryOpType ternary_op_type_;
- Val* const out_;
- Val* const in1_;
- Val* const in2_;
- Val* const in3_;
+ Val* const out_ = nullptr;
+ Val* const in1_ = nullptr;
+ Val* const in2_ = nullptr;
+ Val* const in3_ = nullptr;
};
/*
@@ -210,7 +229,8 @@
* IterDomains to form an ND iterable. We directly set parallization strategies
* on IterDomains.
*/
-struct TORCH_CUDA_API IterDomain : public Val {
+class TORCH_CUDA_API IterDomain : public Val {
+ public:
~IterDomain() = default;
IterDomain() = delete;
@@ -223,6 +243,8 @@
bool _rfactor_domain = false,
bool _broadcast_domain = false);
+ IterDomain(const IterDomain* src, IrCloner* ir_cloner);
+
bool sameAs(const IterDomain* const other) const;
// Returns a new IterDomain matching properties of this
@@ -241,15 +263,15 @@
IterDomain* in,
unsigned int factor);
- bool isReduction() const noexcept {
+ bool isReduction() const {
return is_reduction_domain_;
}
- bool isRFactorProduct() const noexcept {
+ bool isRFactorProduct() const {
return is_rfactor_domain_;
}
- bool isBroadcast() const noexcept {
+ bool isBroadcast() const {
return is_broadcast_domain_;
}
@@ -280,10 +302,6 @@
void parallelize(ParallelType t) {
parallel_method_ = t;
- if (isBlockDim())
- TORCH_CHECK(
- !isReduction(),
- "Cannot parallelize reductions across a block dimension.");
// Currently a limitation as we allocate shared memory as static (not based
// off a dynamic size.)
@@ -307,11 +325,11 @@
" .");
}
- ParallelType parallel_method() const noexcept {
+ ParallelType parallel_method() const {
return parallel_method_;
}
- Val* start() const noexcept {
+ Val* start() const {
return start_;
}
Val* extent() const;
@@ -326,13 +344,14 @@
IterDomain& operator=(IterDomain&& other) = delete;
private:
- Val* const start_;
- Val* const extent_;
+ Val* const start_ = nullptr;
+ Val* const extent_ = nullptr;
ParallelType parallel_method_ = ParallelType::Serial;
- bool is_reduction_domain_;
- bool is_rfactor_domain_;
- bool is_broadcast_domain_;
+ bool is_reduction_domain_ = false;
+ bool is_rfactor_domain_ = false;
+ bool is_broadcast_domain_ = false;
};
+
/*
* TensorDomain holds a vector of IterDomains. It holds an IterDomain for every
* logical axis in its associated tensor. TensorDomain does not directly hold
@@ -346,7 +365,9 @@
* operations that take in a TensorDomain, applies a transformation and outputs
* a tensor domain.
*/
-struct TORCH_CUDA_API TensorDomain : public Val {
+class TORCH_CUDA_API TensorDomain : public Val {
+ public:
+ TensorDomain() = delete;
~TensorDomain() = default;
TensorDomain(const TensorDomain& other) = delete;
@@ -355,7 +376,7 @@
TensorDomain(TensorDomain&& other) = delete;
TensorDomain& operator=(TensorDomain&& other) = delete;
- TensorDomain(std::vector<IterDomain*> _domain);
+ explicit TensorDomain(std::vector<IterDomain*> _domain);
TensorDomain(
std::vector<IterDomain*> _root_domain,
@@ -366,6 +387,8 @@
std::vector<IterDomain*> _rfactor_domain,
std::vector<IterDomain*> _domain);
+ TensorDomain(const TensorDomain* src, IrCloner* ir_cloner);
+
std::vector<IterDomain*>::size_type nDims() const {
return domain_.size();
}
@@ -376,27 +399,29 @@
const std::vector<IterDomain*>& lhs,
const std::vector<IterDomain*>& rhs);
- const std::vector<IterDomain*>& domain() const noexcept {
+ const std::vector<IterDomain*>& domain() const {
return domain_;
}
bool hasReduction() const;
+ bool hasBlockReduction() const;
+ bool hasGridReduction() const;
bool hasBroadcast() const;
bool hasRFactor() const;
- const std::vector<IterDomain*>& noReductions() const noexcept {
+ const std::vector<IterDomain*>& noReductions() const {
return no_reduction_domain_;
}
- const std::vector<IterDomain*>& noBroadcasts() const noexcept {
+ const std::vector<IterDomain*>& noBroadcasts() const {
return no_bcast_domain_;
}
- const std::vector<IterDomain*>& rootDomain() const noexcept {
+ const std::vector<IterDomain*>& rootDomain() const {
return root_domain_;
};
- const std::vector<IterDomain*>& rfactorDomain() const noexcept {
+ const std::vector<IterDomain*>& rfactorDomain() const {
return rfactor_domain_;
};
@@ -447,7 +472,8 @@
* Representation a split on an IterDomain by "factor"
* TODO: Implement split by nparts
*/
-struct TORCH_CUDA_API Split : public Expr {
+class TORCH_CUDA_API Split : public Expr {
+ public:
~Split() = default;
Split(const Split& other) = delete;
@@ -458,25 +484,27 @@
Split(IterDomain* _outer, IterDomain* _inner, IterDomain* _in, Int* _factor);
- IterDomain* outer() const noexcept {
+ Split(const Split* src, IrCloner* ir_cloner);
+
+ IterDomain* outer() const {
return outer_;
}
- IterDomain* inner() const noexcept {
+ IterDomain* inner() const {
return inner_;
}
- IterDomain* in() const noexcept {
+ IterDomain* in() const {
return in_;
}
- Int* factor() const noexcept {
+ Int* factor() const {
return factor_;
}
bool sameAs(const Split* const other) const;
private:
- IterDomain* const outer_;
- IterDomain* const inner_;
- IterDomain* const in_;
- Int* const factor_;
+ IterDomain* const outer_ = nullptr;
+ IterDomain* const inner_ = nullptr;
+ IterDomain* const in_ = nullptr;
+ Int* const factor_ = nullptr;
};
/*
@@ -486,32 +514,35 @@
* if there is one.
* TODO: Should this be a unary op type?
*/
-struct TORCH_CUDA_API Merge : public Expr {
+class TORCH_CUDA_API Merge : public Expr {
+ public:
~Merge() = default;
Merge(IterDomain* _out, IterDomain* _outer, IterDomain* _inner);
+ Merge(const Merge* src, IrCloner* ir_cloner);
+
Merge(const Merge& other) = delete;
Merge& operator=(const Merge& other) = delete;
Merge(Merge&& other) = delete;
Merge& operator=(Merge&& other) = delete;
- IterDomain* out() const noexcept {
+ IterDomain* out() const {
return out_;
}
- IterDomain* outer() const noexcept {
+ IterDomain* outer() const {
return outer_;
}
- IterDomain* inner() const noexcept {
+ IterDomain* inner() const {
return inner_;
}
bool sameAs(const Merge* const other) const;
private:
- IterDomain* const out_;
- IterDomain* const outer_;
- IterDomain* const inner_;
+ IterDomain* const out_ = nullptr;
+ IterDomain* const outer_ = nullptr;
+ IterDomain* const inner_ = nullptr;
};
/*
@@ -523,7 +554,8 @@
* TODO: Change implmentation of Exprs contained in the scope to be more similar
* to Fusion where we can do proper dependency analysis.
*/
-struct TORCH_CUDA_API ForLoop : public Expr {
+class TORCH_CUDA_API ForLoop : public Expr {
+ public:
~ForLoop() = default;
ForLoop(
Val* _index,
@@ -531,38 +563,40 @@
const std::vector<Expr*>& _body = {},
Expr* parent_scope = nullptr);
+ ForLoop(const ForLoop* src, IrCloner* ir_cloner);
+
ForLoop(const ForLoop& other) = delete;
ForLoop& operator=(const ForLoop& other) = delete;
ForLoop(ForLoop&& other) = delete;
ForLoop& operator=(ForLoop&& other) = delete;
- Val* index() const noexcept {
+ Val* index() const {
return index_;
}
- IterDomain* iter_domain() const noexcept {
+ IterDomain* iter_domain() const {
return iter_domain_;
}
- Scope& body() noexcept {
+ Scope& body() {
return body_;
}
- const Scope& constBody() const noexcept {
+ const Scope& constBody() const {
return body_;
}
bool sameAs(const ForLoop* other) const;
- Expr* parentScope() const noexcept {
+ Expr* parentScope() const {
return parent_scope_;
}
private:
- Val* const index_;
+ Val* const index_ = nullptr;
IterDomain* const iter_domain_;
Scope body_;
- Expr* parent_scope_;
+ Expr* parent_scope_ = nullptr;
};
/*
@@ -574,7 +608,8 @@
* TODO: Change implmentation of Exprs contained in the scope to be more similar
* to Fusion where we can do proper dependency analysis.
*/
-struct TORCH_CUDA_API IfThenElse : public Expr {
+class TORCH_CUDA_API IfThenElse : public Expr {
+ public:
~IfThenElse() = default;
IfThenElse(
Bool* _cond,
@@ -582,47 +617,49 @@
const std::vector<Expr*>& _else_body = {},
Expr* _parent_scope = nullptr);
+ IfThenElse(const IfThenElse* src, IrCloner* ir_cloner);
+
IfThenElse(const IfThenElse& other) = delete;
IfThenElse& operator=(const IfThenElse& other) = delete;
IfThenElse(IfThenElse&& other) = delete;
IfThenElse& operator=(IfThenElse&& other) = delete;
- Bool* cond() const noexcept {
+ Bool* cond() const {
return cond_;
}
- const Scope& constBody() const noexcept {
+ const Scope& constBody() const {
return body_;
}
- const Scope& constElseBody() const noexcept {
+ const Scope& constElseBody() const {
return else_body_;
}
- Scope& body() noexcept {
+ Scope& body() {
return body_;
}
- Scope& elseBody() noexcept {
+ Scope& elseBody() {
return else_body_;
}
- bool hasElse() const noexcept {
+ bool hasElse() const {
return !else_body_.empty();
}
bool sameAs(const IfThenElse* other) const;
- Expr* parentScope() const noexcept {
+ Expr* parentScope() const {
return parent_scope_;
}
private:
- Bool* const cond_;
+ Bool* const cond_ = nullptr;
Scope body_;
Scope else_body_;
- Expr* parent_scope_;
+ Expr* parent_scope_ = nullptr;
};
/*
@@ -630,7 +667,8 @@
* TensorView. It is not the flattened index, which needs to be computed using
* stride information.
*/
-struct TORCH_CUDA_API TensorIndex : public Val {
+class TORCH_CUDA_API TensorIndex : public Val {
+ public:
~TensorIndex() = default;
TensorIndex(const TensorIndex& other) = delete;
@@ -655,6 +693,8 @@
"Cannot index with a value other than an int.");
}
+ TensorIndex(const TensorIndex* src, IrCloner* ir_cloner);
+
std::vector<Val*>::size_type nDims() const {
return indices_.size();
}
@@ -663,18 +703,18 @@
// uint.
Val* index(int i) const;
- const std::vector<Val*>& indices() const noexcept {
+ const std::vector<Val*>& indices() const {
return indices_;
}
- const TensorView* view() const noexcept {
+ const TensorView* view() const {
return view_;
}
bool sameAs(const TensorIndex* const other) const;
private:
- const TensorView* view_;
+ const TensorView* view_ = nullptr;
std::vector<Val*> indices_;
};
@@ -687,7 +727,8 @@
* TODO: The components of Allocate like Type and Name could be separated from
* the the assocated TensorView. Perhaps that is more appropriate?
*/
-struct TORCH_CUDA_API Allocate : public Expr {
+class TORCH_CUDA_API Allocate : public Expr {
+ public:
~Allocate() = default;
Allocate(const Allocate& other) = delete;
@@ -698,19 +739,21 @@
Allocate(Val* _tv, Val* size);
+ Allocate(const Allocate* src, IrCloner* ir_cloner);
+
DataType buf_type() const;
- Val* extent() const noexcept {
+ Val* extent() const {
return extent_;
}
- Val* buffer() const noexcept {
+ Val* buffer() const {
return buffer_;
}
bool sameAs(const Allocate* other) const;
private:
- Val* buffer_;
- Val* extent_;
+ Val* buffer_ = nullptr;
+ Val* extent_ = nullptr;
};
/*
@@ -720,20 +763,23 @@
* - blockDim.z
* - T3.stride[2]
*/
-struct TORCH_CUDA_API NamedScalar : public Val {
+class TORCH_CUDA_API NamedScalar : public Val {
+ public:
~NamedScalar() = default;
NamedScalar() = delete;
NamedScalar(std::string _name, DataType dtype)
: Val(ValType::NamedScalar, dtype), name_(_name) {}
+ NamedScalar(const NamedScalar* src, IrCloner* ir_cloner);
+
NamedScalar(const NamedScalar& other) = delete;
NamedScalar& operator=(const NamedScalar& other) = delete;
NamedScalar(NamedScalar&& other) = delete;
NamedScalar& operator=(NamedScalar&& other) = delete;
- const std::string& name() const noexcept {
+ const std::string& name() const {
return name_;
}
diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp
index 6a9146f..7923218 100644
--- a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp
+++ b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp
@@ -8,30 +8,49 @@
namespace jit {
namespace fuser {
-namespace {
// Make sure we can inline something, before we attempt to.
-void check_inlineable(const IRInputOutput* const irio) {
- for (auto inp : irio->inputs())
+static void checkInlineable(const Expr* expr) {
+ for (auto input : expr->inputs()) {
TORCH_CHECK(
- inp->isScalar(),
+ input->isScalar(),
"Printing inline computations involving values other than scalars is not currently supported.");
+ }
TORCH_CHECK(
- irio->nOutputs() == 1,
+ expr->outputs().size() == 1,
"Cannot print inline computations if there's more than one output.");
TORCH_CHECK(
- irio->output(0)->isScalar(),
+ expr->output(0)->isScalar(),
"Printing inline computations involving values other than scalars is not currently supported.");
}
-} // namespace
+
+void IRPrinter::handle(const Statement* s) {
+ OptInConstDispatch::handle(s);
+}
+
+void IRPrinter::handle(const Val* v) {
+ if (follow_val_map) {
+ // Follow a single maping (permutation chains are not expected)
+ v = FusionGuard::getCurFusion()->loweredVal(v);
+ TORCH_INTERNAL_ASSERT(v == FusionGuard::getCurFusion()->loweredVal(v));
+ }
+ OptInConstDispatch::handle(v);
+}
+
+void IRPrinter::handle(const Expr* e) {
+ OptInConstDispatch::handle(e);
+}
void IRPrinter::printHeader(Fusion* fusion, const std::string& kernel_name_) {
os << "__global__ void " << kernel_name_ << "(";
- std::deque<Val*> vals;
- for (decltype(fusion->nInputs()) i{0}; i < fusion->nInputs(); i++)
- vals.push_back(fusion->input(i));
- for (decltype(fusion->nOutputs()) i{0}; i < fusion->nOutputs(); i++)
- vals.push_back(fusion->output(i));
+ std::vector<Val*> vals;
+
+ for (auto val : fusion->inputs()) {
+ vals.push_back(val);
+ }
+ for (auto val : fusion->outputs()) {
+ vals.push_back(val);
+ }
for (Val* val : vals) {
switch (val->getValType().value()) {
@@ -57,6 +76,11 @@
if (fusion->hasRNG())
os << ", unsigned long long seed, unsigned long long offset";
+
+ if (fusion->hasGridReduction()) {
+ os << ", void* work_buf, unsigned* sync_flags";
+ }
+
os << "){\n";
indent_size++;
if (fusion->hasRNG()) {
@@ -65,6 +89,12 @@
indent();
os << "Philox rnd(seed, idx, offset);\n";
}
+ if (fusion->hasBlockReduction() || fusion->hasGridReduction()) {
+ indent();
+ // TODO: Dynamic sizing possible? blockReduce originally used 1024
+ // values of a given type
+ os << "__shared__ float shared_mem[1024];\n";
+ }
}
void IRPrinter::handle(Fusion* fusion) {
@@ -74,9 +104,13 @@
}
}
-void IRPrinter::handle(const TensorDomain* const td) {
+void IRPrinter::handle(const TensorDomain* td) {
+ if (td->nDims() == 0) {
+ os << "[ 0 ]";
+ return;
+ }
os << "[ ";
- for (std::vector<const IterDomain*>::size_type i = 0; i < td->nDims(); i++) {
+ for (size_t i = 0; i < td->nDims(); i++) {
handle(td->axis(i));
if (i != td->nDims() - 1)
os << ", ";
@@ -84,7 +118,7 @@
os << " ]";
}
-void IRPrinter::handle(const TensorView* const tv) {
+void IRPrinter::handle(const TensorView* tv) {
os << "T" << tv->name();
handle(tv->domain());
@@ -95,7 +129,7 @@
}
}
-void IRPrinter::handle(const IterDomain* const id) {
+void IRPrinter::handle(const IterDomain* id) {
if (id->isReduction())
os << "r";
else if (id->isBroadcast())
@@ -127,9 +161,14 @@
os << "rf";
}
-void IRPrinter::handle(const TensorIndex* const ti) {
- os << "T" << ti->view()->name() << "[ ";
+void IRPrinter::handle(const TensorIndex* ti) {
+ os << "T" << ti->view()->name();
+ if (ti->nDims() == 0) {
+ os << "[ 0 ]";
+ return;
+ }
+ os << "[ ";
bool first = true;
for (auto* ind : ti->indices()) {
if (!first)
@@ -140,7 +179,7 @@
os << " ]";
}
-void IRPrinter::handle(const Bool* const b) {
+void IRPrinter::handle(const Bool* b) {
if (print_inline_ && FusionGuard::getCurFusion()->origin(b) != nullptr) {
os << "( ";
handle(FusionGuard::getCurFusion()->origin(b));
@@ -155,7 +194,7 @@
}
}
-void IRPrinter::handle(const Float* const f) {
+void IRPrinter::handle(const Float* f) {
if (print_inline_ && FusionGuard::getCurFusion()->origin(f) != nullptr) {
os << "( ";
handle(FusionGuard::getCurFusion()->origin(f));
@@ -173,7 +212,7 @@
}
}
-void IRPrinter::handle(const Half* const h) {
+void IRPrinter::handle(const Half* h) {
if (print_inline_ && FusionGuard::getCurFusion()->origin(h) != nullptr) {
os << "( ";
handle(FusionGuard::getCurFusion()->origin(h));
@@ -188,12 +227,19 @@
}
}
-void IRPrinter::handle(const Int* const i) {
- if (print_inline_ && FusionGuard::getCurFusion()->origin(i) != nullptr) {
- os << "( ";
- handle(FusionGuard::getCurFusion()->origin(i));
- os << " )";
- return;
+void IRPrinter::handle(const Int* i) {
+ // Make sure we didn't bypass the value mapping
+ // (for example calling IRPrinter::handle() with a Int*)
+ TORCH_CHECK(
+ !follow_val_map || i == FusionGuard::getCurFusion()->loweredVal(i));
+
+ if (print_inline_) {
+ if (auto def = FusionGuard::getCurFusion()->origin(i)) {
+ os << "( ";
+ handle(def);
+ os << " )";
+ return;
+ }
}
if (i->isSymbolic()) {
@@ -203,27 +249,21 @@
}
}
-void IRPrinter::handle(const NamedScalar* const i) {
+void IRPrinter::handle(const NamedScalar* i) {
os << i->name();
}
-namespace {
-
-bool isTV(const Val* const val) {
- return (
- val->getValType().value() == ValType::TensorView ||
- val->getValType().value() == ValType::TensorIndex);
+static bool isTV(const Val* val) {
+ return val->getValType().value() == ValType::TensorView ||
+ val->getValType().value() == ValType::TensorIndex;
}
// Check if we're a TensorView op that we can generate code for.
-bool isTVOp(const Expr* const expr) {
- if (expr->nOutputs() == 1 && isTV(expr->output(0)))
- return true;
- return false;
+static bool isTVOp(const Expr* expr) {
+ return expr->outputs().size() == 1 && isTV(expr->outputs().front());
}
-} // namespace
-void IRPrinter::handle(const UnaryOp* const uop) {
+void IRPrinter::handle(const UnaryOp* uop) {
bool istvop = isTVOp(uop);
if (!print_inline_) {
indent();
@@ -235,7 +275,7 @@
}
os << " = ";
} else {
- check_inlineable(uop);
+ checkInlineable(uop);
}
if (auto inline_uop = inline_op_str(uop->getUnaryOpType())) {
@@ -269,7 +309,7 @@
os << ";\n";
}
-void IRPrinter::handle(const BinaryOp* const bop) {
+void IRPrinter::handle(const BinaryOp* bop) {
bool istvop = isTVOp(bop);
if (!print_inline_) {
indent();
@@ -284,7 +324,7 @@
os << " = ";
} else {
- check_inlineable(bop);
+ checkInlineable(bop);
}
if (auto inline_bop = inline_op_str(bop->getBinaryOpType())) {
@@ -314,7 +354,7 @@
os << ";\n";
}
-void IRPrinter::handle(const TernaryOp* const top) {
+void IRPrinter::handle(const TernaryOp* top) {
bool istvop = isTVOp(top);
if (!print_inline_) {
indent();
@@ -329,7 +369,7 @@
os << " = ";
} else {
- check_inlineable(top);
+ checkInlineable(top);
}
os << top->getTernaryOpType() << "(";
@@ -355,7 +395,7 @@
os << ";\n";
}
-void IRPrinter::handle(const ReductionOp* const rop) {
+void IRPrinter::handle(const ReductionOp* rop) {
// Check if we've lowered yet.
bool lowered = rop->out()->getValType() == ValType::TensorIndex;
@@ -367,52 +407,77 @@
return;
}
- TensorIndex* out = static_cast<TensorIndex*>(rop->out());
+ auto out = rop->out()->as<TensorIndex>();
auto vec_domain = out->view()->domain()->domain();
- IterDomain *tidx = nullptr, *tidy = nullptr, *tidz = nullptr;
- bool is_thread_reduce = false;
- for (auto id : vec_domain) {
- if (id->isThreadDim() && id->isReduction()) {
- switch (id->parallel_method()) {
- case (ParallelType::TIDz):
- tidz = id;
- break;
- case (ParallelType::TIDy):
- tidy = id;
- break;
- case (ParallelType::TIDx):
- tidx = id;
- break;
- default:
- TORCH_INTERNAL_ASSERT(
- false, "Did not recognize parallel type for reduction.");
- }
- is_thread_reduce = true;
- }
- }
+ bool has_block_reduce = out->view()->hasBlockReduction();
+ bool has_grid_reduce = out->view()->hasGridReduction();
- if (!is_thread_reduce) {
+ if (!has_block_reduce && !has_grid_reduce) {
handle(new BinaryOp(rop->getReductionOpType(), out, out, rop->in()));
return;
}
+
+ auto par_domains = rop->getParallelReductionDomains();
+ bool tidx = par_domains.find(ParallelType::TIDx) != par_domains.end();
+ bool tidy = par_domains.find(ParallelType::TIDy) != par_domains.end();
+ bool tidz = par_domains.find(ParallelType::TIDz) != par_domains.end();
+ bool bidx = par_domains.find(ParallelType::BIDx) != par_domains.end();
+ bool bidy = par_domains.find(ParallelType::BIDy) != par_domains.end();
+ bool bidz = par_domains.find(ParallelType::BIDz) != par_domains.end();
+
auto d_type = rop->out()->getDataType().value();
auto op_type = rop->getReductionOpType();
- indent();
- // Thread all reduce.
- os << "blockReduce< " << (tidx != nullptr ? "true" : "false") << ", "
- << (tidy != nullptr ? "true" : "false") << ", "
- << (tidz != nullptr ? "true" : "false") << " >"
- << " ( ";
- handle(rop->out());
- os << ", ";
- handle(rop->in());
- os << ", ";
- os << "reduction_" << op_type << "_" << d_type;
- os << ");\n";
+ const std::string block_result = "block_result";
+ if (has_block_reduce) {
+ if (has_grid_reduce) {
+ indent();
+ os << d_type << " " << block_result << ";\n";
+ }
+ indent();
+ // Thread all reduce.
+ os << "blockReduce< " << (tidx ? "true" : "false") << ", "
+ << (tidy ? "true" : "false") << ", " << (tidz ? "true" : "false") << " >"
+ << " ( ";
+ if (has_grid_reduce) {
+ os << block_result;
+ } else {
+ handle(rop->out());
+ }
+ os << ", ";
+ handle(rop->in());
+ os << ", ";
+ os << "reduction_" << op_type << "_" << d_type;
+ os << ", threadIdx, blockDim";
+ os << ", reinterpret_cast<" << d_type << "*>(shared_mem)";
+ os << ");\n";
+ }
+ if (has_grid_reduce) {
+ indent();
+ // Since block-level reduction is already done, those dimensions
+ // with tidx/y/z being true do not participate in the grid reduction.
+ os << "reduction::gridReduce< " << (bidx ? "true" : "false") << ", "
+ << (bidy ? "true" : "false") << ", " << (bidz ? "true" : "false") << ", "
+ << (!tidx ? "true" : "false") << ", " << (!tidy ? "true" : "false")
+ << ", " << (!tidz ? "true" : "false") << " >"
+ << " ( ";
+ handle(rop->out());
+ os << ", ";
+ if (has_block_reduce) {
+ os << block_result;
+ } else {
+ handle(rop->in());
+ }
+ os << ", ";
+ os << "reduction_" << op_type << "_" << d_type;
+ os << ", static_cast<" << d_type << "*>(work_buf)";
+ os << ", sync_flags";
+ os << ", reinterpret_cast<" << d_type << "*>(shared_mem)";
+ os << ");\n";
+ }
}
-void IRPrinter::handle(const BroadcastOp* const bop) {
+void IRPrinter::handle(const BroadcastOp* bop) {
indent();
handle(bop->out());
os << "\n";
@@ -424,8 +489,8 @@
os << ";\n";
}
-void IRPrinter::handle(const ForLoop* const fl) {
- if (fl->iter_domain()->isThread()) {
+void IRPrinter::handle(const ForLoop* fl) {
+ if (fl->iter_domain()->isThread() || fl->iter_domain()->isBroadcast()) {
for (auto& expr : fl->constBody().exprs())
handle(expr);
return;
@@ -452,7 +517,7 @@
os << "}\n";
}
-void IRPrinter::handle(const IfThenElse* const ite) {
+void IRPrinter::handle(const IfThenElse* ite) {
indent();
// IF
@@ -480,7 +545,7 @@
os << "}\n";
}
-void IRPrinter::handle(const Allocate* const a) {
+void IRPrinter::handle(const Allocate* a) {
indent();
os << a->buf_type();
if (a->buffer()->getValType() == ValType::TensorView) {
@@ -501,7 +566,7 @@
}
}
-void IRPrinter::handle(const Split* const s) {
+void IRPrinter::handle(const Split* s) {
os << "Split: ";
handle(s->in());
os << " by factor " << s->factor() << " -> ";
@@ -511,20 +576,20 @@
os << "\n";
}
-void IRPrinter::handle(const Merge* const m) {
+void IRPrinter::handle(const Merge* m) {
os << "Merge: ";
handle(m->outer());
os << " and ";
handle(m->inner());
os << " -> ";
handle(m->out());
- --indent_size;
os << "\n";
}
namespace {
-struct ReductionOps : OptOutDispatch {
+class ReductionOps : OptOutDispatch {
+ public:
std::set<std::pair<BinaryOpType, DataType>> rops;
void handle(ReductionOp* rop) override {
rops.emplace(std::pair<BinaryOpType, DataType>{
@@ -541,6 +606,7 @@
return ROPs.rops;
}
};
+
} // namespace
void IRPrinter::printReductionOps(Fusion* fusion) {
@@ -566,7 +632,6 @@
const std::vector<Expr*>& exprs,
const std::string& kernel_name) {
Fusion* fusion = FusionGuard::getCurFusion();
-
printReductionOps(fusion);
printHeader(fusion, kernel_name);
for (auto* expr : exprs) {
@@ -575,7 +640,7 @@
os << "}\n";
}
-std::ostream& operator<<(std::ostream& os, const Statement* const stmt) {
+std::ostream& operator<<(std::ostream& os, const Statement* stmt) {
IRPrinter p(os);
p.handle(stmt);
return os;
diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.h b/torch/csrc/jit/codegen/cuda/ir_iostream.h
index 5034dbf..1857648 100644
--- a/torch/csrc/jit/codegen/cuda/ir_iostream.h
+++ b/torch/csrc/jit/codegen/cuda/ir_iostream.h
@@ -10,37 +10,35 @@
namespace jit {
namespace fuser {
-struct Fusion;
+class Fusion;
-struct Statement;
+class Statement;
-struct Val;
-struct Expr;
+class Val;
+class Expr;
-struct UnaryOp;
-struct BinaryOp;
-struct TernaryOp;
-struct ReductionOp;
-struct BroadcastOp;
+class UnaryOp;
+class BinaryOp;
+class TernaryOp;
+class ReductionOp;
+class BroadcastOp;
-struct ForLoop;
-struct IfThenElse;
+class ForLoop;
+class IfThenElse;
-struct TensorDomain;
-struct TensorView;
-struct IterDomain;
-struct TensorIndex;
+class TensorDomain;
+class TensorView;
+class IterDomain;
+class TensorIndex;
-struct TensorContiguity;
+class Split;
+class Merge;
-struct Split;
-struct Merge;
-
-struct Bool;
-struct Float;
-struct Half;
-struct Int;
-struct Add;
+class Bool;
+class Float;
+class Half;
+class Int;
+class Add;
/*
* Define pretty printing functions for all nodes. handle is used so we can take
@@ -50,7 +48,7 @@
* stream operator <<.
*/
-struct TORCH_CUDA_API IRPrinter : public OptInConstDispatch {
+class TORCH_CUDA_API IRPrinter : public OptInConstDispatch {
public:
std::ostream& os;
bool print_inline_ = false;
@@ -58,6 +56,9 @@
// Track the indentation size for pretty printing
int indent_size = 0;
+ // Handle value mapping
+ bool follow_val_map = true;
+
// Indent the generated code
void indent() {
for (int i = 0; i < indent_size; i++)
@@ -72,54 +73,48 @@
IRPrinter(std::ostream& _os) : os(_os) {}
- virtual void handle(Fusion* const f);
+ virtual void handle(Fusion* f);
// handle calls some non const fusion ops,
// eventhough fusion should remain unchanged.
// Need to look into this.
- virtual void handle(const Fusion* const f) {
+ virtual void handle(const Fusion* f) {
handle(const_cast<Fusion*>(f));
}
+
virtual void handle(Fusion& f) {
handle(&f);
}
- virtual void handle(const Statement* const s) {
- OptInConstDispatch::handle(s);
- };
+ void handle(const Statement* s) override;
+ void handle(const Val* v) override;
+ void handle(const Expr* e) override;
- virtual void handle(const Val* const v) {
- OptInConstDispatch::handle(v);
- };
- virtual void handle(const Expr* const e) {
- OptInConstDispatch::handle(e);
- };
+ void handle(const TensorDomain*) override;
+ void handle(const TensorView*) override;
+ void handle(const IterDomain*) override;
+ void handle(const TensorIndex*) override;
- virtual void handle(const TensorDomain* const) override;
- virtual void handle(const TensorView* const) override;
- virtual void handle(const IterDomain* const) override;
- virtual void handle(const TensorIndex* const) override;
+ void handle(const Bool*) override;
+ void handle(const Float*) override;
+ void handle(const Half*) override;
+ void handle(const Int*) override;
+ void handle(const NamedScalar*) override;
- virtual void handle(const Bool* const) override;
- virtual void handle(const Float* const) override;
- virtual void handle(const Half* const) override;
- virtual void handle(const Int* const) override;
- virtual void handle(const NamedScalar* const) override;
+ void handle(const UnaryOp*) override;
+ void handle(const BinaryOp*) override;
+ void handle(const TernaryOp*) override;
+ void handle(const ReductionOp*) override;
+ void handle(const BroadcastOp*) override;
- virtual void handle(const UnaryOp* const) override;
- virtual void handle(const BinaryOp* const) override;
- virtual void handle(const TernaryOp* const) override;
- virtual void handle(const ReductionOp* const) override;
- virtual void handle(const BroadcastOp* const) override;
+ void handle(const ForLoop*) override;
+ void handle(const IfThenElse*) override;
+ void handle(const Allocate*) override;
- virtual void handle(const ForLoop* const) override;
- virtual void handle(const IfThenElse* const) override;
- virtual void handle(const Allocate* const) override;
+ void handle(const Split*) override;
+ void handle(const Merge*) override;
- virtual void handle(const Split* const) override;
- virtual void handle(const Merge* const) override;
-
- void print_inline(const Statement* const stmt) {
+ void print_inline(const Statement* stmt) {
bool prev = print_inline_;
print_inline_ = true;
handle(stmt);
@@ -135,7 +130,7 @@
TORCH_CUDA_API std::ostream& operator<<(
std::ostream& os,
- const Statement* const stmt);
+ const Statement* stmt);
TORCH_CUDA_API std::ostream& operator<<(std::ostream& os, Fusion* f);
TORCH_CUDA_API std::ostream& operator<<(std::ostream& os, Fusion& f);
diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp
index 03ff9c4..1ad4f30 100644
--- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp
+++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp
@@ -1,4 +1,5 @@
#include <torch/csrc/jit/codegen/cuda/arith.h>
+#include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
#include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/transform_iter.h>
@@ -13,36 +14,8 @@
namespace fuser {
namespace {
-struct ScalarCheck : OptInDispatch {
- Val* v1_;
- Val* v2_;
- bool same = false;
- void handle(Bool* b) override {
- same = static_cast<Bool*>(v1_)->sameAs(static_cast<Bool*>(v2_));
- }
-
- void handle(Float* f) override {
- same = static_cast<Float*>(v1_)->sameAs(static_cast<Float*>(v2_));
- }
-
- void handle(Half* h) override {
- same = static_cast<Half*>(v1_)->sameAs(static_cast<Half*>(v2_));
- }
-
- void handle(Int* i) override {
- same = static_cast<Int*>(v1_)->sameAs(static_cast<Int*>(v2_));
- }
-
- void handle(NamedScalar* ns) override {
- same =
- static_cast<NamedScalar*>(v1_)->sameAs(static_cast<NamedScalar*>(v2_));
- }
-
- ScalarCheck(Val* _v1, Val* _v2) : v1_(_v1), v2_(_v2) {
- OptInDispatch::handle(v1_);
- }
-
+class ScalarCheck : OptInDispatch {
public:
static bool sameAs(Val* v1, Val* v2) {
if (v1 == v2)
@@ -55,29 +28,73 @@
return false;
ScalarCheck sc(v1, v2);
- return sc.same;
+ return sc.same_;
}
+
+ private:
+ void handle(Bool* b) override {
+ same_ = static_cast<Bool*>(v1_)->sameAs(static_cast<Bool*>(v2_));
+ }
+
+ void handle(Float* f) override {
+ same_ = static_cast<Float*>(v1_)->sameAs(static_cast<Float*>(v2_));
+ }
+
+ void handle(Half* h) override {
+ same_ = static_cast<Half*>(v1_)->sameAs(static_cast<Half*>(v2_));
+ }
+
+ void handle(Int* i) override {
+ same_ = static_cast<Int*>(v1_)->sameAs(static_cast<Int*>(v2_));
+ }
+
+ void handle(NamedScalar* ns) override {
+ same_ =
+ static_cast<NamedScalar*>(v1_)->sameAs(static_cast<NamedScalar*>(v2_));
+ }
+
+ ScalarCheck(Val* _v1, Val* _v2) : v1_(_v1), v2_(_v2) {
+ OptInDispatch::handle(v1_);
+ }
+
+ private:
+ Val* v1_ = nullptr;
+ Val* v2_ = nullptr;
+ bool same_ = false;
};
+
} // namespace
+Bool::Bool(const Bool* src, IrCloner* ir_cloner)
+ : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {}
+
bool Bool::sameAs(const Bool* const other) const {
if (isConst() && other->isConst())
return *value() == *(other->value());
return this == other;
}
+Float::Float(const Float* src, IrCloner* ir_cloner)
+ : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {}
+
bool Float::sameAs(const Float* const other) const {
if (isConst() && other->isConst())
return *value() == *(other->value());
return this == other;
}
+Half::Half(const Half* src, IrCloner* ir_cloner)
+ : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {}
+
bool Half::sameAs(const Half* const other) const {
if (isConst() && other->isConst())
return *value() == *(other->value());
return this == other;
}
+Int::Int(const Int* src, IrCloner* ir_cloner)
+ : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {}
+
bool Int::sameAs(const Int* const other) const {
if (isConst() && other->isConst())
return *value() == *(other->value());
@@ -91,6 +108,12 @@
this->name_ = FusionGuard::getCurFusion()->registerExpr(this);
}
+UnaryOp::UnaryOp(const UnaryOp* src, IrCloner* ir_cloner)
+ : Expr(src, ir_cloner),
+ unary_op_type_(src->unary_op_type_),
+ out_(ir_cloner->clone(src->out_)),
+ in_(ir_cloner->clone(src->in_)) {}
+
bool UnaryOp::sameAs(const UnaryOp* const other) const {
if (this->type() != other->type())
return false;
@@ -109,6 +132,13 @@
this->name_ = FusionGuard::getCurFusion()->registerExpr(this);
}
+BinaryOp::BinaryOp(const BinaryOp* src, IrCloner* ir_cloner)
+ : Expr(src, ir_cloner),
+ binary_op_type_(src->binary_op_type_),
+ out_(ir_cloner->clone(src->out_)),
+ lhs_(ir_cloner->clone(src->lhs_)),
+ rhs_(ir_cloner->clone(src->rhs_)) {}
+
bool BinaryOp::sameAs(const BinaryOp* other) const {
if (getBinaryOpType() != other->getBinaryOpType())
return false;
@@ -136,6 +166,14 @@
this->name_ = FusionGuard::getCurFusion()->registerExpr(this);
}
+TernaryOp::TernaryOp(const TernaryOp* src, IrCloner* ir_cloner)
+ : Expr(src, ir_cloner),
+ ternary_op_type_(src->ternary_op_type_),
+ out_(ir_cloner->clone(src->out_)),
+ in1_(ir_cloner->clone(src->in1_)),
+ in2_(ir_cloner->clone(src->in2_)),
+ in3_(ir_cloner->clone(src->in3_)) {}
+
bool TernaryOp::sameAs(const TernaryOp* other) const {
if (getTernaryOpType() != other->getTernaryOpType())
return false;
@@ -160,7 +198,7 @@
ndims++;
TORCH_INTERNAL_ASSERT(
- ndims == (int)static_cast<TensorView*>(in_)->nDims(),
+ ndims == (int)in_->as<TensorView>()->domain()->noReductions().size(),
"Invalid broadcast op. Non-broadcasted dims don't match from input to output.");
} else {
TORCH_INTERNAL_ASSERT(
@@ -177,6 +215,11 @@
this->name_ = FusionGuard::getCurFusion()->registerExpr(this);
}
+BroadcastOp::BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner)
+ : Expr(src, ir_cloner),
+ out_(ir_cloner->clone(src->out_)),
+ in_(ir_cloner->clone(src->in_)) {}
+
bool BroadcastOp::sameAs(const BroadcastOp* const other) const {
return other->in() == in() && other->out() == out();
}
@@ -199,6 +242,13 @@
this->name_ = FusionGuard::getCurFusion()->registerExpr(this);
}
+ReductionOp::ReductionOp(const ReductionOp* src, IrCloner* ir_cloner)
+ : Expr(src, ir_cloner),
+ reduction_op_type_(src->reduction_op_type_),
+ init_(ir_cloner->clone(src->init_)),
+ out_(ir_cloner->clone(src->out_)),
+ in_(ir_cloner->clone(src->in_)) {}
+
bool ReductionOp::sameAs(const ReductionOp* other) const {
return (
this->in()->sameAs(other->in()) &&
@@ -206,6 +256,37 @@
this->init()->sameAs(other->init()));
}
+std::vector<IterDomain*> ReductionOp::getReductionDomains() const {
+ const Val* out_val = out();
+ TORCH_INTERNAL_ASSERT(
+ out_val->getValType() == ValType::TensorView ||
+ out_val->getValType() == ValType::TensorIndex,
+ "Output of reduction must be TensorView or TensorIndex");
+ // out is a TensorIndex after lowering
+ if (out_val->getValType() == ValType::TensorIndex) {
+ out_val = static_cast<const TensorIndex*>(out_val)->view();
+ }
+ auto vec_domain = out_val->as<TensorView>()->domain()->domain();
+ vec_domain.erase(
+ std::remove_if(
+ vec_domain.begin(),
+ vec_domain.end(),
+ [](IterDomain* id) { return !id->isReduction(); }),
+ vec_domain.end());
+ return vec_domain;
+}
+
+std::unordered_map<ParallelType, IterDomain*> ReductionOp::
+ getParallelReductionDomains() const {
+ std::unordered_map<ParallelType, IterDomain*> parallel_domains;
+ for (auto d : getReductionDomains()) {
+ if (d->isThread()) {
+ parallel_domains.insert(std::make_pair(d->parallel_method(), d));
+ }
+ }
+ return parallel_domains;
+}
+
IterDomain::IterDomain(
Val* _start,
Val* _extent,
@@ -240,6 +321,15 @@
this->name_ = fusion_->registerVal(this);
}
+IterDomain::IterDomain(const IterDomain* src, IrCloner* ir_cloner)
+ : Val(src, ir_cloner),
+ start_(ir_cloner->clone(src->start_)),
+ extent_(ir_cloner->clone(src->extent_)),
+ parallel_method_(src->parallel_method_),
+ is_reduction_domain_(src->is_reduction_domain_),
+ is_rfactor_domain_(src->is_rfactor_domain_),
+ is_broadcast_domain_(src->is_broadcast_domain_) {}
+
bool IterDomain::sameAs(const IterDomain* const other) const {
bool is_same = isReduction() == other->isReduction() &&
parallel_method() == other->parallel_method();
@@ -393,6 +483,14 @@
this->name_ = fusion_->registerVal(this);
}
+TensorDomain::TensorDomain(const TensorDomain* src, IrCloner* ir_cloner)
+ : Val(src, ir_cloner),
+ root_domain_(ir_cloner->clone(src->root_domain_)),
+ domain_(ir_cloner->clone(src->domain_)),
+ no_bcast_domain_(ir_cloner->clone(src->no_bcast_domain_)),
+ no_reduction_domain_(ir_cloner->clone(src->no_reduction_domain_)),
+ rfactor_domain_(ir_cloner->clone(src->rfactor_domain_)) {}
+
bool TensorDomain::sameAs(const TensorDomain* const other) const {
if (nDims() != other->nDims())
return false;
@@ -433,6 +531,18 @@
return no_reduction_domain_.size() != domain_.size();
}
+bool TensorDomain::hasBlockReduction() const {
+ return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) {
+ return id->isReduction() && id->isThreadDim();
+ });
+}
+
+bool TensorDomain::hasGridReduction() const {
+ return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) {
+ return id->isReduction() && id->isBlockDim();
+ });
+}
+
bool TensorDomain::hasBroadcast() const {
return no_bcast_domain_.size() != domain_.size();
}
@@ -444,6 +554,8 @@
// i here is int, as we want to accept negative value and ::size_type can be a
// uint.
IterDomain* TensorDomain::axis(int i) const {
+ TORCH_INTERNAL_ASSERT(
+ nDims() > 0, "Tried to access an axis in a 0-dim domain");
if (i < 0)
i += nDims();
TORCH_CHECK(
@@ -456,6 +568,7 @@
}
size_t TensorDomain::posOf(IterDomain* id) const {
+ TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to find an axis in a 0-dim domain");
size_t i = 0;
while (i < domain_.size()) {
if (domain_[i] == id)
@@ -468,6 +581,7 @@
// Split "axis" into 2 axes where the inner axes is size of "factor"
// and outer axis is size axis.extent() / factor
void TensorDomain::split(int axis_, unsigned int factor) {
+ TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do split on a 0-dim domain");
if (axis_ < 0)
axis_ += nDims();
@@ -485,6 +599,7 @@
// Merge "axis" and "axis+1" into 1 dimension
void TensorDomain::merge(int axis_o, int axis_i) {
+ TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do merge on a 0-dim domain");
if (axis_o < 0)
axis_o += nDims();
@@ -519,6 +634,9 @@
// Reorder axes according to map[old_pos] = new_pos
void TensorDomain::reorder(const std::unordered_map<int, int>& old2new_) {
+ TORCH_INTERNAL_ASSERT(
+ !(nDims() == 0 && old2new_.size() > 0),
+ "Tried to reorder a 0-dim domain");
domain_ = orderedAs(domain_, old2new_);
resetDomains();
}
@@ -526,6 +644,10 @@
std::vector<IterDomain*> TensorDomain::orderedAs(
const std::vector<IterDomain*>& dom,
const std::unordered_map<int, int>& old2new_) {
+ TORCH_INTERNAL_ASSERT(
+ !(dom.size() == 0 && old2new_.size() > 0),
+ "Tried to reorder a 0-dim domain");
+
// Eventhough these checks are already in TensorView, we want to redo them as
// we can enter this function from other places, not through TensorView
@@ -678,6 +800,8 @@
// pair is in order where second is the consumer of first
std::pair<TensorDomain*, TensorDomain*> TensorDomain::rFactor(
const std::vector<int>& axes_) {
+ TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to rFactor a 0-dim domain");
+
std::vector<int> axes(axes_.size());
auto ndims = nDims();
@@ -736,6 +860,13 @@
this->name_ = FusionGuard::getCurFusion()->registerExpr(this);
}
+Split::Split(const Split* src, IrCloner* ir_cloner)
+ : Expr(src, ir_cloner),
+ outer_(ir_cloner->clone(src->outer_)),
+ inner_(ir_cloner->clone(src->inner_)),
+ in_(ir_cloner->clone(src->in_)),
+ factor_(ir_cloner->clone(src->factor_)) {}
+
bool Split::sameAs(const Split* const other) const {
return (
outer()->sameAs(other->outer()) && inner()->sameAs(other->inner()) &&
@@ -750,6 +881,12 @@
this->name_ = FusionGuard::getCurFusion()->registerExpr(this);
}
+Merge::Merge(const Merge* src, IrCloner* ir_cloner)
+ : Expr(src, ir_cloner),
+ out_(ir_cloner->clone(src->out_)),
+ outer_(ir_cloner->clone(src->outer_)),
+ inner_(ir_cloner->clone(src->inner_)) {}
+
bool Merge::sameAs(const Merge* const other) const {
return (
out()->sameAs(other->out()) && outer()->sameAs(other->outer()) &&
@@ -775,6 +912,13 @@
body().push_back(expr);
}
+ForLoop::ForLoop(const ForLoop* src, IrCloner* ir_cloner)
+ : Expr(src, ir_cloner),
+ index_(ir_cloner->clone(src->index_)),
+ iter_domain_(ir_cloner->clone(src->iter_domain_)),
+ body_(&src->body_, ir_cloner),
+ parent_scope_(ir_cloner->clone(src->parent_scope_)) {}
+
bool ForLoop::sameAs(const ForLoop* other) const {
if (this->iter_domain() != other->iter_domain())
return false;
@@ -798,6 +942,13 @@
else_body_.push_back(expr);
}
+IfThenElse::IfThenElse(const IfThenElse* src, IrCloner* ir_cloner)
+ : Expr(src, ir_cloner),
+ cond_(src->cond_),
+ body_(&src->body_, ir_cloner),
+ else_body_(&src->else_body_, ir_cloner),
+ parent_scope_(ir_cloner->clone(src->parent_scope_)) {}
+
bool IfThenElse::sameAs(const IfThenElse* other) const {
if (!(this->cond()->sameAs(other->cond()) &&
this->constBody().sameAs(other->constBody()) &&
@@ -806,6 +957,11 @@
return true;
}
+TensorIndex::TensorIndex(const TensorIndex* src, IrCloner* ir_cloner)
+ : Val(src, ir_cloner),
+ view_(ir_cloner->clone(src->view_)),
+ indices_(ir_cloner->clone(src->indices_)) {}
+
bool TensorIndex::sameAs(const TensorIndex* const other) const {
if (nDims() != other->nDims())
return false;
@@ -821,6 +977,8 @@
}
Val* TensorIndex::index(int i) const {
+ TORCH_INTERNAL_ASSERT(
+ nDims() > 0, "Tried to get an index of a 0-dim TensorIndex");
if (i < 0)
i += nDims();
assert(i >= 0 && i < nDims());
@@ -846,6 +1004,11 @@
this->name_ = FusionGuard::getCurFusion()->registerExpr(this);
}
+Allocate::Allocate(const Allocate* src, IrCloner* ir_cloner)
+ : Expr(src, ir_cloner),
+ buffer_(ir_cloner->clone(src->buffer_)),
+ extent_(ir_cloner->clone(src->extent_)) {}
+
DataType Allocate::buf_type() const {
return buffer_->getDataType().value();
}
@@ -861,6 +1024,9 @@
return true;
}
+NamedScalar::NamedScalar(const NamedScalar* src, IrCloner* ir_cloner)
+ : Val(src, ir_cloner), name_(src->name_) {}
+
} // namespace fuser
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp
index b86501d..e7df6d0 100644
--- a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp
+++ b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp
@@ -9,11 +9,13 @@
/* ITER VISITOR */
-std::vector<Statement*> IterVisitor::next(Statement* statement) {
+std::vector<Statement*> IterVisitor::next(
+ Statement* statement,
+ bool respect_compute_at) {
if (statement->isVal())
return next(static_cast<Val*>(statement));
else if (statement->isExpr())
- return next(static_cast<Expr*>(statement));
+ return next(static_cast<Expr*>(statement), respect_compute_at);
else
TORCH_INTERNAL_ASSERT(
false, "IterVisitor could not detect type in next_dispatch.");
@@ -26,13 +28,44 @@
return {};
}
-std::vector<Statement*> IterVisitor::next(Expr* expr) {
+std::vector<Statement*> IterVisitor::next(Expr* expr, bool respect_compute_at) {
FusionGuard::getCurFusion()->assertInFusion(expr, "Cannot traverse expr, ");
- return {expr->inputs().begin(), expr->inputs().end()};
+ std::vector<Statement*> next_stmts{expr->inputs().begin(),
+ expr->inputs().end()};
+ if (respect_compute_at) {
+ TORCH_INTERNAL_ASSERT(
+ expr->outputs().size() == 1,
+ "Expressions with multiple outputs are not supported");
+ if (expr->output(0)->getValType().value() == ValType::TensorView) {
+ auto out = expr->output(0)->as<const TensorView>();
+ // Move input TVs that are computed at this expression backward
+ // so that they are visited later. If multiple inputs are
+ // computed at, move TVs that are computed at an inner loop nest
+ // further backward.
+ std::stable_sort(
+ next_stmts.begin(),
+ next_stmts.end(),
+ [out](const Statement* stmt0, const Statement* stmt1) {
+ std::array<const Statement*, 2> inputs{stmt0, stmt1};
+ std::array<int, 2> compute_at_axes{-1, -1};
+ for (int i = 0; i < 2; ++i) {
+ if (inputs[i]->getValType().value() == ValType::TensorView) {
+ auto tv = inputs[i]->as<TensorView>();
+ if (tv->getComputeAtView() == out) {
+ compute_at_axes[i] = tv->getRelativeComputeAtAxis();
+ }
+ }
+ }
+ return compute_at_axes[0] < compute_at_axes[1];
+ });
+ }
+ }
+ return next_stmts;
}
-// Remove any stmt in stmts that is in visited
namespace {
+
+// Remove any stmt in stmts that is in visited
void remove_visited(
std::vector<Statement*>& stmts,
const std::unordered_set<Statement*>& visited) {
@@ -47,50 +80,52 @@
to_erase.pop_back();
}
}
+
} // namespace
void IterVisitor::traverseFrom(
Fusion* const fusion,
const std::vector<Val*>& from,
- bool traverseAllPaths) {
+ bool traverseAllPaths,
+ bool respectComputeAt) {
FusionGuard fg(fusion);
std::unordered_set<Statement*> visited;
stmt_stack.clear();
stmt_stack.emplace_back(from.rbegin(), from.rend());
+ // true when returning to a node after vistiting all its input
+ // nodes. Nodes are only visited when this is true.
+ bool all_inputs_visited = false;
while (!stmt_stack.empty()) {
- auto next_stmts = next(stmt_stack.back().back());
-
- // Remove statements we already visited if we're not traversing all paths
- if (!traverseAllPaths)
- remove_visited(next_stmts, visited);
-
- // Traverse down until we get to a leaf
- while (!next_stmts.empty()) {
- stmt_stack.emplace_back(next_stmts.rbegin(), next_stmts.rend());
- next_stmts = next(stmt_stack.back().back());
- // Remove statements we already visited if we're not traversing all paths
- if (!traverseAllPaths)
- remove_visited(next_stmts, visited);
- }
-
- // Traverse back up
- // Mark visited
- visited.emplace(stmt_stack.back().back());
- // Handle
- handle(stmt_stack.back().back());
- // Remove
- stmt_stack.back().pop_back();
-
- while (!stmt_stack.empty() && stmt_stack.back().empty()) {
+ auto& current_inputs = stmt_stack.back();
+ // When current_inputs is empty, all the input nodes have been
+ // visited. Return to the output node by popping the stack. Record
+ // all inputs are visited.
+ if (current_inputs.empty()) {
stmt_stack.pop_back();
- if (!stmt_stack.empty()) {
- // Mark visited
- visited.emplace(stmt_stack.back().back());
- // Handle
- handle(stmt_stack.back().back());
- // Remove
- stmt_stack.back().pop_back();
+ all_inputs_visited = true;
+ continue;
+ }
+ const auto& stmt = current_inputs.back();
+ // Visit stmt when all_inputs_visited is true.
+ if (all_inputs_visited) {
+ // Mark visited
+ visited.insert(stmt);
+ // Handle
+ handle(stmt);
+ current_inputs.pop_back();
+ all_inputs_visited = false;
+ } else {
+ // Visit input nodes.
+ auto next_stmts = next(stmt, respectComputeAt);
+ if (!traverseAllPaths) {
+ remove_visited(next_stmts, visited);
+ }
+ if (next_stmts.empty()) {
+ all_inputs_visited = true;
+ } else {
+ stmt_stack.emplace_back(next_stmts.rbegin(), next_stmts.rend());
+ all_inputs_visited = false;
}
}
}
@@ -100,7 +135,8 @@
Fusion* const fusion,
bool from_outputs_only,
bool breadth_first,
- bool traverse_all_paths) {
+ bool traverse_all_paths,
+ bool respect_compute_at) {
FusionGuard fg(fusion);
if (breadth_first)
TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
@@ -109,7 +145,8 @@
auto term_outs = IterVisitor::getTerminatingOutputs(fusion);
std::vector<Val*> term_val_outs(term_outs.begin(), term_outs.end());
if (!term_val_outs.empty())
- traverseFrom(fusion, term_val_outs, traverse_all_paths);
+ traverseFrom(
+ fusion, term_val_outs, traverse_all_paths, respect_compute_at);
return;
}
@@ -120,28 +157,31 @@
leaves.push_back(val);
if (!leaves.empty())
- traverseFrom(fusion, leaves, traverse_all_paths);
+ traverseFrom(fusion, leaves, traverse_all_paths, respect_compute_at);
}
void IterVisitor::traverse(
Fusion* const fusion,
bool from_outputs_only,
- bool breadth_first) {
- traverse_(fusion, from_outputs_only, breadth_first, false);
+ bool breadth_first,
+ bool respect_compute_at) {
+ traverse_(
+ fusion, from_outputs_only, breadth_first, false, respect_compute_at);
}
void IterVisitor::traverseAllPaths(
Fusion* const fusion,
bool from_outputs_only,
- bool breadth_first) {
- traverse_(fusion, from_outputs_only, breadth_first, true);
+ bool breadth_first,
+ bool respect_compute_at) {
+ traverse_(fusion, from_outputs_only, breadth_first, true, respect_compute_at);
}
namespace {
// Expr sort will take a fusion and return a topologically sorted list of
// expressions.
-struct Exprs : public IterVisitor {
+class Exprs : public IterVisitor {
private:
std::vector<Expr*> exprs;
@@ -161,7 +201,7 @@
// Expr sort will take a fusion and return a topologically sorted list of
// expressions.
-struct Inputs : public IterVisitor {
+class Inputs : public IterVisitor {
private:
std::unordered_set<Val*> inputs;
@@ -179,6 +219,7 @@
return inps.inputs;
}
};
+
} // namespace
std::unordered_set<Val*> IterVisitor::getTerminatingOutputs(
@@ -186,10 +227,12 @@
FusionGuard fg(fusion);
std::unordered_set<Val*> used_vals;
- for (auto expr : Exprs::getExprs(
- fusion,
- std::vector<Val*>(
- fusion->outputs().begin(), fusion->outputs().end()))) {
+
+ const auto exprs = Exprs::getExprs(
+ fusion,
+ std::vector<Val*>(fusion->outputs().begin(), fusion->outputs().end()));
+
+ for (auto expr : exprs) {
for (auto inp : expr->inputs())
used_vals.emplace(inp);
}
@@ -209,7 +252,7 @@
namespace {
-struct AllVals : public IterVisitor {
+class AllVals : public IterVisitor {
private:
std::unordered_set<Val*> vals;
@@ -342,8 +385,45 @@
/* DEPENDENCY CHECKING */
namespace {
+
+// Looks for and returns all values in between dependencies and vals, including
+// them.
+struct Dependencies : public IterVisitor {
+ std::unordered_set<Val*> dependencies_;
+ std::unordered_set<Val*> vals;
+
+ std::vector<Statement*> next(Val* v) override {
+ if (dependencies_.find(v) != dependencies_.end())
+ return std::vector<Statement*>();
+ return IterVisitor::next(v);
+ }
+
+ void handle(Val* val) override {
+ vals.emplace(val);
+ }
+
+ Dependencies(
+ std::unordered_set<Val*> _dependencies,
+ const std::vector<Val*>& of)
+ : dependencies_(std::move(_dependencies)) {
+ traverseFrom(of[0]->fusion(), of, false);
+ };
+
+ public:
+ static std::unordered_set<Val*> getAllVals(
+ const std::unordered_set<Val*>& dependencies,
+ const std::vector<Val*>& of) {
+ if (of.empty())
+ return std::unordered_set<Val*>();
+
+ Dependencies deps(dependencies, of);
+ return deps.vals;
+ }
+};
+
// Looks for and returns
-struct DependencyChains : public IterVisitor {
+class DependencyChains : public IterVisitor {
+ public:
std::deque<std::deque<Val*>> dep_chains;
bool is_dependency = false;
std::unordered_set<Val*> dependencies_;
@@ -394,6 +474,7 @@
return dp.dep_chains[0];
}
+ // I don't think this is actually hooked up, but leaving for now.
static std::deque<std::deque<Val*>> getDependencyChains(
Val* dependency,
Val* of) {
@@ -403,14 +484,14 @@
return dp.dep_chains;
}
- static std::deque<std::deque<Val*>> getDependencyChainsTo(Val* dependency) {
+ static std::deque<std::deque<Val*>> getAllUseChains(Val* dependency) {
DependencyChains dp(dependency, true);
if (dp.dep_chains.empty())
return std::deque<std::deque<Val*>>();
return dp.dep_chains;
}
- static std::deque<std::deque<Val*>> getDependencyChainsTo(
+ static std::deque<std::deque<Val*>> getAllUseChains(
const std::unordered_set<Val*>& dependencies) {
DependencyChains dp(dependencies, true);
if (dp.dep_chains.empty())
@@ -437,9 +518,14 @@
return DependencyChains::getDependencyChains(dependency, of);
}
-std::deque<std::deque<Val*>> DependencyCheck::getAllDependencyChainsTo(
- Val* dependency) {
- return DependencyChains::getDependencyChainsTo(dependency);
+std::deque<std::deque<Val*>> DependencyCheck::getAllUseChains(Val* producer) {
+ return DependencyChains::getAllUseChains(producer);
+}
+
+std::unordered_set<Val*> DependencyCheck::getAllValsBetween(
+ const std::unordered_set<Val*>& dependencies,
+ const std::vector<Val*>& of) {
+ return Dependencies::getAllVals(dependencies, of);
}
} // namespace fuser
diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.h b/torch/csrc/jit/codegen/cuda/iter_visitor.h
index 1dcb6bd..117c5f1 100644
--- a/torch/csrc/jit/codegen/cuda/iter_visitor.h
+++ b/torch/csrc/jit/codegen/cuda/iter_visitor.h
@@ -12,11 +12,11 @@
namespace jit {
namespace fuser {
-struct Statement;
-struct Val;
-struct Expr;
+class Statement;
+class Val;
+class Expr;
-struct Fusion;
+class Fusion;
enum class ValType;
@@ -32,7 +32,8 @@
* TODO: We may want to have ordering of outputs to inputs. I'm not sure why we
* would want this, but seems like it would be a reasonable request.
*/
-struct TORCH_CUDA_API IterVisitor : public OptOutDispatch {
+class TORCH_CUDA_API IterVisitor : public OptOutDispatch {
+ public:
virtual ~IterVisitor() = default;
IterVisitor() = default;
@@ -48,23 +49,25 @@
// These functions will start at outputs and propagate up through the DAG
// to inputs based on depth first traversal. Next could be called on a node
// multiple times.
- virtual std::vector<Statement*> next(Statement* stmt);
- virtual std::vector<Statement*> next(Expr* expr);
+ virtual std::vector<Statement*> next(
+ Statement* stmt,
+ bool respect_compute_at);
+ virtual std::vector<Statement*> next(Expr* expr, bool respect_compute_at);
virtual std::vector<Statement*> next(Val* v);
// This handle functions is called on every Statement* in topological order,
// starting from outputs to inputs.
- virtual void handle(Statement* s) override {
+ void handle(Statement* s) override {
OptOutDispatch::handle(s);
}
// This handle functions is called on every Expr* in topological order,
// starting from outputs to inputs.
- virtual void handle(Expr* e) override {
+ void handle(Expr* e) override {
OptOutDispatch::handle(e);
}
// This handle functions is called on every Val* in topological order,
// starting from outputs to inputs.
- virtual void handle(Val* v) override {
+ void handle(Val* v) override {
OptOutDispatch::handle(v);
}
@@ -79,7 +82,8 @@
Fusion* const fusion,
bool from_outputs_only = false,
bool breadth_first = false,
- bool traverse_all_paths = false);
+ bool traverse_all_paths = false,
+ bool respect_compute_at = false);
public:
// Starts at nodes provided in from, traverses from these nodes to inputs.
@@ -90,23 +94,28 @@
void traverseFrom(
Fusion* const fusion,
const std::vector<Val*>& from,
- bool traverseAllPaths = false);
+ bool traverseAllPaths = false,
+ bool respectComputeAt = false);
// from_outputs_only = true start from outputs registered with fusion,
// from_outputs_only = false start from all leaf nodes,
// bool breadth_first = true is not implemented yet
+ // respect_compute_at = true traverse computeAt input exprs later
void traverse(
Fusion* const fusion,
bool from_outputs_only = false,
- bool breadth_first = false);
+ bool breadth_first = false,
+ bool respect_compute_at = false);
// from_outputs_only = true start from outputs registered with fusion,
// from_outputs_only = false start from all leaf nodes,
// bool breadth_first = true is not implemented yet
+ // respect_compute_at = true traverse computeAt input exprs later
void traverseAllPaths(
Fusion* const fusion,
bool from_outputs_only = false,
- bool breadth_first = false);
+ bool breadth_first = false,
+ bool respect_compute_at = false);
static std::unordered_set<Val*> getTerminatingOutputs(Fusion* const);
@@ -128,7 +137,8 @@
* outputs to guarentee that we will traverse all outputs of all exprs during
* the backward traversal.
*/
-struct TORCH_CUDA_API BackwardVisitor : public OptOutDispatch {
+class TORCH_CUDA_API BackwardVisitor : public OptOutDispatch {
+ public:
virtual ~BackwardVisitor() = default;
BackwardVisitor() = default;
@@ -187,7 +197,7 @@
bool traverseAllPaths = false);
};
-struct TORCH_CUDA_API DependencyCheck {
+class TORCH_CUDA_API DependencyCheck {
public:
// Returns if "dependency" is a dependency of "of".
static bool isDependencyOf(Val* dependency, Val* of);
@@ -206,7 +216,12 @@
// Finds all Val* paths from all leaf nodes to "dependency". Returns those
// paths. deque[i].back() are leaf nodes, and deque[i][0] is "dependency".
// Returns an empty deque if there are no uses of dependency found.
- static std::deque<std::deque<Val*>> getAllDependencyChainsTo(Val* dependency);
+ static std::deque<std::deque<Val*>> getAllUseChains(Val* dependency);
+
+ // Grab all values that exist between and including provided vals
+ static std::unordered_set<Val*> getAllValsBetween(
+ const std::unordered_set<Val*>& dependencies,
+ const std::vector<Val*>& of);
};
} // namespace fuser
diff --git a/torch/csrc/jit/codegen/cuda/kernel.cpp b/torch/csrc/jit/codegen/cuda/kernel.cpp
index 567cefa..6d7a091 100644
--- a/torch/csrc/jit/codegen/cuda/kernel.cpp
+++ b/torch/csrc/jit/codegen/cuda/kernel.cpp
@@ -9,6 +9,7 @@
#include <torch/csrc/jit/codegen/cuda/kernel_arg.h>
#include <torch/csrc/jit/codegen/cuda/kernel_resource_strings.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
+#include <torch/csrc/jit/codegen/cuda/parser.h>
#include <torch/csrc/jit/resource_guard.h>
#include <fstream>
@@ -19,10 +20,11 @@
namespace fuser {
namespace cuda {
-constexpr auto CG_NAMESPACE = "CudaCodeGen";
-constexpr auto KERNEL_NAME = "kernel";
+constexpr auto kCgNamespace = "CudaCodeGen";
+constexpr auto kKernelName = "kernel";
namespace {
+
// See NOTE [ USE OF NVRTC AND DRIVER API ]
static const at::cuda::NVRTC& nvrtc() {
return at::globalContext().getNVRTC();
@@ -147,24 +149,25 @@
std::pair<std::string, std::string> codeGeneration(Fusion* fusion) {
std::stringstream str_stream;
- str_stream << "namespace " << CG_NAMESPACE << " {\n"
+ str_stream << "namespace " << kCgNamespace << " {\n"
<< code_template_tensor_struct << "\n"
<< code_fp16_support << "\n"
<< code_random_number_gen << "\n"
<< code_helper_funcs << "\n"
- << code_template_block_reduction << "\n";
+ << code_template_block_reduction << "\n"
+ << code_template_grid_reduction << "\n";
std::stringstream cdg;
GPULower gpulw(fusion);
- gpulw.printKernel(str_stream, KERNEL_NAME);
+ gpulw.printKernel(str_stream, kKernelName);
str_stream << "\n} // namespace";
- std::string func_name = std::string(CG_NAMESPACE) + "::" + KERNEL_NAME;
+ std::string func_name = std::string(kCgNamespace) + "::" + kKernelName;
return std::make_pair(func_name, str_stream.str());
-};
+}
bool validateKernelArgTensor(
const at::Tensor& arg,
- const Val* const param,
+ const Val* param,
int device_index,
std::stringstream& msg) {
// Arg is a tensor. Param must be a tensor too.
@@ -219,7 +222,7 @@
bool validateKernelArgScalar(
const c10::TypePtr& arg_type,
- const Val* const param,
+ const Val* param,
std::stringstream& msg) {
if (!param->isScalar()) {
msg << "Argument is a scalar, but the parameter is not.";
@@ -249,7 +252,7 @@
bool validateKernelArg(
const c10::IValue& arg,
- const Val* const param,
+ const Val* param,
int device_index,
std::stringstream& msg) {
if (arg.type()->kind() != c10::TypeKind::TensorType) {
@@ -271,7 +274,7 @@
"Wrong number of kernel inputs.");
for (size_t i = 0; i < inputs.size(); ++i) {
const IValue& arg = inputs[i];
- const Val* const param = entry.fusion_->inputs()[i];
+ const Val* param = entry.fusion_->inputs()[i];
std::stringstream msg;
TORCH_INTERNAL_ASSERT(
validateKernelArg(arg, param, entry.device_, msg),
@@ -290,7 +293,7 @@
"Wrong number of kernel outputs.");
for (size_t i = 0; i < outputs.size(); ++i) {
const at::Tensor& arg = outputs[i];
- const Val* const param = entry.fusion_->outputs()[i];
+ const Val* param = entry.fusion_->outputs()[i];
std::stringstream msg;
TORCH_INTERNAL_ASSERT(
validateKernelArgTensor(arg, param, entry.device_, msg),
@@ -300,6 +303,73 @@
msg.str());
}
}
+
+size_t size(const dim3& d) {
+ return (size_t)d.x * (size_t)d.y * (size_t)d.z;
+}
+
+dim3 dimensionOfReductionBlock(
+ const dim3& block_dim,
+ bool x_thread,
+ bool y_thread,
+ bool z_thread) {
+ return dim3{x_thread ? block_dim.x : 1,
+ y_thread ? block_dim.y : 1,
+ z_thread ? block_dim.z : 1};
+}
+
+int sizeOfReductionBlock(
+ const dim3& block_dim,
+ bool x_thread,
+ bool y_thread,
+ bool z_thread) {
+ return size(
+ dimensionOfReductionBlock(block_dim, x_thread, y_thread, z_thread));
+}
+
+// Returns the total number of reduction segments.
+size_t numberOfReductionSegments(
+ const dim3& grid_dim,
+ bool x_block,
+ bool y_block,
+ bool z_block) {
+ return (x_block ? 1 : grid_dim.x) * (y_block ? 1 : grid_dim.y) *
+ (z_block ? 1 : grid_dim.z);
+}
+
+std::array<size_t, 2> gridReductionTempBufferSizes(CudaKernel* entry) {
+ size_t buffer_size = 0;
+ size_t sync_flag_size = 0;
+ for (auto expr : entry->fusion_->exprs(true)) {
+ if (expr->getExprType() != ExprType::ReductionOp)
+ continue;
+ ReductionOp* rop = static_cast<ReductionOp*>(expr);
+ auto domains = rop->getParallelReductionDomains();
+ bool x_block = domains.find(ParallelType::BIDx) != domains.end();
+ bool y_block = domains.find(ParallelType::BIDy) != domains.end();
+ bool z_block = domains.find(ParallelType::BIDz) != domains.end();
+ // No buffer needed unless it's a grid reduction
+ if (!x_block && !y_block && !z_block)
+ continue;
+ // Assumption here is that reduction along the block-parallel
+ // domains is done prior to this grid reduction, so those domains
+ // do not need to participate in the grid reductions
+ bool x_thread = domains.find(ParallelType::TIDx) == domains.end();
+ bool y_thread = domains.find(ParallelType::TIDy) == domains.end();
+ bool z_thread = domains.find(ParallelType::TIDz) == domains.end();
+ auto rb_size =
+ sizeOfReductionBlock(entry->block_, x_thread, y_thread, z_thread);
+ auto num_blocks = size(entry->grid_);
+ auto element_size = dataTypeSize(*(rop->out()->getDataType()));
+ auto required_temp_buffer_size = num_blocks * rb_size * element_size;
+ buffer_size = std::max(buffer_size, required_temp_buffer_size);
+ auto flag_size = sizeof(unsigned) *
+ numberOfReductionSegments(entry->grid_, x_block, y_block, z_block);
+ sync_flag_size = std::max(sync_flag_size, flag_size);
+ }
+ return {{buffer_size, sync_flag_size}};
+}
+
} // namespace
bool NaivePWKernelArgsReq::matchKernelSize(const at::ArrayRef<IValue> inputs) {
@@ -327,6 +397,16 @@
std::tie(func_name, code) = codeGeneration(entry->fusion_.get());
static int32_t compiled_kernel_id = 0;
+ // We increment the id here instead of at the end of the function to avoid
+ // error during jit-compilation that would make debug message confusing.
+ compiled_kernel_id++;
+ const char* debug_env = getenv("PYTORCH_CUDA_FUSER_DEBUG");
+ if (debug_env && atoi(debug_env)) {
+ std::cout << "\n==== codegen output for kernel: " << compiled_kernel_id
+ << " ====" << std::endl
+ << code << std::endl
+ << "====================================" << std::endl;
+ }
// vvv NVRTC COMPILATION vvv
@@ -452,7 +532,8 @@
void runKernel(
CudaKernel* entry,
const at::ArrayRef<IValue> inputs,
- std::vector<at::Tensor> outputs) {
+ const std::vector<at::Tensor>& outputs,
+ const std::vector<int64_t>& broadcasted_shape) {
validateKernelArgs(*entry, inputs, outputs);
const auto prior_device = at::cuda::current_device();
@@ -461,10 +542,32 @@
// TODO: Proper API to establish reasonable launch configurations;
// Naive launch config;
- size_t numel = outputs[0].numel();
+ const size_t numel = outputs[0].numel();
- // TODO: we can't randomly clap down this until we got striding.
- const auto nBlocks = ceilDiv(numel, 128 * entry->unroll_factor_);
+ int blocks = 1;
+ int thread_x = 1;
+ int thread_y = 1;
+ if (!entry->reduction_axes_.empty()) {
+ // TODO: MAJOR HACK! Expr evaluation makes launch configuration much easier
+ blocks = numel;
+ // Translated to `fcd_reduction`
+ if (entry->reduction_axes_.back() ==
+ outputs[0].dim() + ((int)entry->reduction_axes_.size()) - 1) {
+ thread_x = kFcdReductionThreadX;
+ thread_y = 1;
+ } else {
+ thread_x = kNonFcdReductionThreadX;
+ thread_y = kNonFcdReductionThreadY;
+ }
+ } else {
+ // TODO: we can't randomly clap down this until we got striding.
+ blocks = ceilDiv(numel, kPwThreadX * entry->unroll_factor_);
+ thread_x = kPwThreadX;
+ thread_y = 1;
+ }
+ const auto nBlocks = blocks;
+ const auto nThreadx = thread_x;
+ const auto nThready = thread_y;
KernelArgumentHolder kernel_args;
@@ -473,7 +576,7 @@
// from I/O expected by the generated CUDA kernel.
for (auto& input : inputs) {
if (input.isTensor()) {
- kernel_args.push(input.toTensor(), outputs[0].sizes());
+ kernel_args.push(input.toTensor(), broadcasted_shape);
} else {
kernel_args.push(input);
}
@@ -505,8 +608,8 @@
nBlocks,
1,
1,
- 128,
- 1,
+ nThreadx,
+ nThready,
1,
0,
stream,
@@ -522,7 +625,7 @@
void runTestKernel(
CudaKernel* entry,
const at::ArrayRef<IValue> inputs,
- std::vector<at::Tensor> outputs) {
+ const std::vector<at::Tensor>& outputs) {
validateKernelArgs(*entry, inputs, outputs);
const auto prior_device = at::cuda::current_device();
@@ -540,9 +643,6 @@
KernelArgumentHolder kernel_args;
auto exprs = entry->fusion_->exprs(true);
- bool has_reduction = std::any_of(exprs.begin(), exprs.end(), [](Expr* expr) {
- return expr->getExprType() == ExprType::ReductionOp;
- });
// Naive I/O setup, I'm ignoring all the potential transformation (i.e. I/O
// allocated here from the subgraph could be, and very likely are, different
@@ -555,11 +655,7 @@
TORCH_INTERNAL_ASSERT(
!entry->fusion_->outputs().empty(),
"No output found for this kernel, aborting.");
- if (has_reduction) {
- kernel_args.push(input.toTensor());
- } else {
- kernel_args.push(input.toTensor(), outputs[0].sizes());
- }
+ kernel_args.push(input.toTensor());
} else {
kernel_args.push(input);
}
@@ -585,6 +681,22 @@
kernel_args.push(philox_engine_inputs.second);
}
+ // When the kernel has global reductions, the kernel needs two
+ // additional temporary buffers, one for intermediate results and
+ // another for synchronization among thread blocks.
+ if (entry->fusion_->hasGridReduction()) {
+ auto temp_buf_type = at::kFloat;
+ auto temp_buf_sizes = gridReductionTempBufferSizes(entry);
+ auto options =
+ at::TensorOptions().dtype(temp_buf_type).device(at::kCUDA, 0);
+ at::Tensor reduction_work_buffer = at::empty(
+ {(long)(temp_buf_sizes[0] / c10::elementSize(temp_buf_type))}, options);
+ kernel_args.push(reduction_work_buffer);
+ at::Tensor sync_flags = at::zeros(
+ {(long)(temp_buf_sizes[1] / c10::elementSize(temp_buf_type))}, options);
+ kernel_args.push(sync_flags);
+ }
+
// launch kernel;
AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel(
entry->function_,
diff --git a/torch/csrc/jit/codegen/cuda/kernel.h b/torch/csrc/jit/codegen/cuda/kernel.h
index 2b37938..67050d7 100644
--- a/torch/csrc/jit/codegen/cuda/kernel.h
+++ b/torch/csrc/jit/codegen/cuda/kernel.h
@@ -40,7 +40,7 @@
std::vector<int> dims_;
};
-struct CudaKernel {
+class CudaKernel {
public:
CudaKernel() {
fusion_ = std::make_unique<Fusion>();
@@ -59,6 +59,8 @@
CUfunction function_;
int max_blocks_;
int unroll_factor_ = 1;
+ // mark reduction axes;
+ std::vector<int> reduction_axes_;
// WARNING:
// Block and Grid dimension setting is here for testing purposes only
@@ -89,13 +91,14 @@
TORCH_CUDA_API void runKernel(
CudaKernel* entry,
const at::ArrayRef<c10::IValue> inputs,
- std::vector<at::Tensor> outputs);
+ const std::vector<at::Tensor>& outputs,
+ const std::vector<int64_t>& broadcasted_shape);
// Facility API to run kernel in tests.
TORCH_CUDA_API void runTestKernel(
CudaKernel* entry,
const at::ArrayRef<c10::IValue> inputs,
- std::vector<at::Tensor> outputs);
+ const std::vector<at::Tensor>& outputs);
} // namespace cuda
} // namespace fuser
diff --git a/torch/csrc/jit/codegen/cuda/kernel_arg.h b/torch/csrc/jit/codegen/cuda/kernel_arg.h
index 0541647..984afb0 100644
--- a/torch/csrc/jit/codegen/cuda/kernel_arg.h
+++ b/torch/csrc/jit/codegen/cuda/kernel_arg.h
@@ -20,6 +20,30 @@
constexpr int nDims() {
return N;
}
+ void setSize(int i, int64_t s) {
+ size[i] = s;
+ }
+ void setStride(int i, int64_t s) {
+ stride[i] = s;
+ }
+};
+
+template <typename T>
+struct TensorArgCodegen<T, 0> {
+ T& operator[](int64_t ind) {
+ return data[ind];
+ };
+
+ T* data;
+ constexpr int nDims() {
+ return 0;
+ }
+ void setSize(int, int64_t) {
+ TORCH_INTERNAL_ASSERT(false, "Tried to set size of a 0-dim tensor");
+ }
+ void setStride(int, int64_t) {
+ TORCH_INTERNAL_ASSERT(false, "Tried to set stride of a 0-dim tensor");
+ }
};
struct ArgAbstract {
@@ -64,10 +88,10 @@
TENSOR_TYPE instance_;
void setSize(int i, int64_t size) override {
- instance_.size[i] = size;
+ instance_.setSize(i, size);
}
void setStride(int i, int64_t stride) override {
- instance_.stride[i] = stride;
+ instance_.setStride(i, stride);
}
void setPointer(void* ptr) override {
instance_.data = static_cast<decltype(TENSOR_TYPE::data)>(ptr);
@@ -81,6 +105,8 @@
template <typename T>
TensorArgAbstract* getTensorArg(int nDims) {
switch (nDims) {
+ case (0):
+ return new TensorArg<TensorArgCodegen<T, 0>>();
case (1):
return new TensorArg<TensorArgCodegen<T, 1>>();
case (2):
diff --git a/torch/csrc/jit/codegen/cuda/kernel_resource_strings.h b/torch/csrc/jit/codegen/cuda/kernel_resource_strings.h
index bfba2ad..0427c8f 100644
--- a/torch/csrc/jit/codegen/cuda/kernel_resource_strings.h
+++ b/torch/csrc/jit/codegen/cuda/kernel_resource_strings.h
@@ -20,6 +20,17 @@
int64_t size[N];
int64_t stride[N];
};
+
+// Specialization for 0-dim case as it does not need size and stride arrays.
+// They will be an error as well since zero-length arrays are not allowed.
+template<typename T>
+struct Tensor<T, 0> {
+ T& operator[](int64_t) {
+ return *data;
+ };
+
+ T* data;
+};
)";
// Code support for FP16 __half type and intrinsics
@@ -184,25 +195,22 @@
// may actually be slower.
template<bool X_REDUCE, bool Y_REDUCE, bool Z_REDUCE, typename T, typename Func>
__inline__ __device__
-void blockReduce(T& out, const T inp_val, Func reduction_op) {
-
- // Use worst case for memory.
- __shared__ T shared_mem[1024];
+void blockReduce(T& out, const T inp_val, Func reduction_op, const dim3& thread_idx, const dim3& block_dim, T* shared_mem) {
unsigned int reduction_size
- = (X_REDUCE ? blockDim.x : 1)
- * (Y_REDUCE ? blockDim.y : 1)
- * (Z_REDUCE ? blockDim.z : 1);
+ = (X_REDUCE ? block_dim.x : 1)
+ * (Y_REDUCE ? block_dim.y : 1)
+ * (Z_REDUCE ? block_dim.z : 1);
// If this thread will output a final result
bool should_write = true;
if (X_REDUCE)
- should_write = should_write && threadIdx.x == 0;
+ should_write = should_write && thread_idx.x == 0;
if (Y_REDUCE)
- should_write = should_write && threadIdx.y == 0;
+ should_write = should_write && thread_idx.y == 0;
if (Z_REDUCE)
- should_write = should_write && threadIdx.z == 0;
+ should_write = should_write && thread_idx.z == 0;
unsigned int reduction_stride;
unsigned int reduction_tid;
@@ -212,23 +220,20 @@
// Transpose Z and Y in the shared memory so Z and X dims are contiguous in smem
reduction_stride = 1;
linear_tid = threadIdx.y * blockDim.z * blockDim.x + threadIdx.z * blockDim.x + threadIdx.x;
- reduction_tid
- = threadIdx.y * blockDim.z * blockDim.x
- + threadIdx.z * blockDim.x
- + threadIdx.x;
+ reduction_tid = threadIdx.z * blockDim.x + threadIdx.x;
} else {
// Normal reduction in order
reduction_stride
= (X_REDUCE ? 1
- : (Y_REDUCE ? blockDim.x
- : (Z_REDUCE ? blockDim.x * blockDim.y : 0)));
+ : (Y_REDUCE ? block_dim.x
+ : (Z_REDUCE ? block_dim.x * block_dim.y : 0)));
- linear_tid = threadIdx.z * blockDim.y * blockDim.x + threadIdx.y * blockDim.x + threadIdx.x;
+ linear_tid = thread_idx.z * block_dim.y * block_dim.x + thread_idx.y * block_dim.x + thread_idx.x;
reduction_tid
- = ( Z_REDUCE ? threadIdx.z : 0 ) * ( Y_REDUCE ? blockDim.y : 1 ) * ( X_REDUCE ? blockDim.x : 1 )
- + ( Y_REDUCE ? threadIdx.y : 0 ) * ( X_REDUCE ? blockDim.x : 1 )
- + ( X_REDUCE ? threadIdx.x : 0 );
+ = ( Z_REDUCE ? thread_idx.z : 0 ) * ( Y_REDUCE ? block_dim.y : 1 ) * ( X_REDUCE ? block_dim.x : 1 )
+ + ( Y_REDUCE ? thread_idx.y : 0 ) * ( X_REDUCE ? block_dim.x : 1 )
+ + ( X_REDUCE ? thread_idx.x : 0 );
}
assert( reduction_stride != 0 );
@@ -257,7 +262,315 @@
}
)";
+/**
+ Inter-block reduction.
+
+ Function gridReduce performs point-wise reductions of scalars across thread
+ blocks. Thread blocks are disjointly partitioned into groups of thread blocks,
+ "reduction segments," that are collectively defined by boolean template
+ parameters, X_BLOCK, Y_BLOCK and Z_BLOCK. Each of X/Y/Z_BLOCK determines
+ whether thread blocks along the dimension should be grouped into the same
+ reduction segment. Cross-block reducitons are independently done within each
+ segment and generates distinctive results per segment. For instance, if all of
+ X/Y/Z_BLOCK are true, reductions will be done across all thread blocks since
+ there will be just a single segment consisting of all thread blocks. If none
+ of them are true, each thread block will become a segment by itself, so no
+ reduction will be performed.
+
+ The input scalars to reduce within each segment are a certain subset of
+ thread-private scalars provided as part of the gridReduce function parameters.
+ Boolean template parameters, X_THREAD, Y_THREAD and Z_THREAD, determine which
+ subset of the scalars should be used for inter-block reductions. Specifically,
+ all the input scalars of threads along each dimension will be used when
+ X/Y/Z_THREAD are true. Otherwise, only the value held at offset 0 of each
+ dimension will be used. Thus, for example, if all of X/Y/Z_THREAD are true,
+ the scalars of all threads in each block will participate in inter-block
+ reductions. If all of them are false, only one scalar of the thread at
+ threadIdx.x == threadIdx.y == threadIdx.z == 0 will be used. In the code
+ below, we call the subset of threads a "reduction block."
+
+ Inter-block reductions perform point-wise reductions of scalars of reduction
+ blocks within each reduction segment. More specifically, let rb be a reduction
+ block and rs be a reduction segment. Let IN(thread_idx, block_idx) denote the
+ input scalar of thread at thread_idx and block_idx. The result of each
+ reduction segment, OUT(thread_idx, block_idx_out), is defined only for each
+ thread_idx in thread block block_idx_out in the segment as follows:
+
+ OUT(thread_idx, block_idx_out) = Reduction of IN(thread_idx, block_idx) for
+ all block_idx in a reduction segment
+
+ OUT is not given for all threads that are not in block_idx_out and the
+ reduction block.
+
+ See also the function comment of gridReduce.
+*/
+static auto code_template_grid_reduction = R"(
+namespace reduction {
+
+// Utility functions
+__host__ __device__ __forceinline__ size_t size(const dim3& d) {
+ return (size_t)d.x * (size_t)d.y * (size_t)d.z;
+}
+
+__host__ __device__ __forceinline__ int isize(const dim3& d) {
+ return d.x * d.y * d.z;
+}
+
+__host__ __device__ __forceinline__ size_t offset(const dim3& pos, const dim3& dim) {
+ return (size_t)pos.x + (size_t)pos.y * (size_t)dim.x +
+ (size_t)pos.z * (size_t)dim.x * (size_t)dim.y;
+}
+
+__host__ __device__ __forceinline__ size_t ioffset(const dim3& pos, const dim3& dim) {
+ return pos.x + pos.y * dim.x + pos.z * dim.x * dim.y;
+}
+
+// Returns dim3 of each reduction segment.
+template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK>
+__host__ __device__ dim3 dimension_of_reduction_segment(const dim3& grid_dim) {
+ return dim3{X_BLOCK ? grid_dim.x : 1,
+ Y_BLOCK ? grid_dim.y : 1,
+ Z_BLOCK ? grid_dim.z : 1};
+}
+
+// Returns the number of blocks in each reduction segment.
+template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK>
+__host__ __device__ size_t size_of_reduction_segment(const dim3& grid_dim) {
+ return size(dimension_of_reduction_segment<X_BLOCK, Y_BLOCK, Z_BLOCK>(grid_dim));
+}
+
+// Returns the total number of reduction segments.
+template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK>
+__host__ __device__ size_t number_of_reduction_segments(const dim3& grid_dim) {
+ return (X_BLOCK ? 1: grid_dim.x) *
+ (Y_BLOCK ? 1 : grid_dim.y) *
+ (Z_BLOCK ? 1 : grid_dim.z);
+}
+
+// Returns the 1-D index of the segment of thread block of block_idx.
+template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK>
+__host__ __device__ size_t index_of_reduction_segment(const dim3& block_idx,
+ const dim3& grid_dim) {
+ size_t seg_idx = 0;
+ if (!Z_BLOCK)
+ seg_idx += block_idx.z;
+ if (!Y_BLOCK)
+ seg_idx = seg_idx * grid_dim.y + block_idx.y;
+ if (!X_BLOCK)
+ seg_idx = seg_idx * grid_dim.x + block_idx.x;
+ return seg_idx;
+}
+
+// Returns the offset of thread block in its reduction segment.
+template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK>
+__host__ __device__ size_t offset_in_reduction_segment(const dim3& block_idx,
+ const dim3& grid_dim) {
+ size_t offset = 0;
+ if (Z_BLOCK)
+ offset = offset * grid_dim.z + block_idx.z;
+ if (Y_BLOCK)
+ offset = offset * grid_dim.y + block_idx.y;
+ if (X_BLOCK)
+ offset = offset * grid_dim.x + block_idx.x;
+ return offset;
+}
+
+// Returns dim3 of each reduction block.
+template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD>
+__host__ __device__ dim3 dimension_of_reduction_block(const dim3& block_dim) {
+ return dim3{X_THREAD ? block_dim.x : 1,
+ Y_THREAD ? block_dim.y : 1,
+ Z_THREAD ? block_dim.z : 1};
+}
+
+// Returns the number of threads of each reduction block.
+template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD>
+__host__ __device__ int size_of_reduction_block(const dim3& block_dim) {
+ return isize(dimension_of_reduction_block<X_THREAD, Y_THREAD, Z_THREAD>(block_dim));
+}
+
+// Returns the linear offset of a thread in a reduction block.
+template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD>
+__host__ __device__ int offset_in_reduction_block(const dim3& thread_idx,
+ const dim3& block_dim) {
+ int offset = 0;
+ if (Z_THREAD)
+ offset += thread_idx.z;
+ if (Y_THREAD)
+ offset = offset * block_dim.y + thread_idx.y;
+ if (X_THREAD)
+ offset = offset * block_dim.x + thread_idx.x;
+ return offset;
+}
+
+/** Reduces all the reduction blocks in each reduction segment.
+
+ This is only used by one thread block per reduction segment. The input
+ reduction blocks of the segment are stored in an intermediate buffer pointed
+ by parameter in. Template parameters X/Y/Z_THREAD denote how the reduction
+ block is formed.
+
+ The size of a reduction block is by definition smaller or equal to the size of
+ a thread block. We use the remaining threads to parallelize reductions across
+ reduction blocks. For example, when X/Y/Z_THREAD = {true, false, false}, we
+ use blockDim.y*blockDim.z threads for each output value. This is done first by
+ loading the input values in parallel and then by reducing across threads of
+ dimensions whose XYZ_THREAD are false.
+
+ Note that what is done here after the loading from global memory is similar to
+ what the existing blockReduce function does. The main difference is that the
+ logical block to reduce is a 2D domain where the leading dimension is the size
+ of a reduction block and the second dimension is the remaining factor in each
+ thread block. For example, when X/Y/Z_THREAD = {false, true, false}, the
+ threads are arranged as (blockDim.y, blockDim.x*blockDim.z). We do not reduce
+ along the first dimension but only the second dimension. So, it is possible to
+ reuse the existing blockReduce with dim3{blockDim.y, blockDim.x*blockDim.z}
+ instead of blockDim and with X_THREAD and Y_THREAD being false and true,
+ respectively. Also, it still need to shuffle the final output values to their
+ actual corresponding threads. In the case of when X/Y/Z_THREAD = {false, true,
+ false}, after the intra-block reduction, the final results will still be held
+ by the first blockDim.y threads, which need to be transferred to threads at
+ threadIdx.x == 0 and threadIdx.z == 0.
+*/
+template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD,
+ typename T, typename Func>
+__device__ void gridReduceLastBlock(T& out, const T *in, const size_t in_size,
+ Func reduction_op, T* shared_buf) {
+ const int tid = ioffset(threadIdx, blockDim);
+ const int block_size = isize(blockDim);
+ const int rblock_size = size_of_reduction_block<X_THREAD, Y_THREAD, Z_THREAD>(blockDim);
+
+ T inp = 0;
+ if (tid < in_size) {
+ inp = in[tid];
+ }
+ for (size_t i = tid + block_size; i < in_size; i += block_size) {
+ reduction_op(inp, in[i]);
+ }
+
+ const auto should_write = (X_THREAD || threadIdx.x == 0) &&
+ (Y_THREAD || threadIdx.y == 0) &&
+ (Z_THREAD || threadIdx.z == 0);
+
+ auto rem_size = block_size / rblock_size;
+
+ if (rem_size > 1) {
+ const int rblock_offset = tid % rblock_size;
+ const int rblock_idx = tid / rblock_size;
+ blockReduce<false, true, false>(
+ inp, inp, reduction_op,
+ dim3{(unsigned)rblock_offset, (unsigned)rblock_idx, 0},
+ dim3{(unsigned)rblock_size, (unsigned)rem_size},
+ shared_buf);
+ __syncthreads();
+ if (tid < rblock_size) {
+ shared_buf[tid] = inp;
+ }
+ __syncthreads();
+ if (should_write) {
+ inp = shared_buf[offset_in_reduction_block<X_THREAD, Y_THREAD, Z_THREAD>(
+ threadIdx, blockDim)];
+ }
+ }
+
+ if (should_write) {
+ out = inp;
+ }
+}
+
+__device__ unsigned atomic_inc(unsigned* sync_flag, unsigned max_val) {
+ return atomicInc(sync_flag, max_val - 1);
+}
+
+/** Reduces per-thread values across thread blocks.
+
+Function parameters:
+- out: Per-thread output location
+- inp_val: Per-thread input value
+- reduction_op: Scalar reduction function
+- work_buf: Temporary buffer for cross-block reductions
+- sync_flags: A vector of integers for synchronizations
+- shared_buf: Shared memory buffer for intra-block reduction
+
+Template parameters:
+- X/Y/Z_BLOCK: When true, reduces across thread blocks along the X/Y/Z
+ dimensions
+- X/Y/Z_THREAD: When true, all threads along the X/Y/Z dimensions participate in
+ the cross-block reduction. Otherwise, only threads at offset 0 do.
+- T: Scalar data type of input/output data
+- Func: Type of scalara reduction function
+
+Template parameters X/Y/Z_BLOCK define a group of thread blocks that are reduced together. We call
+it a reduction segment. Some examples are:
+
+Case 1: X/Y/Z_BLOCK == true/true/true -> There is only one segment, which includes all
+ thread blocks. It is effecively the same as the grid.
+Case 2: X/Y/Z_BLOCK == false/false/false -> Each thread block comprises an individual
+ segment by itself.
+Case 3: X/Y/Z_BLOCK == true/false/false -> Each segment contains thread blocks that have
+ the same blockDim.x. There will be blockDim.y*blockDim.z such segments.
+
+X/Y/Z_THREAD defines a sub region of a thread block that should be reduced with
+the sub regions of other thread blocks. We call it a reduction block. E.g.,
+
+Case 1: X/Y/Z_THREAD == false/false/false -> Only thread 0 participates in the
+ cross-block reductions. The reduction block is 1x1x1 with thread 0.
+Case 2: X/Y/Z_THREAD == true/true/true-> All threads in a thread block participate in
+ the cross-block reductions. The reduction block in this case is equivalent to
+ the thread block.
+
+After the function completes, only one thread block per reduction segment gets
+valid reduction results. There is no guarantee which particular block gets the
+final results.
+*/
+template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK,
+ bool X_THREAD, bool Y_THREAD, bool Z_THREAD,
+ typename T, typename Func>
+__device__ void gridReduce(T& out, T inp_val, Func reduction_op,
+ volatile T* work_buf,
+ unsigned* sync_flags,
+ T* shared_buf) {
+ const auto seg_size =
+ size_of_reduction_segment<X_BLOCK, Y_BLOCK, Z_BLOCK>(gridDim);
+ const auto seg_idx =
+ index_of_reduction_segment<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim);
+ const auto rblock_size =
+ size_of_reduction_block<X_THREAD, Y_THREAD, Z_THREAD>(blockDim);
+
+ // advance to the offset for this segment
+ work_buf += seg_idx * seg_size * rblock_size;
+
+ if ((X_THREAD || threadIdx.x == 0) &&
+ (Y_THREAD || threadIdx.y == 0) &&
+ (Z_THREAD || threadIdx.z == 0)) {
+ auto rblock_offset =
+ offset_in_reduction_segment<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim);
+ auto thread_offset =
+ offset_in_reduction_block<X_THREAD, Y_THREAD, Z_THREAD>(threadIdx, blockDim);
+ auto work_buf_offset = rblock_size * rblock_offset + thread_offset;
+ work_buf[work_buf_offset] = inp_val;
+ }
+ __syncthreads();
+
+ __shared__ bool last_block;
+ if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) {
+ __threadfence();
+ auto old = atomic_inc(&sync_flags[seg_idx], seg_size);
+ last_block = old == seg_size - 1;
+ }
+ __syncthreads();
+
+ if (last_block) {
+ // final reduction
+ gridReduceLastBlock<X_THREAD, Y_THREAD, Z_THREAD>(
+ out, (T*)work_buf, seg_size * rblock_size,
+ reduction_op, shared_buf);
+ }
+}
+} // namespace reduction
+)";
+
} // namespace cuda
} // namespace fuser
} // namespace jit
-} // namespace torch
\ No newline at end of file
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp
index ef0ff56..913c39e 100644
--- a/torch/csrc/jit/codegen/cuda/lower2device.cpp
+++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp
@@ -1,12 +1,11 @@
-#include <torch/csrc/jit/codegen/cuda/arith.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/index_compute.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
+#include <torch/csrc/jit/codegen/cuda/lower_index.h>
#include <torch/csrc/jit/codegen/cuda/lower_loops.h>
+#include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h>
+#include <torch/csrc/jit/codegen/cuda/lower_unroll.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
-#include <torch/csrc/jit/codegen/cuda/transform_iter.h>
-#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
-#include <torch/csrc/jit/codegen/cuda/type.h>
+#include <torch/csrc/jit/codegen/cuda/lower_validation.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
@@ -14,418 +13,34 @@
namespace jit {
namespace fuser {
-void GPULower::pushBack(Expr* expr) {
- if (active_scope == nullptr)
- lowered_exprs.push_back(expr);
- else
- scope_utils::pushBack(active_scope, expr);
-}
-
-Statement* GPULower::mutate(Expr* expr) {
- Statement* mutated_stmt = OptOutMutator::mutate(expr);
- TORCH_INTERNAL_ASSERT(
- mutated_stmt->isExpr(),
- "Tried to generate a kernel but hit a non expression during lowering: ",
- mutated_stmt);
- return mutated_stmt;
-}
-
-Statement* GPULower::mutate(IfThenElse* ite) {
- Expr* prev_scope = active_scope;
- active_scope = ite;
- std::vector<Expr*> mutated_exprs;
- bool is_mutated = false;
- for (auto expr : ite->body().exprs()) {
- Statement* mutated_stmt = mutate(expr);
- Expr* mutated_expr = ir_utils::asExpr(mutated_stmt);
- mutated_exprs.push_back(mutated_expr);
- is_mutated = is_mutated | (mutated_expr != expr);
- }
-
- std::vector<Expr*> mutated_else_exprs;
- for (auto expr : ite->elseBody().exprs()) {
- Statement* mutated_stmt = mutate(expr);
- Expr* mutated_expr = ir_utils::asExpr(mutated_stmt);
- mutated_else_exprs.push_back(mutated_expr);
- is_mutated = is_mutated | (mutated_expr != expr);
- }
-
- if (is_mutated) {
- ite->body().clear();
- for (auto expr : mutated_exprs)
- ite->body().push_back(expr);
- ite->elseBody().clear();
- for (auto expr : mutated_else_exprs)
- ite->elseBody().push_back(expr);
- }
-
- active_scope = prev_scope;
-
- if (is_mutated) {
- auto new_ite = new IfThenElse(
- ite->cond(), mutated_exprs, mutated_else_exprs, ite->parentScope());
- return new_ite;
- }
-
- return ite;
-}
-
-Statement* GPULower::mutate(ForLoop* fl) {
- Expr* prev_scope = active_scope;
- active_scope = fl;
- std::vector<Expr*> mutated_exprs;
- bool is_mutated = false;
- for (auto expr : fl->body().exprs()) {
- Statement* mutated_stmt = mutate(expr);
- Expr* mutated_expr = ir_utils::asExpr(mutated_stmt);
- mutated_exprs.push_back(mutated_expr);
- is_mutated = is_mutated | (mutated_expr != expr);
- }
-
- active_scope = prev_scope;
- if (is_mutated) {
- auto newFL = new ForLoop(
- fl->index(), fl->iter_domain(), mutated_exprs, fl->parentScope());
- return newFL;
- }
-
- return fl;
-}
-
-Statement* GPULower::mutate(UnaryOp* uop) {
- if (!ir_utils::isTVOp(uop))
- return OptOutMutator::mutate(uop);
-
- TensorIndex* out = Index::getConsumerIndex(
- ir_utils::asTV(uop->out()), scope_utils::getLoops(active_scope));
- Val* in = uop->in();
- if (ir_utils::isTV(in))
- in = Index::getProducerIndex(
- ir_utils::asTV(in),
- ir_utils::asTV(uop->out()),
- scope_utils::getLoops(active_scope));
- Expr* new_op = new UnaryOp(uop->getUnaryOpType(), out, in);
-
- return new_op;
-}
-
-Statement* GPULower::mutate(BinaryOp* bop) {
- if (!ir_utils::isTVOp(bop))
- return OptOutMutator::mutate(bop);
-
- TensorIndex* out = Index::getConsumerIndex(
- ir_utils::asTV(bop->out()), scope_utils::getLoops(active_scope));
-
- Val* lhs = bop->lhs();
- Val* rhs = bop->rhs();
-
- if (ir_utils::isTV(lhs))
- lhs = Index::getProducerIndex(
- ir_utils::asTV(lhs),
- ir_utils::asTV(bop->out()),
- scope_utils::getLoops(active_scope));
-
- if (ir_utils::isTV(rhs))
- rhs = Index::getProducerIndex(
- ir_utils::asTV(rhs),
- ir_utils::asTV(bop->out()),
- scope_utils::getLoops(active_scope));
-
- Expr* new_op = new BinaryOp(bop->getBinaryOpType(), out, lhs, rhs);
-
- return new_op;
-}
-
-Statement* GPULower::mutate(TernaryOp* top) {
- if (!ir_utils::isTVOp(top))
- return OptOutMutator::mutate(top);
-
- TensorIndex* out = Index::getConsumerIndex(
- ir_utils::asTV(top->out()), scope_utils::getLoops(active_scope));
- Val* in1 = top->in1();
- Val* in2 = top->in2();
- Val* in3 = top->in3();
-
- if (ir_utils::isTV(in1))
- in1 = Index::getProducerIndex(
- ir_utils::asTV(in1),
- ir_utils::asTV(top->out()),
- scope_utils::getLoops(active_scope));
-
- if (ir_utils::isTV(in2))
- in2 = Index::getProducerIndex(
- ir_utils::asTV(in2),
- ir_utils::asTV(top->out()),
- scope_utils::getLoops(active_scope));
-
- if (ir_utils::isTV(in3))
- in3 = Index::getProducerIndex(
- ir_utils::asTV(in3),
- ir_utils::asTV(top->out()),
- scope_utils::getLoops(active_scope));
-
- Expr* new_op = new TernaryOp(top->getTernaryOpType(), out, in1, in2, in3);
-
- return new_op;
-}
-
-Statement* GPULower::mutate(ReductionOp* rop) {
- TORCH_INTERNAL_ASSERT(
- ir_utils::isTVOp(rop),
- "Cannot have a reduction operation on something other than a tensor view.");
- auto loops = scope_utils::getLoops(active_scope);
- TORCH_INTERNAL_ASSERT(
- std::none_of(
- loops.begin(),
- loops.end(),
- [](ForLoop* fl) {
- return fl->iter_domain()->isBlockDim() &&
- fl->iter_domain()->isReduction();
- }),
- "Reduction on block axes not yet supported.");
-
- bool is_thread_reduce =
- std::any_of(loops.begin(), loops.end(), [](ForLoop* fl) {
- return fl->iter_domain()->isThreadDim() &&
- fl->iter_domain()->isReduction();
- });
-
- TensorIndex* out = Index::getConsumerIndex(ir_utils::asTV(rop->out()), loops);
-
- Val* in = rop->in();
- if (ir_utils::isTV(in))
- in = Index::getProducerIndex(
- ir_utils::asTV(in),
- ir_utils::asTV(rop->out()),
- scope_utils::getLoops(active_scope));
-
- if (is_thread_reduce)
- return new ReductionOp(rop->getReductionOpType(), rop->init(), out, in);
-
- Expr* new_op = new BinaryOp(rop->getReductionOpType(), out, out, in);
-
- return new_op;
-}
-
-Statement* GPULower::mutate(BroadcastOp* bop) {
- if (!ir_utils::isTVOp(bop))
- return OptOutMutator::mutate(bop);
-
- TensorIndex* out = Index::getConsumerIndex(
- ir_utils::asTV(bop->out()), scope_utils::getLoops(active_scope));
- Val* in = bop->in();
- if (ir_utils::isTV(in))
- in = Index::getProducerIndex(
- ir_utils::asTV(in),
- ir_utils::asTV(bop->out()),
- scope_utils::getLoops(active_scope));
- Expr* new_op = new BroadcastOp(out, in);
-
- return new_op;
-}
-
-// TensorViews are all based on symbolic sizes. When we first initialize them we
-// don't know if they're inputs or outputs which would mean that they have
-// runtime shapes. Intermediate tensors (those not going to global memory) do
-// not have this information. Since we need to have the correct information in
-// the kernel being fetched for shapes, we want to replace input and output
-// tensors to reference the runtime structure containing sizes.
-void GPULower::replaceSizes() {
- Fusion* fusion = FusionGuard::getCurFusion();
- // Sizes of inputs/outputs -> T.size[...]
- std::unordered_map<Val*, Val*> size_map;
-
- // Grab inputs and outputs
- std::vector<TensorView*> orig_inp_out;
- std::vector<TensorView*> all_tvs;
-
- for (auto* val : fusion->inputs())
- if (ir_utils::isTV(val))
- orig_inp_out.push_back(ir_utils::asTV(val));
-
- for (auto* val : fusion->outputs())
- if (ir_utils::isTV(val))
- orig_inp_out.push_back(ir_utils::asTV(val));
-
- for (auto* val : fusion->deterministic_vals()) {
- if (ir_utils::isTV(val)) {
- all_tvs.push_back(ir_utils::asTV(val));
- }
- }
-
- // Run through inputs and outputs first. Since we're replacing full
- // tensorviews their names are going to change. We need the new referenc
- // name for the inputs/outputs. This way we won't reference the wrong tensor
- // view. For example T0 may be translated to T9. We don't want our new
- // variable to be T0->size[...] we need it to be T9->size[...]
- //
- // This could be done in a better way but changing split/merge to be a
- // TensorDomain focused operation, then we could simply do this process on
- // domains, instead of tensorviews. This would have the benefit that the
- // TensorView wouldn't change, so users pointers will remain valid. The other
- // option which seems less elegant but would also work is build up the domain
- // on the new tensor, and then simply replace it into the original one.
- for (TensorView* tv : orig_inp_out) {
- // Replace the domain with one based on Ti.size[j]
- std::vector<IterDomain*> new_domain_iters;
- const std::vector<IterDomain*>& root_td = tv->getRootDomain();
-
- for (decltype(root_td.size()) i{0}; i < root_td.size(); i++) {
- // Output sizes could have reduction axes, which isn't what gets output.
- if (root_td[i]->isReduction())
- continue;
-
- Val* orig_size = root_td[i]->extent();
-
- std::stringstream ss;
- ss << "T" << tv->name() << ".size[" << i << "]";
- Val* new_size =
- new NamedScalar(ss.str(), orig_size->getDataType().value());
- if (!orig_size->sameAs(new_size) ||
- size_map.find(orig_size) == size_map.end())
- size_map[orig_size] = new_size;
- }
- }
-
- // If we already lowered all inputs/outputs we can just return.
- if (size_map.size() == 0)
- return;
-
- // Set domains to be based on symbolic sizes (i.e. Ti.size[...])
- for (TensorView* tv : all_tvs) {
- std::vector<IterDomain*> new_domain_iters;
- const std::vector<IterDomain*>& root_td = tv->getRootDomain();
-
- for (decltype(root_td.size()) i{0}; i < root_td.size(); i++) {
- Val* new_size = root_td[i]->extent();
- if (size_map.find(new_size) != size_map.end())
- new_size = size_map[new_size];
-
- new_domain_iters.push_back(new IterDomain(
- root_td[i]->start(),
- new_size,
- root_td[i]->parallel_method(),
- root_td[i]->isReduction(),
- root_td[i]->isRFactorProduct(),
- root_td[i]->isBroadcast()));
- }
-
- TensorDomain* old_domain = tv->domain();
- TensorDomain* new_domain = new TensorDomain(new_domain_iters);
-
- // We should just be able to replace sizes in place, but mutator is setup to
- // do that as it set up to replace vals in Exprs, but
- // IterDomain/TensorDomain are vals.
-
- new_domain = TransformReplay::fullSelfReplay(new_domain, old_domain);
-
- TORCH_INTERNAL_ASSERT(
- old_domain->nDims() == new_domain->nDims(),
- "Tried to set symbolic sizes through the kernel, but hit a snag, Replayed domain should be the same size as the target domain, but got ",
- new_domain->nDims(),
- " and ",
- old_domain->nDims());
- // Parallelize all iter domains
- for (decltype(new_domain->nDims()) i{0}; i < new_domain->nDims(); i++)
- new_domain->axis(i)->parallelize(old_domain->axis(i)->parallel_method());
-
- tv->setDomain(new_domain);
- }
-
- // Adjust memory types to make sure they are valid
- for (TensorView* tv : all_tvs) {
- if (fusion->hasInput(tv) || fusion->hasOutput(tv)) {
- tv->setMemoryType(MemoryType::Global);
- } else {
- if (tv->getMemoryType() == MemoryType::Global)
- tv->setMemoryType(MemoryType::Local);
- }
- }
-}
-
-namespace {
-
-// Some pre-compilation checks
-void validate(Fusion* fusion) {
- FusionGuard fg(fusion);
- fusion->validateInputs();
- for (Val* val : fusion->vals()) {
- if (ir_utils::isTV(val)) {
- TensorView* tv = ir_utils::asTV(val);
- for (decltype(tv->nDims()) i{0}; i < tv->nDims(); i++) {
- IterDomain* id = tv->getComputeAtAxis(i).first;
-
- if (id->isBlockDim())
- TORCH_CHECK(
- !id->isReduction(),
- "Parallelization across blocks on reduction axes not support at the moment but found on, ",
- tv,
- ".");
- }
- } // if ir_utils::isTV
- } // for(Val* val : fusion->vals())
-} // validate
-
-} // namespace
-
-// Remove circular computeAt references
-void GPULower::fixComputeAt(Fusion* fusion) {
- FusionGuard fg(fusion);
-
- std::vector<Expr*> exprs = fusion->exprs(true);
- std::set<TensorView*> visited;
- for (auto it = exprs.rbegin(); it != exprs.rend(); it++) {
- Expr* expr = *it;
- if (!ir_utils::isTVOp(expr))
- continue;
-
- TensorView* tv = ir_utils::asTV(expr->output(0));
- TensorView* ctv = tv->getComputeAtView();
-
- if (ctv != nullptr && visited.find(ctv) == visited.end()) {
- ctv->setComputeAt(tv, (int)tv->getThisComputeAtAxis());
- tv->clearComputeAt();
- }
- visited.emplace(tv);
- }
-}
-
// Traverse through the fusion and print CUDA code associated with it
std::vector<Expr*> GPULower::getLoweredExprs() {
FusionGuard fg(fusion_);
- // Compute at can have some circular references. Before we can call any tv
- // with tv->getComputeAtAxis(i) we need to break those circular dependencies.
- fixComputeAt(fusion_);
+ // Validate and make some minor modifications in preparation to generate code.
+ PrepareForLowering(fusion_);
- validate(fusion_);
- replaceSizes();
- auto loop_nests =
- LoopNestGenerator::getLoopNest(fusion_, fusion_->exprs(true));
+ auto preds = ThreadPredicates::compute(fusion_);
- auto unrolled_loops = UnrollPass::runPass(fusion_, loop_nests);
+ // Run our passes keeping the lowered expressions and forwarding them.
+ auto loop_nests = LoopNestGenerator::getLoopNest(
+ fusion_, fusion_->exprs(true, false, true), preds);
- // Run through loop nests and further lower the expressions
- for (auto* expr : unrolled_loops) {
- Statement* mutated_stmt = mutate(expr);
- TORCH_INTERNAL_ASSERT(
- mutated_stmt->isExpr(),
- "Tried to generate a kernel but hit a non expression during lowering: ",
- mutated_stmt);
- lowered_exprs.push_back(static_cast<Expr*>(mutated_stmt));
- }
+ auto unrolled_loops = UnrollPass::runPass(fusion_, loop_nests, preds);
- return lowered_exprs;
+ auto indexed_loops = IndexLowering::getIndexedExprs(fusion_, unrolled_loops);
+
+ return indexed_loops;
}
std::ostream& GPULower::printKernel(
std::ostream& os,
const std::string& kernel_name) {
FusionGuard fg(fusion_);
- getLoweredExprs();
+ auto exprs = getLoweredExprs();
IRPrinter irp(os);
- irp.printKernel(lowered_exprs, kernel_name);
+ irp.printKernel(exprs, kernel_name);
return os;
}
diff --git a/torch/csrc/jit/codegen/cuda/lower2device.h b/torch/csrc/jit/codegen/cuda/lower2device.h
index 80607de..57eadbf 100644
--- a/torch/csrc/jit/codegen/cuda/lower2device.h
+++ b/torch/csrc/jit/codegen/cuda/lower2device.h
@@ -4,53 +4,13 @@
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <map>
#include <ostream>
-#include <stack>
namespace torch {
namespace jit {
namespace fuser {
-// TODO: Change lowering so it can be called multiple times. It would be good to
-// keep user references intact so they can lower it as they describe the kernel.
-// Right now we can only lower once.
-
-struct TORCH_CUDA_API GPULower : public OptOutMutator {
- private:
- Fusion* const fusion_;
- std::vector<Expr*> lowered_exprs;
- Expr* active_scope = nullptr;
-
- // Wrap pushBack in lower_utils if active_scope is null we want it to go
- // straight to lower_exprs
- void pushBack(Expr*);
-
- // Custom dispatch for Expr, want to find out of it's a TV op
- Statement* mutate(Expr*) final;
-
- // Open the for loop.
- Statement* mutate(ForLoop*) final;
-
- // Open the for loop.
- Statement* mutate(IfThenElse*) final;
-
- // Remake operations with TensorIndex
- Statement* mutate(UnaryOp*) final;
- Statement* mutate(BinaryOp*) final;
- Statement* mutate(TernaryOp*) final;
- Statement* mutate(ReductionOp*) final;
- Statement* mutate(BroadcastOp*) final;
-
- // TensorViews are all based on symbolic sizes. When we first initialize them
- // we don't know if they're inputs or outputs which would mean that they have
- // runtime shapes. Intermediate tensors (those not going to global memory) do
- // not have this information. Since we need to have the correct information in
- // the kernel being fetched for shapes, we want to replace input and output
- // tensors to reference the runtime structure containing sizes.
- void replaceSizes();
- void fixComputeAt(Fusion* fusion);
-
+class TORCH_CUDA_API GPULower {
public:
// Init printer on ostream
GPULower(Fusion* _fusion) : fusion_(_fusion) {}
@@ -60,6 +20,9 @@
std::ostream& printKernel(
std::ostream& _os,
const std::string& kernel_name = "CUDAGeneratedKernel");
+
+ private:
+ Fusion* const fusion_ = nullptr;
};
} // namespace fuser
diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp
new file mode 100644
index 0000000..d506d64
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp
@@ -0,0 +1,225 @@
+#include <torch/csrc/jit/codegen/cuda/index_compute.h>
+#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
+
+#include <torch/csrc/jit/codegen/cuda/lower_index.h>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+
+void IndexLowering::pushBack(Expr* expr) {
+ if (active_scope == nullptr)
+ lowered_exprs.push_back(expr);
+ else
+ scope_utils::pushBack(active_scope, expr);
+}
+
+Statement* IndexLowering::mutate(Expr* expr) {
+ Statement* mutated_stmt = OptOutMutator::mutate(expr);
+ TORCH_INTERNAL_ASSERT(
+ mutated_stmt->isExpr(),
+ "Tried to generate a kernel but hit a non expression during lowering: ",
+ mutated_stmt);
+ return mutated_stmt;
+}
+
+Statement* IndexLowering::mutate(IfThenElse* ite) {
+ Expr* prev_scope = active_scope;
+ active_scope = ite;
+ std::vector<Expr*> mutated_exprs;
+ bool is_mutated = false;
+ for (auto expr : ite->body().exprs()) {
+ Statement* mutated_stmt = mutate(expr);
+ Expr* mutated_expr = ir_utils::asExpr(mutated_stmt);
+ mutated_exprs.push_back(mutated_expr);
+ is_mutated = is_mutated | (mutated_expr != expr);
+ }
+
+ std::vector<Expr*> mutated_else_exprs;
+ for (auto expr : ite->elseBody().exprs()) {
+ Statement* mutated_stmt = mutate(expr);
+ Expr* mutated_expr = ir_utils::asExpr(mutated_stmt);
+ mutated_else_exprs.push_back(mutated_expr);
+ is_mutated = is_mutated | (mutated_expr != expr);
+ }
+
+ if (is_mutated) {
+ ite->body().clear();
+ for (auto expr : mutated_exprs)
+ ite->body().push_back(expr);
+ ite->elseBody().clear();
+ for (auto expr : mutated_else_exprs)
+ ite->elseBody().push_back(expr);
+ }
+
+ active_scope = prev_scope;
+
+ if (is_mutated) {
+ auto new_ite = new IfThenElse(
+ ite->cond(), mutated_exprs, mutated_else_exprs, ite->parentScope());
+ return new_ite;
+ }
+
+ return ite;
+}
+
+Statement* IndexLowering::mutate(ForLoop* fl) {
+ Expr* prev_scope = active_scope;
+ active_scope = fl;
+ std::vector<Expr*> mutated_exprs;
+ bool is_mutated = false;
+ for (auto expr : fl->body().exprs()) {
+ Statement* mutated_stmt = mutate(expr);
+ Expr* mutated_expr = ir_utils::asExpr(mutated_stmt);
+ mutated_exprs.push_back(mutated_expr);
+ is_mutated = is_mutated | (mutated_expr != expr);
+ }
+
+ active_scope = prev_scope;
+ if (is_mutated) {
+ auto newFL = new ForLoop(
+ fl->index(), fl->iter_domain(), mutated_exprs, fl->parentScope());
+ return newFL;
+ }
+
+ return fl;
+}
+
+Statement* IndexLowering::mutate(UnaryOp* uop) {
+ if (!ir_utils::isTVOp(uop))
+ return OptOutMutator::mutate(uop);
+
+ TensorIndex* out = Index::getConsumerIndex(
+ ir_utils::asTV(uop->out()), scope_utils::getLoops(active_scope));
+ Val* in = uop->in();
+ if (ir_utils::isTV(in))
+ in = Index::getProducerIndex(
+ ir_utils::asTV(in),
+ ir_utils::asTV(uop->out()),
+ scope_utils::getLoops(active_scope));
+ Expr* new_op = new UnaryOp(uop->getUnaryOpType(), out, in);
+
+ return new_op;
+}
+
+Statement* IndexLowering::mutate(BinaryOp* bop) {
+ if (!ir_utils::isTVOp(bop))
+ return OptOutMutator::mutate(bop);
+
+ TensorIndex* out = Index::getConsumerIndex(
+ ir_utils::asTV(bop->out()), scope_utils::getLoops(active_scope));
+
+ Val* lhs = bop->lhs();
+ Val* rhs = bop->rhs();
+
+ if (ir_utils::isTV(lhs))
+ lhs = Index::getProducerIndex(
+ ir_utils::asTV(lhs),
+ ir_utils::asTV(bop->out()),
+ scope_utils::getLoops(active_scope));
+
+ if (ir_utils::isTV(rhs))
+ rhs = Index::getProducerIndex(
+ ir_utils::asTV(rhs),
+ ir_utils::asTV(bop->out()),
+ scope_utils::getLoops(active_scope));
+
+ Expr* new_op = new BinaryOp(bop->getBinaryOpType(), out, lhs, rhs);
+
+ return new_op;
+}
+
+Statement* IndexLowering::mutate(TernaryOp* top) {
+ if (!ir_utils::isTVOp(top))
+ return OptOutMutator::mutate(top);
+
+ TensorIndex* out = Index::getConsumerIndex(
+ ir_utils::asTV(top->out()), scope_utils::getLoops(active_scope));
+ Val* in1 = top->in1();
+ Val* in2 = top->in2();
+ Val* in3 = top->in3();
+
+ if (ir_utils::isTV(in1))
+ in1 = Index::getProducerIndex(
+ ir_utils::asTV(in1),
+ ir_utils::asTV(top->out()),
+ scope_utils::getLoops(active_scope));
+
+ if (ir_utils::isTV(in2))
+ in2 = Index::getProducerIndex(
+ ir_utils::asTV(in2),
+ ir_utils::asTV(top->out()),
+ scope_utils::getLoops(active_scope));
+
+ if (ir_utils::isTV(in3))
+ in3 = Index::getProducerIndex(
+ ir_utils::asTV(in3),
+ ir_utils::asTV(top->out()),
+ scope_utils::getLoops(active_scope));
+
+ Expr* new_op = new TernaryOp(top->getTernaryOpType(), out, in1, in2, in3);
+
+ return new_op;
+}
+
+Statement* IndexLowering::mutate(ReductionOp* rop) {
+ TORCH_INTERNAL_ASSERT(
+ ir_utils::isTVOp(rop),
+ "Cannot have a reduction operation on something other than a tensor view.");
+ auto loops = scope_utils::getLoops(active_scope);
+
+ bool is_private_reduce =
+ std::none_of(loops.begin(), loops.end(), [](ForLoop* fl) {
+ return fl->iter_domain()->isThread() &&
+ fl->iter_domain()->isReduction();
+ });
+
+ TensorIndex* out = Index::getConsumerIndex(ir_utils::asTV(rop->out()), loops);
+
+ Val* in = rop->in();
+ if (ir_utils::isTV(in))
+ in = Index::getProducerIndex(
+ ir_utils::asTV(in),
+ ir_utils::asTV(rop->out()),
+ scope_utils::getLoops(active_scope));
+
+ if (!is_private_reduce)
+ return new ReductionOp(rop->getReductionOpType(), rop->init(), out, in);
+
+ Expr* new_op = new BinaryOp(rop->getReductionOpType(), out, out, in);
+
+ return new_op;
+}
+
+Statement* IndexLowering::mutate(BroadcastOp* bop) {
+ if (!ir_utils::isTVOp(bop))
+ return OptOutMutator::mutate(bop);
+
+ TensorIndex* out = Index::getConsumerIndex(
+ ir_utils::asTV(bop->out()), scope_utils::getLoops(active_scope));
+ Val* in = bop->in();
+ if (ir_utils::isTV(in))
+ in = Index::getProducerIndex(
+ ir_utils::asTV(in),
+ ir_utils::asTV(bop->out()),
+ scope_utils::getLoops(active_scope));
+ Expr* new_op = new BroadcastOp(out, in);
+
+ return new_op;
+}
+
+void IndexLowering::generate(const std::vector<Expr*>& exprs) {
+ // Run through loop nests and further lower the expressions
+ for (auto* expr : exprs) {
+ Statement* mutated_stmt = mutate(expr);
+ TORCH_INTERNAL_ASSERT(
+ mutated_stmt->isExpr(),
+ "Tried to generate a kernel but hit a non expression during lowering: ",
+ mutated_stmt);
+ lowered_exprs.push_back(static_cast<Expr*>(mutated_stmt));
+ }
+}
+
+} // namespace fuser
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/lower_index.h b/torch/csrc/jit/codegen/cuda/lower_index.h
new file mode 100644
index 0000000..564cc49
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/lower_index.h
@@ -0,0 +1,52 @@
+#pragma once
+
+#include <torch/csrc/WindowsTorchApiMacro.h>
+
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
+
+#include <vector>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+
+class TORCH_CUDA_API IndexLowering : public OptOutMutator {
+ private:
+ std::vector<Expr*> lowered_exprs;
+ Expr* active_scope = nullptr;
+
+ // Wrap pushBack in lower_utils if active_scope is null we want it to go
+ // straight to lower_exprs
+ void pushBack(Expr*);
+
+ // Custom dispatch for Expr, want to find out of it's a TV op
+ Statement* mutate(Expr*) final;
+
+ // Open the for loop.
+ Statement* mutate(ForLoop*) final;
+
+ // Open the for loop.
+ Statement* mutate(IfThenElse*) final;
+
+ // Remake operations with TensorIndex
+ Statement* mutate(UnaryOp*) final;
+ Statement* mutate(BinaryOp*) final;
+ Statement* mutate(TernaryOp*) final;
+ Statement* mutate(ReductionOp*) final;
+ Statement* mutate(BroadcastOp*) final;
+ void generate(const std::vector<Expr*>& exprs);
+
+ public:
+ static std::vector<Expr*> getIndexedExprs(
+ Fusion* fusion,
+ std::vector<Expr*> incoming_exprs) {
+ FusionGuard fg(fusion);
+ IndexLowering il;
+ il.generate(incoming_exprs);
+ return il.lowered_exprs;
+ }
+};
+
+} // namespace fuser
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.cpp b/torch/csrc/jit/codegen/cuda/lower_loops.cpp
index acd48b1..93f2613 100644
--- a/torch/csrc/jit/codegen/cuda/lower_loops.cpp
+++ b/torch/csrc/jit/codegen/cuda/lower_loops.cpp
@@ -2,304 +2,89 @@
#include <torch/csrc/jit/codegen/cuda/arith.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
-#include <torch/csrc/jit/codegen/cuda/index_compute.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/predicate_compute.h>
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
namespace torch {
namespace jit {
namespace fuser {
-bool operator==(
- const std::pair<IterDomain*, TensorView*>& p1,
- const std::pair<IterDomain*, TensorView*>& p2) {
- return p1.first->sameAs(p2.first) && p1.second == p2.second;
-}
-
-// all the way in the loop nest, grab predicate
-/*
-for( i : ceil(I/4) ) {
- for( j : ceil(J/128) ) {
-
- if( i * 4 + 3 < I && j * 128 + 127 < J ){
- for( k : 4)
- for( l : 128 )
- T0[ ( i * 4 + k ) * J + j * 128 + l ] = …
- } else {
- for( k : 4 )
- for( l : 128 )
- if( i * 4 + k < I && j * 128 + l < J)
- T0[ ( i * 4 + k ) * J + j * 128 + l ] = …
- }
-
- }
-}
-*/
-
-// Custom dispatch for Expr, want to find out of it's a TV op
-void UnrollPass::handle(Expr* expr) {
- OptOutDispatch::handle(expr);
-}
-
-namespace {
-Bool* getPredicate(TensorView* tv, std::vector<Val*> inds_) {
- TORCH_INTERNAL_ASSERT(
- inds_.size() == tv->nDims() ||
- inds_.size() == tv->domain()->noReductions().size());
-
- std::vector<Val*> inds;
- if (inds_.size() < tv->nDims()) {
- size_t i_ = 0;
- for (size_t i = 0; i < tv->nDims() && i_ < inds_.size(); i++) {
- if (tv->axis(i)->isReduction())
- inds.push_back(new Int(0));
- else
- inds.push_back(inds_[i_++]);
- }
- } else {
- inds = inds_;
- }
- if (tv->nDims() > inds.size()) {
- for (decltype(tv->nDims()) i{0}; i < tv->nDims(); i++) {
- if (tv->axis(i)->isReduction())
- inds.insert(inds.begin() + i, new Int(0));
- }
- }
- std::vector<Bool*> all_preds = PredicateCompute::computePredicates(
- new TensorIndex(tv, IndexCompute::get(tv->domain(), inds)));
-
- std::vector<Bool*> preds;
-
- for (Bool* pred : all_preds)
- if (!(pred->isConst()) || !(pred->isConst() && pred->value().value()))
- preds.push_back(pred);
-
- if (preds.size() == 0)
- return new Bool(true);
-
- Val* cond = preds[0];
-
- for (decltype(preds.size()) i{1}; i < preds.size(); i++) {
- cond = andOp(cond, preds[i]);
- }
-
- TORCH_INTERNAL_ASSERT(
- cond->getValType().value() == ValType::Scalar &&
- cond->getDataType().value() == DataType::Bool,
- "Error computing predicate, should be returning a Bool, but returning ",
- cond->getDataType().value());
-
- return static_cast<Bool*>(cond);
-}
-} // namespace
-
-// This function is one huge mess that should be refactored.
-// It handles the unrolling and predicate generation
-void UnrollPass::handle(ForLoop* fl) {
- // Setup for loop scoping
- for_loops.push_back(fl);
- bool prev_unroll = within_unroll;
- within_unroll = ir_utils::isUnrolledFor(fl) || within_unroll;
-
- for (auto expr : fl->body().exprs()) {
- OptOutDispatch::handle(expr);
- }
-
- TensorView* out = nullptr;
- bool has_global = false;
- for (Expr* expr : fl->body().exprs())
- if (ir_utils::isTVOp(expr)) {
- // Predicate determining op for unroll
- out = ir_utils::asTV(expr->output(0));
- has_global = has_global || out->getMemoryType() == MemoryType::Global;
- for (auto inp : expr->inputs())
- if (ir_utils::isTV(inp))
- has_global = has_global ||
- ir_utils::asTV(inp)->getMemoryType() == MemoryType::Global;
- }
-
- bool has_TV_op = out != nullptr;
-
- if (within_unroll && has_TV_op && has_global) {
- // Setup unrolled loop information:
-
- // Indices used to detect when we can unroll a loop safely
- // For loops outside the unroll, it's just he index, for loops inside
- // the unroll, if it's a thread it's the thread index, otherwise it's
- // the size-1
- std::vector<Val*> unroll_pred_inds;
- auto it = for_loops.begin();
- while (it != for_loops.end()) {
- if (ir_utils::isUnrolledFor(*it))
- break;
- unroll_pred_inds.push_back((*it)->index());
- it++;
- }
-
- TORCH_INTERNAL_ASSERT(
- it != for_loops.end(),
- "Error unrolling loops, expected an unrolled loop but wasn't found.");
-
- // This is the outer most loop that needs to be unrolled
- ForLoop* first_unroll = *it;
-
- // Indicies inside the unroll
- while (it != for_loops.end()) {
- IterDomain* id = (*it)->iter_domain();
- if (id->isThread())
- unroll_pred_inds.push_back((*it)->index());
- else
- unroll_pred_inds.push_back(sub(id->extent(), new Int(1)));
- it++;
- }
-
- // Make predicates for the unrolling, and the epilogue
- Bool* unroll_predicate = getPredicate(out, unroll_pred_inds);
- // Make the IfThenElse controlling the unrolling
- IfThenElse* unroll_ite =
- new IfThenElse(unroll_predicate, {}, {}, first_unroll->parentScope());
-
- // Get the loop nest for the unrolled path
- ForLoop* unrolled_loop =
- scope_utils::cloneLoopNest(first_unroll, unroll_ite);
- unroll_ite->body().push_back(unrolled_loop);
-
- // Loop nest for inlined path
- ForLoop* inlined_loop =
- scope_utils::cloneLoopNest(first_unroll, unroll_ite);
- unroll_ite->elseBody().push_back(inlined_loop);
-
- // Inner most inlined loop
- Expr* inner_most_inlined_loop =
- scope_utils::firstInnerMostScope(inlined_loop);
-
- loop_replacement_map.insert({first_unroll, unroll_ite});
-
- for (auto expr : fl->body().exprs()) {
- if (!ir_utils::isTVOp(expr))
- continue;
-
- // Setup the expressions that need predicates around them.
- Bool* inline_predicate = getPredicate(out, ir_utils::indices(for_loops));
-
- IfThenElse* inline_ite =
- new IfThenElse(inline_predicate, {expr}, {}, inner_most_inlined_loop);
- std::unordered_map<Expr*, Expr*> inline_replacement_map;
- inline_replacement_map.emplace(std::pair<Expr*, Expr*>(expr, inline_ite));
- scope_utils::replaceExprsInScope(
- inner_most_inlined_loop, inline_replacement_map);
-
- } // for expr
- } else { // if(!within_unroll)
- // modify in place, so grab a copy of exprs first.
- std::vector<Expr*> exprs(
- fl->body().exprs().begin(), fl->body().exprs().end());
-
- for (auto expr : exprs) {
- if (!ir_utils::isTVOp(expr))
- continue;
-
- TensorView* out = ir_utils::asTV(ir_utils::asExpr(expr)->outputs()[0]);
-
- Bool* pred = getPredicate(out, ir_utils::indices(for_loops));
-
- // If we need a predicate, put expr inside an if then else
- if (!(pred->isConst()) || !(pred->isConst() && pred->value().value())) {
- IfThenElse* inline_ite =
- new IfThenElse(pred, {expr}, {}, for_loops.back());
- for_loops.back()->body().insert_before(expr, inline_ite);
- for_loops.back()->body().erase(expr);
- }
- }
- } // else (if(!within_unroll))
-
- for_loops.pop_back();
- within_unroll = prev_unroll;
-}
-
-// Generate the loop nest structure and place it in lowered_exprs
-void UnrollPass::computeMap() {
- FusionGuard fg(fusion_);
-
- // Run through loop nests and further lower the expressions
- for (auto* expr : incoming_exprs_) {
- OptOutDispatch::handle(expr);
- }
-}
-
-std::vector<Expr*> UnrollPass::runPass(
- Fusion* fusion,
- const std::vector<Expr*>& exprs) {
- FusionGuard fg(fusion);
- UnrollPass up(fusion, exprs);
- up.computeMap();
- std::vector<Expr*> mutated_exprs;
- for (Expr* expr : exprs) {
- if (up.loop_replacement_map.find(expr) != up.loop_replacement_map.end()) {
- mutated_exprs.push_back(up.loop_replacement_map[expr]);
- } else {
- if (ir_utils::isScope(expr))
- scope_utils::replaceExprsInScope(expr, up.loop_replacement_map);
- mutated_exprs.push_back(expr);
- }
- }
- return mutated_exprs;
-}
-
-void LoopNestGenerator::pushAlloc(TensorView* tv) {
+// Create, place, and return the allocation for tv
+Expr* LoopNestGenerator::pushAlloc(TensorView* tv) {
TORCH_INTERNAL_ASSERT(
!(FusionGuard::getCurFusion()->hasInput(tv) ||
FusionGuard::getCurFusion()->hasOutput(tv)),
"Tried to allocate an input or output tensor.");
- // Compute at axis can be == tv->nDims() meaning it's inline
- decltype(tv->nDims()) alloc_pos = 0;
- // Do we need to close to root and alloc there?
- bool reset = true;
- while (alloc_pos <= tv->nDims()) {
+ // First figure out which loop nest this allocation needs to be placed in
+ // Do we need to place the allocation at the root?
+ size_t alloc_pos = 0;
+ // If there's no computeAt, then we want to be allocated at the root
+ while (alloc_pos <= tv->nDims() && tv->hasComputeAt()) {
+ // If we have a computeAt and we reached computeAt pos that's where it goes
if (tv->hasComputeAt() && alloc_pos == tv->getThisComputeAtAxis()) {
- reset = false;
break;
}
+
+ // If we found an unroll, we want to place the allocation outside the unroll
if (alloc_pos < tv->nDims() &&
tv->getComputeAtAxis(alloc_pos).first->parallel_method() ==
ParallelType::Unroll) {
- reset = false;
break;
}
alloc_pos++;
}
- alloc_pos = reset ? 0 : alloc_pos;
+ // Grab the dimensions the allocation will be based on
std::vector<Val*> alloc_dims;
for (auto i = alloc_pos; i < tv->nDims(); i++) {
IterDomain* dim = tv->getComputeAtAxis(i).first;
- if (dim->isThreadDim() || dim->isReduction())
+ if (
+ // If shared memory, don't use any IDs bound to a grid dimension
+ (tv->memory_type_ == MemoryType::Shared && dim->isBlockDim()) ||
+ // If local memory, don't use any IDs bound to a grid or block dimension
+ (tv->memory_type_ == MemoryType::Local && dim->isThread()) ||
+ // If we're reducing this dimension, don't use it in the allocation
+ // computation
+ dim->isReduction() ||
+ // If this is a broadcast dimension, don't use it in the allocation
+ // computation
+ dim->isBroadcast())
continue;
alloc_dims.push_back(dim->extent());
}
- Val* size;
+ // Multiply all the dimensions we're going to use for the allocation together
+ // to get the total size
+ Val* size = nullptr;
if (alloc_dims.size() == 0) {
size = new Int(1);
} else {
size = alloc_dims[0];
- for (decltype(alloc_dims.size()) i{1}; i < alloc_dims.size(); i++) {
+ for (size_t i = 1; i < alloc_dims.size(); i++) {
size = mul(size, alloc_dims[i]);
}
}
+
+ // Create the allocation node
Allocate* alloc = new Allocate(tv, size);
+ // Place the allocation
if (alloc_pos == 0) {
+ // If we allocate at the root, insert at the begining of the lowered
+ // expressions
lowered_exprs.insert(lowered_exprs.begin(), alloc);
} else if (alloc_pos == for_loops.size()) {
- // inline
- scope_utils::pushBack(for_loops[alloc_pos - 1], alloc);
+ // If we allocate inline, push to the back of the last for loop
+ scope_utils::pushBack(for_loops[for_loops.size() - 1], alloc);
} else {
+ // Otherwise we allocate in some loop nest that is not inline, or root, so
+ // insert right before the loop we're just outside of
scope_utils::insertBefore(
for_loops[alloc_pos - 1], for_loops[alloc_pos], alloc);
}
+
+ return alloc;
}
void LoopNestGenerator::openFor(std::pair<IterDomain*, TensorView*> id_pair) {
@@ -329,29 +114,39 @@
scope_utils::pushBack(for_loops.back(), expr);
}
-// Update for loop structure based on this TensorView
-void LoopNestGenerator::initReduction(TensorView* tv, Val* init_val) {
- // This logic was taken from allocation placement, as we want to initialize
- // the reduction buffers right after they're allocated. Compute at axis can be
- // == tv->nDims() meaning it's inline
- decltype(tv->nDims()) alloc_pos = 0;
- // Do we need to close to root and alloc there?
- bool reset = true;
- while (alloc_pos <= tv->nDims()) {
+// Update for loop structure based on this TensorView, if there's an allocation
+// stmt, send it in so we can make sure that we insert this initialization after
+// it
+void LoopNestGenerator::initReduction(
+ TensorView* tv,
+ Val* init_val,
+ Expr* alloc_expr) {
+ // This logic was taken from pushAlloc, as the initialization loop nest will
+ // go at the same place.
+
+ // First figure out which loop nest this allocation needs to be placed in
+ // Do we need to place the allocation at the root?
+ size_t alloc_pos = 0;
+ // If there's no computeAt, then we want to be allocated at the root
+ while (alloc_pos <= tv->nDims() && tv->hasComputeAt()) {
+ // If we have a computeAt and we reached computeAt pos that's where it goes
if (tv->hasComputeAt() && alloc_pos == tv->getThisComputeAtAxis()) {
- reset = false;
break;
}
+
+ // If we found an unroll, we want to place the allocation outside the unroll
if (alloc_pos < tv->nDims() &&
tv->getComputeAtAxis(alloc_pos).first->parallel_method() ==
ParallelType::Unroll) {
- reset = false;
break;
}
alloc_pos++;
}
- alloc_pos = reset ? 0 : alloc_pos;
+ // Grab the IDs that will be involved in the initialization, ignore reduction
+ // dimensions. Everything else will be iterated over to cover the entire
+ // buffer. Index compute will ignore [block, grid]Dims depending on buffer
+ // memory location
std::vector<IterDomain*> ids;
for (auto i = alloc_pos; i < tv->nDims(); i++) {
IterDomain* dim = tv->getComputeAtAxis(i).first;
@@ -360,45 +155,91 @@
ids.push_back(dim);
}
+ // Unsafe clone, as we want an exact replica of tv so we can create a UnaryOp
+ // to set the buffer to the init_val.
auto clone = tv->unsafeClone();
+ if (thread_predicates_.find(tv) != thread_predicates_.end()) {
+ thread_predicates_[clone] = thread_predicates_[tv];
+ }
+ // The initilization stmt that will be located inside the loop nest (if there
+ // is one)
auto init_stmt = new UnaryOp(UnaryOpType::Set, clone, init_val);
- Expr* init = nullptr;
+ // Init a pointer that will become the entirety of the initialization
+ Expr* init_loop_nest = nullptr;
+
+ // The for loop that we will place the initialization within (alloc_pos - 1),
+ // if one exists. Once we're done this inner_fl will be the inner most loop
+ // containing the init_stmt
ForLoop* inner_fl = nullptr;
if (alloc_pos >= 1)
inner_fl = for_loops[alloc_pos - 1];
+
+ // Work through the iter domains that we need to initialize on, outside to
+ // inside, to construct the loop nest for the initialization.
for (auto id : ids) {
ForLoop* new_fl;
+
if (id->isThread()) {
+ // If based on a thread, make sure we get the named Int right
std::stringstream ss;
ss << id->parallel_method();
new_fl = new ForLoop(
new NamedScalar(ss.str(), DataType::Int), id, {}, inner_fl);
} else {
+ // Otherwise it's just a new int-
new_fl = new ForLoop(new Int(), id, {}, inner_fl);
}
- if (init == nullptr) {
- init = new_fl;
- inner_fl = new_fl;
+ if (init_loop_nest == nullptr) {
+ // If this is our first generated loop, then it will be our outer most
+ // loop nest
+ init_loop_nest = new_fl;
} else {
+ // Otherwise place it inside the last generated loop
inner_fl->body().push_back(new_fl);
- inner_fl = new_fl;
}
+ // Increment the inner most for loop
+ inner_fl = new_fl;
}
- if (init == nullptr) {
- init = init_stmt;
+
+ if (init_loop_nest == nullptr) {
+ // If no loops were generated, than our init_stmt is all we need
+ init_loop_nest = init_stmt;
} else {
+ // If there were for loops generated, place the init_stmt in the inner most
+ // for loop.
inner_fl->body().push_back(init_stmt);
}
+ // Place the allocation
if (alloc_pos == 0) {
- lowered_exprs.insert(lowered_exprs.begin(), init);
+ // If we allocate at the root, look for the provided allocatoin if it
+ // exists, and place after it.
+ if (alloc_expr != nullptr) {
+ bool found = false;
+ for (auto it = lowered_exprs.begin(); it != lowered_exprs.end(); it++) {
+ if ((*it) == alloc_expr) {
+ lowered_exprs.insert(it + 1, init_loop_nest);
+ found = true;
+ break;
+ }
+ }
+ TORCH_INTERNAL_ASSERT(
+ found,
+ "Could not figure out where to initialize the buffer for ",
+ tv);
+ } else {
+ lowered_exprs.insert(lowered_exprs.begin(), init_loop_nest);
+ }
} else if (alloc_pos == for_loops.size()) {
- scope_utils::pushBack(for_loops[alloc_pos - 1], init);
+ // If we allocate inline, push to the back of the last for loop
+ scope_utils::pushBack(for_loops[for_loops.size() - 1], init_loop_nest);
} else {
+ // Otherwise we allocate in some loop nest that is not inline, or root, so
+ // insert right before the loop we're just outside of
scope_utils::insertBefore(
- for_loops[alloc_pos - 1], for_loops[alloc_pos], init);
+ for_loops[alloc_pos - 1], for_loops[alloc_pos], init_loop_nest);
}
}
@@ -437,15 +278,17 @@
openFor(out->getComputeAtAxis((int)compute_at_scope.size()));
}
+ Expr* alloc_stmt = nullptr;
// 3) Allocate the output.
if (!FusionGuard::getCurFusion()->hasInput(out) &&
!FusionGuard::getCurFusion()->hasOutput(out))
- pushAlloc(out);
+ alloc_stmt = pushAlloc(out);
// 4) If this is a reduction, initialize the output (open for loops to inner
- // most, predicate, initialize, F predicate, close to computeAt)
+ // most, predicate, initialize, place next after allocation if exists, close
+ // to computeAt)
if (out->hasReduction())
- initReduction(out, static_cast<ReductionOp*>(expr)->init());
+ initReduction(out, static_cast<ReductionOp*>(expr)->init(), alloc_stmt);
// 5) Open to inner most loop
for (decltype(out->nDims()) i = for_loops.size(); i < out->nDims(); i++)
@@ -472,4 +315,4 @@
} // namespace fuser
} // namespace jit
-} // namespace torch
\ No newline at end of file
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.h b/torch/csrc/jit/codegen/cuda/lower_loops.h
index 008216e..f6ecb57 100644
--- a/torch/csrc/jit/codegen/cuda/lower_loops.h
+++ b/torch/csrc/jit/codegen/cuda/lower_loops.h
@@ -8,77 +8,81 @@
namespace jit {
namespace fuser {
-struct UnrollPass : public OptOutDispatch {
+/*
+ * Loop nest generator pass will get IR that looks something like:
+ * T0[I0o{ceil(I0/4)}, I1o{ceil(I1/128)}, I0iU{4}, I1i{128}] = ...* for( i :
+ * I0o{ceil(I0/4)} ) { and will generate the loop nest structure for these exprs
+ * like:
+ *
+ * for( i : I0o{ceil(I0/4)} ) {
+ * for( j : I1o{ceil(I1/128)} ) {
+ * for( k : I0i{4} )
+ * for( l : I1i{128} )
+ * T0[I0o{ceil(I0/4)}, I1o{ceil(I1/128)}, I0iU{4}, I1i{128}] = ...
+ *
+ * It does not generate predicates, but it will generate allocations, and loop
+ * nests to initialize reduction buffers.
+ *
+ */
+class TORCH_CUDA_API LoopNestGenerator : public OptOutDispatch {
private:
- std::unordered_map<Expr*, Expr*> loop_replacement_map;
- Fusion* fusion_;
- const std::vector<Expr*>& incoming_exprs_;
-
- // Keep all for loops conveniently to make unrolling easier
- std::vector<ForLoop*> for_loops;
-
- // keep track if we're within an unrolled loop
- bool within_unroll = false;
-
- // Custom dispatch for Expr, want to find out of it's a TV op
- void handle(Expr*) final;
-
- // Open the for loop.
- void handle(ForLoop*) final;
-
- UnrollPass(Fusion* _fusion, const std::vector<Expr*>& _incoming_exprs)
- : fusion_(_fusion), incoming_exprs_(_incoming_exprs) {}
-
- void computeMap();
-
- public:
- static std::vector<Expr*> runPass(
- Fusion* fusion,
- const std::vector<Expr*>& exprs);
-};
-
-struct TORCH_CUDA_API LoopNestGenerator : public OptOutDispatch {
- private:
+ // Lowered exprs to return
std::vector<Expr*> lowered_exprs;
+ // Fusion pointer for convenience
Fusion* fusion_;
- // Keep all for loops conveniently to make unrolling easier
+ // Keep all for loops conveniently to make unrolling easier, basically just a
+ // stack of the active for_loops
std::vector<ForLoop*> for_loops;
- // computeAT scope is determined by the iterat domain, and the tensor view it
- // belongs to (the final TensorView when following the computeAt path)
+
+ // Track the active computeAt scope, and what view we're "computeAt-ing" into
std::vector<std::pair<IterDomain*, TensorView*>> compute_at_scope;
- // Get Register allocation statement for tensorview
- void pushAlloc(TensorView*);
+ // Predicates from ThreadPredicates that we will extend to reduction buffer
+ // initialization
+ std::unordered_map<const TensorView*, Bool*>& thread_predicates_;
- // Open a new inner most for loop
+ // Create, place, and return the allocation for tv
+ Expr* pushAlloc(TensorView*);
+
+ // Open a new inner most for loop, track which TV it was constructed from
+ // according to the computeAt chain.
void openFor(std::pair<IterDomain*, TensorView*>);
+
+ // Close the inner most for loop
void popFor();
// Wrap pushBack in lower_utils if active_scope is null we want it to go
// straight to lower_exprs
void pushBack(Expr*);
- // Update for loop structure based on this TensorView
+ // Update for loop structure based on this TensorView, see implementation for
+ // more details
void updateLoopNest(TensorView*);
- // Update for loop structure based on this TensorView
- void initReduction(TensorView* tv, Val* init_val);
+ // Initialize a buffer to init_val. If this buffer is in smem or registers,
+ // pass in its allocation statement so we can make sure that we insert this
+ // initialization comes after the allocation.
+ void initReduction(TensorView* tv, Val* init_val, Expr* alloc_expr = nullptr);
- // Check if a TV op, generate for loop nest around it
+ // Check if expr is a TV op and handle accordingly.
void handle(Expr*) final;
- // Generate the loop nest structure and place it in lowered_exprs
+ // Run the pass and accumulate output in lowered_exprs
void generate(const std::vector<Expr*>& exprs);
- LoopNestGenerator(Fusion* _fusion) : fusion_(_fusion) {}
+ LoopNestGenerator(
+ Fusion* _fusion,
+ std::unordered_map<const TensorView*, Bool*>& _thread_predicates)
+ : fusion_(_fusion), thread_predicates_(_thread_predicates) {}
public:
static std::vector<Expr*> getLoopNest(
Fusion* fusion,
- std::vector<Expr*> exprs) {
+ std::vector<Expr*> exprs,
+ std::unordered_map<const TensorView*, Bool*>& thread_predicates) {
FusionGuard fg(fusion);
- LoopNestGenerator lng(fusion);
+ LoopNestGenerator lng(fusion, thread_predicates);
lng.generate(exprs);
return lng.lowered_exprs;
}
diff --git a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp
new file mode 100644
index 0000000..a39a825
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp
@@ -0,0 +1,187 @@
+#include <torch/csrc/jit/codegen/cuda/arith.h>
+#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
+#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
+
+#include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+
+const static std::unordered_map<ParallelType, int> pt_to_offset{
+ {ParallelType::BIDx, 0},
+ {ParallelType::BIDy, 1},
+ {ParallelType::BIDz, 2},
+ {ParallelType::TIDx, 3},
+ {ParallelType::TIDy, 4},
+ {ParallelType::TIDz, 5}};
+
+const static std::unordered_map<int, ParallelType> offset_to_pt{
+ {0, ParallelType::BIDx},
+ {1, ParallelType::BIDy},
+ {2, ParallelType::BIDz},
+ {3, ParallelType::TIDx},
+ {4, ParallelType::TIDy},
+ {5, ParallelType::TIDz}};
+
+static constexpr int num_p_type = 6;
+
+namespace {
+
+void flip_true(std::bitset<num_p_type>& bits, const ParallelType p_type) {
+ if (pt_to_offset.find(p_type) == pt_to_offset.end()) {
+ TORCH_INTERNAL_ASSERT(false, "Could not recognize parallel type.");
+ }
+ bits[pt_to_offset.at(p_type)] = true;
+}
+
+Val* threadPredicate(int i) {
+ if (offset_to_pt.find(i) == offset_to_pt.end()) {
+ TORCH_INTERNAL_ASSERT(
+ false,
+ "Invalid int for predicate computation, should be from [0-5], but recieved, ",
+ i,
+ ".");
+ }
+ return eq(
+ new NamedScalar(stringifyThread(offset_to_pt.at(i)), DataType::Int),
+ new Int(0));
+}
+
+Bool* getThreadPredicate(std::bitset<num_p_type> bits) {
+ if (bits.none())
+ return new Bool(true);
+
+ Val* pred = nullptr;
+
+ for (int i = 0; i < num_p_type; i++) {
+ if (bits[i]) {
+ if (pred == nullptr) {
+ pred = threadPredicate(i);
+ } else {
+ pred = andOp(pred, threadPredicate(i));
+ }
+ }
+ }
+
+ // Should never be hit.
+ TORCH_INTERNAL_ASSERT(pred != nullptr);
+
+ TORCH_INTERNAL_ASSERT(
+ pred->getDataType().value() == DataType::Bool,
+ "Tried to return a predicate that is not a bool val.");
+
+ return pred->as<Bool>();
+}
+
+} // namespace
+
+std::bitset<num_p_type> ThreadPredicates::getThreadPredicates(
+ const TensorView* tv) {
+ TORCH_INTERNAL_ASSERT(
+ thread_predicates.find(tv) != thread_predicates.end(),
+ "Invalid predicate initialization, couldn't find ",
+ tv);
+ return thread_predicates[tv];
+}
+
+// Update the reduction_deps bitset based on provided Expr
+void ThreadPredicates::updateBitSet(Expr* expr) {
+ // Which predicates were set for the inputs
+ std::bitset<num_p_type> input_preds;
+
+ // Which dims are reductions in inputs
+ std::bitset<num_p_type> input_reductions;
+
+ // Which dims are bcast in inputs
+ std::bitset<num_p_type> input_bcasts;
+
+ // Run through inputs and update bitsets
+ for (const auto* inp : expr->inputs()) {
+ if (!ir_utils::isTV(inp))
+ continue;
+
+ auto tv_inp = ir_utils::asConstTV(inp);
+ TORCH_INTERNAL_ASSERT(
+ thread_predicates.find(tv_inp) != thread_predicates.end(),
+ "Thread predicate map was not initialized, couldn't find ",
+ inp);
+
+ input_preds |= thread_predicates[tv_inp];
+
+ std::bitset<num_p_type> id_reductions;
+ std::bitset<num_p_type> id_bcasts;
+ std::bitset<num_p_type> id_ptypes;
+
+ for (auto id : tv_inp->domain()->domain()) {
+ if (id->isThread()) {
+ flip_true(id_ptypes, id->parallel_method());
+ if (id->isReduction())
+ flip_true(id_reductions, id->parallel_method());
+ if (id->isBroadcast())
+ flip_true(id_bcasts, id->parallel_method());
+ }
+ }
+
+ // Validate the combination of ptypes, reductions, bcasts
+ for (size_t i = 0; i < num_p_type; i++) {
+ if (input_reductions[i]) {
+ if (id_ptypes[i]) {
+ TORCH_INTERNAL_ASSERT(
+ id_reductions[i],
+ "Mismatched parallelized reductions found on inputs of epxr: ",
+ expr);
+ TORCH_CHECK(
+ !id_bcasts[i],
+ "Invalid broadcast and reduction combination, tried to parallelize both with the same thread dim: ",
+ inp);
+ }
+ }
+ }
+
+ // Accumulate
+ input_reductions |= id_reductions;
+ input_bcasts |= id_bcasts;
+ }
+
+ // Update map for this tv, before accumulating to other inputs
+ // Add any reductions this id has to any input predicates
+ auto output_preds = input_preds | input_reductions;
+
+ // Figure out which dims bcast wants to reset
+ auto bcast_reset_map = output_preds & input_bcasts;
+
+ // Flip it to make a bit mask
+ bcast_reset_map = ~bcast_reset_map;
+
+ // Get rid of any reductions which are bcasted
+ output_preds &= bcast_reset_map;
+
+ // Run through outputs and set bitset predicates
+ for (const auto* out : expr->outputs()) {
+ if (!ir_utils::isTV(out))
+ continue;
+ thread_predicates[ir_utils::asConstTV(out)] = output_preds;
+ }
+}
+ThreadPredicates::ThreadPredicates(Fusion* _fusion) : fusion_(_fusion) {
+ for (auto inp : fusion_->inputs())
+ if (ir_utils::isTV(inp))
+ thread_predicates[ir_utils::asConstTV(inp)] = std::bitset<num_p_type>();
+}
+
+std::unordered_map<const TensorView*, Bool*> ThreadPredicates::compute(
+ Fusion* fusion) {
+ ThreadPredicates tp(fusion);
+ for (auto expr : fusion->exprs(true))
+ tp.updateBitSet(expr);
+ std::unordered_map<const TensorView*, Bool*> preds;
+ for (auto entry : tp.thread_predicates) {
+ preds[entry.first] = getThreadPredicate(entry.second);
+ }
+ return preds;
+}
+
+} // namespace fuser
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.h b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.h
new file mode 100644
index 0000000..df09b72
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.h
@@ -0,0 +1,44 @@
+#pragma once
+#include <torch/csrc/WindowsTorchApiMacro.h>
+
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
+
+#include <bitset>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+
+class TORCH_CUDA_API ThreadPredicates {
+ private:
+ Fusion* fusion_;
+
+ /*
+ * Map from tensorview to bit set represnting <BIDx, BIDy, BIDz, TIDx, TIDy,
+ * TIDz> If any dependency of TV had a parallelized reduction, we will track
+ * it here. This will be used for predicate generation to prevent
+ * parallelization on that axis. This is important if we have a reduction on
+ * for example TIDx, as the reduced value is only valid on threadIdx.x == 0
+ * therefore if we use that value later in the kernel we have that predicate.
+ * If we follow a reduction parallelized on TIDx with a broadcast on TIDx we
+ * no longer need the predicate and can reset the bit accordingly
+ */
+ std::unordered_map<const TensorView*, std::bitset<6>> thread_predicates;
+
+ // Update the thread_predicates bitset based on provided Expr
+ void updateBitSet(Expr*);
+
+ // Safety wrapper to access thread_predicates
+ std::bitset<6> getThreadPredicates(const TensorView*);
+
+ ThreadPredicates(Fusion* _fusion);
+
+ public:
+ // Computes any thread predicates that need to be applied when computing a
+ // TensorView.
+ static std::unordered_map<const TensorView*, Bool*> compute(Fusion* fusion);
+};
+
+} // namespace fuser
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp
new file mode 100644
index 0000000..91c0d28
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp
@@ -0,0 +1,244 @@
+#include <torch/csrc/jit/codegen/cuda/arith.h>
+#include <torch/csrc/jit/codegen/cuda/index_compute.h>
+#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
+#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
+#include <torch/csrc/jit/codegen/cuda/predicate_compute.h>
+
+#include <torch/csrc/jit/codegen/cuda/lower_unroll.h>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+
+Bool* UnrollPass::getThreadPredicate(const TensorView* tv) {
+ TORCH_INTERNAL_ASSERT(
+ thread_predicates_.find(tv) != thread_predicates_.end(),
+ "Invalid predicate initialization, couldn't find ",
+ tv);
+ return thread_predicates_[tv];
+}
+
+// Custom dispatch for Expr, want to find out of it's a TV op
+void UnrollPass::handle(Expr* expr) {
+ OptOutDispatch::handle(expr);
+}
+
+namespace {
+Bool* getPredicate(TensorView* tv, std::vector<Val*> inds_, Bool* thread_pred) {
+ TORCH_INTERNAL_ASSERT(
+ inds_.size() == tv->nDims() ||
+ inds_.size() == tv->domain()->noReductions().size());
+
+ // Do we need to adjust for reduction axes?
+ bool reductions = inds_.size() != tv->nDims();
+
+ std::vector<Val*> inds;
+ if (reductions) {
+ for (size_t ind_i = 0, tv_i = 0; tv_i < tv->nDims();) {
+ if (tv->axis(tv_i++)->isReduction()) {
+ inds.push_back(new Int(0));
+ } else {
+ TORCH_INTERNAL_ASSERT(
+ ind_i < inds_.size(), "Ran out of indices to generate predicate.");
+ inds.push_back(inds_[ind_i++]);
+ }
+ }
+ } else {
+ inds = inds_;
+ }
+
+ if (tv->nDims() > inds.size()) {
+ for (decltype(tv->nDims()) i{0}; i < tv->nDims(); i++) {
+ if (tv->axis(i)->isReduction())
+ inds.insert(inds.begin() + i, new Int(0));
+ }
+ }
+ std::vector<Bool*> all_preds = PredicateCompute::computePredicates(
+ new TensorIndex(tv, IndexCompute::get(tv->domain(), inds)));
+
+ all_preds.push_back(thread_pred);
+
+ std::vector<Bool*> preds;
+
+ for (Bool* pred : all_preds)
+ if (!(pred->isConst()) || !(pred->isConst() && pred->value().value()))
+ preds.push_back(pred);
+
+ if (preds.size() == 0)
+ return new Bool(true);
+
+ Val* cond = preds[0];
+
+ for (decltype(preds.size()) i{1}; i < preds.size(); i++) {
+ cond = andOp(cond, preds[i]);
+ }
+
+ TORCH_INTERNAL_ASSERT(
+ cond->getValType().value() == ValType::Scalar &&
+ cond->getDataType().value() == DataType::Bool,
+ "Error computing predicate, should be returning a Bool, but returning ",
+ cond->getDataType().value());
+
+ return static_cast<Bool*>(cond);
+}
+} // namespace
+
+// This function is one huge mess that should be refactored.
+// It handles the unrolling and predicate generation
+void UnrollPass::handle(ForLoop* fl) {
+ // Setup for loop scoping
+ for_loops.push_back(fl);
+ bool prev_unroll = within_unroll;
+ within_unroll = ir_utils::isUnrolledFor(fl) || within_unroll;
+
+ for (auto expr : fl->body().exprs()) {
+ OptOutDispatch::handle(expr);
+ }
+
+ TensorView* out = nullptr;
+ bool has_global = false;
+ for (Expr* expr : fl->body().exprs())
+ if (ir_utils::isTVOp(expr)) {
+ // Predicate determining op for unroll
+ out = ir_utils::asTV(expr->output(0));
+ has_global = has_global || out->getMemoryType() == MemoryType::Global;
+ for (auto inp : expr->inputs())
+ if (ir_utils::isTV(inp))
+ has_global = has_global ||
+ ir_utils::asTV(inp)->getMemoryType() == MemoryType::Global;
+ }
+
+ bool has_TV_op = out != nullptr;
+
+ if (within_unroll && has_TV_op && has_global) {
+ // Setup unrolled loop information:
+
+ // Indices used to detect when we can unroll a loop safely
+ // For loops outside the unroll, it's just he index, for loops inside
+ // the unroll, if it's a thread it's the thread index, otherwise it's
+ // the size-1
+ std::vector<Val*> unroll_pred_inds;
+ auto it = for_loops.begin();
+ while (it != for_loops.end()) {
+ if (ir_utils::isUnrolledFor(*it))
+ break;
+ unroll_pred_inds.push_back((*it)->index());
+ it++;
+ }
+
+ TORCH_INTERNAL_ASSERT(
+ it != for_loops.end(),
+ "Error unrolling loops, expected an unrolled loop but wasn't found.");
+
+ // This is the outer most loop that needs to be unrolled
+ ForLoop* first_unroll = *it;
+
+ // Indicies inside the unroll
+ while (it != for_loops.end()) {
+ IterDomain* id = (*it)->iter_domain();
+ if (id->isThread())
+ unroll_pred_inds.push_back((*it)->index());
+ else
+ unroll_pred_inds.push_back(sub(id->extent(), new Int(1)));
+ it++;
+ }
+
+ // Make predicates for the unrolling, and the epilogue
+ Bool* unroll_predicate =
+ getPredicate(out, unroll_pred_inds, getThreadPredicate(out));
+ // Make the IfThenElse controlling the unrolling
+ IfThenElse* unroll_ite =
+ new IfThenElse(unroll_predicate, {}, {}, first_unroll->parentScope());
+
+ // Get the loop nest for the unrolled path
+ ForLoop* unrolled_loop =
+ scope_utils::cloneLoopNest(first_unroll, unroll_ite);
+ unroll_ite->body().push_back(unrolled_loop);
+
+ // Loop nest for inlined path
+ ForLoop* inlined_loop =
+ scope_utils::cloneLoopNest(first_unroll, unroll_ite);
+ unroll_ite->elseBody().push_back(inlined_loop);
+
+ // Inner most inlined loop
+ Expr* inner_most_inlined_loop =
+ scope_utils::firstInnerMostScope(inlined_loop);
+
+ loop_replacement_map.insert({first_unroll, unroll_ite});
+
+ for (auto expr : fl->body().exprs()) {
+ if (!ir_utils::isTVOp(expr))
+ continue;
+
+ // Setup the expressions that need predicates around them.
+ Bool* inline_predicate = getPredicate(
+ out, ir_utils::indices(for_loops), getThreadPredicate(out));
+
+ IfThenElse* inline_ite =
+ new IfThenElse(inline_predicate, {expr}, {}, inner_most_inlined_loop);
+ std::unordered_map<Expr*, Expr*> inline_replacement_map;
+ inline_replacement_map.emplace(std::pair<Expr*, Expr*>(expr, inline_ite));
+ scope_utils::replaceExprsInScope(
+ inner_most_inlined_loop, inline_replacement_map);
+
+ } // for expr
+ } else { // if(!within_unroll)
+ // modify in place, so grab a copy of exprs first.
+ const std::vector<Expr*> exprs = fl->body().exprs();
+
+ for (auto expr : exprs) {
+ if (!ir_utils::isTVOp(expr))
+ continue;
+
+ TensorView* out = ir_utils::asTV(ir_utils::asExpr(expr)->outputs()[0]);
+
+ Bool* pred = getPredicate(
+ out, ir_utils::indices(for_loops), getThreadPredicate(out));
+
+ // If we need a predicate, put expr inside an if then else
+ if (!(pred->isConst()) || !(pred->isConst() && pred->value().value())) {
+ IfThenElse* inline_ite =
+ new IfThenElse(pred, {expr}, {}, for_loops.back());
+ for_loops.back()->body().insert_before(expr, inline_ite);
+ for_loops.back()->body().erase(expr);
+ }
+ }
+ } // else (if(!within_unroll))
+
+ for_loops.pop_back();
+ within_unroll = prev_unroll;
+}
+
+// Generate the loop nest structure and place it in lowered_exprs
+void UnrollPass::computeMap() {
+ FusionGuard fg(fusion_);
+
+ // Run through loop nests and further lower the expressions
+ for (auto* expr : incoming_exprs_) {
+ OptOutDispatch::handle(expr);
+ }
+}
+
+std::vector<Expr*> UnrollPass::runPass(
+ Fusion* fusion,
+ const std::vector<Expr*>& exprs,
+ std::unordered_map<const TensorView*, Bool*>& thread_predicates) {
+ FusionGuard fg(fusion);
+ UnrollPass up(fusion, exprs, thread_predicates);
+ up.computeMap();
+ std::vector<Expr*> mutated_exprs;
+ for (Expr* expr : exprs) {
+ if (up.loop_replacement_map.find(expr) != up.loop_replacement_map.end()) {
+ mutated_exprs.push_back(up.loop_replacement_map[expr]);
+ } else {
+ if (ir_utils::isScope(expr))
+ scope_utils::replaceExprsInScope(expr, up.loop_replacement_map);
+ mutated_exprs.push_back(expr);
+ }
+ }
+ return mutated_exprs;
+}
+
+} // namespace fuser
+} // namespace jit
+} // namespace torch
\ No newline at end of file
diff --git a/torch/csrc/jit/codegen/cuda/lower_unroll.h b/torch/csrc/jit/codegen/cuda/lower_unroll.h
new file mode 100644
index 0000000..d421676
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/lower_unroll.h
@@ -0,0 +1,102 @@
+#pragma once
+#include <torch/csrc/WindowsTorchApiMacro.h>
+
+#include <torch/csrc/jit/codegen/cuda/dispatch.h>
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
+#include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h>
+
+#include <bitset>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+
+/*
+ * A bit deceptively: UnrollPass adds all predicates, so it needs to be run even
+ * if we don't unroll any loops.
+ *
+ * Unrolling pass will get IR that looks something like:
+ * for( i : I0o{ceil(I0/4)} ) {
+ * for( j : I1o{ceil(I1/128)} ) {
+ * for( k : I0i{4} )
+ * for( l : I1i{128} )
+ * T0[I0o{ceil(I0/4)}, I1o{ceil(I1/128)}, I0iU{4}, I1i{128}] = ...
+ *
+ * And it will return the following:
+ * for( i : I0o{ceil(I0/4)} ) {
+ * for( j : I1o{ceil(I1/128)} ) {
+ *
+ * if( i * 4 + 3 < I && j * 128 + 127 < J ){
+ * for( k : I0i{4} )
+ * for( l : I1i{128} )
+ * T0[ ( i * 4 + k ) * J + j * 128 + l ] = …
+ * } else {
+ * for( k : I0i{4} )
+ * for( l : I1i{128} )
+ * if( i * 4 + k < I && j * 128 + l < J)
+ * T0[ ( i * 4 + k ) * J + j * 128 + l ] = …
+ * }
+ *
+ * }
+ * }
+ *
+ * As can be seen it generates two sets of loops for I0i{4} and I1i{128}. The
+ * first set is protected by a predicate that makes sure there's a full internal
+ * tile we can iterate over. This way we remove the predicate nested in the
+ * inner most loop. There's of course a second set of loops, which has a
+ * predicate still in the inner most loop, making sure that we cover edges and
+ * corners.
+ */
+
+class TORCH_CUDA_API UnrollPass : public OptOutDispatch {
+ private:
+ // Wrapper to access thread_predicates_
+ Bool* getThreadPredicate(const TensorView*);
+
+ // We will track which loops in the incomming IR will be replaced and by what
+ std::unordered_map<Expr*, Expr*> loop_replacement_map;
+ // Hold on to a reference to the fusion for convenience
+ Fusion* fusion_;
+ // Hold on to the incoming exprs, but don't modify them. We don't set the
+ // Expr* to be const as Exprs' are const by virtue of their interface design
+ const std::vector<Expr*>& incoming_exprs_;
+
+ // Keep all for loops conveniently to make unrolling easier
+ std::vector<ForLoop*> for_loops;
+
+ // Map from TensorView
+ std::unordered_map<const TensorView*, Bool*>& thread_predicates_;
+
+ // keep track if we're within an unrolled loop
+ bool within_unroll = false;
+
+ // Custom dispatch for Expr, want to find out of it's a TV op
+ void handle(Expr*) final;
+
+ // Open the for loop.
+ void handle(ForLoop*) final;
+
+ // Constructor
+ UnrollPass(
+ Fusion* _fusion,
+ const std::vector<Expr*>& _incoming_exprs,
+ std::unordered_map<const TensorView*, Bool*>& _thread_predicates)
+ : fusion_(_fusion),
+ incoming_exprs_(_incoming_exprs),
+ thread_predicates_(_thread_predicates) {}
+
+ // Generate the for Expr replacement map
+ void computeMap();
+
+ public:
+ // Take the incoming fusion and exprs and run loop unrolling, returning the
+ // new IR.
+ static std::vector<Expr*> runPass(
+ Fusion* fusion,
+ const std::vector<Expr*>& exprs,
+ std::unordered_map<const TensorView*, Bool*>& thread_predicates);
+};
+
+} // namespace fuser
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp
index 208916e..75dcf0e 100644
--- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp
+++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp
@@ -9,7 +9,7 @@
// START SCOPE HELPER SYSTEMS
namespace {
-struct Loops : private OptInDispatch {
+class Loops : private OptInDispatch {
private:
std::deque<ForLoop*> loops;
void handle(ForLoop* fl) final {
@@ -34,7 +34,7 @@
}
};
-struct forLoopCount : private OptInDispatch {
+class forLoopCount : private OptInDispatch {
private:
unsigned int count_ = 0;
@@ -60,7 +60,7 @@
}
};
-struct scopePushBack : private OptInDispatch {
+class scopePushBack : private OptInDispatch {
private:
Expr* expr_;
void handle(ForLoop* fl) final {
@@ -87,7 +87,7 @@
}
};
-struct scopeInsertBefore : private OptInDispatch {
+class scopeInsertBefore : private OptInDispatch {
private:
Expr* ref_;
Expr* expr_;
@@ -115,7 +115,7 @@
}
};
-struct parentScope : private OptInDispatch {
+class parentScope : private OptInDispatch {
private:
Expr* parent_ = nullptr;
@@ -139,7 +139,7 @@
}
};
-struct scopeClearExprs : private OptInDispatch {
+class scopeClearExprs : private OptInDispatch {
private:
void handle(ForLoop* fl) final {
fl->body().clear();
@@ -169,7 +169,7 @@
"Assert Scope failed when calling a scope_util function.");
}
-struct CloneLoopNest : public OptOutMutator {
+class CloneLoopNest : public OptOutMutator {
private:
Expr* parent_scope_ = nullptr;
Expr* to_clone_ = nullptr;
@@ -199,50 +199,7 @@
}
};
-struct ReplaceExprsInScope : public OptOutDispatch {
- private:
- std::unordered_map<Expr*, Expr*> replacement_map_;
-
- void handle(Expr* expr) final {
- OptOutDispatch::handle(expr);
- }
-
- void handle(ForLoop* fl) final {
- for (Expr* expr : fl->body().exprs()) {
- auto it = replacement_map_.find(expr);
- if (it == replacement_map_.end()) {
- handle(expr);
- continue;
- }
- fl->body().insert_before(expr, replacement_map_[expr]);
- fl->body().erase(expr);
- }
- }
-
- void handle(IfThenElse* ite) final {
- for (Expr* expr : ite->body().exprs()) {
- auto it = replacement_map_.find(expr);
- if (it == replacement_map_.end()) {
- handle(expr);
- continue;
- }
- ite->body().insert_before(expr, replacement_map_[expr]);
- ite->body().erase(expr);
- }
- for (Expr* expr : ite->elseBody().exprs()) {
- auto it = replacement_map_.find(expr);
- if (it == replacement_map_.end()) {
- handle(expr);
- continue;
- }
- ite->elseBody().insert_before(expr, replacement_map_[expr]);
- ite->elseBody().erase(expr);
- }
- }
-
- ReplaceExprsInScope(std::unordered_map<Expr*, Expr*> _replacement_map)
- : replacement_map_(std::move(_replacement_map)) {}
-
+class ReplaceExprsInScope : public OptOutDispatch {
public:
static void replace(
Expr* scope,
@@ -250,9 +207,40 @@
ReplaceExprsInScope reis(std::move(replacement_map));
reis.handle(scope);
}
+
+ private:
+ explicit ReplaceExprsInScope(std::unordered_map<Expr*, Expr*> replacement_map)
+ : replacement_map_(std::move(replacement_map)) {}
+
+ void handleScope(Scope& scope) {
+ for (size_t i = 0; i < scope.size(); ++i) {
+ const auto it = replacement_map_.find(scope[i]);
+ if (it == replacement_map_.end()) {
+ handle(scope[i]);
+ continue;
+ }
+ scope[i] = it->second;
+ }
+ }
+
+ void handle(Expr* expr) final {
+ OptOutDispatch::handle(expr);
+ }
+
+ void handle(ForLoop* fl) final {
+ handleScope(fl->body());
+ }
+
+ void handle(IfThenElse* ite) final {
+ handleScope(ite->body());
+ handleScope(ite->elseBody());
+ }
+
+ private:
+ std::unordered_map<Expr*, Expr*> replacement_map_;
};
-struct FirstInnerMostScope : private OptInDispatch {
+class FirstInnerMostScope : private OptInDispatch {
private:
Expr* active_scope = nullptr;
@@ -295,6 +283,10 @@
FirstInnerMostScope fims;
Expr* inner = fims.getInner(scope);
+
+ if (inner == nullptr)
+ return scope;
+
while (fims.getInner(inner) != nullptr)
inner = fims.getInner(inner);
return inner;
@@ -411,13 +403,13 @@
return ids;
}
-bool isTV(const Val* const val) {
+bool isTV(const Val* val) {
return val->getValType().value() == ValType::TensorView;
}
// Check if we're a TensorView op that we can generate code for.
bool isTVOp(const Expr* expr) {
- if (expr->nOutputs() == 1 && isTV(expr->output(0)) &&
+ if (expr->outputs().size() == 1 && isTV(expr->output(0)) &&
(expr->getExprType().value() == ExprType::BinaryOp ||
expr->getExprType().value() == ExprType::UnaryOp ||
expr->getExprType().value() == ExprType::TernaryOp ||
@@ -462,7 +454,7 @@
return static_cast<ForLoop*>(expr);
}
-const TensorView* asConstTV(const Val* const val) {
+const TensorView* asConstTV(const Val* val) {
TORCH_INTERNAL_ASSERT(isTV(val));
return static_cast<const TensorView*>(val);
}
diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp
new file mode 100644
index 0000000..31a2fc2
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp
@@ -0,0 +1,130 @@
+#include <torch/csrc/jit/codegen/cuda/lower_validation.h>
+#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
+#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
+#include <torch/csrc/jit/codegen/cuda/type.h>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+
+// Some pre-compilation checks
+static void IrValidate(Fusion* fusion) {
+ fusion->validateInputs();
+ for (Val* val : fusion->vals()) {
+ if (ir_utils::isTV(val)) {
+ TensorView* tv = ir_utils::asTV(val);
+ for (decltype(tv->nDims()) i{0}; i < tv->nDims(); i++) {
+ IterDomain* id = tv->getComputeAtAxis(i).first;
+
+ if (id->isBlockDim()) {
+ TORCH_CHECK(
+ !id->isBroadcast(),
+ "Parallelization across blocks on broadcast axes is not supported, but found on, ",
+ tv,
+ ".");
+ }
+ }
+ }
+ }
+}
+
+// Remove circular computeAt references
+void IrFixComputeAt(Fusion* fusion) {
+ std::vector<Expr*> exprs = fusion->exprs(true);
+ std::set<TensorView*> visited;
+ for (auto it = exprs.rbegin(); it != exprs.rend(); it++) {
+ Expr* expr = *it;
+ if (!ir_utils::isTVOp(expr))
+ continue;
+
+ TensorView* tv = ir_utils::asTV(expr->output(0));
+ TensorView* ctv = tv->getComputeAtView();
+
+ if (ctv != nullptr && visited.find(ctv) == visited.end()) {
+ ctv->setComputeAt(tv, (int)tv->getThisComputeAtAxis());
+ tv->clearComputeAt();
+ }
+ visited.emplace(tv);
+ }
+}
+
+void IrBuildSizesMap(Fusion* fusion) {
+ // Sizes of inputs/outputs -> T.size[...]
+ std::unordered_map<Val*, Val*> size_map;
+
+ // Grab inputs and outputs
+ std::vector<TensorView*> inputs_and_outputs;
+ for (auto val : fusion->inputs()) {
+ if (ir_utils::isTV(val)) {
+ inputs_and_outputs.push_back(val->as<TensorView>());
+ }
+ }
+ for (auto val : fusion->outputs()) {
+ if (ir_utils::isTV(val)) {
+ inputs_and_outputs.push_back(val->as<TensorView>());
+ }
+ }
+
+ // Run through inputs and outputs first. Since we're replacing full
+ // tensorviews their names are going to change. We need the new referenc
+ // name for the inputs/outputs. This way we won't reference the wrong tensor
+ // view. For example T0 may be translated to T9. We don't want our new
+ // variable to be T0->size[...] we need it to be T9->size[...]
+ //
+ // This could be done in a better way but changing split/merge to be a
+ // TensorDomain focused operation, then we could simply do this process on
+ // domains, instead of tensorviews. This would have the benefit that the
+ // TensorView wouldn't change, so users pointers will remain valid. The other
+ // option which seems less elegant but would also work is build up the domain
+ // on the new tensor, and then simply replace it into the original one.
+ for (TensorView* tv : inputs_and_outputs) {
+ // Replace the domain with one based on Ti.size[j]
+ std::vector<IterDomain*> new_domain_iters;
+ const std::vector<IterDomain*>& root_td = tv->getRootDomain();
+
+ size_t dim = 0;
+ for (auto id : root_td) {
+ // Output sizes could have reduction axes, which isn't what gets output.
+ if (id->isReduction())
+ continue;
+
+ Val* orig_size = id->extent();
+
+ std::stringstream ss;
+ ss << "T" << tv->name() << ".size[" << dim++ << "]";
+ Val* new_size =
+ new NamedScalar(ss.str(), orig_size->getDataType().value());
+ if (!orig_size->sameAs(new_size) ||
+ size_map.find(orig_size) == size_map.end())
+ size_map[orig_size] = new_size;
+ }
+ }
+
+ fusion->setValuesMap(size_map);
+}
+
+void IrAdjustMemoryTypes(Fusion* fusion) {
+ for (auto val : fusion->deterministic_vals()) {
+ if (ir_utils::isTV(val)) {
+ auto tv = val->as<TensorView>();
+ if (fusion->hasInput(tv) || fusion->hasOutput(tv)) {
+ tv->setMemoryType(MemoryType::Global);
+ } else if (tv->getMemoryType() == MemoryType::Global) {
+ tv->setMemoryType(MemoryType::Local);
+ }
+ }
+ }
+}
+
+void PrepareForLowering(Fusion* fusion) {
+ FusionGuard fg(fusion);
+
+ IrFixComputeAt(fusion);
+ IrValidate(fusion);
+ IrBuildSizesMap(fusion);
+ IrAdjustMemoryTypes(fusion);
+}
+
+} // namespace fuser
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.h b/torch/csrc/jit/codegen/cuda/lower_validation.h
new file mode 100644
index 0000000..6990012
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/lower_validation.h
@@ -0,0 +1,45 @@
+#pragma once
+
+#include <torch/csrc/WindowsTorchApiMacro.h>
+
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+
+/*
+ * Currently this does the following:
+ *
+ * (1) Run a validation pass on the IR making sure there are no mistakes or
+ * unsupported scheduling.
+ *
+ * (2) Creates a mapping for symbolic sizes to named scalars
+ * i.e. T0[i0] -> T0[T0.size[0]]
+ *
+ * (3) Change computeAt structure to make sure computeAt structure follows the
+ * expression structure.
+ *
+ * (4) Adjust TensorView memory types to make sure they are valid
+ */
+
+void TORCH_CUDA_API PrepareForLowering(Fusion* fusion);
+
+// Compute at can have some circular references. Before we can call any tv
+// with tv->getComputeAtAxis(i) we need to break those circular dependencies.
+void IrFixComputeAt(Fusion* fusion);
+
+// TensorViews are all based on symbolic sizes. When we first initialize them we
+// don't know if they're inputs or outputs which would mean that they have
+// runtime shapes. Intermediate tensors (those not going to global memory) do
+// not have this information. Since we need to have the correct information in
+// the kernel being fetched for shapes, we want to replace input and output
+// tensors to reference the runtime structure containing sizes.
+void IrBuildSizesMap(Fusion* fusion);
+
+// Adjust memory types to make sure they are valid
+void IrAdjustMemoryTypes(Fusion* fusion);
+
+} // namespace fuser
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/manager.cpp b/torch/csrc/jit/codegen/cuda/manager.cpp
index 486b0b4..e8b029c 100644
--- a/torch/csrc/jit/codegen/cuda/manager.cpp
+++ b/torch/csrc/jit/codegen/cuda/manager.cpp
@@ -3,7 +3,6 @@
#include <torch/csrc/jit/codegen/cuda/kernel_cache.h>
#include <torch/csrc/jit/codegen/cuda/parser.h>
#include <torch/csrc/jit/codegen/cuda/shape_inference.h>
-#include <torch/csrc/jit/codegen/cuda/tensor_meta.h>
#include <torch/csrc/jit/codegen/cuda/utils.h>
#include <torch/csrc/jit/passes/canonicalize.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
@@ -26,15 +25,6 @@
return req_ptr;
}
-// TODO: contiguity could be used for better kernel launch config.
-TensorContiguity infer_contiguity_from_tensor_type(
- const std::shared_ptr<c10::TensorType>& tensor_type) {
- TORCH_INTERNAL_ASSERT(tensor_type->isComplete());
- return TensorContiguity(
- *(tensor_type->sizes().concrete_sizes()),
- *(tensor_type->strides().concrete_sizes()));
-}
-
// CudaFusionManager holds compiled `CudaKernel` and handles all interfacing
// including compilation and execution.
//
@@ -88,7 +78,8 @@
int32_t kernel_id,
std::shared_ptr<Graph>& graph,
const at::ArrayRef<IValue> inputs,
- std::vector<at::Tensor> outputs) {
+ const std::vector<at::Tensor>& outputs,
+ const std::vector<int64_t>& broadcasted_shape) {
std::lock_guard<std::mutex> guard(mutex_);
TORCH_CHECK(
kernel_cache_.count(kernel_id) != 0, "kernel id not recognized");
@@ -98,7 +89,7 @@
if (cuda_kernel) {
// TODO: update launch config for specific sizes;
// maybe we should store it in CudaKernel and compute it later
- runKernel(*cuda_kernel, inputs, outputs);
+ runKernel(*cuda_kernel, inputs, outputs, broadcasted_shape);
} else {
// TODO: this should somehow be done after kernel compilation.
// we will want compileKernel to return a heuristic
@@ -127,7 +118,7 @@
// NVRTC compile kernel
compileKernel(cuda_kernel.value());
- runKernel(*cuda_kernel, inputs, outputs);
+ runKernel(*cuda_kernel, inputs, outputs, broadcasted_shape);
}
}
@@ -164,7 +155,7 @@
fusion_node->i_(attr::cache_id, fusion_cache_id);
}
-void runCudaFusionGroup(const Node* const fusion_node, Stack& stack) {
+void runCudaFusionGroup(const Node* fusion_node, Stack& stack) {
TORCH_CHECK(
fusion_node->kind() == prim::CudaFusionGroup,
"prim::CudaFusionGroup expected");
@@ -179,22 +170,28 @@
std::shared_ptr<Graph> graph = fusion_node->g(attr::Subgraph)->copy();
auto execute_lambda = [&]() {
- auto nInputs = graph->inputs().size();
+ const auto nInputs = graph->inputs().size();
at::ArrayRef<IValue> inputs = last(stack, nInputs);
// shape inference in graph
// update shape information per the new inputs;
EraseShapeInformation(graph);
- for (decltype(nInputs) i = 0; i < nInputs; i++) {
+ for (size_t i = 0; i < nInputs; i++) {
graph->inputs()[i]->setType(inputs[i].type());
}
// shape inference
ShapeTypePropagate(graph);
+ // TODO: temporary WAR that allows us to handle fusion with uniform output
+ // shape and consistent broadcast scheme. The difinition is loose and the
+ // implementation is risky. We'll do this properly when we integrate proper
+ // broadcast support.
+ std::vector<int64_t> broadcasted_shape;
+
// we need to construct outputs;
std::vector<at::Tensor> outputs;
- for (const auto* const output : graph->outputs()) {
- auto type = output->type()->expect<TensorType>();
+ for (const auto* output : graph->outputs()) {
+ const auto type = output->type()->expect<TensorType>();
// Expect output to be tensor;
TORCH_CHECK(
type && type->isComplete(),
@@ -213,11 +210,34 @@
const auto sizes = extractSizes(type);
const auto strides = extractStrides(type);
- auto tensor = at::empty_strided(sizes, strides, options);
+ const auto tensor = at::empty_strided(sizes, strides, options);
outputs.push_back(tensor);
+
+ // TODO: unsafe broadcast assumption. We assume all output from fusion has
+ // identical size when broadcasting.
+ if (broadcasted_shape.empty()) {
+ if (!hasReductionNode(graph->block())) {
+ broadcasted_shape = sizes;
+ } else if (isReductionNode(output->node())) {
+ auto i_type =
+ output->node()->inputs()[0]->type()->expect<TensorType>();
+ TORCH_CHECK(
+ i_type && i_type->sizes().isComplete(),
+ "Complete TensorType for output is expected.");
+ broadcasted_shape = extractSizes(i_type);
+ } else {
+ // TODO: this assert is not fool proof. We could have ignored
+ // pre-reduction tensor marked as output after we first encountered
+ // reduction output tensor.
+ TORCH_INTERNAL_ASSERT(
+ false,
+ "pre-reduction tensor output for reduction fusion is nor properly supported yet.");
+ }
+ }
}
+
CudaFusionManager::getManager().runFusionNode(
- kernel_id, graph, inputs, outputs);
+ kernel_id, graph, inputs, outputs, broadcasted_shape);
drop(stack, inputs.size());
stack.insert(
stack.end(),
diff --git a/torch/csrc/jit/codegen/cuda/manager.h b/torch/csrc/jit/codegen/cuda/manager.h
index f5ce2d5..9fb1bd1 100644
--- a/torch/csrc/jit/codegen/cuda/manager.h
+++ b/torch/csrc/jit/codegen/cuda/manager.h
@@ -28,9 +28,7 @@
// Current protocol is that the function allocates output tensor append them to
// `stack` after execution.
// TODO: support shape inferencing. Right now we only handles static shape
-TORCH_CUDA_API void runCudaFusionGroup(
- const Node* const fusion_node,
- Stack& stack);
+TORCH_CUDA_API void runCudaFusionGroup(const Node* fusion_node, Stack& stack);
TORCH_CUDA_API void CudaFuseGraph(std::shared_ptr<Graph>& graph);
diff --git a/torch/csrc/jit/codegen/cuda/mutator.h b/torch/csrc/jit/codegen/cuda/mutator.h
index c8b4c6f..822fba5 100644
--- a/torch/csrc/jit/codegen/cuda/mutator.h
+++ b/torch/csrc/jit/codegen/cuda/mutator.h
@@ -11,7 +11,7 @@
namespace jit {
namespace fuser {
-struct Fusion;
+class Fusion;
/*
* Mutators are the mechanism used to modify IR nodes. Since most nodes are
@@ -25,7 +25,7 @@
// Search through "within" and replace all instances of "instance" with the
// value "with".
-struct TORCH_CUDA_API ReplaceAll : public OptOutMutator {
+class TORCH_CUDA_API ReplaceAll : public OptOutMutator {
private:
// Will look in fusion and if we're replacing an input or output we register
// those changes
diff --git a/torch/csrc/jit/codegen/cuda/parser.cpp b/torch/csrc/jit/codegen/cuda/parser.cpp
index a70640b..e266f0a 100644
--- a/torch/csrc/jit/codegen/cuda/parser.cpp
+++ b/torch/csrc/jit/codegen/cuda/parser.cpp
@@ -19,23 +19,74 @@
namespace fuser {
namespace cuda {
-constexpr auto NUM_UNARY_OPS = 31;
-constexpr auto NUM_BINARY_OPS = 24;
-constexpr auto NUM_BINARY_OPS_WITH_ALPHA = 4;
-constexpr auto NUM_LERP_OPS = 2;
+constexpr auto kNumUnaryOps = 31;
+constexpr auto kNumBinaryOps = 24;
+constexpr auto kNumBinaryOpsWithAlpha = 4;
+constexpr auto kNumLerpOps = 2;
namespace {
typedef Val* CgValue;
typedef Expr* CgOp;
-typedef void (
- *ParseFuncPtr)(const Node* const, std::unordered_map<size_t, CgValue>&);
+typedef void (*ParseFuncPtr)(const Node*, std::unordered_map<size_t, CgValue>&);
+typedef bool (*MergeQueryFuncPtr)(const Node*);
+
+std::vector<int> reductionAxes(TensorView* tv) {
+ size_t n_dims = tv->nDims();
+ std::vector<int> reduction_axes;
+ for (size_t i = 0; i < n_dims; i++) {
+ if (tv->axis(i)->isReduction()) {
+ reduction_axes.emplace_back(i);
+ }
+ }
+ return reduction_axes;
+}
+
+// coalesces all reduction to the right side and returns total number of
+// reduction axes
+size_t coalescReduction(TensorView* tv) {
+ auto reduction_axes = reductionAxes(tv);
+ size_t n_dims = tv->nDims();
+ std::unordered_map<int, int> coalesc_permute;
+ for (size_t i = 0; i < reduction_axes.size(); i++) {
+ size_t new_pos = i + n_dims - reduction_axes.size();
+ if (new_pos == size_t(reduction_axes[i])) {
+ break;
+ } else {
+ coalesc_permute[reduction_axes[i]] = new_pos;
+ }
+ }
+ if (!coalesc_permute.empty()) {
+ tv->reorder(coalesc_permute);
+ }
+ return reduction_axes.size();
+}
// TODO: add a mutex to make it thread safe.
class IrParser {
+ class RegistrationEntry {
+ public:
+ RegistrationEntry(ParseFuncPtr parse_f, MergeQueryFuncPtr merge_f = nullptr)
+ : parse_f_(parse_f), merge_f_(merge_f) {}
+
+ void parse(const Node* node, std::unordered_map<size_t, CgValue>& values) {
+ parse_f_(node, values);
+ }
+
+ bool is_compatible(const Node* node) {
+ if (merge_f_ == nullptr) {
+ return true;
+ }
+ return merge_f_(node);
+ }
+
+ private:
+ ParseFuncPtr parse_f_;
+ MergeQueryFuncPtr merge_f_;
+ };
+
private:
- static const int nthreads = 128;
static const int unroll_factor = 4;
public:
@@ -52,18 +103,45 @@
FusionGuard fg(cuda_kernel_->fusion_.get());
auto block = graph_->block();
- // in case of broadcast, we don't support explicit broadcast, so we need to
- // convert/expand all inputs tensors to comply to the broadcasted size.
- // This supports very limited case, which we try to accomodate in graph
- // partition, that we only merge nodes with identical output shapes.
- int broadcast_dim =
- block->outputs()[0]->type()->cast<TensorType>()->dim().value();
+ // [ Note - broadcast support in integration ]
+ //
+ // in case of broadcast, we don't support explicit broadcast,
+ // 1. for point-wise fusion, so we need to convert/expand all inputs
+ // tensors to comply to the broadcasted size. This supports very limited
+ // case, which we try to accomodate in graph partition, that we only merge
+ // nodes with identical output shapes.
+ // 2. in case of reduction-at-end fusion, right now we only support single
+ // reduction operation in fusion, hence we can use the same logig for PW
+ // fusion and conver/expand all inputs to the input tensor to reduction op.
+
+ // TODO: proper broadcast support in integration
+ int broadcast_dim = -1;
+ // broadcast support hack is disabled to reduction.
+ if (hasReductionNode(graph_->block())) {
+ // reduction-at-end fusion, broadcast all inputs to tensor before
+ // reduction
+ // TODO: Not perfectly safe! We could have intermediate output that is not
+ // part of outputs of reduction operations. But we have similar limitation
+ // for broadcast support in PW fusion. We should properly fix this after
+ // broadcast integration.
+ broadcast_dim = block->outputs()[0]
+ ->node()
+ ->inputs()[0]
+ ->type()
+ ->cast<TensorType>()
+ ->dim()
+ .value();
+ } else {
+ // point-wise fusion, broadcast all inputs to output size.
+ broadcast_dim =
+ block->outputs()[0]->type()->cast<TensorType>()->dim().value();
+ }
// register all inputs;
// shape propagation during parsing is effctively done in parsing rules, as
// we only explicitly register inputs in the graph.
for (auto val : block->inputs()) {
- TORCH_CHECK(registerValue(val, broadcast_dim));
+ TORCH_INTERNAL_ASSERT(registerValue(val, broadcast_dim));
cuda_kernel_->fusion_->addInput(value_map_[val->unique()]);
auto opt_dtype = value_map_[val->unique()]->getDataType();
@@ -78,12 +156,17 @@
// TODO: disable unroll to ensure rand_like generates identical output as
// with eager mode
bool disable_unroll = false;
+ bool has_reduction = false;
+ bool fcd_reduction = false;
// compose nodes in topo order;
for (const JitOp* node : block->nodes()) {
processJitNode(node);
if (node->kind() == aten::rand_like) {
disable_unroll = true;
}
+ if (node->kind() == aten::sum) {
+ has_reduction = true;
+ }
}
// mark output;
@@ -102,51 +185,134 @@
cuda_kernel_->fusion_->addOutput(out);
- // Merge all dimensions because we're only supporting pointwise
- while (out->nDims() > 1)
- out->merge(0, 1);
- // Split into 128 which will be bockDim.x
- out->split(0, nthreads);
- // Split by another 4 which will be our unroll factor
- auto ur_factor = disable_unroll ? 1 : unroll_factor;
- if (!disable_unroll) {
- out->split(0, ur_factor);
- cuda_kernel_->unroll_factor_ = ur_factor;
- }
- }
+ // TODO: has_reduction for scheudling should be done on a per output
+ // tensor basis.
+ if (has_reduction) {
+ // TODO: this scheduling only works for a single reduction operation in
+ // the fusion, in this case we can coalesc all reduction axes and
+ // merge them together. (same applies to iteration axes)
+ // TODO: does this work for multiple outputs?
- // Run through outputs, grab all inputs of outputs
- // squeeze with computeAt to set overall structure.
- for (auto output : cuda_kernel_->fusion_->outputs()) {
- if (output->getValType() != ValType::TensorView)
- continue;
- TensorView* out_tv = static_cast<TensorView*>(output);
- for (Val* inp : cuda_kernel_->fusion_->inputsOf(output)) {
- if (inp->getValType().value() == ValType::TensorView)
- static_cast<TensorView*>(inp)->computeAt(out_tv, 1);
- }
- out_tv->axis(0)->parallelize(ParallelType::BIDx);
- }
+ // query if fastest changing dimension (FCD) is a reduction
+ fcd_reduction = out->axis((int)out->nDims() - 1)->isReduction();
- // Run through intermediates, unroll, and bind their axes
- for (auto val : cuda_kernel_->fusion_->vals()) {
- if (val->getValType().value() != ValType::TensorView)
- continue;
- TensorView* tv = static_cast<TensorView*>(val);
+ // TODO: could really use evaluation here. Launch configuration is
+ // imposed by transformation and the information should be
+ // embedded in codegen IR.
+ cuda_kernel_->reduction_axes_ = reductionAxes(out);
- // Should be true for all intermediates, but if one isn't hooked
- // up right, skip it and hope for the best for now
- if (!disable_unroll && tv->nDims() == 3) {
- tv->axis(-2)->parallelize(ParallelType::Unroll);
- tv->axis(-1)->parallelize(ParallelType::TIDx);
+ // We coalesc all reduction axes to the right;
+ size_t num_reduction_axes = coalescReduction(out);
+
+ // Merge all iteration dimensions
+ while (out->nDims() > num_reduction_axes + 1) {
+ out->merge(0, 1);
+ }
+ // Merge all reduction dimensions
+ while (out->nDims() > 2) {
+ out->merge(1, 2);
+ }
+
} else {
- if (tv->nDims() == 2)
+ // Merge all dimensions because we're only supporting pointwise
+ while (out->nDims() > 1)
+ out->merge(0, 1);
+ // Split into 128 which will be bockDim.x
+ out->split(0, kPwThreadX);
+ // Split by another 4 which will be our unroll factor
+ auto ur_factor = disable_unroll ? 1 : unroll_factor;
+ if (!disable_unroll) {
+ out->split(0, ur_factor);
+ cuda_kernel_->unroll_factor_ = ur_factor;
+ }
+ }
+ }
+
+ if (has_reduction) {
+ // Run through outputs, grab all inputs of outputs
+ // squeeze with computeAt to set overall structure.
+ for (auto output : cuda_kernel_->fusion_->outputs()) {
+ if (output->getValType() != ValType::TensorView)
+ continue;
+ TensorView* out_tv = static_cast<TensorView*>(output);
+
+ // fcd_reduction could be queried later via
+ // cuda_kernel_->reduction_axes_, which would ensure we have proper
+ // launch configuratoin.
+ TensorView* intermediate;
+ if (fcd_reduction) {
+ out_tv->split(-1, kFcdReductionThreadX);
+ // necessary to avoid dynamic allocation on intermediates;
+ intermediate = out_tv->rFactor({-2});
+ } else {
+ // TODO: we don't need a full warp here, this should be determined by
+ // element data type
+ out_tv->split(0, kNonFcdReductionThreadX);
+ out_tv->split(
+ -1, kNonFcdReductionThreadY); // necessary to avoid dynamic
+ // allocation on intermediates;
+ intermediate = out_tv->rFactor({-2});
+ }
+ for (Val* inp : cuda_kernel_->fusion_->inputsOf(output)) {
+ // scheduling of inputs shouldn't change with different fcd_reduction
+ if (inp->getValType().value() == ValType::TensorView) {
+ static_cast<TensorView*>(inp)->computeAt(intermediate, -1);
+ }
+ }
+ // scheduling of inputs shouldn't change with different fcd_reduction
+ intermediate->computeAt(out_tv, -2);
+ if (fcd_reduction) {
+ out_tv->axis(0)->parallelize(ParallelType::BIDx);
+ } else {
+ out_tv->axis(0)->parallelize(ParallelType::BIDx);
+ out_tv->axis(1)->parallelize(ParallelType::TIDx);
+ }
+ }
+ // Run through all values, unroll, and bind their axes
+ for (auto val : cuda_kernel_->fusion_->vals()) {
+ if (val->getValType().value() != ValType::TensorView)
+ continue;
+ TensorView* tv = static_cast<TensorView*>(val);
+ if (fcd_reduction) {
tv->axis(-1)->parallelize(ParallelType::TIDx);
+ } else {
+ tv->axis(-1)->parallelize(ParallelType::TIDy);
+ }
+ }
+ } else {
+ // Run through outputs, grab all inputs of outputs
+ // squeeze with computeAt to set overall structure.
+ for (auto output : cuda_kernel_->fusion_->outputs()) {
+ if (output->getValType() != ValType::TensorView)
+ continue;
+ TensorView* out_tv = static_cast<TensorView*>(output);
+ for (Val* inp : cuda_kernel_->fusion_->inputsOf(output)) {
+ if (inp->getValType().value() == ValType::TensorView)
+ static_cast<TensorView*>(inp)->computeAt(out_tv, 1);
+ }
+ out_tv->axis(0)->parallelize(ParallelType::BIDx);
+ }
+
+ // Run through all values, unroll, and bind their axes
+ for (auto val : cuda_kernel_->fusion_->vals()) {
+ if (val->getValType().value() != ValType::TensorView)
+ continue;
+ TensorView* tv = static_cast<TensorView*>(val);
+
+ // Should be true for all intermediates, but if one isn't hooked
+ // up right, skip it and hope for the best for now
+ if (!disable_unroll && tv->nDims() == 3) {
+ tv->axis(-2)->parallelize(ParallelType::Unroll);
+ tv->axis(-1)->parallelize(ParallelType::TIDx);
+ } else {
+ if (tv->nDims() == 2)
+ tv->axis(-1)->parallelize(ParallelType::TIDx);
+ }
}
}
}
- static bool canParseNode(const Node* const node) {
+ static bool canParseNode(const Node* node) {
if (init_registry_) {
// TODO: mutex this guy;
registerJitOperator();
@@ -160,17 +326,39 @@
}
for (auto& pair_op_func : iter->second) {
if (node->matches(pair_op_func.first->schema())) {
- return true;
+ return pair_op_func.second.is_compatible(node);
}
}
return false;
}
+ static bool isReductionNode(const Node* node) {
+ if (init_registry_) {
+ // TODO: mutex this guy;
+ registerJitOperator();
+ init_registry_ = false;
+ }
+
+ return jit_reduction_op_registry_.count(node->kind());
+ }
+
+ // TODO: is_reduction is too hacky here. we should categorize operation types
+ // based on their memory accessing pattern, which would affect fusion
+ // strategy and partition logic.
static void registerParseRule(
std::shared_ptr<Operator>& op,
- ParseFuncPtr fn) {
+ ParseFuncPtr parse_fn,
+ MergeQueryFuncPtr merge_query_fn = nullptr,
+ bool is_reduction = false) {
jit_operator_registry_[Symbol::fromQualString(op->schema().name())]
- .emplace_back(std::make_pair(op, fn));
+ .emplace_back(
+ std::piecewise_construct,
+ std::forward_as_tuple(op),
+ std::forward_as_tuple(parse_fn, merge_query_fn));
+ if (is_reduction) {
+ jit_reduction_op_registry_.emplace(
+ Symbol::fromQualString(op->schema().name()));
+ }
}
private:
@@ -179,7 +367,7 @@
// This is a one-time look up, our hash registry indexes on the pointer in
// OperatorRegistry.
- std::array<const char*, NUM_BINARY_OPS_WITH_ALPHA> BinaryOpWithAlpha = {
+ std::array<const char*, kNumBinaryOpsWithAlpha> BinaryOpWithAlpha = {
"aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
"aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
"aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
@@ -188,7 +376,7 @@
auto ptr_op = getOperatorForLiteral(signature);
registerParseRule(
ptr_op,
- [](const Node* const node,
+ [](const Node* node,
std::unordered_map<size_t, CgValue>& value_map) -> void {
using BinaryOpWithAlphaType = Val* (*)(Val*, Val*, Val*);
static std::unordered_map<
@@ -218,7 +406,7 @@
});
}
- std::array<const char*, NUM_BINARY_OPS> BinaryOp = {
+ std::array<const char*, kNumBinaryOps> BinaryOp = {
"aten::div(Tensor self, Tensor other) -> Tensor",
"aten::div(Tensor self, Scalar other) -> Tensor",
"aten::mul(Tensor self, Tensor other) -> Tensor",
@@ -247,7 +435,7 @@
auto ptr_op = getOperatorForLiteral(signature);
registerParseRule(
ptr_op,
- [](const Node* const node,
+ [](const Node* node,
std::unordered_map<size_t, CgValue>& value_map) -> void {
static std::unordered_map<Symbol, BinaryOpType> op_mapping(
{{aten::div, BinaryOpType::Div},
@@ -275,7 +463,7 @@
}
// TODO: cast operations should be merged in.
- std::array<const char*, NUM_UNARY_OPS> UnaryOp = {
+ std::array<const char*, kNumUnaryOps> UnaryOp = {
"aten::neg(Tensor self) -> Tensor",
"aten::abs(Tensor self) -> Tensor",
"aten::log(Tensor self) -> Tensor",
@@ -312,7 +500,7 @@
auto ptr_op = getOperatorForLiteral(signature);
registerParseRule(
ptr_op,
- [](const Node* const node,
+ [](const Node* node,
std::unordered_map<size_t, CgValue>& value_map) -> void {
static std::unordered_map<Symbol, UnaryOpType> op_mapping({
{aten::neg, UnaryOpType::Neg},
@@ -359,7 +547,7 @@
"aten::rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor");
registerParseRule(
ptr_op,
- [](const Node* const node,
+ [](const Node* node,
std::unordered_map<size_t, CgValue>& value_map) -> void {
auto operand = value_map[node->inputs()[0]->unique()];
@@ -373,7 +561,7 @@
"aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor");
registerParseRule(
ptr_op,
- [](const Node* const node,
+ [](const Node* node,
std::unordered_map<size_t, CgValue>& value_map) -> void {
auto operand = value_map[node->inputs()[0]->unique()];
auto th = value_map[node->inputs()[1]->unique()];
@@ -389,7 +577,7 @@
"aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor");
registerParseRule(
ptr_op,
- [](const Node* const node,
+ [](const Node* node,
std::unordered_map<size_t, CgValue>& value_map) -> void {
auto operand = value_map[node->inputs()[0]->unique()];
// TODO: we need to get a proper lower bound per dtype in operand.
@@ -410,7 +598,7 @@
"aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor");
registerParseRule(
ptr_op,
- [](const Node* const node,
+ [](const Node* node,
std::unordered_map<size_t, CgValue>& value_map) -> void {
auto condition = value_map[node->inputs()[0]->unique()];
auto x = value_map[node->inputs()[1]->unique()];
@@ -422,14 +610,14 @@
}
{
- std::array<const char*, NUM_LERP_OPS> LerpOp = {
+ std::array<const char*, kNumLerpOps> LerpOp = {
"aten::lerp(Tensor self, Tensor end, Scalar weight) -> Tensor",
"aten::lerp(Tensor self, Tensor end, Tensor weight) -> Tensor"};
for (auto signature : LerpOp) {
auto ptr_op = getOperatorForLiteral(signature);
registerParseRule(
ptr_op,
- [](const Node* const node,
+ [](const Node* node,
std::unordered_map<size_t, CgValue>& value_map) -> void {
auto self = value_map[node->inputs()[0]->unique()];
auto end = value_map[node->inputs()[1]->unique()];
@@ -446,7 +634,7 @@
"aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor");
registerParseRule(
ptr_op,
- [](const Node* const node,
+ [](const Node* node,
std::unordered_map<size_t, CgValue>& value_map) -> void {
auto self = value_map[node->inputs()[0]->unique()];
auto tensor1 = value_map[node->inputs()[1]->unique()];
@@ -457,6 +645,48 @@
value_map.emplace(node->output()->unique(), out);
});
}
+
+ {
+ auto ptr_op = getOperatorForLiteral(
+ "aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)");
+ registerParseRule(
+ ptr_op,
+ [](const Node* node,
+ std::unordered_map<size_t, CgValue>& value_map) -> void {
+ auto self = value_map[node->input(0)->unique()];
+ auto dims_list = constant_as<c10::List<int64_t>>(node->input(1));
+ TORCH_INTERNAL_ASSERT(
+ dims_list.has_value(), "requires static reduce axes");
+ auto keepdim = constant_as<bool>(node->input(2));
+ std::vector<int> dims;
+ for (const auto dim : dims_list->vec()) {
+ dims.emplace_back(static_cast<int>(dim));
+ }
+ TORCH_INTERNAL_ASSERT(
+ keepdim.has_value() && !keepdim.value(),
+ "Keep dim in reduction is not a const false");
+ auto out = sum(self->as<TensorView>(), dims);
+ value_map.emplace(node->output()->unique(), out);
+ },
+ [](const Node* node) -> bool {
+ // we don't support cast of output types yet;
+ if (!node->inputs()[3]->type()->isSubtypeOf(
+ static_cast<c10::TypePtr>(NoneType::get()))) {
+ return false;
+ }
+ // we don't support dynamic reduction axes;
+ if (node->inputs()[1]->node()->kind() != prim::Constant) {
+ return false;
+ }
+ // we don't support keepdim yet;
+ if (node->inputs()[2]->node()->kind() != prim::Constant ||
+ *constant_as<bool>(node->input(2))) {
+ return false;
+ }
+ return true;
+ },
+ true);
+ }
}
void processJitNode(const JitOp* node) {
@@ -464,22 +694,27 @@
// partition doesn't take constant node explicitly, but it does and copy
// constant into subgraph. So we need to register constants in codegen IR;
for (auto output : node->outputs()) {
- TORCH_CHECK(registerScalar(output));
+ TORCH_INTERNAL_ASSERT(
+ registerScalar(output),
+ "registration of output failed at index ",
+ output->offset(),
+ " for node ",
+ *node);
}
} else {
auto iter = IrParser::jit_operator_registry_.find(node->kind());
// make sure we have a parser for the op;
- TORCH_CHECK(
+ TORCH_INTERNAL_ASSERT(
iter != IrParser::jit_operator_registry_.end(),
"CudaFusionGroup Parser doesn't handle operator kind(): ",
node->kind().toDisplayString());
for (auto& pair_op_func : iter->second) {
if (node->matches(pair_op_func.first->schema())) {
- pair_op_func.second(node, value_map_);
+ pair_op_func.second.parse(node, value_map_);
return;
}
}
- TORCH_CHECK(
+ TORCH_INTERNAL_ASSERT(
false,
"CudaFusionGroup Parser doesn't recognize operator overload:",
canonicalSchemaString(node->schema()));
@@ -511,9 +746,24 @@
value_map_.emplace(val->unique(), cg_val);
return true;
} else if (val->type()->isSubtypeOf(
+ static_cast<c10::TypePtr>(BoolType::get()))) {
+ CgValue cg_val;
+ if (auto ival = constant_as<bool>(val)) {
+ cg_val = new Bool(ival.value());
+ } else {
+ cg_val = new Bool();
+ }
+ value_map_.emplace(val->unique(), cg_val);
+ return true;
+ } else if (val->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get()))) {
// TODO: should we consider adding support for NoneType;
return true;
+ } else if (val->type()->cast<ListType>()) {
+ // TODO: we don't support list type in codegen yet;
+ // This is a WAR to allow axes of reduction to be passed as constant list;
+ // We simply ignore conversion if the scalar value is a constant;
+ return toIValue(val).has_value();
}
return false;
}
@@ -524,6 +774,9 @@
// TODO: make this a static function in Tensor class;
// create tensor;
if (broadcast_dim >= 0) {
+ TORCH_INTERNAL_ASSERT(
+ broadcast_dim >= (int)*tensor_type->dim(),
+ "attempt to broadcast a tensor to shrinked dimension is invalid");
tensor_type = tensor_type->withDim(broadcast_dim);
}
// TODO: make this a static function in Tensor class;
@@ -543,20 +796,41 @@
// parsing rule registry.
static std::unordered_map<
Symbol,
- std::vector<std::pair<std::shared_ptr<Operator>, ParseFuncPtr>>>
+ std::vector<std::pair<std::shared_ptr<Operator>, RegistrationEntry>>>
jit_operator_registry_;
+ static std::unordered_set<Symbol> jit_reduction_op_registry_;
static bool init_registry_;
};
std::unordered_map<
Symbol,
- std::vector<std::pair<std::shared_ptr<Operator>, ParseFuncPtr>>>
+ std::vector<
+ std::pair<std::shared_ptr<Operator>, IrParser::RegistrationEntry>>>
IrParser::jit_operator_registry_;
+std::unordered_set<Symbol> IrParser::jit_reduction_op_registry_;
bool IrParser::init_registry_ = true;
} // namespace
-bool isNodeParsible(const Node* const node) {
+bool hasReductionNode(const Block* block) {
+ for (auto node : block->nodes()) {
+ if (isReductionNode(node)) {
+ return true;
+ }
+ for (auto block : node->blocks()) {
+ if (hasReductionNode(block)) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+bool isReductionNode(const Node* node) {
+ return IrParser::isReductionNode(node);
+}
+
+bool isNodeParsible(const Node* node) {
return IrParser::canParseNode(node);
}
diff --git a/torch/csrc/jit/codegen/cuda/parser.h b/torch/csrc/jit/codegen/cuda/parser.h
index 53eebef..38c0cb2 100644
--- a/torch/csrc/jit/codegen/cuda/parser.h
+++ b/torch/csrc/jit/codegen/cuda/parser.h
@@ -26,8 +26,17 @@
namespace fuser {
namespace cuda {
+constexpr int kPwThreadX = 128;
+constexpr int kFcdReductionThreadX = 128;
+constexpr int kNonFcdReductionThreadX = 32;
+constexpr int kNonFcdReductionThreadY = 32;
+
+TORCH_CUDA_API bool hasReductionNode(const Block* block);
+
+TORCH_CUDA_API bool isReductionNode(const Node* node);
+
// returns whether or not a parsing function exists for the given node type.
-TORCH_CUDA_API bool isNodeParsible(const Node* const node);
+TORCH_CUDA_API bool isNodeParsible(const Node* node);
// lowers PyTorch jit graph to `Fusion`.
TORCH_CUDA_API void parseJitIR(
diff --git a/torch/csrc/jit/codegen/cuda/partition.cpp b/torch/csrc/jit/codegen/cuda/partition.cpp
index e96d97f..09eb585 100644
--- a/torch/csrc/jit/codegen/cuda/partition.cpp
+++ b/torch/csrc/jit/codegen/cuda/partition.cpp
@@ -13,7 +13,7 @@
// 1. TensorType
// 2. on the same device;
// TODO: update this when codegen can output scalar
-static c10::optional<c10::Device> getDevice(const Value* const value) {
+static c10::optional<c10::Device> getDevice(const Value* value) {
if (!value->type()->isSubtypeOf(TensorType::get())) {
// not tensor type, return false as the op is not outputing scalar.
return c10::nullopt;
@@ -21,7 +21,7 @@
return value->type()->expect<TensorType>()->device();
}
-static c10::optional<c10::Device> getDevice(const Node* const node) {
+static c10::optional<c10::Device> getDevice(const Node* node) {
auto outputs = node->outputs();
for (auto output : outputs) {
auto device = getDevice(output);
@@ -51,25 +51,39 @@
return device->is_cuda();
}
-inline bool isFusableNode(const Node* const node) {
+inline bool isFusableNode(const Node* node) {
// checks if node is compatible with parser:
// 1. if we have a parsing rule; or 2. if the node is already a fusion group.
return (isNodeParsible(node) || node->kind() == prim::CudaFusionGroup);
}
+bool hasReductionOperation(const Node* node) {
+ if (isReductionNode(node)) {
+ return true;
+ }
+ if (node->kind() == prim::CudaFusionGroup) {
+ for (auto n : node->g(attr::Subgraph)->nodes()) {
+ if (hasReductionOperation(n)) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
} // namespace
-bool isFusableCudaFusionGroup(const Node* const node) {
+bool isFusableCudaFusionGroup(const Node* node) {
if (isFusableNode(node)) {
return isFusableDevice(node);
}
return false;
}
-bool isFusableCudaFusionGroup(
- const Node* const fusion,
- const Node* const node) {
- if (isFusableCudaFusionGroup(node)) {
+bool isFusableCudaFusionGroup(const Node* fusion, const Node* node) {
+ // TODO: lift the restriction of not fusing producer containing reduction when
+ // we have proper scheduling.
+ if (isFusableCudaFusionGroup(node) && !hasReductionOperation(node)) {
// TODO: ensure legit fusion.
// issue 0: currently codegen doesn't support broadcasting, except in the
// form of stride 0.
diff --git a/torch/csrc/jit/codegen/cuda/partition.h b/torch/csrc/jit/codegen/cuda/partition.h
index e46ad8e..21a44d8 100644
--- a/torch/csrc/jit/codegen/cuda/partition.h
+++ b/torch/csrc/jit/codegen/cuda/partition.h
@@ -19,12 +19,12 @@
namespace fuser {
namespace cuda {
-TORCH_CUDA_API bool isFusableCudaFusionGroup(const Node* const node);
+TORCH_CUDA_API bool isFusableCudaFusionGroup(const Node* node);
// consider if `node` could be fused into `fusion`
TORCH_CUDA_API bool isFusableCudaFusionGroup(
- const Node* const fusion,
- const Node* const node);
+ const Node* fusion,
+ const Node* node);
} // namespace cuda
} // namespace fuser
diff --git a/torch/csrc/jit/codegen/cuda/predicate_compute.h b/torch/csrc/jit/codegen/cuda/predicate_compute.h
index 620bdc2..f671d31 100644
--- a/torch/csrc/jit/codegen/cuda/predicate_compute.h
+++ b/torch/csrc/jit/codegen/cuda/predicate_compute.h
@@ -32,7 +32,8 @@
namespace jit {
namespace fuser {
-struct PredicateCompute {
+class PredicateCompute {
+ public:
// Return if there are any predicates
static bool hasPredicates(const TensorIndex*);
diff --git a/torch/csrc/jit/codegen/cuda/register_interface.cpp b/torch/csrc/jit/codegen/cuda/register_interface.cpp
index 663518b..cedf44c 100644
--- a/torch/csrc/jit/codegen/cuda/register_interface.cpp
+++ b/torch/csrc/jit/codegen/cuda/register_interface.cpp
@@ -12,7 +12,8 @@
namespace cuda {
namespace {
-struct RegisterInterface {
+class RegisterInterface {
+ public:
RegisterInterface() {
auto ptr = getFuserInterface();
ptr->fn_compile_n_ = &compileCudaFusionGroup;
diff --git a/torch/csrc/jit/codegen/cuda/shape_inference.cpp b/torch/csrc/jit/codegen/cuda/shape_inference.cpp
index f6f0643..8038b2b 100644
--- a/torch/csrc/jit/codegen/cuda/shape_inference.cpp
+++ b/torch/csrc/jit/codegen/cuda/shape_inference.cpp
@@ -1,5 +1,6 @@
#include <torch/csrc/jit/codegen/cuda/shape_inference.h>
#include <c10/core/ScalarType.h>
+#include <torch/csrc/jit/ir/constants.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <ATen/ExpandUtils.h>
@@ -105,7 +106,7 @@
// to neither type promoteion nor shape.
case aten::add:
case aten::sub: {
- auto promoted_type = binary_broadcast_type(
+ const auto promoted_type = binary_broadcast_type(
node->input(0)->type()->cast<TensorType>(),
node->input(1)->type()->cast<TensorType>());
node->output()->setType(promoted_type);
@@ -118,7 +119,7 @@
case aten::ge:
case aten::ne:
case aten::eq: {
- auto promoted_type = binary_broadcast_type(
+ const auto promoted_type = binary_broadcast_type(
node->input(0)->type()->cast<TensorType>(),
node->input(1)->type()->cast<TensorType>(),
at::ScalarType::Bool);
@@ -126,7 +127,7 @@
break;
}
case aten::where: {
- auto promoted_type = binary_broadcast_type(
+ const auto promoted_type = binary_broadcast_type(
node->input(1)->type()->cast<TensorType>(),
node->input(2)->type()->cast<TensorType>());
node->output()->setType(promoted_type);
@@ -141,8 +142,21 @@
node->output()->setType(promoted_type);
break;
}
+ case aten::sum: {
+ const auto out_type = node->input(0)->type()->cast<TensorType>();
+ const auto dims = constant_as<c10::List<int64_t>>(node->input(1));
+ const auto keepdim = constant_as<bool>(node->input(2));
+ TORCH_CHECK(
+ dims.has_value() && keepdim.has_value(),
+ "Shape inference cannot handle options.");
+ node->output()->setType(
+ unary_reduce_type(out_type, dims->vec(), keepdim.value()));
+ break;
+ }
default:
- TORCH_CHECK(false, "shape/type inference failed.");
+ TORCH_CHECK(
+ false,
+ "shape/type inference failed, unrecognized operation encountered.");
// TODO: generate a proper error log, as this probably means something
// went unexpected.
break;
@@ -154,6 +168,29 @@
}
protected:
+ TensorTypePtr unary_reduce_type(
+ const TensorTypePtr& op,
+ const std::vector<int64_t>& dims,
+ bool keepdim) {
+ TORCH_CHECK(
+ op->scalarType().has_value() && op->device().has_value() &&
+ op->sizes().isComplete(),
+ "requires complete shape on input");
+ std::vector<int64_t> output_size;
+ std::vector<int64_t> input_size = *op->sizes().concrete_sizes();
+ for (size_t i = 0; i < input_size.size(); i++) {
+ if (std::find(dims.begin(), dims.end(), i) == dims.end()) {
+ output_size.emplace_back(input_size[i]);
+ } else if (keepdim) {
+ // Pushing size 1 here to maintain the reduction dimension because
+ // keepdim is true;
+ output_size.emplace_back(1);
+ }
+ }
+ return TensorType::createContiguous(
+ *op->scalarType(), *op->device(), output_size);
+ }
+
// TODO: we should comply to codegen type promotion.
TensorTypePtr binary_broadcast_type(
TensorTypePtr const& op0,
diff --git a/torch/csrc/jit/codegen/cuda/tensor_meta.cpp b/torch/csrc/jit/codegen/cuda/tensor_meta.cpp
deleted file mode 100644
index 275afc4..0000000
--- a/torch/csrc/jit/codegen/cuda/tensor_meta.cpp
+++ /dev/null
@@ -1,290 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/tensor_meta.h>
-#include <algorithm>
-#include <numeric>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-
-//#define TC_DEBUG
-
-/*
- * [Note - TensorContiguity implementation]
- *
- * contiguity_
- * stores contiguity information for each dimension:
- * number stored should be between 0 to N+2;
- * 0 - this axis requires broadcasting;
- * X - where X belongs [-N, -1] U [1, N] means current axis is immediately
- * outside of axis `abs(X)-1`. If X < 0, The two axis are contiguous
- * and can be collapsed;
- * N+1 - the fastest changing dimension, If -N-1, means its contiguous in
- * storage (stride == 1);
- * N+2 - Unknown;
- *
- * TODO sorted_axes_ is something hard to maintain in a meaningful way during
- * merge, maybe we'll drop it if it's not of much help for kernel
- * generation
- * sorted_axes_
- * This is a helper field, list of axes sorted by their stride;
- * Default would be `[0, 1, 2, ..., N-1]`
- * Given sorted_axes_[i] == X means: the i-th axis should be X (if X belongs
- * [0, N-1]). If X == -1, that means it's unknown. This could happen when we
- * merge two TensorContiguity and their order of axes are not consistent.
- *
- * The design of TensorContiguity is to handle two things:
- * 1. Contiguity check - whether or not the contiguity information has
- * changed. To do this, we can simply compare TensorContiguity::contiguity_
- * between two instances;
- * 2. Kernel generation
- * By looking at contiguity_ flag, we can make correct decision like:
- * a. collpasing dimensions;
- * b. kernel binding;
- *
- * merging two TensorContiguity would check their contiguity_ flag and mark
- * accordinly.
- *
- * Issues with current implementation [definitely not complete]:
- * 1. stride for size-1 dimension.
- * Because of PyTorch implementation, stride for size-1 dimension is ill
- * defined and can't properly express intended layout.
- */
-
-// debug print. remove this guy!
-#ifdef TC_DEBUG
-template <typename T>
-std::ostream& operator<<(std::ostream& os, const std::vector<T>& data) {
- os << "(";
- for (auto i = data.begin(); i != data.end(); i++) {
- os << (*i);
- os << " ";
- }
- return os << ")";
-}
-#endif
-
-TensorContiguity::TensorContiguity(
- const std::vector<int64_t>& sizes,
- const std::vector<int64_t>& strides) {
-#ifdef TC_DEBUG
- std::cout << "==== contiguity ====" << std::endl;
- std::cout << "sizes: " << sizes << std::endl;
- std::cout << "strides: " << strides << std::endl;
-#endif
-
- // assert consistent dimensionality;
- assert(sizes.size() == strides.size());
-
- // check for size 0 tensor;
- // size 0 tensor is not handled yet and we should not treat it as broadcast
- assert(std::none_of(
- sizes.begin(), sizes.end(), [](int64_t size) { return size == 0; }));
-
- int dim = sizes.size();
- contiguity_.resize(dim);
- sorted_axes_.resize(dim);
-
- // TODO:
- if (dim <= 0) {
- return;
- }
-
- std::iota(sorted_axes_.begin(), sorted_axes_.end(), 0);
-
- // sort axes per their strides
- // It's important that we use stable sort here, as higher
- std::stable_sort(
- sorted_axes_.begin(),
- sorted_axes_.end(),
- [&strides](int64_t s_a, int64_t s_b) {
- return strides[s_a] > strides[s_b];
- });
-
-#ifdef TC_DEBUG
- std::cout << "sorted index: " << sorted_axes_ << std::endl;
-#endif
-
- // Update contiguity flag all the way until the second to the last;
- for (int i = 0; i < dim; i++) {
- // decending strides: strides[axis_p_1] <= stride[axis_p];
- int axis_p = sorted_axes_[i];
- int stride_p = strides[axis_p];
- if (stride_p == 0) {
- contiguity_[axis_p] = 0; // mark axis_p as broadcast
- } else {
- if (i + 1 == dim) {
- contiguity_[axis_p] = dim + 1;
- if (stride_p == 1) {
- contiguity_[axis_p] *= -1; // we mark axis_p as contiguous in memory;
- }
- break;
- }
- // Check if we should skip the check for collapsing, if:
- // 1. we are at the fastest changing dimension already.
- // (i == dim-1)
- // or
- // 2. the next dimension is a broadcast dimension.
- // (strides[sorted_axes_[i+1]] == 0))
- if ((i == dim - 1) || (strides[sorted_axes_[i + 1]] == 0)) {
- // axis_p is the fastest changing dimension.
- // dim+1 is out of range for next axis, so we would know it's the last
- // dimension.
- contiguity_[axis_p] = dim + 1;
- if (stride_p == 1) {
- // we mark axis_p as contiguous in memory by setting it to negative.
- contiguity_[axis_p] *= -1;
- }
- } else {
- int axis_p_1 = sorted_axes_[i + 1];
- // mark axis_p_1 as the neighboring axis;
- // Notice the compensation for 1-based indexing.
- contiguity_[axis_p] = axis_p_1 + 1;
-
- // Check if axis_p could collapse down to axis_p_1;
- // [Note] Do NOT specialize on size-1 dimension! Two issues:
- // Although size-1 could collapse with any dimension, that's going to
- // be a specialization to the static shape information -> hence not an
- // intended protocol.
- // size-1 collapsing could also be misleading, as the size-1 axis
- // could falsely give the impression that its neighboring axes are
- // collapsible, while they are not.
- // i.e. size[4, 1, 4]; stride[8, 1, 1];
- // both axis 0 and 2 could fuse with axis 1 separately. but we cannot
- // fuse them all together.
- if (stride_p == sizes[axis_p_1] * strides[axis_p_1]) {
- // negative number to specify it's collapsable.
- contiguity_[axis_p] *= -1; // mark axis_p as broadcast
- }
- }
- }
- }
-
-#ifdef TC_DEBUG
- std::cout << "contiguity flag: " << contiguity_ << std::endl;
- std::cout << "==== done contiguity ====" << std::endl;
-#endif
-}
-
-bool TensorContiguity::isBroadcastDim(int axis) const {
- assert(axis >= 0 && axis < rank());
- return contiguity_[axis] == 0;
-}
-
-std::vector<int> TensorContiguity::getBroadcastDims() const {
- std::vector<int> ret;
- for (decltype(contiguity_.size()) i{0}; i < contiguity_.size(); i++) {
- if (contiguity_[i] == 0) {
- ret.emplace_back(static_cast<int>(i));
- }
- }
- return ret;
-}
-
-// we are checking if axis can merge to right.
-// axis_right == axis + 1,
-bool TensorContiguity::canCollapseToHigher(int axis) const {
- // not necessary as to check `assert(axis < rank()-1);` as
- // canCollapseLowerHigher would assert on that;
- return canCollapseLowerHigher(axis, axis + 1);
-}
-
-int TensorContiguity::rank() const {
- return contiguity_.size();
-}
-
-bool TensorContiguity::canCollapseLowerHigher(int lower_axis, int higher_axis)
- const {
- // we are checking if axis can merge to right.
- // we mark contiguity_ as -(target_axis + 1), if it's collapsible;
- assert(
- lower_axis >= 0 && lower_axis < rank() && higher_axis >= 0 &&
- higher_axis < rank());
- return contiguity_[lower_axis] == -(higher_axis + 1);
-}
-
-int TensorContiguity::getFCD() const {
- for (decltype(contiguity_.size()) i{0}; i < contiguity_.size(); i++) {
- if (contiguity_[i] == (-((int)contiguity_.size()) - 1))
- return i;
- }
- return -1;
-}
-
-bool TensorContiguity::isIdentical(const TensorContiguity& tc) const {
- for (int i = 0; i < rank(); i++) {
- if (tc.contiguity_[i] != contiguity_[i]) {
- return false;
- }
- }
- return true;
-}
-
-bool TensorContiguity::isCompatible(const TensorContiguity& tc) const {
- assert(false); // not yet implemented;
- return false;
-}
-bool TensorContiguity::hasContiguousFCD() const {
- for (decltype(contiguity_.size()) i{0}; i < contiguity_.size(); i++) {
- if (contiguity_[i] == (-((int)contiguity_.size()) - 1))
- return true;
- }
- return false;
-}
-
-int TensorContiguity::getAxisByStride(int order) const {
- assert(order >= 0 && order < rank());
- return sorted_axes_[order];
-}
-
-const std::vector<int>& TensorContiguity::getAxesOrderedByStride() const {
- return sorted_axes_;
-}
-
-const std::vector<int>& TensorContiguity::getContiguityTag() const {
- return contiguity_;
-}
-
-const std::vector<int>& TensorContiguity::getSortedAxesTag() const {
- return sorted_axes_;
-}
-
-void TensorContiguity::merge(const TensorContiguity& tc) {
- // TODO: different rank not supported yet; This could be done if we follow
- // numpy broadcasting rule across multiple operands. We simply insert
- // dummy dimensions at the left for tc with lower rank()
- // see [Note - TensorContiguity implementation]
- int dim = rank();
- assert(dim == tc.rank());
-
- for (int i = 0; i < dim; i++) {
- int cont_flag = tc.contiguity_[i];
-
- if (cont_flag != contiguity_[i]) {
- if (cont_flag == -contiguity_[i]) {
- // If sorting should remain, we preserve the information but only relax
- // the contiguity information;
- contiguity_[i] = std::abs(cont_flag);
- } else {
- // mark contiguity as unknown otherwise.
- contiguity_[i] = dim + 2;
- }
- cont_flag = contiguity_[i];
- }
-
- // TODO: can we update sorted_axes_ information via contiguity flag?
- if (tc.sorted_axes_[i] != sorted_axes_[i]) {
- // mark sorted_axes_ as unknown;
- sorted_axes_[i] = -1;
- }
- }
-
-#ifdef TC_DEBUG
- std::cout << "merging" << std::endl;
- std::cout << "sorted index: " << sorted_axes_ << std::endl;
- std::cout << "contiguity flag: " << contiguity_ << std::endl;
-#endif
-}
-
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/tensor_meta.h b/torch/csrc/jit/codegen/cuda/tensor_meta.h
deleted file mode 100644
index 0f85ad2..0000000
--- a/torch/csrc/jit/codegen/cuda/tensor_meta.h
+++ /dev/null
@@ -1,85 +0,0 @@
-#pragma once
-
-#include <c10/core/Device.h>
-#include <c10/core/DeviceType.h>
-#include <c10/util/Exception.h>
-#include <torch/csrc/WindowsTorchApiMacro.h> // TORCH_CUDA_API
-
-#include <torch/csrc/jit/codegen/cuda/utils.h>
-
-#include <cstdint>
-#include <vector>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-/*
- Issues that we are not solving:
- 1. strides for trivial dimensions (size-1 dimension);
- 2. memory overlap / interleave;
- TODO: docs explaining the protocol; what is stored and how does merge work
- */
-struct TORCH_CUDA_API TensorContiguity {
- TensorContiguity(
- const std::vector<int64_t>& size,
- const std::vector<int64_t>& stride);
-
- // gives broadcast information per axis;
- bool isBroadcastDim(int axis) const;
-
- // returns all axes that requires broadcast;
- std::vector<int> getBroadcastDims() const;
-
- // gives contiguity information per axis;
- // This basically calls to canCollapseLowerHigher(axis, axis+1);
- bool canCollapseToHigher(int axis) const;
-
- // return the rank of the tensor;
- int rank() const;
-
- // check contiguity
- bool isIdentical(const TensorContiguity& tc) const;
-
- // TODO: check for compatiblity
- // I need feedback on this one. What do we check? order + broadcast rule?
- bool isCompatible(const TensorContiguity& tc) const;
-
- /*******************************************************************************
- * Future proof support
- * we don't need these yet.
- * TODO: we probably won't need this until much later, but let's try solve
- * the problem that doesn't exist yet;
- ******************************************************************************/
-
- // [NOTE] the order of the argument matters:
- // canCollapseLowerHigher(x, y) differs from canCollapseLowerHigher(y, x)
- bool canCollapseLowerHigher(int lower_axis, int higher_axis) const;
-
- // FCD: Fast changing dimension, the dimension with smallest stride (>0).
- // returns -1 if FCD doesn't exist (e.g. fully broadcast)
- int getFCD() const;
- // Check if FCD exist and has stride == 1.
- bool hasContiguousFCD() const;
-
- // This is used to support rational binding;
- // similarly return -1 means it's unknown;
- int getAxisByStride(int order) const;
- const std::vector<int>& getAxesOrderedByStride() const;
-
- // TODO: we should encode this to a single integer with restricted rank.
- const std::vector<int>& getContiguityTag() const;
- const std::vector<int>& getSortedAxesTag() const;
-
- void merge(const TensorContiguity& tc);
-
- protected:
- // contiguity_ : contiguity and broadcast;
- std::vector<int> contiguity_;
-
- // sorted_axes_ : axes ordered by strides (slow dimension to fast dimension).
- std::vector<int> sorted_axes_;
-};
-
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp
index 44dd66a..dd8ac6f 100644
--- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp
+++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp
@@ -1,9 +1,14 @@
#include <torch/csrc/jit/codegen/cuda/arith.h>
+#include <torch/csrc/jit/codegen/cuda/compute_at.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
+#include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
-#include <torch/csrc/jit/codegen/cuda/mutator.h>
+// #include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
+#include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
+
+// Cleanup
+// #include <torch/csrc/jit/codegen/cuda/mutator.h>
#include <torch/csrc/jit/codegen/cuda/transform_iter.h>
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
@@ -38,10 +43,26 @@
this->name_ = fusion_->registerVal(this);
}
+TensorView::TensorView(const TensorView* src, IrCloner* ir_cloner)
+ : Val(src, ir_cloner),
+ domain_(ir_cloner->clone(src->domain_)),
+ compute_at_view_(ir_cloner->clone(src->compute_at_view_)),
+ relative_compute_at_axis_(src->relative_compute_at_axis_),
+ this_compute_at_axis_(src->this_compute_at_axis_),
+ memory_type_(src->memory_type_) {}
+
bool TensorView::hasReduction() const {
return domain()->hasReduction();
}
+bool TensorView::hasBlockReduction() const {
+ return domain()->hasBlockReduction();
+}
+
+bool TensorView::hasGridReduction() const {
+ return domain()->hasGridReduction();
+}
+
bool TensorView::hasBroadcast() const {
return domain()->hasBroadcast();
}
@@ -55,6 +76,8 @@
}
IterDomain* TensorView::axis(int pos) const {
+ TORCH_INTERNAL_ASSERT(
+ nDims() > 0, "Tried to access an axis in a 0-dim TensorView");
if (pos < 0)
pos += domain()->nDims();
TORCH_CHECK(
@@ -98,6 +121,17 @@
"Invalid computeAt, reduction domain inside computeAt axis.");
}
+void TensorView::setComputeAt(
+ TensorView* computeAtView,
+ int thisPos,
+ int relPos) {
+ compute_at_view_ = computeAtView;
+ relative_compute_at_axis_ = relPos;
+ this_compute_at_axis_ = thisPos;
+ TORCH_INTERNAL_ASSERT(
+ this_compute_at_axis_ <= nDims(), "Manually set an invalid computeAt.");
+}
+
void TensorView::copyDomain(const TensorDomain* td) {
std::vector<IterDomain*> idv;
for (decltype(td->nDims()) i = 0; i < td->nDims(); i++)
@@ -116,7 +150,8 @@
size_t pos_cav = 0, pos_this = 0;
while ((int)pos_this < pos) {
TORCH_INTERNAL_ASSERT(
- pos_cav < nDims(), "Error computing relative position in computeAt.");
+ pos_cav < compute_at_view_->nDims(),
+ "Error computing relative position in computeAt.");
if (compute_at_view_->axis(pos_cav)->isBroadcast() &&
!(axis(pos_this)->isBroadcast())) {
pos_cav++;
@@ -160,239 +195,27 @@
this_compute_at_axis_ = pos_this;
}
-// Actually applies transformation
-void TensorView::computeAt_impl(
- TensorView* consumer,
- int consumer_compute_at_axis) {
- // Reset view otherwise will conflict with replay.
- clearComputeAt();
- // replay this as consumer / producer as consumer
- TransformReplay::replayPasC(this, consumer, consumer_compute_at_axis);
- setComputeAt(consumer, consumer_compute_at_axis);
-}
-
-// Actually applies transformation
-void TensorView::forwardComputeAt_impl(
- TensorView* producer,
- int producer_compute_at_axis) {
- // Reset view otherwise will conflict with replay.
- producer->clearComputeAt();
- TransformReplay::replayCasP(this, producer, producer_compute_at_axis);
- producer->setComputeAt(this, producer_compute_at_axis);
-}
-
-namespace {
-// Wrapper around set_intersection
-template <typename T>
-std::set<T> set_intersection(const std::set<T>& set1, const std::set<T>& set2) {
- std::set<T> intersection;
- std::set_intersection(
- set1.begin(),
- set1.end(),
- set2.begin(),
- set2.end(),
- std::inserter(intersection, intersection.begin()));
- return intersection;
-}
-
-// convert an iterable of Val* to be an iterable of TensorView*
-template <typename T1, typename T2>
-T1 tv_iterable(const T2& val_iterable) {
- T1 tv_iterable = T1();
- std::transform(
- val_iterable.begin(),
- val_iterable.end(),
- std::back_inserter(tv_iterable),
- [](Val* v) {
- TORCH_INTERNAL_ASSERT(
- v->getValType().value() == ValType::TensorView,
- "When following the computeAt dependency chain, a non TensorView value was found.");
- return static_cast<TensorView*>(v);
- });
- return tv_iterable;
-}
-} // namespace
-
TensorView* TensorView::computeAt(TensorView* consumer, int axis) {
- TORCH_CHECK(
- this->fusion() == consumer->fusion(),
- this,
- " and ",
- consumer,
- " are not in the same fusion.");
-
- FusionGuard fg(this->fusion());
-
+ // Make sure this and consumer are not the same tensor, that's illegal
TORCH_CHECK(
!this->sameAs(consumer), "Cannot call this->computeAt(this, ...)");
+ // We support negative axes, so increment it by consumer->nDims() + 1 and make
+ // sure the result is within consumer->nDims() + 1. being at consumer->nDims()
+ // means producer will be computed inline with consumer, hence the +1.
if (axis < 0)
- // Compute at is a bit strange where size is the maximum acceptable value
- // instead of size-1
axis += int(consumer->nDims()) + 1;
-
TORCH_CHECK(
axis >= 0 && (unsigned int)axis < consumer->nDims() + 1,
"Compute at called on an axis outside valid range.");
- // If not direct relationship follow dependency chain from consumer to
- // producer.
- auto dep_chains = DependencyCheck::getAllDependencyChains(this, consumer);
-
- std::deque<Val*> dep_chain;
- if (!dep_chains.empty())
- dep_chain = dep_chains.front();
-
- // Make sure there is a dependency chain, if not it's an invalid computeAt.
- // We could do indirect computeAts, but it's not supported at this time.
- TORCH_CHECK(
- !dep_chain.empty(),
- "Compute At expects ",
- this,
- " is a dependency of ",
- consumer,
- ", however it is not.");
-
- // Validate dependency chain returned as expected
- TORCH_INTERNAL_ASSERT(
- dep_chain.back() == consumer && dep_chain[0] == this,
- "Error computing dependency chain.");
-
- // Start the replay going from consumer, through the dependency chain to
- // producer. After this section, producer should look like consumer, and there
- // should be a computeAt chain going from producer to consumer. Proper
- // computeAts are setup, though they will be over-written in a later stage.
- while (dep_chain.size() > 1) {
- Val* consumer_val = dep_chain.back();
- dep_chain.pop_back();
- Val* producer_val = dep_chain.back();
-
- TORCH_INTERNAL_ASSERT(
- consumer_val->getValType().value() == ValType::TensorView &&
- producer_val->getValType().value() == ValType::TensorView,
- "When following the computeAt dependency chain, a non TensorView value was found.");
-
- TensorView* running_consumer = static_cast<TensorView*>(consumer_val);
- TensorView* running_producer = static_cast<TensorView*>(producer_val);
- // Axis is relative to consumer, however as we propagate computeAt, it may
- // move. This is why we have TensorView->getThisComputeAtAxis() which
- // returns where in a TensorView does the computeAt (relative to consumer)
- // line up. Mismatch is due to broadcast.
- int compute_at_axis = axis;
- if (running_consumer != consumer)
- compute_at_axis = (int)running_consumer->getThisComputeAtAxis();
- running_producer->computeAt_impl(running_consumer, compute_at_axis);
- }
-
- /*
- * Compute At has now worked from consumer to producer, transforming producer
- * to match computeAt selected in consumer We now need to work from producer
- * up to its consumers (including indirect consumption) so their use also
- * matches. If we can find a TV that contains all uses of producer (common
- * consumer), we can terminate this propagation there. If not, we need to
- * propagate all the way to outputs.
- */
-
- // Start looking for a common consumer of producer
-
- // Grab all uses of producer in fusion
- auto val_all_consumer_chains =
- DependencyCheck::getAllDependencyChainsTo(this);
-
- // Convert dep chains to tensor view chains
- std::deque<std::deque<TensorView*>> all_consumer_chains;
- for (const auto& val_dep_chain : val_all_consumer_chains)
- all_consumer_chains.push_back(
- tv_iterable<std::deque<TensorView*>>(val_dep_chain));
-
- // Set arith to find a common consumer, start with first use chain of producer
- std::set<TensorView*> common_consumers(
- all_consumer_chains.front().begin(), all_consumer_chains.front().end());
-
- // Run through all use chains of producer, and intersect them
- for (auto dep_chain : all_consumer_chains)
- common_consumers = set_intersection(
- common_consumers,
- std::set<TensorView*>(dep_chain.begin(), dep_chain.end()));
-
- // Remove all TVs between producer and consumer as we don't want a common
- // consumer placed logically before consumer provided in computeAt
- for (const auto& dep_chain : dep_chains) {
- auto tv_chain = tv_iterable<std::deque<TensorView*>>(dep_chain);
- for (auto tv : tv_chain) {
- if (tv != consumer)
- common_consumers.erase(tv);
- }
- }
-
- // If there is a common consumer, grab the first one (topologically)
- TensorView* common_consumer = nullptr;
- if (!common_consumers.empty()) {
- for (TensorView* tv : all_consumer_chains.front())
- if (common_consumers.find(tv) != common_consumers.end()) {
- common_consumer = tv;
- break;
- }
- }
-
- // Forward propagate the transformationthrough all use chains until
- // common_consumer if there is one otherwise until we hit all output TVs
- std::set<TensorView*> output_set;
- // computeAt axis in outputs don't necessarily match up, make sure to keep the
- // relative computeAt position in each output
- std::vector<std::pair<TensorView*, int>> ordered_outputs;
- for (auto dep_chain : all_consumer_chains) {
- // All dep chains start with this.
- TORCH_INTERNAL_ASSERT(
- dep_chain.front() == this,
- "Invalid dependency chain found during computeAt, ",
- dep_chain.front(),
- " should be ",
- this);
- TORCH_INTERNAL_ASSERT(
- this->hasComputeAt(),
- "Error detected during computeAt, ",
- this,
- ", should have a computeAt set at this point even though we will over-write it.");
- int running_producer_compute_at = (int)this->getThisComputeAtAxis();
- while (dep_chain.size() > 1) {
- TensorView* running_producer = dep_chain.front();
- dep_chain.pop_front();
- TensorView* running_consumer = dep_chain.front();
-
- if (running_producer == common_consumer)
- break;
- // Axis is relative to consumer, and may not necessarily apply to all
- // intermediate steps. Fortunately producer is guarenteed to have a valid
- // computeAt set, so we can use the compute at axis relative to producer.
- running_consumer->forwardComputeAt_impl(
- running_producer, running_producer_compute_at);
- running_producer_compute_at =
- (int)running_producer->getThisComputeAtAxis();
- int consumer_compute_at =
- (int)running_producer->getRelativeComputeAtAxis();
-
- if (dep_chain.size() == 1) { // last one
- if (output_set.find(running_consumer) == output_set.end()) {
- output_set.emplace(running_consumer);
- ordered_outputs.emplace_back(std::pair<TensorView*, int>(
- running_consumer, consumer_compute_at));
- }
- }
- }
- }
-
- if (!ordered_outputs.empty())
- for (auto it = ordered_outputs.begin(); it + 1 != ordered_outputs.end();
- it++)
- (*it).first->computeAt_impl(
- (*(it + 1)).first,
- (*(it + 1)).second); // use recorded position, not axis.
+ ComputeAt::run(this, consumer, (unsigned int)axis);
return this;
}
TensorView* TensorView::split(int axis, unsigned int factor) {
+ TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do split on a 0-dim TensorView");
if (axis < 0)
axis += domain()->nDims();
@@ -411,6 +234,7 @@
// Merge "axis" and "axis+1" into 1 dimension
TensorView* TensorView::merge(int axis_o, int axis_i) {
+ TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do merge on a 0-dim TensorView");
if (axis_o < 0)
axis_o += domain()->nDims();
@@ -434,19 +258,15 @@
}
TensorView* TensorView::reorder(const std::unordered_map<int, int>& old2new_) {
+ TORCH_INTERNAL_ASSERT(
+ !(nDims() == 0 && old2new_.size() > 0),
+ "Tried to reorder a 0-dim TensorView");
domain()->reorder(old2new_);
return this;
}
-/*
- * Take reduction axes out of this domain, and create a new domain. New domain
- * will be used to create this domain. For example: TV1[I0, I1] = TV0[I0, R0,
- * R1, I1] TV0->rfactor({1}) TV0 is transformed to -> TV0[I0, R1, I1] The
- * TensorView returned is: TV2[I0, R0, I3, I1] The reduction will now beset
- * as: TV1[I0, R1, I1] = TV2[I0, R0, I3, I1] TV0[I0, I1] = TV1[I0, R1, I1]
- */
-
TensorView* TensorView::rFactor(const std::vector<int>& axes) {
+ TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to rFactor a 0-dim TensorView");
FusionGuard fg(this->fusion());
Expr* origin_expr = this->fusion()->origin(this);
TORCH_CHECK(
diff --git a/torch/csrc/jit/codegen/cuda/transform_iter.cpp b/torch/csrc/jit/codegen/cuda/transform_iter.cpp
index 42ecdbe..20cf158 100644
--- a/torch/csrc/jit/codegen/cuda/transform_iter.cpp
+++ b/torch/csrc/jit/codegen/cuda/transform_iter.cpp
@@ -26,7 +26,7 @@
// going to replay the split on
auto it = id_map_.find(id_in);
if (it == id_map_.end()) {
- if (check_all_ops_run_) {
+ if (error_on_failure_) {
TORCH_INTERNAL_ASSERT(
false, "Transform traversal failed, dependencies not met.");
} else {
@@ -67,18 +67,42 @@
// we're going to replay the merge on
auto it_outer = id_map_.find(id_outer);
auto it_inner = id_map_.find(id_inner);
- if (it_outer == id_map_.end() || it_inner == id_map_.end()) {
- if (check_all_ops_run_) {
- TORCH_INTERNAL_ASSERT(
- false, "Transform traversal failed, dependencies not met.");
- } else {
- return;
+
+ const bool outer_found = it_outer != id_map_.end();
+ const bool outer_bcast = id_outer->isBroadcast();
+ const bool inner_found = it_inner != id_map_.end();
+ const bool inner_bcast = id_inner->isBroadcast();
+
+ // If either are not found
+ if (!outer_found || !inner_found) {
+ // If both aren't found, it's a failure
+ // If outer is found && inner is bcast it is not a failure
+ // If inner is found && outer is bcast it is not a failure
+ if (!(outer_found || inner_found) || (outer_found && !inner_bcast) ||
+ (inner_found && !outer_bcast)) {
+ if (error_on_failure_) {
+ TORCH_INTERNAL_ASSERT(
+ false, "Transform traversal failed, dependencies not met.");
+ } else {
+ return;
+ }
}
}
+ // If we merge a broadcast dim with a non-broadcast dim, just remap the output
+ // to the non-broadcast dim.
+ if (inner_found && !outer_found && outer_bcast) {
+ id_map_[m->out()] = it_inner->second;
+ return;
+ }
+ if (outer_found && !inner_found && inner_bcast) {
+ id_map_[m->out()] = it_outer->second;
+ return;
+ }
+
// Grab the IDs we're going to replay this merge on
- auto id_outer_mapped = (*it_outer).second;
- auto id_inner_mapped = (*it_inner).second;
+ const auto id_outer_mapped = it_outer->second;
+ const auto id_inner_mapped = it_inner->second;
// Make sure these IDs are leaf IDs (meaning they have no uses we generated)
TORCH_INTERNAL_ASSERT(
@@ -107,15 +131,15 @@
ReplayTransformations::ReplayTransformations(
const std::vector<IterDomain*>& _target_domain,
std::unordered_map<IterDomain*, IterDomain*> _id_map,
- bool _check_all_ops_run)
+ bool _error_on_failure)
: target_domain_(_target_domain),
id_map_(std::move(_id_map)),
- check_all_ops_run_(_check_all_ops_run) {
+ error_on_failure_(_error_on_failure) {
// Make sure id_map has all the inputs needed to replay target_domain
auto inps = IterVisitor::getInputsTo(
std::vector<Val*>(target_domain_.begin(), target_domain_.end()));
- if (check_all_ops_run_)
+ if (error_on_failure_)
std::for_each(inps.begin(), inps.end(), [this](Val* val) {
TORCH_INTERNAL_ASSERT(
val->getValType().value() == ValType::IterDomain,
@@ -149,7 +173,7 @@
target_domain_.begin(), target_domain_.end());
traverseFrom(traversal_vals[0]->fusion(), traversal_vals);
- if (check_all_ops_run_)
+ if (error_on_failure_)
TORCH_INTERNAL_ASSERT(
leaf_ids_.size() >= target_domain_.size(),
"Transform traversal failed, did not find enough output IterDomains.");
@@ -158,7 +182,7 @@
for (auto out : target_domain_) {
auto it_replayed = id_map_.find(out);
if (it_replayed == id_map_.end()) {
- if (check_all_ops_run_) {
+ if (error_on_failure_) {
TORCH_INTERNAL_ASSERT(
false,
"Transform traversal failed, could not find expected output.");
@@ -170,7 +194,9 @@
auto it_leaf = leaf_ids_.find(id_replayed);
TORCH_INTERNAL_ASSERT(
it_leaf != leaf_ids_.end(),
- "Transform Traversal failed, expected matched output to be a leaf of the replay, but was not.");
+ "Transform Traversal failed, expected a replayed dim for ",
+ out,
+ " but one was not created.");
}
// Populate leaf_vec_ in a deterministic manner. This is deterministic
@@ -275,7 +301,7 @@
continue;
}
- if (t_expr->nOutputs() != r_expr->nOutputs()) {
+ if (t_expr->outputs().size() != r_expr->outputs().size()) {
TORCH_INTERNAL_ASSERT(!has_rfactor, err_str);
continue;
}
@@ -318,7 +344,7 @@
}
// Add outputs to map.
- for (size_t i = 0; i < t_expr->nOutputs(); i++) {
+ for (size_t i = 0; i < t_expr->outputs().size(); i++) {
auto t_out = t_expr->output(i);
auto r_out = r_expr->output(i);
if (t_out->getValType() == ValType::IterDomain &&
diff --git a/torch/csrc/jit/codegen/cuda/transform_iter.h b/torch/csrc/jit/codegen/cuda/transform_iter.h
index 645d702..d1aaddd 100644
--- a/torch/csrc/jit/codegen/cuda/transform_iter.h
+++ b/torch/csrc/jit/codegen/cuda/transform_iter.h
@@ -24,7 +24,7 @@
};
// Simply grabs all exprs needed to produce provided outputs.
-struct Exprs : public IterVisitor {
+class Exprs : public IterVisitor {
private:
std::vector<Expr*> exprs;
void handle(Expr* e) override {
@@ -44,14 +44,27 @@
} // namespace
-struct TORCH_CUDA_API ReplayTransformations : public IterVisitor {
+// Uses the history of _target_domain, and replays that history using the
+// provided map.
+//
+// target_domain contains the history we want replayed.
+//
+// id_map maps IterDomains in that history to the IterDomains we want it
+// replayed on.
+//
+// error_on_failure = true will cause the replay to error if we can't replay any
+// operation in target_domain's history due to missing IDs in the id_map.
+//
+// If error_on_failure = false, replay will replay everything it can, and ignore
+// operations it can't.
+class TORCH_CUDA_API ReplayTransformations : public IterVisitor {
protected:
const std::vector<IterDomain*>& target_domain_;
std::unordered_map<IterDomain*, IterDomain*> id_map_;
std::unordered_map<IterDomain*, size_t> leaf_ids_;
std::vector<IterDomain*> leaf_vec_;
size_t counter = 0;
- bool check_all_ops_run_ = true;
+ bool error_on_failure_ = true;
bool ran_replay = false; // Mark if replay has been run
using IterVisitor::handle;
@@ -60,23 +73,16 @@
// TODO: HANDLE RFACTOR DOMAINS
// We're going to replay this split operation on the corresponding ID
- virtual void handle(Split* s) override;
+ void handle(Split* s) override;
// We're going to replay this merge operation on the corresponding IDs
- virtual void handle(Merge* m) override;
+ void handle(Merge* m) override;
public:
- // Uses the history of _target_domain, and replays that history using the
- // provided map target_domain contains the history we want replayed, and
- // id_map maps IterDomains in that history to the IterDomains we want it
- // replayed on. check_all_ops_run will cause the replay to error if we can't
- // play any operation in target_domain's history because the IDs are not in
- // the id_map. If check_all_ops_run = false, replay will replay everything it
- // can, and ignore operations it can't.
ReplayTransformations(
const std::vector<IterDomain*>& _target_domain,
std::unordered_map<IterDomain*, IterDomain*> _id_map,
- bool _check_all_ops_run = true);
+ bool _error_on_failure = true);
// Replays outputs that were generated from ids.first on ids.second
void runReplay();
@@ -155,7 +161,7 @@
* to the output of the equivlent expr's outputs in relpay_domain's history.
*/
-struct TORCH_CUDA_API BestEffortReplay {
+class TORCH_CUDA_API BestEffortReplay {
private:
std::unordered_map<IterDomain*, IterDomain*> id_map_;
std::unordered_map<IterDomain*, size_t> leaf_ids_;
diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp
index 2da65bd..ccfa63e 100644
--- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp
+++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp
@@ -15,7 +15,7 @@
namespace {
-struct ReplaySelf : public ReplayTransformations {
+class ReplaySelf : public ReplayTransformations {
private:
// Took a good bit of this from ReplayTransformations::handle(Split...)
void handle(Split* s) override {
@@ -129,8 +129,8 @@
// Self replay.
TensorDomain* TransformReplay::fullSelfReplay(
- TensorDomain* new_self_root,
- TensorDomain* self) {
+ const TensorDomain* new_self_root,
+ const TensorDomain* self) {
TORCH_INTERNAL_ASSERT(
new_self_root->nDims() == self->rootDomain().size(),
"Invalid number of IterDomains provided.");
@@ -174,15 +174,14 @@
return new TensorDomain(new_self_root->domain(), new_domain);
}
-// Replay producer as consumer.
// Producer could have rfactor axes which consumer may want replayed. We can
// "replay" them as long as it doesn't modify the root rfactor axes. What we
// really want to do is validate if we replayed these axes to the ones they
// mapped to in the consumer the operations would all be the same. then we want
// to start the replay of the producer from the rfactor root axes, not the root.
-TensorDomain* TransformReplay::replayPasC(
- TensorDomain* producer,
- TensorDomain* consumer,
+std::pair<TensorDomain*, unsigned int> TransformReplay::replayPasC(
+ const TensorDomain* producer,
+ const TensorDomain* consumer,
int consumer_compute_at_axis) {
if (consumer_compute_at_axis < 0)
consumer_compute_at_axis += (int)consumer->nDims() + 1;
@@ -192,17 +191,9 @@
"Invalid axis in transform replayPasC.");
// consumer ids we need to match in producer
- std::vector<IterDomain*> consumer_CA_ids;
- {
- int itc = 0;
- while (itc < consumer_compute_at_axis) {
- if (consumer->axis(itc)->isBroadcast()) {
- itc++;
- } else {
- consumer_CA_ids.emplace_back(consumer->axis(itc++));
- }
- }
- }
+ std::vector<IterDomain*> consumer_CA_ids(
+ consumer->domain().begin(),
+ consumer->domain().begin() + consumer_compute_at_axis);
// Figure out all inputs required to generate the compute_at dimensions
std::unordered_set<Val*> consumer_CA_root_ids = IterVisitor::getInputsTo(
@@ -228,7 +219,8 @@
{
size_t itc = 0, itp = 0;
while (itc < consumer_root.size() || itp < producer_root.size()) {
- if (itc < consumer_root.size() && consumer_root[itc]->isBroadcast()) {
+ if (itc < consumer_root.size() && consumer_root[itc]->isBroadcast() &&
+ (itp >= producer_root.size() || !producer_root[itp]->isBroadcast())) {
itc++;
continue;
}
@@ -273,11 +265,14 @@
// rest
for (auto c_id : consumer_CA_ids) {
auto it = replay_PasC.getReplay().find(c_id);
- TORCH_INTERNAL_ASSERT(
- it != replay_PasC.getReplay().end(),
- "Could not find axis, ",
- c_id,
- ", requested in replay.");
+ if (it == replay_PasC.getReplay().end()) {
+ TORCH_INTERNAL_ASSERT(
+ c_id->isBroadcast(),
+ "Could not find axis, ",
+ c_id,
+ ", requested in replay.");
+ continue;
+ }
if (leaf_ids.find(it->second) != leaf_ids.end())
leaf_ids.erase(it->second);
}
@@ -328,15 +323,19 @@
// Add axes in (1)
for (auto c_id : consumer_CA_ids) {
auto it = replay_PasC.getReplay().find(c_id);
- TORCH_INTERNAL_ASSERT(
- it != replay_PasC.getReplay().end(),
- "Could not find axis, ",
- c_id,
- ", requested in replay.");
+ if (it == replay_PasC.getReplay().end()) {
+ TORCH_INTERNAL_ASSERT(
+ c_id->isBroadcast(),
+ "Could not find axis, ",
+ c_id,
+ ", requested in replay.");
+ continue;
+ }
new_IDs.push_back(it->second);
used_IDs.emplace(it->second);
}
+ unsigned int producer_compute_at_axis = new_IDs.size();
// Add axes in (2)
std::unordered_set<IterDomain*> consumer_CA_ids_set(
consumer_CA_ids.begin(), consumer_CA_ids.end());
@@ -369,16 +368,16 @@
TensorDomain* replayed = new TensorDomain(
producer->rootDomain(), producer->rfactorDomain(), new_IDs);
- return replayed;
+ return {replayed, producer_compute_at_axis};
}
-// Replay consumer as producer.
-TensorDomain* TransformReplay::replayCasP(
- TensorDomain* consumer,
- TensorDomain* producer,
+std::pair<TensorDomain*, unsigned int> TransformReplay::replayCasP(
+ const TensorDomain* consumer,
+ const TensorDomain* producer,
int producer_compute_at_axis) {
if (producer_compute_at_axis < 0)
producer_compute_at_axis += (int)producer->nDims() + 1;
+
TORCH_INTERNAL_ASSERT(
producer_compute_at_axis >= 0 &&
(unsigned int)producer_compute_at_axis <= producer->nDims(),
@@ -397,21 +396,31 @@
}
}
- // Figure out all inputs required to generate the compute_at dimensions
- std::unordered_set<Val*> producer_CA_root_ids = IterVisitor::getInputsTo(
- std::vector<Val*>(producer_CA_ids.begin(), producer_CA_ids.end()));
-
// Map of producer_CA_root_ids to related producer_CA_ids
id_map replay_root_map;
// Grab root domains of producer and consumer
std::vector<IterDomain*> consumer_root = consumer->rootDomain();
std::vector<IterDomain*> producer_root = producer->rootDomain();
+
// If producer has an rfactor root, that's the one that will match the
// consumer
if (producer->hasRFactor())
producer_root = producer->rfactorDomain();
+ // Figure out all inputs required to generate the compute_at dimensions
+ std::unordered_set<Val*> all_CA_id_deps = DependencyCheck::getAllValsBetween(
+ std::unordered_set<Val*>(
+ producer->rootDomain().begin(), producer->rootDomain().end()),
+ std::vector<Val*>(producer_CA_ids.begin(), producer_CA_ids.end()));
+
+ // Figure out which root IDs we need:
+ std::unordered_set<Val*> producer_CA_root_ids;
+ for (Val* val : producer_root) {
+ if (all_CA_id_deps.find(val) != all_CA_id_deps.end())
+ producer_CA_root_ids.emplace(val);
+ }
+
// Track which root axes in consumer we send to replay
std::unordered_set<IterDomain*> consumer_roots4replay;
// Map related axes from producer and consumer roots. Make sure we go to the
@@ -419,7 +428,8 @@
{
size_t itc = 0, itp = 0;
while (itc < consumer_root.size() || itp < producer_root.size()) {
- if (itc < consumer_root.size() && consumer_root[itc]->isBroadcast()) {
+ if (itc < consumer_root.size() && consumer_root[itc]->isBroadcast() &&
+ (itp > producer_root.size() || !producer_root[itp]->isBroadcast())) {
itc++;
continue;
}
@@ -561,11 +571,11 @@
TensorDomain* replayed = new TensorDomain(
consumer->rootDomain(), consumer->rfactorDomain(), new_IDs);
- return replayed;
+ return {replayed, producer_CA_ids.size()};
}
// replay Producer as Consumer
-TensorView* TransformReplay::replayPasC(
+std::pair<TensorView*, unsigned int> TransformReplay::replayPasC(
TensorView* producer,
TensorView* consumer,
int compute_at_axis) {
@@ -573,26 +583,26 @@
// tensor view. When this happens, just return thet target view.
if (producer == consumer)
- return producer;
+ return {producer, 0};
- TensorDomain* td =
+ std::pair<TensorDomain*, unsigned int> replay =
replayPasC(producer->domain(), consumer->domain(), compute_at_axis);
- producer->setDomain(td);
- return producer;
+ producer->setDomain(replay.first);
+ return {producer, replay.second};
}
-TensorView* TransformReplay::replayCasP(
+std::pair<TensorView*, unsigned int> TransformReplay::replayCasP(
TensorView* consumer,
TensorView* producer,
int compute_at_axis) {
// If this is a reduction operation, we may call transform_replay on the same
// tensor view. When this happens, just return thet target view.
if (consumer == producer)
- return consumer;
- TensorDomain* td =
+ return {consumer, 0};
+ std::pair<TensorDomain*, unsigned int> replay =
replayCasP(consumer->domain(), producer->domain(), compute_at_axis);
- consumer->setDomain(td);
- return consumer;
+ consumer->setDomain(replay.first);
+ return {consumer, replay.second};
}
} // namespace fuser
diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.h b/torch/csrc/jit/codegen/cuda/transform_replay.h
index 65118b7..d43d71f 100644
--- a/torch/csrc/jit/codegen/cuda/transform_replay.h
+++ b/torch/csrc/jit/codegen/cuda/transform_replay.h
@@ -116,40 +116,39 @@
*
*/
-struct TensorDomain;
-struct TensorView;
+class TensorDomain;
+class TensorView;
-struct TORCH_CUDA_API TransformReplay {
- private:
+class TORCH_CUDA_API TransformReplay {
public:
- // Replay producer as consumer.
- static TensorDomain* replayPasC(
- TensorDomain* producer,
- TensorDomain* consumer,
+ // Replay producer as consumer, returns {producer, producer_compute_at_axis}.
+ static std::pair<TensorDomain*, unsigned int> replayPasC(
+ const TensorDomain* producer,
+ const TensorDomain* consumer,
int consumer_compute_at_axis);
- // Replay producer as consumer.
- static TensorView* replayPasC(
+ // Replay producer as consumer, returns {producer, producer_compute_at_axis}.
+ static std::pair<TensorView*, unsigned int> replayPasC(
TensorView* producer,
TensorView* consumer,
int consumer_compute_at_axis);
- // Replay producer as consumer.
- static TensorDomain* replayCasP(
- TensorDomain* consumer,
- TensorDomain* producer,
+ // Replay producer as consumer, returns {consumer, consumer_compute_at_axis}.
+ static std::pair<TensorDomain*, unsigned int> replayCasP(
+ const TensorDomain* consumer,
+ const TensorDomain* producer,
int producer_compute_at_axis);
- // Replay producer as consumer.
- static TensorView* replayCasP(
+ // Replay producer as consumer, returns {consumer, consumer_compute_at_axis}.
+ static std::pair<TensorView*, unsigned int> replayCasP(
TensorView* consumer,
TensorView* producer,
int producer_compute_at_axis);
// Self replay.
static TensorDomain* fullSelfReplay(
- TensorDomain* new_self_root,
- TensorDomain* self);
+ const TensorDomain* new_self_root,
+ const TensorDomain* self);
};
} // namespace fuser
diff --git a/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp b/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp
index d4152bc..ac10289 100644
--- a/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp
+++ b/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp
@@ -11,7 +11,7 @@
namespace {
-struct ReplayRFactor : public ReplayTransformations {
+class ReplayRFactor : public ReplayTransformations {
private:
// Took a good bit of this from ReplayTransformations::handle(Split...)
void handle(Split* s) override {
@@ -258,7 +258,7 @@
id->extent(),
id->parallel_method(),
false,
- true,
+ false,
false);
} else {
new_root[i] = id->clone();
diff --git a/torch/csrc/jit/codegen/cuda/transform_rfactor.h b/torch/csrc/jit/codegen/cuda/transform_rfactor.h
index 9fc2e15..6d0977f 100644
--- a/torch/csrc/jit/codegen/cuda/transform_rfactor.h
+++ b/torch/csrc/jit/codegen/cuda/transform_rfactor.h
@@ -14,7 +14,7 @@
// TODO: Only replay dispatch is really borrowed from TransformIter, we should
// reevaluate the reuse of dispatch for classes that inherit TransformIter.
-struct TORCH_CUDA_API TransformRFactor {
+class TORCH_CUDA_API TransformRFactor {
public:
// Create a copy of td, change its history by presrving axes so they appear in
// the root domain
diff --git a/torch/csrc/jit/codegen/cuda/type.cpp b/torch/csrc/jit/codegen/cuda/type.cpp
index dbdcb33..fee6910 100644
--- a/torch/csrc/jit/codegen/cuda/type.cpp
+++ b/torch/csrc/jit/codegen/cuda/type.cpp
@@ -444,7 +444,7 @@
TORCH_CUDA_API std::ostream& operator<<(
std::ostream& out,
const ParallelType ptype) {
- return out << parallel_type2string(ptype);
+ return out << stringifyThread(ptype);
}
TORCH_CUDA_API std::ostream& operator<<(
@@ -471,6 +471,10 @@
return thread_size2string(ptype);
}
+std::string stringifyThread(const ParallelType ptype) {
+ return parallel_type2string(ptype);
+}
+
TORCH_CUDA_API c10::optional<std::string> cast_func_str(
const std::pair<DataType, DataType>& cast) {
const char* str = supported_casts2string(cast);
@@ -478,6 +482,21 @@
: c10::nullopt;
}
+size_t dataTypeSize(DataType type) {
+ switch (type) {
+ case DataType::Bool:
+ return sizeof(bool);
+ case DataType::Float:
+ return 4;
+ case DataType::Half:
+ return 2;
+ case DataType::Int:
+ return 4;
+ default:
+ TORCH_INTERNAL_ASSERT(false, "Size undefined for data type, ", type);
+ }
+}
+
} // namespace fuser
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h
index 760ff58..afabfa9 100644
--- a/torch/csrc/jit/codegen/cuda/type.h
+++ b/torch/csrc/jit/codegen/cuda/type.h
@@ -26,6 +26,7 @@
enum class DataType { Bool, Float, Half, Int, Null };
enum class ExprType {
+ Invalid,
UnaryOp,
BinaryOp,
TernaryOp,
@@ -93,7 +94,7 @@
// TypeAs,
// Logical Ops
- // Int operations, leave position oif Mod we depend on its location of first
+ // Int operations, leave position of Mod we depend on its location of first
Mod,
CeilDiv,
And,
@@ -136,6 +137,7 @@
TORCH_CUDA_API std::ostream& operator<<(std::ostream&, const ParallelType);
std::string stringifyThreadSize(const ParallelType);
+std::string stringifyThread(const ParallelType);
TORCH_CUDA_API c10::optional<std::string> inline_op_str(const UnaryOpType);
TORCH_CUDA_API c10::optional<std::string> inline_op_str(const BinaryOpType);
@@ -143,6 +145,8 @@
TORCH_CUDA_API c10::optional<std::string> cast_func_str(
const std::pair<DataType, DataType>&);
+size_t dataTypeSize(DataType type);
+
} // namespace fuser
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/utils.cpp b/torch/csrc/jit/codegen/cuda/utils.cpp
index cce431e..6c49b7f 100644
--- a/torch/csrc/jit/codegen/cuda/utils.cpp
+++ b/torch/csrc/jit/codegen/cuda/utils.cpp
@@ -2,8 +2,6 @@
#include <c10/util/Exception.h>
-#include <torch/csrc/jit/codegen/cuda/tensor_meta.h>
-
#include <algorithm>
#include <ostream>