New CUDA Fuser: Unrolling support, interface refactor (#36435)
Summary:
Unrolling support has been added in a way that we get good performing code on GPUs. Not sure how long this link will last but an example of a generated unrolled kernel is:
https://godbolt.org/z/i0uAv3
What can be seen from there is multiple calls of "ld.global.f32" without "ld.store.f32" in between them (and vice versa). This means that we are launching multiple loads that can be run in parallel, as well as multiple stores that can be run in parallel. This can be a crucial optimization for memory bound kernels. This was generally a point of concern in TVM as an attempt of a similar kernel from TVM produces: https://godbolt.org/z/Vu97vG which surrounds load - store pairs in conditional branches preventing the benefits of unrolling.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36435
Reviewed By: ZolotukhinM
Differential Revision: D21024011
Pulled By: soumith
fbshipit-source-id: e852e282fa7a304aba962e1926f756098c011fe0
diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt
index 721bb8a..f10dcfd 100644
--- a/caffe2/CMakeLists.txt
+++ b/caffe2/CMakeLists.txt
@@ -587,8 +587,11 @@
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/ir_iostream.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/iter_visitor.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/kernel.cpp
+ ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/kernel_cache.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/manager.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/mutator.cpp
+ ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_loops.cpp
+ ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_utils.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower2device.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/parser.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/partition.cpp
diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp
index 5837002..c58a7fc 100644
--- a/test/cpp/jit/test_gpu.cpp
+++ b/test/cpp/jit/test_gpu.cpp
@@ -8,7 +8,6 @@
#include <torch/csrc/jit/codegen/cuda/kernel.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/mutator.h>
-#include <torch/csrc/jit/codegen/cuda/tensor.h>
#include <torch/csrc/jit/codegen/cuda/tensor_meta.h>
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
@@ -27,7 +26,7 @@
TensorView* makeDummyTensor(int nDims) {
std::vector<IterDomain*> dom;
for (int i = 0; i < nDims; i++)
- dom.push_back(new IterDomain(new Int()));
+ dom.push_back(new IterDomain(new Int(0), new Int()));
return new TensorView(new TensorDomain(dom), DataType::Float);
}
@@ -384,22 +383,22 @@
tv = tv->split(2, 2);
TORCH_CHECK(tv->nDims() == 4);
- Expr* outer = tv->axis(2)->size()->getOrigin();
+ Expr* outer = tv->axis(2)->extent()->getOrigin();
TORCH_CHECK(
outer->getExprType().value() == ExprType::BinaryOp &&
static_cast<BinaryOp*>(outer)->getBinaryOpType() ==
BinaryOpType::CeilDiv &&
static_cast<BinaryOp*>(outer)->lhs()->sameAs(
- tv->getRootDomain()->axis(2)->size()) &&
+ tv->getRootDomain()->axis(2)->extent()) &&
static_cast<Int*>(static_cast<BinaryOp*>(outer)->rhs())
->sameAs(new Int(2)));
IterDomain* inner = static_cast<IterDomain*>(tv->axis(3));
TORCH_CHECK(
- inner->size()->isScalar() &&
- static_cast<Int*>(inner->size())->isConst() &&
- static_cast<Int*>(inner->size())->value().value() == 2);
+ inner->extent()->isScalar() &&
+ static_cast<Int*>(inner->extent())->isConst() &&
+ static_cast<Int*>(inner->extent())->value().value() == 2);
}
void testGPU_FusionTVMerge() {
@@ -409,15 +408,15 @@
TensorView* tv = makeDummyTensor(3);
tv = tv->merge(1);
- Expr* axisOp = tv->axis(1)->size()->getOrigin();
+ Expr* axisOp = tv->axis(1)->extent()->getOrigin();
TORCH_CHECK(
tv->nDims() == 2 && axisOp->getExprType() == ExprType::BinaryOp &&
static_cast<BinaryOp*>(axisOp)->getBinaryOpType() == BinaryOpType::Mul &&
static_cast<BinaryOp*>(axisOp)->lhs() ==
- tv->getRootDomain()->axis(1)->size() &&
+ tv->getRootDomain()->axis(1)->extent() &&
static_cast<BinaryOp*>(axisOp)->rhs() ==
- tv->getRootDomain()->axis(2)->size());
+ tv->getRootDomain()->axis(2)->extent());
}
void testGPU_FusionTVReorder() {
@@ -527,10 +526,10 @@
void testGPU_FusionParser() {
auto g = std::make_shared<Graph>();
const auto graph0_string = R"IR(
- graph(%0 : Float(2, 3, 4),
- %1 : Float(2, 3, 4)):
- %c0 : Float(2, 3, 4) = aten::mul(%0, %1)
- %d0 : Float(2, 3, 4) = aten::mul(%c0, %0)
+ graph(%0 : Float(2),
+ %1 : Float(2)):
+ %c0 : Float(2) = aten::mul(%0, %1)
+ %d0 : Float(2) = aten::mul(%c0, %0)
return (%d0))IR";
torch::jit::parseIR(graph0_string, g.get());
@@ -555,17 +554,37 @@
<< " return (a + b - 1) / b;\n"
<< "}\n"
<< "\n"
- << "__global__ void CUDAGeneratedKernel(Tensor<float> T4, Tensor<float> T5, Tensor<float> T6){\n"
- << " float T7[1];\n"
- << " if ( ( ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T5.size[2] ) / T5.size[1] ) < T5.size[0] ) && ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T5.size[2] ) % T5.size[1] ) < T5.size[1] ) ) && ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) % T5.size[2] ) < T5.size[2] ) ) ) { \n"
- << " T7[ 0 ]\n"
- << " = T4[ ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T4.size[2] ) / T4.size[1] ) * T4.stride[0] ) + ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T4.size[2] ) % T4.size[1] ) * T4.stride[1] ) + ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) % T4.size[2] ) * T4.stride[2] ) ]\n"
- << " * T5[ ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T5.size[2] ) / T5.size[1] ) * T5.stride[0] ) + ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T5.size[2] ) % T5.size[1] ) * T5.stride[1] ) + ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) % T5.size[2] ) * T5.stride[2] ) ];\n"
+ << "__global__ void CUDAGeneratedKernel(Tensor<float, 1> T0, Tensor<float, 1> T1, Tensor<float, 1> T3){\n"
+ << " float T2[4];\n"
+ << " if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T1.size[0] ) ) { \n"
+ << " for(size_t i64 = 0; i64 < 4; ++i64 ) {\n"
+ << " T2[ i64 ]\n"
+ << " = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i64 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]\n"
+ << " * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i64 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ];\n"
+ << " }\n"
+ << " } else { \n"
+ << " for(size_t i64 = 0; i64 < 4; ++i64 ) {\n"
+ << " if ( ( ( ( ( ( blockIdx.x * 4 ) + i64 ) * 128 ) + threadIdx.x ) < T1.size[0] ) ) { \n"
+ << " T2[ i64 ]\n"
+ << " = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i64 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]\n"
+ << " * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i64 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ];\n"
+ << " }\n"
+ << " }\n"
<< " }\n"
- << " if ( ( ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T6.size[2] ) / T6.size[1] ) < T6.size[0] ) && ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T6.size[2] ) % T6.size[1] ) < T6.size[1] ) ) && ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) % T6.size[2] ) < T6.size[2] ) ) ) { \n"
- << " T6[ ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T6.size[2] ) / T6.size[1] ) * T6.stride[0] ) + ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T6.size[2] ) % T6.size[1] ) * T6.stride[1] ) + ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) % T6.size[2] ) * T6.stride[2] ) ]\n"
- << " = T7[ 0 ]\n"
- << " * T4[ ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T4.size[2] ) / T4.size[1] ) * T4.stride[0] ) + ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) / T4.size[2] ) % T4.size[1] ) * T4.stride[1] ) + ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) % T4.size[2] ) * T4.stride[2] ) ];\n"
+ << " if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { \n"
+ << " for(size_t i65 = 0; i65 < 4; ++i65 ) {\n"
+ << " T3[ ( ( ( ( ( blockIdx.x * 4 ) + i65 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ]\n"
+ << " = T2[ i65 ]\n"
+ << " * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i65 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ];\n"
+ << " }\n"
+ << " } else { \n"
+ << " for(size_t i65 = 0; i65 < 4; ++i65 ) {\n"
+ << " if ( ( ( ( ( ( blockIdx.x * 4 ) + i65 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { \n"
+ << " T3[ ( ( ( ( ( blockIdx.x * 4 ) + i65 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ]\n"
+ << " = T2[ i65 ]\n"
+ << " * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i65 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ];\n"
+ << " }\n"
+ << " }\n"
<< " }\n"
<< "}\n";
@@ -675,63 +694,54 @@
fusion.addOutput(tv2);
tv0->computeAt(tv2, -1);
- /*
- std::stringstream ref;
- ref
- << "__device__ int ceilDiv(const int a, const int b) {\n"
- << " return (a + b - 1) / b;\n"
- << "}\n"
- << "\n"
- << "__global__ void CUDAGeneratedKernel(Tensor<float> T3){\n"
- << " for(size_t i49{0}; i49 < i38; ++i49 ) {\n"
- << " for(size_t i50{0}; i50 < i37; ++i50 ) {\n"
- << " for(size_t i51{0}; i51 < 2; ++i51 ) {\n"
- << " for(size_t i52{0}; i52 < i41; ++i52 ) {\n"
- << " float T5[1];\n"
- << " if ( ( ( ( ( ( i50 * 4 ) + ( i49 / T3.size[1] ) ) < T3.size[0]
- ) && ( ( i49 % T3.size[1] ) < T3.size[1] ) ) && ( ( ( i52 * 2 ) + i51 ) <
- T3.size[2] ) ) ) { \n"
- << " T5[ 0 ]\n"
- << " = float(0)\n"
- << " + float(1);\n"
- << " }\n"
- << " float T4[1];\n"
- << " if ( ( ( ( ( ( i50 * 4 ) + ( i49 / T3.size[1] ) ) < T3.size[0]
- ) && ( ( i49 % T3.size[1] ) < T3.size[1] ) ) && ( ( ( i52 * 2 ) + i51 ) <
- T3.size[2] ) ) ) { \n"
- << " T4[ 0 ]\n"
- << " = T5[ 0 ]\n"
- << " + float(2);\n"
- << " }\n"
- << " if ( ( ( ( ( ( i50 * 4 ) + ( i49 / T3.size[1] ) ) < T3.size[0]
- ) && ( ( i49 % T3.size[1] ) < T3.size[1] ) ) && ( ( ( i52 * 2 ) + i51 ) <
- T3.size[2] ) ) ) { \n"
- << " T3[ ( ( i50 * 4 ) + ( i49 / T3.size[1] ) ) + ( i49 %
- T3.size[1] ) + ( ( i52 * 2 ) + i51 ) ]\n"
- << " = T4[ 0 ]\n"
- << " + float(3);\n"
- << " }\n"
- << " }\n"
- << " }\n"
- << " }\n"
- << " }\n"
- << "}\n"
- ;
- GPULower gpulw(&fusion);
- std::stringstream cdg;
- gpulw.printKernel(cdg);
+ std::stringstream ref;
+ ref << "__device__ int ceilDiv(const int a, const int b) {\n"
+ << " return (a + b - 1) / b;\n"
+ << "}\n"
+ << "\n"
+ << "__global__ void CUDAGeneratedKernel(Tensor<float, 3> T2){\n"
+ << " for(size_t i82 = 0; i82 < ( 4 * T2.size[1] ); ++i82 ) {\n"
+ << " for(size_t i83 = 0; i83 < ( ceilDiv(T2.size[0], 4) ); ++i83 ) {\n"
+ << " for(size_t i84 = 0; i84 < 2; ++i84 ) {\n"
+ << " for(size_t i85 = 0; i85 < ( ceilDiv(T2.size[2], 2) ); ++i85 ) {\n"
+ << " float T0[1];\n"
+ << " if ( ( ( ( ( ( i83 * 4 ) + ( i82 / T2.size[1] ) ) < T2.size[0] ) && ( ( i82 % T2.size[1] ) < T2.size[1] ) ) && ( ( ( i85 * 2 ) + i84 ) < T2.size[2] ) ) ) { \n"
+ << " T0[ 0 ]\n"
+ << " = float(0)\n"
+ << " + float(1);\n"
+ << " }\n"
+ << " float T1[1];\n"
+ << " if ( ( ( ( ( ( i83 * 4 ) + ( i82 / T2.size[1] ) ) < T2.size[0] ) && ( ( i82 % T2.size[1] ) < T2.size[1] ) ) && ( ( ( i85 * 2 ) + i84 ) < T2.size[2] ) ) ) { \n"
+ << " T1[ 0 ]\n"
+ << " = T0[ 0 ]\n"
+ << " + float(2);\n"
+ << " }\n"
+ << " if ( ( ( ( ( ( i83 * 4 ) + ( i82 / T2.size[1] ) ) < T2.size[0] ) && ( ( i82 % T2.size[1] ) < T2.size[1] ) ) && ( ( ( i85 * 2 ) + i84 ) < T2.size[2] ) ) ) { \n"
+ << " T2[ ( ( ( i83 * 4 ) + ( i82 / T2.size[1] ) ) * T2.stride[0] ) + ( ( i82 % T2.size[1] ) * T2.stride[1] ) + ( ( ( i85 * 2 ) + i84 ) * T2.stride[2] ) ]\n"
+ << " = T1[ 0 ]\n"
+ << " + float(3);\n"
+ << " }\n"
+ << " }\n"
+ << " }\n"
+ << " }\n"
+ << " }\n"
+ << "}\n";
- if (ref.str().size() != cdg.str().size() ||
- ref.str().compare(cdg.str()) != 0) {
- std::cerr
- << " Codegen mismatch, codegen possibly changed, or is incorrect. "
- << " \n ========= REF ========= \n"
- << ref.str() << "\n========= RESULT ========== \n"
- << cdg.str() << "\n=================" << std::endl;
- TORCH_CHECK(false);
- }
- */
+ GPULower gpulw(&fusion);
+ std::stringstream cdg;
+ gpulw.printKernel(cdg);
+
+ if (ref.str().size() != cdg.str().size() ||
+ ref.str().compare(cdg.str()) != 0) {
+ std::cerr
+ << " Codegen mismatch, codegen possibly changed, or is incorrect. "
+ << " \n ========= REF ========= \n"
+ << ref.str() << "\n========= RESULT ========== \n"
+ << cdg.str() << "\n=================" << std::endl;
+ TORCH_CHECK(false);
+ }
+
torch::jit::fuser::cuda::CudaKernel prog;
prog.device_ = 0;
// These can be set to anything as there are no bindings!
@@ -744,7 +754,7 @@
at::Tensor output = at::empty({16, 8, 8}, options);
std::vector<at::Tensor> outputs{{output}};
- torch::jit::fuser::cuda::compileKernel(fusion, prog);
+ torch::jit::fuser::cuda::compileKernel(fusion, &prog);
torch::jit::fuser::cuda::runTestKernel(prog, {}, outputs);
at::Tensor output_ref = at::zeros_like(output, options);
@@ -780,35 +790,28 @@
tv3->axis(0)->parallelize(ParallelType::BIDx);
tv3->axis(-1)->parallelize(ParallelType::TIDx);
- /*
std::stringstream ref;
- ref
- << "__device__ int ceilDiv(const int a, const int b) {\n"
- << " return (a + b - 1) / b;\n"
- << "}\n"
- << "\n"
- << "__global__ void kernel(Tensor<float> T0, Tensor<float> T1, Tensor<float>
- T3){\n"
- << " for( size_t i13 = 0; i13 < 4; ++i13 ) {\n"
- << " for( size_t i15 = 0; i15 < T1.size[1]; ++i15 ) {\n"
- << " float T2[1];\n"
- << " if( ( ( ( blockIdx.x * 4 ) + i13 ) < T1.size[0] ) ) {\n"
- << " T2[0]\n"
- << " = T1[( ( blockIdx.x * 4 ) + i13 ) * T1.stride[0] + i15 *
- T1.stride[1] + threadIdx.x * T1.stride[2]]\n"
- << " + float(2);\n"
- << " }\n"
- << " if( ( ( ( blockIdx.x * 4 ) + i13 ) < T1.size[0] ) ) {\n"
- << " T3[( ( blockIdx.x * 4 ) + i13 ) * T3.stride[0] + i15 *
- T3.stride[1] + threadIdx.x * T3.stride[2]]\n"
- << " = T0[( ( blockIdx.x * 4 ) + i13 ) * T0.stride[0] + i15 *
- T0.stride[1] + threadIdx.x * T0.stride[2]]\n"
- << " + T2[0];\n"
- << " }\n"
- << " }\n"
- << " }\n"
- << "}\n"
- ;
+ ref << "__device__ int ceilDiv(const int a, const int b) {\n"
+ << " return (a + b - 1) / b;\n"
+ << "}\n"
+ << "\n"
+ << "__global__ void CUDAGeneratedKernel(Tensor<float, 3> T0, Tensor<float, 3> T1, Tensor<float, 3> T3){\n"
+ << " for(size_t i33 = 0; i33 < 4; ++i33 ) {\n"
+ << " for(size_t i34 = 0; i34 < T3.size[1]; ++i34 ) {\n"
+ << " float T2[1];\n"
+ << " if ( ( ( ( blockIdx.x * 4 ) + i33 ) < T3.size[0] ) ) { \n"
+ << " T2[ 0 ]\n"
+ << " = T1[ ( ( ( blockIdx.x * 4 ) + i33 ) * T1.stride[0] ) + ( i34 * T1.stride[1] ) + ( threadIdx.x * T1.stride[2] ) ]\n"
+ << " + float(2);\n"
+ << " }\n"
+ << " if ( ( ( ( blockIdx.x * 4 ) + i33 ) < T3.size[0] ) ) { \n"
+ << " T3[ ( ( ( blockIdx.x * 4 ) + i33 ) * T3.stride[0] ) + ( i34 * T3.stride[1] ) + ( threadIdx.x * T3.stride[2] ) ]\n"
+ << " = T0[ ( ( ( blockIdx.x * 4 ) + i33 ) * T0.stride[0] ) + ( i34 * T0.stride[1] ) + ( threadIdx.x * T0.stride[2] ) ]\n"
+ << " + T2[ 0 ];\n"
+ << " }\n"
+ << " }\n"
+ << " }\n"
+ << "}\n";
GPULower gpulw(&fusion);
std::stringstream cdg;
@@ -823,7 +826,7 @@
<< cdg.str() << "\n=================" << std::endl;
TORCH_CHECK(false);
}
- */
+
torch::jit::fuser::cuda::CudaKernel prog;
prog.device_ = 0;
prog.grid(4);
@@ -838,7 +841,7 @@
std::vector<at::Tensor> inputs{{input1, input2}};
std::vector<at::Tensor> outputs{{output}};
- torch::jit::fuser::cuda::compileKernel(fusion, prog);
+ torch::jit::fuser::cuda::compileKernel(fusion, &prog);
torch::jit::fuser::cuda::runTestKernel(prog, inputs, outputs);
at::Tensor tv2_ref = input2 + 2.0;
@@ -856,7 +859,7 @@
// Set up symbolic sizes for the axes should be dimensionality of the problem
std::vector<IterDomain*> dom;
for (int i = 0; i < nDims; i++)
- dom.push_back(new IterDomain(new Int()));
+ dom.push_back(new IterDomain(new Int(0), new Int()));
// Set up your input tensor views
TensorView* tv0 = new TensorView(new TensorDomain(dom), DataType::Float);
@@ -907,7 +910,7 @@
std::vector<at::Tensor> inputs{{input1, input2}};
std::vector<at::Tensor> outputs{{output}};
- torch::jit::fuser::cuda::compileKernel(fusion, prog);
+ torch::jit::fuser::cuda::compileKernel(fusion, &prog);
torch::jit::fuser::cuda::runTestKernel(prog, inputs, outputs);
at::Tensor tv2_ref = input2 + 2.0;
@@ -936,13 +939,18 @@
// Register your outputs
fusion.addOutput(tv3);
+ tv3->split(0, 4);
+
// For all inputs, computeAt the output inline, temporaries should be squeezed
// between them
- tv0->computeAt(tv3, -1);
- tv1->computeAt(tv3, -1);
+ tv0->computeAt(tv3, 1);
+ tv1->computeAt(tv3, 1);
// Parallelize TV3
tv3->axis(0)->parallelize(ParallelType::BIDx);
+ tv2->axis(1)->parallelize(ParallelType::Unroll);
+ tv3->axis(1)->parallelize(ParallelType::Unroll);
+ tv2->axis(-1)->parallelize(ParallelType::TIDx);
tv3->axis(-1)->parallelize(ParallelType::TIDx);
torch::jit::fuser::cuda::CudaKernel prog;
@@ -959,7 +967,7 @@
std::vector<at::Tensor> inputs{{input1, input2}};
std::vector<at::Tensor> outputs{{output}};
- torch::jit::fuser::cuda::compileKernel(fusion, prog);
+ torch::jit::fuser::cuda::compileKernel(fusion, &prog);
torch::jit::fuser::cuda::runTestKernel(prog, inputs, outputs);
at::Tensor check = at::full({1, 128}, 4, options);
@@ -972,14 +980,16 @@
FusionGuard fg(&fusion);
const auto TV0 = new TensorView(
- new TensorDomain({new IterDomain(new Int(16))}), DataType::Float);
+ new TensorDomain({new IterDomain(new Int(0), new Int(16))}),
+ DataType::Float);
const auto TV1 = new TensorView(
- new TensorDomain({new IterDomain(new Int(16))}), DataType::Float);
+ new TensorDomain({new IterDomain(new Int(0), new Int(16))}),
+ DataType::Float);
fusion.addInput(TV0);
fusion.addInput(TV1);
- auto ID0 = new IterDomain(new Int(8));
+ auto ID0 = new IterDomain(new Int(0), new Int(8));
TensorView* TV2 = static_cast<TensorView*>(add(TV0, TV1));
BinaryOp* op = static_cast<BinaryOp*>(TV2->getOrigin());
@@ -1000,7 +1010,69 @@
}
}
-void testGPU_Fusion() {}
+void testGPU_FusionLoopUnroll() {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
+
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(1);
+ TensorView* tv1 = makeDummyTensor(1);
+
+ // Register your inputs
+ fusion.addInput(tv0);
+ fusion.addInput(tv1);
+
+ // Do math with it, it returns a `Val*` but can be static_casted back to
+ // TensorView
+ TensorView* tv2 = static_cast<TensorView*>(add(tv1, new Float(2.0)));
+ TensorView* tv3 = static_cast<TensorView*>(add(tv0, tv2));
+
+ // Register your outputs
+ fusion.addOutput(tv3);
+
+ int block_size = 16;
+
+ tv3->split(0, block_size);
+ tv3->split(0, 4);
+
+ // For all inputs, computeAt the output inline, temporaries should be squeezed
+ // between them
+ tv0->computeAt(tv3, 1);
+ tv1->computeAt(tv3, 1);
+
+ // Parallelize
+ tv2->axis(1)->parallelize(ParallelType::Unroll);
+ tv3->axis(1)->parallelize(ParallelType::Unroll);
+ tv2->axis(-1)->parallelize(ParallelType::TIDx);
+ tv3->axis(-1)->parallelize(ParallelType::TIDx);
+ tv3->axis(0)->parallelize(ParallelType::BIDx);
+
+ // GPULower lower(&fusion);
+ // lower.printKernel(std::cout);
+
+ int inp_size = 129;
+
+ torch::jit::fuser::cuda::CudaKernel prog;
+ prog.device_ = 0;
+ prog.grid((inp_size + 63) / 64);
+ prog.block(block_size);
+
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+
+ at::Tensor input1 = at::ones({inp_size}, options);
+ at::Tensor input2 = at::ones_like(input1);
+
+ at::Tensor output = at::empty_like(input1);
+ std::vector<at::Tensor> inputs{{input1, input2}};
+ std::vector<at::Tensor> outputs{{output}};
+
+ torch::jit::fuser::cuda::compileKernel(fusion, &prog);
+ torch::jit::fuser::cuda::runTestKernel(prog, inputs, outputs);
+
+ at::Tensor check = at::full({inp_size}, 4, options);
+
+ TORCH_CHECK(output.equal(check));
+}
} // namespace jit
} // namespace torch
diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h
index 32c5695..8c4aca4 100644
--- a/test/cpp/jit/tests.h
+++ b/test/cpp/jit/tests.h
@@ -117,7 +117,8 @@
_(GPU_FusionCodeGen2) \
_(GPU_FusionSimplePWise) \
_(GPU_FusionExecKernel) \
- _(GPU_FusionForLoop)
+ _(GPU_FusionForLoop) \
+ _(GPU_FusionLoopUnroll)
#else
#define TH_FORALL_TESTS_CUDA(_) \
_(ArgumentSpec) \
diff --git a/test/test_jit_cuda_fuser.py b/test/test_jit_cuda_fuser.py
index 25b29c8..d7af37e 100644
--- a/test/test_jit_cuda_fuser.py
+++ b/test/test_jit_cuda_fuser.py
@@ -86,20 +86,57 @@
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
@skipIfRocm
def test_scalar_input(self):
- def t(x, y, z):
- # type: (Tensor, Tensor, float) -> Tensor
+ def t(x : torch.Tensor, y : torch.Tensor, z : float):
o = x + y
o = o + z
return o
t_jit = torch.jit.script(t)
- x = torch.randn(4, 8, dtype=torch.float, device="cuda")
- y = torch.randn(4, 8, dtype=torch.float, device="cuda")
+ x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
+ y = torch.randn(4, 8, 1, 32, dtype=torch.float, device="cuda")
+ y = y.expand(4, 8, 32, 32)
jit_o = t_jit(x, y, 2.0)
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, 2.0)))
+ @unittest.skipIf(not RUN_CUDA, "requires CUDA")
+ @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
+ @skipIfRocm
+ def test_broadcasting(self):
+ def t(x : torch.Tensor, y : torch.Tensor, z : float):
+ o = x + y
+ o = o + z
+ return o
+ t_jit = torch.jit.script(t)
+ x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
+ y = torch.randn(32, 32, dtype=torch.float, device="cuda")
+ jit_o = t_jit(x, y, 2.0)
+ jit_o = t_jit(x, y, 2.0)
+ o = t(x, y, 2.0)
+ self.assertEqual(o, jit_o)
+ self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, 2.0)))
+
+ @unittest.skipIf(not RUN_CUDA, "requires CUDA")
+ @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
+ @skipIfRocm
+ def test_broadcasting_multiple_output_shape(self):
+ def t(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor):
+ o = x + 12
+ o1 = o + y
+ o2 = o + z
+ oo = o1.sum() + o2.sum()
+ return oo
+ t_jit = torch.jit.script(t)
+ x = torch.randn(32, 32, dtype=torch.float, device="cuda")
+ y = torch.randn(2, 32, 32, dtype=torch.float, device="cuda")
+ z = torch.randn(4, 32, 32, dtype=torch.float, device="cuda")
+ jit_o = t_jit(x, y, z)
+ jit_o = t_jit(x, y, z)
+ o = t(x, y, z)
+ self.assertEqual(o, jit_o)
+ # Currently cannot fuse this
+ self.assertFalse(self._has_cuda_fusion_group(t_jit.graph_for(x, y, z)))
if __name__ == '__main__':
run_tests()
diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl
index 78a150b..6da96a1 100644
--- a/tools/build_variables.bzl
+++ b/tools/build_variables.bzl
@@ -246,6 +246,9 @@
"torch/csrc/jit/codegen/cuda/ir_iostream.cpp",
"torch/csrc/jit/codegen/cuda/iter_visitor.cpp",
"torch/csrc/jit/codegen/cuda/kernel.cpp",
+ "torch/csrc/jit/codegen/cuda/kernel_cache.cpp",
+ "torch/csrc/jit/codegen/cuda/lower_loops.cpp",
+ "torch/csrc/jit/codegen/cuda/lower_utils.cpp",
"torch/csrc/jit/codegen/cuda/lower2device.cpp",
"torch/csrc/jit/codegen/cuda/manager.cpp",
"torch/csrc/jit/codegen/cuda/mutator.cpp",
diff --git a/torch/csrc/jit/codegen/cuda/data_struct_str.h b/torch/csrc/jit/codegen/cuda/data_struct_str.h
deleted file mode 100644
index bcfae73..0000000
--- a/torch/csrc/jit/codegen/cuda/data_struct_str.h
+++ /dev/null
@@ -1,11 +0,0 @@
-STRINGIFY(template <typename T> struct Tensor {
- public:
- T& operator[](int ind) {
- return data[ind];
- };
-
- int size[8];
- int stride[8];
- T* data;
- int nDim;
-};)
diff --git a/torch/csrc/jit/codegen/cuda/dispatch.cpp b/torch/csrc/jit/codegen/cuda/dispatch.cpp
index f37fe39..8c871ba 100644
--- a/torch/csrc/jit/codegen/cuda/dispatch.cpp
+++ b/torch/csrc/jit/codegen/cuda/dispatch.cpp
@@ -1,6 +1,5 @@
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/tensor.h>
#include <torch/csrc/jit/codegen/cuda/type.h>
#include <torch/csrc/jit/codegen/cuda/dispatch.h>
diff --git a/torch/csrc/jit/codegen/cuda/fusion.cpp b/torch/csrc/jit/codegen/cuda/fusion.cpp
index 82239e8..0567293 100644
--- a/torch/csrc/jit/codegen/cuda/fusion.cpp
+++ b/torch/csrc/jit/codegen/cuda/fusion.cpp
@@ -1,4 +1,5 @@
#include <torch/csrc/jit/codegen/cuda/fusion.h>
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_printer.h>
namespace torch {
@@ -33,6 +34,22 @@
return es.exprs;
}
+void InputsOf::handle(TensorView* tv) {
+ if (FusionGuard::getCurFusion()->hasInput(tv))
+ inputs.push_back(tv);
+}
+
+std::vector<TensorView*> InputsOf::output(Fusion* fusion, Val* output_) {
+ TORCH_CHECK(
+ fusion->hasOutput(output_),
+ "Asked for the inputs of ",
+ output_,
+ " however, it is not an output of the provided fusion.");
+ InputsOf io;
+ io.traverseFrom(FusionGuard::getCurFusion(), {output_});
+ return io.inputs;
+}
+
Fusion::~Fusion() {
{
auto it = val_set_.begin();
@@ -140,6 +157,10 @@
return ExprSort::getExprs(this, from_outputs_only, breadth_first);
}
+std::vector<TensorView*> Fusion::inputsOf(Val* val) {
+ return InputsOf::output(this, val);
+}
+
void Fusion::print() {
FusionGuard fg(this);
std::cout << "%kernel {\n";
diff --git a/torch/csrc/jit/codegen/cuda/fusion.h b/torch/csrc/jit/codegen/cuda/fusion.h
index 15facd7..342ee81 100644
--- a/torch/csrc/jit/codegen/cuda/fusion.h
+++ b/torch/csrc/jit/codegen/cuda/fusion.h
@@ -49,6 +49,7 @@
*/
struct Fusion;
+struct TensorView;
// Fusion Guard is our "context manager". It holds the actrive fusion and allows
// it to be accessed anywhere through FusionGuard::getCurFusion().
@@ -79,6 +80,20 @@
bool breadth_first);
};
+// Expr sort will take a fusion and return a topologically sorted list of
+// expressions.
+struct InputsOf : public IterVisitor {
+ using IterVisitor::handle;
+
+ private:
+ std::vector<TensorView*> inputs;
+
+ void handle(TensorView* tv) override;
+
+ public:
+ static std::vector<TensorView*> output(Fusion* fusion, Val* output_);
+};
+
/*
* Fusion is mutable but unique. Nodes cannot be copied in any way from one
* Fusion to another. If anything like that is desired, it would require
@@ -139,6 +154,8 @@
bool from_outputs_only = false,
bool breadth_first = false);
+ std::vector<TensorView*> inputsOf(Val* val);
+
// Print this fusion to cout.
void print();
@@ -174,8 +191,6 @@
// Return the Expr that produces val (const version)
const Expr* origin(const Val* val) const;
- bool lowered = false;
-
private:
// Sets of all Vals/Exprs registered with this fusion
std::set<Val*> val_set_;
diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp
index ddc3e2e..98046c1 100644
--- a/torch/csrc/jit/codegen/cuda/index_compute.cpp
+++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp
@@ -21,7 +21,7 @@
ax >= 0 && ax < indices.size(),
"Hit an invalid MERGE transformation during IndexCompute, axis is not within bounds.");
- Val* I = expr->in()->axis(ax + 1)->size();
+ Val* I = expr->in()->axis(ax + 1)->extent();
Val* ind = indices[ax];
indices[ax] = div(ind, I);
indices.insert(indices.begin() + ax + 1, mod(ind, I));
@@ -62,10 +62,10 @@
TensorDomain* td = tv->domain();
- bool exclude_reduction = td->size() > indices.size();
+ bool exclude_reduction = td->nDims() > indices.size();
TORCH_CHECK(
- exclude_reduction || td->size() == indices.size(),
+ exclude_reduction || td->nDims() == indices.size(),
"For IndexCompute the number of axis should match the number of dimensions"
" in the TensorView.");
@@ -73,7 +73,7 @@
// being consumed, not produced, then insert dummy dimensions in the
// indices for bookkeeping while replaying split/merge/reorder operations.
if (exclude_reduction)
- for (decltype(td->size()) i{0}; i < td->size(); i++)
+ for (decltype(td->nDims()) i{0}; i < td->nDims(); i++)
if (td->axis(i)->isReduction())
indices.insert(indices.begin() + i, new Int(-1));
@@ -83,7 +83,7 @@
TensorDomain* root = TransformIter::runBackward(td, true);
TORCH_INTERNAL_ASSERT(
- root->size() == indices.size(),
+ root->nDims() == indices.size(),
"Error during IndexCompute. The number of indices generated"
" after running the transformations backwards should match"
" the number of dimensions of the root TensorView.");
@@ -91,7 +91,7 @@
// Remove indices associated with reduction axes, we had them just for
// bookkeeping.
if (exclude_reduction) {
- for (auto i = root->size() - 1; i >= 0; i--)
+ for (auto i = root->nDims() - 1; i >= 0; i--)
if (root->axis(i)->isReduction())
indices.erase(indices.begin() + i);
}
diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp
index ca5750f..29167ba 100644
--- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp
+++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp
@@ -1,9 +1,8 @@
-#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_printer.h>
#include <torch/csrc/jit/codegen/cuda/mutator.h>
-#include <torch/csrc/jit/codegen/cuda/tensor.h>
#include <torch/csrc/jit/codegen/cuda/dispatch.h>
@@ -41,6 +40,9 @@
}
}
+// Traverse origin of all values involved in constructing the provided val.
+// Check if all values involved are constant values, meaning the provided
+// val is also a constant value.
namespace {
struct ConstCheck : OptOutConstDispatch {
@@ -64,7 +66,6 @@
void handle(const NamedScalar* const ns) override {
is_const_ = false;
}
-
void handle(const Val* const val) override {
const Expr* orig = FusionGuard::getCurFusion()->origin(val);
if (orig != nullptr)
@@ -88,6 +89,22 @@
return ConstCheck::isConst(this);
}
+bool Val::isZeroInt() const {
+ if (isConstScalar() && getValType().value() == ValType::Scalar &&
+ getDataType().value() == DataType::Int &&
+ static_cast<const Int*>(this)->value().value() == 0)
+ return true;
+ return false;
+}
+
+bool Val::isOneInt() const {
+ if (isConstScalar() && getValType().value() == ValType::Scalar &&
+ getDataType().value() == DataType::Int &&
+ static_cast<const Int*>(this)->value().value() == 1)
+ return true;
+ return false;
+}
+
c10::optional<DataType> Val::getDataType() const {
TORCH_INTERNAL_ASSERT(
dtype_ != DataType::Null, "Value does not have a data type.");
@@ -147,6 +164,10 @@
return true;
}
+void Scope::clear() {
+ this->exprs_ = std::vector<Expr*>();
+}
+
bool IRInputOutput::hasInput(const Val* const input) const {
for (auto val : inputs_)
if (val == input)
diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h
index 2c493fd..884df4e 100644
--- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h
+++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h
@@ -178,6 +178,9 @@
return isScalar() && dtype_ == DataType::Int;
}
+ bool isZeroInt() const;
+ bool isOneInt() const;
+
// Returns the Expr that this value is an output of, returns nullptr if none
// was found
Expr* getOrigin();
@@ -251,6 +254,8 @@
bool sameAs(const Scope& other) const;
+ void clear();
+
private:
std::vector<Expr*> exprs_;
};
diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h
index 3bedc89..7031bb3 100644
--- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h
+++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h
@@ -18,7 +18,7 @@
/*
* A Float32 value. For now we don't have any other type besides
- * Float32.reorder_ This value can be a symbolic value (defined after the kernel
+ * Float32. This value can be a symbolic value (defined after the kernel
* is compiled) or a constant value (inlined into the kernel definition).
*/
struct TORCH_CUDA_API Float : public Val {
@@ -76,13 +76,14 @@
return maybe_value_;
}
- bool sameAs(const Int* const other) const;
+ virtual bool sameAs(const Int* const other) const;
private:
const c10::optional<int> maybe_value_;
};
struct TransformReplay;
+struct TransformIter;
struct OptOutMutator;
struct GPULower;
/*
@@ -92,6 +93,14 @@
* these transformations are kept and used for generating actual code referncing
* physical memory. Generally when users are thinking of code generation in
* reference to a Tensor, this is the class they should be interacting with.
+ *
+ * The reason we need both TensorView and TensorDomain is that we need to have a
+ * record of both what is being computed and how it is being computed. For
+ * Example we may have the operation: TV3[I, J, K] = TV2[I, J, K] + TV1[I, J, K]
+ * The mathematical operationss here are on the tensor views TV1, TV2, and TV3.
+ * This operation is a pointwise operation. To compute this pointwise operation
+ * we iterate over the 3D TensorDomain [I, J, K], where K is the fastest
+ * changing dimension.
*/
struct TORCH_CUDA_API TensorView : public Val {
~TensorView() = default;
@@ -156,28 +165,17 @@
// Split "axis" into 2 axes where the inner axes is size of "factor"
// and outer axis is size axis.size() / factor
- TensorView* split(int axis, int factor) {
- return split_(this, axis, factor);
- }
+ TensorView* split(int axis, int factor);
// Merge "axis" and "axis+1" into 1 dimension
- TensorView* merge(int axis) {
- return merge_(this, axis);
- }
+ TensorView* merge(int axis);
- // Reorder axes according to map[old_pos] = new_pos
- TensorView* reorder(const std::unordered_map<int, int>& map) {
- return reorder_(this, map);
- }
+ // Reorder axes according to axis2pos[old_pos] = new_pos
+ TensorView* reorder(const std::unordered_map<int, int>& axis2pos);
- // Implementations for split/merge/reorder
- friend TORCH_CUDA_API TensorView* split_(TensorView*, int axis, int factor);
- friend TORCH_CUDA_API TensorView* merge_(TensorView*, int axis);
- friend TORCH_CUDA_API TensorView* reorder_(
- TensorView*,
- const std::unordered_map<int, int>&);
- friend TORCH_CUDA_API OptOutMutator;
friend TORCH_CUDA_API TransformReplay;
+ friend TORCH_CUDA_API TransformIter;
+ friend TORCH_CUDA_API OptOutMutator;
friend TORCH_CUDA_API GPULower;
protected:
diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h
index 58cf680..e49a1eb 100644
--- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h
+++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h
@@ -99,9 +99,10 @@
};
/*
- * Simply a representation of an iterable from 0 to size. TensorDomains which
- * represent how to iterate over a tensor is made up of IterDomains. We directly
- * set parallization strategies on IterDomains.
+ * Simply a representation of an annotated 1D iterable from start to extent.
+ * TensorDomains which represent how to iterate over a tensor is made up of
+ * IterDomains to form an ND iterable. We directly set parallization strategies
+ * on IterDomains.
*/
struct TORCH_CUDA_API IterDomain : public Val {
~IterDomain() = default;
@@ -109,7 +110,8 @@
IterDomain() = delete;
IterDomain(
- Val* int_size,
+ Val* _start,
+ Val* _extent,
ParallelType _parallel_method = ParallelType::Serial,
bool _reduction_domain = false);
@@ -157,7 +159,14 @@
TORCH_CHECK(
t != ParallelType::Vectorize, "Vectorization not yet supported.");
if (t == ParallelType::Unroll)
- TORCH_CHECK(false, "Unrolling not yet supported.");
+ TORCH_CHECK(
+ start()->isZeroInt() && extent()->isConstScalar(),
+ "Unrolling only supported with start = 0 and extent as a const int, but got ",
+ "a start of ",
+ start(),
+ " and extent ",
+ extent(),
+ " .");
}
}
@@ -165,7 +174,10 @@
return parallel_method_;
}
- Val* size() const;
+ Val* start() const noexcept {
+ return start_;
+ }
+ Val* extent() const;
IterDomain(const IterDomain& other) = delete;
IterDomain& operator=(const IterDomain& other) = delete;
@@ -174,12 +186,24 @@
IterDomain& operator=(IterDomain&& other) = delete;
private:
- Val* const size_;
+ Val* const start_;
+ Val* const extent_;
ParallelType parallel_method_ = ParallelType::Serial;
bool is_reduction_domain_;
};
-
-// A list of IterDomains representing how to iterate across a given Tensor.
+/*
+ * TensorDomain holds a vector of IterDomains. It holds an IterDomain for every
+ * logical axis in its associated tensor. TensorDomain does not directly hold
+ * the Tensor it is associated with, and in theory could be associated with
+ * multiple tensors. TensorDomain's primary responsibility is to provide a
+ * mechanism to access history of transformations that were used to generate it.
+ * This is done through the normal interaction of Expr/Val in Fusion. i.e. if we
+ * want to know the previous operation generating a particular TensorDomain we
+ * can simply call FusionGuard::getCurFusion()->origin(a_tensor_domain) which
+ * should give us an operation in the list [split, merge, reorder] or similar
+ * operations that take in a TensorDomain, applies a transformation and outputs
+ * a tensor domain.
+ */
struct TORCH_CUDA_API TensorDomain : public Val {
~TensorDomain() = default;
@@ -192,7 +216,7 @@
TensorDomain(std::vector<IterDomain*> _domain)
: Val(ValType::TensorDomain), domain_(_domain) {}
- std::vector<IterDomain*>::size_type size() const {
+ std::vector<IterDomain*>::size_type nDims() const {
return domain_.size();
}
@@ -208,6 +232,18 @@
// uint.
IterDomain* axis(int i) const;
+ // Split "axis" into 2 axes where the inner axes is size of "factor"
+ // and outer axis is size axis.size() / factor
+ TensorDomain* split(int axis, int factor);
+
+ // Merge "axis" and "axis+1" into 1 dimension
+ TensorDomain* merge(int axis);
+
+ // Reorder axes according to map[old_pos] = new_pos
+ TensorDomain* reorder(const std::unordered_map<int, int>& axis2pos);
+
+ TensorDomain* rootDomain();
+
private:
std::vector<IterDomain*> domain_;
};
@@ -329,7 +365,7 @@
~ForLoop() = default;
ForLoop(
Val* _index,
- IterDomain* _range,
+ IterDomain* _iter_domain,
const std::vector<Expr*>& _body = {},
Expr* parent_scope = nullptr);
@@ -343,8 +379,8 @@
return index_;
}
- IterDomain* range() const noexcept {
- return range_;
+ IterDomain* iter_domain() const noexcept {
+ return iter_domain_;
}
Scope& body() noexcept {
@@ -365,7 +401,7 @@
private:
Val* const index_;
- IterDomain* const range_;
+ IterDomain* const iter_domain_;
Scope body_;
Expr* parent_scope_;
};
@@ -458,7 +494,7 @@
"Cannot index with a value other than an int.");
}
- std::vector<Val*>::size_type size() const {
+ std::vector<Val*>::size_type nDims() const {
return indices_.size();
}
diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp
index a266518..4223d6a 100644
--- a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp
+++ b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp
@@ -1,7 +1,6 @@
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/tensor.h>
#include <iostream>
@@ -42,7 +41,9 @@
for (Val* val : vals) {
switch (val->getValType().value()) {
case (ValType::TensorView):
- os << "Tensor<" << val->getDataType().value() << "> T" << val->name();
+ os << "Tensor<" << val->getDataType().value() << ", "
+ << static_cast<TensorView*>(val)->getRootDomain()->nDims() << "> T"
+ << val->name();
break;
case (ValType::Scalar):
os << val->getDataType().value() << " " << val;
@@ -70,9 +71,9 @@
void IRPrinter::handle(const TensorDomain* const td) {
os << "[ ";
- for (std::vector<const IterDomain*>::size_type i = 0; i < td->size(); i++) {
+ for (std::vector<const IterDomain*>::size_type i = 0; i < td->nDims(); i++) {
handle(td->axis(i));
- if (i != td->size() - 1)
+ if (i != td->nDims() - 1)
os << ", ";
}
os << " ]";
@@ -107,8 +108,13 @@
default:
os << id->parallel_method();
}
+
os << "{";
- print_inline(id->size());
+ if (!id->start()->isZeroInt()) {
+ print_inline(id->start());
+ os << " : ";
+ }
+ print_inline(id->extent());
os << "}";
}
@@ -256,7 +262,7 @@
}
void IRPrinter::handle(const ForLoop* const fl) {
- if (fl->range()->isThread()) {
+ if (fl->iter_domain()->isThread()) {
for (auto& expr : fl->constBody().exprs())
handle(expr);
return;
@@ -265,10 +271,12 @@
indent();
os << "for(size_t ";
handle(fl->index());
- os << "{0}; ";
+ os << " = ";
+ print_inline(fl->iter_domain()->start());
+ os << "; ";
handle(fl->index());
os << " < ";
- print_inline(fl->range()->size());
+ print_inline(fl->iter_domain()->extent());
os << "; ++";
handle(fl->index());
os << " ) {\n";
@@ -343,8 +351,6 @@
const std::vector<Expr*>& exprs,
const std::string& kernel_name) {
Fusion* fusion = FusionGuard::getCurFusion();
- // if(exprs.size() != 0)
- // fusion = exprs[0]->fusion();
printHeader(fusion, kernel_name);
for (auto* expr : exprs) {
diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.h b/torch/csrc/jit/codegen/cuda/ir_iostream.h
index e0c4df5..50cc4d9 100644
--- a/torch/csrc/jit/codegen/cuda/ir_iostream.h
+++ b/torch/csrc/jit/codegen/cuda/ir_iostream.h
@@ -47,6 +47,7 @@
*/
struct TORCH_CUDA_API IRPrinter : public OptInConstDispatch {
+ public:
std::ostream& os;
bool print_inline_ = false;
@@ -65,7 +66,6 @@
void printHeader(Fusion* fusion, const std::string& kernel_name_);
- public:
IRPrinter(std::ostream& _os) : os(_os) {}
virtual void handle(Fusion* const f);
diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp
index daa79e7..38a0e7e 100644
--- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp
+++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp
@@ -2,7 +2,9 @@
#include <torch/csrc/jit/codegen/cuda/arith.h>
#include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/tensor.h>
+#include <torch/csrc/jit/codegen/cuda/transform_iter.h>
+
+#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <sstream>
@@ -10,6 +12,46 @@
namespace jit {
namespace fuser {
+namespace {
+struct ScalarCheck : OptInDispatch {
+ Val* v1_;
+ Val* v2_;
+ bool same = false;
+
+ void handle(Float* f) override {
+ same = static_cast<Float*>(v1_)->sameAs(static_cast<Float*>(v2_));
+ }
+
+ void handle(Int* i) override {
+ same = static_cast<Int*>(v1_)->sameAs(static_cast<Int*>(v2_));
+ }
+
+ void handle(NamedScalar* ns) override {
+ same =
+ static_cast<NamedScalar*>(v1_)->sameAs(static_cast<NamedScalar*>(v2_));
+ }
+
+ ScalarCheck(Val* _v1, Val* _v2) : v1_(_v1), v2_(_v2) {
+ OptInDispatch::handle(v1_);
+ }
+
+ public:
+ static bool sameAs(Val* v1, Val* v2) {
+ if (v1 == v2)
+ return true;
+
+ if (v1->getValType() != v2->getValType())
+ return false;
+
+ if (v1->getDataType() != v2->getDataType())
+ return false;
+
+ ScalarCheck sc(v1, v2);
+ return sc.same;
+ }
+};
+} // namespace
+
bool Float::sameAs(const Float* const other) const {
if (isConst() && other->isConst())
return *value() == *(other->value());
@@ -56,50 +98,53 @@
}
IterDomain::IterDomain(
- Val* _size,
+ Val* _start,
+ Val* _extent,
ParallelType _parallel_method,
bool _reduction_domain)
: Val(ValType::IterDomain, DataType::Int),
- size_(_size),
+ start_(_start),
+ extent_(_extent),
parallel_method_(_parallel_method),
is_reduction_domain_(_reduction_domain) {
TORCH_INTERNAL_ASSERT(
- _size->isAnInt(),
- "Cannot create an iter domain over a size that is not an int.");
+ _extent->isAnInt(),
+ "Cannot create an iter domain over an extent that is not an int but recieved ",
+ _extent,
+ " .");
+ TORCH_INTERNAL_ASSERT(
+ _start->isAnInt(),
+ "Cannot create an iter domain with a start that is not an int but recieved ",
+ _extent,
+ " .");
}
bool IterDomain::sameAs(const IterDomain* const other) const {
bool is_same = isReduction() == other->isReduction() &&
parallel_method() == other->parallel_method();
+ is_same = is_same && ScalarCheck::sameAs(extent(), other->extent());
+ is_same = is_same && ScalarCheck::sameAs(start(), other->start());
- if (size()->getValType() == ValType::NamedScalar &&
- other->size()->getValType() == ValType::NamedScalar) {
- is_same = is_same &&
- (static_cast<NamedScalar*>(size())->name().compare(
- static_cast<NamedScalar*>(other->size())->name()) == 0);
- } else {
- is_same = is_same && size()->sameAs(other->size());
- }
return is_same;
}
-Val* IterDomain::size() const {
+Val* IterDomain::extent() const {
if (isThread()) {
- if (size_->getValType() == ValType::Scalar)
- if (static_cast<Int*>(size_)->isConst())
- return size_;
+ if (extent_->getValType() == ValType::Scalar)
+ if (static_cast<Int*>(extent_)->isConst())
+ return extent_;
std::string parallel_dim = stringifyThreadSize(parallel_method_);
return new NamedScalar(parallel_dim, DataType::Int);
}
- return size_;
+ return extent_;
}
bool TensorDomain::sameAs(const TensorDomain* const other) const {
- if (size() != other->size())
+ if (nDims() != other->nDims())
return false;
- for (decltype(size()) i = 0; i < size(); i++)
+ for (decltype(nDims()) i = 0; i < nDims(); i++)
if (!(axis(i)->sameAs(other->axis(i))))
return false;
@@ -118,12 +163,230 @@
// uint.
IterDomain* TensorDomain::axis(int i) const {
if (i < 0)
- i += size();
+ i += nDims();
TORCH_CHECK(
- i >= 0 && i < size(), "Tried to access axis ", i, " in domain ", this);
+ i >= 0 && i < nDims(), "Tried to access axis ", i, " in domain ", this);
return domain_[i];
}
+// Split "axis" into 2 axes where the inner axes is size of "factor"
+// and outer axis is size axis.extent() / factor
+TensorDomain* TensorDomain::split(int axis_, int factor) {
+ if (axis_ < 0)
+ axis_ += nDims();
+
+ TORCH_INTERNAL_ASSERT(
+ axis_ >= 0 && axis_ < nDims(),
+ "Tried to split on axis outside TensorDomain's range.");
+
+ IterDomain* id = axis(axis_);
+
+ TORCH_CHECK(
+ id->start()->isZeroInt(),
+ "Splitting IterDomains with starting values that aren't 0, is not supported at this time.");
+
+ if (id->parallel_method() != ParallelType::Serial)
+ TORCH_CHECK(
+ false,
+ "Splitting an axis of non-Serial iteration is not supported at this time."
+ " Parallelization strategy must be set after calling split.");
+
+ std::vector<IterDomain*> new_domain;
+
+ Int* fact = new Int(factor);
+ Int* one = new Int(1);
+
+ for (decltype(nDims()) i = 0; i < nDims(); i++) {
+ if (i != axis_)
+ new_domain.push_back(axis(i));
+ else {
+ // outer loop size
+ Val* vo = ceilDiv(id->extent(), fact);
+ Int* so = static_cast<Int*>(vo);
+
+ // outer loop IterDomain
+ IterDomain* ido = new IterDomain(
+ new Int(0), so, id->parallel_method(), id->isReduction());
+ new_domain.push_back(ido);
+
+ // inner loop IterDomain
+ IterDomain* idi = new IterDomain(
+ new Int(0), fact, id->parallel_method(), id->isReduction());
+ new_domain.push_back(idi);
+ }
+ }
+ TensorDomain* split_td = new TensorDomain(new_domain);
+ Split* split_node =
+ new Split(split_td, this, axis_, fact); // For record keeping
+ return split_td;
+}
+
+// Merge "axis" and "axis+1" into 1 dimension
+TensorDomain* TensorDomain::merge(int axis_) {
+ if (axis_ < 0)
+ axis_ += nDims();
+
+ TORCH_CHECK(
+ axis_ >= 0 && axis_ + 1 < nDims(),
+ "Trying to merge axis_ outside of TensorView's range.");
+
+ IterDomain* first = axis(axis_);
+ IterDomain* second = axis(axis_ + 1);
+
+ TORCH_CHECK(
+ first->start()->isZeroInt() && second->start()->isZeroInt(),
+ "Merging IterDomains with starting values that aren't 0, is not supported at this time.");
+ TORCH_CHECK(
+ first->isReduction() == second->isReduction(),
+ "Merging domains requires that they're either both a reduction axis_, or both an iteration axis_.");
+ TORCH_CHECK(
+ first->parallel_method() == second->parallel_method(),
+ "Axes must have matching parallel types.");
+
+ Val* merged_id_size = mul(first->extent(), second->extent());
+ IterDomain* merged_id = new IterDomain(
+ new Int(0),
+ static_cast<Int*>(merged_id_size),
+ first->parallel_method(),
+ first->isReduction());
+
+ std::vector<IterDomain*> new_domain;
+ for (decltype(nDims()) i = 0; i < nDims(); i++) {
+ if (i < axis_ || i > axis_ + 1)
+ new_domain.push_back(axis(i));
+ else if (i == axis_) {
+ new_domain.push_back(merged_id);
+ }
+ }
+ TensorDomain* merged_td = new TensorDomain(new_domain);
+ Merge* merge_node = new Merge(merged_td, this, axis_); // For record keeping
+ return merged_td;
+}
+
+// Reorder axes according to map[old_pos] = new_pos
+TensorDomain* TensorDomain::reorder(
+ const std::unordered_map<int, int>& axis2pos_) {
+ // START VALIDATION CHECKS
+ // Eventhough these checks are already in TensorView, we want to redo them as
+ // we can enter this function from other places, not through TensorView
+
+ // adjust based on negative values (any negative values gets nDims added to
+ // it)
+ std::unordered_map<int, int> axis2pos;
+ auto ndims = nDims();
+ std::transform(
+ axis2pos_.begin(),
+ axis2pos_.end(),
+ std::inserter(axis2pos, axis2pos.begin()),
+ [ndims](std::unordered_map<int, int>::value_type entry) {
+ return std::unordered_map<int, int>::value_type({
+ entry.first < 0 ? entry.first + ndims : entry.first,
+ entry.second < 0 ? entry.second + ndims : entry.second,
+ });
+ });
+
+ // Check if any adjusted values are < 0, or >= nDims, which are invalid
+ bool out_of_range = std::any_of(
+ axis2pos.begin(),
+ axis2pos.end(),
+ [ndims](std::unordered_map<int, int>::value_type entry) {
+ return entry.first < 0 || entry.first >= ndims || entry.second < 0 ||
+ entry.second >= ndims;
+ });
+
+ TORCH_CHECK(
+ !out_of_range,
+ "TensorView reorder axes are outside the number of dimensions in the TensorView.")
+
+ // Going to use sets, to see if any duplicate values are in the map.
+
+ std::set<int> old_pos_set;
+ std::transform(
+ axis2pos.begin(),
+ axis2pos.end(),
+ std::inserter(old_pos_set, old_pos_set.begin()),
+ [](std::unordered_map<int, int>::value_type entry) {
+ return entry.first;
+ });
+
+ std::set<int> new_pos_set;
+ std::transform(
+ axis2pos.begin(),
+ axis2pos.end(),
+ std::inserter(new_pos_set, new_pos_set.begin()),
+ [](std::unordered_map<int, int>::value_type entry) {
+ return entry.first;
+ });
+
+ // Error out if duplicate values are found.
+ TORCH_CHECK(
+ old_pos_set.size() == axis2pos.size() &&
+ new_pos_set.size() == axis2pos.size(),
+ "Duplicate entries in transformation map sent to TensorView reorder.");
+
+ // END VALIDATION CHECKS
+
+ // Map to save, from previous order, to new order.
+ std::vector<int> pos2axis(ndims, -1);
+
+ // Go through each old and new position, make sure they're within 0-ndims
+ for (std::pair<int, int> elem : axis2pos) {
+ int old_pos = elem.first;
+ int new_pos = elem.second;
+
+ assert(old_pos >= 0 && old_pos < ndims && new_pos >= 0 && new_pos < ndims);
+
+ if (pos2axis[new_pos] != -1)
+ TORCH_CHECK(false, "Reorder found duplicate destination positions.");
+
+ pos2axis[new_pos] = old_pos;
+ }
+
+ std::set<int> old_positions(pos2axis.begin(), pos2axis.end());
+ old_positions.erase(-1);
+
+ if (old_positions.size() != axis2pos.size())
+ TORCH_INTERNAL_ASSERT(
+ false, "Reorder found duplicate destination positions.");
+
+ std::set<int> all_positions;
+ for (decltype(ndims) i{0}; i < ndims; i++)
+ all_positions.insert(i);
+
+ // Check what positions haven't been specified.
+ std::set<int> positions_left;
+ std::set_difference(
+ all_positions.begin(),
+ all_positions.end(),
+ old_positions.begin(),
+ old_positions.end(),
+ std::inserter(positions_left, positions_left.end()));
+
+ // Fill in positions that weren't specified, in relative order,
+ // in empty spots in the set of new positions.
+ // pos2axis[new_position] = old_position
+ auto it = positions_left.begin(); // old positions left
+ std::transform(
+ pos2axis.begin(), pos2axis.end(), pos2axis.begin(), [&it](int i) -> int {
+ return i == -1 ? *it++ : i;
+ });
+
+ std::vector<IterDomain*> reordered_domain;
+ std::transform(
+ pos2axis.begin(),
+ pos2axis.end(),
+ std::back_inserter(reordered_domain),
+ [this](int i) -> IterDomain* { return this->axis(i); });
+
+ TensorDomain* reordered_td = new TensorDomain(reordered_domain);
+ Reorder* merge_node = new Reorder(reordered_td, this, pos2axis);
+ return reordered_td;
+}
+
+TensorDomain* TensorDomain::rootDomain() {
+ return TransformIter::getRoot(this);
+}
+
Split::Split(TensorDomain* _out, TensorDomain* _in, int _axis, Int* _factor)
: Expr(ExprType::Split),
out_{_out},
@@ -174,25 +437,25 @@
ForLoop::ForLoop(
Val* _index,
- IterDomain* _range,
+ IterDomain* _iter_domain,
const std::vector<Expr*>& _body,
Expr* _parent_scope)
: Expr(ExprType::ForLoop),
index_{_index},
- range_{_range},
+ iter_domain_{_iter_domain},
parent_scope_{_parent_scope} {
TORCH_INTERNAL_ASSERT(
_index->isAnInt(),
"Cannot create a for loop with an index that is not an int.");
addInput(_index);
- addInput(_range);
+ addInput(_iter_domain);
this->name_ = FusionGuard::getCurFusion()->registerExpr(this);
for (Expr* expr : _body)
body().push_back(expr);
}
bool ForLoop::sameAs(const ForLoop* other) const {
- if (this->range() != other->range())
+ if (this->iter_domain() != other->iter_domain())
return false;
if (!(constBody().sameAs(other->constBody())))
return false;
@@ -223,13 +486,13 @@
}
bool TensorIndex::sameAs(const TensorIndex* const other) const {
- if (size() != other->size())
+ if (nDims() != other->nDims())
return false;
if (!view()->sameAs(other->view()))
return false;
- for (decltype(size()) i = 0; i < size(); i++)
+ for (decltype(nDims()) i = 0; i < nDims(); i++)
if (!(index(i)->sameAs(other->index(i))))
return false;
@@ -238,8 +501,8 @@
Val* TensorIndex::index(int i) const {
if (i < 0)
- i += size();
- assert(i >= 0 && i < size());
+ i += nDims();
+ assert(i >= 0 && i < nDims());
return indices_[i];
}
diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.h b/torch/csrc/jit/codegen/cuda/iter_visitor.h
index 030da4b..74021e3 100644
--- a/torch/csrc/jit/codegen/cuda/iter_visitor.h
+++ b/torch/csrc/jit/codegen/cuda/iter_visitor.h
@@ -36,6 +36,8 @@
struct TORCH_CUDA_API IterVisitor : public OptOutDispatch {
virtual ~IterVisitor() = default;
+ using OptOutDispatch::handle;
+
IterVisitor() = default;
IterVisitor(const IterVisitor& other) = default;
@@ -52,13 +54,13 @@
std::vector<Statement*> next(Expr* expr);
std::vector<Statement*> next(Val* v);
- void handle(Statement* s) {
+ virtual void handle(Statement* s) {
OptOutDispatch::handle(s);
}
- void handle(Expr* e) {
+ virtual void handle(Expr* e) {
OptOutDispatch::handle(e);
}
- void handle(Val* v) {
+ virtual void handle(Val* v) {
OptOutDispatch::handle(v);
}
@@ -94,10 +96,10 @@
// when handle is called on val, we know 2 things. Val is a dependency of of.
// and dep_chain contains the values in between of and dependency.
- void handle(Val* val);
+ void handle(Val* val) override;
// When we handle an expr we pop off its outputs from the dep_chain
- void handle(Expr* expr);
+ void handle(Expr* expr) override;
// When we visit an Expr we place its outputs on the dep_chain
void toVisitCallback(Statement* stmt);
diff --git a/torch/csrc/jit/codegen/cuda/kernel.cpp b/torch/csrc/jit/codegen/cuda/kernel.cpp
index aaafcf1..3313485 100644
--- a/torch/csrc/jit/codegen/cuda/kernel.cpp
+++ b/torch/csrc/jit/codegen/cuda/kernel.cpp
@@ -1,11 +1,16 @@
-#include <torch/csrc/jit/codegen/cuda/kernel.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <iostream>
-
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
+#include <c10/core/ScalarType.h>
#include <c10/cuda/CUDACachingAllocator.h>
+#include <c10/util/ArrayRef.h>
+
+#include <torch/csrc/jit/codegen/cuda/kernel.h>
+#include <torch/csrc/jit/codegen/cuda/kernel_arg.h>
+#include <torch/csrc/jit/codegen/cuda/kernel_resource_strings.h>
+#include <torch/csrc/jit/codegen/cuda/lower2device.h>
+
#include <torch/csrc/jit/resource_guard.h>
+#include <iostream>
namespace torch {
namespace jit {
@@ -25,59 +30,137 @@
return (a + b - 1) / b;
}
+// Go through a tensor, and grab it's sizes/strides potentially broadcasted
+struct ExtractSizeStride {
+ std::vector<int64_t> sizes;
+ std::vector<int64_t> strides;
+
+ ExtractSizeStride(
+ const at::Tensor& val,
+ c10::optional<at::IntArrayRef> broadcasted_size = c10::nullopt) {
+ if (broadcasted_size) {
+ int b_dim = (int)broadcasted_size->size();
+ int o_dim = (int)val.dim();
+ TORCH_CHECK(b_dim >= o_dim);
+ for (int i = 0; i < b_dim; i++) {
+ sizes.push_back(broadcasted_size->at(i));
+ int index = i + o_dim - b_dim;
+ if (index < 0) {
+ strides.push_back(0);
+ } else if (val.sizes()[index] == sizes[i]) {
+ strides.push_back(val.strides()[index]);
+ } else {
+ TORCH_CHECK(
+ val.sizes()[index] == 1,
+ "Not compatible dimension size for broadcast");
+ strides.push_back(0);
+ }
+ }
+ } else {
+ auto o_dim = val.dim();
+ for (decltype(val.dim()) i{0}; i < o_dim; i++) {
+ sizes.push_back(val.sizes()[i]);
+ strides.push_back(val.strides()[i]);
+ }
+ }
+ }
+};
+
+struct KernelArgumentHolder {
+ private:
+ std::vector<ArgAbstract*> arguments;
+ std::vector<void*> void_ptrs;
+ bool changed = true;
+
+ public:
+ virtual ~KernelArgumentHolder() {
+ for (auto arg : arguments)
+ delete arg;
+ }
+
+ // Push a tensor to the arguments
+ void push(
+ const at::Tensor& val,
+ c10::optional<at::IntArrayRef> broadcasted_size = c10::nullopt) {
+ changed = true;
+ ExtractSizeStride ess(val, std::move(broadcasted_size));
+ int nDims = ess.sizes.size();
+
+ c10::ScalarType dtype = val.scalar_type();
+ TensorArgAbstract* tensor_arg = getTensorArg(dtype, nDims);
+ tensor_arg->setPointer(val.data_ptr());
+ for (int i = 0; i < nDims; i++) {
+ tensor_arg->setSize(i, ess.sizes[i]);
+ tensor_arg->setStride(i, ess.strides[i]);
+ }
+ arguments.push_back(tensor_arg);
+ }
+
+ // Push a scalar or integer to the arguments
+ void push(const IValue& val) {
+ changed = true;
+ TORCH_INTERNAL_ASSERT(
+ val.isScalar(),
+ "Tried to push an arg to run in a fused kernel, expected a scalar but got, ",
+ val);
+ switch (val.toScalar().type()) {
+ case (c10::ScalarType::Double):
+ arguments.push_back(new FloatArg((float)val.toDouble()));
+ return;
+ case (c10::ScalarType::Long):
+ arguments.push_back(new IntArg((int)val.toInt()));
+ return;
+ default:
+ TORCH_INTERNAL_ASSERT(
+ false,
+ " Tried to create argument to send to a fused kernel, but got an unexpected type.");
+ }
+ TORCH_INTERNAL_ASSERT(
+ false,
+ " Tried to create argument to send to a fused kernel, but got a non-scalar type.");
+ }
+
+ // Create buffer, flatten arguments into it, align by 8 Bytes, return pointers
+ // in the buffer
+ void** getBuffer() {
+ if (changed) {
+ void_ptrs = std::vector<void*>(arguments.size(), nullptr);
+ for (decltype(arguments.size()) i{0}; i < arguments.size(); i++)
+ void_ptrs[i] = static_cast<void*>(arguments[i]->arg());
+ changed = false;
+ }
+ return void_ptrs.data();
+ }
+};
+
std::pair<std::string, std::string> codeGeneration(Fusion& fusion) {
std::stringstream str_stream;
-
- str_stream << "namespace " << CG_NAMESPACE << " {\n" << typeinfo << "\n";
+ str_stream << "namespace " << CG_NAMESPACE << " {\n"
+ << code_template_tensor_struct << "\n";
std::stringstream cdg;
GPULower gpulw(&fusion);
gpulw.printKernel(str_stream, KERNEL_NAME);
str_stream << "\n} // namespace";
std::string func_name = std::string(CG_NAMESPACE) + "::" + KERNEL_NAME;
-
return std::make_pair(func_name, str_stream.str());
};
-void prepare_argument(
- std::vector<void*>& arguments,
- std::vector<Tensor<float>>& tensor_args,
- const at::Tensor& val) {
- tensor_args.emplace_back();
- Tensor<float>& t = tensor_args.back();
- // passing address, type doesn't really matter here;
- t.data = static_cast<float*>(val.data_ptr());
-
- for (decltype(val.dim()) i{0}; i < val.dim(); i++) {
- t.size[i] = val.sizes()[i];
- t.stride[i] = val.strides()[i];
- }
-
- arguments.push_back(&(tensor_args.back()));
-};
-
-void prepare_argument(
- std::vector<void*>& arguments,
- std::vector<Tensor<float>>& tensor_args,
- std::vector<int>& int_args,
- std::vector<float>& float_args,
- const IValue& val) {
- if (val.isTensor()) {
- prepare_argument(arguments, tensor_args, val.toTensor());
- } else if (val.isDouble()) {
- float_args.push_back(val.to<float>());
- arguments.push_back(&(float_args.back()));
- } else if (val.isInt()) {
- int_args.push_back(val.to<int>());
- arguments.push_back(&(int_args.back()));
- } else {
- TORCH_CHECK(false, "Not supported input IValue encounted.");
- }
-};
-
} // namespace
-void compileKernel(Fusion& fusion, CudaKernel& entry) {
+bool KernelArgsReq::matchKernelSize(const at::IntArrayRef inputs) {
+ if (inputs.size() != low_.size()) {
+ return false;
+ }
+ for (decltype(inputs.size()) i{0}; i < inputs.size(); i++) {
+ if (inputs[i] < low_[i] || inputs[i] > hi_[i]) {
+ return false;
+ }
+ }
+ return true;
+}
+
+void compileKernel(Fusion& fusion, CudaKernel* entry) {
// generating cuda code;
std::string code;
std::string func_name;
@@ -96,14 +179,14 @@
// set device for the operation;
const auto prior_device = at::cuda::current_device();
- at::cuda::set_device(entry.device_);
+ at::cuda::set_device(entry->device_);
const auto prop = at::cuda::getCurrentDeviceProperties();
int nvrtc_major, nvrtc_minor;
AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcVersion(&nvrtc_major, &nvrtc_minor));
// Short-circuits if NVRTC version too low
- AT_ASSERT(nvrtc_major >= 6);
+ TORCH_INTERNAL_ASSERT(nvrtc_major >= 6);
// Major and minor is determined by device properties and
// possibly "downcompiled" to a lower (compatible) compute architecture
// based on the NVRTC version
@@ -143,55 +226,50 @@
ptx.resize(ptx_size);
AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetPTX(program, ptx.data()));
- AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoadData(&(entry.module_), ptx.data()));
+ AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoadData(&(entry->module_), ptx.data()));
AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleGetFunction(
- &(entry.function_), entry.module_, lowered_kernel_name));
+ &(entry->function_), entry->module_, lowered_kernel_name));
AT_CUDA_DRIVER_CHECK(nvrtc().cuOccupancyMaxActiveBlocksPerMultiprocessor(
- &entry.max_blocks_, entry.function_, 128, 0));
- entry.max_blocks_ *= prop->multiProcessorCount;
+ &entry->max_blocks_, entry->function_, 128, 0));
+ entry->max_blocks_ *= prop->multiProcessorCount;
}
void runKernel(
- CudaKernel& entry,
+ CudaKernel* entry,
const at::ArrayRef<IValue>& inputs,
std::vector<at::Tensor>& outputs) {
const auto prior_device = at::cuda::current_device();
- at::cuda::set_device(entry.device_);
+ at::cuda::set_device(entry->device_);
auto stream = at::cuda::getCurrentCUDAStream();
// TODO: Proper API to establish reasonable launch configurations;
// Naive launch config;
size_t numel = outputs[0].numel();
- const auto nBlocks = std::min(entry.max_blocks_, ceilDiv(numel, 128));
- // TODO: Proper API to tranform JIT I/O Tensor to CodeGen I/O Tensor
- std::vector<void*> arguments;
+ // TODO: we can't randomly clap down this until we got striding.
+ // const auto nBlocks = std::min(entry->max_blocks_, ceilDiv(numel, 128));
+ const auto nBlocks = ceilDiv(numel, 128);
- // TODO: There are better ways to do this;
- // argument holder;
- // host code, `T` in `Tensor<T>` doesn't really matter, as we only interact
- // with the address; Just put a float here to simply the argument holder.
- auto max_capacity = inputs.size() + outputs.size();
- std::vector<Tensor<float>> tensor_args;
- std::vector<int> int_args;
- std::vector<float> float_args;
- tensor_args.reserve(max_capacity);
- int_args.reserve(max_capacity);
- float_args.reserve(max_capacity);
+ KernelArgumentHolder kernel_args;
// Naive I/O setup, I'm ignoring all the potential transformation (i.e. I/O
// allocated here from the subgraph could be, and very likely are, different
// from I/O expected by the generated CUDA kernel.
for (auto& input : inputs) {
- prepare_argument(arguments, tensor_args, int_args, float_args, input);
+ if (input.isTensor()) {
+ kernel_args.push(input.toTensor(), outputs[0].sizes());
+ } else {
+ kernel_args.push(input);
+ }
}
+
for (auto& output : outputs) {
- prepare_argument(arguments, tensor_args, output);
+ kernel_args.push(output);
}
// launch kernel;
AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel(
- entry.function_,
+ entry->function_,
nBlocks,
1,
1,
@@ -200,17 +278,11 @@
1,
0,
stream,
- arguments.data(),
+ kernel_args.getBuffer(),
nullptr));
// Resets device (see at::DeviceGuard notes above)
at::cuda::set_device(prior_device);
-
- /*
- for (auto& output : outputs) {
- output.fill_(0.24);
- }
- */
}
// WARNING:
@@ -223,29 +295,17 @@
at::cuda::set_device(entry.device_);
auto stream = at::cuda::getCurrentCUDAStream();
- // TODO: Proper API to tranform JIT I/O Tensor to CodeGen I/O Tensor
- std::vector<void*> arguments;
-
- // TODO: There are better ways to do this;
- // argument holder;
- // host code, `T` in `Tensor<T>` doesn't really matter, as we only interact
- // with the address; Just put a float here to simply the argument holder.
- auto max_capacity = inputs.size() + outputs.size();
- std::vector<Tensor<float>> tensor_args;
- std::vector<int> int_args;
- std::vector<float> float_args;
- tensor_args.reserve(max_capacity);
- int_args.reserve(max_capacity);
- float_args.reserve(max_capacity);
+ KernelArgumentHolder kernel_args;
// Naive I/O setup, I'm ignoring all the potential transformation (i.e. I/O
// allocated here from the subgraph could be, and very likely are, different
// from I/O expected by the generated CUDA kernel.
for (auto& input : inputs) {
- prepare_argument(arguments, tensor_args, input);
+ kernel_args.push(input, outputs[0].sizes());
}
+
for (auto& output : outputs) {
- prepare_argument(arguments, tensor_args, output);
+ kernel_args.push(output);
}
// launch kernel;
@@ -259,7 +319,7 @@
entry.block_.z,
0,
stream,
- arguments.data(),
+ kernel_args.getBuffer(),
nullptr));
// Resets device (see at::DeviceGuard notes above)
diff --git a/torch/csrc/jit/codegen/cuda/kernel.h b/torch/csrc/jit/codegen/cuda/kernel.h
index 3de9b10..949bafc 100644
--- a/torch/csrc/jit/codegen/cuda/kernel.h
+++ b/torch/csrc/jit/codegen/cuda/kernel.h
@@ -24,10 +24,15 @@
namespace fuser {
namespace cuda {
-// include IO data structure for host code
-#define STRINGIFY(...) __VA_ARGS__
-#include <torch/csrc/jit/codegen/cuda/data_struct_str.h>
-#undef STRINGIFY
+// Not checking explicit broadcasting yet.
+// check only shape falls in the range;
+struct KernelArgsReq {
+ // We are checking accumulated output shape for now, this is a restricting
+ // aproach, we should check applicability on input tensor shapes instead.
+ bool matchKernelSize(const c10::IntArrayRef inputs);
+ std::vector<size_t> low_;
+ std::vector<size_t> hi_;
+};
class CudaKernel {
public:
@@ -61,23 +66,16 @@
dim3 grid_;
};
-// include IO data structure for stringification
-#define STRINGIFY(...) #__VA_ARGS__
-static auto typeinfo =
-#include "data_struct_str.h"
- ;
-#undef STRINGIFY
-
// compile Fusion to CUDA functions:
// 1. JIT compilation via nvrtc to generate CUDA c++ kernel code;
// 2. CUDA Drive API to load CUDA c++ kernel code as function_;
-TORCH_CUDA_API void compileKernel(Fusion& fusion, CudaKernel& entry);
+TORCH_CUDA_API void compileKernel(Fusion& fusion, CudaKernel* entry);
// run loaded kernel through Function.
// inputs/outputs is given in the sense of a PyTorch JIT ir node. This function
// wraps IO data structure for tensors on host.
TORCH_CUDA_API void runKernel(
- CudaKernel& entry,
+ CudaKernel* entry,
const at::ArrayRef<IValue>& inputs,
std::vector<at::Tensor>& outputs);
diff --git a/torch/csrc/jit/codegen/cuda/kernel_arg.h b/torch/csrc/jit/codegen/cuda/kernel_arg.h
new file mode 100644
index 0000000..51a588d
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/kernel_arg.h
@@ -0,0 +1,117 @@
+#pragma once
+
+#include <c10/core/ScalarType.h>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+namespace cuda {
+
+// This should match the tensor used in the code generation (almost exactly)
+template <typename T, int N>
+struct TensorArgCodegen {
+ T& operator[](int64_t ind) {
+ return data[ind];
+ };
+
+ T* data;
+ int64_t size[N];
+ int64_t stride[N];
+ constexpr int nDims() {
+ return N;
+ }
+};
+
+struct ArgAbstract {
+ virtual ~ArgAbstract() {}
+ virtual void* arg() = 0;
+};
+
+struct IntArg : public ArgAbstract {
+ int val_;
+ IntArg(int _val) : val_(_val){};
+ void* arg() {
+ return &val_;
+ }
+};
+
+struct FloatArg : public ArgAbstract {
+ float val_;
+ FloatArg(float _val) : val_(_val){};
+ void* arg() {
+ return &val_;
+ }
+};
+
+struct TensorArgAbstract : ArgAbstract {
+ virtual ~TensorArgAbstract(){};
+ virtual void setSize(int i, int64_t size) = 0;
+ virtual void setStride(int i, int64_t stride) = 0;
+ virtual void setPointer(void* ptr) = 0;
+};
+
+// This should match the tensor used in the code generation (almost exactly)
+template <typename TENSOR_TYPE>
+struct TensorArg : public TensorArgAbstract {
+ TENSOR_TYPE instance_;
+
+ void setSize(int i, int64_t size) override {
+ instance_.size[i] = size;
+ }
+ void setStride(int i, int64_t stride) override {
+ instance_.stride[i] = stride;
+ }
+ void setPointer(void* ptr) override {
+ instance_.data = static_cast<decltype(TENSOR_TYPE::data)>(ptr);
+ }
+
+ void* arg() override {
+ return &instance_;
+ }
+};
+
+template <typename T>
+TensorArgAbstract* getTensorArg(int nDims) {
+ switch (nDims) {
+ case (1):
+ return new TensorArg<TensorArgCodegen<T, 1>>();
+ case (2):
+ return new TensorArg<TensorArgCodegen<T, 2>>();
+ case (3):
+ return new TensorArg<TensorArgCodegen<T, 3>>();
+ case (4):
+ return new TensorArg<TensorArgCodegen<T, 4>>();
+ case (5):
+ return new TensorArg<TensorArgCodegen<T, 5>>();
+ case (6):
+ return new TensorArg<TensorArgCodegen<T, 6>>();
+ case (7):
+ return new TensorArg<TensorArgCodegen<T, 7>>();
+ case (8):
+ return new TensorArg<TensorArgCodegen<T, 8>>();
+ default:
+ TORCH_INTERNAL_ASSERT(
+ false,
+ "Tried to gerneate a tensor to run a generated kernel with ",
+ nDims,
+ " dimensions, however it must be a 1-8 dimensional tensor.");
+ }
+}
+
+TensorArgAbstract* getTensorArg(c10::ScalarType dtype, int nDims) {
+ switch (dtype) {
+ case (at::kFloat):
+ return getTensorArg<float>(nDims);
+ default:
+ TORCH_CHECK(
+ false,
+ "Dtype: ",
+ dtype,
+ " not currently supported in code generated kernels.");
+ }
+}
+
+} // namespace cuda
+} // namespace fuser
+} // namespace jit
+} // namespace torch
\ No newline at end of file
diff --git a/torch/csrc/jit/codegen/cuda/kernel_cache.cpp b/torch/csrc/jit/codegen/cuda/kernel_cache.cpp
new file mode 100644
index 0000000..717fb91
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/kernel_cache.cpp
@@ -0,0 +1,29 @@
+#include <torch/csrc/jit/codegen/cuda/kernel_cache.h>
+
+/*
+ */
+
+namespace torch {
+namespace jit {
+namespace fuser {
+namespace cuda {
+
+at::optional<CudaKernel*> CudaKernelCache::getKernelPtr(
+ c10::IntArrayRef sizes) {
+ for (auto& iter : kernels_) {
+ if (iter.first.matchKernelSize(sizes)) {
+ return &(iter.second);
+ }
+ }
+ return at::nullopt;
+}
+
+CudaKernel* CudaKernelCache::allocateKernelInCache(KernelArgsReq args_req) {
+ kernels_.emplace_back(std::make_pair(std::move(args_req), CudaKernel()));
+ return &(kernels_.back().second);
+}
+
+} // namespace cuda
+} // namespace fuser
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/kernel_cache.h b/torch/csrc/jit/codegen/cuda/kernel_cache.h
new file mode 100644
index 0000000..2d1276a
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/kernel_cache.h
@@ -0,0 +1,35 @@
+#pragma once
+
+#include <c10/util/ArrayRef.h>
+#include <torch/csrc/WindowsTorchApiMacro.h>
+
+#include <torch/csrc/jit/codegen/cuda/kernel.h>
+
+/*
+ */
+
+namespace torch {
+namespace jit {
+namespace fuser {
+namespace cuda {
+
+class CudaKernelCache {
+ public:
+ CudaKernelCache() = default;
+
+ at::optional<CudaKernel*> getKernelPtr(c10::IntArrayRef sizes);
+ CudaKernel* allocateKernelInCache(KernelArgsReq args_req);
+
+ // private:
+ // TODO: In theory we should assume contiguity remain constant across runs
+ // (job for BailOut node from profiling executor). In reality we might
+ // want to be safe and cache on that as well.
+ // Assuming constant nDims. Cache of kernels targetting different tensor size;
+ // We should flatten
+ std::vector<std::pair<KernelArgsReq, CudaKernel>> kernels_;
+};
+
+} // namespace cuda
+} // namespace fuser
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/kernel_resource_strings.h b/torch/csrc/jit/codegen/cuda/kernel_resource_strings.h
new file mode 100644
index 0000000..d074355
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/kernel_resource_strings.h
@@ -0,0 +1,28 @@
+namespace torch {
+namespace jit {
+namespace fuser {
+namespace cuda {
+
+// IO data structure for kernel code;
+static auto code_template_tensor_struct = R"(
+typedef unsigned char uint8_t;
+typedef signed char int8_t;
+typedef short int int16_t;
+typedef long long int int64_t;
+
+template<typename T, int N>
+struct Tensor {
+ T& operator[](int64_t ind) {
+ return data[ind];
+ };
+
+ T* data;
+ int64_t size[N];
+ int64_t stride[N];
+};
+)";
+
+} // namespace cuda
+} // namespace fuser
+} // namespace jit
+} // namespace torch
\ No newline at end of file
diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp
index 98c56dc..0953b08 100644
--- a/torch/csrc/jit/codegen/cuda/lower2device.cpp
+++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp
@@ -1,7 +1,12 @@
+#include <torch/csrc/jit/codegen/cuda/arith.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
+#include <torch/csrc/jit/codegen/cuda/index_compute.h>
+#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
+#include <torch/csrc/jit/codegen/cuda/lower_loops.h>
+#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
#include <torch/csrc/jit/codegen/cuda/mutator.h>
-#include <torch/csrc/jit/codegen/cuda/tensor.h>
+#include <torch/csrc/jit/codegen/cuda/predicate_compute.h>
+#include <torch/csrc/jit/codegen/cuda/transform_iter.h>
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
#include <torch/csrc/jit/codegen/cuda/type.h>
@@ -11,191 +16,6 @@
namespace jit {
namespace fuser {
-// START HELPER FUNCTIONS
-namespace {
-
-bool isTV(const Val* const val) {
- return val->getValType().value() == ValType::TensorView;
-}
-
-// Check if we're a TensorView op that we can generate code for.
-bool isTVOp(const Expr* expr) {
- if (expr->nOutputs() == 1 && isTV(expr->output(0)) &&
- (expr->getExprType().value() == ExprType::BinaryOp ||
- expr->getExprType().value() == ExprType::UnaryOp))
- return true;
- return false;
-}
-
-TensorView* asTV(Val* val) {
- TORCH_INTERNAL_ASSERT(isTV(val));
- return static_cast<TensorView*>(val);
-}
-
-const TensorView* asConstTV(const Val* const val) {
- TORCH_INTERNAL_ASSERT(isTV(val));
- return static_cast<const TensorView*>(val);
-}
-
-struct parentScope_ : private OptInDispatch {
- private:
- Expr* parent_ = nullptr;
-
- void handle(ForLoop* fl) final {
- parent_ = fl->parentScope();
- }
-
- void handle(IfThenElse* ite) final {
- parent_ = ite->parentScope();
- }
-
- void handle(Expr* expr) final {
- OptInDispatch::handle(expr);
- }
-
- public:
- static Expr* parent(Expr* scope) {
- parentScope_ sp;
- sp.handle(scope);
- return sp.parent_;
- }
-};
-
-struct forLoopCount : private OptInDispatch {
- private:
- unsigned int count_ = 0;
-
- void handle(ForLoop* fl) final {
- count_++;
- }
-
- void handle(IfThenElse* ite) final {}
-
- void handle(Expr* expr) final {
- OptInDispatch::handle(expr);
- }
-
- public:
- static unsigned int count(Expr* scope) {
- forLoopCount flc;
- Expr* it = scope;
- while (it != nullptr) {
- flc.handle(it);
- it = parentScope_::parent(it);
- }
- return flc.count_;
- }
-};
-
-struct scopePushBack : private OptInDispatch {
- private:
- Expr* _expr = nullptr;
- void handle(ForLoop* fl) final {
- fl->body().push_back(_expr);
- }
-
- void handle(IfThenElse* ite) final {
- ite->body().push_back(_expr);
- }
-
- void handle(Expr* expr) final {
- OptInDispatch::handle(expr);
- }
-
- public:
- static void pushBack(Expr* scope, Expr* expr) {
- scopePushBack pb;
- TORCH_INTERNAL_ASSERT(
- expr != nullptr && scope != nullptr,
- "Cannot push back, scope or expr is a nullptr.");
- pb._expr = expr;
- pb.handle(scope);
- }
-};
-
-struct forLoopIndices : private OptInDispatch {
- private:
- std::vector<Val*> inds_;
- void handle(ForLoop* fl) final {
- inds_.insert(inds_.begin(), fl->index());
- }
-
- void handle(IfThenElse* ite) final {}
-
- void handle(Expr* expr) final {
- OptInDispatch::handle(expr);
- }
-
- public:
- static std::vector<Val*> get(Expr* scope) {
- forLoopIndices fli;
- Expr* it = scope;
- while (it != nullptr) {
- fli.handle(it);
- it = parentScope_::parent(it);
- }
- return fli.inds_;
- }
-};
-
-struct forLoopIDs : private OptInDispatch {
- private:
- std::vector<IterDomain*> IDs_;
- void handle(ForLoop* fl) final {
- IDs_.insert(IDs_.begin(), fl->range());
- }
-
- void handle(IfThenElse* ite) final {}
-
- void handle(Expr* expr) final {
- OptInDispatch::handle(expr);
- }
-
- public:
- static std::vector<IterDomain*> get(Expr* scope) {
- forLoopIDs fli;
- Expr* it = scope;
- while (it != nullptr) {
- fli.handle(it);
- it = parentScope_::parent(it);
- }
- return fli.IDs_;
- }
-};
-
-} // namespace
-// END HELPER FUNCTIONS
-
-// Open a new inner most for loop
-void GPULower::openFor(IterDomain* id) {
- ForLoop* new_scope = nullptr;
- if (id->isThread()) {
- new_scope = new ForLoop(
- new NamedScalar(stringify(id->parallel_method()), DataType::Int),
- id,
- {},
- active_scope);
- } else {
- new_scope = new ForLoop(new Int(), id, {}, active_scope);
- }
- pushBack(new_scope);
- active_scope = new_scope;
-}
-
-// Close the inner most scope
-void GPULower::closeScope() {
- TORCH_INTERNAL_ASSERT(
- active_scope != nullptr,
- "Tried to close the active scope, but there isn't one set.");
- Expr* parent = parentScope_::parent(active_scope);
- active_scope = parent;
-}
-
-// Close all scopes
-void GPULower::resetScope() {
- active_scope = nullptr;
-}
-
// Clear out the last recorded computeAtView
void GPULower::clearActiveView() {
active_view_axis = 0;
@@ -208,14 +28,6 @@
active_view = tv->getComputeAtView();
}
-std::vector<Val*> GPULower::getLoopIndices() {
- return forLoopIndices::get(active_scope);
-}
-
-std::vector<IterDomain*> GPULower::getLoopIterDomains() {
- return forLoopIDs::get(active_scope);
-}
-
TensorIndex* GPULower::getGlobalProducerIndex(
TensorView* producer,
TensorView* consumer) {
@@ -224,14 +36,14 @@
// This replay will ignore reduction dimensions on the producer
TransformReplay::fullReplay(consumer, cloned_tv);
TORCH_INTERNAL_ASSERT(
- getLoopIndices().size() == cloned_tv->nDims(),
+ scope_utils::getLoopIndices(active_scope).size() == cloned_tv->nDims(),
"Dimensionality error in code generator while computing indexing.");
- const std::vector<Val*> computed_inds =
- IndexCompute::computeIndices(cloned_tv, getLoopIndices());
+ const std::vector<Val*> computed_inds = IndexCompute::computeIndices(
+ cloned_tv, scope_utils::getLoopIndices(active_scope));
TORCH_INTERNAL_ASSERT(
- computed_inds.size() == producer->getRootDomain()->size(),
+ computed_inds.size() == producer->getRootDomain()->nDims(),
"Dimensionality error in code generator while computing indexing.");
std::vector<Val*> strided_inds;
@@ -253,18 +65,23 @@
TensorView* producer,
TensorView* consumer) {
TORCH_INTERNAL_ASSERT(
- computeForDepth() == producer->nDims(),
+ scope_utils::computeForDepth(active_scope) == producer->nDims(),
"Expected a tensor with ",
- computeForDepth(),
+ scope_utils::computeForDepth(active_scope),
" dimensions but got one with ",
producer->nDims());
- std::vector<Val*> loopInds = getLoopIndices();
- std::vector<IterDomain*> ranges = getLoopIterDomains();
+ std::vector<Val*> loopInds = scope_utils::getLoopIndices(active_scope);
+ std::vector<IterDomain*> ranges =
+ scope_utils::getLoopIterDomains(active_scope);
std::vector<Val*> computed_inds;
std::vector<IterDomain*> used_ranges;
+ bool unrolled = false;
for (decltype(loopInds.size()) i{0}; i < loopInds.size(); i++) {
- if (producer->hasComputeAt() && i < producer->getComputeAtAxis())
+ if (ranges[i]->parallel_method() == ParallelType::Unroll)
+ unrolled = true;
+ if (!unrolled && producer->hasComputeAt() &&
+ i < producer->getComputeAtAxis())
continue;
if (ranges[i]->isThread())
continue;
@@ -275,7 +92,7 @@
for (decltype(computed_inds.size()) i{0}; i < computed_inds.size(); i++) {
Val* ind = computed_inds[i];
for (decltype(used_ranges.size()) j{i + 1}; j < used_ranges.size(); j++)
- ind = mul(ind, used_ranges[i]->size());
+ ind = mul(ind, used_ranges[i]->extent());
computed_inds[i] = ind;
}
if (computed_inds.size() == 0)
@@ -295,14 +112,14 @@
TensorIndex* GPULower::getGlobalConsumerIndex(TensorView* consumer) {
TORCH_INTERNAL_ASSERT(
- getLoopIndices().size() == consumer->nDims(),
+ scope_utils::getLoopIndices(active_scope).size() == consumer->nDims(),
"Dimensionality error in code generator while computing indexing.");
- const std::vector<Val*> computed_inds =
- IndexCompute::computeIndices(consumer, getLoopIndices());
+ const std::vector<Val*> computed_inds = IndexCompute::computeIndices(
+ consumer, scope_utils::getLoopIndices(active_scope));
TORCH_INTERNAL_ASSERT(
- computed_inds.size() == consumer->getRootDomain()->size(),
+ computed_inds.size() == consumer->getRootDomain()->nDims(),
"Dimensionality error in code generator while computing indexing.");
std::vector<Val*> strided_inds;
@@ -322,19 +139,23 @@
TensorIndex* GPULower::getLocalConsumerIndex(TensorView* consumer) {
TORCH_INTERNAL_ASSERT(
- computeForDepth() == consumer->nDims(),
+ scope_utils::computeForDepth(active_scope) == consumer->nDims(),
"Expected a tensor with ",
- computeForDepth(),
+ scope_utils::computeForDepth(active_scope),
" dimensions but got one with ",
consumer->nDims());
- std::vector<Val*> loopInds = getLoopIndices();
- std::vector<IterDomain*> ranges = getLoopIterDomains();
+ std::vector<Val*> loopInds = scope_utils::getLoopIndices(active_scope);
+ std::vector<IterDomain*> ranges =
+ scope_utils::getLoopIterDomains(active_scope);
std::vector<Val*> computed_inds;
std::vector<IterDomain*> used_ranges;
-
+ bool unrolled = false;
for (decltype(loopInds.size()) i{0}; i < loopInds.size(); i++) {
- if (i < consumer->getComputeAtAxis())
+ if (ranges[i]->parallel_method() == ParallelType::Unroll)
+ unrolled = true;
+ if (!unrolled && consumer->hasComputeAt() &&
+ i < consumer->getComputeAtAxis())
continue;
if (ranges[i]->isThread())
continue;
@@ -345,7 +166,7 @@
for (decltype(computed_inds.size()) i{0}; i < computed_inds.size(); i++) {
Val* ind = computed_inds[i];
for (decltype(used_ranges.size()) j{i + 1}; j < used_ranges.size(); j++)
- ind = mul(ind, used_ranges[i]->size());
+ ind = mul(ind, used_ranges[i]->extent());
computed_inds[i] = ind;
}
@@ -364,202 +185,115 @@
return getLocalConsumerIndex(consumer);
}
-// Track how far our for loop scope is
-unsigned int GPULower::computeForDepth() {
- return forLoopCount::count(active_scope);
-}
-
-// Push an expr to the active scope
void GPULower::pushBack(Expr* expr) {
- if (active_scope == nullptr) {
- lowered_exprs.push_back(expr);
- return;
- }
- scopePushBack::pushBack(active_scope, expr);
-}
-
-// Return the parent of the active scope
-Expr* GPULower::parentScope() {
if (active_scope == nullptr)
- return nullptr;
- return parentScope_::parent(active_scope);
+ lowered_exprs.push_back(expr);
+ else
+ scope_utils::pushBack(active_scope, expr);
}
-Allocate* GPULower::getAlloc(TensorView* tv) {
+Statement* GPULower::mutate(Expr* expr) {
+ Statement* mutated_stmt = OptOutMutator::mutate(expr);
TORCH_INTERNAL_ASSERT(
- !(FusionGuard::getCurFusion()->hasInput(tv) ||
- FusionGuard::getCurFusion()->hasOutput(tv)),
- "Tried to allocate an input or output tensor.");
-
- std::vector<Val*> alloc_dims;
-
- for (decltype(tv->nDims()) i = tv->getComputeAtAxis(); i < tv->nDims(); i++) {
- IterDomain* dim = tv->getComputeAtAxis(i);
- if (dim->isThreadDim())
- continue;
- alloc_dims.push_back(dim->size());
- }
-
- Val* size;
- if (alloc_dims.size() == 0) {
- size = new Int(1);
- } else {
- size = alloc_dims[0];
- for (decltype(alloc_dims.size()) i{1}; i < alloc_dims.size(); i++) {
- size = mul(size, alloc_dims[i]);
- }
- }
- return new Allocate(tv, size);
+ mutated_stmt->isExpr(),
+ "Tried to generate a kernel but hit a non expression during lowering: ",
+ mutated_stmt);
+ return mutated_stmt;
}
-IfThenElse* GPULower::getPredicate(const TensorView* const pred_tv) {
- TensorIndex* ti = new TensorIndex(
- pred_tv, IndexCompute::computeIndices(pred_tv, getLoopIndices()));
-
- std::vector<Int*> all_preds = PredicateCompute::computePredicates(ti);
-
- std::vector<Int*> preds;
-
- Int* one = new Int(1);
-
- for (Int* pred : all_preds)
- if (!pred->sameAs(one))
- preds.push_back(pred);
-
- if (preds.size() == 0) {
- return new IfThenElse(one, {}, {}, active_scope);
+Statement* GPULower::mutate(IfThenElse* ite) {
+ Expr* prev_scope = active_scope;
+ active_scope = ite;
+ std::vector<Expr*> mutated_exprs;
+ bool is_mutated = false;
+ for (auto expr : ite->body().exprs()) {
+ Statement* mutated_stmt = mutate(expr);
+ Expr* mutated_expr = ir_utils::asExpr(mutated_stmt);
+ mutated_exprs.push_back(mutated_expr);
+ is_mutated = is_mutated | (mutated_expr != expr);
}
- Int* cond = preds[0];
+ std::vector<Expr*> mutated_else_exprs;
+ for (auto expr : ite->elseBody().exprs()) {
+ Statement* mutated_stmt = mutate(expr);
+ Expr* mutated_expr = ir_utils::asExpr(mutated_stmt);
+ mutated_else_exprs.push_back(mutated_expr);
+ is_mutated = is_mutated | (mutated_expr != expr);
+ }
- for (decltype(preds.size()) i{1}; i < preds.size(); i++)
- cond = static_cast<Int*>(andOp(cond, preds[i]));
+ if (is_mutated) {
+ ite->body().clear();
+ for (auto expr : mutated_exprs)
+ ite->body().push_back(expr);
+ ite->elseBody().clear();
+ for (auto expr : mutated_else_exprs)
+ ite->elseBody().push_back(expr);
+ }
- return new IfThenElse(cond, {}, {}, active_scope);
+ active_scope = prev_scope;
+
+ if (is_mutated) {
+ auto new_ite = new IfThenElse(
+ ite->cond(), mutated_exprs, mutated_else_exprs, ite->parentScope());
+ return new_ite;
+ }
+
+ return ite;
}
-// Custom dispatch for Expr, want to find out of it's a TV op
-void GPULower::handle(Expr* expr) {
- if (!isTVOp(expr))
- return;
+Statement* GPULower::mutate(ForLoop* fl) {
+ Expr* prev_scope = active_scope;
+ active_scope = fl;
+ std::vector<Expr*> mutated_exprs;
+ bool is_mutated = false;
+ for (auto expr : fl->body().exprs()) {
+ Statement* mutated_stmt = mutate(expr);
+ Expr* mutated_expr = ir_utils::asExpr(mutated_stmt);
+ mutated_exprs.push_back(mutated_expr);
+ is_mutated = is_mutated | (mutated_expr != expr);
+ }
- TensorView* out = static_cast<TensorView*>(expr->output(0));
+ active_scope = prev_scope;
- updateView(out);
+ if (is_mutated) {
+ auto newFL = new ForLoop(
+ fl->index(), fl->iter_domain(), mutated_exprs, fl->parentScope());
+ return newFL;
+ }
- // 8) Run operation
- OptOutDispatch::handle(expr);
-
- // 9) Close predicate
- if (active_scope != nullptr &&
- active_scope->getExprType() == ExprType::IfThenElse)
- closeScope();
+ return fl;
}
-void GPULower::handle(UnaryOp* uop) {
- TORCH_INTERNAL_ASSERT(
- isTV(uop->out()),
- "Expected a tensor view but got ",
- uop->out()->getValType().value());
- TensorIndex* out = getConsumerIndex(asTV(uop->out()));
+Statement* GPULower::mutate(UnaryOp* uop) {
+ if (!ir_utils::isTVOp(uop))
+ return OptOutMutator::mutate(uop);
+
+ TensorIndex* out = getConsumerIndex(ir_utils::asTV(uop->out()));
Val* in = uop->in();
- if (isTV(in))
- in = getProducerIndex(asTV(in), asTV(uop->out()));
- pushBack(new UnaryOp(uop->getUnaryOpType(), out, in));
+ if (ir_utils::isTV(in))
+ in = getProducerIndex(ir_utils::asTV(in), ir_utils::asTV(uop->out()));
+ Expr* new_op = new UnaryOp(uop->getUnaryOpType(), out, in);
+
+ return new_op;
}
-void GPULower::handle(BinaryOp* bop) {
- TORCH_INTERNAL_ASSERT(
- isTV(bop->out()),
- "Expected a tensor view but got ",
- bop->out()->getValType().value());
- TensorIndex* out = getConsumerIndex(asTV(bop->out()));
+Statement* GPULower::mutate(BinaryOp* bop) {
+ if (!ir_utils::isTVOp(bop))
+ return OptOutMutator::mutate(bop);
+
+ TensorIndex* out = getConsumerIndex(ir_utils::asTV(bop->out()));
Val* lhs = bop->lhs();
Val* rhs = bop->rhs();
- if (isTV(lhs))
- lhs = getProducerIndex(asTV(lhs), asTV(bop->out()));
+ if (ir_utils::isTV(lhs))
+ lhs = getProducerIndex(ir_utils::asTV(lhs), ir_utils::asTV(bop->out()));
- if (isTV(rhs))
- rhs = getProducerIndex(asTV(rhs), asTV(bop->out()));
+ if (ir_utils::isTV(rhs))
+ rhs = getProducerIndex(ir_utils::asTV(rhs), ir_utils::asTV(bop->out()));
- pushBack(new BinaryOp(bop->getBinaryOpType(), out, lhs, rhs));
-}
+ Expr* new_op = new BinaryOp(bop->getBinaryOpType(), out, lhs, rhs);
-/*
- * This is one of the most complex parts of the code lowering logic. what we
- * need to do is: 1) Reduce loop structure
- * - Reset all loops if active_view == nullptr (I'm not the last in a series
- * of computeAts)
- * - Else reduce to active_view_axis if loop_depth > active_view_axis
- * 2) Set active_view(_axis)
- * - If there is a computeAt set for this TV
- * 3) Open to compute At
- * - If there is a computeAt set for this TV
- * 4) Allocate the output.
- * 5) If this is a reduction, initialize the output (open for loops to inner
- * most, predicate, initialize, close predicate, close to computeAt) 6) Open to
- * inner most loop 7) Open predicate 8) Run operation 9) Close predicate
- */
-
-// Update fors based on tv.
-void GPULower::updateView(TensorView* tv) {
- // 1) Reduce loop structure
- if (active_view == nullptr) {
- // - Reset all loops if active_view == nullptr (I'm not the last in a series
- // of computeAts)
- resetScope();
- } else {
- // - Else reduce to active_view_axis if loop_depth > active_view_axis
- auto depth = computeForDepth();
- for (auto i = depth; i > active_view_axis; i--) {
- closeScope();
- }
- }
- if (tv->hasComputeAt()) {
- // 2) Set active_view(_axis)
- // - If there is a computeAt set for this TV
- setActiveView(tv);
-
- // 3) Open to compute At
- // - If there is a computeAt set for this TV
- auto depth = computeForDepth();
- for (auto i = depth; i < tv->getComputeAtAxis(); i++)
- openFor(tv->getComputeAtAxis(i));
- } else {
- if (active_view != nullptr)
- // If we're the last computeAt of a block, active view should match this
- // tv
- TORCH_INTERNAL_ASSERT(
- tv->sameAs(active_view),
- "Error detected in code lowering. Expected ",
- active_view,
- " but recieved ",
- tv);
- clearActiveView();
- }
-
- // 4) Allocate the output.
-
- if (!FusionGuard::getCurFusion()->hasInput(tv) &&
- !FusionGuard::getCurFusion()->hasOutput(tv)) {
- pushBack(getAlloc(tv));
- }
-
- // TODO:
- // 5) If this is a reduction, initialize the output (open for loops to inner
- // most, predicate, initialize, close predicate, close to computeAt)
-
- // 6) Open to inner most loop
- for (decltype(tv->nDims()) i = computeForDepth(); i < tv->nDims(); i++)
- openFor(tv->getComputeAtAxis(i));
-
- // 7) Open predicate
- IfThenElse* pred = getPredicate(tv);
- if (!pred->cond()->sameAs(new Int(1))) {
- pushBack(pred);
- active_scope = pred;
- }
+ return new_op;
}
// TensorViews are all based on symbolic sizes. When we first initialize them we
@@ -572,20 +306,22 @@
Fusion* fusion = FusionGuard::getCurFusion();
// Sizes of inputs/outputs -> T.size[...]
std::unordered_map<Val*, Val*> size_map;
- // Replacement of full tensor views
- std::unordered_map<Val*, Val*> tv_map;
// Grab inputs and outputs
std::vector<TensorView*> orig_inp_out;
- std::vector<TensorView*> orig_intermediates;
+ std::vector<TensorView*> all_tvs;
+
+ for (auto* val : fusion->inputs())
+ if (ir_utils::isTV(val))
+ orig_inp_out.push_back(ir_utils::asTV(val));
+
+ for (auto* val : fusion->outputs())
+ if (ir_utils::isTV(val))
+ orig_inp_out.push_back(ir_utils::asTV(val));
for (auto* val : fusion->deterministic_vals()) {
- if (isTV(val)) {
- if (fusion->hasInput(val) || fusion->hasOutput(val)) {
- orig_inp_out.push_back(asTV(val));
- } else {
- orig_intermediates.push_back(asTV(val));
- }
+ if (ir_utils::isTV(val)) {
+ all_tvs.push_back(ir_utils::asTV(val));
}
}
@@ -602,104 +338,105 @@
// option which seems less elegant but would also work is build up the domain
// on the new tensor, and then simply replace it into the original one.
for (TensorView* tv : orig_inp_out) {
- TensorView* new_tv =
- new TensorView(tv->domain(), tv->getDataType().value());
-
- // We can place the new_tv in the map right away.
- tv_map[tv] = new_tv;
-
// Replace the domain with one based on Ti.size[j]
- std::vector<IterDomain*> new_domain;
+ std::vector<IterDomain*> new_domain_iters;
TensorDomain* root_td = tv->getRootDomain();
- for (decltype(root_td->size()) i{0}; i < root_td->size(); i++) {
- Val* orig_size = root_td->axis(i)->size();
+ for (decltype(root_td->nDims()) i{0}; i < root_td->nDims(); i++) {
+ Val* orig_size = root_td->axis(i)->extent();
std::stringstream ss;
- ss << "T" << new_tv->name() << ".size[" << i << "]";
+ ss << "T" << tv->name() << ".size[" << i << "]";
Val* new_size =
new NamedScalar(ss.str(), orig_size->getDataType().value());
- size_map[orig_size] = new_size;
-
- new_domain.push_back(new IterDomain(
- new_size,
- root_td->axis(i)->parallel_method(),
- root_td->axis(i)->isReduction()));
+ if (!orig_size->sameAs(new_size) ||
+ size_map.find(orig_size) == size_map.end())
+ size_map[orig_size] = new_size;
}
- new_tv->setDomain(new TensorDomain(new_domain));
}
- for (TensorView* tv : orig_intermediates) {
- TensorView* new_tv =
- new TensorView(tv->domain(), tv->getDataType().value());
- tv_map[tv] = new_tv;
+ // If we already lowered all inputs/outputs we can just return.
+ if (size_map.size() == 0)
+ return;
- std::vector<IterDomain*> new_domain;
+ for (TensorView* tv : all_tvs) {
+ std::vector<IterDomain*> new_domain_iters;
TensorDomain* root_td = tv->getRootDomain();
- for (decltype(root_td->size()) i{0}; i < root_td->size(); i++) {
- Val* new_size = root_td->axis(i)->size();
+ for (decltype(root_td->nDims()) i{0}; i < root_td->nDims(); i++) {
+ Val* new_size = root_td->axis(i)->extent();
if (size_map.find(new_size) != size_map.end())
new_size = size_map[new_size];
- new_domain.push_back(new IterDomain(
+ new_domain_iters.push_back(new IterDomain(
+ root_td->axis(i)->start(),
new_size,
root_td->axis(i)->parallel_method(),
root_td->axis(i)->isReduction()));
}
- new_tv->setDomain(new TensorDomain(new_domain));
- }
- // Now that we have the base tensor views. Lets fix its members.
- for (auto entry : tv_map) {
- TensorView* orig_tv = asTV(entry.first);
- TensorView* new_tv = asTV(entry.second);
+ TensorDomain* old_domain = tv->domain();
+ TensorDomain* new_domain = TransformReplay::fullReplay(
+ old_domain, new TensorDomain(new_domain_iters));
- // Domain in the new TV is the root domain, replay it like the original
- // domain.
- TransformReplay::fullReplay(orig_tv, new_tv);
-
+ TORCH_INTERNAL_ASSERT(
+ old_domain->nDims() == new_domain->nDims(),
+ "Tried to set symbolic sizes through the kernel, but hit a snag, Replayed domain should be the same size as the target domain, but got ",
+ new_domain->nDims(),
+ " and ",
+ old_domain->nDims());
// Parallelize all iter domains
- for (decltype(new_tv->domain()->size()) i{0}; i < new_tv->domain()->size();
- i++)
- new_tv->axis(i)->parallelize(orig_tv->axis(i)->parallel_method());
+ for (decltype(new_domain->nDims()) i{0}; i < new_domain->nDims(); i++)
+ new_domain->axis(i)->parallelize(old_domain->axis(i)->parallel_method());
- // Set compute at view and axis
- TensorView* computeAtTV = orig_tv->getComputeAtView();
- if (computeAtTV != nullptr) {
- TORCH_INTERNAL_ASSERT(
- tv_map.find(computeAtTV) != tv_map.end(),
- "Expected to find a translation for ",
- computeAtTV,
- " but one wasn't found.");
- new_tv->setComputeAt(
- asTV(tv_map[computeAtTV]), (int)(orig_tv->getComputeAtAxis()));
- }
+ tv->setDomain(new_domain);
}
-
- ReplaceAll::instancesOf(tv_map);
}
+namespace {
+
+// Some pre-compilation checks
+void validate(Fusion* fusion) {
+ for (Val* val : fusion->vals()) {
+ if (ir_utils::isTV(val)) {
+ TensorView* tv = ir_utils::asTV(val);
+ for (decltype(tv->nDims()) i{0}; i < tv->nDims(); i++) {
+ IterDomain* id = tv->getComputeAtAxis(i);
+
+ if (id->isThread())
+ TORCH_CHECK(
+ !id->isReduction(),
+ "Parallelization on reduction axes not support at the moment found on, ",
+ tv,
+ ".");
+ }
+ } // if ir_utils::isTV
+ } // for(Val* val : fusion->vals())
+
+} // validate
+} // namespace
+
// Traverse through the fusion and print CUDA code associated with it
std::vector<Expr*> GPULower::getLoweredExprs() {
FusionGuard fg(fusion_);
- TORCH_CHECK(
- !fusion_->lowered,
- "Fusions can only be lowered once as of now. You could reuse the lowering using",
- " std::vector<Expr*> GPULower::getLoweredExprs() the result can be printed as",
- " a kernel with IRPrinter irp(os); irp.printKernel(lowered_exprs, kernel_name);");
+ validate(fusion_);
// Initialize members of the class
- lowered_exprs = std::vector<Expr*>();
active_view = nullptr;
active_view_axis = 0;
replaceSizes();
- // Run through and lower the expressions
- std::vector<Expr*> exprs = fusion_->exprs(true);
- for (auto* expr : exprs)
- handle(expr);
+ auto loop_nests = LoopNestGenerator::getLoopNest(fusion_);
+ auto unrolled_loops = UnrollPass::runPass(fusion_, loop_nests);
+ // Run through loop nests and further lower the expressions
+ for (auto* expr : unrolled_loops) {
+ Statement* mutated_stmt = mutate(expr);
+ TORCH_INTERNAL_ASSERT(
+ mutated_stmt->isExpr(),
+ "Tried to generate a kernel but hit a non expression during lowering: ",
+ mutated_stmt);
+ lowered_exprs.push_back(static_cast<Expr*>(mutated_stmt));
+ }
- fusion_->lowered = true;
return lowered_exprs;
}
@@ -707,7 +444,6 @@
std::ostream& os,
const std::string& kernel_name) {
FusionGuard fg(fusion_);
-
getLoweredExprs();
IRPrinter irp(os);
@@ -717,4 +453,4 @@
} // namespace fuser
} // namespace jit
-} // namespace torch
\ No newline at end of file
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/lower2device.h b/torch/csrc/jit/codegen/cuda/lower2device.h
index 00018ad..ff92826 100644
--- a/torch/csrc/jit/codegen/cuda/lower2device.h
+++ b/torch/csrc/jit/codegen/cuda/lower2device.h
@@ -2,12 +2,7 @@
#include <torch/csrc/WindowsTorchApiMacro.h>
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/index_compute.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
-#include <torch/csrc/jit/codegen/cuda/predicate_compute.h>
-#include <torch/csrc/jit/codegen/cuda/transform_iter.h>
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <map>
#include <ostream>
@@ -21,71 +16,61 @@
// keep user references intact so they can lower it as they describe the kernel.
// Right now we can only lower once.
-struct TORCH_CUDA_API GPULower : public OptOutDispatch {
+struct TORCH_CUDA_API GPULower : public OptOutMutator {
private:
bool lowered = false;
- Fusion* fusion_;
+ Fusion* const fusion_;
std::vector<Expr*> lowered_exprs;
Expr* active_scope = nullptr;
+
// Track the last computeAt TensorView and axis
const TensorView* active_view;
unsigned int active_view_axis;
- // Open a new inner most for loop
- void openFor(IterDomain*);
- // Close the inner most for loop
- void closeScope();
- // Close all for loops
- void resetScope();
// Clear out the last recorded computeAtView
void clearActiveView();
// Set active views from computeAtView
void setActiveView(const TensorView* const);
- // Grab the index variables of the active loop nest
- std::vector<Val*> getLoopIndices();
- // Grab the iterDomains of the active loops
- std::vector<IterDomain*> getLoopIterDomains();
- // Gets the indexing of a TensorView producer. These are values consumed in a
- // TensorView Expr. We use the consumer (left hand side of the =) to compute
- // the indexing into the consumer.
+
+ // Indexing functions
+ // Consumer = Producer
+ // i.e. T0 = T1... -> T0 is the consumer, T1 is the producer
+ // Producer indexing dispatch
TensorIndex* getProducerIndex(TensorView* producer, TensorView* consumer);
+ // Producer if it's in global memory
TensorIndex* getGlobalProducerIndex(
TensorView* producer,
TensorView* consumer);
+ // Producer indexing if it's in registers
TensorIndex* getLocalProducerIndex(
TensorView* producer,
TensorView* consumer);
+ // Consumer index dispatch
TensorIndex* getConsumerIndex(TensorView* consumer);
+ // Consumer indexing if it's in global memory
TensorIndex* getGlobalConsumerIndex(TensorView* consumer);
+ // Consumer indexing if it's in local memory
TensorIndex* getLocalConsumerIndex(TensorView* consumer);
- // Track how far our for loop scope is
- unsigned int computeForDepth();
- // Push an expr to the active scope
- void pushBack(Expr* expr);
- // Return the parent of the active scope
- Expr* parentScope();
-
- // Get Register allocation statement for tensorview
- Allocate* getAlloc(TensorView*);
// Get a predicate based on a particular tensorview
IfThenElse* getPredicate(const TensorView* const);
+ // Wrap pushBack in lower_utils if active_scope is null we want it to go
+ // straight to lower_exprs
+ void pushBack(Expr*);
+
// Custom dispatch for Expr, want to find out of it's a TV op
- void handle(Expr*) final;
+ Statement* mutate(Expr*) final;
+
+ // Open the for loop.
+ Statement* mutate(ForLoop*) final;
+
+ // Open the for loop.
+ Statement* mutate(IfThenElse*) final;
// Remake operations with TensorIndex
- void handle(UnaryOp*) final;
- void handle(BinaryOp*) final;
-
- // Ignore split/merge/reorder operations,
- // we don't want to print them.
- void handle(Split*) final {}
- void handle(Merge*) final {}
- void handle(Reorder*) final {}
-
- // Update for loop structure based on producing provided TensorView
- void updateView(TensorView*);
+ Statement* mutate(UnaryOp*) final;
+ Statement* mutate(BinaryOp*) final;
// TensorViews are all based on symbolic sizes. When we first initialize them
// we don't know if they're inputs or outputs which would mean that they have
diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.cpp b/torch/csrc/jit/codegen/cuda/lower_loops.cpp
new file mode 100644
index 0000000..7eee9ef
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/lower_loops.cpp
@@ -0,0 +1,384 @@
+#include <torch/csrc/jit/codegen/cuda/lower_loops.h>
+#include <torch/csrc/jit/codegen/cuda/arith.h>
+#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
+
+#include <torch/csrc/jit/codegen/cuda/index_compute.h>
+#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
+#include <torch/csrc/jit/codegen/cuda/predicate_compute.h>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+// all the way in the loop nest, grab predicate
+/*
+for( i : ceil(I/4) ) {
+ for( j : ceil(J/128) ) {
+
+ if( i * 4 + 3 < I && j * 128 + 127 < J ){
+ for( k : 4)
+ for( l : 128 )
+ T0[ ( i * 4 + k ) * J + j * 128 + l ] = …
+ } else {
+ for( k : 4 )
+ for( l : 128 )
+ if( i * 4 + k < I && j * 128 + l < J)
+ T0[ ( i * 4 + k ) * J + j * 128 + l ] = …
+ }
+
+ }
+}
+*/
+
+// Custom dispatch for Expr, want to find out of it's a TV op
+void UnrollPass::handle(Expr* expr) {
+ OptOutDispatch::handle(expr);
+}
+
+namespace {
+Int* getPredicate(const TensorView* const pred_tv, std::vector<Val*> indices) {
+ TensorIndex* ti = new TensorIndex(
+ pred_tv, IndexCompute::computeIndices(pred_tv, std::move(indices)));
+ std::vector<Int*> all_preds = PredicateCompute::computePredicates(ti);
+
+ std::vector<Int*> preds;
+
+ Int* one = new Int(1);
+
+ for (Int* pred : all_preds)
+ if (!pred->sameAs(one))
+ preds.push_back(pred);
+
+ if (preds.size() == 0) {
+ return one;
+ } else {
+ Int* cond = preds[0];
+
+ for (decltype(preds.size()) i{1}; i < preds.size(); i++)
+ cond = static_cast<Int*>(andOp(cond, preds[i]));
+
+ return cond;
+ }
+}
+} // namespace
+
+// Open the for loop.
+void UnrollPass::handle(ForLoop* fl) {
+ // Setup for loop scoping
+ for_loops.push_back(fl);
+ bool prev_unroll = within_unroll;
+ within_unroll = ir_utils::isUnrolledFor(fl) || within_unroll;
+
+ for (auto expr : fl->body().exprs()) {
+ OptOutDispatch::handle(expr);
+ }
+
+ TensorView* out;
+ bool has_TV_op = false;
+ for (Expr* expr : fl->body().exprs())
+ if (ir_utils::isTVOp(expr)) {
+ // Predicate determining op for unroll
+ out = ir_utils::asTV(expr->output(0));
+ has_TV_op = true;
+ break;
+ }
+
+ if (within_unroll && has_TV_op) {
+ // Setup unrolled loop information:
+
+ // Indices used to detect when we can unroll a loop safely
+ // For loops outside the unroll, it's just he index, for loops inside
+ // the unroll, if it's a thread it's the thread index, otherwise it's
+ // the size-1
+ std::vector<Val*> unroll_pred_inds;
+ auto it = for_loops.begin();
+ while (it != for_loops.end()) {
+ if (ir_utils::isUnrolledFor(*it))
+ break;
+ unroll_pred_inds.push_back((*it)->index());
+ it++;
+ }
+
+ // This is the outer most loop that needs to be unrolled
+ ForLoop* first_unroll = *it;
+
+ // Indicies inside the unroll
+ while (it != for_loops.end()) {
+ IterDomain* id = (*it)->iter_domain();
+ if (id->isThread())
+ unroll_pred_inds.push_back((*it)->index());
+ else
+ unroll_pred_inds.push_back(sub(id->extent(), new Int(1)));
+ it++;
+ }
+
+ // Make predicates for the unrolling, and the epilogue
+ Int* unroll_predicate = getPredicate(out, unroll_pred_inds);
+
+ // Make the IfThenElse controlling the unrolling
+ IfThenElse* unroll_ite =
+ new IfThenElse(unroll_predicate, {}, {}, first_unroll->parentScope());
+
+ // Get the loop nest for the unrolled path
+ ForLoop* unrolled_loop =
+ scope_utils::cloneLoopNest(first_unroll, unroll_ite);
+ unroll_ite->body().push_back(unrolled_loop);
+
+ // Loop nest for inlined path
+ ForLoop* inlined_loop =
+ scope_utils::cloneLoopNest(first_unroll, unroll_ite);
+ unroll_ite->elseBody().push_back(inlined_loop);
+
+ // Inner most inlined loop
+ Expr* inner_most_inlined_loop =
+ scope_utils::firstInnerMostScope(inlined_loop);
+
+ loop_replacement_map.insert({first_unroll, unroll_ite});
+
+ bool first_expr = true;
+ for (auto expr : fl->body().exprs()) {
+ if (!ir_utils::isTVOp(expr))
+ continue;
+
+ // Setup the expressions that need predicates around them.
+ Int* inline_predicate =
+ getPredicate(out, scope_utils::getLoopIndices(for_loops.back()));
+ IfThenElse* inline_ite =
+ new IfThenElse(inline_predicate, {expr}, {}, inner_most_inlined_loop);
+ std::unordered_map<Expr*, Expr*> inline_replacement_map;
+ inline_replacement_map.emplace(std::pair<Expr*, Expr*>(expr, inline_ite));
+ scope_utils::replaceExprsInScope(
+ inner_most_inlined_loop, inline_replacement_map);
+
+ } // for expr
+
+ } else { // if(!within_unroll)
+
+ for (auto expr : fl->body().exprs()) {
+ if (!ir_utils::isTVOp(expr))
+ continue;
+
+ // ! within_unroll
+ TensorView* out = ir_utils::asTV(ir_utils::asExpr(expr)->outputs()[0]);
+ Int* pred =
+ getPredicate(out, scope_utils::getLoopIndices(for_loops.back()));
+ if (!pred->isOneInt()) {
+ IfThenElse* inline_ite =
+ new IfThenElse(pred, {expr}, {}, for_loops.back());
+ for_loops.back()->body().insert_before(expr, inline_ite);
+ for_loops.back()->body().erase(expr);
+ }
+ }
+ } // else (if(!within_unroll))
+ for_loops.pop_back();
+ bool within_unroll = prev_unroll;
+}
+
+// Generate the loop nest structure and place it in lowered_exprs
+void UnrollPass::computeMap() {
+ FusionGuard fg(fusion_);
+
+ // Initialize members of the class
+ active_view = nullptr;
+ active_view_axis = 0;
+
+ // Run through loop nests and further lower the expressions
+ for (auto* expr : incoming_exprs_) {
+ OptOutDispatch::handle(expr);
+ }
+}
+
+std::vector<Expr*> UnrollPass::runPass(
+ Fusion* fusion,
+ const std::vector<Expr*>& exprs) {
+ FusionGuard fg(fusion);
+ UnrollPass up(fusion, exprs);
+ up.computeMap();
+ std::vector<Expr*> mutated_exprs;
+ for (Expr* expr : exprs) {
+ if (up.loop_replacement_map.find(expr) != up.loop_replacement_map.end()) {
+ mutated_exprs.push_back(up.loop_replacement_map[expr]);
+ } else {
+ if (ir_utils::isScope(expr))
+ scope_utils::replaceExprsInScope(expr, up.loop_replacement_map);
+ mutated_exprs.push_back(expr);
+ }
+ }
+ return mutated_exprs;
+}
+
+void LoopNestGenerator::pushAlloc(TensorView* tv) {
+ TORCH_INTERNAL_ASSERT(
+ !(FusionGuard::getCurFusion()->hasInput(tv) ||
+ FusionGuard::getCurFusion()->hasOutput(tv)),
+ "Tried to allocate an input or output tensor.");
+
+ // Compute at axis can be == tv->nDims() meaning it's inline
+ decltype(tv->nDims()) alloc_pos = 0;
+ bool reset = true;
+ while (alloc_pos <= tv->nDims()) {
+ if (tv->hasComputeAt() && alloc_pos == tv->getComputeAtAxis()) {
+ reset = false;
+ break;
+ }
+ if (alloc_pos < tv->nDims() &&
+ tv->getComputeAtAxis(alloc_pos)->parallel_method() ==
+ ParallelType::Unroll) {
+ reset = false;
+ break;
+ }
+ alloc_pos++;
+ }
+ alloc_pos = reset ? 0 : alloc_pos;
+
+ std::vector<Val*> alloc_dims;
+ for (auto i = alloc_pos; i < tv->nDims(); i++) {
+ IterDomain* dim = tv->getComputeAtAxis(i);
+ if (dim->isThreadDim())
+ continue;
+ // TORCH_INTERNAL_ASSERT()
+ alloc_dims.push_back(dim->extent());
+ }
+
+ Val* size;
+ if (alloc_dims.size() == 0) {
+ size = new Int(1);
+ } else {
+ size = alloc_dims[0];
+ for (decltype(alloc_dims.size()) i{1}; i < alloc_dims.size(); i++) {
+ size = mul(size, alloc_dims[i]);
+ }
+ }
+ Allocate* alloc = new Allocate(tv, size);
+ if (alloc_pos == 0) {
+ lowered_exprs.insert(lowered_exprs.begin(), alloc);
+ } else if (alloc_pos == for_loops.size()) {
+ // inline
+ scope_utils::pushBack(for_loops[alloc_pos - 1], alloc);
+ } else {
+ scope_utils::insertBefore(
+ for_loops[alloc_pos - 1], for_loops[alloc_pos], alloc);
+ }
+}
+
+// Clear out the last recorded computeAtView
+void LoopNestGenerator::clearActiveView() {
+ active_view_axis = 0;
+ active_view = nullptr;
+}
+
+// Set active views from computeAtView
+void LoopNestGenerator::setActiveView(const TensorView* const tv) {
+ active_view_axis = tv->getComputeAtAxis();
+ active_view = tv->getComputeAtView();
+}
+
+void LoopNestGenerator::openFor(IterDomain* id) {
+ if (for_loops.size() > 0) {
+ ForLoop* new_scope = scope_utils::openFor(for_loops.back(), id);
+ for_loops.push_back(new_scope);
+ } else {
+ for_loops.push_back(scope_utils::openFor(nullptr, id));
+ lowered_exprs.push_back(for_loops.back());
+ }
+}
+
+void LoopNestGenerator::pushBack(Expr* expr) {
+ if (for_loops.size() == 0)
+ lowered_exprs.push_back(expr);
+ else
+ scope_utils::pushBack(for_loops.back(), expr);
+}
+
+/*
+ * This is one of the most complex parts of the code lowering logic. what we
+ * need to do is: 1) Reduce loop structure
+ * - Reset all loops if active_view == nullptr (I'm not the last in a series
+ * of computeAts)
+ * - Else reduce to active_view_axis if loop_depth > active_view_axis
+ * 2) Set active_view(_axis)
+ * - If there is a computeAt set for this TV
+ * 3) Open to compute At
+ * - If there is a computeAt set for this TV
+ * 4) Allocate the output.
+ * 5) If this is a reduction, initialize the output (open for loops to inner
+ * most, predicate, initialize, close predicate, close to computeAt) 6) Open to
+ * inner most loop 7) Open predicate 8) Run operation 9) Close predicate
+ */
+
+// Update fors based on tv.
+void LoopNestGenerator::updateLoopNest(TensorView* tv) {
+ // 1) Reduce loop structure
+ if (active_view != nullptr) {
+ // - Else reduce to active_view_axis if loop_depth > active_view_axis
+ auto depth = for_loops.size();
+ for (auto i = depth; i > active_view_axis; i--) {
+ for_loops.pop_back();
+ }
+ }
+
+ if (tv->hasComputeAt()) {
+ // 2) Set active_view(_axis)
+ // - If there is a computeAt set for this TV
+ setActiveView(tv);
+
+ // 3) Open to compute At
+ // - If there is a computeAt set for this TV
+ auto depth = for_loops.size();
+
+ for (auto i = depth; i < tv->getComputeAtAxis(); i++)
+ openFor(tv->getComputeAtAxis(i));
+ } else {
+ if (active_view != nullptr)
+ // If we're the last computeAt of a block, active view should match this
+ // tv
+ TORCH_INTERNAL_ASSERT(
+ tv->sameAs(active_view),
+ "Error detected in code lowering. Expected ",
+ active_view,
+ " but recieved ",
+ tv);
+
+ clearActiveView();
+ }
+ // 4) Allocate the output.
+ if (!FusionGuard::getCurFusion()->hasInput(tv) &&
+ !FusionGuard::getCurFusion()->hasOutput(tv)) {
+ pushAlloc(tv);
+ }
+ // TODO:
+ // 5) If this is a reduction, initialize the output (open for loops to inner
+ // most, predicate, initialize, close predicate, close to computeAt)
+
+ // 6) Open to inner most loop
+ for (decltype(tv->nDims()) i = for_loops.size(); i < tv->nDims(); i++)
+ openFor(tv->getComputeAtAxis(i));
+}
+
+// Custom dispatch for Expr, want to find out of it's a TV op
+void LoopNestGenerator::handle(Expr* expr) {
+ if (!ir_utils::isTVOp(expr))
+ return;
+
+ TensorView* out = static_cast<TensorView*>(expr->output(0));
+ updateLoopNest(out);
+
+ pushBack(expr);
+}
+
+// Generate the loop nest structure and place it in lowered_exprs
+void LoopNestGenerator::generate() {
+ FusionGuard fg(fusion_);
+
+ // Initialize members of the class
+ lowered_exprs = std::vector<Expr*>();
+ active_view = nullptr;
+ active_view_axis = 0;
+
+ std::vector<Expr*> exprs = fusion_->exprs(true);
+ for (auto* expr : exprs)
+ handle(expr);
+}
+
+} // namespace fuser
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.h b/torch/csrc/jit/codegen/cuda/lower_loops.h
new file mode 100644
index 0000000..3cc1ef8
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/lower_loops.h
@@ -0,0 +1,93 @@
+#pragma once
+#include <torch/csrc/WindowsTorchApiMacro.h>
+
+#include <torch/csrc/jit/codegen/cuda/dispatch.h>
+
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
+namespace torch {
+namespace jit {
+namespace fuser {
+
+struct UnrollPass : public OptOutDispatch {
+ private:
+ std::unordered_map<Expr*, Expr*> loop_replacement_map;
+ Fusion* fusion_;
+ const std::vector<Expr*>& incoming_exprs_;
+
+ // Keep all for loops conveniently to make unrolling easier
+ std::vector<ForLoop*> for_loops;
+
+ // keep track if we're within an unrolled loop
+ bool within_unroll = false;
+
+ // Track the last computeAt TensorView and axis
+ const TensorView* active_view;
+ unsigned int active_view_axis;
+
+ // Custom dispatch for Expr, want to find out of it's a TV op
+ void handle(Expr*) final;
+
+ // Open the for loop.
+ void handle(ForLoop*) final;
+
+ UnrollPass(Fusion* _fusion, const std::vector<Expr*>& _incoming_exprs)
+ : fusion_(_fusion), incoming_exprs_(_incoming_exprs) {}
+
+ void computeMap();
+
+ public:
+ static std::vector<Expr*> runPass(
+ Fusion* fusion,
+ const std::vector<Expr*>& exprs);
+};
+
+struct TORCH_CUDA_API LoopNestGenerator : public OptOutDispatch {
+ private:
+ std::vector<Expr*> lowered_exprs;
+ Fusion* fusion_;
+
+ // Track the last computeAt TensorView and axis
+ const TensorView* active_view;
+ unsigned int active_view_axis;
+
+ // Keep all for loops conveniently to make unrolling easier
+ std::vector<ForLoop*> for_loops;
+
+ // Get Register allocation statement for tensorview
+ void pushAlloc(TensorView*);
+
+ // Clear out the last recorded computeAtView
+ void clearActiveView();
+ // Set active views from computeAtView
+ void setActiveView(const TensorView* const);
+
+ // Open a new inner most for loop
+ void openFor(IterDomain*);
+
+ // Wrap pushBack in lower_utils if active_scope is null we want it to go
+ // straight to lower_exprs
+ void pushBack(Expr*);
+
+ // Update for loop structure based on this TensorView
+ void updateLoopNest(TensorView*);
+
+ // Check if a TV op, generate for loop nest around it
+ void handle(Expr*) final;
+
+ // Generate the loop nest structure and place it in lowered_exprs
+ void generate();
+
+ LoopNestGenerator(Fusion* _fusion) : fusion_(_fusion) {}
+
+ public:
+ static std::vector<Expr*> getLoopNest(Fusion* fusion) {
+ FusionGuard fg(fusion);
+ LoopNestGenerator lng(fusion);
+ lng.generate();
+ return lng.lowered_exprs;
+ }
+};
+
+} // namespace fuser
+} // namespace jit
+} // namespace torch
\ No newline at end of file
diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp
new file mode 100644
index 0000000..bb71dfb
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp
@@ -0,0 +1,491 @@
+#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
+#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
+namespace torch {
+namespace jit {
+namespace fuser {
+
+namespace scope_utils {
+
+// START SCOPE HELPER SYSTEMS
+namespace {
+
+struct forLoopIndices : private OptInDispatch {
+ private:
+ std::vector<Val*> inds_;
+ void handle(ForLoop* fl) final {
+ inds_.insert(inds_.begin(), fl->index());
+ }
+
+ void handle(IfThenElse* ite) final {}
+
+ void handle(Expr* expr) final {
+ OptInDispatch::handle(expr);
+ }
+
+ public:
+ static std::vector<Val*> get(Expr* scope) {
+ forLoopIndices fli;
+ Expr* it = scope;
+ while (it != nullptr) {
+ fli.handle(it);
+ it = scope_utils::getParent(it);
+ }
+ return fli.inds_;
+ }
+};
+
+struct forLoopIDs : private OptInDispatch {
+ private:
+ std::vector<IterDomain*> IDs_;
+ void handle(ForLoop* fl) final {
+ IDs_.insert(IDs_.begin(), fl->iter_domain());
+ }
+
+ void handle(IfThenElse* ite) final {}
+
+ void handle(Expr* expr) final {
+ OptInDispatch::handle(expr);
+ }
+
+ public:
+ static std::vector<IterDomain*> get(Expr* scope) {
+ forLoopIDs fli;
+ Expr* it = scope;
+ while (it != nullptr) {
+ fli.handle(it);
+ it = scope_utils::getParent(it);
+ }
+ return fli.IDs_;
+ }
+};
+
+struct forLoopCount : private OptInDispatch {
+ private:
+ unsigned int count_ = 0;
+
+ void handle(ForLoop* fl) final {
+ count_++;
+ }
+
+ void handle(IfThenElse* ite) final {}
+
+ void handle(Expr* expr) final {
+ OptInDispatch::handle(expr);
+ }
+
+ public:
+ static unsigned int get(Expr* scope) {
+ forLoopCount flc;
+ Expr* it = scope;
+ while (it != nullptr) {
+ flc.handle(it);
+ it = scope_utils::getParent(it);
+ }
+ return flc.count_;
+ }
+};
+
+struct scopePushBack : private OptInDispatch {
+ private:
+ Expr* expr_;
+ void handle(ForLoop* fl) final {
+ fl->body().push_back(expr_);
+ }
+
+ void handle(IfThenElse* ite) final {
+ ite->body().push_back(expr_);
+ }
+
+ void handle(Expr* expr) final {
+ OptInDispatch::handle(expr);
+ }
+
+ scopePushBack(Expr* expr) : expr_(expr) {}
+
+ public:
+ static void push(Expr* scope, Expr* expr) {
+ scopePushBack pb(expr);
+ TORCH_INTERNAL_ASSERT(
+ expr != nullptr && scope != nullptr,
+ "Cannot push back, scope or expr is a nullptr.");
+ pb.handle(scope);
+ }
+};
+
+struct scopeInsertBefore : private OptInDispatch {
+ private:
+ Expr* ref_;
+ Expr* expr_;
+ void handle(ForLoop* fl) final {
+ fl->body().insert_before(ref_, expr_);
+ }
+
+ void handle(IfThenElse* ite) final {
+ ite->body().insert_before(ref_, expr_);
+ }
+
+ void handle(Expr* expr) final {
+ OptInDispatch::handle(expr);
+ }
+
+ scopeInsertBefore(Expr* ref, Expr* expr) : ref_(ref), expr_(expr) {}
+
+ public:
+ static void insert(Expr* scope, Expr* ref, Expr* expr) {
+ scopeInsertBefore scb(ref, expr);
+ TORCH_INTERNAL_ASSERT(
+ expr != nullptr && scope != nullptr,
+ "Cannot push back, scope or expr is a nullptr.");
+ scb.handle(scope);
+ }
+};
+
+struct parentScope : private OptInDispatch {
+ private:
+ Expr* parent_ = nullptr;
+
+ void handle(ForLoop* fl) final {
+ parent_ = fl->parentScope();
+ }
+
+ void handle(IfThenElse* ite) final {
+ parent_ = ite->parentScope();
+ }
+
+ void handle(Expr* expr) final {
+ OptInDispatch::handle(expr);
+ }
+
+ public:
+ static Expr* get(Expr* scope) {
+ parentScope sp;
+ sp.handle(scope);
+ return sp.parent_;
+ }
+};
+
+struct scopeClearExprs : private OptInDispatch {
+ private:
+ Expr* _expr = nullptr;
+ void handle(ForLoop* fl) final {
+ fl->body().clear();
+ }
+
+ void handle(IfThenElse* ite) final {
+ ite->body().clear();
+ }
+
+ void handle(Expr* expr) final {
+ OptInDispatch::handle(expr);
+ }
+
+ public:
+ static void clear(Expr* scope) {
+ scopeClearExprs sce;
+ TORCH_INTERNAL_ASSERT(
+ scope != nullptr, "Cannot clear scope, scope is a nullptr.");
+ sce.handle(scope);
+ }
+};
+
+void assertScope(Expr* expr) {
+ TORCH_INTERNAL_ASSERT(
+ expr->getExprType() == ExprType::ForLoop ||
+ expr->getExprType() == ExprType::IfThenElse,
+ "Assert Scope failed when calling a scope_util function.");
+}
+
+struct CloneLoopNest : public OptOutMutator {
+ private:
+ Expr* parent_scope_ = nullptr;
+ Expr* to_clone_ = nullptr;
+
+ Statement* mutate(ForLoop* fl) final {
+ std::vector<Expr*> mutated_exprs;
+ for (Expr* expr : fl->body().exprs()) {
+ mutated_exprs.push_back(ir_utils::asExpr(OptOutMutator::mutate(expr)));
+ }
+ if (fl == to_clone_)
+ return new ForLoop(
+ fl->index(), fl->iter_domain(), mutated_exprs, parent_scope_);
+ return new ForLoop(
+ fl->index(), fl->iter_domain(), mutated_exprs, fl->parentScope());
+ }
+
+ CloneLoopNest(Expr* _to_clone, Expr* _parent_scope)
+ : parent_scope_(_parent_scope), to_clone_(_to_clone) {}
+
+ public:
+ static ForLoop* getClone(ForLoop* _to_clone, Expr* _parent_scope) {
+ TORCH_INTERNAL_ASSERT(
+ _to_clone != nullptr,
+ "Tried to clone a scope, but received a nullptr.");
+ CloneLoopNest cln(_to_clone, _parent_scope);
+ return ir_utils::asForLoop(ir_utils::asExpr(cln.mutate(_to_clone)));
+ }
+};
+
+struct ReplaceExprsInScope : public OptOutDispatch {
+ private:
+ std::unordered_map<Expr*, Expr*> replacement_map_;
+
+ void handle(Expr* expr) final {
+ OptOutDispatch::handle(expr);
+ }
+
+ void handle(ForLoop* fl) final {
+ for (Expr* expr : fl->body().exprs()) {
+ auto it = replacement_map_.find(expr);
+ if (it == replacement_map_.end()) {
+ handle(expr);
+ continue;
+ }
+ fl->body().insert_before(expr, replacement_map_[expr]);
+ fl->body().erase(expr);
+ }
+ }
+
+ void handle(IfThenElse* ite) final {
+ for (Expr* expr : ite->body().exprs()) {
+ auto it = replacement_map_.find(expr);
+ if (it == replacement_map_.end()) {
+ handle(expr);
+ continue;
+ }
+ ite->body().insert_before(expr, replacement_map_[expr]);
+ ite->body().erase(expr);
+ }
+ for (Expr* expr : ite->elseBody().exprs()) {
+ auto it = replacement_map_.find(expr);
+ if (it == replacement_map_.end()) {
+ handle(expr);
+ continue;
+ }
+ ite->elseBody().insert_before(expr, replacement_map_[expr]);
+ ite->elseBody().erase(expr);
+ }
+ }
+
+ ReplaceExprsInScope(std::unordered_map<Expr*, Expr*> _replacement_map)
+ : replacement_map_(std::move(_replacement_map)) {}
+
+ public:
+ static void replace(
+ Expr* scope,
+ std::unordered_map<Expr*, Expr*> replacement_map) {
+ ReplaceExprsInScope reis(std::move(replacement_map));
+ reis.handle(scope);
+ }
+};
+
+struct FirstInnerMostScope : private OptInDispatch {
+ private:
+ Expr* active_scope = nullptr;
+
+ void handle(ForLoop* fl) final {
+ for (auto expr : fl->body().exprs()) {
+ if (ir_utils::isScope(expr)) {
+ active_scope = expr;
+ return;
+ }
+ }
+ active_scope = nullptr;
+ }
+
+ void handle(IfThenElse* ite) final {
+ for (auto expr : ite->body().exprs()) {
+ if (ir_utils::isScope(expr)) {
+ active_scope = expr;
+ return;
+ }
+ }
+ for (auto expr : ite->elseBody().exprs()) {
+ if (ir_utils::isScope(expr)) {
+ active_scope = expr;
+ return;
+ }
+ }
+ active_scope = nullptr;
+ }
+
+ Expr* getInner(Expr* expr) {
+ OptInDispatch::handle(expr);
+ return active_scope;
+ }
+
+ public:
+ static Expr* get(Expr* scope) {
+ TORCH_INTERNAL_ASSERT(
+ scope != nullptr,
+ "Tried to get inner most scope, but was provided nullptr.");
+
+ FirstInnerMostScope fims;
+ Expr* inner = fims.getInner(scope);
+ while (fims.getInner(inner) != nullptr)
+ inner = fims.getInner(inner);
+ return inner;
+ }
+};
+
+// END SCOPE HELPER SYSTEMS
+} // namespace
+
+// Grab the index variables of the active loop nest
+std::vector<Val*> getLoopIndices(Expr* scope) {
+ if (scope == nullptr)
+ return std::vector<Val*>();
+ assertScope(scope);
+ return forLoopIndices::get(scope);
+}
+
+// Grab the iterDomains of the active loops
+std::vector<IterDomain*> getLoopIterDomains(Expr* scope) {
+ if (scope == nullptr)
+ return std::vector<IterDomain*>();
+ assertScope(scope);
+ return forLoopIDs::get(scope);
+}
+
+// Track how far our for loop scope is
+unsigned int computeForDepth(Expr* scope) {
+ if (scope == nullptr)
+ return 0;
+ assertScope(scope);
+ return forLoopCount::get(scope);
+}
+
+// Push back an expr to scope
+void pushBack(Expr* scope, Expr* expr) {
+ TORCH_INTERNAL_ASSERT(
+ scope != nullptr, "Scope is a nullptr, cannot push an expr to it.");
+ assertScope(scope);
+ scopePushBack::push(scope, expr);
+}
+
+// Insert expr in scope before ref
+void insertBefore(Expr* scope, Expr* ref, Expr* expr) {
+ scopeInsertBefore::insert(scope, ref, expr);
+}
+
+// Return the parent of the active scope
+Expr* getParent(Expr* scope) {
+ TORCH_INTERNAL_ASSERT(
+ scope != nullptr,
+ "Tried to close the active scope, but there isn't one set.");
+ assertScope(scope);
+ return parentScope::get(scope);
+}
+
+// Open a new inner most for loop
+ForLoop* openFor(Expr* scope, IterDomain* id) {
+ ForLoop* new_scope = nullptr;
+ if (id->isThread()) {
+ new_scope = new ForLoop(
+ new NamedScalar(stringify(id->parallel_method()), DataType::Int),
+ id,
+ {},
+ scope);
+ } else {
+ new_scope = new ForLoop(new Int(), id, {}, scope);
+ }
+ if (scope != nullptr)
+ pushBack(scope, new_scope);
+ return new_scope;
+}
+
+// Close the inner most for loop
+Expr* closeScope(Expr* scope) {
+ TORCH_INTERNAL_ASSERT(
+ scope != nullptr, "Tried to close a scope but got a nullptr.");
+ return getParent(scope);
+}
+
+// Clear all expressions from the scope
+Expr* clearScope(Expr* scope) {
+ TORCH_INTERNAL_ASSERT(
+ scope != nullptr, "Tried to clear a scope but got a nullptr.");
+ assertScope(scope);
+ scopeClearExprs::clear(scope);
+ return scope;
+}
+
+ForLoop* cloneLoopNest(ForLoop* to_clone, Expr* parent_scope) {
+ return CloneLoopNest::getClone(to_clone, parent_scope);
+}
+
+void replaceExprsInScope(
+ Expr* scope,
+ std::unordered_map<Expr*, Expr*> replacement_map) {
+ TORCH_INTERNAL_ASSERT(
+ replacement_map.find(scope) == replacement_map.end(),
+ "Error trying to replace expressions in a scope, scope wants to be replaced entirely.");
+ ReplaceExprsInScope::replace(scope, std::move(replacement_map));
+}
+
+Expr* firstInnerMostScope(Expr* scope) {
+ return FirstInnerMostScope::get(scope);
+}
+
+} // namespace scope_utils
+
+namespace ir_utils {
+
+bool isTV(const Val* const val) {
+ return val->getValType().value() == ValType::TensorView;
+}
+
+// Check if we're a TensorView op that we can generate code for.
+bool isTVOp(const Expr* expr) {
+ if (expr->nOutputs() == 1 && isTV(expr->output(0)) &&
+ (expr->getExprType().value() == ExprType::BinaryOp ||
+ expr->getExprType().value() == ExprType::UnaryOp))
+ return true;
+ return false;
+}
+
+void ASSERT_EXPR(Statement* stmt) {
+ TORCH_INTERNAL_ASSERT(
+ stmt->isExpr(),
+ "Tried to generate a kernel but hit a non expression during lowering: ",
+ stmt);
+}
+
+Expr* asExpr(Statement* stmt) {
+ ASSERT_EXPR(stmt);
+ return static_cast<Expr*>(stmt);
+}
+
+TensorView* asTV(Val* val) {
+ TORCH_INTERNAL_ASSERT(isTV(val));
+ return static_cast<TensorView*>(val);
+}
+
+bool isScope(const Expr* expr) {
+ return expr->getExprType() == ExprType::ForLoop ||
+ expr->getExprType() == ExprType::IfThenElse;
+}
+
+ForLoop* asForLoop(Statement* stmt) {
+ Expr* expr = asExpr(stmt);
+ TORCH_INTERNAL_ASSERT(expr->getExprType() == ExprType::ForLoop);
+ return static_cast<ForLoop*>(expr);
+}
+
+const TensorView* asConstTV(const Val* const val) {
+ TORCH_INTERNAL_ASSERT(isTV(val));
+ return static_cast<const TensorView*>(val);
+}
+
+bool isUnrolledFor(const Expr* expr) {
+ if (expr->getExprType() != ExprType::ForLoop) {
+ return false;
+ }
+ return static_cast<const ForLoop*>(expr)->iter_domain()->parallel_method() ==
+ ParallelType::Unroll;
+}
+
+} // namespace ir_utils
+
+} // namespace fuser
+} // namespace jit
+} // namespace torch
\ No newline at end of file
diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.h b/torch/csrc/jit/codegen/cuda/lower_utils.h
new file mode 100644
index 0000000..208cd7d
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/lower_utils.h
@@ -0,0 +1,77 @@
+#pragma once
+
+#include <torch/csrc/WindowsTorchApiMacro.h>
+
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
+
+// Provides utilities for dealing with nested ForLoop and IfThenElse scopes
+
+namespace torch {
+namespace jit {
+namespace fuser {
+
+namespace scope_utils {
+
+// Grab the index variables of the active loop nest
+std::vector<Val*> getLoopIndices(Expr* scope);
+
+// Grab the iterDomains of the active loops
+std::vector<IterDomain*> getLoopIterDomains(Expr* scope);
+
+// Track how far our for loop scope is
+unsigned int computeForDepth(Expr* scope);
+
+// Push back an expr to scope
+void pushBack(Expr* scope, Expr* expr);
+
+// Insert expr in scope before ref
+void insertBefore(Expr* scope, Expr* ref, Expr* expr);
+
+// Return the parent of the active scope
+Expr* getParent(Expr* scope);
+
+// Open a new inner most for loop
+ForLoop* openFor(Expr* scope, IterDomain*);
+
+// Close the inner most for loop
+Expr* closeScope(Expr* scope);
+
+// Clear all expressions from the scope
+Expr* clearScope(Expr* scope);
+
+// Provide a new for loop matching the one provided
+ForLoop* cloneLoopNest(ForLoop* to_clone, Expr* parent_scope);
+
+// Run through a scope and replace expressions inside with replacement_map
+void replaceExprsInScope(
+ Expr* scope,
+ std::unordered_map<Expr*, Expr*> replacement_map);
+
+Expr* firstInnerMostScope(Expr* scope);
+
+} // namespace scope_utils
+
+namespace ir_utils {
+
+bool isTV(const Val* const);
+
+bool isTVOp(const Expr*);
+
+void ASSERT_EXPR(Statement*);
+
+bool isScope(const Expr*);
+
+Expr* asExpr(Statement*);
+
+TensorView* asTV(Val*);
+
+ForLoop* asForLoop(Statement*);
+
+const TensorView* asConstTV(const Val* const);
+
+bool isUnrolledFor(const Expr*);
+
+} // namespace ir_utils
+} // namespace fuser
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/manager.cpp b/torch/csrc/jit/codegen/cuda/manager.cpp
index 2fe12a8..d152552 100644
--- a/torch/csrc/jit/codegen/cuda/manager.cpp
+++ b/torch/csrc/jit/codegen/cuda/manager.cpp
@@ -1,11 +1,10 @@
#include <torch/csrc/jit/codegen/cuda/manager.h>
-#include <torch/csrc/jit/codegen/cuda/kernel.h>
-#include <torch/csrc/jit/codegen/cuda/parser.h>
-
#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/tensor.h>
+#include <torch/csrc/jit/codegen/cuda/kernel_cache.h>
+#include <torch/csrc/jit/codegen/cuda/parser.h>
#include <torch/csrc/jit/codegen/cuda/utils.h>
#include <torch/csrc/jit/passes/canonicalize.h>
+#include <torch/csrc/jit/passes/shape_analysis.h>
#include <unordered_map>
@@ -16,6 +15,15 @@
namespace {
+KernelArgsReq expandSizeSupport(const at::IntArrayRef sizes) {
+ KernelArgsReq req;
+ for (auto size : sizes) {
+ req.low_.push_back(size);
+ req.hi_.push_back(size);
+ }
+ return req;
+}
+
// CudaFusionManager holds compiled `CudaKernel` and handles all interfacing
// including compilation and execution.
//
@@ -38,9 +46,12 @@
// is even more restricting in a good way)
int32_t registerOrGetCacheId(std::shared_ptr<Graph>& graph) {
std::lock_guard<std::mutex> guard(mutex_);
+
// prepare graph for lowering;
+ // TODO: this is needed. Otherwise caching on tensor size would not work, as
+ // different tensor size would result in unique string representation.
+ EraseShapeInformation(graph);
Canonicalize(graph, false);
- // EraseShapeInformation(graph);
auto repr = graph->toString(false);
// create new graph_cache_ entry;
@@ -49,29 +60,53 @@
graph_cache_[repr] = kernel_id;
- Fusion fusion;
- // lower torch::jit::Graph to torch::jit::fuser::cuda::fusion
- parseJitIR(graph, fusion);
+ // create entry for cached kernel;
+ kernel_cache_.insert({kernel_id, CudaKernelCache()});
- // default constructor via accessing empty key;
- compileKernel(fusion, kernel_cache_[kernel_id]);
-
- return kernel_id;
- } else {
- return graph_cache_[repr];
+ // TODO: we should compile here using profiled information:
+ // size (range) / stride (contiguity)
}
+
+ return graph_cache_[repr];
};
void runFusionNode(
int32_t kernel_id,
+ std::shared_ptr<Graph>& graph,
const at::ArrayRef<IValue> inputs,
std::vector<at::Tensor> outputs) {
+ std::lock_guard<std::mutex> guard(mutex_);
TORCH_CHECK(
kernel_cache_.count(kernel_id) != 0, "kernel id not recognized");
- CudaKernel& cuda_kernel_entry = kernel_cache_[kernel_id];
+ // TODO: temporary hack
+ auto cuda_kernel =
+ kernel_cache_[kernel_id].getKernelPtr(outputs[0].sizes());
+ if (cuda_kernel) {
+ // TODO: update launch config for specific sizes;
+ // maybe we should store it in CudaKernel and compute it later
+ runKernel(*cuda_kernel, inputs, outputs);
+ } else {
+ // major HACK!
+ auto kernel_arg_req = expandSizeSupport(outputs[0].sizes());
+ cuda_kernel =
+ kernel_cache_[kernel_id].allocateKernelInCache(kernel_arg_req);
- runKernel(cuda_kernel_entry, inputs, outputs);
+ // lower torch::jit::Graph to torch::jit::fuser::cuda::fusion
+ Fusion fusion;
+ // TODO: pass contiguity infor as well as size req, so we can apply proper
+ // transform to computation
+ // we should propagate more information back:
+ // 1. device;
+ // 2. launch config;
+ parseJitIR(graph, fusion);
+ cuda_kernel.value()->device_ = 0;
+
+ // NVRTC compile kernel
+ compileKernel(fusion, cuda_kernel.value());
+
+ runKernel(*cuda_kernel, inputs, outputs);
+ }
}
private:
@@ -87,7 +122,7 @@
};
std::unordered_map<std::string, int32_t> graph_cache_;
- std::unordered_map<int64_t, CudaKernel> kernel_cache_;
+ std::unordered_map<int64_t, CudaKernelCache> kernel_cache_;
int32_t next_unique_id_ = 0;
};
@@ -119,10 +154,49 @@
int32_t kernel_id = fusion_node->i(attr::cache_id);
// Currently we just construct I/O tensors for static graph;
- const std::shared_ptr<Graph> graph = fusion_node->g(attr::Subgraph);
+ std::shared_ptr<Graph> graph = fusion_node->g(attr::Subgraph);
+
const auto nInputs = graph->inputs().size();
at::ArrayRef<IValue> inputs = last(stack, nInputs);
+ // shape inference in graph
+ bool matched_static_inputs = true;
+ for (int i = 0; i < nInputs; i++) {
+ auto& static_input = graph->inputs()[i];
+ auto& dynamic_input = inputs[i]; // this is FILO stack
+ if ((*dynamic_input.type()) != (*static_input->type())) {
+ matched_static_inputs = false;
+ break;
+ }
+ if (dynamic_input.isTensor()) {
+ at::Tensor inp_tensor = dynamic_input.toTensor();
+ // we need to return use shape inference when static shape is not complete
+ // even though it is compatible with profiling graph.
+ // TODO: we could relax on a bunch of checks here, like strides & gradient
+ if (!static_input->type()->cast<TensorType>()->sizes().isComplete() ||
+ !static_input->type()
+ ->cast<TensorType>()
+ ->isCompatibleWithInCurrentExecutionContext(inp_tensor)) {
+ matched_static_inputs = false;
+ break;
+ }
+ }
+ }
+
+ // TODO: expose the API to populate shape inference. This allows separate CI
+ // tests
+ // matched_static_inputs = false;
+ if (!matched_static_inputs) {
+ // update shape information per the new inputs;
+ // shape inference done through PyTorch JIT shape propagation;
+ EraseShapeInformation(graph);
+ for (int i = 0; i < nInputs; i++) {
+ graph->inputs()[i]->setType(inputs[i].type());
+ }
+ // shape inference
+ PropagateInputShapes(graph);
+ }
+
// we need to construct outputs;
std::vector<at::Tensor> outputs;
for (const auto* const output : graph->outputs()) {
@@ -149,7 +223,8 @@
outputs.push_back(tensor);
}
- CudaFusionManager::getManager().runFusionNode(kernel_id, inputs, outputs);
+ CudaFusionManager::getManager().runFusionNode(
+ kernel_id, graph, inputs, outputs);
drop(stack, inputs.size());
stack.insert(
diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp
index dad5a78..65eee53 100644
--- a/torch/csrc/jit/codegen/cuda/mutator.cpp
+++ b/torch/csrc/jit/codegen/cuda/mutator.cpp
@@ -29,20 +29,21 @@
// MUTATE FUNCTIONS FOR VALS
Statement* OptOutMutator::mutate(IterDomain* id) {
- Val* s = mutateAsVal(id->size())->asVal();
- if (!s->sameAs(id->size())) {
- Val* mutated_val =
- new IterDomain(s, id->parallel_method(), id->isReduction());
- registerMutation(id, mutated_val);
- return mutated_val;
- }
- return id;
+ Val* s = mutateAsVal(id->start())->asVal();
+ Val* e = mutateAsVal(id->extent())->asVal();
+ if (s->sameAs(id->start()) && e->sameAs(id->extent()))
+ return id;
+
+ Val* mutated_val =
+ new IterDomain(s, e, id->parallel_method(), id->isReduction());
+ registerMutation(id, mutated_val);
+ return mutated_val;
}
Statement* OptOutMutator::mutate(TensorDomain* td) {
std::vector<IterDomain*> dom;
bool mutated = false;
- for (decltype(td->size()) i = 0; i < td->size(); i++) {
+ for (decltype(td->nDims()) i = 0; i < td->nDims(); i++) {
IterDomain* id = static_cast<IterDomain*>(mutateAsVal(td->axis(i)));
dom.push_back(id);
if (!id->sameAs(td->axis(i)))
@@ -175,12 +176,76 @@
return new BinaryOp(bop->getBinaryOpType(), out, lhs, rhs);
}
-Statement* OptOutMutator::mutate(ForLoop* n) {
- return n;
+Statement* OptOutMutator::mutate(ForLoop* fl) {
+ Val* index = mutateAsVal(fl->index())->asVal();
+ Val* val_id = mutateAsVal(fl->iter_domain())->asVal();
+
+ TORCH_INTERNAL_ASSERT(val_id->getValType() == ValType::IterDomain);
+ IterDomain* id = static_cast<IterDomain*>(val_id);
+
+ bool is_mutated = !index->sameAs(fl->index());
+ is_mutated = is_mutated | !id->sameAs(fl->iter_domain());
+
+ std::vector<Expr*> mutated_exprs;
+ for (auto expr : fl->body().exprs()) {
+ Statement* mutated_stmt = mutate(expr);
+ TORCH_INTERNAL_ASSERT(
+ mutated_stmt->isExpr(),
+ "While mutating a for loop, received a non-expression for a body entry.");
+ Expr* mutated_expr = static_cast<Expr*>(mutated_stmt);
+ mutated_exprs.push_back(mutated_expr);
+ // could use sameAs here, but we'd have to check the output value separately
+ is_mutated = is_mutated | (mutated_expr != expr);
+ }
+
+ if (is_mutated) {
+ auto newFL = new ForLoop(index, id, mutated_exprs, fl->parentScope());
+ return newFL;
+ }
+
+ return fl;
}
-Statement* OptOutMutator::mutate(IfThenElse* n) {
- return n;
+Statement* OptOutMutator::mutate(IfThenElse* ite) {
+ Val* val_cond = mutateAsVal(ite->cond())->asVal();
+ TORCH_INTERNAL_ASSERT(
+ val_cond->getValType().value() == ValType::Scalar &&
+ val_cond->getDataType().value() == DataType::Int);
+ Int* cond = static_cast<Int*>(val_cond);
+
+ bool is_mutated = !cond->sameAs(ite->cond());
+
+ std::vector<Expr*> mutated_exprs;
+ for (auto expr : ite->body().exprs()) {
+ Statement* mutated_stmt = mutate(expr);
+ TORCH_INTERNAL_ASSERT(
+ mutated_stmt->isExpr(),
+ "While mutating a for loop, received a non-expression for a body entry.");
+ Expr* mutated_expr = static_cast<Expr*>(mutated_stmt);
+ mutated_exprs.push_back(mutated_expr);
+ // could use sameAs here, but we'd have to check the output value separately
+ is_mutated = is_mutated | (mutated_expr != expr);
+ }
+
+ std::vector<Expr*> mutated_else_exprs;
+ for (auto expr : ite->elseBody().exprs()) {
+ Statement* mutated_stmt = mutate(expr);
+ TORCH_INTERNAL_ASSERT(
+ mutated_stmt->isExpr(),
+ "While mutating a for loop, received a non-expression for a body entry.");
+ Expr* mutated_expr = static_cast<Expr*>(mutated_stmt);
+ mutated_else_exprs.push_back(mutated_expr);
+ // could use sameAs here, but we'd have to check the output value separately
+ is_mutated = is_mutated | (mutated_expr != expr);
+ }
+
+ if (is_mutated) {
+ auto newITE = new IfThenElse(
+ cond, ite->body().exprs(), ite->elseBody().exprs(), ite->parentScope());
+ return newITE;
+ }
+
+ return ite;
}
// START REPLACE ALL
diff --git a/torch/csrc/jit/codegen/cuda/mutator.h b/torch/csrc/jit/codegen/cuda/mutator.h
index a8f7723..c8b4c6f 100644
--- a/torch/csrc/jit/codegen/cuda/mutator.h
+++ b/torch/csrc/jit/codegen/cuda/mutator.h
@@ -4,7 +4,6 @@
#include <torch/csrc/jit/codegen/cuda/dispatch.h>
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/tensor.h>
#include <unordered_map>
diff --git a/torch/csrc/jit/codegen/cuda/parser.cpp b/torch/csrc/jit/codegen/cuda/parser.cpp
index ab88a01..9a0677a 100644
--- a/torch/csrc/jit/codegen/cuda/parser.cpp
+++ b/torch/csrc/jit/codegen/cuda/parser.cpp
@@ -1,8 +1,9 @@
#include <torch/csrc/jit/codegen/cuda/parser.h>
#include <torch/csrc/jit/codegen/cuda/arith.h>
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/tensor.h>
+
#include <torch/csrc/jit/frontend/function_schema_parser.h>
#include <torch/csrc/jit/ir/constants.h>
@@ -19,14 +20,18 @@
namespace {
-typedef Val CgValue;
-typedef Expr CgOp;
+typedef Val* CgValue;
+typedef Expr* CgOp;
typedef void (
- *ParseFuncPtr)(const Node* const, std::unordered_map<size_t, CgValue*>&);
+ *ParseFuncPtr)(const Node* const, std::unordered_map<size_t, CgValue>&);
// TODO: add a mutex to make it thread safe.
class IrParser {
+ private:
+ static const int nthreads = 128;
+ static const int unroll_factor = 4;
+
public:
IrParser(std::shared_ptr<Graph> graph, Fusion& fusion)
: graph_(std::move(graph)), fusion_(&fusion) {
@@ -36,15 +41,23 @@
}
}
+ // Fuses pointwise ops with loop unrolling (factor = 4).
void parse() {
FusionGuard fg(fusion_);
auto block = graph_->block();
+ // in case of broadcast, we don't support explicit broadcast, so we need to
+ // convert/expand all inputs tensors to comply to the broadcasted size.
+ // This supports very limited case, which we try to accomodate in graph
+ // partition, that we only merge nodes with identical output shapes.
+ int broadcast_dim =
+ block->outputs()[0]->type()->cast<TensorType>()->dim().value();
+
// register all inputs;
// shape propagation during parsing is effctively done in parsing rules, as
// we only explicitly register inputs in the graph.
for (auto val : block->inputs()) {
- TORCH_CHECK(registerValue(val));
+ TORCH_CHECK(registerValue(val, broadcast_dim));
fusion_->addInput(value_maps_[val->unique()]);
}
@@ -62,26 +75,39 @@
// Merge all dimensions because we're only supporting pointwise
while (out->nDims() > 1)
out->merge(0);
- // Split into 128 so we can map blocks/threads
- out->split(0, 128);
+ // Split into 128 which will be bockDim.x
+ out->split(0, nthreads);
+ // Split by another 4 which will be our unroll factor
+ out->split(0, unroll_factor);
// Map blocks/threads
out->axis(0)->parallelize(ParallelType::BIDx);
+ out->axis(1)->parallelize(ParallelType::Unroll);
out->axis(-1)->parallelize(ParallelType::TIDx);
}
- for (auto jit_input : block->inputs()) {
- TensorView* inp =
- static_cast<TensorView*>(value_maps_[jit_input->unique()]);
- for (auto jit_output : block->outputs()) {
- TensorView* out =
- static_cast<TensorView*>(value_maps_[jit_output->unique()]);
- if (DependencyCheck::isDependencyOf(inp, out)) {
- inp->computeAt(out, -1);
- break;
- }
+ // Run through outputs, grab all inputs of outputs
+ // squeeze with computeAt to set overall structure.
+ for (auto jit_output : block->outputs()) {
+ TensorView* out =
+ static_cast<TensorView*>(value_maps_[jit_output->unique()]);
+
+ for (TensorView* inp : fusion_->inputsOf(out)) {
+ inp->computeAt(out, 1);
}
}
+
+ // Run through intermediates, unroll, and bind their axes
+ for (auto entry : value_maps_) {
+ CgValue val = entry.second;
+ if (fusion_->hasInput(val) || fusion_->hasOutput(val))
+ continue;
+ if (val->getValType().value() != ValType::TensorView)
+ continue;
+ TensorView* tv = static_cast<TensorView*>(val);
+ tv->axis(-2)->parallelize(ParallelType::Unroll);
+ tv->axis(-1)->parallelize(ParallelType::TIDx);
+ }
}
static bool canParseNode(const Node* const node) {
@@ -111,10 +137,10 @@
.push_back(std::make_pair(op, fn));
}
- protected:
+ private:
static void parseBinaryOpWithAlpha(
const Node* const node,
- std::unordered_map<size_t, CgValue*>& value_maps) {
+ std::unordered_map<size_t, CgValue>& value_maps) {
static std::unordered_map<Symbol, BinaryOpType> op_mapping({
{aten::add, BinaryOpType::Add},
{aten::sub, BinaryOpType::Sub},
@@ -128,7 +154,7 @@
static void parseBinaryOp(
const Node* const node,
- std::unordered_map<size_t, CgValue*>& value_maps) {
+ std::unordered_map<size_t, CgValue>& value_maps) {
static std::unordered_map<Symbol, BinaryOpType> op_mapping({
{aten::mul, BinaryOpType::Mul},
{aten::div, BinaryOpType::Div},
@@ -193,13 +219,13 @@
}
}
- bool registerValue(const JitValue* val) {
- return registerTensor(val) || registerScalar(val);
+ bool registerValue(const JitValue* val, int broadcast_dim = -1) {
+ return registerTensor(val, broadcast_dim) || registerScalar(val);
}
bool registerScalar(const JitValue* val) {
if (val->type()->isSubtypeOf(static_cast<c10::TypePtr>(FloatType::get()))) {
- CgValue* cg_val;
+ CgValue cg_val;
if (auto ival = constant_as<float>(val)) {
cg_val = new Float(ival.value());
} else {
@@ -209,7 +235,7 @@
return true;
} else if (val->type()->isSubtypeOf(
static_cast<c10::TypePtr>(IntType::get()))) {
- CgValue* cg_val;
+ CgValue cg_val;
if (auto ival = constant_as<int>(val)) {
cg_val = new Float(ival.value());
} else {
@@ -221,12 +247,16 @@
return false;
}
- bool registerTensor(const JitValue* val) {
- CgValue* cg_val;
+ bool registerTensor(const JitValue* val, int broadcast_dim = -1) {
+ CgValue cg_val;
if (val->isCompleteTensor()) {
+ auto tensor_type = val->type()->cast<TensorType>();
+ if (broadcast_dim >= 0) {
+ tensor_type = tensor_type->withDim(broadcast_dim);
+ }
// TODO: make this a static function in Tensor class;
// create tensor;
- cg_val = new TensorView(val->type()->cast<TensorType>());
+ cg_val = new TensorView(tensor_type);
value_maps_.emplace(val->unique(), cg_val);
return true;
}
@@ -237,7 +267,7 @@
Fusion* fusion_;
// maps from JitValue::unique() to fusion Val;
- std::unordered_map<size_t, CgValue*> value_maps_;
+ std::unordered_map<size_t, CgValue> value_maps_;
// parsing rule registry.
static std::unordered_map<
Symbol,
diff --git a/torch/csrc/jit/codegen/cuda/partition.cpp b/torch/csrc/jit/codegen/cuda/partition.cpp
index 7a7fbc4..c12f3c1 100644
--- a/torch/csrc/jit/codegen/cuda/partition.cpp
+++ b/torch/csrc/jit/codegen/cuda/partition.cpp
@@ -1,4 +1,5 @@
#include <torch/csrc/jit/codegen/cuda/partition.h>
+#include <ATen/core/jit_type.h>
#include <torch/csrc/jit/codegen/cuda/parser.h>
namespace torch {
@@ -54,6 +55,23 @@
return (isNodeParsible(node) || node->kind() == prim::CudaFusionGroup);
}
+// TODO: how would symbolic shape from profiling executor play with this?
+static bool compatible_broadcast_shape(
+ const c10::VaryingShape& e,
+ const c10::VaryingShape& a) {
+ if (e.isComplete() && a.isComplete()) {
+ auto e_size = e.concrete_sizes().value();
+ auto a_size = a.concrete_sizes().value();
+ for (size_t i = 0; i < e_size.size(); i++) {
+ if (e_size[i] != a_size[i]) {
+ return false;
+ }
+ }
+ return true;
+ }
+ return false;
+}
+
} // namespace
bool isFusableCudaFusionGroup(const Node* const node) {
@@ -69,7 +87,30 @@
if (isFusableNode(node)) {
auto device = getDevice(fusion);
- return (device.has_value() && isFusableDevice(node, device.value()));
+ auto tensor_type = fusion->outputs()[0]->type()->cast<TensorType>();
+ if (tensor_type) {
+ for (auto output : node->outputs()) {
+ // We only check shape of tensor output
+ auto output_type = output->type()->cast<TensorType>();
+ if (output_type) {
+ bool output_tensor = false;
+ for (auto use : output->uses()) {
+ if (use.user != fusion) {
+ output_tensor = true;
+ break;
+ }
+ }
+ // if the output is not used by outside, there's no need to check its
+ // shape
+ if (output_tensor &&
+ !compatible_broadcast_shape(
+ tensor_type->sizes(), output_type->sizes())) {
+ return false;
+ }
+ }
+ }
+ return (device.has_value() && isFusableDevice(node, device.value()));
+ }
}
return false;
}
diff --git a/torch/csrc/jit/codegen/cuda/predicate_compute.cpp b/torch/csrc/jit/codegen/cuda/predicate_compute.cpp
index dd55f87..c059f12 100644
--- a/torch/csrc/jit/codegen/cuda/predicate_compute.cpp
+++ b/torch/csrc/jit/codegen/cuda/predicate_compute.cpp
@@ -23,11 +23,11 @@
const TensorView* tv = ti->view();
TensorDomain* root = tv->getRootDomain();
- TORCH_CHECK(root->size() == ti->size());
- for (decltype(ti->size()) i{0}; i < ti->size(); i++)
+ TORCH_CHECK(root->nDims() == ti->nDims());
+ for (decltype(ti->nDims()) i{0}; i < ti->nDims(); i++)
if (FusionGuard::getCurFusion()->origin(ti->index(i)) != nullptr) {
- Val* pred = lt(ti->index(i), root->axis(i)->size());
+ Val* pred = lt(ti->index(i), root->axis(i)->extent());
TORCH_CHECK(
pred->getValType().value() == ValType::Scalar &&
pred->getDataType().value() == DataType::Int);
diff --git a/torch/csrc/jit/codegen/cuda/predicate_compute.h b/torch/csrc/jit/codegen/cuda/predicate_compute.h
index 35e4334..0e5811b 100644
--- a/torch/csrc/jit/codegen/cuda/predicate_compute.h
+++ b/torch/csrc/jit/codegen/cuda/predicate_compute.h
@@ -1,7 +1,6 @@
#pragma once
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/tensor.h>
/*
* Predicate compute takes a TensorView and set of indices. The number of
diff --git a/torch/csrc/jit/codegen/cuda/tensor.h b/torch/csrc/jit/codegen/cuda/tensor.h
deleted file mode 100644
index 80d8705..0000000
--- a/torch/csrc/jit/codegen/cuda/tensor.h
+++ /dev/null
@@ -1,73 +0,0 @@
-#pragma once
-
-#include <torch/csrc/WindowsTorchApiMacro.h>
-
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/tensor_meta.h>
-
-/*
- * This file currently contains items associated with tensors, tensor domains,
- * tensor views and transforms associated with them (split, merge, reorder,
- * compute_at).
- *
- * Tensor is our link to the tensors described and used in the JIT. We create
- * our own wrapper version as a stepping stone into our IR structure, this
- * allows us to link our concept of tensors with that of the JIT.
- *
- * IterDomain for now is an annotated size. The size is a range for us to
- * iterate over (number of elements, not including stride). The annotations are
- * associated with if there's a parallelization mechanism associated with the
- * iter domain, and if we need to reduce over it.
- *
- * TensorDomain holds a vector (could be changed to an array) of IterDomains. It
- * holds an IterDomain for every logical axis in its associated tensor.
- * TensorDomain does not directly hold the Tensor it is associated.
- * TensorDomain's primary responsibility is to hold the history of
- * transformations that were used to generate it. This is done through the
- * normal interaction of Expr/Val in Fusion. i.e. if we want to know the
- * previous operation generating a particular TensorDomain we can simply call
- * FusionGuard::getCurFusion()->origin(a_tensor_domain) which should give us an
- * operation in the list [split, merge, reorder] or similar operations that take
- * in a TensorDomain, applies a transformation and outputs a tensor domain.
- *
- * TensorView is the glue between TensorDomain and Tensor. TensorView is
- * intended to be used directly in mathematical operations. TensorView is
- * directly used in the "what" is being computed. TensorView holds a reference
- * to the Tensor it's a view of, as well as the TensorDomain of that particular
- * view. TensorView provides the history of the what is being computed and that
- * history can be accessed, similar to the mechanism TensorDomain uses, through
- * normal Expr/Val interactions in Fusion. i.e.
- * FusionGuard::getCurFusion()->origin(a_tensor_view) which should give us an
- * operation that takes in a TensorView, other inputs (other TensorViews, or
- * Scalars) applies a mathematical operation and outputs a TensorView (and other
- * outputs?).
- *
- * The reason we need TensorView and TensorDomain is that we need to have a
- * record of both what is being computed and how it is being computed. For
- * Example we may have the operation: TV3[I, J, K] = TV2[I, J, K] + TV1[I, J, K]
- * The mathematical operationss here are on the tensor views TV1, TV2, and TV3.
- * This operation is a pointwise operation. To compute this pointwise operation
- * we iterate over the 3D TensorDomain [I, J, K], where K is the fastest
- * changing dimension.
- *
- * For now the functions split, merge, reorder, and compute_at are also in this
- * file and its associated .cpp file. However, they may be moved later.
- *
- */
-
-namespace torch {
-namespace jit {
-namespace fuser {
-
-struct TransformReplay;
-struct TensorView;
-
-TORCH_CUDA_API TensorView* split_(TensorView*, int axis, int factor);
-TORCH_CUDA_API TensorView* merge_(TensorView*, int axis);
-TORCH_CUDA_API TensorView* reorder_(
- TensorView*,
- const std::unordered_map<int, int>&);
-
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/tensor_meta.cpp b/torch/csrc/jit/codegen/cuda/tensor_meta.cpp
index c180459..7643ed9 100644
--- a/torch/csrc/jit/codegen/cuda/tensor_meta.cpp
+++ b/torch/csrc/jit/codegen/cuda/tensor_meta.cpp
@@ -51,6 +51,7 @@
*/
// debug print. remove this guy!
+#ifdef TC_DEBUG
template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& data) {
os << "(";
@@ -60,6 +61,7 @@
}
return os << ")";
}
+#endif
TensorContiguity::TensorContiguity(
const std::vector<int64_t>& sizes,
diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp
index 680643a..dd1e5de 100644
--- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp
+++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp
@@ -28,171 +28,6 @@
} // namespace
-TensorView* split_(TensorView* tv, int axis, int factor) {
- TensorDomain* td = tv->domain();
-
- if (axis < 0)
- axis += td->size();
-
- assert(axis >= 0 && axis < td->size());
-
- IterDomain* id = td->axis(axis);
-
- if (id->parallel_method() != ParallelType::Serial)
- TORCH_CHECK(
- false,
- "Splitting an axis of non-Serial iteration is not supported at this time."
- " Parallelization strategy must be set after calling split.");
-
- if (tv->getComputeAtView() != nullptr)
- if (axis < tv->getComputeAtAxis())
- TORCH_CHECK(false, "Cannot split axis within the compute at range.");
-
- std::vector<IterDomain*> new_domain;
-
- Int* fact = new Int(factor);
- Int* one = new Int(1);
-
- for (decltype(td->size()) i = 0; i < td->size(); i++) {
- if (i != axis)
- new_domain.push_back(td->axis(i));
- else {
- // outer loop size
- Val* vo = ceilDiv(id->size(), fact);
- Int* so = static_cast<Int*>(vo);
-
- // outer loop IterDomain
- IterDomain* ido =
- new IterDomain(so, id->parallel_method(), id->isReduction());
- new_domain.push_back(ido);
-
- // inner loop IterDomain
- IterDomain* idi =
- new IterDomain(fact, id->parallel_method(), id->isReduction());
- new_domain.push_back(idi);
- }
- }
- TensorDomain* split_td = new TensorDomain(new_domain);
- Split* split_node = new Split(split_td, td, axis, fact); // For record keeping
- tv->setDomain(split_td);
- return tv;
-}
-
-TensorView* merge_(TensorView* tv, int axis) {
- TensorDomain* td = tv->domain();
-
- if (axis < 0)
- axis += td->size();
-
- assert(axis >= 0 && axis + 1 < td->size());
-
- if (tv->getComputeAtView() != nullptr)
- if (axis < tv->getComputeAtAxis())
- TORCH_CHECK(false, "Cannot split axis within compute at range.");
-
- IterDomain* first = td->axis(axis);
- IterDomain* second = td->axis(axis + 1);
-
- assert(first->isReduction() == second->isReduction());
- assert(first->parallel_method() == second->parallel_method());
-
- Val* merged_id_size = mul(first->size(), second->size());
- IterDomain* merged_id = new IterDomain(
- static_cast<Int*>(merged_id_size),
- first->parallel_method(),
- first->isReduction());
-
- std::vector<IterDomain*> new_domain;
- for (decltype(td->size()) i = 0; i < td->size(); i++) {
- if (i < axis || i > axis + 1)
- new_domain.push_back(td->axis(i));
- else if (i == axis) {
- new_domain.push_back(merged_id);
- }
- }
- TensorDomain* merged_td = new TensorDomain(new_domain);
- Merge* merge_node = new Merge(merged_td, td, axis); // For record keeping
- tv->setDomain(merged_td);
- return tv;
-}
-
-/*
- * Takes axis2pos map, axis2pos[old_pos] = new_pos, to modify the ordering of
- * the iter axes.
- */
-TensorView* reorder_(
- TensorView* tv,
- const std::unordered_map<int, int>& axis2pos) {
- TensorDomain* td = tv->domain();
- auto ndims = td->size();
- // Map to save from previous order, to new order.
- std::vector<int> pos2axis(ndims, -1);
-
- // Go through each old and new position, make sure they're within 0-ndims
- for (std::pair<int, int> elem : axis2pos) {
- int old_pos = elem.first;
- int new_pos = elem.second;
-
- if (old_pos < 0)
- old_pos += ndims;
- if (new_pos < 0)
- new_pos += ndims;
-
- assert(old_pos >= 0 && old_pos < ndims && new_pos >= 0 && new_pos < ndims);
-
- if (pos2axis[new_pos] != -1)
- TORCH_CHECK(false, "Reorder found duplicate destination positions.");
-
- pos2axis[new_pos] = old_pos;
- }
-
- std::set<int> old_positions(pos2axis.begin(), pos2axis.end());
- old_positions.erase(-1);
-
- if (old_positions.size() != axis2pos.size())
- TORCH_INTERNAL_ASSERT(
- false, "Reorder found duplicate destination positions.");
-
- std::set<int> all_positions;
- for (decltype(ndims) i{0}; i < ndims; i++)
- all_positions.insert(i);
-
- // Check what positions haven't been specified.
- std::set<int> positions_left;
- std::set_difference(
- all_positions.begin(),
- all_positions.end(),
- old_positions.begin(),
- old_positions.end(),
- std::inserter(positions_left, positions_left.end()));
-
- // Fill in positions that weren't specified, in relative order,
- // in empty spots in the set of new positions.
- // pos2axis[new_position] = old_position
- auto it = positions_left.begin(); // old positions left
- for (decltype(pos2axis.size()) i = 0; i < pos2axis.size(); i++) {
- if (pos2axis[i] == -1)
- pos2axis[i] = *it++;
- }
-
- // pos2axis is now filled
- if (tv->getComputeAtView() != nullptr) {
- for (int i = 0; i < tv->getComputeAtAxis(); i++) {
- if (pos2axis[i] != i)
- TORCH_CHECK(false, "Cannot reorder axis within compute at range.");
- }
- }
-
- std::vector<IterDomain*> reordered_domain;
- for (int entry : pos2axis)
- reordered_domain.push_back(td->axis(entry));
-
- TensorDomain* reordered_td = new TensorDomain(reordered_domain);
- Reorder* merge_node = new Reorder(reordered_td, td, pos2axis);
- tv->setDomain(reordered_td);
- return tv;
-}
-
TensorView::TensorView(TensorDomain* _domain, DataType dtype)
: Val(ValType::TensorView, dtype), domain_(_domain) {}
@@ -202,7 +37,7 @@
TORCH_CHECK(
tensor_type->dim().has_value(), "Requires static rank for Tensor");
for (int i = 0; i < tensor_type->dim().value(); i++) {
- sizes.push_back(new IterDomain(new Int()));
+ sizes.push_back(new IterDomain(new Int(0), new Int()));
}
domain_ = new TensorDomain(sizes);
}
@@ -221,7 +56,8 @@
// consumers and we're copying over a producer.
if (this->axis(i)->isReduction())
continue;
- domain_copy.push_back(new IterDomain(this->axis(i)->size()));
+ domain_copy.push_back(
+ new IterDomain(this->axis(i)->start(), this->axis(i)->extent()));
}
TensorDomain* td = new TensorDomain(domain_copy);
return new TensorView(td, dtype);
@@ -238,14 +74,14 @@
}
std::vector<IterDomain*>::size_type TensorView::nDims() const {
- return domain()->size();
+ return domain()->nDims();
}
IterDomain* TensorView::axis(int pos) const {
if (pos < 0)
- pos += domain()->size();
+ pos += domain()->nDims();
TORCH_CHECK(
- pos >= 0 && pos < domain()->size(),
+ pos >= 0 && pos < domain()->nDims(),
"Tried to access position ",
pos,
" in domain: ",
@@ -255,7 +91,7 @@
void TensorView::copyDomain(const TensorDomain* td) {
std::vector<IterDomain*> idv;
- for (decltype(td->size()) i = 0; i < td->size(); i++)
+ for (decltype(td->nDims()) i = 0; i < td->nDims(); i++)
idv.push_back(td->axis(i));
setDomain(new TensorDomain(idv));
}
@@ -344,6 +180,117 @@
return this;
}
+TensorView* TensorView::split(int axis, int factor) {
+ if (axis < 0)
+ axis += domain()->nDims();
+
+ TORCH_CHECK(
+ axis >= 0 && axis < domain()->nDims(),
+ "Trying to split axis outside of TensorView's range.");
+
+ if (getComputeAtView() != nullptr)
+ if (axis < getComputeAtAxis())
+ TORCH_CHECK(false, "Cannot split axis within compute at range.");
+
+ setDomain(domain()->split(axis, factor));
+ return this;
+}
+
+// Merge "axis" and "axis+1" into 1 dimension
+TensorView* TensorView::merge(int axis) {
+ if (axis < 0)
+ axis += domain()->nDims();
+
+ TORCH_CHECK(
+ axis >= 0 && axis + 1 < domain()->nDims(),
+ "Trying to merge axis outside of TensorView's range.");
+
+ if (getComputeAtView() != nullptr)
+ if (axis + 1 < getComputeAtAxis())
+ TORCH_CHECK(false, "Cannot merge axis within compute at range.");
+
+ setDomain(domain()->merge(axis));
+ return this;
+}
+
+// Reorder axes according to map[old_pos] = new_pos
+TensorView* TensorView::reorder(const std::unordered_map<int, int>& axis2pos_) {
+ // START VALIDATION CHECKS
+ // adjust based on negative values (any negative values gets nDims added to
+ // it)
+ std::unordered_map<int, int> axis2pos;
+ auto ndims = nDims();
+ std::transform(
+ axis2pos_.begin(),
+ axis2pos_.end(),
+ std::inserter(axis2pos, axis2pos.begin()),
+ [ndims](std::unordered_map<int, int>::value_type entry) {
+ return std::unordered_map<int, int>::value_type({
+ entry.first < 0 ? entry.first + ndims : entry.first,
+ entry.second < 0 ? entry.second + ndims : entry.second,
+ });
+ });
+
+ // Check if any adjusted values are < 0, or >= nDims, which are invalid
+ bool out_of_range = std::any_of(
+ axis2pos.begin(),
+ axis2pos.end(),
+ [ndims](std::unordered_map<int, int>::value_type entry) {
+ return entry.first < 0 || entry.first >= ndims || entry.second < 0 ||
+ entry.second >= ndims;
+ });
+
+ TORCH_CHECK(
+ !out_of_range,
+ "TensorView reorder axes are outside the number of dimensions in the TensorView.")
+
+ // Going to use sets, to see if any duplicate values are in the map.
+
+ std::set<int> old_pos_set;
+ std::transform(
+ axis2pos.begin(),
+ axis2pos.end(),
+ std::inserter(old_pos_set, old_pos_set.begin()),
+ [](std::unordered_map<int, int>::value_type entry) {
+ return entry.first;
+ });
+
+ std::set<int> new_pos_set;
+ std::transform(
+ axis2pos.begin(),
+ axis2pos.end(),
+ std::inserter(new_pos_set, new_pos_set.begin()),
+ [](std::unordered_map<int, int>::value_type entry) {
+ return entry.first;
+ });
+
+ // Error out if duplicate values are found.
+ TORCH_CHECK(
+ old_pos_set.size() == axis2pos.size() &&
+ new_pos_set.size() == axis2pos.size(),
+ "Duplicate entries in transformation map sent to TensorView reorder.");
+
+ // Check if we're trying to reorder any values outside of the computeAt axis
+
+ if (hasComputeAt()) {
+ auto compute_at_axis = getComputeAtAxis();
+ bool outside_computeat = std::any_of(
+ axis2pos.begin(),
+ axis2pos.end(),
+ [compute_at_axis](std::unordered_map<int, int>::value_type entry) {
+ return entry.first < compute_at_axis ||
+ entry.second < compute_at_axis;
+ });
+ TORCH_CHECK(
+ !outside_computeat,
+ "Cannot reorder dimensions that are outside computeAt axis.");
+ }
+ // END VALIDATION CHECKS
+ setDomain(domain()->reorder(axis2pos_));
+
+ return this;
+}
+
} // namespace fuser
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/transform_iter.cpp b/torch/csrc/jit/codegen/cuda/transform_iter.cpp
index 10d636a..7771609 100644
--- a/torch/csrc/jit/codegen/cuda/transform_iter.cpp
+++ b/torch/csrc/jit/codegen/cuda/transform_iter.cpp
@@ -80,41 +80,41 @@
return root;
}
-TensorView* TransformIter::replay(Split* expr, TensorView* tv) {
- return tv->split(
+TensorDomain* TransformIter::replay(Split* expr, TensorDomain* td) {
+ return td->split(
expr->axis(), static_cast<Int*>(expr->factor())->value().value());
}
-TensorView* TransformIter::replay(Merge* expr, TensorView* tv) {
- return tv->merge(expr->axis());
+TensorDomain* TransformIter::replay(Merge* expr, TensorDomain* td) {
+ return td->merge(expr->axis());
}
-TensorView* TransformIter::replay(Reorder* expr, TensorView* tv) {
+TensorDomain* TransformIter::replay(Reorder* expr, TensorDomain* td) {
std::unordered_map<int, int> axis2pos;
for (decltype(expr->pos2axis().size()) i{0}; i < expr->pos2axis().size(); i++)
axis2pos[expr->pos2axis()[i]] = i;
- return tv->reorder(axis2pos);
+ return td->reorder(axis2pos);
}
-TensorView* TransformIter::replay(Expr* expr, TensorView* tv) {
+TensorDomain* TransformIter::replay(Expr* expr, TensorDomain* td) {
TORCH_INTERNAL_ASSERT(expr->isExpr());
switch (*(expr->getExprType())) {
case (ExprType::Split):
- return replay(static_cast<Split*>(expr), tv);
+ return replay(static_cast<Split*>(expr), td);
case (ExprType::Merge):
- return replay(static_cast<Merge*>(expr), tv);
+ return replay(static_cast<Merge*>(expr), td);
case (ExprType::Reorder):
- return replay(static_cast<Reorder*>(expr), tv);
+ return replay(static_cast<Reorder*>(expr), td);
default:
TORCH_INTERNAL_ASSERT(false, "Could not detect expr type in replay.");
}
}
-TensorView* TransformIter::runReplay(TensorView* tv) {
+TensorDomain* TransformIter::runReplay(TensorDomain* td) {
for (auto it = record.begin(); it < record.end(); ++it) {
- tv = TransformIter::replay(*it, tv);
+ td = TransformIter::replay(*it, td);
}
- return tv;
+ return td;
}
} // namespace fuser
diff --git a/torch/csrc/jit/codegen/cuda/transform_iter.h b/torch/csrc/jit/codegen/cuda/transform_iter.h
index d93353c..7a68174 100644
--- a/torch/csrc/jit/codegen/cuda/transform_iter.h
+++ b/torch/csrc/jit/codegen/cuda/transform_iter.h
@@ -4,7 +4,6 @@
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
-#include <torch/csrc/jit/codegen/cuda/tensor.h>
namespace torch {
namespace jit {
@@ -32,15 +31,15 @@
// order operations root->td.
TensorDomain* runBackward(TensorDomain* td, bool generate_record);
- virtual TensorView* replay(Split* expr, TensorView* tv);
- virtual TensorView* replay(Merge* expr, TensorView* tv);
- virtual TensorView* replay(Reorder* expr, TensorView* tv);
+ virtual TensorDomain* replay(Split* expr, TensorDomain* tv);
+ virtual TensorDomain* replay(Merge* expr, TensorDomain* tv);
+ virtual TensorDomain* replay(Reorder* expr, TensorDomain* tv);
// dispatch
- TensorView* replay(Expr* expr, TensorView* tv);
+ TensorDomain* replay(Expr* expr, TensorDomain* tv);
// Runs through operations recorded in record from root-> present
- TensorView* runReplay(TensorView* tv);
+ TensorDomain* runReplay(TensorDomain* tv);
// Forward record from root, to replay_ref/ref_root
std::vector<Expr*> record;
diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp
index 80ff45e..b21c362 100644
--- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp
+++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp
@@ -47,28 +47,35 @@
TensorDomain* TransformReplay::replayBackward(
TensorDomain* td,
bool create_record) {
- influence = std::vector<bool>(td->size(), false);
+ influence = std::vector<bool>(td->nDims(), false);
for (int i = 0; i < compute_at_axis; i++)
influence[i] = true;
return TransformIter::runBackward(td, create_record);
}
/*
- * Replay functions, takes a TensorView and steps through the operations in
+ * Replay functions, takes a TensorDomain and steps through the operations in
* "record" based on influence axes. Will also update influence and propagate
* it forward.
*/
-TensorView* TransformReplay::replay(Split* expr, TensorView* tv) {
+TensorDomain* TransformReplay::replay(Split* expr, TensorDomain* td) {
int axis = expr->axis();
+ bool run_split = influence[axis];
+
+ // Propagate influence
+ influence.insert(influence.begin() + axis + 1, influence[axis]);
+
// Forward prop influence
- if (influence[axis]) {
+ if (run_split) {
// Make sure split axis is real.
int real_axis = axis_map[expr->axis()];
TORCH_INTERNAL_ASSERT(
real_axis != -1,
"During transformation replay attempted to split an imaginary axis.");
- // Replay split
- tv->split(real_axis, *(expr->factor()->value()));
+ TORCH_INTERNAL_ASSERT(
+ td->axis(real_axis)->start()->isZeroInt(),
+ "Transform Replay tried to split an IterDomain with a start value that is not 0,",
+ " this is not currently supported.");
// Inserted a real axis, push everything in axis_map over to the right
// after this inserted axis
for (decltype(axis_map.size()) i{0}; i < axis_map.size(); i++)
@@ -78,31 +85,20 @@
axis_map.insert(
axis_map.begin() + expr->axis() + 1,
real_axis + 1); // insert axis at position axis.
+
+ // Replay split
+ return td->split(real_axis, *(expr->factor()->value()));
} else {
// Fake it
axis_map.insert(axis_map.begin() + expr->axis() + 1, -1);
}
- influence.insert(influence.begin() + axis + 1, influence[axis]);
-
- return tv;
+ return td;
}
-TensorView* TransformReplay::replay(Merge* expr, TensorView* tv) {
+TensorDomain* TransformReplay::replay(Merge* expr, TensorDomain* td) {
int axis = expr->axis();
-
- if (influence[axis] || influence[axis + 1]) {
- // Make sure both merge axes are real.
- TORCH_INTERNAL_ASSERT(
- axis_map[axis] != -1 && axis_map[axis + 1] != -1,
- "During transformation replay attempted to merge an imaginary axis.");
- // Replay merge
- tv->merge(axis_map[axis]);
- } else {
- // If we aren't applying the merge, we won't change any following axis
- // Doesn't matter which axis we propagate for the merge in the axis_map
- assert(axis_map[axis + 1] == -1);
- }
+ bool merge = influence[axis] || influence[axis + 1];
axis_map.erase(axis_map.begin() + expr->axis() + 1);
for (decltype(axis_map.size()) i = expr->axis() + 1; i < axis_map.size(); i++)
@@ -113,10 +109,27 @@
influence[axis] = influence[axis] || influence[axis + 1];
influence.erase(influence.begin() + axis + 1);
- return tv;
+ if (merge) {
+ // Make sure both merge axes are real.
+ TORCH_INTERNAL_ASSERT(
+ axis_map[axis] != -1 && axis_map[axis + 1] != -1,
+ "During transformation replay attempted to merge an imaginary axis.");
+ // Replay merge
+ TORCH_INTERNAL_ASSERT(
+ td->axis(axis)->start()->isZeroInt() &&
+ td->axis(axis + 1)->start()->isZeroInt(),
+ "Transform Replay tried to Merge IterDomains with a start value that is not 0,",
+ " this is not currently supported.");
+ return td->merge(axis_map[axis]);
+ } else {
+ // If we aren't applying the merge, we won't change any following axis
+ // Doesn't matter which axis we propagate for the merge in the axis_map
+ assert(axis_map[axis + 1] == -1);
+ return td;
+ }
}
-TensorView* TransformReplay::replay(Reorder* expr, TensorView* tv) {
+TensorDomain* TransformReplay::replay(Reorder* expr, TensorDomain* td) {
// axis2pos[old_pos] = new_pos is sent to reorder, Reorder holds
// pos2axis[new_pos] = old_pos Generate new axis2pos map
std::unordered_map<int, int> axis2pos;
@@ -161,13 +174,13 @@
axis2pos[entry.first] = axis++;
}
- for (decltype(tv->nDims()) i{0}; i < tv->nDims(); i++) {
+ for (decltype(td->nDims()) i{0}; i < td->nDims(); i++) {
if (axis2pos.find(i) == axis2pos.end())
axis2pos[i] = axis++;
}
// replay reorder
- tv->reorder(axis2pos);
+ TensorDomain* reordered_td = td->reorder(axis2pos);
// Fake transform:
for (decltype(pos2axis.size()) i = 0; i < pos2axis.size(); i++) {
@@ -181,7 +194,7 @@
influence = reordered_influence;
axis_map = reordered_axis_map;
- return tv;
+ return reordered_td;
}
/*
@@ -219,16 +232,18 @@
* outside of compute_at_axis.
*
*/
-TensorView* TransformReplay::runReplay(
- TensorView* replay_ref,
- TensorView* replay_target,
+TensorDomain* TransformReplay::runReplay(
+ TensorDomain* replay_ref,
+ TensorDomain* replay_target,
int compute_at_axis) {
if (compute_at_axis < 0)
compute_at_axis += int(replay_ref->nDims()) + 1;
TORCH_CHECK(
compute_at_axis >= 0 && compute_at_axis < int(replay_ref->nDims()) + 1,
- "Transform replay cannot be performed as the compute_at_axis is not in the valid range.");
+ "Transform replay cannot be performed as the compute_at_axis is not in the valid range, it should be 0 or greater, and less than ",
+ int(replay_ref->nDims()) + 1,
+ ".");
this->compute_at_axis = compute_at_axis;
@@ -236,7 +251,7 @@
// Reset the tensor domain of the target, this is the only way we can be
// certain That we can actually replay the ops of ref.
// Trace back to the root TensorDomain's of ref and target
- replay_target->resetView();
+ replay_target = replay_target->rootDomain();
/* STEP 2 */
// Mark compute_at_axis and below as "influenced", trace back through
@@ -244,7 +259,7 @@
// produce these axis
// As we trace the ref, record the operations to go from replay_ref ->
// ref_root, save in "record"
- TensorDomain* ref_root = replayBackward(replay_ref->domain(), true);
+ TensorDomain* ref_root = replayBackward(replay_ref, true);
// We're going to save a copy of this vector, class member influnce will be
// used during replay to forward propagate influence.
std::vector<bool> root_influence_vector = influence;
@@ -257,7 +272,7 @@
axis_map.push_back(i);
// Domain sizes must match at root for replay.
- if (axis_map.size() != ref_root->size()) {
+ if (axis_map.size() != ref_root->nDims()) {
std::stringstream err_msg;
err_msg
<< "Transforms cannot be replayed as source and destinations do not have the same root sizes."
@@ -266,9 +281,10 @@
}
/*
- * TODO: Decide if the following check is reasonable, when we're parsing the
- * JIT graph, we are using symbolic sizes for each tensor individually, so
- * they won't all have the same size.
+ * TODO: The JIT graph has symbolic sizes, so inputs may actually have the
+ * same sizes (assuming no broadcasts/reductions), we at some point want to
+ * have some size matching, and sizes should actually match at this point, but
+ * the check below won't work.
*/
// for (decltype(axis_map.size()) i{0}; i < axis_map.size(); i++) {
@@ -288,7 +304,7 @@
// actually execute those based on influence. If we didn't track all
// axes, we wouldn't know what axis split/merge/reorder are referencing
// as they're relative to the "full" replay that produced the reference.
- TensorView* replayed = TransformIter::runReplay(replay_target);
+ TensorDomain* replayed = TransformIter::runReplay(replay_target);
for (decltype(replayed->nDims()) i{0}; i < compute_at_axis; i++)
if (replayed->axis(i)->isReduction())
@@ -299,6 +315,16 @@
return replayed;
}
+TensorView* TransformReplay::runReplay(
+ TensorView* replay_ref,
+ TensorView* replay_target,
+ int compute_at_axis) {
+ TensorDomain* td =
+ runReplay(replay_ref->domain(), replay_target->domain(), compute_at_axis);
+ replay_target->setDomain(td);
+ return replay_target;
+}
+
TensorView* TransformReplay::replay(
TensorView* replay_ref,
TensorView* replay_target,
@@ -312,8 +338,14 @@
TensorView* replay_ref,
TensorView* replay_target) {
TransformReplay tr;
- tr.runReplay(replay_ref, replay_target, -1);
- return replay_target;
+ return tr.runReplay(replay_ref, replay_target, -1);
+}
+
+TensorDomain* TransformReplay::fullReplay(
+ TensorDomain* replay_ref,
+ TensorDomain* replay_target) {
+ TransformReplay tr;
+ return tr.runReplay(replay_ref, replay_target, -1);
}
} // namespace fuser
diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.h b/torch/csrc/jit/codegen/cuda/transform_replay.h
index 7f0be72..5248a04 100644
--- a/torch/csrc/jit/codegen/cuda/transform_replay.h
+++ b/torch/csrc/jit/codegen/cuda/transform_replay.h
@@ -3,7 +3,6 @@
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/tensor.h>
#include <torch/csrc/jit/codegen/cuda/transform_iter.h>
#include <algorithm>
@@ -136,16 +135,26 @@
* "record" based on influence axes. Will also update influence and propagate
* it forward.
*/
- TensorView* replay(Split* expr, TensorView* tv);
- TensorView* replay(Merge* expr, TensorView* tv);
- TensorView* replay(Reorder* expr, TensorView* tv);
+ TensorDomain* replay(Split* expr, TensorDomain* tv);
+ TensorDomain* replay(Merge* expr, TensorDomain* tv);
+ TensorDomain* replay(Reorder* expr, TensorDomain* tv);
/*
* Takes replay_ref and replays its transformations on replay_target
* Replays from begining of both TensorDomains. could be more efficient to try
- * and find a common ancestor to start from, but that's outside the scope of
- * this work for now.
- *
+ * and find a common ancestor to start from, but likely not a worthwhile
+ * optimization.
+ */
+ TensorDomain* runReplay(
+ TensorDomain* replay_ref,
+ TensorDomain* replay_target,
+ int compute_at_axis);
+
+ /*
+ * Takes replay_ref and replays its transformations on replay_target
+ * Replays from begining of both TensorDomains. could be more efficient to try
+ * and find a common ancestor to start from, but likely not a worthwhile
+ * optimization.
*/
TensorView* runReplay(
TensorView* replay_ref,
@@ -175,6 +184,10 @@
static TensorView* fullReplay(
TensorView* replay_ref,
TensorView* replay_target);
+
+ static TensorDomain* fullReplay(
+ TensorDomain* replay_ref,
+ TensorDomain* replay_target);
};
} // namespace fuser