| #if defined(USE_CUDA) |
| #include <gtest/gtest.h> |
| |
| #include <torch/csrc/jit/codegen/cuda/arith.h> |
| #include <torch/csrc/jit/codegen/cuda/codegen.h> |
| #include <torch/csrc/jit/codegen/cuda/disjoint_set.h> |
| #include <torch/csrc/jit/codegen/cuda/executor.h> |
| #include <torch/csrc/jit/codegen/cuda/executor_launch_params.h> |
| #include <torch/csrc/jit/codegen/cuda/expr_evaluator.h> |
| #include <torch/csrc/jit/codegen/cuda/fusion.h> |
| #include <torch/csrc/jit/codegen/cuda/fusion_segmenter.h> |
| #include <torch/csrc/jit/codegen/cuda/interface.h> |
| #include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h> |
| #include <torch/csrc/jit/codegen/cuda/ir_graphviz.h> |
| #include <torch/csrc/jit/codegen/cuda/ir_iostream.h> |
| #include <torch/csrc/jit/codegen/cuda/ir_utils.h> |
| #include <torch/csrc/jit/codegen/cuda/iter_visitor.h> |
| #include <torch/csrc/jit/codegen/cuda/kernel_cache.h> |
| #include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h> |
| #include <torch/csrc/jit/codegen/cuda/kernel_ir.h> |
| #include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h> |
| #include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h> |
| #include <torch/csrc/jit/codegen/cuda/lower2device.h> |
| #include <torch/csrc/jit/codegen/cuda/mutator.h> |
| #include <torch/csrc/jit/codegen/cuda/root_domain_map.h> |
| #include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h> |
| #include <torch/csrc/jit/codegen/cuda/scheduler/utils.h> |
| #include <torch/csrc/jit/codegen/cuda/transform_replay.h> |
| #include <torch/csrc/jit/codegen/cuda/transform_rfactor.h> |
| |
| // fuser and IR parser |
| #include "test_gpu_validator.h" |
| |
| #include <ATen/cuda/Exceptions.h> |
| #include <c10/cuda/CUDAStream.h> |
| |
| #include <algorithm> |
| #include <iostream> |
| |
| // Tests go in torch::jit |
| namespace torch { |
| namespace jit { |
| |
| using namespace torch::jit::fuser::cuda; |
| using namespace at::indexing; |
| |
| namespace { |
| |
| // Make a tensor that is known to be fully contiguous of dimensionality=ndims, |
| // but unknown sizes |
| TensorView* makeContigTensor(size_t ndims, DataType dtype = DataType::Float) { |
| return TensorViewBuilder() |
| .ndims(ndims) |
| .dtype(dtype) |
| .contiguity(std::vector<bool>(ndims, true)) |
| .build(); |
| } |
| |
| // Make a tensor that is known to be non-contiguous of dimensionality=ndims, |
| // but unknown sizes |
| TensorView* makeSymbolicTensor(size_t ndims, DataType dtype = DataType::Float) { |
| return TensorViewBuilder().ndims(ndims).dtype(dtype).build(); |
| } |
| |
| // Make a non-contiguous tensor of compile-time known sizes |
| TensorView* makeConcreteTensor( |
| std::vector<int64_t> shape, |
| DataType dtype = DataType::Float) { |
| return TensorViewBuilder().shape(shape).dtype(dtype).build(); |
| } |
| |
| void checkIntValue( |
| ExpressionEvaluator& evaluator, |
| Val* val, |
| Int::ScalarType expected_value) { |
| TORCH_CHECK(val->isAnInt()); |
| const auto actual_value = evaluator.evaluate(val); |
| TORCH_CHECK(actual_value.has_value()); |
| TORCH_CHECK(actual_value.value() == expected_value); |
| } |
| |
| void checkIntValue( |
| kir::ExpressionEvaluator& evaluator, |
| const kir::Val* val, |
| kir::Int::ScalarType expected_value) { |
| const auto actual_value = evaluator.evaluate(val); |
| TORCH_CHECK(actual_value.has_value()); |
| TORCH_CHECK(actual_value.value() == expected_value); |
| } |
| |
| // ATen version of tensor shifting |
| auto shift(at::Tensor tensor, const std::vector<int>& offsets) { |
| TORCH_INTERNAL_ASSERT(tensor.ndimension() == offsets.size()); |
| at::Tensor t = tensor; |
| for (size_t i = 0; i < offsets.size(); ++i) { |
| const auto offset = offsets[i]; |
| if (offset == 0) { |
| continue; |
| } |
| t = t.roll(offsets[i], i); |
| std::vector<at::indexing::TensorIndex> indices( |
| tensor.ndimension(), at::indexing::Slice(0, at::indexing::None)); |
| if (offset > 0) { |
| indices[i] = at::indexing::Slice(0, offset); |
| } else { |
| indices[i] = at::indexing::Slice(offset, at::indexing::None); |
| } |
| t.index(indices) = 0; |
| } |
| return t; |
| } |
| |
| // ATen version of tensor shifting |
| auto gather( |
| at::Tensor tensor, |
| const std::vector<int>& window_shape, |
| const std::vector<std::vector<int>>& pad_width) { |
| TORCH_CHECK( |
| tensor.ndimension() == window_shape.size(), |
| "Invalid window shape: ", |
| window_shape, |
| ". Size of the window shape is different from the tensor dimension."); |
| TORCH_CHECK( |
| tensor.ndimension() == pad_width.size(), |
| "Invalid pad width: ", |
| pad_width, |
| ". Size of the pad width is different from the tensor dimension."); |
| at::Tensor t = tensor; |
| for (size_t i = 0; i < window_shape.size(); ++i) { |
| const auto w_size = window_shape[i]; |
| TORCH_CHECK(w_size != 0); |
| const auto& pad = pad_width[i]; |
| TORCH_CHECK(pad.size() == 2); |
| at::Tensor concat_tensor; |
| for (int w = 0; w < w_size; ++w) { |
| std::vector<int> shift_offsets(t.ndimension(), 0); |
| shift_offsets[i] = pad[0] - w; |
| auto shifted = shift(t, shift_offsets); |
| shifted = shifted.unsqueeze(-1); |
| if (w == 0) { |
| concat_tensor = shifted; |
| } else { |
| concat_tensor = at::cat({concat_tensor, shifted}, -1); |
| } |
| } |
| t = concat_tensor; |
| } |
| return t; |
| } |
| |
| } // namespace |
| |
| // Shift an input tensor |
| TEST(NVFuserTest, FusionShift1_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| auto tv0 = makeSymbolicTensor(2); |
| fusion.addInput(tv0); |
| |
| auto tv1 = shift(tv0, {-1, 0}); |
| fusion.addOutput(tv1); |
| |
| auto tv2 = shift(tv0, {0, 1}); |
| fusion.addOutput(tv2); |
| |
| auto tv3 = shift(tv0, {2, 2}); |
| fusion.addOutput(tv3); |
| |
| auto tv4 = shift(tv0, {-2, -2}); |
| fusion.addOutput(tv4); |
| |
| int numel_x = 9; |
| int numel_y = 11; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
| std::vector<IValue> inputs = {t0}; |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| auto outputs = fe.runFusion(inputs); |
| |
| auto t1 = shift(t0, {-1, 0}); |
| TORCH_CHECK(t1.equal(outputs[0])); |
| |
| auto t2 = shift(t0, {0, 1}); |
| TORCH_CHECK(t2.equal(outputs[1])); |
| |
| auto t3 = shift(t0, {2, 2}); |
| TORCH_CHECK(t3.equal(outputs[2])); |
| |
| auto t4 = shift(t0, {-2, -2}); |
| TORCH_CHECK(t4.equal(outputs[3])); |
| } |
| |
| // Shifts an intermediate tensor |
| TEST(NVFuserTest, FusionShift2_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| auto tv0 = makeSymbolicTensor(2); |
| fusion.addInput(tv0); |
| auto tv1 = add(tv0, new Double(1)); |
| auto tv2 = shift(tv1, {-1, 0}); |
| fusion.addOutput(tv2); |
| |
| // make it a little more complex |
| auto tv3 = add(tv0, new Double(3)); |
| auto tv4 = add(tv3, new Double(4)); |
| auto tv5 = shift(tv4, {-1, 0}); |
| auto tv6 = shift(tv4, {0, -1}); |
| auto tv7 = shift(tv4, {1, 0}); |
| auto tv8 = shift(tv4, {0, 0}); |
| auto tv9 = add(tv5, tv6); |
| auto tv10 = add(tv9, tv7); |
| auto tv11 = add(tv10, tv8); |
| fusion.addOutput(tv11); |
| |
| for (auto tv : {tv1, tv2, tv3, tv4, tv5, tv6, tv7, tv8, tv9, tv10, tv11}) { |
| tv->setMemoryType(MemoryType::Global); |
| } |
| |
| // t1 allocation: (t1.size[0] + 1) * (t1.size[1]) |
| // t3 allocation: (t3.size[0] + 2) * (t3.size[1] + 1) |
| // t4 allocation: (t3.size[0] + 2) * (t3.size[1] + 1) |
| GpuLower gpulw(&fusion); |
| |
| for (const auto& kir_node : gpulw.kernel()->irNodes()) { |
| if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) { |
| auto tensor_name = alloc->buffer()->name(); |
| if (tensor_name == 1 || tensor_name == 3 || tensor_name == 4) { |
| TORCH_CHECK(alloc->shape().size() == 2); |
| for (int i = 0; i < 2; ++i) { |
| if (tensor_name == 1 && i == 1) { |
| TORCH_CHECK(alloc->shape().at(i)->isA<kir::NamedScalar>()); |
| continue; |
| } |
| auto def = |
| dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition()); |
| TORCH_CHECK(def != nullptr && def->operation() == BinaryOpType::Add); |
| TORCH_CHECK(def->as<kir::BinaryOp>()->lhs()->isA<kir::NamedScalar>()); |
| auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs()); |
| TORCH_CHECK(rhs != nullptr && rhs->isConst()); |
| int rhs_value = *rhs->value(); |
| if (tensor_name == 1) { |
| TORCH_CHECK(i == 0); |
| TORCH_CHECK(rhs_value == 1); |
| } else { |
| if (i == 0) { |
| TORCH_CHECK(rhs_value == 2); |
| } else { |
| TORCH_CHECK(rhs_value == 1); |
| } |
| } |
| } |
| } |
| } |
| } |
| |
| int numel_x = 9; |
| int numel_y = 11; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
| std::vector<IValue> inputs = {t0}; |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| auto outputs = fe.runFusion(inputs); |
| |
| auto t1 = t0 + 1; |
| auto t2 = shift(t1, {-1, 0}); |
| |
| auto t3 = t0 + 3; |
| auto t4 = t3 + 4; |
| auto t5 = shift(t4, {-1, 0}); |
| auto t6 = shift(t4, {0, -1}); |
| auto t7 = shift(t4, {1, 0}); |
| auto t8 = shift(t4, {0, 0}); |
| auto t9 = t5 + t6; |
| auto t10 = t9 + t7; |
| auto t11 = t10 + t8; |
| |
| testValidate(&fusion, outputs, inputs, {t2, t11}, __LINE__, __FILE__); |
| } |
| |
| TEST(NVFuserTest, FusionShiftRightOfCA_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| auto tv0 = makeSymbolicTensor(2); |
| fusion.addInput(tv0); |
| |
| auto tv1 = add(tv0, new Double(1)); |
| auto tv2 = shift(tv1, {0, 1}); |
| fusion.addOutput(tv2); |
| |
| tv0->computeAt(tv2, -2); |
| |
| tv1->setMemoryType(MemoryType::Global); |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| |
| int numel_x = 100; |
| int numel_y = 101; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
| std::vector<IValue> inputs = {t0}; |
| auto outputs = fe.runFusion(inputs); |
| |
| auto t1 = t0 + 1; |
| auto t2 = shift(t1, {0, 1}); |
| |
| TORCH_CHECK(t2.allclose(outputs[0])); |
| } |
| |
| TEST(NVFuserTest, FusionShiftLeftOfCA_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| auto tv0 = makeSymbolicTensor(2); |
| fusion.addInput(tv0); |
| auto tv1 = add(tv0, new Double(1)); |
| auto tv2 = add(tv1, new Double(1)); |
| auto tv3 = shift(tv2, {-1, 0}); |
| auto tv4 = add(tv3, new Double(1)); |
| fusion.addOutput(tv4); |
| |
| tv0->computeAt(tv4, -1); |
| |
| // Lowering should trigger an assertion failure as a shifted axis is |
| // found inside an allocation position. |
| ASSERT_ANY_THROW(fusion.printKernel()); |
| } |
| |
| TEST(NVFuserTest, FusionShiftSplit1_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| auto tv0 = makeSymbolicTensor(2); |
| fusion.addInput(tv0); |
| auto tv1 = add(tv0, new Double(1)); |
| auto tv2 = shift(tv1, {0, 1}); |
| auto tv3 = shift(tv1, {0, -2}); |
| fusion.addOutput(tv2); |
| fusion.addOutput(tv3); |
| |
| int split_factor = 4; |
| tv2->split(-1, split_factor); |
| tv3->split(-1, split_factor); |
| |
| tv0->computeAt(tv2, -2); |
| tv0->computeAt(tv3, -2); |
| |
| // t1 allocation: (4 + 3) |
| GpuLower gpulw(&fusion); |
| for (const auto& kir_node : gpulw.kernel()->irNodes()) { |
| if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) { |
| auto tensor_name = alloc->buffer()->name(); |
| if (tensor_name == 1) { |
| TORCH_CHECK(alloc->shape().size() == 1); |
| auto def = |
| dynamic_cast<kir::BinaryOp*>(alloc->shape().at(0)->definition()); |
| auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs()); |
| TORCH_CHECK(lhs != nullptr && lhs->isConst()); |
| int lhs_value = *lhs->value(); |
| auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs()); |
| TORCH_CHECK(rhs != nullptr && rhs->isConst()); |
| int rhs_value = *rhs->value(); |
| TORCH_CHECK(lhs_value == split_factor && rhs_value == 3); |
| } |
| } |
| } |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| |
| int numel_x = 9; |
| int numel_y = 11; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
| std::vector<IValue> inputs = {t0}; |
| auto outputs = fe.runFusion(inputs); |
| |
| auto t1 = t0 + 1; |
| auto t2 = shift(t1, {0, 1}); |
| auto t3 = shift(t1, {0, -2}); |
| |
| testValidate(&fusion, outputs, inputs, {t2, t3}, __LINE__, __FILE__); |
| } |
| |
| TEST(NVFuserTest, FusionShiftSplit2_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| auto tv0 = makeSymbolicTensor(2); |
| fusion.addInput(tv0); |
| |
| auto tv1 = add(tv0, new Double(1)); |
| auto tv2 = add(tv1, new Double(1)); |
| auto tv3 = shift(tv2, {0, -1}); |
| auto tv4 = shift(tv2, {0, 1}); |
| auto tv5 = add(tv3, tv4); |
| fusion.addOutput(tv5); |
| |
| auto tv6 = add(tv0, new Double(1)); |
| auto tv7 = shift(tv6, {0, 0}); |
| auto tv8 = add(tv7, new Double(1)); |
| fusion.addOutput(tv8); |
| |
| int split_factor = 4; |
| |
| tv5->split(-1, split_factor); |
| tv8->split(-1, split_factor); |
| |
| tv0->computeAt(tv5, -2); |
| tv0->computeAt(tv8, -2); |
| |
| // t1 and t2 allocation: (4 + 2) |
| // t4 allocation: (4) |
| GpuLower gpulw(&fusion); |
| for (const auto& kir_node : gpulw.kernel()->irNodes()) { |
| if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) { |
| auto tensor_name = alloc->buffer()->name(); |
| if (tensor_name == 1 || tensor_name == 2) { |
| TORCH_CHECK(alloc->shape().size() == 1); |
| auto def = |
| dynamic_cast<kir::BinaryOp*>(alloc->shape().at(0)->definition()); |
| auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs()); |
| TORCH_CHECK(lhs != nullptr && lhs->isConst()); |
| int lhs_value = *lhs->value(); |
| auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs()); |
| TORCH_CHECK(rhs != nullptr && rhs->isConst()); |
| int rhs_value = *rhs->value(); |
| TORCH_CHECK(lhs_value == split_factor && rhs_value == 2); |
| } else if (tensor_name == 4) { |
| TORCH_CHECK(alloc->shape().size() == 1); |
| auto size = dynamic_cast<kir::Int*>(alloc->shape().at(0)); |
| TORCH_CHECK(size != nullptr && size->isConst()); |
| int size_value = *size->value(); |
| TORCH_CHECK(size_value == split_factor); |
| } |
| } |
| } |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| |
| int numel_x = 9; |
| int numel_y = 11; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
| std::vector<IValue> inputs = {t0}; |
| auto outputs = fe.runFusion(inputs); |
| |
| auto t1 = t0 + 2; |
| auto t3 = shift(t1, {0, -1}); |
| auto t4 = shift(t1, {0, 1}); |
| auto t5 = t3 + t4; |
| |
| auto t6 = t0 + 1; |
| auto t7 = t6; |
| auto t8 = t7 + 1; |
| |
| testValidate(&fusion, outputs, inputs, {t5, t8}, __LINE__, __FILE__); |
| } |
| |
| TEST(NVFuserTest, FusionShiftDoubleSplit_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| auto tv0 = makeSymbolicTensor(2); |
| fusion.addInput(tv0); |
| auto tv1 = add(tv0, new Double(1)); |
| auto tv2 = add(tv1, new Double(2)); |
| auto tv3 = shift(tv2, {0, 1}); |
| fusion.addOutput(tv3); |
| |
| int split_factor1 = 8; |
| int split_factor2 = 4; |
| |
| tv3->split(-1, split_factor1); |
| |
| tv0->computeAt(tv3, -2); |
| |
| tv1->split(-1, split_factor2); |
| |
| // t1: [i1, i2/8, 8/4, 4] |
| // t2: [i1, i2/8, 8] |
| // t3: [i1, i2/8, 8] |
| |
| // t1 and t2 allocation: (split_factor1 + 1) |
| GpuLower gpulw(&fusion); |
| for (const auto& kir_node : gpulw.kernel()->irNodes()) { |
| if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) { |
| auto tensor_name = alloc->buffer()->name(); |
| if (tensor_name == 1 || tensor_name == 2) { |
| TORCH_CHECK(alloc->shape().size() == 1); |
| auto def = |
| dynamic_cast<kir::BinaryOp*>(alloc->shape().at(0)->definition()); |
| auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs()); |
| TORCH_CHECK(lhs != nullptr && lhs->isConst()); |
| int lhs_value = *lhs->value(); |
| auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs()); |
| TORCH_CHECK(rhs != nullptr && rhs->isConst()); |
| int rhs_value = *rhs->value(); |
| TORCH_CHECK(lhs_value == split_factor1 && rhs_value == 1); |
| } |
| } |
| } |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| |
| int numel_x = 99; |
| int numel_y = 101; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
| std::vector<IValue> inputs = {t0}; |
| auto outputs = fe.runFusion(inputs); |
| |
| auto t1 = t0 + 3; |
| auto ref = shift(t1, {0, 1}); |
| |
| testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
| } |
| |
| TEST(NVFuserTest, FusionShift3ptStencil_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| // 3-pt stencil |
| auto tv0 = makeSymbolicTensor(1); |
| fusion.addInput(tv0); |
| |
| std::vector<std::vector<int>> offsets = {{-1}, {1}}; |
| |
| std::vector<TensorView*> tvs; |
| for (const auto& offset : offsets) { |
| tvs.push_back(shift(tv0, offset)); |
| } |
| |
| auto tv_out = tv0; |
| |
| for (auto tv : tvs) { |
| tv_out = add(tv_out, tv); |
| } |
| |
| tv_out = div(tv_out, new Double(tvs.size() + 1)); |
| |
| fusion.addOutput(tv_out); |
| |
| int split_factor = 4; |
| |
| tv_out->split(0, split_factor); |
| |
| // This seems fine but not verified yet |
| // tv_out->axis(-1)->parallelize(ParallelType::Unswitch); |
| |
| auto cache = tv0->cache_after(); |
| |
| tv0->computeAt(tv_out, 1); |
| |
| // Inline completely except for the cache |
| for (auto tv : tvs) { |
| tv->computeAt(tv_out, -1); |
| } |
| |
| // cache allocation: (split_factor + 2) |
| GpuLower gpulw(&fusion); |
| for (const auto& kir_node : gpulw.kernel()->irNodes()) { |
| if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) { |
| auto tensor_name = alloc->buffer()->name(); |
| if (tensor_name == cache->name()) { |
| TORCH_CHECK(alloc->shape().size() == 1); |
| auto def = |
| dynamic_cast<kir::BinaryOp*>(alloc->shape().at(0)->definition()); |
| auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs()); |
| TORCH_CHECK(lhs != nullptr && lhs->isConst()); |
| int lhs_value = *lhs->value(); |
| auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs()); |
| TORCH_CHECK(rhs != nullptr && rhs->isConst()); |
| int rhs_value = *rhs->value(); |
| TORCH_CHECK(lhs_value == split_factor && rhs_value == 2); |
| } |
| } |
| } |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| |
| int numel_x = 99; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x}, options); |
| std::vector<IValue> inputs = {t0}; |
| auto outputs = fe.runFusion(inputs); |
| |
| auto ref = (t0 + shift(t0, {-1}) + shift(t0, {1})) / 3; |
| |
| testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
| } |
| |
| TEST(NVFuserTest, FusionShift5ptStencil_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| // 5-pt stencil |
| auto tv0 = makeSymbolicTensor(2); |
| fusion.addInput(tv0); |
| std::vector<std::vector<int>> offsets = {{-1, 0}, {1, 0}, {0, -1}, {0, 1}}; |
| |
| std::vector<TensorView*> tvs; |
| for (const auto& offset : offsets) { |
| tvs.push_back(shift(tv0, offset)); |
| } |
| |
| auto tv_out = tv0; |
| |
| for (auto tv : tvs) { |
| tv_out = add(tv_out, tv); |
| } |
| |
| tv_out = div(tv_out, new Double(tvs.size() + 1)); |
| |
| fusion.addOutput(tv_out); |
| |
| std::vector<int> split_factor({4, 8}); |
| |
| tv_out->split(-1, split_factor[1]); |
| tv_out->split(0, split_factor[0]); |
| tv_out->reorder({{1, 2}, {2, 1}}); |
| |
| auto cache = tv0->cache_after(); |
| |
| tv0->computeAt(tv_out, 2); |
| |
| // Inline completely except for the cache |
| for (auto tv : tvs) { |
| tv->computeAt(tv_out, -1); |
| } |
| |
| // cache allocation: (split_factor + 2) * (split_factor + 2) |
| GpuLower gpulw(&fusion); |
| for (const auto& kir_node : gpulw.kernel()->irNodes()) { |
| if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) { |
| auto tensor_name = alloc->buffer()->name(); |
| if (tensor_name == cache->name()) { |
| TORCH_CHECK(alloc->shape().size() == 2); |
| for (int i = 0; i < 2; ++i) { |
| auto def = |
| dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition()); |
| auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs()); |
| TORCH_CHECK(lhs != nullptr && lhs->isConst()); |
| int lhs_value = *lhs->value(); |
| auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs()); |
| TORCH_CHECK(rhs != nullptr && rhs->isConst()); |
| int rhs_value = *rhs->value(); |
| TORCH_CHECK(lhs_value == split_factor[i] && rhs_value == 2); |
| } |
| } |
| } |
| } |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| |
| int numel_x = 99; |
| int numel_y = 101; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
| std::vector<IValue> inputs = {t0}; |
| auto outputs = fe.runFusion(inputs); |
| |
| auto ref = t0; |
| for (const auto& offset : offsets) { |
| ref = ref + shift(t0, offset); |
| } |
| ref = ref / int(offsets.size() + 1); |
| |
| testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
| } |
| |
| TEST(NVFuserTest, FusionShift9ptStencil_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| // 9-pt stencil |
| std::vector<std::vector<int>> offsets; |
| for (int i = -1; i < 2; ++i) { |
| for (int j = -1; j < 2; ++j) { |
| if (i == 0 && j == 0) { |
| continue; |
| } |
| offsets.push_back({i, j}); |
| } |
| } |
| |
| auto tv0 = makeSymbolicTensor(2); |
| fusion.addInput(tv0); |
| std::vector<TensorView*> tvs; |
| for (const auto& offset : offsets) { |
| tvs.push_back(shift(tv0, offset)); |
| } |
| |
| auto tv_out = tv0; |
| |
| for (auto tv : tvs) { |
| tv_out = add(tv_out, tv); |
| } |
| |
| tv_out = div(tv_out, new Double(tvs.size() + 1)); |
| |
| fusion.addOutput(tv_out); |
| |
| std::vector<int> split_factor({4, 8}); |
| tv_out->split(-1, split_factor[1]); |
| tv_out->split(0, split_factor[0]); |
| tv_out->reorder({{1, 2}, {2, 1}}); |
| |
| auto cache = tv0->cache_after(); |
| |
| tv0->computeAt(tv_out, 2); |
| |
| // Inline completely except for the cache |
| for (auto tv : tvs) { |
| tv->computeAt(tv_out, -1); |
| } |
| |
| // This seems fine but not yet verified |
| // tv_out->axis(-1)->parallelize(ParallelType::Unswitch); |
| |
| // cache allocation: (split_factor + 2) * (split_factor + 2) |
| GpuLower gpulw(&fusion); |
| for (const auto& kir_node : gpulw.kernel()->irNodes()) { |
| if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) { |
| auto tensor_name = alloc->buffer()->name(); |
| if (tensor_name == cache->name()) { |
| TORCH_CHECK(alloc->shape().size() == 2); |
| for (int i = 0; i < 2; ++i) { |
| auto def = |
| dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition()); |
| auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs()); |
| TORCH_CHECK(lhs != nullptr && lhs->isConst()); |
| int lhs_value = *lhs->value(); |
| auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs()); |
| TORCH_CHECK(rhs != nullptr && rhs->isConst()); |
| int rhs_value = *rhs->value(); |
| TORCH_CHECK(lhs_value == split_factor[i] && rhs_value == 2); |
| } |
| } |
| } |
| } |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| |
| int numel_x = 99; |
| int numel_y = 101; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
| std::vector<IValue> inputs = {t0}; |
| auto outputs = fe.runFusion(inputs); |
| |
| auto ref = t0; |
| for (const auto& offset : offsets) { |
| ref = ref + shift(t0, offset); |
| } |
| ref = ref / int(offsets.size() + 1); |
| |
| testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
| } |
| |
| TEST(NVFuserTest, FusionShiftSmemBlocking_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| auto tv0 = makeSymbolicTensor(2); |
| fusion.addInput(tv0); |
| auto tv1 = add(tv0, new Double(1)); |
| auto tv2 = shift(tv1, {0, 1}); |
| fusion.addOutput(tv2); |
| |
| int smem_block_factor = 32; |
| |
| tv2->split(-1, smem_block_factor); |
| |
| tv0->computeAt(tv2, -2); |
| |
| tv1->axis(-1)->parallelize(ParallelType::TIDx); |
| tv2->axis(-1)->parallelize(ParallelType::TIDx); |
| |
| tv1->setMemoryType(MemoryType::Shared); |
| |
| // tv1 allocation: (split_factor + 1) |
| GpuLower gpulw(&fusion); |
| for (const auto& kir_node : gpulw.kernel()->irNodes()) { |
| if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) { |
| auto tensor_name = alloc->buffer()->name(); |
| if (tensor_name == tv1->name()) { |
| TORCH_CHECK(alloc->shape().size() == 1); |
| for (int i = 0; i < 1; ++i) { |
| auto def = |
| dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition()); |
| auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs()); |
| TORCH_CHECK(lhs != nullptr && lhs->isConst()); |
| int lhs_value = *lhs->value(); |
| auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs()); |
| TORCH_CHECK(rhs != nullptr && rhs->isConst()); |
| int rhs_value = *rhs->value(); |
| TORCH_CHECK(lhs_value == smem_block_factor && rhs_value == 1); |
| } |
| } |
| } |
| } |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| |
| int numel_x = 100; |
| int numel_y = 101; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
| std::vector<IValue> inputs = {t0}; |
| auto outputs = fe.runFusion(inputs); |
| |
| auto t1 = t0 + 1; |
| auto t2 = shift(t1, {0, 1}); |
| auto ref = t2; |
| |
| testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
| } |
| |
| TEST(NVFuserTest, FusionShift3ptStencilParallel_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| // 3-pt stencil |
| auto tv0 = makeSymbolicTensor(1); |
| fusion.addInput(tv0); |
| std::vector<TensorView*> tvs; |
| tvs.push_back(shift(tv0, {-1})); |
| tvs.push_back(shift(tv0, {1})); |
| |
| auto tv_out = tv0; |
| |
| for (auto tv : tvs) { |
| tv_out = add(tv_out, tv); |
| } |
| |
| tv_out = div(tv_out, new Double(tvs.size() + 1)); |
| |
| fusion.addOutput(tv_out); |
| |
| int smem_block_factor = 32; |
| |
| tv_out->split(0, smem_block_factor); |
| // tv_out->axis(-1)->parallelize(ParallelType::Unswitch); |
| |
| auto tv0_cache = tv0->cache_after(); |
| |
| tv0->computeAt(tv_out, 1); |
| |
| for (auto tv : tvs) { |
| tv->computeAt(tv_out, -1); |
| } |
| |
| tv0_cache->setMemoryType(MemoryType::Shared); |
| tv_out->axis(-1)->parallelize(ParallelType::TIDx); |
| tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| |
| int numel_x = 99; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x}, options); |
| std::vector<IValue> inputs = {t0}; |
| auto outputs = fe.runFusion(inputs); |
| |
| auto ref = (t0 + shift(t0, {-1}) + shift(t0, {1})) / 3; |
| |
| testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
| } |
| |
| TEST(NVFuserTest, FusionShift5ptStencilParallel_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| // 5-pt stencil |
| auto tv0 = makeSymbolicTensor(2); |
| fusion.addInput(tv0); |
| std::vector<std::vector<int>> offsets = {{-1, 0}, {1, 0}, {0, -1}, {0, 1}}; |
| |
| std::vector<TensorView*> tvs; |
| for (const auto& offset : offsets) { |
| tvs.push_back(shift(tv0, offset)); |
| } |
| |
| auto tv_out = tv0; |
| |
| for (auto tv : tvs) { |
| tv_out = add(tv_out, tv); |
| } |
| |
| tv_out = div(tv_out, new Double(tvs.size() + 1)); |
| |
| fusion.addOutput(tv_out); |
| |
| int smem_block_factor = 32; |
| |
| tv_out->split(-1, smem_block_factor); |
| tv_out->split(0, smem_block_factor); |
| |
| tv_out->reorder({{1, 2}, {2, 1}}); |
| |
| auto tv0_cache = tv0->cache_after(); |
| |
| tv0->computeAt(tv_out, 2); |
| |
| for (auto tv : tvs) { |
| tv->computeAt(tv_out, -1); |
| } |
| |
| tv_out->axis(-1)->parallelize(ParallelType::TIDx); |
| tv_out->axis(-2)->parallelize(ParallelType::TIDy); |
| tv_out->axis(-3)->parallelize(ParallelType::BIDx); |
| tv_out->axis(-4)->parallelize(ParallelType::BIDy); |
| |
| tv0_cache->setMemoryType(MemoryType::Shared); |
| tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); |
| tv0_cache->axis(-2)->parallelize(ParallelType::TIDy); |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| |
| int numel_x = 99; |
| int numel_y = 101; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
| std::vector<IValue> inputs = {t0}; |
| auto outputs = fe.runFusion(inputs); |
| |
| auto ref = t0; |
| for (const auto& offset : offsets) { |
| ref = ref + shift(t0, offset); |
| } |
| ref = ref / int(offsets.size() + 1); |
| |
| testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
| } |
| |
| TEST(NVFuserTest, FusionShiftMerge1_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| auto tv0 = makeSymbolicTensor(2); |
| fusion.addInput(tv0); |
| auto tv1 = add(tv0, new Double(1)); |
| auto tv2 = shift(tv1, {-1, 1}); |
| fusion.addOutput(tv2); |
| |
| int split_factor = 4; |
| |
| tv2->split(-1, split_factor); |
| tv2->split(0, split_factor); |
| tv2->reorder({{1, 2}, {2, 1}}); |
| tv2->merge(2, 3); |
| |
| tv0->computeAt(tv2, 2); |
| |
| // t1 allocation: (split_factor + 1) * (split_factor + 1) |
| GpuLower gpulw(&fusion); |
| for (const auto& kir_node : gpulw.kernel()->irNodes()) { |
| if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) { |
| auto tensor_name = alloc->buffer()->name(); |
| if (tensor_name == 1) { |
| TORCH_CHECK(alloc->shape().size() == 2); |
| for (int i = 0; i < 2; ++i) { |
| auto def = |
| dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition()); |
| auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs()); |
| TORCH_CHECK(lhs != nullptr && lhs->isConst()); |
| int lhs_value = *lhs->value(); |
| auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs()); |
| TORCH_CHECK(rhs != nullptr && rhs->isConst()); |
| int rhs_value = *rhs->value(); |
| TORCH_CHECK(lhs_value == split_factor && rhs_value == 1); |
| } |
| } |
| } |
| } |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| |
| int numel_x = 99; |
| int numel_y = 101; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
| std::vector<IValue> inputs = {t0}; |
| auto outputs = fe.runFusion(inputs); |
| |
| auto t1 = t0 + 1; |
| auto t2 = shift(t1, {-1, 1}); |
| auto ref = t2; |
| |
| testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
| } |
| |
| TEST(NVFuserTest, FusionShiftMerge2_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| auto tv0 = makeSymbolicTensor(2); |
| fusion.addInput(tv0); |
| auto tv1 = add(tv0, new Double(1)); |
| auto tv2 = shift(tv1, {1, -1}); |
| auto tv3 = shift(tv1, {-1, 1}); |
| auto tv4 = add(tv2, tv3); |
| fusion.addOutput(tv4); |
| |
| int split_factor = 4; |
| |
| tv4->split(-1, split_factor); |
| tv4->split(0, split_factor); |
| tv4->reorder({{1, 2}, {2, 1}}); |
| tv4->merge(2, 3); |
| |
| tv0->computeAt(tv4, -2); |
| |
| // t1 allocation: (split_factor + 2) * (split_factor + 2) |
| GpuLower gpulw(&fusion); |
| for (const auto& kir_node : gpulw.kernel()->irNodes()) { |
| if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) { |
| auto tensor_name = alloc->buffer()->name(); |
| if (tensor_name == 1) { |
| TORCH_CHECK(alloc->shape().size() == 2); |
| for (int i = 0; i < 2; ++i) { |
| auto def = |
| dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition()); |
| auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs()); |
| TORCH_CHECK(lhs != nullptr && lhs->isConst()); |
| int lhs_value = *lhs->value(); |
| auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs()); |
| TORCH_CHECK(rhs != nullptr && rhs->isConst()); |
| int rhs_value = *rhs->value(); |
| TORCH_CHECK(lhs_value == split_factor && rhs_value == 2); |
| } |
| } |
| } |
| } |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| |
| int numel_x = 99; |
| int numel_y = 101; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
| std::vector<IValue> inputs = {t0}; |
| auto outputs = fe.runFusion(inputs); |
| |
| auto t1 = t0 + 1; |
| auto t2 = shift(t1, {1, -1}); |
| auto t3 = shift(t1, {-1, 1}); |
| auto t4 = t2 + t3; |
| |
| TORCH_CHECK(t4.allclose(outputs[0])); |
| } |
| |
| TEST(NVFuserTest, FusionShiftGlobal_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| auto tv0 = makeSymbolicTensor(2); |
| fusion.addInput(tv0); |
| |
| auto tv1 = add(tv0, new Double(1)); |
| auto tv2 = shift(tv1, {0, 1}); |
| auto tv3 = shift(tv1, {-1, 0}); |
| auto tv4 = add(tv2, tv3); |
| fusion.addOutput(tv4); |
| |
| tv1->split(-1, 4); |
| tv2->split(-1, 8); |
| tv3->split(-1, 2); |
| tv4->split(-1, 3); |
| |
| tv1->merge(-2, -1); |
| |
| tv1->setMemoryType(MemoryType::Global); |
| tv2->setMemoryType(MemoryType::Global); |
| tv3->setMemoryType(MemoryType::Global); |
| |
| // t1 allocation: (t1.size[0] + 1) * (t1.size[1] + 1) |
| GpuLower gpulw(&fusion); |
| for (const auto& kir_node : gpulw.kernel()->irNodes()) { |
| if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) { |
| auto tensor_name = alloc->buffer()->name(); |
| if (tensor_name == 1) { |
| TORCH_CHECK(alloc->shape().size() == 2); |
| for (int i = 0; i < 2; ++i) { |
| auto def = |
| dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition()); |
| TORCH_CHECK(def != nullptr && def->operation() == BinaryOpType::Add); |
| TORCH_CHECK(def->as<kir::BinaryOp>()->lhs()->isA<kir::NamedScalar>()); |
| auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs()); |
| TORCH_CHECK(rhs != nullptr && rhs->isConst()); |
| int rhs_value = *rhs->value(); |
| TORCH_CHECK(rhs_value == 1); |
| } |
| } |
| } |
| } |
| |
| int numel_x = 99; |
| int numel_y = 101; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
| std::vector<IValue> inputs = {t0}; |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| auto outputs = fe.runFusion(inputs); |
| |
| auto t1 = t0 + 1; |
| auto t2 = shift(t1, {0, 1}); |
| auto t3 = shift(t1, {-1, 0}); |
| auto t4 = t2 + t3; |
| auto ref = t4; |
| |
| testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
| } |
| |
| TEST(NVFuserTest, FusionShiftDoubleSplitMerge1_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| auto tv0 = makeSymbolicTensor(2); |
| fusion.addInput(tv0); |
| auto tv1 = add(tv0, new Double(1)); |
| auto tv2 = add(tv1, new Double(2)); |
| auto tv3 = shift(tv2, {0, 1}); |
| fusion.addOutput(tv3); |
| |
| int split_factor1 = 8; |
| int split_factor2 = 4; |
| |
| tv3->split(-1, split_factor1); |
| |
| tv0->computeAt(tv3, -2); |
| |
| tv1->split(-1, split_factor2); |
| tv1->merge(-2, -1); |
| |
| // t1 and t2 allocation: (split_factor1 + 1) |
| GpuLower gpulw(&fusion); |
| for (const auto& kir_node : gpulw.kernel()->irNodes()) { |
| if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) { |
| auto tensor_name = alloc->buffer()->name(); |
| if (tensor_name == 1 || tensor_name == 2) { |
| TORCH_CHECK(alloc->shape().size() == 1); |
| auto def = |
| dynamic_cast<kir::BinaryOp*>(alloc->shape().at(0)->definition()); |
| auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs()); |
| TORCH_CHECK(lhs != nullptr && lhs->isConst()); |
| int lhs_value = *lhs->value(); |
| auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs()); |
| TORCH_CHECK(rhs != nullptr && rhs->isConst()); |
| int rhs_value = *rhs->value(); |
| TORCH_CHECK(lhs_value == split_factor1 && rhs_value == 1); |
| } |
| } |
| } |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| |
| int numel_x = 99; |
| int numel_y = 101; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
| std::vector<IValue> inputs = {t0}; |
| auto outputs = fe.runFusion(inputs); |
| |
| auto t1 = t0 + 3; |
| auto ref = shift(t1, {0, 1}); |
| |
| testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
| } |
| |
| TEST(NVFuserTest, FusionShiftDoubleSplitMerge2_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| auto tv0 = makeSymbolicTensor(2); |
| fusion.addInput(tv0); |
| |
| auto tv1 = add(tv0, new Double(1)); |
| auto tv2 = add(tv1, new Double(2)); |
| auto tv3 = shift(tv2, {1, 1}); |
| fusion.addOutput(tv3); |
| |
| auto out = tv3; |
| |
| int split_factor1 = 32; |
| int split_factor2 = 4; |
| |
| out->split(-1, split_factor1); |
| out->split(-1, split_factor2); |
| out->split(0, split_factor1); |
| out->split(1, split_factor2); |
| out->reorder({{3, 1}, {1, 2}, {4, 3}, {2, 4}}); |
| out->merge(2, 3); |
| out->merge(2, 3); |
| out->merge(2, 3); |
| out->merge(0, 1); |
| |
| TransformPropagator::from(out); |
| |
| tv0->computeAt(out, 1); |
| |
| out->axis(0)->parallelize(ParallelType::BIDx); |
| out->axis(1)->parallelize(ParallelType::TIDx); |
| |
| scheduler_utils::parallelizeAllLike(out, {tv1, tv2}); |
| |
| for (auto tv : {tv1, tv2}) { |
| tv->setMemoryType(MemoryType::Shared); |
| } |
| |
| // t1 and t2 allocation: (split_factor1 + 1) * (split_factor1 + 1) |
| GpuLower gpulw(&fusion); |
| for (const auto& kir_node : gpulw.kernel()->irNodes()) { |
| if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) { |
| auto tensor_name = alloc->buffer()->name(); |
| if (tensor_name == 1 || tensor_name == 2) { |
| TORCH_CHECK(alloc->shape().size() == 2); |
| for (int i = 0; i < 2; ++i) { |
| auto def = |
| dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition()); |
| auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs()); |
| TORCH_CHECK(lhs != nullptr && lhs->isConst()); |
| int lhs_value = *lhs->value(); |
| auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs()); |
| TORCH_CHECK(rhs != nullptr && rhs->isConst()); |
| int rhs_value = *rhs->value(); |
| TORCH_CHECK(lhs_value == split_factor1 && rhs_value == 1); |
| } |
| } |
| } |
| } |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| |
| int numel_x = 99; |
| int numel_y = 101; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
| std::vector<IValue> inputs = {t0}; |
| auto outputs = fe.runFusion(inputs); |
| |
| auto ref = shift(t0 + 1 + 2, {1, 1}); |
| |
| testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
| } |
| |
| TEST(NVFuserTest, FusionShift5ptStencilParallel1DThreadBlock_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| // 5-pt stencil |
| auto tv0 = makeSymbolicTensor(2); |
| fusion.addInput(tv0); |
| std::vector<std::vector<int>> offsets = {{-1, 0}, {1, 0}, {0, -1}, {0, 1}}; |
| |
| std::vector<TensorView*> tvs; |
| for (const auto& offset : offsets) { |
| tvs.push_back(shift(tv0, offset)); |
| } |
| |
| auto tv_out = tv0; |
| |
| for (auto tv : tvs) { |
| tv_out = add(tv_out, tv); |
| } |
| |
| tv_out = div(tv_out, new Double(tvs.size() + 1)); |
| |
| fusion.addOutput(tv_out); |
| |
| std::vector<int> split_factor({4, 32}); |
| |
| tv_out->split(-1, split_factor[1]); |
| tv_out->split(0, split_factor[0]); |
| tv_out->reorder({{1, 2}, {2, 1}}); |
| |
| auto tv0_cache = tv0->cache_after(); |
| |
| // Merge the inner-most two axes and create |
| // a 1D thread block of split_factor1*split_factor2 threads |
| tv_out->merge(-2, -1); |
| |
| tv0->computeAt(tv_out, 2); |
| |
| // Inline completely except for the cache |
| for (auto tv : tvs) { |
| tv->computeAt(tv_out, -1); |
| } |
| |
| tv0_cache->merge(-2, -1); |
| |
| tv_out->axis(-1)->parallelize(ParallelType::TIDx); |
| tv_out->axis(1)->parallelize(ParallelType::BIDx); |
| tv_out->axis(0)->parallelize(ParallelType::BIDy); |
| |
| tv0_cache->setMemoryType(MemoryType::Shared); |
| tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); |
| |
| // cache allocation: (split_factor1 + 2) * (split_factor2 + 2) |
| GpuLower gpulw(&fusion); |
| for (const auto& kir_node : gpulw.kernel()->irNodes()) { |
| if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) { |
| auto tensor_name = alloc->buffer()->name(); |
| if (tensor_name == tv0_cache->name()) { |
| TORCH_CHECK(alloc->shape().size() == 2); |
| for (int i = 0; i < 2; ++i) { |
| auto def = |
| dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition()); |
| auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs()); |
| TORCH_CHECK(lhs != nullptr && lhs->isConst()); |
| int lhs_value = *lhs->value(); |
| auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs()); |
| TORCH_CHECK(rhs != nullptr && rhs->isConst()); |
| int rhs_value = *rhs->value(); |
| TORCH_CHECK(lhs_value == split_factor[i] && rhs_value == 2); |
| } |
| } |
| } |
| } |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| |
| int numel_x = 99; |
| int numel_y = 101; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
| std::vector<IValue> inputs = {t0}; |
| auto outputs = fe.runFusion(inputs); |
| |
| auto ref = t0; |
| for (const auto& offset : offsets) { |
| ref = ref + shift(t0, offset); |
| } |
| ref = ref / int(offsets.size() + 1); |
| |
| testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
| } |
| |
| TEST(NVFuserTest, FusionShiftChain1_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| auto tv0 = makeSymbolicTensor(2); |
| fusion.addInput(tv0); |
| auto tv1 = shift(tv0, {0, 1}); |
| auto tv2 = shift(tv1, {0, 1}); |
| fusion.addOutput(tv2); |
| |
| int split_factor = 4; |
| tv2->split(-1, split_factor); |
| |
| tv0->computeAt(tv2, -2); |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| |
| int numel_x = 99; |
| int numel_y = 101; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
| std::vector<IValue> inputs = {t0}; |
| auto outputs = fe.runFusion(inputs); |
| |
| auto ref = shift(shift(t0, {0, 1}), {0, 1}); |
| |
| testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
| } |
| |
| TEST(NVFuserTest, FusionShiftChain2_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| auto tv0 = makeSymbolicTensor(2); |
| fusion.addInput(tv0); |
| auto tv1 = shift(tv0, {0, 1}); |
| auto tv2 = shift(tv1, {0, -1}); |
| fusion.addOutput(tv2); |
| |
| tv2->split(-1, 4); |
| |
| tv0->computeAt(tv2, -2); |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| |
| int numel_x = 99; |
| int numel_y = 101; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
| std::vector<IValue> inputs = {t0}; |
| auto outputs = fe.runFusion(inputs); |
| |
| auto ref = shift(shift(t0, {0, 1}), {0, -1}); |
| |
| testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
| } |
| |
| TEST(NVFuserTest, FusionShiftChain3_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| auto tv0 = makeSymbolicTensor(2); |
| fusion.addInput(tv0); |
| auto tv1 = add(tv0, new Double(1)); |
| auto tv2 = shift(tv1, {0, 1}); |
| auto tv3 = shift(tv2, {0, 1}); |
| fusion.addOutput(tv3); |
| |
| int split_factor = 4; |
| tv3->split(-1, split_factor); |
| |
| tv0->computeAt(tv3, -2); |
| |
| // Halo size of tv1 is 2 as it needs to account for both of the two |
| // shift operations , while that of tv2 is still just 1 |
| |
| // tv1: (split_factor + 2) |
| // tv2: (split_factor + 1) |
| GpuLower gpulw(&fusion); |
| for (const auto& kir_node : gpulw.kernel()->irNodes()) { |
| if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) { |
| auto tensor_name = alloc->buffer()->name(); |
| if (tensor_name == 1 || tensor_name == 2) { |
| TORCH_CHECK(alloc->shape().size() == 1); |
| for (int i = 0; i < 1; ++i) { |
| auto def = |
| dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition()); |
| auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs()); |
| TORCH_CHECK(lhs != nullptr && lhs->isConst()); |
| int lhs_value = *lhs->value(); |
| auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs()); |
| TORCH_CHECK(rhs != nullptr && rhs->isConst()); |
| int rhs_value = *rhs->value(); |
| TORCH_CHECK(lhs_value == split_factor); |
| if (tensor_name == 1) { |
| TORCH_CHECK(rhs_value == 2); |
| } else if (tensor_name == 2) { |
| TORCH_CHECK(rhs_value == 1); |
| } |
| } |
| } |
| } |
| } |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| |
| int numel_x = 99; |
| int numel_y = 101; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
| std::vector<IValue> inputs = {t0}; |
| auto outputs = fe.runFusion(inputs); |
| |
| auto t1 = t0 + 1; |
| auto t2 = shift(t1, {0, 1}); |
| auto t3 = shift(t2, {0, 1}); |
| auto ref = t3; |
| |
| testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
| } |
| |
| TEST(NVFuserTest, FusionShiftChain4_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| auto tv0 = makeSymbolicTensor(2); |
| fusion.addInput(tv0); |
| auto tv1 = shift(tv0, {1, -1}); |
| auto tv2 = shift(tv1, {2, -2}); |
| auto tv3 = shift(tv2, {3, -3}); |
| auto tv4 = shift(tv3, {4, -4}); |
| auto tv_out = tv4; |
| |
| fusion.addOutput(tv_out); |
| |
| int split_factor = 4; |
| |
| tv_out->split(-1, split_factor); |
| tv_out->split(0, split_factor); |
| tv_out->reorder({{1, 2}, {2, 1}}); |
| |
| tv0->computeAt(tv_out, 2); |
| |
| tv1->merge(-2, -1); |
| tv2->merge(-2, -1); |
| tv3->merge(-2, -1); |
| |
| // tv1: (split_factor + 9) * (split_factor + 9) |
| // tv2: (split_factor + 7) * (split_factor + 7) |
| // tv3: (split_factor + 4) * (split_factor + 4) |
| GpuLower gpulw(&fusion); |
| for (const auto& kir_node : gpulw.kernel()->irNodes()) { |
| if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) { |
| auto tensor_name = alloc->buffer()->name(); |
| if (tensor_name == 1 || tensor_name == 2) { |
| TORCH_CHECK(alloc->shape().size() == 2); |
| for (int i = 0; i < 2; ++i) { |
| auto def = |
| dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition()); |
| auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs()); |
| TORCH_CHECK(lhs != nullptr && lhs->isConst()); |
| int lhs_value = *lhs->value(); |
| auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs()); |
| TORCH_CHECK(rhs != nullptr && rhs->isConst()); |
| int rhs_value = *rhs->value(); |
| TORCH_CHECK(lhs_value == split_factor); |
| if (tensor_name == 1) { |
| TORCH_CHECK(rhs_value == 9); |
| } else if (tensor_name == 2) { |
| TORCH_CHECK(rhs_value == 7); |
| } else if (tensor_name == 3) { |
| TORCH_CHECK(rhs_value == 4); |
| } |
| } |
| } |
| } |
| } |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| |
| int numel_x = 99; |
| int numel_y = 101; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
| std::vector<IValue> inputs = {t0}; |
| auto outputs = fe.runFusion(inputs); |
| |
| auto t1 = shift(t0, {1, -1}); |
| auto t2 = shift(t1, {2, -2}); |
| auto t3 = shift(t2, {3, -3}); |
| auto t4 = shift(t3, {4, -4}); |
| auto ref = t4; |
| |
| testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
| } |
| |
| TEST(NVFuserTest, FusionShift5ptStencilChain_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| auto tv0 = makeSymbolicTensor(2); |
| fusion.addInput(tv0); |
| std::vector<std::vector<int>> offsets = {{-1, 0}, {1, 0}, {0, -1}, {0, 1}}; |
| |
| // First stencil: 5pt stencil |
| // stencil1 = (tv0 + tv0[+1][0] + tv0[-1][0] + tv0[0][+1] + tv0[0][-1]) / 5 |
| std::vector<TensorView*> tv_stencil1_shifts; |
| for (const auto& offset : offsets) { |
| tv_stencil1_shifts.push_back(shift(tv0, offset)); |
| } |
| |
| auto tv_stencil1 = tv0; |
| for (auto tv : tv_stencil1_shifts) { |
| tv_stencil1 = add(tv_stencil1, tv); |
| } |
| |
| tv_stencil1 = div(tv_stencil1, new Double(tv_stencil1_shifts.size() + 1)); |
| |
| // Second stencil: Same 5pt stencil |
| std::vector<TensorView*> tv_stencil2_shifts; |
| for (const auto& offset : offsets) { |
| tv_stencil2_shifts.push_back(shift(tv_stencil1, offset)); |
| } |
| |
| auto tv_stencil2 = tv_stencil1; |
| for (auto tv : tv_stencil2_shifts) { |
| tv_stencil2 = add(tv_stencil2, tv); |
| } |
| |
| tv_stencil2 = div(tv_stencil2, new Double(tv_stencil2_shifts.size() + 1)); |
| |
| auto tv_out = tv_stencil2; |
| |
| fusion.addOutput(tv_out); |
| |
| auto tv0_cache = tv0->cache_after(); |
| |
| std::vector<int> split_factor({16, 16}); |
| |
| tv_out->split(-1, split_factor[1]); |
| tv_out->split(0, split_factor[0]); |
| tv_out->reorder({{1, 2}, {2, 1}}); |
| |
| tv0->computeAt(tv_out, 2); |
| |
| // Inline completely all inputs to the first stencil output, except for the |
| // tv0 cache |
| for (auto tv : tv_stencil1_shifts) { |
| tv->computeAt(tv_stencil1, -1); |
| } |
| |
| // Inline completely all inputs to the second stencil output, except |
| // for the first stencil output |
| for (auto tv : tv_stencil2_shifts) { |
| tv->computeAt(tv_stencil2, -1); |
| } |
| |
| tv_out->axis(1)->parallelize(ParallelType::BIDx); |
| tv_out->axis(0)->parallelize(ParallelType::BIDy); |
| |
| auto all_values = DependencyCheck::getAllValsBetween( |
| {fusion.inputs().begin(), fusion.inputs().end()}, fusion.outputs()); |
| for (auto tv : ir_utils::filterByType<TensorView>(all_values)) { |
| tv->axis(-1)->parallelize(ParallelType::TIDx); |
| tv->axis(-2)->parallelize(ParallelType::TIDy); |
| } |
| |
| tv0_cache->setMemoryType(MemoryType::Shared); |
| tv_stencil1->setMemoryType(MemoryType::Shared); |
| |
| // tv0_cache: (split_factor + 4) * (split_factor + 4) |
| // tv_stencil1: (split_factor + 2) * (split_factor + 2) |
| GpuLower gpulw(&fusion); |
| for (const auto& kir_node : gpulw.kernel()->irNodes()) { |
| if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) { |
| auto tensor_name = alloc->buffer()->name(); |
| if (tensor_name == tv0_cache->name() || |
| tensor_name == tv_stencil1->name()) { |
| TORCH_CHECK(alloc->shape().size() == 2); |
| for (int i = 0; i < 2; ++i) { |
| auto def = |
| dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition()); |
| auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs()); |
| TORCH_CHECK(lhs != nullptr && lhs->isConst()); |
| int lhs_value = *lhs->value(); |
| auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs()); |
| TORCH_CHECK(rhs != nullptr && rhs->isConst()); |
| int rhs_value = *rhs->value(); |
| TORCH_CHECK(lhs_value == split_factor[i]); |
| if (tensor_name == tv0_cache->name()) { |
| TORCH_CHECK(rhs_value == 4); |
| } else if (tensor_name == tv_stencil1->name()) { |
| TORCH_CHECK(rhs_value == 2); |
| } |
| } |
| } |
| } |
| } |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| |
| int numel_x = 99; |
| int numel_y = 101; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
| std::vector<IValue> inputs = {t0}; |
| auto outputs = fe.runFusion(inputs); |
| |
| auto stencil1 = t0; |
| for (const auto& offset : offsets) { |
| stencil1 = stencil1 + shift(t0, offset); |
| } |
| stencil1 = stencil1 / int(offsets.size() + 1); |
| auto stencil2 = stencil1; |
| for (const auto& offset : offsets) { |
| stencil2 = stencil2 + shift(stencil1, offset); |
| } |
| stencil2 = stencil2 / int(offsets.size() + 1); |
| auto ref = stencil2; |
| |
| testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
| } |
| |
| // Shift a reduced tensor |
| TEST(NVFuserTest, FusionShiftReduction1_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| auto tv0 = makeSymbolicTensor(2); |
| fusion.addInput(tv0); |
| auto tv1 = add(tv0, new Double(1)); |
| auto tv2 = sum(tv1, {1}); |
| auto tv3 = shift(tv2, {1}); |
| fusion.addOutput(tv3); |
| |
| tv3->split(0, 4); |
| tv0->computeAt(tv3, 1); |
| tv0->computeAt(tv2, -1); |
| |
| const int numel_x = 9; |
| const int numel_y = 11; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
| std::vector<IValue> inputs = {t0}; |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| auto outputs = fe.runFusion(inputs); |
| |
| auto t1 = t0 + 1; |
| auto t2 = sum(t1, {1}); |
| auto t3 = shift(t2, {1}); |
| auto ref = t3; |
| |
| testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
| } |
| |
| // Parallelized version of FusionShiftReduction1 |
| TEST(NVFuserTest, FusionShiftReduction2_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| auto tv0 = makeSymbolicTensor(2); |
| fusion.addInput(tv0); |
| auto tv1 = add(tv0, new Double(1)); |
| auto tv2 = sum(tv1, {1}); |
| auto tv3 = shift(tv2, {1}); |
| fusion.addOutput(tv3); |
| |
| tv3->split(0, 4); |
| tv0->computeAt(tv3, 1); |
| |
| tv2->split(-1, 32); |
| tv0->computeAt(tv2, -1); |
| |
| tv2->axis(-1)->parallelize(ParallelType::TIDx); |
| |
| tv2->setMemoryType(MemoryType::Shared); |
| |
| const int numel_x = 201; |
| const int numel_y = 301; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
| std::vector<IValue> inputs = {t0}; |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| auto outputs = fe.runFusion(inputs); |
| |
| auto t1 = t0 + 1; |
| auto t2 = sum(t1, {1}); |
| auto t3 = shift(t2, {1}); |
| auto ref = t3; |
| |
| testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
| } |
| |
| TEST(NVFuserTest, FusionShiftRfactor1_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| auto tv0 = makeSymbolicTensor(2); |
| fusion.addInput(tv0); |
| auto tv1 = add(tv0, new Double(1)); |
| auto tv2 = sum(tv1, {1}); |
| auto tv3 = shift(tv2, {1}); |
| fusion.addOutput(tv3); |
| |
| tv3->split(0, 4); |
| tv0->computeAt(tv3, 1); |
| |
| tv2->split(-1, 32); |
| auto rf = tv2->rFactor({-2}); |
| tv0->computeAt(tv2, -1); |
| tv0->computeAt(rf, -1); |
| |
| tv2->axis(-1)->parallelize(ParallelType::TIDx); |
| |
| tv2->setMemoryType(MemoryType::Shared); |
| |
| const int numel_x = 201; |
| const int numel_y = 301; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
| std::vector<IValue> inputs = {t0}; |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| auto outputs = fe.runFusion(inputs); |
| |
| auto t1 = t0 + 1; |
| auto t2 = sum(t1, {1}); |
| auto t3 = shift(t2, {1}); |
| auto ref = t3; |
| |
| testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
| } |
| |
| TEST(NVFuserTest, FusionShiftBcast1_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| auto tv0 = makeSymbolicTensor(1); |
| fusion.addInput(tv0); |
| auto tv1 = makeSymbolicTensor(2); |
| fusion.addInput(tv1); |
| auto tv2 = broadcast(tv0, {false, true}); |
| auto tv3 = shift(tv2, {0, 1}); |
| auto tv4 = add(tv3, tv1); |
| fusion.addOutput(tv4); |
| |
| tv0->computeAt(tv4, -1); |
| tv1->computeAt(tv4, -1); |
| |
| const int numel_x = 9; |
| const int numel_y = 11; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x}, options); |
| at::Tensor t1 = at::randn({numel_x, numel_y}, options); |
| std::vector<IValue> inputs = {t0, t1}; |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| auto outputs = fe.runFusion(inputs); |
| |
| auto t4 = t0.unsqueeze(-1).expand({numel_x, numel_y}) + t1; |
| auto ref = t4; |
| |
| testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
| } |
| |
| TEST(NVFuserTest, FusionShiftBcast2_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| auto tv0 = makeSymbolicTensor(1); |
| fusion.addInput(tv0); |
| auto tv1 = makeSymbolicTensor(2); |
| fusion.addInput(tv1); |
| auto tv2 = broadcast(tv0, {false, true}); |
| auto tv3 = shift(tv2, {1, 0}); |
| auto tv4 = add(tv3, tv1); |
| fusion.addOutput(tv4); |
| |
| tv4->split(0, 4); |
| tv0->computeAt(tv4, 1); |
| |
| const int numel_x = 9; |
| const int numel_y = 11; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x}, options); |
| at::Tensor t1 = at::randn({numel_x, numel_y}, options); |
| std::vector<IValue> inputs = {t0, t1}; |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| auto outputs = fe.runFusion(inputs); |
| |
| auto t2 = t0.unsqueeze(-1).expand({numel_x, numel_y}); |
| auto t3 = shift(t2, {1, 0}); |
| auto ref = t3 + t1; |
| |
| testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
| } |
| |
| // Combine ShiftBcast1 and ShiftBcast2 with parallelization |
| TEST(NVFuserTest, FusionShiftBcast3_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| auto tv0 = makeSymbolicTensor(1); |
| fusion.addInput(tv0); |
| auto tv1 = makeSymbolicTensor(2); |
| fusion.addInput(tv1); |
| auto tv2 = broadcast(tv0, {false, true}); |
| auto tv3 = shift(tv2, {1, 0}); |
| auto tv4 = shift(tv2, {0, 1}); |
| auto tv5 = shift(tv2, {-1, -1}); |
| auto tv6 = add(tv3, tv4); |
| auto tv7 = add(tv6, tv5); |
| auto tv8 = add(tv7, tv1); |
| fusion.addOutput(tv8); |
| |
| tv8->split(0, 4); |
| tv8->split(-1, 4); |
| tv0->computeAt(tv8, 1); |
| |
| tv8->axis(-1)->parallelize(ParallelType::TIDx); |
| for (auto tv : {tv8, tv7, tv6, tv5, tv4, tv3, tv2}) { |
| tv->axis(1)->parallelize(ParallelType::TIDy); |
| } |
| |
| tv2->setMemoryType(MemoryType::Shared); |
| |
| const int numel_x = 101; |
| const int numel_y = 201; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x}, options); |
| at::Tensor t1 = at::randn({numel_x, numel_y}, options); |
| std::vector<IValue> inputs = {t0, t1}; |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| auto outputs = fe.runFusion(inputs); |
| |
| auto t2 = t0.unsqueeze(-1).expand({numel_x, numel_y}); |
| auto t3 = shift(t2, {1, 0}); |
| auto t4 = t2; |
| auto t5 = shift(t2, {-1, 0}); |
| auto ref = t3 + t4 + t5 + t1; |
| |
| testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
| } |
| |
| // See issue #893 |
| TEST(NVFuserTest, FusionShiftSyncPlacement1_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| auto tv0 = makeSymbolicTensor(2); |
| fusion.addInput(tv0); |
| auto tv1 = add(tv0, new Double(1)); |
| auto tv2 = add(tv0, new Double(2)); |
| auto tv3 = add(tv1, tv2); |
| auto tv4 = shift(tv3, {0, 1}); |
| fusion.addOutput(tv4); |
| |
| tv4->split(1, 8); |
| tv0->computeAt(tv4, 2); |
| |
| tv2->computeAt(tv3, -1); |
| |
| tv1->setMemoryType(MemoryType::Shared); |
| tv3->setMemoryType(MemoryType::Shared); |
| |
| tv1->axis(-1)->parallelize(ParallelType::TIDx); |
| tv3->axis(-1)->parallelize(ParallelType::TIDx); |
| tv4->axis(-1)->parallelize(ParallelType::TIDx); |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| |
| int numel_x = 99; |
| int numel_y = 101; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x, numel_y}, options); |
| std::vector<IValue> inputs = {t0}; |
| auto outputs = fe.runFusion(inputs); |
| |
| auto t1 = t0 + 1; |
| auto t2 = t0 + 2; |
| auto t3 = add(t1, t2); |
| auto t4 = shift(t3, {0, 1}); |
| |
| testValidate(&fusion, outputs, inputs, {t4}, __LINE__, __FILE__); |
| } |
| |
| // See issue #893. Top-level placement. |
| TEST(NVFuserTest, FusionShiftSyncPlacement2_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| auto tv0 = makeSymbolicTensor(1); |
| fusion.addInput(tv0); |
| auto tv1 = add(tv0, new Double(1)); |
| auto tv2 = add(tv0, new Double(2)); |
| auto tv3 = add(tv1, tv2); |
| auto tv4 = shift(tv3, {1}); |
| fusion.addOutput(tv4); |
| |
| tv2->computeAt(tv3, -1); |
| |
| tv1->setMemoryType(MemoryType::Shared); |
| tv3->setMemoryType(MemoryType::Shared); |
| |
| tv1->axis(-1)->parallelize(ParallelType::TIDx); |
| tv3->axis(-1)->parallelize(ParallelType::TIDx); |
| tv4->axis(-1)->parallelize(ParallelType::TIDx); |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| |
| int numel_x = 99; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({numel_x}, options); |
| std::vector<IValue> inputs = {t0}; |
| auto outputs = fe.runFusion(inputs); |
| |
| auto t1 = t0 + 1; |
| auto t2 = t0 + 2; |
| auto t3 = add(t1, t2); |
| auto t4 = shift(t3, {1}); |
| |
| testValidate(&fusion, outputs, inputs, {t4}, __LINE__, __FILE__); |
| } |
| |
| TEST(NVFuserTest, FusionShiftSyncPlacement3_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| auto tv0 = makeSymbolicTensor(1); |
| fusion.addInput(tv0); |
| auto tv1 = add(tv0, new Double(1)); |
| auto tv2 = add(tv1, new Double(2)); |
| auto tv3 = shift(tv2, {1}); |
| fusion.addOutput(tv3); |
| |
| // This doesn't work. syncthreads is needed between tv1 and tv2, but |
| // both the loop extent of both tv1 and tv2 has halo, so the loop is |
| // not eliminated even though it is parallelized. Moving syncthreads |
| // out of the loop would make it placed before tv1, which would make |
| // it meaningless. |
| // Ideally, an exception should be thrown at this computeAt, but at |
| // this point, the fusion is not yet parallelized, nor memory type |
| // is set, so this computeAt itself is not an error yet. |
| tv1->computeAt(tv2, -1); |
| |
| tv1->setMemoryType(MemoryType::Shared); |
| tv2->setMemoryType(MemoryType::Shared); |
| |
| tv1->axis(-1)->parallelize(ParallelType::TIDx); |
| tv2->axis(-1)->parallelize(ParallelType::TIDx); |
| tv3->axis(-1)->parallelize(ParallelType::TIDx); |
| |
| // The error should be detected when the fusion is lowered. |
| ASSERT_ANY_THROW(fusion.printKernel()); |
| } |
| |
| // Based on original CUDA provided by Vishal Mehta. |
| // Major differences with the original version: |
| // - Boundary processing. We always pad by zero. The original version |
| // is only defined for the interior domain. |
| // - The original version uses additional 2 warps to load the halos |
| // along the Y dimension. The other 10 warps are used to load a 32x10 |
| // tile, and all warps will do coalesced loads. No such optimization |
| // is done in the fuser version. |
| TEST(NVFuserTest, FusionHorizontalDiffusion_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| auto inp = makeSymbolicTensor(3); |
| fusion.addInput(inp); |
| auto coeff = makeSymbolicTensor(3); |
| fusion.addInput(coeff); |
| |
| std::vector<std::vector<int>> offsets{ |
| {0, 1, 0}, {0, -1, 0}, {0, 0, 1}, {0, 0, -1}}; |
| |
| // T2, T3, T4, T5 |
| std::vector<TensorView*> inp_neighbors; |
| for (const auto& offset : offsets) { |
| inp_neighbors.push_back(shift(inp, offset)); |
| } |
| |
| // T8 |
| TensorView* sum_of_neighbors = nullptr; |
| for (auto inp_neighbor : inp_neighbors) { |
| if (sum_of_neighbors == nullptr) { |
| sum_of_neighbors = inp_neighbor; |
| } else { |
| sum_of_neighbors = add(sum_of_neighbors, inp_neighbor); |
| } |
| } |
| |
| // T9 = T0 * 4 |
| // T10 = T9 - T8 |
| auto lap = sub(mul(inp, new Double(4)), sum_of_neighbors); |
| |
| // T11 = shift(T10) |
| // T12 = T11 - T10 |
| auto flx = sub(shift(lap, {0, 0, -1}), lap); |
| // T14 = T13 - T0 |
| // T15 = T12 * T14 |
| // T16 = T15 > 0 |
| // T17 = T16 ? 0 : T12 |
| auto flx_cond = gt(mul(flx, sub(shift(inp, {0, 0, -1}), inp)), new Double(0)); |
| auto flx0 = where(flx_cond, new Double(0), flx); |
| |
| // T18 = shift(T10) |
| // T19 = T18 - T10 |
| auto fly = sub(shift(lap, {0, -1, 0}), lap); |
| // T20 = shift(T0) |
| // T21 = T20 - T0 |
| // T22 = T19 * T21 |
| // T23 = T22 > 0 |
| auto fly_cond = gt(mul(fly, sub(shift(inp, {0, -1, 0}), inp)), new Double(0)); |
| // T24 = T23 ? 0 : T19 |
| auto fly0 = where(fly_cond, new Double(0), fly); |
| |
| // T25 = shift(flx0) |
| // T26 = T17 - T25 |
| // T27 = shift(fly0) |
| // T28 = T24 - T27 |
| // T29 = T26 + T28 |
| // T30 = T1 * T29 |
| // T31 = T0 - T30 |
| auto out = |
| sub(inp, |
| mul(coeff, |
| add(sub(flx0, shift(flx0, {0, 0, 1})), |
| sub(fly0, shift(fly0, {0, 1, 0}))))); |
| |
| fusion.addOutput(out); |
| |
| ///////////////////////////////// |
| // Scheduling |
| ///////////////////////////////// |
| |
| // Step 1: 2D Tiling |
| |
| const int tile_x = 32; |
| const int tile_y = 8; |
| |
| out->split(-1, tile_x); |
| out->split(-3, tile_y); |
| out->reorder({{-2, -3}}); |
| inp->computeAt(out, -3); |
| coeff->computeAt(out, -3); |
| |
| // Step 2: Inlining |
| |
| // Inline inputs to lap |
| auto lap_vals = DependencyCheck::getAllValsBetween({inp}, {lap}); |
| for (auto val : ir_utils::filterByType<TensorView>(lap_vals)) { |
| if (val != lap && val != inp) { |
| val->computeAt(lap, -1); |
| } |
| } |
| |
| // Inline inputs to flx0 |
| auto flx0_vals = DependencyCheck::getAllValsBetween({lap, inp}, {flx0}); |
| for (auto val : ir_utils::filterByType<TensorView>(flx0_vals)) { |
| if (val != lap && val != flx0 && val != inp) { |
| val->computeAt(flx0, -1); |
| } |
| } |
| |
| // Inline inputs to fly0 |
| auto flxy_vals = DependencyCheck::getAllValsBetween({lap, inp}, {fly0}); |
| for (auto val : ir_utils::filterByType<TensorView>(flxy_vals)) { |
| if (val != lap && val != fly0 && val != inp) { |
| val->computeAt(fly0, -1); |
| } |
| } |
| |
| // Inline inputs to out |
| auto out_vals = DependencyCheck::getAllValsBetween({flx0, fly0}, {out}); |
| for (auto val : ir_utils::filterByType<TensorView>(out_vals)) { |
| if (val != flx0 && val != fly0 && val != out) { |
| val->computeAt(out, -1); |
| } |
| } |
| |
| // Step 3: Parallelization |
| |
| // Block parallelization |
| out->axis(0)->parallelize(ParallelType::BIDz); |
| out->axis(1)->parallelize(ParallelType::BIDy); |
| out->axis(2)->parallelize(ParallelType::BIDx); |
| |
| // Thread parallelization |
| for (auto tv : {out, flx0, fly0, lap}) { |
| tv->axis(3)->parallelize(ParallelType::TIDy); |
| tv->axis(4)->parallelize(ParallelType::TIDx); |
| if (tv != out) { |
| tv->setMemoryType(MemoryType::Shared); |
| } |
| } |
| |
| ///////////////////////////////// |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| |
| int numel_x = 101; |
| int numel_y = 99; |
| int numel_z = 10; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor inp_at = at::randn({numel_z, numel_y, numel_x}, options); |
| at::Tensor coeff_at = at::randn({numel_z, numel_y, numel_x}, options); |
| std::vector<IValue> inputs = {inp_at, coeff_at}; |
| auto outputs = fe.runFusion(inputs); |
| |
| { |
| at::Tensor zeros = at::zeros({numel_z, numel_y, numel_x}, options); |
| auto lap = inp_at * 4 - |
| (shift(inp_at, {0, 1, 0}) + shift(inp_at, {0, -1, 0}) + |
| shift(inp_at, {0, 0, 1}) + shift(inp_at, {0, 0, -1})); |
| auto flx = shift(lap, {0, 0, -1}) - lap; |
| auto flx_cond = (flx * (shift(inp_at, {0, 0, -1}) - inp_at)) > 0; |
| auto flx0 = at::where(flx_cond, zeros, flx); |
| auto fly = shift(lap, {0, -1, 0}) - lap; |
| auto fly_cond = (fly * (shift(inp_at, {0, -1, 0}) - inp_at)) > 0; |
| auto fly0 = at::where(fly_cond, zeros, fly); |
| |
| auto ref = inp_at - |
| coeff_at * |
| ((flx0 - shift(flx0, {0, 0, 1})) + (fly0 - shift(fly0, {0, 1, 0}))); |
| |
| testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
| } |
| } |
| |
| // 3x3 max pooling |
| TEST(NVFuserTest, FusionMaxPooling_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| // Format: CHW |
| auto inp = makeSymbolicTensor(3); |
| fusion.addInput(inp); |
| |
| // 3x3 pooling of the HW spatial domain |
| std::vector<std::vector<int>> offsets; |
| for (int i = -1; i <= 1; ++i) { |
| for (int j = -1; j <= 1; ++j) { |
| if (i == 0 && j == 0) { |
| continue; |
| } |
| offsets.push_back({i, j}); |
| } |
| } |
| |
| std::vector<TensorView*> inp_tile({inp}); |
| for (auto offset : offsets) { |
| offset.insert(offset.begin(), 0); |
| inp_tile.push_back(shift(inp, offset)); |
| } |
| |
| TensorView* max_tensor = nullptr; |
| for (auto tv : inp_tile) { |
| if (max_tensor == nullptr) { |
| max_tensor = tv; |
| } else { |
| max_tensor = binaryOp(BinaryOpType::Max, max_tensor, tv); |
| } |
| } |
| |
| fusion.addOutput(max_tensor); |
| |
| //////////////////////////////////// |
| |
| // Cache the input and weight tensors |
| auto inp_cache = inp->cache_after(); |
| |
| // Tiling the spatial domain |
| const int tile_x = 32; |
| const int tile_y = 8; |
| |
| max_tensor->split(-2, tile_y); |
| max_tensor->axis(-2)->parallelize(ParallelType::TIDy); |
| max_tensor->split(-1, tile_x); |
| max_tensor->axis(-1)->parallelize(ParallelType::TIDx); |
| max_tensor->reorder({{-3, -2}}); |
| |
| inp_cache->computeAt(max_tensor, 3); |
| inp_cache->axis(-2)->parallelize(ParallelType::TIDy); |
| inp_cache->axis(-1)->parallelize(ParallelType::TIDx); |
| inp_cache->setMemoryType(MemoryType::Shared); |
| |
| auto max_tensor_dep = |
| DependencyCheck::getAllValsBetween({inp_cache}, {max_tensor}); |
| for (auto tv : ir_utils::filterByType<TensorView>(max_tensor_dep)) { |
| if (tv == inp_cache || tv == max_tensor) { |
| continue; |
| } |
| tv->computeAt(max_tensor, -1); |
| } |
| |
| max_tensor->axis(0)->parallelize(ParallelType::BIDx); |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| |
| const int hw = 50; |
| const int num_channels = 20; |
| const int pooling_window = 3; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor aten_inp = at::randn({num_channels, hw, hw}, options); |
| // shift always pads by zero, so if all surrounding values are |
| // negative, max pooling would pick a padded value, which isn't the |
| // correct behavior. We need to be able to choose the value of |
| // padding. In this case, padding by the minimum value would not |
| // have this problem. For now, avoid the problem by making sure all |
| // values are not negative. |
| aten_inp = at::abs(aten_inp); |
| std::vector<IValue> inputs = {aten_inp}; |
| |
| auto outputs = fe.runFusion(inputs); |
| |
| auto ref = at::max_pool2d( |
| aten_inp, {pooling_window, pooling_window}, {1, 1}, {1, 1}); |
| |
| testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
| } |
| |
| TEST(NVFuserTest, FusionGatherPadding1_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| auto tv0 = makeSymbolicTensor(2); |
| fusion.addInput(tv0); |
| |
| const std::vector<int> window_shape = {1, 3}; |
| const std::vector<std::vector<int>> padding_width = {{0, 0}, {1, 1}}; |
| |
| auto tv1 = gather(tv0, window_shape, padding_width); |
| |
| fusion.addOutput(tv1); |
| |
| const int s1 = 11; |
| const int s2 = 13; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({s1, s2}, options); |
| |
| auto ref = gather(t0, window_shape, padding_width); |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| auto outputs = fe.runFusion({t0}); |
| |
| TORCH_CHECK(ref.equal(outputs[0])); |
| } |
| |
| TEST(NVFuserTest, FusionGatherPadding2_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| const std::vector<int> window_shape = {1, 3}; |
| const std::vector<std::vector<int>> padding_width = {{0, 0}, {1, 1}}; |
| |
| auto tv0 = makeSymbolicTensor(2); |
| fusion.addInput(tv0); |
| |
| auto tv1 = add(tv0, new Double(1)); |
| |
| auto tv2 = gather(tv1, window_shape, padding_width); |
| |
| auto tv3 = sum(tv2, {-1}); |
| |
| fusion.addOutput(tv3); |
| |
| tv3->split(1, 32); |
| tv0->computeAt(tv3, 2); |
| tv2->computeAt(tv3, -1); |
| |
| tv3->axis(0)->parallelize(ParallelType::BIDy); |
| tv3->axis(1)->parallelize(ParallelType::BIDx); |
| tv3->axis(2)->parallelize(ParallelType::TIDx); |
| tv1->axis(2)->parallelize(ParallelType::TIDx); |
| |
| tv1->setMemoryType(MemoryType::Shared); |
| |
| const int s1 = 99; |
| const int s2 = 101; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::Tensor t0 = at::randn({s1, s2}, options); |
| std::vector<IValue> inputs = {t0}; |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| auto outputs = fe.runFusion(inputs); |
| |
| auto t1 = t0 + 1; |
| auto t2 = gather(t1, window_shape, padding_width); |
| auto ref = sum(t2, {-1}); |
| |
| testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); |
| } |
| |
| TEST(NVFuserTest, FusionConv2DStatic_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| // Input: [C, H, W] |
| auto inp = makeSymbolicTensor(3); |
| fusion.addInput(inp); |
| |
| // Weights: [K, C, 3, 3] |
| auto w = makeSymbolicTensor(4); |
| fusion.addInput(w); |
| |
| // Gather a neighbor tile of [3, 3] with padding size of 1 for each |
| // side of the spatial dimensions |
| auto inp_tile = gather(inp, {1, 3, 3}, {{0, 0}, {1, 1}, {1, 1}}); |
| // inp_tile: [C, H, W, 1, 3, 3] |
| |
| auto inp_bc = |
| broadcast(inp_tile, {true, false, false, false, false, false, false}); |
| auto w_bc = broadcast(w, {false, false, true, true, true, false, false}); |
| |
| auto inp_times_w = mul(inp_bc, w_bc); |
| |
| // Reduce the channel and neighbor tile dimensions |
| auto out = sum(inp_times_w, {1, 4, 5, 6}); |
| |
| fusion.addOutput(out); |
| |
| //////////////////////////////////// |
| |
| // Cache the input and weight tensors |
| auto inp_cache = inp->cache_after(); |
| |
| // Blocking the spatial dimensions |
| const int block_w = 16; |
| const int block_h = 4; |
| // Blocking the channel dimension |
| const int block_c = 8; |
| |
| out->split(2, block_h); |
| out->split(4, block_w); |
| out->reorder({{3, 4}}); |
| // out: [K, C, Ho, Wo, Hi, Wi, 1, 3, 3] |
| |
| out->split(1, block_c); |
| // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 3, 3] |
| |
| auto out_rf = out->rFactor({1, -3, -2, -1}); |
| // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 3, 3] |
| // out_rf: [K, Ci, Ho, Wo, Hi, Wi] |
| |
| // Create a [block_x, block_y] tile on smem |
| inp_cache->computeAt(out, 4); |
| // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi] |
| inp_cache->setMemoryType(MemoryType::Shared); |
| |
| // Move Ci forward |
| out_rf->reorder({{-4, -6}, {-5, -4}, {-6, -5}}); |
| inp_cache->computeAt(out_rf, 5); |
| |
| inp_tile->computeAt(out_rf, -1); |
| w->computeAt(out_rf, -1); |
| |
| out->axis(0)->parallelize(ParallelType::BIDx); |
| out->axis(1)->parallelize(ParallelType::TIDz); |
| out->axis(4)->parallelize(ParallelType::TIDy); |
| out->axis(5)->parallelize(ParallelType::TIDx); |
| |
| scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf}); |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| |
| const int dim_h = 99; |
| const int dim_w = 101; |
| const int dim_c = 10; |
| const int dim_f = 20; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::manual_seed(0); |
| at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options); |
| at::Tensor at_w = at::randn({dim_f, dim_c, 3, 3}, options); |
| std::vector<IValue> inputs = {at_inp, at_w}; |
| |
| auto cg_outputs = fe.runFusion(inputs); |
| |
| at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis |
| auto at_out = at::conv2d(at_inp, at_w, {}, 1, 1); |
| at_out = at_out.squeeze(0); // drop the N axis |
| |
| testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); |
| } |
| |
| // Mostly the same as the static conv test, but the shape of the weights, |
| // 3x3 in this case, is given dynamically |
| TEST(NVFuserTest, FusionConv2DDynamic_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| // Input: [C, H, W] |
| auto inp = makeSymbolicTensor(3); |
| fusion.addInput(inp); |
| |
| // Weights: [K, C, S, T] |
| auto w = makeSymbolicTensor(4); |
| fusion.addInput(w); |
| |
| auto w_h = new Int(); |
| fusion.addInput(w_h); |
| auto w_w = new Int(); |
| fusion.addInput(w_w); |
| |
| auto pad_h = new Int(); |
| fusion.addInput(pad_h); |
| auto pad_w = new Int(); |
| fusion.addInput(pad_w); |
| |
| // Gather a neighbor tile of [w_dim_h, w_dim_w] with padding |
| auto inp_tile = gather( |
| inp, |
| {new Int(1), w_h, w_w}, |
| {{new Int(0), new Int(0)}, {pad_h, pad_h}, {pad_w, pad_w}}); |
| // inp_tile: [C, 1, H - w_h + 1, W - w_w + 1, w_h, w_w] |
| |
| auto inp_bc = |
| broadcast(inp_tile, {true, false, false, false, false, false, false}); |
| auto w_bc = broadcast(w, {false, false, true, true, true, false, false}); |
| |
| auto inp_times_w = mul(inp_bc, w_bc); |
| |
| // Reduce the channel and neighbor tile dimensions |
| auto out = sum(inp_times_w, {1, 4, 5, 6}); |
| |
| fusion.addOutput(out); |
| |
| //////////////////////////////////// |
| // Cache the input and weight tensors |
| auto inp_cache = inp->cache_after(); |
| |
| // Blocking the spatial dimensions |
| const int block_w = 16; |
| const int block_h = 4; |
| // Blocking the channel dimension |
| const int block_c = 8; |
| |
| out->split(2, block_h); |
| out->split(4, block_w); |
| out->reorder({{3, 4}}); |
| // out: [K, C, Ho, Wo, Hi, Wi, 1, 3, 3] |
| |
| out->split(1, block_c); |
| // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 3, 3] |
| |
| auto out_rf = out->rFactor({1, -3, -2, -1}); |
| // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 3, 3] |
| // out_rf: [K, Ci, Ho, Wo, Hi, Wi] |
| |
| // Create a [block_x, block_y] tile on smem |
| inp_cache->computeAt(out, 4); |
| // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi] |
| inp_cache->setMemoryType(MemoryType::Shared); |
| |
| // Move Ci forward |
| out_rf->reorder({{-4, -6}, {-5, -4}, {-6, -5}}); |
| inp_cache->computeAt(out_rf, 5); |
| |
| inp_tile->computeAt(out_rf, -1); |
| w->computeAt(out_rf, -1); |
| |
| out->axis(0)->parallelize(ParallelType::BIDx); |
| out->axis(1)->parallelize(ParallelType::TIDz); |
| out->axis(4)->parallelize(ParallelType::TIDy); |
| out->axis(5)->parallelize(ParallelType::TIDx); |
| |
| scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf}); |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| |
| const int dim_h = 99; |
| const int dim_w = 101; |
| const int dim_c = 10; |
| const int dim_f = 20; |
| const int dim_w_h = 3; |
| const int dim_w_w = 3; |
| const int dim_pad_h = (dim_w_h - 1) / 2; |
| const int dim_pad_w = (dim_w_w - 1) / 2; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::manual_seed(0); |
| at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options); |
| at::Tensor at_w = at::randn({dim_f, dim_c, dim_w_h, dim_w_w}, options); |
| std::vector<IValue> inputs = { |
| at_inp, at_w, dim_w_h, dim_w_w, dim_pad_h, dim_pad_w}; |
| |
| auto cg_outputs = fe.runFusion(inputs); |
| |
| at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis |
| auto at_out = at::conv2d(at_inp, at_w, {}, 1, 1); |
| at_out = at_out.squeeze(0); // drop the N axis |
| |
| testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); |
| } |
| |
| // 5x5 followed by 3x3 |
| TEST(NVFuserTest, FusionConv2DDynamicChain_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| // Input: [K1, H, W] |
| auto inp = makeSymbolicTensor(3); |
| fusion.addInput(inp); |
| |
| // Weights: [K2, K1, S1, T1] |
| auto w1 = makeSymbolicTensor(4); |
| fusion.addInput(w1); |
| |
| // Weights: [K3, K2, S2, T2] |
| auto w2 = makeSymbolicTensor(4); |
| fusion.addInput(w2); |
| |
| auto w1_h = new Int(); |
| fusion.addInput(w1_h); |
| auto w1_w = new Int(); |
| fusion.addInput(w1_w); |
| |
| auto w2_h = new Int(); |
| fusion.addInput(w2_h); |
| auto w2_w = new Int(); |
| fusion.addInput(w2_w); |
| |
| auto pad_h1 = new Int(); |
| fusion.addInput(pad_h1); |
| auto pad_w1 = new Int(); |
| fusion.addInput(pad_w1); |
| |
| auto pad_h2 = new Int(); |
| fusion.addInput(pad_h2); |
| auto pad_w2 = new Int(); |
| fusion.addInput(pad_w2); |
| |
| // Gather a neighbor tile of [w1_h, w1_w] with padding |
| auto inp_tile = gather( |
| inp, |
| {new Int(1), w1_h, w1_w}, |
| {{new Int(0), new Int(0)}, {pad_h1, pad_h1}, {pad_w1, pad_w1}}); |
| // inp_tile: [C, 1, H - w1_h + 1, W - w1_w + 1, w1_h, w1_w] |
| |
| auto inp_bc = |
| broadcast(inp_tile, {true, false, false, false, false, false, false}); |
| auto w1_bc = broadcast(w1, {false, false, true, true, true, false, false}); |
| |
| auto inp_times_w1 = mul(inp_bc, w1_bc); |
| |
| // Reduce the channel and neighbor tile dimensions |
| auto out1 = sum(inp_times_w1, {1, 4, 5, 6}); |
| |
| // Second conv |
| auto out1_tile = gather( |
| out1, |
| {new Int(1), w2_h, w2_w}, |
| {{new Int(0), new Int(0)}, {pad_h2, pad_h2}, {pad_w2, pad_w2}}); |
| |
| auto out1_bc = |
| broadcast(out1_tile, {true, false, false, false, false, false, false}); |
| auto w2_bc = broadcast(w2, {false, false, true, true, true, false, false}); |
| |
| auto out1_times_w2 = mul(out1_bc, w2_bc); |
| |
| auto out2 = sum(out1_times_w2, {1, 4, 5, 6}); |
| |
| fusion.addOutput(out2); |
| |
| //////////////////////////////////// |
| // Cache the input and weight tensors |
| auto inp_cache = inp->cache_after(); |
| |
| // Blocking the spatial dimensions |
| const int block_w = 16; |
| const int block_h = 4; |
| |
| out2->split(2, block_h); |
| out2->split(4, block_w); |
| out2->reorder({{3, 4}}); |
| // out2: [K3, K2, Ho, Wo, Hi, Wi, 1, 3, 3] |
| |
| // Create a [block_x, block_y] tile on smem |
| inp_cache->computeAt(out2, 4); |
| // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi] |
| inp_cache->setMemoryType(MemoryType::Shared); |
| |
| // Move Ci forward |
| out1->reorder({{5, 3}, {3, 4}, {4, 5}}); |
| out1->setMemoryType(MemoryType::Shared); |
| |
| inp_cache->computeAt(out1, 4); |
| |
| inp_tile->computeAt(out1, -1); |
| w1->computeAt(out1, -1); |
| |
| out1_tile->computeAt(out2, -1); |
| w2->computeAt(out2, -1); |
| |
| out2->axis(0)->parallelize(ParallelType::BIDx); |
| out2->axis(4)->parallelize(ParallelType::TIDy); |
| out2->axis(5)->parallelize(ParallelType::TIDx); |
| |
| scheduler_utils::parallelizeAllLike(out2, {inp_cache, out1}); |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| |
| const int dim_h = 99; |
| const int dim_w = 101; |
| const int dim_k1 = 3; |
| const int dim_k2 = 5; |
| const int dim_k3 = 7; |
| const int dim_w1_h = 5; |
| const int dim_w1_w = 5; |
| const int dim_pad1_h = (dim_w1_h - 1) / 2; |
| const int dim_pad1_w = (dim_w1_w - 1) / 2; |
| const int dim_w2_h = 3; |
| const int dim_w2_w = 3; |
| const int dim_pad2_h = (dim_w2_h - 1) / 2; |
| const int dim_pad2_w = (dim_w2_w - 1) / 2; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::manual_seed(0); |
| at::Tensor at_inp = at::randn({dim_k1, dim_h, dim_w}, options); |
| at::Tensor at_w1 = at::randn({dim_k2, dim_k1, dim_w1_h, dim_w1_w}, options); |
| at::Tensor at_w2 = at::randn({dim_k3, dim_k2, dim_w2_h, dim_w2_w}, options); |
| std::vector<IValue> inputs = { |
| at_inp, |
| at_w1, |
| at_w2, |
| dim_w1_h, |
| dim_w1_w, |
| dim_w2_h, |
| dim_w2_w, |
| dim_pad1_h, |
| dim_pad1_w, |
| dim_pad2_h, |
| dim_pad2_w}; |
| |
| auto cg_outputs = fe.runFusion(inputs); |
| |
| at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis |
| auto at_out1 = at::conv2d(at_inp, at_w1, {}, 1, 2); |
| auto at_out2 = at::conv2d(at_out1, at_w2, {}, 1, 1); |
| at_out2 = at_out2.squeeze(0); // drop the N axis |
| |
| testValidate(&fusion, cg_outputs, inputs, {at_out2}, __LINE__, __FILE__); |
| } |
| |
| TEST(NVFuserTest, FusionConv2DStaticEvenSizedWindow_CUDA) { |
| Fusion fusion; |
| FusionGuard fg(&fusion); |
| |
| // Input: [C, H, W] |
| auto inp = makeSymbolicTensor(3); |
| fusion.addInput(inp); |
| |
| // Weights: [K, C, 2, 2] |
| auto w = makeSymbolicTensor(4); |
| fusion.addInput(w); |
| |
| // Gather a neighbor tile of [2, 2] with padding size of 1 only for |
| // the right side of the spatial dimensions. The left padding is |
| // zero so that the output axis stays the same. |
| auto inp_tile = gather(inp, {1, 2, 2}, {{0, 0}, {0, 1}, {0, 1}}); |
| // inp_tile: [C, H, W, 1, 2, 2] |
| |
| auto inp_bc = |
| broadcast(inp_tile, {true, false, false, false, false, false, false}); |
| auto w_bc = broadcast(w, {false, false, true, true, true, false, false}); |
| |
| auto inp_times_w = mul(inp_bc, w_bc); |
| |
| // Reduce the channel and neighbor tile dimensions |
| auto out = sum(inp_times_w, {1, 4, 5, 6}); |
| |
| fusion.addOutput(out); |
| |
| //////////////////////////////////// |
| |
| // Cache the input and weight tensors |
| auto inp_cache = inp->cache_after(); |
| |
| // Blocking the spatial dimensions |
| const int block_w = 16; |
| const int block_h = 4; |
| // Blocking the channel dimension |
| const int block_c = 8; |
| |
| out->split(2, block_h); |
| out->split(4, block_w); |
| out->reorder({{3, 4}}); |
| // out: [K, C, Ho, Wo, Hi, Wi, 1, 2, 2] |
| |
| out->split(1, block_c); |
| // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 2, 2] |
| |
| auto out_rf = out->rFactor({1, -3, -2, -1}); |
| // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 2, 2] |
| // out_rf: [K, Ci, Ho, Wo, Hi, Wi] |
| |
| // Create a [block_x, block_y] tile on smem |
| inp_cache->computeAt(out, 4); |
| // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi] |
| inp_cache->setMemoryType(MemoryType::Shared); |
| |
| // Move Ci forward |
| out_rf->reorder({{-4, -6}, {-5, -4}, {-6, -5}}); |
| inp_cache->computeAt(out_rf, 5); |
| |
| inp_tile->computeAt(out_rf, -1); |
| w->computeAt(out_rf, -1); |
| |
| out->axis(0)->parallelize(ParallelType::BIDx); |
| out->axis(1)->parallelize(ParallelType::TIDz); |
| out->axis(4)->parallelize(ParallelType::TIDy); |
| out->axis(5)->parallelize(ParallelType::TIDx); |
| |
| scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf}); |
| |
| FusionExecutor fe; |
| fe.compileFusion(&fusion); |
| |
| const int dim_h = 99; |
| const int dim_w = 101; |
| const int dim_c = 10; |
| const int dim_f = 20; |
| |
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| at::manual_seed(0); |
| at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options); |
| at::Tensor at_w = at::randn({dim_f, dim_c, 2, 2}, options); |
| std::vector<IValue> inputs = {at_inp, at_w}; |
| |
| auto cg_outputs = fe.runFusion(inputs); |
| |
| at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis |
| auto at_out = at::conv2d(at_inp, at_w, {}, 1, 1); |
| at_out = at_out.squeeze(0); // drop the N axis |
| // The shape of the spatial domain is (dim_h+1)x(dim_w+1), whereas |
| // the fuser output has dim_h*dim_w. Drop the first elements to make |
| // it match with the fuser output. |
| std::vector<at::indexing::TensorIndex> indices{ |
| at::indexing::Slice(0, at::indexing::None), |
| at::indexing::Slice(1, at::indexing::None), |
| at::indexing::Slice(1, at::indexing::None)}; |
| ; |
| at_out = at_out.index(indices); |
| |
| testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); |
| } |
| |
| } // namespace jit |
| } // namespace torch |
| #endif // #if defined(USE_CUDA) |