blob: 9d0cf9b9dae02fa087085379a2083aa09ab6245f [file] [log] [blame]
#include <torch/csrc/jit/codegen/cuda/arith.h>
#include <torch/csrc/jit/codegen/cuda/executor.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
#include <benchmark/benchmark.h>
#include <cuda_runtime.h>
#include "utils.h"
using namespace torch::jit::fuser::cuda;
//------------------------------------------------------------------------------
static void setupSoftmax(
Fusion* fusion,
DataType dtype,
const int reduction_axis) {
TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
FusionGuard fg(fusion);
// setup fusion
auto input = makeContigTensor(2, dtype);
fusion->addInput(input);
if (dtype == DataType::Half) {
input = castOp(DataType::Float, input);
}
auto output = softmax(input, reduction_axis);
if (dtype == DataType::Half) {
output = castOp(DataType::Half, output);
}
fusion->addOutput(output);
}
static void NvFuserScheduler_Softmax(
benchmark::State& benchmark_state,
FusionExecutorCache* fusion_executor_cache,
DataType dtype,
const int reduction_axis) {
TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
std::vector<int64_t> input_shape{
benchmark_state.range(1), benchmark_state.range(0)};
// inputs
at::manual_seed(0);
auto options =
at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
at::Tensor aten_input = at::randn(input_shape, options);
std::vector<c10::IValue> aten_inputs({aten_input});
runBenchmarkIterations(benchmark_state, fusion_executor_cache, aten_inputs);
benchmark_state.SetBytesProcessed(
int64_t(benchmark_state.iterations()) *
(2 * aten_input.numel() * int64_t(dataTypeSize(dtype))));
}
//------------------------------------------------------------------------------
static void Baseline_Softmax(
benchmark::State& benchmark_state,
DataType dtype) {
std::vector<int64_t> input_shape{
benchmark_state.range(1), benchmark_state.range(0)};
const int kReductionAxis = benchmark_state.range(2);
// inputs
at::manual_seed(0);
auto options =
at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
at::Tensor aten_input = at::randn(input_shape, options);
cudaDeviceSynchronize();
for (auto _ : benchmark_state) {
CudaKernelTimer timer;
auto output = at::_softmax(aten_input, kReductionAxis, false);
benchmark_state.SetIterationTime(timer.elapsed() / 1000.0);
cudaDeviceSynchronize();
clearL2Cache();
cudaDeviceSynchronize();
}
benchmark_state.SetBytesProcessed(
int64_t(benchmark_state.iterations()) *
(2 * aten_input.numel() * int64_t(dataTypeSize(dtype))));
}
static void Baseline_Softmax_fp32(benchmark::State& benchmark_state) {
Baseline_Softmax(benchmark_state, DataType::Float);
}
static void Baseline_Softmax_fp16(benchmark::State& benchmark_state) {
Baseline_Softmax(benchmark_state, DataType::Half);
}
//------------------------------------------------------------------------------
static void setupSoftmaxDropout(
Fusion* fusion,
DataType dtype,
const int kReductionAxis) {
TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
FusionGuard fg(fusion);
constexpr int kHiddenSize = 768;
constexpr int kNumAttentionHeads = 12;
constexpr int kAttentionHeadSize = kHiddenSize / kNumAttentionHeads;
constexpr float kDropoutProbability = 0.9;
constexpr float kScale = 1.0f / kDropoutProbability;
// setup fusion
auto attention_scores = makeContigTensor(4, dtype);
auto attention_mask = makeContigTensor(4, dtype);
Double* divisor = new Double();
fusion->addInput(attention_scores);
fusion->addInput(attention_mask);
fusion->addInput(divisor);
if (dtype == DataType::Half) {
attention_scores = castOp(DataType::Float, attention_scores);
attention_mask = castOp(DataType::Float, attention_mask);
}
attention_scores = div(attention_scores, divisor);
attention_scores = add(attention_scores, attention_mask);
auto attention_probs = softmax(attention_scores, kReductionAxis);
auto prob = new Double(kDropoutProbability);
auto scale = new Double(kScale);
auto dropout_results = dropout(attention_probs, prob, scale);
auto output = dropout_results.output;
if (dtype == DataType::Half) {
attention_scores = castOp(DataType::Half, attention_scores);
attention_probs = castOp(DataType::Half, attention_probs);
output = castOp(DataType::Half, output);
}
fusion->addOutput(attention_scores);
fusion->addOutput(attention_probs);
fusion->addOutput(output);
fusion->addOutput(dropout_results.mask);
}
static void NvFuserScheduler_SoftmaxDropout(
benchmark::State& benchmark_state,
FusionExecutorCache* fusion_executor_cache,
DataType dtype,
const int kReductionAxis) {
TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
// reduce across 1, [256, 12, 100, 8]
std::vector<int64_t> input_shape{256, 12, 100, benchmark_state.range(0)};
constexpr int kHiddenSize = 768;
constexpr int kNumAttentionHeads = 12;
constexpr int kAttentionHeadSize = kHiddenSize / kNumAttentionHeads;
constexpr float kDropoutProbability = 0.9;
constexpr float kScale = 1.0f / kDropoutProbability;
// inputs
at::manual_seed(0);
auto options =
at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
at::Tensor at_scores = at::randn(input_shape, options);
at::Tensor at_mask = at::randn(input_shape, options);
std::vector<c10::IValue> aten_inputs(
{at_scores, at_mask, sqrt(kAttentionHeadSize)});
runBenchmarkIterations(benchmark_state, fusion_executor_cache, aten_inputs);
// 5 dtype: attention_scores + attention_mask + attention_scores_out +
// attention_probs_out + output
// 1 bool: dropout_results.mask
// All the same size
benchmark_state.SetBytesProcessed(
int64_t(benchmark_state.iterations()) * 5 * at_scores.numel() *
int64_t(dataTypeSize(dtype)) +
// bool mask
int64_t(benchmark_state.iterations()) * at_scores.numel() *
int64_t(dataTypeSize(DataType::Bool)));
}
//------------------------------------------------------------------------------
static void Baseline_Softmax_Dropout(
benchmark::State& benchmark_state,
const int kReductionAxis,
DataType dtype) {
std::vector<int64_t> input_shape{256, 12, 100, benchmark_state.range(0)};
constexpr int kHiddenSize = 768;
constexpr int kNumAttentionHeads = 12;
constexpr float kDropoutProbability = 0.1;
constexpr int kAttentionHeadSize = kHiddenSize / kNumAttentionHeads;
// inputs
at::manual_seed(0);
auto options =
at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
at::Tensor attention_scores = at::randn(input_shape, options);
at::Tensor at_y = at::randn(input_shape, options);
cudaDeviceSynchronize();
for (auto _ : benchmark_state) {
// Create
CudaKernelTimer timer;
// Run
attention_scores = attention_scores / sqrt(kAttentionHeadSize);
attention_scores = attention_scores + at_y;
auto attention_probs =
at::_softmax(attention_scores, kReductionAxis, false);
attention_probs = at::dropout(attention_probs, kDropoutProbability, true);
// Record
benchmark_state.SetIterationTime(timer.elapsed() / 1000.0);
cudaDeviceSynchronize();
clearL2Cache();
cudaDeviceSynchronize();
}
// 5 dtype: attention_scores + attention_mask + attention_scores_out +
// attention_probs_out + output
// 1 bool: dropout_results.mask
// All the same size
benchmark_state.SetBytesProcessed(
int64_t(benchmark_state.iterations()) * 5 * attention_scores.numel() *
int64_t(dataTypeSize(dtype)) +
// bool mask
int64_t(benchmark_state.iterations()) * attention_scores.numel() *
int64_t(dataTypeSize(DataType::Bool)));
}
//------------------------------------------------------------------------------
static void Baseline_Softmax_Dropout_Inner_fp32(
benchmark::State& benchmark_state) {
Baseline_Softmax_Dropout(benchmark_state, 3, DataType::Float);
}
static void Baseline_Softmax_Dropout_Outer_fp32(
benchmark::State& benchmark_state) {
Baseline_Softmax_Dropout(benchmark_state, 1, DataType::Float);
}
static void Baseline_Softmax_Dropout_Inner_fp16(
benchmark::State& benchmark_state) {
Baseline_Softmax_Dropout(benchmark_state, 3, DataType::Half);
}
static void Baseline_Softmax_Dropout_Outer_fp16(
benchmark::State& benchmark_state) {
Baseline_Softmax_Dropout(benchmark_state, 1, DataType::Half);
}
//------------------------------------------------------------------------------
NVFUSER_BENCHMARK_DEFINE(
NvFuserScheduler_Softmax_Outer_fp32,
setupSoftmax,
NvFuserScheduler_Softmax,
DataType::Float,
0);
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_Outer_fp32)
->RangeMultiplier(2)
->Ranges({{656, 656}, {8, 8 << 12}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();
NVFUSER_BENCHMARK_DEFINE(
NvFuserScheduler_Softmax_Inner_fp32,
setupSoftmax,
NvFuserScheduler_Softmax,
DataType::Float,
1);
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_Inner_fp32)
->RangeMultiplier(2)
->Ranges({{656, 656}, {8, 8 << 12}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();
NVFUSER_BENCHMARK_DEFINE(
NvFuserScheduler_Softmax_Outer_fp16,
setupSoftmax,
NvFuserScheduler_Softmax,
DataType::Half,
0);
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_Outer_fp16)
->RangeMultiplier(2)
->Ranges({{656, 656}, {8, 8 << 12}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();
NVFUSER_BENCHMARK_DEFINE(
NvFuserScheduler_Softmax_Inner_fp16,
setupSoftmax,
NvFuserScheduler_Softmax,
DataType::Half,
1);
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_Inner_fp16)
->RangeMultiplier(2)
->Ranges({{656, 656}, {8, 8 << 12}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();
NVFUSER_BENCHMARK_DEFINE(
NvFuserScheduler_Softmax_Dropout_Inner_fp32,
setupSoftmaxDropout,
NvFuserScheduler_SoftmaxDropout,
DataType::Float,
3);
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_Dropout_Inner_fp32)
->Arg(8)
->Arg(16)
->Arg(24)
->Arg(32)
->Arg(40)
->Arg(48)
->Arg(56)
->Arg(64)
->Arg(72)
->Arg(80)
->Arg(88)
->Arg(96)
->Arg(104)
->Arg(112)
->Arg(120)
->Arg(128)
->Unit(benchmark::kMicrosecond)
->UseManualTime();
NVFUSER_BENCHMARK_DEFINE(
NvFuserScheduler_Softmax_Dropout_Outer_fp32,
setupSoftmaxDropout,
NvFuserScheduler_SoftmaxDropout,
DataType::Float,
1);
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_Dropout_Outer_fp32)
->Arg(8)
->Arg(16)
->Arg(24)
->Arg(32)
->Arg(40)
->Arg(48)
->Arg(56)
->Arg(64)
->Arg(72)
->Arg(80)
->Arg(88)
->Arg(96)
->Arg(104)
->Arg(112)
->Arg(120)
->Arg(128)
->Unit(benchmark::kMicrosecond)
->UseManualTime();
NVFUSER_BENCHMARK_DEFINE(
NvFuserScheduler_Softmax_Dropout_Inner_fp16,
setupSoftmaxDropout,
NvFuserScheduler_SoftmaxDropout,
DataType::Half,
3);
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_Dropout_Inner_fp16)
->Arg(8)
->Arg(16)
->Arg(24)
->Arg(32)
->Arg(40)
->Arg(48)
->Arg(56)
->Arg(64)
->Arg(72)
->Arg(80)
->Arg(88)
->Arg(96)
->Arg(104)
->Arg(112)
->Arg(120)
->Arg(128)
->Unit(benchmark::kMicrosecond)
->UseManualTime();
NVFUSER_BENCHMARK_DEFINE(
NvFuserScheduler_Softmax_Dropout_Outer_fp16,
setupSoftmaxDropout,
NvFuserScheduler_SoftmaxDropout,
DataType::Half,
1);
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_Dropout_Outer_fp16)
->Arg(8)
->Arg(16)
->Arg(24)
->Arg(32)
->Arg(40)
->Arg(48)
->Arg(56)
->Arg(64)
->Arg(72)
->Arg(80)
->Arg(88)
->Arg(96)
->Arg(104)
->Arg(112)
->Arg(120)
->Arg(128)
->Unit(benchmark::kMicrosecond)
->UseManualTime();
//------------------------------------------------------------------------------
BENCHMARK(Baseline_Softmax_fp32)
->RangeMultiplier(2)
->Ranges({{656, 656}, {8, 8 << 12}, {0, 1}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();
BENCHMARK(Baseline_Softmax_fp16)
->RangeMultiplier(2)
->Ranges({{656, 656}, {8, 8 << 12}, {0, 1}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();
BENCHMARK(Baseline_Softmax_Dropout_Inner_fp32)
->Arg(8)
->Arg(16)
->Arg(24)
->Arg(32)
->Arg(40)
->Arg(48)
->Arg(56)
->Arg(64)
->Arg(72)
->Arg(80)
->Arg(88)
->Arg(96)
->Arg(104)
->Arg(112)
->Arg(120)
->Arg(128)
->Unit(benchmark::kMicrosecond)
->UseManualTime();
BENCHMARK(Baseline_Softmax_Dropout_Outer_fp32)
->Arg(8)
->Arg(16)
->Arg(24)
->Arg(32)
->Arg(40)
->Arg(48)
->Arg(56)
->Arg(64)
->Arg(72)
->Arg(80)
->Arg(88)
->Arg(96)
->Arg(104)
->Arg(112)
->Arg(120)
->Arg(128)
->Unit(benchmark::kMicrosecond)
->UseManualTime();
BENCHMARK(Baseline_Softmax_Dropout_Inner_fp16)
->Arg(8)
->Arg(16)
->Arg(24)
->Arg(32)
->Arg(40)
->Arg(48)
->Arg(56)
->Arg(64)
->Arg(72)
->Arg(80)
->Arg(88)
->Arg(96)
->Arg(104)
->Arg(112)
->Arg(120)
->Arg(128)
->Unit(benchmark::kMicrosecond)
->UseManualTime();
BENCHMARK(Baseline_Softmax_Dropout_Outer_fp16)
->Arg(8)
->Arg(16)
->Arg(24)
->Arg(32)
->Arg(40)
->Arg(48)
->Arg(56)
->Arg(64)
->Arg(72)
->Arg(80)
->Arg(88)
->Arg(96)
->Arg(104)
->Arg(112)
->Arg(120)
->Arg(128)
->Unit(benchmark::kMicrosecond)
->UseManualTime();