blob: ef50e707ff9477a670733ee401441328d3a09798 [file] [log] [blame]
#include <torch/csrc/jit/fuser/codegen.h>
#include <ATen/ATen.h>
#include <torch/csrc/jit/assertions.h>
#include <torch/csrc/jit/code_template.h>
#include <torch/csrc/jit/fuser/compiler.h>
#include <torch/csrc/jit/fuser/config.h>
#include <torch/csrc/jit/fuser/interface.h>
#include <torch/csrc/jit/fuser/tensor_info.h>
#include <torch/csrc/jit/ir.h>
#if USE_CUDA_FUSER
#include <torch/csrc/jit/fuser/cuda/resource_strings.h>
#endif
#if USE_CPU_FUSER
#include <torch/csrc/jit/fuser/cpu/resource_strings.h>
#endif
#include <cmath>
#include <cstdint>
#include <iostream>
#include <sstream>
#include <tuple>
#include <vector>
namespace torch {
namespace jit {
namespace fuser {
// Template for computing the offset into the tensor to access a value
static auto dim_calc = CodeTemplate(R"(
//printf("tensor ${tensor} sizes[${d}] = %d, strides[${d}] = %d\n", ${tensor}.sizes[${d}],${tensor}.strides[${d}]);
size_t ${tensor}_dimIndex${d} = ${tensor}_linearIndex ${mod_sizes};
${tensor}_offset += ${tensor}_dimIndex${d} ${times_stride};
)");
static std::string valueName(const Value* n) {
return "n" + std::to_string(n->unique());
}
static std::string scalarValue(const int64_t v) {
return std::to_string(v);
}
static std::string scalarValue(const bool v) {
return std::to_string(v);
}
// Note: The NAN, NEG_INFINITY and POS_INFINITY strings map to device-specific
// implementations of these special values. These macros are found in the
// resource strings for each device.
static std::string scalarValue(const double v) {
std::ostringstream out;
if (std::isnan(v)) {
out << "NAN";
} else if (std::isinf(v)) {
if (v < 0) {
out << "NEG_INFINITY";
} else {
out << "POS_INFINITY";
}
} else {
out << std::scientific << v << "f";
}
return out.str();
}
// Note: Half is special-cased to avoid returning at::Half
static const char* scalarTypeName(const at::ScalarType type) {
if (type == at::ScalarType::Half) {
return "half";
}
switch (type) {
#define DEFINE_CASE(ctype, name, _) \
case at::ScalarType::name: \
return #ctype;
AT_FORALL_SCALAR_TYPES_EXCEPT_HALF(DEFINE_CASE)
#undef DEFINE_CASE
default:
throw std::runtime_error("unknown scalar type");
}
}
static const char* calcScalarTypeName(const at::ScalarType type) {
if (type == at::ScalarType::Half) {
return "float";
}
return scalarTypeName(type);
}
static std::string variableType(const std::shared_ptr<c10::Type>& t) {
if (t->kind() == TypeKind::IntType) {
return "int";
} else if (t->kind() == TypeKind::FloatType) {
return "float";
} else if (t->kind() == TypeKind::BoolType) {
return "bool";
} else if (t->kind() == TypeKind::TensorType) {
auto const tt = t->cast<TensorType>();
return calcScalarTypeName(tt->scalarType());
}
// something went wrong with the type analysis during shape propagation
throw std::runtime_error(
"unknown scalar type during JIT fusion code generation");
}
static std::string typeCastedValueName(
const std::shared_ptr<c10::Type>& t,
const at::ScalarType outtype,
const std::string& vn) {
if (t->kind() == TypeKind::IntType || t->kind() == TypeKind::BoolType) {
if (!isIntegralType(outtype)) {
return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
}
return vn;
} else if (t->kind() == TypeKind::FloatType) {
if (!isFloatingType(outtype)) {
return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
}
return vn;
} else if (t->kind() == TypeKind::TensorType) {
auto const tt = t->cast<TensorType>();
if (tt->scalarType() != outtype) {
return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
}
return vn;
}
// something went wrong with the type analysis during shape propagation
throw std::runtime_error(
"unknown scalar type during JIT fusion code generation");
}
// Writes "simple mappable" ops
static std::string encodeRHS(const Node* n) {
static std::unordered_map<NodeKind, std::string> simple_map_ops = {
// unary
{aten::_cast_Float, "static_cast<float>(${0})"},
{aten::abs, "fabs(${0})"},
{aten::sigmoid, "1.f / (1.f + expf(-${0}))"},
{aten::relu, "${0} < 0 ? 0.f : ${0} "},
{aten::threshold,
"${0} <= ${1} ? static_cast<decltype(${0})>(${2}) : ${0} "},
{aten::log, "logf(${0})"},
{aten::log10, "log10f(${0})"},
{aten::log1p, "log1pf(${0})"},
{aten::log2, "log2f(${0})"},
{aten::lgamma, "lgammaf(${0})"},
{aten::exp, "expf(${0})"},
{aten::expm1, "expm1f(${0})"},
{aten::erf, "erff(${0})"},
{aten::erfc, "erfcf(${0})"},
{aten::cos, "cosf(${0})"},
{aten::acos, "acosf(${0})"},
{aten::cosh, "coshf(${0})"},
{aten::sin, "sinf(${0})"},
{aten::asin, "asinf(${0})"},
{aten::sinh, "sinhf(${0})"},
{aten::tan, "tanf(${0})"},
{aten::atan, "atanf(${0})"},
{aten::tanh, "tanhf(${0})"},
{aten::sqrt, "sqrtf(${0})"},
{aten::rsqrt, "rsqrtf(${0})"},
{aten::ceil, "ceilf(${0})"},
{aten::floor, "floorf(${0})"},
{aten::round, "roundf(${0})"},
{aten::trunc, "truncf(${0})"},
{aten::frac, "fracf(${0})"},
{aten::reciprocal, "1.f/(${0})"},
{aten::neg, "-${0}"},
// simple binary
{aten::atan2, "atan2(${0}, ${1})"},
{aten::min, "fminf(${0}, ${1})"},
{aten::max, "fmaxf(${0}, ${1})"},
// binary with other
// TODO: some of these ops will not get generated because
// we only work on float inputs/outputs, but they are here to record
// that they are valid mappable ops once we handle more type
{aten::__and__, "${0} && ${1}"},
{aten::__lshift__, "${0} << ${1}"},
{aten::__or__, "${0} || ${1}"},
{aten::__rshift__, "${0} >> ${1}"},
{aten::__xor__, "${0} ^ ${1}"},
{aten::div, "${cast_0} / ${cast_1}"},
{aten::eq, "${0} == ${1}"},
{aten::fmod, "fmodf(${cast_0}, ${cast_1})"},
{aten::ge, "(${0} >= ${1})"},
{aten::gt, "${0} > ${1}"},
{aten::le, "(${0} <= ${1})"},
{aten::lt, "${0} < ${1}"},
{aten::type_as, "(${cast_0})"},
{aten::mul, "${cast_0} * ${cast_1}"},
{aten::ne, "${0} != ${1}"},
{aten::remainder, "remainderf(${0}, ${1})"},
{aten::pow, "powf(${cast_0}, ${cast_1})"},
// alpha
{aten::add, "${cast_0} + ${cast_2}*${cast_1}"},
{aten::sub, "(${cast_0} - ${cast_2}*${cast_1})"},
{aten::rand_like, "uniform(rnd())"},
// min, max
// It may seem unusual to have the bounds as the first case below,
// this is so that if min or max is NaN, they are "ignored"
// and when the input is NaN, the output is, too
{aten::clamp, "(${0}<${1}?${1}:(${0}>${2}?${2}:${0}))"},
// where
{aten::where, "(${0} ? ${1} : ${2})"},
// simple derivatives
{aten::_sigmoid_backward, "${0} * ${1} * (1.f - ${1})"},
{aten::_tanh_backward, "${0} * (1.f - ${1} * ${1})"},
};
if (n->kind() == prim::Constant) {
const auto val = toIValue(n->output()).value();
if (val.isDouble()) {
return scalarValue(val.toDouble());
} else if (val.isBool()) {
return scalarValue(val.toBool());
} else {
JIT_ASSERT(val.isInt());
return scalarValue(val.toInt());
}
}
TemplateEnv env;
size_t i = 0;
auto outtype =
n->output()->type()->expect<c10::TensorType const>()->scalarType();
for (auto in : n->inputs()) {
// PyTorch converts (scalar) argument types to result before applying the
// operator e.g. 1.4-torch.tensor(3) = -2
env.s(std::to_string(i), valueName(in));
env.s(
std::string("cast_") + std::to_string(i),
typeCastedValueName(in->type(), outtype, valueName(in)));
i++;
}
const auto& str = simple_map_ops.at(n->kind());
return format(str, env);
}
// If there is a single user of a node and it's a chunk operation, returns
// that user. Returns nullptr otherwise.
static Node* usedInFusedChunk(const Value* input) {
auto uses = input->uses();
if (uses.size() == 1) {
Node* user = uses[0].user;
if (user->kind() == prim::ConstantChunk) {
return user;
}
}
return nullptr;
}
static void emitIndexingFor(
std::ostream& out,
const std::string& tensor,
const int ndim,
const bool last_is_cont) {
TemplateEnv env;
env.s("tensor", tensor);
out << format("IndexType ${tensor}_offset = 0;\n", env);
out << format("IndexType ${tensor}_linearIndex = linearIndex;\n", env);
for (int d = ndim - 1; d >= 0; --d) {
env.d("d", d);
env.s("mod_sizes", d > 0 ? format("% ${tensor}.sizes[${d}]", env) : "");
env.s(
"times_stride",
(d < ndim - 1 || !last_is_cont)
? format("* ${tensor}.strides[${d}]", env)
: "");
out << dim_calc.format(env);
if (d > 0) {
out << format("${tensor}_linearIndex /= ${tensor}.sizes[${d}];\n", env);
}
}
}
// TODO: handle cases where we need to generate > 2^32 element tensors
std::tuple<
std::string,
std::vector<PartitionDesc>,
std::vector<PartitionDesc>,
bool>
generateKernel(
const std::string& name,
const Graph& graph,
const std::vector<TensorDesc>& input_desc,
const std::vector<TensorDesc>& output_desc,
const bool use_cuda) {
TemplateEnv env;
env.s("kernelName", name);
env.s(
"IndexType",
"unsigned int"); // Note: not uint32_t to avoid including cstdint
std::stringstream body;
std::stringstream tensorOffsets;
std::vector<std::string> formals;
std::vector<std::string> argument_loads;
// Lambda for writing arguments
auto emitFormal = [&](const Value* n, const TensorDesc& desc) {
std::string tensor =
"t" +
std::to_string(
formals.size()); // can't be unique() because Param may be an output
const auto nDim = desc.nDim();
emitIndexingFor(tensorOffsets, tensor, nDim, desc.lastIsContiguous());
env.s("tensor", tensor);
env.d(
"formal_index",
formals.size() +
1); // + 1 because the first argument is the linearIndex
env.d("nDim", nDim);
env.s("scalar_type", scalarTypeName(desc.scalar_type));
formals.push_back(
format("TensorInfo<${scalar_type},${nDim}> ${tensor}", env));
argument_loads.push_back(format(
"*static_cast<TensorInfo<${scalar_type},${nDim}>*>(args[${formal_index}])",
env));
};
// Writes input parameters and creates flattened inputs
std::vector<PartitionDesc> chunk_desc;
std::vector<std::pair<const Value*, const TensorDesc&>> flat_inputs;
{
size_t input_index = 0;
for (const auto& p : graph.inputs()) {
if (const Node* chunk = usedInFusedChunk(p)) {
int64_t dim = chunk->i(attr::dim);
int64_t chunks = chunk->i(attr::chunks);
chunk_desc.emplace_back(input_desc[input_index++], chunks, dim);
for (const auto* o : chunk->outputs()) {
flat_inputs.emplace_back(o, *chunk_desc.back().subTensorDesc());
}
} else {
chunk_desc.emplace_back();
flat_inputs.emplace_back(p, input_desc[input_index++]);
}
}
for (const auto& input : flat_inputs) {
emitFormal(input.first, input.second);
}
}
// Writes output parameters and creates flattened outputs
std::vector<PartitionDesc> concat_desc;
std::vector<std::pair<const Value*, TensorDesc>> flat_output_nodes;
{
size_t i = 0;
for (const auto& o : graph.outputs()) {
const auto& desc = output_desc[i++];
if (o->node()->kind() != prim::FusedConcat) {
emitFormal(o, desc);
concat_desc.emplace_back();
flat_output_nodes.emplace_back(o, desc);
} else {
const auto cat = o->node();
concat_desc.emplace_back(desc, cat->inputs().size(), cat->i(attr::dim));
for (const auto& c : cat->inputs()) {
emitFormal(c, *concat_desc.back().subTensorDesc());
flat_output_nodes.emplace_back(c, desc);
}
}
}
}
// Acquires input values
bool has_half_tensor = false;
size_t formal_count = 0;
for (const auto input : flat_inputs) {
auto p = input.first;
env.s("node", valueName(p));
env.d("formal", formal_count++);
// Acquires and converts (if needed) inputs
// Note: conversion from half is only supported for CUDA kernels.
// The conversion immediately converts fp16 inputs to float.
// Access for other types is common to CUDA and CPU kernels.
const auto is_half = (input.second.scalar_type == at::ScalarType::Half);
if (is_half) {
JIT_ASSERT(use_cuda);
env.s(
"access",
format("__half2float(t${formal}.data[t${formal}_offset])", env));
has_half_tensor = true;
} else {
env.s("access", format("t${formal}.data[t${formal}_offset]", env));
}
env.s("lhs_type", calcScalarTypeName(input.second.scalar_type));
body << format("${lhs_type} ${node} = ${access};\n", env);
}
bool has_random = false;
// Generates code for intermediate nodes
// Note: Concat and Chunk are implicitly generated
// Note: Random number generation is only supported for CUDA kernels.
for (const auto& n : graph.nodes()) {
// Note: FusedConcat nodes work by narrowing the output Tensors before the
// kernel runs
if (n->kind() == prim::FusedConcat)
continue;
if (n->kind() == prim::ConstantChunk)
continue;
if (n->kind() == aten::rand_like) {
JIT_ASSERT(use_cuda);
has_random = true;
}
env.s("node", valueName(n->output()));
env.s("rhs", encodeRHS(n));
env.s("lhs_type", variableType(n->output()->type()));
body << format("${lhs_type} ${node} = ${rhs};\n", env);
}
// Generates writes to output tensors
for (const auto& output : flat_output_nodes) {
const auto& o = output.first;
env.d("formal", formal_count++);
env.s("access", format("t${formal}.data[t${formal}_offset]", env));
env.s("node", valueName(o));
// Acquires and converts (if needed) outputs
// Note: conversion to half is only supported for CUDA kernels.
const auto is_half = (output.second.scalar_type == at::ScalarType::Half);
if (is_half) {
JIT_ASSERT(use_cuda);
body << format("${access} = __float2half(${node});\n", env);
has_half_tensor = true;
} else {
body << format("${access} = ${node};\n", env);
}
}
// Includes headers
// Note: CUDA kernels support halfs and random generation, CPU kernels do not
#if USE_CUDA_FUSER
if (has_half_tensor) {
env.s("HalfHeader", cuda::half_support_literal);
} else {
env.s("HalfHeader", "");
}
if (has_random) {
env.s("RandHeader", cuda::rand_support_literal);
env.s("RandParam", cuda::rand_param);
env.s("RandInit", cuda::rand_init);
} else {
env.s("RandHeader", "");
env.s("RandParam", "");
env.s("RandInit", "");
}
#endif // USE_CUDA_FUSER
// Insantiates the CUDA or CPU-specific templates
env.s("tensorOffsets", tensorOffsets.str());
env.s("kernelBody", body.str());
env.v("formals", formals);
env.v("argument_loads", argument_loads);
std::string code_string;
if (use_cuda) {
#if USE_CUDA_FUSER
env.s("type_declarations", cuda::type_declarations_template.format(env));
code_string = cuda::cuda_compilation_unit_template.format(env);
#else
throw std::runtime_error("CUDA Fusion requested but not supported.");
#endif // USE_CUDA_FUSER
} else {
#if USE_CPU_FUSER
env.s("type_declarations", cpu::type_declarations_template.format(env));
code_string = cpu::cpu_compilation_unit_template.format(env);
#else
throw std::runtime_error("CPU Fusion requested but not supported");
#endif // USE_CPU_FUSER
}
if (debugFuser()) {
std::cerr << "fusion code:" << code_string << std::endl;
}
return std::make_tuple(
code_string, std::move(chunk_desc), std::move(concat_desc), has_random);
}
} // namespace fuser
} // namespace jit
} // namespace torch