blob: 0d8ace860121a87302053d08ad8b52665568dd0d [file] [log] [blame]
#ifndef _WIN32
#include "torch/csrc/jit/fusion_compiler.h"
#include "torch/csrc/jit/ir.h"
#include "torch/csrc/jit/code_template.h"
#include "torch/csrc/jit/resource_guard.h"
#include "torch/csrc/jit/constants.h"
#include "torch/csrc/jit/passes/shape_analysis.h"
#include "torch/csrc/jit/custom_operator.h"
#include "torch/csrc/utils/disallow_copy.h"
#include "torch/csrc/variable_tensor_functions.h"
#include "torch/csrc/utils/hash.h"
#include <torch/csrc/jit/assertions.h>
#include "ATen/ATen.h"
#include "ATen/ExpandUtils.h"
#include "ATen/WrapDimUtils.h"
#ifdef USE_CUDA
#include "ATen/cuda/CUDAContext.h"
#include "THC/THC.h"
#include <THC/THCGenerator.hpp>
#include "torch/csrc/cuda/cuda_check.h"
#include <nvrtc.h>
#include <cuda.h>
#include <cuda_runtime.h>
#endif
#include <string>
#include <algorithm>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <sstream>
#include <iostream>
#include <dlfcn.h>
#include <unistd.h>
#ifdef USE_CUDA
THCGenerator* THCRandom_getGenerator(THCState* state);
#endif
namespace torch { namespace jit {
std::vector<bool> TensorDesc::findContiguous(
const at::IntList& sizes,
const at::IntList& strides) {
JIT_ASSERT(sizes.size() == strides.size());
std::vector<bool> cont(sizes.size());
for(size_t i = 0; i < sizes.size(); ++i) {
int64_t expected_stride = (i + 1 < sizes.size()) ? sizes[i+1]*strides[i+1] : 1;
cont[i] = strides[i] == expected_stride;
}
return cont;
}
// Descriptor for chunk-ing an input tensor into subtensors
// OR concat-ing an output tensor from subtensors
struct PartitionDesc {
size_t nSubtensors; // == 1 for tensors that should not be operated on via chunk/cat
size_t dim; // dimension along which the chunk/concat occurs
std::unique_ptr<TensorDesc> subtensorDesc; // descriptor for the subtensor, if it exists
PartitionDesc()
: nSubtensors(1), dim(0) {}
PartitionDesc(const TensorDesc & desc, size_t nSubtensors, size_t dim)
: nSubtensors(nSubtensors), dim(dim) {
JIT_ASSERT(nSubtensors > 1);
std::vector<bool> cont = desc.contiguity;
if(dim > 0) {
// when we narrow the concatenated output/chunked input
// we make the size[dim] smaller while keeping the stride[dim] the same,
// meaning: stride[dim - 1] != stride[dim]*size[dim]
// so dim - 1 is no longer contiguous
cont[dim - 1] = false;
}
subtensorDesc.reset(new TensorDesc(desc.scalar_type, cont));
}
bool isNoop() const {
return nSubtensors == 1;
}
};
struct FusedKernel {
TH_DISALLOW_COPY_AND_ASSIGN(FusedKernel);
FusedKernel(const std::string & name, AnnotatedGraph & agraph);
virtual ~FusedKernel() = default;
// expects outputs to be pre-allocated
void launch_with_tensors(at::ArrayRef<at::Tensor> inputs, at::ArrayRef<at::Tensor> outputs);
// creates new tensors for outputs
void launch(at::ArrayRef<at::Tensor> inputs, std::vector<at::Tensor> & outputs);
const std::vector<TensorDesc> & outputDescriptors() const {
return output_desc;
}
protected:
virtual at::Backend backend() const = 0;
// arguments is a list of pointers to the arguments for the compiled CUDA/CPU
// code.
// The format of arguments is suitable for directly passing to a call to
// cuLaunchKernel as the kernel arguments.
// Currently the first argument is a pointer to numel (for passing to
// CUDA code), and the remainder are pointers to the TensorInfo<T> structs
// that compiled code uses to load Tensor data.
// launch_with_tensors handles packing at::Tensors into this arguments array.
// CPU code uses the same convension so that launch_with_tensors can be shared.
virtual void launch_raw(uint32_t numel, void ** arguments) = 0;
virtual uint64_t get_rand_offset(uint32_t numel) = 0;
bool has_random;
std::string name;
// We keep these around for debugging
std::string compilation_unit;
std::vector<TensorDesc> input_desc;
std::vector<TensorDesc> output_desc;
// same size as output_desc, describes whether
// an output is actually a concatenation of
// many subtensors that the fusion group produces
std::vector<PartitionDesc> concat_desc;
// same size as input_desc, describes whether an
// input should be broken into subtensors (chunks)
// to be consumed by the fusion group
std::vector<PartitionDesc> chunk_desc;
};
namespace {
#ifdef USE_CUDA
static int ceilDiv(int a, int b) {
return (a + b - 1) / b;
}
#endif
Node* usedInFusedChunk(Value * input) {
auto uses = input->uses();
if (uses.size() == 1) {
Node *user = uses[0].user;
if (user->kind() == prim::FusedChunk) {
return user;
}
}
return nullptr;
}
////////////////////////////////////////////////////////////////////////////////
// Code generation
namespace codegen {
/*with type_as not checking type of its input, a fusion group can have non-fp32 tensor as input.
Correct code for this case is generated, however, nvrtc does not know how to handle int*_t integer types,
so typedefs help it handle those cases*/
auto type_declarations_template = CodeTemplate(R"(
#if defined(__CUDACC_RTC__)
typedef unsigned char uint8_t;
typedef signed char int8_t;
typedef short int int16_t;
typedef long long int int64_t;
${HalfHeader}
${RandHeader}
#endif
typedef ${IndexType} IndexType;
template<typename T, size_t N>
struct TensorInfo {
T * data;
IndexType sizes[N];
IndexType strides[N];
};
)");
// We rewrite the code for philox RNG from curand as nvrtc couldn't resolve the
// curand header correctly.
constexpr auto rand_support_literal = R"(
class Philox {
public:
__device__ inline Philox(unsigned long long seed,
unsigned long long subsequence,
unsigned long long offset) {
key.x = (unsigned int)seed;
key.y = (unsigned int)(seed >> 32);
counter = make_uint4(0, 0, 0, 0);
counter.z = (unsigned int)(subsequence);
counter.w = (unsigned int)(subsequence >> 32);
STATE = 0;
incr_n(offset / 4);
}
__device__ inline unsigned long operator()() {
if(STATE == 0) {
uint4 counter_ = counter;
uint2 key_ = key;
for(int i = 0; i < 9; i++) {
counter_ = single_round(counter_, key_);
key_.x += (kPhilox10A); key_.y += (kPhilox10B);
}
output = single_round(counter_, key_);
incr();
}
unsigned long ret;
switch(STATE) {
case 0: ret = output.x; break;
case 1: ret = output.y; break;
case 2: ret = output.z; break;
case 3: ret = output.w; break;
}
STATE = (STATE + 1) % 4;
return ret;
}
private:
uint4 counter;
uint4 output;
uint2 key;
unsigned int STATE;
__device__ inline void incr_n(unsigned long long n) {
unsigned int nlo = (unsigned int)(n);
unsigned int nhi = (unsigned int)(n >> 32);
counter.x += nlo;
if (counter.x < nlo)
nhi++;
counter.y += nhi;
if (nhi <= counter.y)
return;
if (++counter.z)
return;
++counter.w;
}
__device__ inline void incr() {
if (++counter.x)
return;
if (++counter.y)
return;
if (++counter.z)
return;
++counter.w;
}
__device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
unsigned int *result_high) {
*result_high = __umulhi(a, b);
return a*b;
}
__device__ inline uint4 single_round(uint4 ctr, uint2 key) {
unsigned int hi0;
unsigned int hi1;
unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0);
unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1);
uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0};
return ret;
}
static const unsigned long kPhilox10A = 0x9E3779B9;
static const unsigned long kPhilox10B = 0xBB67AE85;
static const unsigned long kPhiloxSA = 0xD2511F53;
static const unsigned long kPhiloxSB = 0xCD9E8D57;
};
// Inverse of 2^32.
#define M_RAN_INVM32 2.3283064e-10f
__device__ __inline__ float uniform(unsigned int x) {
return x * M_RAN_INVM32;
}
)";
constexpr auto rand_param = ",unsigned long long seed, unsigned long long offset";
constexpr auto rand_init = R"(
int idx = blockIdx.x*blockDim.x + threadIdx.x;
Philox rnd(seed, idx, offset);
)";
auto cuda_compilation_unit_template = CodeTemplate(R"(
${type_declarations}
extern "C" __global__
void ${kernelName}(IndexType totalElements, ${formals} ${RandParam}) {
${RandInit}
for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
linearIndex < totalElements;
linearIndex += gridDim.x * blockDim.x) {
// Convert `linearIndex` into an offset of tensor:
${tensorOffsets}
// calculate the results
${kernelBody}
}
}
)");
auto cpu_compilation_unit_template = CodeTemplate(R"(
#include <cstddef>
#include <cstdint>
#include <math.h>
${type_declarations}
#define OMP_THRESHOLD 100000
static void ${kernelName}_kernel(IndexType totalElements, ${formals}) {
#pragma omp parallel for if(totalElements > OMP_THRESHOLD)
for (IndexType linearIndex = 0;
linearIndex < totalElements;
linearIndex += 1) {
// Convert `linearIndex` into an offset of tensor:
${tensorOffsets}
// calculate the results
${kernelBody}
}
}
extern "C"
void ${kernelName}(IndexType totalElements, void ** args) {
${kernelName}_kernel(totalElements ${,argument_loads});
}
)");
// This snippet enables half support in the jit. Following the pattern for
// reductions, fp16 input data is immediately upconverted to float
// with __half2float(). All mathematical operations are done on float
// values, and if needed the intermediate float representation is
// converted to half with __float2half() when writing to a half tensor.
constexpr auto half_support_literal = R"(
#define __HALF_TO_US(var) *(reinterpret_cast<unsigned short *>(&(var)))
#define __HALF_TO_CUS(var) *(reinterpret_cast<const unsigned short *>(&(var)))
#if defined(__cplusplus)
struct __align__(2) __half {
__host__ __device__ __half() { }
protected:
unsigned short __x;
};
/* All intrinsic functions are only available to nvcc compilers */
#if defined(__CUDACC__)
/* Definitions of intrinsics */
__device__ __half __float2half(const float f) {
__half val;
asm("{ cvt.rn.f16.f32 %0, %1;}\n" : "=h"(__HALF_TO_US(val)) : "f"(f));
return val;
}
__device__ float __half2float(const __half h) {
float val;
asm("{ cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(__HALF_TO_CUS(h)));
return val;
}
#endif /* defined(__CUDACC__) */
#endif /* defined(__cplusplus) */
#undef __HALF_TO_US
#undef __HALF_TO_CUS
typedef __half half;
)";
// curDimIndex = linearId % sizes[i]; // % sizes[i] is not needed for d == 0, because we already guard for numel outside the index calculation
// offset += curDimIndex*strides[i]; // *strides[i] is optional if list_is_cont becaause strides.back() == 1
// linearId /= sizes[i];
auto dim_calc = CodeTemplate(R"(
//printf("tensor ${tensor} sizes[${d}] = %d, strides[${d}] = %d\n", ${tensor}.sizes[${d}],${tensor}.strides[${d}]);
size_t ${tensor}_dimIndex${d} = ${tensor}_linearIndex ${mod_sizes};
${tensor}_offset += ${tensor}_dimIndex${d} ${times_stride};
)");
static void emitIndexingFor(std::ostream & out, const std::string & tensor, int ndim, bool last_is_cont) {
TemplateEnv env;
env.s("tensor",tensor);
out << format("IndexType ${tensor}_offset = 0;\n",env);
out << format("IndexType ${tensor}_linearIndex = linearIndex;\n",env);
for(int d = ndim - 1; d >= 0; --d) {
env.d("d",d);
env.s("mod_sizes", d > 0 ? format("% ${tensor}.sizes[${d}]",env) : "");
env.s("times_stride",(d < ndim - 1 || !last_is_cont) ?
format("* ${tensor}.strides[${d}]",env) : "");
out << dim_calc.format(env);
if(d > 0) {
out << format("${tensor}_linearIndex /= ${tensor}.sizes[${d}];\n",env);
}
}
}
static std::string valueName(Value * n) {
return "n" + std::to_string(n->unique());
}
static std::string scalarValue(int64_t v) {
return std::to_string(v);
}
static std::string scalarValue(double v) {
std::ostringstream out;
out << std::scientific << v << "f";
return out.str();
}
static const char * scalarTypeName(at::ScalarType type) {
if (type == at::ScalarType::Half) {
return "half";
}
switch(type) {
#define DEFINE_CASE(ctype,name,_) \
case at::ScalarType::name: return #ctype;
AT_FORALL_SCALAR_TYPES_EXCEPT_HALF(DEFINE_CASE)
#undef DEFINE_CASE
default:
throw std::runtime_error("unknown scalar type");
}
}
std::string encodeRHS(Node * n) {
static std::unordered_map<NodeKind, std::string> simple_map_ops = {
// unary
{aten::abs, "absf(${0})"},
{aten::sigmoid, "1.f / (1.f + expf(-${0}))"},
{aten::relu, "${0} < 0 ? 0.f : ${0} "},
{aten::log, "logf(${0})"},
{aten::log10, "log10f(${0})"},
{aten::log1p, "log1pf(${0})"},
{aten::log2, "log2f(${0})"},
{aten::lgamma, "lgammaf(${0})"},
{aten::exp, "expf(${0})"},
{aten::expm1, "expm1f(${0})"},
{aten::cos, "cosf(${0})"},
{aten::acos, "acosf(${0})"},
{aten::cosh, "coshf(${0})"},
{aten::sin, "sinf(${0})"},
{aten::asin, "asinf(${0})"},
{aten::sinh, "sinhf(${0})"},
{aten::tan, "tanf(${0})"},
{aten::atan, "atanf(${0})"},
{aten::tanh, "tanhf(${0})"},
{aten::sqrt, "sqrtf(${0})"},
{aten::rsqrt, "rsqrtf(${0})"},
{aten::ceil, "ceilf(${0})"},
{aten::floor, "floorf(${0})"},
{aten::round, "roundf(${0})"},
{aten::trunc, "truncf(${0})"},
{aten::frac, "fracf(${0})"},
{aten::reciprocal, "reciprocalf(${0})"},
{aten::neg, "-${0}"},
//simple binary
{aten::atan2, "atan2(${0}, ${1})"},
{aten::min, "fminf(${0}, ${1})"},
{aten::max, "fmaxf(${0}, ${1})"},
//binary with other
// TODO: some of these ops will not get generated because
// we only work on float inputs/outputs, but they are here to record
// that they are valid mappable ops once we handle more type
{aten::__and__, "${0} && ${1}"},
{aten::__lshift__, "${0} << ${1}"},
{aten::__or__, "${0} || ${1}"},
{aten::__rshift__, "${0} >> ${1}"},
{aten::__xor__, "${0} ^ ${1}"},
{aten::div, "${0} / ${1}"},
{aten::eq, "${0} == ${1}"},
{aten::fmod, "fmodf(${0}, ${1})"},
{aten::ge, "(${0} >= ${1})"},
{aten::gt, "${0} > ${1}"},
{aten::le, "(${0} <= ${1})"},
{aten::lt, "${0} < ${1}"},
{aten::type_as, "(${0})"}, //everything is implicitly convertible to float
{aten::mul, "${0} * ${1}"},
{aten::ne, "${0} != ${1}"},
{aten::remainder, "remainderf(${0}, ${1})"},
{aten::pow, "powf(${0}, ${1})"},
//alpha
{aten::add, "${0} + ${2}*${1}"},
{aten::sub, "(${0} - ${2}*${1})"},
{aten::rand_like, "uniform(rnd())"},
// simple derivatives
{aten::_sigmoid_backward, "${0} * ${1} * (1.f - ${1})"},
{aten::_tanh_backward, "${0} * (1.f - ${1} * ${1})"},
};
if (n->kind() == prim::Constant) {
auto val = toIValue(n->output()).value();
if (val.isDouble()) {
return scalarValue(val.toDouble());
} else {
JIT_ASSERT(val.isInt());
return scalarValue(val.toInt());
}
}
TemplateEnv env;
size_t i = 0;
for(auto in : n->inputs()) {
env.s(std::to_string(i++), valueName(in));
}
const auto & str = simple_map_ops.at(n->kind());
return format(str, env);
}
// Returns: (input chunk metadata, output concat metadata, is_random)
std::tuple<std::vector<PartitionDesc>,std::vector<PartitionDesc>,bool> emitCompilationUnit(
std::ostream& out,
const std::string& name,
AnnotatedGraph& agraph,
bool use_cuda) {
bool has_random = false;
Graph& subgraph = *agraph.graph;
TemplateEnv env;
env.s("kernelName",name);
// TODO: handle cases where we need to generate > 2^32 element tensors
env.s("IndexType","unsigned int"); //avoiding slow header includes to get uint32_t
std::stringstream body;
std::stringstream tensorOffsets;
std::vector<std::string> formals;
std::vector<std::string> argument_loads;
auto emitFormal = [&](Value * n, const TensorDesc & desc) {
std::string tensor = "t" + std::to_string(formals.size()); //can't be unique() because Param may be an output
size_t nDim = desc.nDim();
emitIndexingFor(tensorOffsets, tensor, nDim, desc.lastIsContiguous());
env.s("tensor",tensor);
env.d("formal_index", formals.size() + 1); // + 1 because the first argument is the linearIndex
env.d("nDim",nDim);
env.s("scalar_type",scalarTypeName(desc.scalar_type));
formals.push_back(format("TensorInfo<${scalar_type},${nDim}> ${tensor}",env));
argument_loads.push_back(format("*static_cast<TensorInfo<${scalar_type},${nDim}>*>(args[${formal_index}])",env));
};
std::vector<PartitionDesc> chunk_desc;
std::vector<std::pair<Value*,TensorDesc&>> flat_inputs;
{
size_t input_index = 0;
for(auto p : subgraph.inputs()) {
if (Node * chunk = usedInFusedChunk(p)) {
int64_t dim = chunk->i(attr::dim);
int64_t chunks = chunk->i(attr::chunks);
chunk_desc.emplace_back(agraph.input_desc[input_index++], chunks, dim);
for (auto * o : chunk->outputs()) {
flat_inputs.emplace_back(o, *chunk_desc.back().subtensorDesc);
}
} else {
chunk_desc.emplace_back();
flat_inputs.emplace_back(p, agraph.input_desc[input_index++]);
}
}
for (auto & input : flat_inputs) {
emitFormal(input.first, input.second);
}
}
std::vector<PartitionDesc> concat_desc;
std::vector<std::pair<Value*,TensorDesc>> flat_output_nodes;
{
size_t i = 0;
for(auto o : subgraph.outputs()) {
auto & desc = agraph.output_desc[i++];
if(o->node()->kind() != prim::FusedConcat) {
emitFormal(o, desc);
concat_desc.emplace_back();
flat_output_nodes.emplace_back(o, desc);
} else {
auto cat = o->node();
concat_desc.emplace_back(desc, cat->inputs().size(), cat->i(attr::dim));
for(auto c : cat->inputs()) {
emitFormal(c, *concat_desc.back().subtensorDesc);
flat_output_nodes.emplace_back(c, desc);
}
}
}
}
bool has_half_tensor = false;
size_t formal_count = 0;
for(auto input : flat_inputs) {
auto p = input.first;
env.s("node",valueName(p));
env.d("formal",formal_count++);
// Acquires and converts (if needed) inputs
bool is_half = input.second.scalar_type == at::ScalarType::Half;
if (is_half) {
AT_ASSERT(use_cuda);
env.s(
"access"
, format("__half2float(t${formal}.data[t${formal}_offset])", env));
has_half_tensor = true;
} else {
env.s("access", format("t${formal}.data[t${formal}_offset]", env));
}
//TODO: actual type propagation rather than relying on auto..
body << format("auto ${node} = ${access};\n",env);
}
for(auto n : subgraph.nodes()) {
// FusedConcat nodes work by narrowing the output Tensors before the kernel runs
if (n->kind() == prim::FusedConcat)
continue;
if (n->kind() == prim::FusedChunk)
continue;
if(n->kind() == aten::rand_like) {
has_random = true;
if(!use_cuda)
throw std::runtime_error("Fusion doesn't support rand on CPU");
}
env.s("node",valueName(n->output()));
env.s("rhs", encodeRHS(n));
body << format("auto ${node} = ${rhs};\n",env);
}
for(auto output : flat_output_nodes) {
auto o = output.first;
env.d("formal",formal_count++);
env.s("access",format("t${formal}.data[t${formal}_offset]",env));
env.s("node",valueName(o));
// Acquires and converts (if needed) outputs
bool is_half = output.second.scalar_type == at::ScalarType::Half;
if (is_half) {
AT_ASSERT(use_cuda);
body << format("${access} = __float2half(${node});\n",env);
has_half_tensor = true;
} else {
body << format("${access} = ${node};\n",env);
}
}
// Includes half support if any half tensors are involved
if (has_half_tensor) {
env.s("HalfHeader", half_support_literal);
} else {
env.s("HalfHeader", "");
}
if (has_random) {
env.s("RandHeader", rand_support_literal);
env.s("RandParam", rand_param);
env.s("RandInit", rand_init);
} else {
env.s("RandHeader", "");
env.s("RandParam", "");
env.s("RandInit", "");
}
env.s("tensorOffsets",tensorOffsets.str());
env.s("kernelBody",body.str());
env.v("formals",formals);
env.v("argument_loads",argument_loads);
env.s("type_declarations", type_declarations_template.format(env));
if(use_cuda) {
out << cuda_compilation_unit_template.format(env);
} else {
out << cpu_compilation_unit_template.format(env);
}
return std::make_tuple(std::move(chunk_desc), std::move(concat_desc), has_random);
}
////////////////////////////////////////////////////////////////////////////////
} // codegen namespace
} // anonymous namespace
////////////////////////////////////////////////////////////////////////////////
// CompiledFunctionFunction
// Host-side view of TensorInfo (that visivle for the kernel is defined above).
// Note dims[0] - we need to dynamically allocate the dims.
struct TensorInfo {
void * data;
#pragma GCC diagnostic ignored "-Wpedantic"
uint32_t sizes_strides[0];
#pragma GCC diagnostic pop
uint32_t* sizes(size_t nDim) { return &sizes_strides[0]; }
uint32_t* strides(size_t nDim) { return &sizes_strides[nDim]; }
};
FusedKernel::FusedKernel(const std::string & name, AnnotatedGraph & agraph)
: name(name)
, input_desc(agraph.input_desc)
, output_desc(agraph.output_desc) {}
namespace {
// Tries to compress sizes and strides according to cont. Emits the result t
// c_sizes, c_strides and throws an error on failure (if can't compress)
void compressContiguous(
at::IntList sizes,
at::IntList strides,
const std::vector<bool> & cont,
uint32_t * c_sizes,
uint32_t * c_strides) {
size_t compressed_dims = 0;
size_t cur = 0;
size_t ndim = sizes.size();
while(cur < ndim) {
size_t total_size = sizes[cur];
cur++;
while(cont[cur-1] && cur < ndim) {
JIT_ASSERT(strides[cur-1] == sizes[cur]*strides[cur]);
total_size *= sizes[cur];
cur++;
}
// cur starts pointing at the beginning of run to compress
// cur ends one _after_ the terminating false or end of list.
// total_size is the size of all dimensions [begin,end)
// examples:
// f = not cont.
// t = cont.
// x = don't care, including past end of list
// s = start of cur
// e = end of cur
// f x x x
// s e
// t f x x
// s e
// t t f x
// s e
c_sizes[compressed_dims] = total_size;
c_strides[compressed_dims] = strides[cur-1];
compressed_dims++;
}
JIT_ASSERT(!cont.back() || strides.back() == 1);
}
} // anonymous namespace
// XXX: Assumes that after at::chunk, all inputs are the same size
static std::vector<int64_t> computeMapSize(
const at::Tensor& tensor,
const PartitionDesc& chunkDesc) {
std::vector<int64_t> sizes(tensor.sizes().begin(), tensor.sizes().end());
// Should have been checked in graph fuser
JIT_ASSERT(sizes[chunkDesc.dim] % chunkDesc.nSubtensors == 0);
sizes[chunkDesc.dim] /= chunkDesc.nSubtensors;
return sizes;
}
// XXX: this code assumes that inputs are 32-bit addressable
static uint32_t computeNumel(at::ArrayRef<int64_t> sizes) {
uint32_t result = 1;
if (sizes.size() == 0) {
return 1; // scalar tensor
}
for (int64_t size : sizes) {
result *= size;
}
return result;
}
void FusedKernel::launch_with_tensors(at::ArrayRef<at::Tensor> inputs, at::ArrayRef<at::Tensor> outputs) {
at::DeviceGuard device_guard(inputs);
JIT_ASSERT(inputs.size() == input_desc.size());
JIT_ASSERT(outputs.size() == output_desc.size());
size_t flat_inputs_size = 0;
size_t flat_outputs_size = 0;
for(auto & c : chunk_desc)
flat_inputs_size += c.nSubtensors;
for(auto & c : concat_desc)
flat_outputs_size += c.nSubtensors;
// XXX: this code assumes that inputs are 32-bit addressable
// XXX: this code assumes that all inputs are of the same size
JIT_ASSERT(inputs[0].numel() <= std::numeric_limits<uint32_t>::max());
// Compute map_size, numel from the first input
at::IntList map_size;
uint32_t numel;
std::vector<int64_t> keep_alive_size;
if (chunk_desc[0].isNoop()) {
map_size = inputs[0].sizes();
numel = inputs[0].numel();
} else {
keep_alive_size = computeMapSize(inputs[0], chunk_desc[0]);
map_size = keep_alive_size;
numel = computeNumel(map_size);
}
// Compute the storage needed to store TensorInfo structs for inputs and outputs.
size_t uncompressedDim = input_desc.at(0).contiguity.size();
size_t maxPossibleTensorInfoSize = sizeof(TensorInfo) + 2 * sizeof(uint32_t) * uncompressedDim;
size_t maxPossibleBufferSize = maxPossibleTensorInfoSize * (flat_inputs_size + flat_outputs_size);
std::vector<char> buffer(maxPossibleBufferSize);
char * buffer_next = buffer.data();
// A vector of arguments to the kernel. It's (numel, *input_descs, *output_descs)
std::vector<void*> arguments;
arguments.reserve(3 + flat_inputs_size + flat_outputs_size);
auto addTensorInfoRaw = [&](TensorDesc & desc, void* data_ptr, at::IntList sizes, at::IntList strides) {
size_t nDim = desc.nDim(); // NOTE: this is the compressed dim
JIT_ASSERT(nDim <= uncompressedDim); // We'd overflow the space otherwise
auto ti = reinterpret_cast<TensorInfo*>(buffer_next);
ti->data = data_ptr;
compressContiguous(sizes, strides, desc.contiguity, ti->sizes(nDim), ti->strides(nDim));
buffer_next += maxPossibleTensorInfoSize;
arguments.push_back(ti);
};
// Asserts that t's dims can be compressed in the same way as in desc
// (that's what the kernel assumes), and appends it to the arguments vector.
auto addTensorInfo = [&](TensorDesc & desc, const at::Tensor & t) {
addTensorInfoRaw(desc, t.data_ptr(), t.sizes(), t.strides());
};
arguments.push_back(&numel);
for (size_t i = 0; i < input_desc.size(); ++i) {
auto & chunk = chunk_desc[i];
const at::Tensor& tensor = inputs[i];
if (chunk.isNoop()) {
addTensorInfo(input_desc[i], tensor);
} else {
size_t chunk_offset = map_size[chunk.dim] * tensor.stride(chunk.dim) * elementSize(tensor.type().scalarType());
char * data_ptr = reinterpret_cast<char*>(tensor.data_ptr());
for (size_t chunks = 0; chunks < chunk.nSubtensors; ++chunks) {
addTensorInfoRaw(*chunk.subtensorDesc, data_ptr, map_size, tensor.strides());
data_ptr += chunk_offset;
}
}
}
for (size_t i = 0; i < output_desc.size(); ++i) {
auto & c = concat_desc[i];
at::Tensor o = outputs[i];
if(c.isNoop()) {
o.resize_(map_size);
addTensorInfo(output_desc[i], outputs[i]);
} else {
size_t small_size = map_size[c.dim];
std::vector<int64_t> concat_size(map_size.begin(), map_size.end());
concat_size[c.dim] = small_size * c.nSubtensors;
o.resize_(concat_size);
size_t offset = 0;
for(size_t j = 0; j < c.nSubtensors; ++j) {
// because the concatenated_output stays live, the underlying data
// in this view remains live through the end of this function
// so there is not need to hold onto this tensor
auto view = o.narrow(c.dim, offset, small_size);
addTensorInfo(*c.subtensorDesc, view);
offset += small_size;
}
}
}
// If the kernel call contains a random op, we need to pass in random seeds as
// well.
#ifdef USE_CUDA
if(has_random && this->backend() == at::Backend::CUDA) {
auto gen_ = THCRandom_getGenerator(at::globalContext().getTHCState());
uint64_t offset =
gen_->state.philox_seed_offset.fetch_add(this->get_rand_offset(numel));
arguments.push_back(&gen_->state.initial_seed);
arguments.push_back(&offset);
}
#endif
launch_raw(numel, arguments.data());
}
void FusedKernel::launch(at::ArrayRef<at::Tensor> inputs, std::vector<at::Tensor> & outputs) {
at::DeviceGuard guard(inputs.back());
JIT_ASSERT(inputs.size() > 0);
auto & ref_type = inputs[0].type();
outputs.clear();
outputs.reserve(outputDescriptors().size());
for(auto & od : outputDescriptors()) {
outputs.push_back(ref_type.toScalarType(od.scalar_type).tensor());
}
launch_with_tensors(inputs, outputs);
}
////////////////////////////////////////////////////////////////////////////////
// CUDAFusedKernel
#ifdef USE_CUDA
void checkCUDAVersion(const cudaDeviceProp & prop) {
if ((prop.major >= 6 && CUDA_VERSION < 8000) ||
(prop.major >= 7 && CUDA_VERSION < 9000)) {
std::stringstream err_string;
err_string << "In CUDAFusedKernel, PyTorch compiled with insufficient CUDA version: "
<< CUDA_VERSION << " for the current GPU device " << prop.name
<< " with device capability " << prop.major << "." << prop.minor;
throw std::runtime_error(err_string.str());
}
}
struct CUDAFusedKernel : public FusedKernel {
CUDAFusedKernel(const std::string & name, AnnotatedGraph & agraph)
: FusedKernel(name, agraph) {
at::DeviceGuard device_guard(agraph.device);
TORCH_CUDA_CHECK(cudaGetDeviceProperties(&prop, agraph.device));
checkCUDAVersion(prop);
std::stringstream cu;
std::tie(chunk_desc, concat_desc, has_random) = codegen::emitCompilationUnit(cu, name, agraph, true);
compilation_unit = cu.str();
nvrtcProgram program;
TORCH_NVRTC_CHECK(nvrtcCreateProgram(&program, compilation_unit.c_str(), NULL, 0, nullptr, nullptr));
std::string compute = "--gpu-architecture=compute_" + std::to_string(prop.major) + std::to_string(prop.minor);
std::vector<const char *> args = {"--std=c++11", compute.c_str(), "-default-device"};
nvrtcResult result = nvrtcCompileProgram(program, args.size(), args.data());
if (result == NVRTC_ERROR_COMPILATION) {
size_t logsize;
nvrtcGetProgramLogSize(program, &logsize);
std::vector<char> log(logsize);
nvrtcGetProgramLog(program, log.data());
cu << log.data();
throw std::runtime_error(cu.str());
}
ResourceGuard holdProgram([&] {
TORCH_NVRTC_CHECK(nvrtcDestroyProgram(&program));
});
TORCH_NVRTC_CHECK(result);
size_t ptx_size;
TORCH_NVRTC_CHECK(nvrtcGetPTXSize(program, &ptx_size));
ptx.resize(ptx_size);
TORCH_NVRTC_CHECK(nvrtcGetPTX(program, ptx.data()));
TORCH_CU_CHECK(cuModuleLoadData(&module, ptx.data()));
TORCH_CU_CHECK(cuModuleGetFunction(&function, module, name.c_str()));
TORCH_CU_CHECK(cuOccupancyMaxActiveBlocksPerMultiprocessor(
&maxBlocks, function, 128, 0));
maxBlocks *= prop.multiProcessorCount;
}
virtual ~CUDAFusedKernel() override {
TORCH_CU_CHECK(cuModuleUnload(module));
}
protected:
virtual at::Backend backend() const override {
return at::Backend::CUDA;
}
virtual uint64_t get_rand_offset(uint32_t numel) override {
int numBlocks = std::min(maxBlocks, ceilDiv(numel, blockSize));
return 4 * (ceil(numel/(4 * blockSize * numBlocks)) + 1);
}
virtual void launch_raw(uint32_t numel, void ** arguments) override {
int numBlocks = std::min(maxBlocks, ceilDiv(numel, blockSize));
//std::cout << "maxBlocks = " << maxBlocks << " needed blocks: " << ceilDiv(numel,blockSize)
// << " numblocks = " << numBlocks;
// it is possible that this is the first cuda call on this thread
// so make sure we initialize the Driver API's context
// cudaFree(0) accomplishes this.
CUcontext pctx = 0;
TORCH_CU_CHECK(cuCtxGetCurrent(&pctx));
if (!pctx) {
std::unique_lock<std::mutex> cudaFreeMutexLock(
*(THCCachingAllocator_getCudaFreeMutex()));
cudaFree(0);
}
CUstream stream = at::cuda::getCurrentCUDAStream();
TORCH_CU_CHECK(cuLaunchKernel(
function,
numBlocks, 1, 1,
blockSize, 1, 1,
0, stream,
arguments,
nullptr));
}
std::vector<char> ptx;
CUmodule module;
CUfunction function;
// we record prop/device so if they are availiable for launch heuristics
// querying at launch is too slow for device properties.
int device;
cudaDeviceProp prop;
int blockSize = 128;
int maxBlocks;
};
#endif
////////////////////////////////////////////////////////////////////////////////
// CPUFusedKernel
struct TempFile {
TH_DISALLOW_COPY_AND_ASSIGN(TempFile);
TempFile(const std::string & t, int suffix) {
// mkstemps edits its first argument in places
// so we make a copy of the string here, including null terminator
std::vector<char> tt(t.c_str(), t.c_str() + t.size() + 1);
int fd = mkstemps(tt.data(), suffix);
JIT_ASSERT(fd != -1);
file_ = fdopen(fd, "r+");
// - 1 becuase tt.size() includes the null terminator,
// but std::string does not expect one
name_ = std::string(tt.begin(), tt.end() - 1);
}
const std::string & name() const {
return name_;
}
void sync() {
fflush(file_);
}
void write(const std::string & str) {
size_t result = fwrite(str.c_str(), 1, str.size(), file_);
JIT_ASSERT(str.size() == result);
}
FILE* file() {
return file_;
}
~TempFile() {
if(file_ != nullptr) {
// unlink first to ensure another mkstemps doesn't
// race between close and unlink
unlink(name_.c_str());
fclose(file_);
}
}
private:
FILE * file_ = nullptr;
std::string name_;
};
static void* checkDL(void * x) {
if(!x) {
AT_ERROR("error in dlopen or dlsym: ", dlerror());
}
return x;
}
struct DynamicLibrary {
TH_DISALLOW_COPY_AND_ASSIGN(DynamicLibrary);
DynamicLibrary(const char * name) {
handle = checkDL(dlopen(name, RTLD_LOCAL | RTLD_NOW));
}
void * sym(const char * name) {
JIT_ASSERT(handle);
return checkDL(dlsym(handle, name));
}
~DynamicLibrary() {
if(!handle) return;
dlclose(handle);
}
private:
void * handle = nullptr;
};
static const std::string so_template = "/tmp/pytorch_fuserXXXXXX.so";
static const std::string cpp_template = "/tmp/pytorch_fuserXXXXXX.cpp";
// NB: -march=native not supported on PPC64 g++. It's a bit annoying
// to do a configure-style test to decide whether or not the g++
// actually supports it or not, so we heuristically use the host
// compiler to predict if the runtime compiler supports the option we
// want. This probably won't work if you're cross-compiling.
// NB: -march=native is disabled because it has caused problems where
// compiler and assembler do not agree on what native instruction they
// understand for AVX512. When we need better CPU performance this
// optimization can be re-enabled by tracking down the platforms where
// this error occurs and only selectively disabling it.
static const std::string compile_string =
"\"${cxx}\" -O3 -g "
#ifndef __PPC64__
// "-march=native "
#endif
"-std=c++11 -fPIC ${fopenmp} -shared \"${cpp_file}\" -o \"${so_file}\" -lm";
static void runCompiler(FusionCompilerConfig & config, const std::string & cpp_file, const std::string & so_file) {
TemplateEnv env;
env.s("cxx", config.cxx);
env.s("fopenmp", config.openmp ? "-fopenmp" : "");
env.s("cpp_file",cpp_file);
env.s("so_file",so_file);
std::string result = format(compile_string,env);
int r = system(result.c_str());
if(config.openmp && r != 0) {
std::cerr << "warning: pytorch jit fuser failed to compile with openmp, trying without it...\n";
config.openmp = false; // disable for future compiles
return runCompiler(config, cpp_file, so_file);
}
JIT_ASSERTM(r == 0, "Failed to compile a fused CPU kernel");
}
static const std::string disas_string =
"objdump -M intel -d \"${so_file}\"";
static void disas(const std::string & so_file) {
TemplateEnv env;
env.s("so_file", so_file);
std::string cmd = format(disas_string, env);
int r = system(cmd.c_str());
JIT_ASSERT(r == 0);
}
struct CPUFusedKernel : public FusedKernel {
CPUFusedKernel(const std::string & name, AnnotatedGraph & agraph, FusionCompilerConfig & config)
: FusedKernel(name, agraph) {
TempFile so_file(so_template, 3);
TempFile cpp_file(cpp_template, 4);
std::stringstream cu;
std::tie(chunk_desc, concat_desc, has_random) = codegen::emitCompilationUnit(cu, name, agraph, false);
JIT_ASSERT(!has_random);
compilation_unit = cu.str();
cpp_file.write(compilation_unit);
cpp_file.sync();
runCompiler(config, cpp_file.name(), so_file.name());
if(config.debug) {
disas(so_file.name());
}
so_lib.reset(new DynamicLibrary(so_file.name().c_str()));
#pragma GCC diagnostic ignored "-Wpedantic"
kernel = reinterpret_cast<void(*)(uint32_t, void**)>(so_lib->sym(name.c_str()));
#pragma GCC diagnostic pop
}
protected:
virtual at::Backend backend() const override {
return at::Backend::CPU;
}
virtual uint64_t get_rand_offset(uint32_t numel) override {
return numel;
}
virtual void launch_raw(uint32_t numel, void ** arguments) override {
kernel(numel, arguments);
}
std::unique_ptr<DynamicLibrary> so_lib;
void (*kernel)(uint32_t, void**) = nullptr;
};
////////////////////////////////////////////////////////////////////////////////
// FusedKernelCache
// Note [Run-time shape checking code]
// There are multiple assumptions that our codegen makes, which we can't check
// in the fusion pass, because we don't have the shape information. Most notably,
// that all values (post-input-chunk, and pre-output-concat) have the same shape
// (hereinafter referred to as map size). One way to check this would be to run
// shape propagation for every size configuration we get as an input, but that
// requires a full graph traversal, and might incur unnecessary overhead. The code
// below uses a few nice properties of broadcasting rules and their interactions with
// pointwise operations, and takes a smarter approach, to quickly verify validity of
// the kernel.
//
// Notation:
// - a.s when a is a tensor is a shorthand for a.shape.
// - B is a shorthand for the broadcasting/expanding function. It is used as a
// vararg function.
// - E is a shorthand for expand function.
// - Every pointwise operation can be equivalently rewritten as
// f(a, b) = f^(E(a, B(a.s, b.s)), E(b, B(a.s, b.s))),
// where f^ is a non-broadcasting verison of f.
// - A set of inputs that are used to produce a certain graph output is referred to
// as the output's broadcasting group (see Lemma 2. for explanation why).
//
// Lemma 1. Set of lists of integers (shapes) + { _|_ (bottom/error marker) }, with the
// operation of broadcasting (returning bottom upon shape mismatch) forms a monoid.
// In simpler terms: broadcasting is associative, i.e. B(a, B(b, c)) == B(B(a, b), c).
//
// Proof. Satisfies all monoid laws:
// - Closed under broadcasting (trivial)
// - Empty shape is the identity element: B(a, []) == B([], a) == a
// - Associativity: A simple visual proof is that you can expand 3 tensors
// at the same time by stacking their sizes (with alignment to the right),
// just as you'd do in the case of 2 tensors, but with an intermediate
// (the algorithm ends up being pretty much the same).
//
// Lemma 2. Shape of an output of an arbitrary DAG of pointwise ops depends only on the set
// of inputs used in this DAG and is equal to B([i.shape for i in used_inputs]).
//
// Proof. Let G be any DAG of pointwise ops and < be any valid topological
// ordering on nodes of G. Proof by induction over <.
// Base case (graph input):
// Trivial (input is also an output).
// Step (n = f(q, r)):
// Let QS (RS) be the set of shapes of inputs that q (r) depends on.
// Note that the set of inputs that n depends on is exactly QS + RS.
// shape(n) == shape(f(q, r))
// (def of f)
// == shape(f^(E(q, B(q.s, r.s)), E(r, B(q.s, r.s))))
// (output shape of f^ is equal to either of argument shapes)
// == shape(E(q, B(q.s, r.s)))
// (property of expand)
// == B(q.s, r.s)
// (induction assumption)
// == B(B(QS...), B(RS...))
// (Lemma 1.)
// == B(QS..., RS...)
// (repeated shapes don't matter for broadcasting)
// == B((QS + RS)...)
//
// Lemma 3. Expands are distributive over pointwise ops, i.e. E(f(a, b), s) = f(E(a, s), E(b, s))
// Lemma 4. Expands can be collapsed, i.e. E(E(x, s1), s2) = E(x, B(s1, s2)).
// Proof. A simple exercise for the reader :)
//
// Theorem. If all (pre-concat-)outputs have equal shapes, then we can push the expands to
// (post-chunk-)inputs, and have all intermediates of the same shape
// (no broadcasting happening in the body).
//
// Proof. Using the above lemmas we can easily show that a graph with a single output
// can be easily rewritten by taking the shape given by B applied to all input
// shapes, expanding inputs to it, and using only non-broadcasting operations.
// Example:
//
// let d = f(a, b) in
// let e = h(b, c) in
// g(d, e)
//
// (By def. of broadcasting pointwise ops applied to g, f and h)
// (Lemma 2. for a closed formula for the size of g = gs)
//
// let gs = B(a.s, b.s, c.s) in
// let d' = E(f^(E(a, B(a.s, b.s)), E(b, B(a.s, b.s))), gs) in
// let e' = E(h^(E(b, B(b.s, c.s)), E(c, B(b.s, c.s))), gs) in
// g^(d', e')
//
// (Lemma 3.)
//
// let gs = B(a.s, b.s, c.s) in
// let d' = f^(E(E(a, B(a.s, b.s)), gs), E(E(b, B(a.s, b.s)), gs)) in
// let e' = h^(E(E(b, B(b.s, c.s)), gs), E(E(c, B(b.s, c.s)), gs)) in
// g^(d', e')
//
// (Lemma 4. + Lemma 1. to simplify broadcasting function)
//
// let gs = B(a.s, b.s, c.s) in
// let d' = f^(E(a, gs), E(b, gs)) in
// let e' = h^(E(b, gs), E(c, gs)) in
// g^(d', e')
//
// (Simple rewrite)
//
// let gs = B(a.s, b.s, c.s) in
// let a' = E(a, gs) in
// let b' = E(b, gs) in
// let c' = E(c, gs) in
// let d' = f^(a', b') in
// let e' = h^(b', c') in
// g^(d', e')
//
// This example can be easily formalized to arbitrary DAGs using induction
// over topological ordering, similar to Lemma 2. Now, if broadcasting groups
// for all outputs have the same shape, then performing an expand to this size
// on all inputs will ensure that all intermediates on all paths to outputs
// will have the same shape, proving that the body of the kernel is valid.
//
// This shows the part until post-chunk-inputs. Extending it to pre-chunk-inputs
// is straightforward (needs a simple lemma for moving expands through chunks).
// Register implementations of fused operators, so that we can reuse the fused graph
// to generate fallback code.
RegisterOperators reg_fused_operators({
Operator(
prim::FusedChunk,
[](Node* node) {
int64_t dim = node->i(attr::dim);
int64_t chunks = node->outputs().size();
return [dim, chunks](Stack& stack) {
auto result = at::chunk(std::move(peek(stack, 0, 1)).toTensor(), chunks, dim);
drop(stack, 1);
pack(stack, std::move(result));
return 0;
};
}),
Operator(
prim::FusedConcat,
[](Node* node) {
int64_t dim = node->i(attr::dim);
int64_t num_inputs = node->inputs().size();
return [dim, num_inputs](Stack& stack) {
auto result = at::cat(
fmap(last(stack, num_inputs), [](const IValue& i) { return i.toTensor(); }),
dim
);
drop(stack, num_inputs);
pack(stack, std::move(result));
return 0;
};
})
});
FusedKernelCache::FusedKernelCache(FusionCompiler& compiler, std::shared_ptr<Graph> _graph, int device)
: device(device)
, fallback_code(_graph)
, compiler(compiler)
, graph(std::move(_graph))
, input_broadcast_groups(getInputBroadcastGroups())
, input_chunks(getInputChunkDescriptors())
, kernels() {}
std::atomic<size_t> FusedKernelCache::next_kernel_id {0};
auto FusedKernelCache::getInputChunkDescriptors() -> std::vector<PartitionInfo> {
std::vector<PartitionInfo> descs;
descs.reserve(graph->inputs().size());
for (Value * input : graph->inputs()) {
if (Node * chunk = usedInFusedChunk(input)) {
descs.emplace_back(chunk->i(attr::chunks), chunk->i(attr::dim));
} else {
descs.emplace_back(1, 0);
}
}
return descs;
}
// NB: this vector is really a set, but we want to keep it contiguous in memory for faster access
static std::vector<int64_t> getInputDependencies(Value* output) {
// Run a DFS traversal to find all inputs that affect a given output value
std::vector<Value*> queue { output };
std::unordered_set<Value*> inputs;
std::unordered_set<Value*> seen;
while (!queue.empty()) {
Value * val = queue.back(); queue.pop_back();
Node * producer = val->node();
if (producer->kind() == prim::Param) {
inputs.insert(val);
continue;
}
for (Value * input : producer->inputs()) {
if (/*bool inserted = */seen.insert(input).second) {
queue.push_back(input);
}
}
}
// Convert Value* into offsets into the graph's input list
std::vector<int64_t> offsets;
offsets.reserve(inputs.size());
for (Value * input : inputs) {
offsets.push_back(input->offset());
}
std::sort(offsets.begin(), offsets.end());
return offsets;
}
std::vector<std::vector<int64_t>> FusedKernelCache::getInputBroadcastGroups() {
std::unordered_set<std::vector<int64_t>, torch::hash<std::vector<int64_t>>> broadcast_groups;
for (Value * output : graph->outputs()) {
broadcast_groups.insert(getInputDependencies(output));
}
return std::vector<std::vector<int64_t>>{ broadcast_groups.begin(), broadcast_groups.end() };
}
void FusedKernelCache::run(Stack& stack) {
int64_t num_inputs = graph->inputs().size();
auto args = fmap(last(stack, num_inputs), [](const IValue& i) {
return i.toTensor();
});
auto maybe_map_size = canRunKernel(args);
if (!maybe_map_size) {
return runFallback(stack);
}
expandArgs(args, *maybe_map_size);
FusedKernelArgSpec spec { args };
auto it = kernels.find(spec);
if (it == kernels.end()) {
std::tie(it, std::ignore) = kernels.emplace(spec, compileSpec(spec, *maybe_map_size));
}
auto & fn = it->second;
std::vector<at::Tensor> outputs;
fn->launch(args, outputs);
drop(stack, num_inputs);
pack(stack, std::move(outputs));
}
at::optional<std::vector<int64_t>> FusedKernelCache::getMapSize(at::TensorList args, at::IntList arg_subset) {
int64_t dim_after_broadcast = 0;
for (int64_t arg_idx : arg_subset) {
dim_after_broadcast = std::max(dim_after_broadcast, args[arg_idx].dim());
}
// TODO: this keeps reallocating map_size at every iteration, but we know
// exactly how much storage do we need, so this could be fixed in-place at
// every step. We're just missing a few functions for ATen, but the fix
// should be straightforward.
// NB: we leave this uninitialized, because an empty size is trivially
// broadcastable to any other size.
std::vector<int64_t> map_size;
for (size_t i = 0; i < arg_subset.size(); ++i) {
auto & arg = args.at(arg_subset[i]);
auto & chunk_desc = input_chunks.at(arg_subset[i]);
if (chunk_desc.nSubtensors == 1) {
try {
map_size = at::infer_size(map_size, arg.sizes());
} catch (std::exception& e) {
return at::nullopt;
}
} else {
auto tensor_sizes = arg.sizes().vec();
int64_t num_chunks = chunk_desc.nSubtensors;
int64_t dim = at::maybe_wrap_dim(chunk_desc.dim, tensor_sizes.size());
if (tensor_sizes[dim] % num_chunks != 0) {
return at::nullopt;
}
tensor_sizes[dim] /= num_chunks;
try {
map_size = at::infer_size(map_size, tensor_sizes);
} catch (std::exception& e) {
return at::nullopt;
}
}
}
return {map_size};
}
// See Note [Run-time shape checking code] for more explanation on the algorithm.
at::optional<std::vector<int64_t>> FusedKernelCache::canRunKernel(at::TensorList args) {
AT_CHECK(args.size() == input_chunks.size(),
"Expected ", input_chunks.size(), " arguments, but got ", args.size());
at::optional<std::vector<int64_t>> map_size;
for (const auto & broadcast_group : input_broadcast_groups) {
if (!map_size) {
map_size = getMapSize(args, broadcast_group);
if (!map_size) {
return at::nullopt;
}
} else {
auto group_map_size = getMapSize(args, broadcast_group);
// NB: this checks that group_map_size is defined AND equal to map_size
if (map_size != group_map_size) {
return at::nullopt;
}
}
}
return map_size;
}
void FusedKernelCache::runFallback(Stack& stack) {
InterpreterState(fallback_code).runOneStage(stack);
}
// NB: args are mutated in this call. map_size is mutated too, but is restored to its original
// value before this function returns (it's an optimization).
void FusedKernelCache::expandArgs(std::vector<at::Tensor>& args, std::vector<int64_t>& map_size) {
for (size_t i = 0; i < args.size(); ++i) {
auto & arg = args[i];
auto & pdesc = input_chunks[i];
if (pdesc.nSubtensors == 1) {
if (arg.sizes().equals(map_size)) continue;
arg = arg.expand(map_size);
} else {
map_size.at(pdesc.dim) *= pdesc.nSubtensors;
if (!arg.sizes().equals(map_size)) {
arg = arg.expand(map_size);
}
map_size.at(pdesc.dim) /= pdesc.nSubtensors;
}
}
}
std::unique_ptr<FusedKernel> FusedKernelCache::compileSpec(
const FusedKernelArgSpec& spec, const std::vector<int64_t>& map_size) {
AnnotatedGraph agraph {*graph, device};
agraph.input_desc = spec.descs();
// XXX: this assumes that fused kernels only operate on floating-point values inside
at::optional<at::ScalarType> scalar_type;
for (TensorDesc& desc : agraph.input_desc) {
if (isFloatingType(desc.scalar_type)) {
scalar_type = desc.scalar_type;
break;
}
}
JIT_ASSERT(scalar_type);
for (Value * output : graph->outputs()) {
std::vector<int64_t> sizes = map_size;
if (output->node()->kind() == prim::FusedConcat) {
sizes.at(output->node()->i(attr::dim)) *= output->node()->inputs().size();
}
auto type = CompleteTensorType::create(*scalar_type, device, sizes);
agraph.output_desc.emplace_back(std::move(type));
}
std::string name = "kernel_" + std::to_string(next_kernel_id++);
FusedKernel * raw_func;
if (device != kCPUDevice) {
#ifdef USE_CUDA
raw_func = new CUDAFusedKernel(name, agraph);
#else
throw std::runtime_error("cannot compile a CUDA fusion group, CUDA is not enabled.");
#endif
} else {
JIT_ASSERT(compiler.canCompileOnCPU());
raw_func = new CPUFusedKernel(name, agraph, compiler.config_);
}
return std::unique_ptr<FusedKernel>(raw_func);
}
////////////////////////////////////////////////////////////////////////////////
// FusionCompiler
std::shared_ptr<FusedKernelCache> FusionCompiler::getOrCompile(Node* fusion_group) {
int device = fusion_group->i(attr::device);
if (device == kCPUDevice) {
JIT_ASSERT(canCompileOnCPU());
} else {
#ifndef USE_CUDA
throw std::runtime_error("cannot compile a CUDA fusion group - CUDA is not enabled.");
#endif
}
auto graph = fusion_group->g(attr::Subgraph)->copy();
EraseShapeInformation(*graph);
std::stringstream key;
key << "device " << device << "\n";
key << *graph << "\n";
std::string key_ = key.str();
auto it = cache_map.find(key_);
if (it == cache_map.end()) {
std::tie(it, std::ignore) = cache_map.emplace(key_, std::make_shared<FusedKernelCache>(*this, graph, device));
}
return it->second;
}
std::vector<at::Tensor> FusionCompiler::debugLaunchGraph(Graph & graph, int device, at::ArrayRef<at::Tensor> inputs) {
auto wrapper_graph = std::make_shared<Graph>();
Node * fusion_group = wrapper_graph->insertNode(wrapper_graph->createFusionGroup(device));
fusion_group->g_(attr::Subgraph, graph.copy());
for (size_t i = 0; i < graph.inputs().size(); ++i) {
fusion_group->addInput(wrapper_graph->addInput());
}
for (size_t i = 0; i < graph.outputs().size(); ++i) {
wrapper_graph->registerOutput(fusion_group->addOutput());
}
auto cache = getOrCompile(fusion_group);
Stack stack = fmap<IValue>(inputs);
cache->run(stack);
return fmap(stack, [](const IValue& iv) { return iv.toTensor(); });
}
static const std::string check_exists_string =
"which '${program}' > /dev/null";
static bool programExists(const std::string & program) {
TemplateEnv env;
env.s("program", program);
std::string cmd = format(check_exists_string, env);
return 0 == system(cmd.c_str());
}
FusionCompiler::FusionCompiler() {
const char * cxx_env = getenv("CXX");
if(cxx_env != nullptr) {
config_.cxx = cxx_env;
}
if(!programExists(config_.cxx)) {
config_.cxx = "";
}
const char * debug_env = getenv("PYTORCH_FUSION_DEBUG");
config_.debug = debug_env && atoi(debug_env) != 0;
}
//TODO: thread safety
FusionCompiler & sharedFusionCompiler() {
static FusionCompiler compiler;
return compiler;
}
}}
# else
// dummy implementations for windows
#include "torch/csrc/jit/fusion_compiler.h"
#include "torch/csrc/jit/ir.h"
#include "torch/csrc/jit/code_template.h"
#include "torch/csrc/jit/resource_guard.h"
#include "torch/csrc/utils/disallow_copy.h"
#include "ATen/ATen.h"
#ifdef USE_CUDA
#include "torch/csrc/cuda/cuda_check.h"
#include <nvrtc.h>
#include <cuda.h>
#include <cuda_runtime.h>
#endif
#include <string>
#include <algorithm>
#include <unordered_map>
#include <vector>
#include <sstream>
#include <iostream>
namespace torch { namespace jit {
struct FusedKernel {
char padding;
};
FusedKernelCache::FusedKernelCache(FusionCompiler& compiler, std::shared_ptr<Graph> graph, int device)
: compiler(compiler) {}
void FusedKernelCache::run(Stack& inputs) {}
void FusedKernelCache::runFallback(Stack& stack) {}
void FusedKernelCache::expandArgs(std::vector<at::Tensor>& args, std::vector<int64_t>& map_size) {}
at::optional<std::vector<int64_t>> FusedKernelCache::canRunKernel(at::TensorList args) { return at::nullopt; }
at::optional<std::vector<int64_t>> FusedKernelCache::getMapSize(at::TensorList args, at::IntList arg_subset) { return at::nullopt; }
std::vector<std::vector<int64_t>> FusedKernelCache::getInputBroadcastGroups() { return {}; }
auto FusedKernelCache::getInputChunkDescriptors() -> std::vector<PartitionInfo> { return {}; }
std::unique_ptr<FusedKernel> FusedKernelCache::compileSpec(
const FusedKernelArgSpec& spec, const std::vector<int64_t>& map_size) { return nullptr; }
std::atomic<size_t> FusedKernelCache::next_kernel_id {0};
FusionCompiler::FusionCompiler() {}
std::shared_ptr<FusedKernelCache> FusionCompiler::getOrCompile(Node* fusion_group) { return nullptr; }
std::vector<at::Tensor> FusionCompiler::debugLaunchGraph(Graph & graph, int device, at::ArrayRef<at::Tensor> inputs) { return {}; }
FusionCompiler & sharedFusionCompiler() {
throw std::runtime_error("NYI: fuser is not supported on Windows.");
}
}}
# endif