blob: 05475f114474364d396878428626c41d466440d4 [file] [log] [blame]
#include <torch/csrc/jit/codegen/cuda/arith.h>
#include <torch/csrc/jit/codegen/cuda/executor.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
#include <benchmark/benchmark.h>
#include <cuda_runtime.h>
#include <benchmarks/cpp/nvfuser/utils.h>
using namespace torch::jit::fuser::cuda;
static void setupInstanceNorm(
Fusion* fusion,
DataType dtype,
bool channels_last_3d = false) {
TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
FusionGuard fg(fusion);
auto input = makeContigTensor(4, dtype);
if (channels_last_3d) {
input = makeContigTensor(5, dtype);
}
auto weight = makeContigTensor(1, dtype);
auto bias = makeContigTensor(1, dtype);
auto running_mean = makeContigTensor(1, DataType::Float);
auto running_var = makeContigTensor(1, DataType::Float);
fusion->addInput(input);
fusion->addInput(weight);
fusion->addInput(bias);
fusion->addInput(running_mean);
fusion->addInput(running_var);
if (dtype == DataType::Half) {
input = castOp(DataType::Float, input);
weight = castOp(DataType::Float, weight);
bias = castOp(DataType::Float, bias);
}
const bool kTraining = true;
const float kMomentum = 0.1;
const float kEps = 1e-5;
auto momentum_ptr = IrBuilder::create<Double>(kMomentum);
auto eps_ptr = IrBuilder::create<Double>(kEps);
auto norm = instance_norm(
input,
weight,
bias,
running_mean,
running_var,
kTraining,
momentum_ptr,
eps_ptr,
channels_last_3d);
auto output = unaryOp(UnaryOpType::Relu, norm.output);
if (dtype == DataType::Half) {
output = castOp(DataType::Half, output);
}
fusion->addOutput(output);
}
//------------------------------------------------------------------------------
static void NvFuserScheduler_InstanceNorm(
benchmark::State& benchmark_state,
FusionExecutorCache* fusion_executor_cache,
DataType dtype,
bool channels_last_3d = false) {
TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
std::vector<int64_t> input_shape{
benchmark_state.range(0),
benchmark_state.range(2),
benchmark_state.range(1),
benchmark_state.range(1)};
std::vector<int64_t> input_shape_3d{
benchmark_state.range(0),
benchmark_state.range(1),
benchmark_state.range(1),
benchmark_state.range(1),
benchmark_state.range(2)};
// inputs
at::manual_seed(0);
auto options =
at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
auto fp32_options =
at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor at_x =
at::randn(channels_last_3d ? input_shape_3d : input_shape, options);
at::Tensor at_weight = at::ones({benchmark_state.range(2)}, options);
at::Tensor at_bias = at::zeros({benchmark_state.range(2)}, options);
at::Tensor at_mean = at::zeros({benchmark_state.range(2)}, fp32_options);
at::Tensor at_var = at::ones({benchmark_state.range(2)}, fp32_options);
std::vector<c10::IValue> aten_inputs = {
at_x, at_weight, at_bias, at_mean, at_var};
std::vector<at::Tensor> outputs;
runBenchmarkIterations(benchmark_state, fusion_executor_cache, aten_inputs);
const size_t kChannels = benchmark_state.range(2);
// Read: x, weight, bias, running_mean, running_var
// Write: y, running_mean, running_var
benchmark_state.SetBytesProcessed(
benchmark_state.iterations() *
((kChannels * 2 + at_x.numel() * 2) * dataTypeSize(dtype) +
(kChannels * 2 * 2) * dataTypeSize(DataType::Float)));
}
static void Baseline_InstanceNorm(
benchmark::State& benchmark_state,
DataType dtype,
bool channels_last_3d = false) {
TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
std::vector<int64_t> input_shape{
benchmark_state.range(0),
benchmark_state.range(2),
benchmark_state.range(1),
benchmark_state.range(1)};
std::vector<int64_t> input_shape_3d{
benchmark_state.range(0),
benchmark_state.range(2),
benchmark_state.range(1),
benchmark_state.range(1),
benchmark_state.range(1),
};
const float kMomentum = 0.1;
const float kEps = 1e-5;
const auto aten_dtype = data_type_to_aten(dtype);
at::manual_seed(0);
auto options = at::TensorOptions().dtype(aten_dtype).device(at::kCUDA, 0);
auto fp32_options =
at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor at_x = at::randn(input_shape, options);
if (channels_last_3d) {
at_x = at::randn(
input_shape_3d,
options.memory_format(c10::MemoryFormat::ChannelsLast3d));
}
at::Tensor at_weight = at::ones({benchmark_state.range(2)}, options);
at::Tensor at_bias = at::zeros({benchmark_state.range(2)}, options);
at::Tensor at_mean = at::zeros({benchmark_state.range(2)}, fp32_options);
at::Tensor at_var = at::ones({benchmark_state.range(2)}, fp32_options);
auto ato_weight = c10::optional<at::Tensor>(at_weight);
auto ato_bias = c10::optional<at::Tensor>(at_bias);
auto ato_running_mean = c10::optional<at::Tensor>(at_mean);
auto ato_running_var = c10::optional<at::Tensor>(at_var);
clearL2Cache();
C10_CUDA_CHECK(cudaDeviceSynchronize());
for (auto _ : benchmark_state) {
CudaKernelTimer timer;
auto norm = at::instance_norm(
at_x,
ato_weight,
ato_bias,
ato_running_mean,
ato_running_var,
true,
kMomentum,
kEps,
false);
auto output = at::relu(norm);
benchmark_state.SetIterationTime(timer.elapsed() / 1000.0);
C10_CUDA_CHECK(cudaDeviceSynchronize());
clearL2Cache();
C10_CUDA_CHECK(cudaDeviceSynchronize());
}
const size_t kChannels = benchmark_state.range(2);
// Read: x, weight, bias, running_mean, running_var
// Write: y, running_mean, running_var
benchmark_state.SetBytesProcessed(
benchmark_state.iterations() *
((kChannels * 2 + at_x.numel() * 2) * dataTypeSize(dtype) +
(kChannels * 2 * 2) * dataTypeSize(DataType::Float)));
}
//------------------------------------------------------------------------------
static void Baseline_InstanceNorm_fp32(benchmark::State& benchmark_state) {
Baseline_InstanceNorm(benchmark_state, DataType::Float);
}
static void Baseline_InstanceNorm_fp16(benchmark::State& benchmark_state) {
Baseline_InstanceNorm(benchmark_state, DataType::Half);
}
static void Baseline_InstanceNorm_fp32_channels_last_3d(
benchmark::State& benchmark_state) {
Baseline_InstanceNorm(benchmark_state, DataType::Float, true);
}
//------------------------------------------------------------------------------
NVFUSER_BENCHMARK_DEFINE(
NvFuserScheduler_InstanceNorm_fp32,
setupInstanceNorm,
NvFuserScheduler_InstanceNorm,
DataType::Float);
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_InstanceNorm_fp32)
// ->RangeMultiplier(2)
->Ranges({{8, 8}, {640, 640}, {64, 128}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();
NVFUSER_BENCHMARK_DEFINE(
NvFuserScheduler_InstanceNorm_fp16,
setupInstanceNorm,
NvFuserScheduler_InstanceNorm,
DataType::Half);
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_InstanceNorm_fp16)
// ->RangeMultiplier(2)
->Ranges({{8, 8}, {640, 640}, {64, 256}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();
NVFUSER_BENCHMARK_DEFINE(
NvFuserScheduler_InstanceNorm3d_channels_last_fp32,
setupInstanceNorm,
NvFuserScheduler_InstanceNorm,
DataType::Float,
true);
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_InstanceNorm3d_channels_last_fp32)
->RangeMultiplier(2)
->Ranges({{1, 8}, {128, 128}, {32, 32}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_InstanceNorm3d_channels_last_fp32)
->RangeMultiplier(2)
->Ranges({{1, 8}, {64, 64}, {64, 64}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_InstanceNorm3d_channels_last_fp32)
->RangeMultiplier(2)
->Ranges({{1, 8}, {32, 32}, {128, 128}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_InstanceNorm3d_channels_last_fp32)
->RangeMultiplier(2)
->Ranges({{1, 8}, {16, 16}, {256, 256}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_InstanceNorm3d_channels_last_fp32)
->RangeMultiplier(2)
->Ranges({{1, 8}, {4, 8}, {320, 320}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();
//------------------------------------------------------------------------------
BENCHMARK(Baseline_InstanceNorm_fp32)
// ->RangeMultiplier(2)
->Ranges({{8, 8}, {640, 640}, {64, 128}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();
BENCHMARK(Baseline_InstanceNorm_fp16)
// ->RangeMultiplier(2)
->Ranges({{8, 8}, {640, 640}, {64, 256}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();
BENCHMARK(Baseline_InstanceNorm_fp32_channels_last_3d)
->RangeMultiplier(2)
->Ranges({{2, 8}, {128, 128}, {32, 32}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();
BENCHMARK(Baseline_InstanceNorm_fp32_channels_last_3d)
->RangeMultiplier(2)
->Ranges({{2, 8}, {64, 64}, {64, 64}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();
BENCHMARK(Baseline_InstanceNorm_fp32_channels_last_3d)
->RangeMultiplier(2)
->Ranges({{2, 8}, {16, 16}, {256, 256}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();
BENCHMARK(Baseline_InstanceNorm_fp32_channels_last_3d)
->RangeMultiplier(2)
->Ranges({{2, 8}, {4, 8}, {320, 320}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();
//------------------------------------------------------------------------------