blob: 4046907dcf5496f9d9b5c10a7460cd51cb832d02 [file] [log] [blame]
#include <torch/csrc/autograd/profiler_utils.h>
namespace torch { namespace autograd { namespace profiler {
static constexpr auto kConv2dStride = 3;
static constexpr auto kConv2dPadding = 4;
static constexpr auto kConv2dDilation = 5;
static constexpr auto kConv2dGroups = 6;
// List of supported operators
static constexpr auto kConv2dOp = "aten::conv2d";
static constexpr auto kMMOp = "aten::mm";
static constexpr auto kAddMMOp = "aten::addmm";
static constexpr auto kMulOp = "aten::mul";
static constexpr auto kAddOp = "aten::add";
static constexpr auto kBMMOp = "aten::bmm";
static constexpr auto kBAddBMMOp = "aten::baddbmm";
static constexpr auto kInputSize = "input_size";
static constexpr auto kWeightSize = "weight_size";
static constexpr auto kGroups = "groups";
static constexpr auto kPadding = "padding";
static constexpr auto kStride = "stride";
static constexpr auto kDilation = "dilation";
static constexpr auto kMatSize = "mat_size";
static constexpr auto kMat1Size = "mat1_size";
static constexpr auto kMat2Size = "mat2_size";
static bool validateInput(const std::string &op_name, size_t min_size,
const std::vector<c10::IValue>& inputs,
const std::vector<int>& should_be_tensor) {
std::stringstream ss;
if (inputs.size() < min_size) {
ss << "Failed to save extra arguments for flops compuation of op "
<< op_name
<< ", min size: " << min_size
<< ", actual size: " << inputs.size();
TORCH_WARN(ss.str());
return false;
}
for (auto index : should_be_tensor) {
if (!inputs[index].isTensor()) {
ss << "Failed to save extra arguments for flops compuation of op "
<< op_name
<< ", input[" << index
<< "] must be a tensor.";
TORCH_WARN(ss.str());
return false;
}
}
return true;
}
std::unordered_map<std::string, c10::IValue> saveExtraArgs(const at::RecordFunction& fn) {
// for specific types of fn, return the saved extra args for computing flops
std::unordered_map<std::string, c10::IValue> map;
std::vector<c10::IValue> inputs = fn.inputs();
std::string fname(fn.name());
if (inputs.empty()) {
// Input shape is unavailable, return empty map
return map;
}
if (fname == kConv2dOp) {
std::vector<int> tensors{0, 1};
bool check = validateInput(fname, kConv2dGroups + 1, inputs, tensors);
if (!check) {
return map;
}
at::Tensor input = inputs[0].toTensor();
at::Tensor weight = inputs[1].toTensor();
if (weight.sizes().size() != 4) {
TORCH_WARN("Failed to compute flops for op aten::conv2d because it requires a 4D kernel tensor.");
return map;
}
map[kInputSize] = at::IValue(input.sizes());
map[kWeightSize] = at::IValue(weight.sizes());
map[kStride] = inputs[kConv2dStride];
map[kPadding] = inputs[kConv2dPadding];
map[kDilation] = inputs[kConv2dDilation];
map[kGroups] = inputs[kConv2dGroups];
} else if (fname == kMMOp) {
std::vector<int> tensors{0, 1};
bool check = validateInput(fname, 2, inputs, tensors);
if (!check) {
return map;
}
at::Tensor left = inputs[0].toTensor();
at::Tensor right = inputs[1].toTensor();
map[kMat1Size] = at::IValue(left.sizes());
map[kMat2Size] = at::IValue(right.sizes());
} else if (fname == kAddMMOp) {
std::vector<int> tensors{0, 1, 2};
bool check = validateInput(fname, 3, inputs, tensors);
if (!check) {
return map;
}
// Exact FLOP count depends on scaling factors alpha and beta but
// just assume these are +=1.
// (similar to http://www.netlib.org/lapack/lawnspdf/lawn41.pdf,
// "Operations Count for the BLAS and LAPACK", Table 3, SGEMM)
at::Tensor left = inputs[1].toTensor();
at::Tensor right = inputs[2].toTensor();
map[kMat1Size] = at::IValue(left.sizes());
map[kMat2Size] = at::IValue(right.sizes());
} else if (fname == kMulOp) {
std::vector<int> tensors{0};
bool check = validateInput(fname, 1, inputs, tensors);
if (!check) {
return map;
}
at::Tensor mat = inputs[0].toTensor();
map[kMatSize] = at::IValue(mat.sizes());
} else if (fname == kAddOp) {
std::vector<int> tensors{0};
bool check = validateInput(fname, 1, inputs, tensors);
if (!check) {
return map;
}
at::Tensor mat = inputs[0].toTensor();
map[kMatSize] = at::IValue(mat.sizes());
} else if (fname == kBMMOp) {
std::vector<int> tensors{0, 1};
bool check = validateInput(fname, 2, inputs, tensors);
if (!check) {
return map;
}
at::Tensor left = inputs[0].toTensor();
at::Tensor right = inputs[1].toTensor();
map[kMat1Size] = at::IValue(left.sizes());
map[kMat2Size] = at::IValue(right.sizes());
} else if (fname == kBAddBMMOp) {
std::vector<int> tensors{0, 1, 2};
bool check = validateInput(fname, 3, inputs, tensors);
if (!check) {
return map;
}
// Exact FLOP count depends on scaling factors alpha and beta but
// just assume these are +=1.
// (similar to http://www.netlib.org/lapack/lawnspdf/lawn41.pdf,
// "Operations Count for the BLAS and LAPACK", Table 3, SGEMM)
at::Tensor left = inputs[1].toTensor();
at::Tensor right = inputs[2].toTensor();
map[kMat1Size] = at::IValue(left.sizes());
map[kMat2Size] = at::IValue(right.sizes());
}
return map;
}
uint64_t computeFlops(const std::string &op_name, const std::unordered_map<std::string, c10::IValue> &extra_args) {
if (op_name == kConv2dOp) {
if (extra_args.find(kInputSize) == extra_args.end()
|| extra_args.find(kWeightSize) == extra_args.end()
|| extra_args.find(kGroups) == extra_args.end()
|| extra_args.find(kPadding) == extra_args.end()
|| extra_args.find(kStride) == extra_args.end()
|| extra_args.find(kDilation) == extra_args.end()) {
TORCH_WARN("Calculating flops for aten::conv2d requires groups, padding, stride, dilation, input_size, and weight_size in saved arguments.");
return 0;
}
auto input_sizes_ref = extra_args.at(kInputSize);
auto kernel_sizes_ref = extra_args.at(kWeightSize);
auto groups_ref = extra_args.at(kGroups);
auto padding_ref = extra_args.at(kPadding);
auto stride_ref = extra_args.at(kStride);
auto dilation_ref = extra_args.at(kDilation);
if (!input_sizes_ref.isIntList() || !kernel_sizes_ref.isIntList()) {
TORCH_WARN("Failed to compute flops for op aten::conv2d because it requires input and weight tensor sizes.");
return 0;
}
if (!padding_ref.isIntList() || !stride_ref.isIntList() || !dilation_ref.isIntList()) {
TORCH_WARN("Failed to compute flops for op aten::conv2d because it requires padding, stride, and dilation values.");
return 0;
}
const std::vector<int64_t> input_sizes = input_sizes_ref.toIntVector();
const std::vector<int64_t> kernel_sizes = kernel_sizes_ref.toIntVector();
const uint64_t groups = groups_ref.toInt();
const std::vector<int64_t> padding = padding_ref.toIntVector();
const std::vector<int64_t> stride = stride_ref.toIntVector();
const std::vector<int64_t> dilation = dilation_ref.toIntVector();
if (input_sizes.size() != 4 || kernel_sizes.size() != 4) {
TORCH_WARN("Failed to compute flops for op aten::conv2d because both input and weight must be size 4.");
return 0;
}
if (!groups) {
TORCH_WARN("Failed to compute flops for op aten::conv2d because group size must not be 0.");
return 0;
}
if (padding.size() != 2 || dilation.size() != 2) {
TORCH_WARN("Failed to compute flops for op aten::conv2d because both padding and dilation must be size 2.");
return 0;
}
if (stride.size() != 2 || (stride[0] * stride[1] == 0)) {
TORCH_WARN("Failed to compute flops for op aten::conv2d because stride must be size 2 and cannot be 0.");
return 0;
}
// format of the input is defined in torch.nn.quantized.functional.conv2d()
uint64_t minibatch = 0, in_channels = 0, input_h = 0, input_w = 0;
uint64_t out_channels = 0, kernel_h = 0, kernel_w = 0;
const uint64_t conv2d_multiply_factor = 2;
std::tie(minibatch, in_channels, input_h, input_w) = std::make_tuple(input_sizes[0], input_sizes[1],
input_sizes[2], input_sizes[3]);
std::tie(out_channels, std::ignore, kernel_h, kernel_w) = std::make_tuple(kernel_sizes[0], kernel_sizes[1],
kernel_sizes[2], kernel_sizes[3]);
uint64_t output_h = (input_h + 2 * padding[0] - dilation[0] * (kernel_h - 1) - 1) / stride[0] + 1;
uint64_t output_w = (input_w + 2 * padding[1] - dilation[1] * (kernel_w - 1) - 1) / stride[1] + 1;
return conv2d_multiply_factor * minibatch * output_h * output_w *
kernel_h * kernel_w * in_channels * out_channels / groups;
} else if (op_name == kMMOp || op_name == kAddMMOp) {
if (extra_args.find(kMat1Size) == extra_args.end()
|| extra_args.find(kMat2Size) == extra_args.end()) {
TORCH_WARN("Calculating flops for ", op_name, " requires mat1_size and mat2_size in saved arguments.");
return 0;
}
auto mat1_sizes_ref = extra_args.at(kMat1Size);
auto mat2_sizes_ref = extra_args.at(kMat2Size);
if (!mat1_sizes_ref.isIntList() || !mat2_sizes_ref.isIntList()) {
TORCH_WARN("Failed to compute flops for op ", op_name, " because it requires mat1_size and mat2_size to be IntList.");
return 0;
}
std::vector<int64_t> mat1_size = mat1_sizes_ref.toIntVector();
std::vector<int64_t> mat2_size = mat2_sizes_ref.toIntVector();
if (mat1_size.size() == 0) {
return 0;
}
int64_t overlap_dim = mat1_size.back();
if (overlap_dim == 0) {
return 0;
}
const uint64_t gemm_multiply_factor = 2;
uint64_t flops = 1;
for(int64_t dim : mat1_size) {
flops *= dim;
}
flops /= overlap_dim;
for(int64_t dim : mat2_size) {
flops *= dim;
}
flops *= gemm_multiply_factor;
return flops;
} else if (op_name == kBMMOp || op_name == kBAddBMMOp) {
if (extra_args.find(kMat1Size) == extra_args.end()
|| extra_args.find(kMat2Size) == extra_args.end()) {
TORCH_WARN("Calculating flops for ", op_name, " requires mat1_size and mat2_size in saved arguments.");
return 0;
}
auto mat1_sizes_ref = extra_args.at(kMat1Size);
auto mat2_sizes_ref = extra_args.at(kMat2Size);
if (!mat1_sizes_ref.isIntList() || !mat2_sizes_ref.isIntList()) {
TORCH_WARN("Failed to compute flops for op ", op_name, " because it requires mat1_size and mat2_size to be IntList.");
return 0;
}
std::vector<int64_t> mat1_size = mat1_sizes_ref.toIntVector();
std::vector<int64_t> mat2_size = mat2_sizes_ref.toIntVector();
if (mat1_size.size() == 0) {
return 0;
}
int64_t batch_size = mat1_size.front();
if (batch_size == 0) {
return 0;
}
int64_t overlap_dim = mat1_size.back();
if (overlap_dim == 0) {
return 0;
}
const uint64_t gemm_multiply_factor = 2;
uint64_t flops = 1;
for(int64_t dim : mat1_size) {
flops *= dim;
}
flops /= overlap_dim;
flops /= batch_size;
for(int64_t dim : mat2_size) {
flops *= dim;
}
flops *= gemm_multiply_factor;
return flops;
} else if (op_name == kMulOp) {
if (extra_args.find(kMatSize) == extra_args.end()) {
TORCH_WARN("Calculating flops for aten::mul.Tensor requires mat_size in saved arguments.");
return 0;
}
auto mat_sizes = extra_args.at(kMatSize);
if (!mat_sizes.isIntList()) {
TORCH_WARN("Failed to compute flops for op aten::mul because it requires mat_size to be IntList.");
return 0;
}
std::vector<int64_t> mat_size = mat_sizes.toIntVector();
uint64_t flops = 1;
for(int64_t dim : mat_size) {
flops *= dim;
}
return flops;
} else if (op_name == kAddOp) {
if (extra_args.find(kMatSize) == extra_args.end()) {
TORCH_WARN("Calculating flops for aten::add.Tensor requires mat_size in saved arguments.");
return 0;
}
auto mat_sizes = extra_args.at(kMatSize);
if (!mat_sizes.isIntList()) {
TORCH_WARN("Failed to compute flops for op aten::add because it requires mat_size to be IntList.");
return 0;
}
std::vector<int64_t> mat_size = mat_sizes.toIntVector();
uint64_t flops = 1;
for(int64_t dim : mat_size) {
flops *= dim;
}
return flops;
}
return 0;
}
} // namespace profiler
} // namespace autograd
} // namespace torch