| #ifndef _WIN32 |
| #include "torch/csrc/jit/fusion_compiler.h" |
| #include "torch/csrc/jit/ir.h" |
| #include "torch/csrc/jit/code_template.h" |
| #include "torch/csrc/jit/resource_guard.h" |
| #include "torch/csrc/utils/disallow_copy.h" |
| #include "ATen/ATen.h" |
| #ifdef WITH_CUDA |
| #include "torch/csrc/cuda/cuda_check.h" |
| #include <nvrtc.h> |
| #include <cuda.h> |
| #include <cuda_runtime.h> |
| #endif |
| #include <string> |
| #include <algorithm> |
| #include <unordered_map> |
| #include <vector> |
| #include <sstream> |
| #include <iostream> |
| #include <dlfcn.h> |
| #include <unistd.h> |
| |
| namespace torch { namespace jit { |
| |
| std::unordered_map<NodeKind, std::string> simple_map_ops = { |
| // unary |
| {kabs, "absf(${0})"}, |
| {ksigmoid, "1.f / (1.f + expf(-${0}))"}, |
| {klog, "logf(${0})"}, |
| {klog1p, "log1pf(${0})"}, |
| {klgamma, "lgammaf(${0})"}, |
| {kexp, "expf(${0})"}, |
| {kexpm1, "expm1f(${0})"}, |
| {kcos, "cosf(${0})"}, |
| {kacos, "acosf(${0})"}, |
| {kcosh, "coshf(${0})"}, |
| {ksin, "sinf(${0})"}, |
| {kasin, "asinf(${0})"}, |
| {ksinh, "sinhf(${0})"}, |
| {ktan, "tanf(${0})"}, |
| {katan, "atanf(${0})"}, |
| {ktanh, "tanhf(${0})"}, |
| {ksqrt, "sqrtf(${0})"}, |
| {krsqrt, "rsqrtf(${0})"}, |
| {kceil, "ceilf(${0})"}, |
| {kfloor, "floorf(${0})"}, |
| {kround, "roundf(${0})"}, |
| {ktrunc, "truncf(${0})"}, |
| {kfrac, "fracf(${0})"}, |
| {kreciprocal, "reciprocalf(${0})"}, |
| {kneg, "-${0}"}, |
| //simple binary |
| {katan2, "atan2(${0}, ${1})"}, |
| {kmin, "fminf(${0}, ${1})"}, |
| {kmax, "fmaxf(${0}, ${1})"}, |
| |
| //binary with other |
| // TODO: some of these ops will not get generated because |
| // we only work on float inputs/outputs, but they are here to record |
| // that they are valid mappable ops once we handle more type |
| {k__and__, "${0} && ${1}"}, |
| {k__lshift__, "${0} << ${1}"}, |
| {k__or__, "${0} || ${1}"}, |
| {k__rshift__, "${0} >> ${1}"}, |
| {k__xor__, "${0} ^ ${1}"}, |
| {kdiv, "${0} / ${1}"}, |
| {keq, "${0} == ${1}"}, |
| {kfmod, "fmodf(${0}, ${1})"}, |
| {kge, "${0} >= ${1})"}, |
| {kgt, "${0} > ${1}"}, |
| {kle, "${0} <= ${1})"}, |
| {klt, "${0} < ${1}"}, |
| {kmul, "${0} * ${1}"}, |
| {kne, "${0} != ${1}"}, |
| {kremainder, "remainderf(${0}, ${1})"}, |
| {kpow, "powf(${0}, ${1})"}, |
| |
| //alpha |
| {kadd, "${0} + ${alpha}*${1}"}, |
| {ksub, "(${0} - ${alpha}*${1})"}, |
| |
| // special |
| {klerp, "${0} + ${weight}*(${1} - ${0})"}, |
| {kclamp, "min(max(${0},${min}),${max})"}, |
| |
| // simple derivatives |
| {"_sigmoid_backward"_sym, "${0} * ${1} * (1.f - ${1})"}, |
| {"_tanh_backward"_sym, "${0} * (1.f - ${1} * ${1})"}, |
| }; |
| |
| std::vector<bool> TensorDesc::findContiguous( |
| const at::IntList& sizes, |
| const at::IntList& strides) { |
| JIT_ASSERT(sizes.size() == strides.size()); |
| std::vector<bool> cont(sizes.size()); |
| for(size_t i = 0; i < sizes.size(); ++i) { |
| int64_t expected_stride = (i + 1 < sizes.size()) ? sizes[i+1]*strides[i+1] : 1; |
| cont[i] = strides[i] == expected_stride; |
| } |
| return cont; |
| } |
| |
| namespace { |
| |
| static int ceilDiv(int a, int b) { |
| return (a + b - 1) / b; |
| } |
| |
| std::ostream& operator<<(std::ostream & out, const TensorDesc & d) { |
| out << d.scalar_type << "["; |
| for(auto b : d.contiguity) |
| out << b << ";"; |
| out << "]"; |
| return out; |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| // Code generation |
| |
| namespace codegen { |
| |
| auto type_declarations_template = CodeTemplate(R"( |
| typedef ${IndexType} IndexType; |
| template<typename T, size_t N> |
| struct TensorInfo { |
| T * data; |
| IndexType sizes[N]; |
| IndexType strides[N]; |
| }; |
| )"); |
| |
| auto cuda_compilation_unit_template = CodeTemplate(R"( |
| ${type_declarations} |
| |
| extern "C" __global__ |
| void ${kernelName}(IndexType totalElements, ${formals}) { |
| for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x; |
| linearIndex < totalElements; |
| linearIndex += gridDim.x * blockDim.x) { |
| // Convert `linearIndex` into an offset of tensor: |
| ${tensorOffsets} |
| // calculate the results |
| ${kernelBody} |
| } |
| } |
| )"); |
| |
| auto cpu_compilation_unit_template = CodeTemplate(R"( |
| #include <cstddef> |
| #include <math.h> |
| #include <iostream> |
| ${type_declarations} |
| |
| #define OMP_THRESHOLD 100000 |
| static void ${kernelName}_kernel(IndexType totalElements, ${formals}) { |
| #pragma omp parallel for if(totalElements > OMP_THRESHOLD) |
| for (IndexType linearIndex = 0; |
| linearIndex < totalElements; |
| linearIndex += 1) { |
| // Convert `linearIndex` into an offset of tensor: |
| ${tensorOffsets} |
| // calculate the results |
| ${kernelBody} |
| } |
| } |
| |
| extern "C" |
| void ${kernelName}(IndexType totalElements, void ** args) { |
| ${kernelName}_kernel(totalElements ${,argument_loads}); |
| } |
| )"); |
| |
| // curDimIndex = linearId % sizes[i]; // % sizes[i] is not needed for d == 0, because we already guard for numel outside the index calculation |
| // offset += curDimIndex*strides[i]; // *strides[i] is optional if list_is_cont becaause strides.back() == 1 |
| // linearId /= sizes[i]; |
| auto dim_calc = CodeTemplate(R"( |
| //printf("tensor ${tensor} sizes[${d}] = %d, strides[${d}] = %d\n", ${tensor}.sizes[${d}],${tensor}.strides[${d}]); |
| size_t ${tensor}_dimIndex${d} = ${tensor}_linearIndex ${mod_sizes}; |
| ${tensor}_offset += ${tensor}_dimIndex${d} ${times_stride}; |
| )"); |
| |
| void emitIndexingFor(std::ostream & out, const std::string & tensor, int ndim, bool last_is_cont) { |
| TemplateEnv env; |
| env.s("tensor",tensor); |
| out << format("IndexType ${tensor}_offset = 0;\n",env); |
| out << format("IndexType ${tensor}_linearIndex = linearIndex;\n",env); |
| for(int d = ndim - 1; d >= 0; --d) { |
| env.d("d",d); |
| env.s("mod_sizes", d > 0 ? format("% ${tensor}.sizes[${d}]",env) : ""); |
| env.s("times_stride",(d < ndim - 1 || !last_is_cont) ? |
| format("* ${tensor}.strides[${d}]",env) : ""); |
| out << dim_calc.format(env); |
| if(d > 0) { |
| out << format("${tensor}_linearIndex /= ${tensor}.sizes[${d}];\n",env); |
| } |
| } |
| } |
| |
| std::string valueName(Value * n) { |
| return "n" + std::to_string(n->unique()); |
| } |
| |
| std::string scalarValue(const at::Tensor & t) { |
| auto s = at::Scalar(t); |
| return (s.isIntegral()) ? |
| std::to_string(s.toLong()) : |
| (std::to_string(s.toDouble()) + "f"); |
| } |
| |
| const char * scalarTypeName(at::ScalarType type) { |
| switch(type) { |
| #define DEFINE_CASE(ctype,name,_) \ |
| case at::ScalarType::name: return #ctype; |
| AT_FORALL_SCALAR_TYPES(DEFINE_CASE) |
| #undef DEFINE_CASE |
| default: |
| throw std::runtime_error("unknown scalar type"); |
| } |
| } |
| |
| std::string encodeRHS(Node * n) { |
| TemplateEnv env; |
| size_t i = 0; |
| for(auto in : n->inputs()) { |
| env.s(std::to_string(i++),valueName(in)); |
| } |
| // ops like div have a / b or a / 2 with the constant having the attribute other |
| // so we add other as an input if it is present |
| // 'pow' is the same but uses exponent as the attribute, so we handle that here as well |
| if(n->hasAttribute(kother) || n->hasAttribute(kexponent)) { |
| env.s(std::to_string(i), scalarValue(n->t(kother))); |
| } |
| // we also add any other scalar tensors to the env for special ops |
| for(auto a : n->attributeNames()) { |
| if(n->kindOf(a) == AttributeKind::t) { |
| auto v = n->t(a); |
| if(v.dim() == 0) { |
| env.s(symbolToString(a), scalarValue(v)); |
| } |
| } |
| } |
| const auto & str = simple_map_ops.at(n->kind()); |
| return format(str, env); |
| } |
| |
| std::vector<ConcatDesc> emitCompilationUnit(std::ostream & out, |
| const std::string & name, |
| AnnotatedGraph & agraph, |
| bool use_cuda) { |
| Graph& subgraph = *agraph.graph; |
| TemplateEnv env; |
| env.s("kernelName",name); |
| // TODO: handle cases where we need to generate > 2^32 element tensors |
| env.s("IndexType","unsigned int"); //avoiding slow header includes to get uint32_t |
| |
| std::stringstream body; |
| std::stringstream tensorOffsets; |
| std::vector<std::string> formals; |
| std::vector<std::string> argument_loads; |
| auto emitFormal = [&](Value * n, const TensorDesc & desc) { |
| std::string tensor = "t" + std::to_string(formals.size()); //can't be unique() because Param may be an output |
| size_t nDim = desc.nDim(); |
| emitIndexingFor(tensorOffsets, tensor, nDim, desc.lastIsContiguous()); |
| env.s("tensor",tensor); |
| env.d("formal_index", formals.size() + 1); // + 1 because the first argument is the linearIndex |
| env.d("nDim",nDim); |
| env.s("scalar_type",scalarTypeName(desc.scalar_type)); |
| formals.push_back(format("TensorInfo<${scalar_type},${nDim}> ${tensor}",env)); |
| argument_loads.push_back(format("*static_cast<TensorInfo<${scalar_type},${nDim}>*>(args[${formal_index}])",env)); |
| }; |
| { |
| size_t i = 0; |
| for(auto p : subgraph.inputs()) |
| emitFormal(p,agraph.input_desc[i++]); |
| } |
| std::vector<ConcatDesc> concat_desc; |
| std::vector<Value*> flat_output_nodes; |
| { |
| size_t i = 0; |
| for(auto o : subgraph.outputs()) { |
| auto & desc = agraph.output_desc[i++]; |
| if(o->node()->kind() != kcat) { |
| emitFormal(o, desc); |
| concat_desc.emplace_back(); |
| flat_output_nodes.push_back(o); |
| } else { |
| auto cat = o->node(); |
| size_t nInputs = cat->inputs().size(); |
| concat_desc.emplace_back(desc, nInputs, cat->i(kdim)); |
| for(auto c : cat->inputs()) { |
| emitFormal(c, *concat_desc.back().subtensorDesc); |
| flat_output_nodes.push_back(c); |
| } |
| } |
| } |
| } |
| size_t formal_count = 0; |
| for(auto p : subgraph.inputs()) { |
| env.s("node",valueName(p)); |
| env.d("formal",formal_count++); |
| env.s("access",format("t${formal}.data[t${formal}_offset]",env)); |
| //TODO: actual type propagation rather than relying on auto.. |
| body << format("auto ${node} = ${access};\n",env); |
| } |
| for(auto n : subgraph.nodes()) { |
| if(n->kind() == kcat) |
| continue; // Concat nodes by narrowing the output Tensors before the kernel runs |
| env.s("node",valueName(n->output())); |
| env.s("rhs", encodeRHS(n)); |
| body << format("auto ${node} = ${rhs};\n",env); |
| } |
| for(auto o : flat_output_nodes) { |
| env.d("formal",formal_count++); |
| env.s("access",format("t${formal}.data[t${formal}_offset]",env)); |
| env.s("node",valueName(o)); |
| body << format("${access} = ${node};\n",env); |
| } |
| env.s("tensorOffsets",tensorOffsets.str()); |
| env.s("kernelBody",body.str()); |
| env.v("formals",formals); |
| env.v("argument_loads",argument_loads); |
| env.s("type_declarations", type_declarations_template.format(env)); |
| if(use_cuda) { |
| out << cuda_compilation_unit_template.format(env); |
| } else { |
| out << cpu_compilation_unit_template.format(env); |
| } |
| return concat_desc; |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| |
| } // codegen namespace |
| } // anonymous namespace |
| |
| // Host-side view of TensorInfo (that visivle for the kernel is defined above). |
| // Note dims[0] - we need to dynamically allocate the dims. |
| struct TensorInfo { |
| void * data; |
| uint32_t sizes_strides[0]; |
| |
| uint32_t* sizes(size_t nDim) { return &sizes_strides[0]; } |
| uint32_t* strides(size_t nDim) { return &sizes_strides[nDim]; } |
| }; |
| |
| CompiledFusionFunction::CompiledFusionFunction(const std::string & name, AnnotatedGraph & agraph) |
| : name(name) |
| , input_desc(agraph.input_desc) |
| , output_desc(agraph.output_desc) {} |
| |
| namespace { |
| |
| // Tries to compress sizes and strides according to cont. Emits the result t |
| // c_sizes, c_strides and throws an error on failure (if can't compress) |
| void compressContiguous( |
| at::IntList sizes, |
| at::IntList strides, |
| const std::vector<bool> & cont, |
| uint32_t * c_sizes, |
| uint32_t * c_strides) { |
| size_t compressed_dims = 0; |
| size_t cur = 0; |
| size_t ndim = sizes.size(); |
| while(cur < ndim) { |
| size_t total_size = sizes[cur]; |
| cur++; |
| while(cont[cur-1] && cur < ndim) { |
| JIT_ASSERT(strides[cur-1] == sizes[cur]*strides[cur]); |
| total_size *= sizes[cur]; |
| cur++; |
| } |
| // cur starts pointing at the beginning of run to compress |
| // cur ends one _after_ the terminating false or end of list. |
| // total_size is the size of all dimensions [begin,end) |
| // examples: |
| // f = not cont. |
| // t = cont. |
| // x = don't care, including past end of list |
| // s = start of cur |
| // e = end of cur |
| |
| |
| // f x x x |
| // s e |
| |
| // t f x x |
| // s e |
| |
| // t t f x |
| // s e |
| |
| c_sizes[compressed_dims] = total_size; |
| c_strides[compressed_dims] = strides[cur-1]; |
| compressed_dims++; |
| } |
| JIT_ASSERT(!cont.back() || strides.back() == 1); |
| } |
| |
| } // anonymous namespace |
| |
| void CompiledFusionFunction::launch_with_tensors(at::ArrayRef<at::Tensor> inputs, at::ArrayRef<at::Tensor> outputs) { |
| AutoGPU gpu_guard(inputs); |
| JIT_ASSERT(inputs.size() == input_desc.size()); |
| JIT_ASSERT(outputs.size() == output_desc.size()); |
| size_t flat_outputs_size = 0; |
| for(auto & c : concat_desc) |
| flat_outputs_size += c.nSubtensors; |
| // XXX: this code assumes that inputs are 32-bit addressable |
| // XXX: this code assumes that all inputs are of the same size |
| JIT_ASSERT(inputs[0].numel() <= std::numeric_limits<uint32_t>::max()); |
| uint32_t numel = inputs[0].numel(); |
| at::IntList map_size = inputs[0].sizes(); |
| // Compute the storage needed to store TensorInfo structs for inputs and outputs. |
| size_t uncompressedDim = input_desc.at(0).contiguity.size(); |
| size_t maxPossibleTensorInfoSize = sizeof(TensorInfo) + 2 * sizeof(uint32_t) * uncompressedDim; |
| size_t maxPossibleBufferSize = maxPossibleTensorInfoSize * (inputs.size() + flat_outputs_size); |
| std::vector<char> buffer(maxPossibleBufferSize); |
| char * buffer_next = buffer.data(); |
| // A vector of arguments to the kernel. It's (numel, *input_descs, *output_descs) |
| std::vector<void*> arguments; |
| arguments.reserve(1 + inputs.size() + flat_outputs_size); |
| // Asserts that t's dims can be compressed in the same way as in desc |
| // (that's what the kernel assumes), and appends it to the arguments vector. |
| auto addTensorInfo = [&](TensorDesc & desc, const at::Tensor & t) { |
| size_t nDim = desc.nDim(); // NOTE: this is the compressed dim |
| JIT_ASSERT(nDim <= uncompressedDim); // We'd overflow the space otherwise |
| auto ti = reinterpret_cast<TensorInfo*>(buffer_next); |
| ti->data = t.data_ptr(); |
| compressContiguous(t.sizes(), t.strides(), desc.contiguity, ti->sizes(nDim), ti->strides(nDim)); |
| buffer_next += maxPossibleTensorInfoSize; |
| arguments.push_back(ti); |
| }; |
| arguments.push_back(&numel); |
| for (std::size_t i = 0; i < input_desc.size(); ++i) |
| addTensorInfo(input_desc[i], inputs[i]); |
| for (std::size_t i = 0; i < output_desc.size(); ++i) { |
| auto & c = concat_desc[i]; |
| at::Tensor o = outputs[i]; |
| if(c.nSubtensors == 1) { |
| o.resize_(map_size); |
| addTensorInfo(output_desc[i], outputs[i]); |
| } else { |
| size_t small_size = map_size[c.dim]; |
| std::vector<int64_t> concat_size(map_size.begin(), map_size.end()); |
| concat_size[c.dim] = small_size * c.nSubtensors; |
| o.resize_(concat_size); |
| size_t offset = 0; |
| for(size_t j = 0; j < c.nSubtensors; ++j) { |
| // because the concatenated_output stays live, the underlying data |
| // in this view remains live through the end of this function |
| // so there is not need to hold onto this tensor |
| auto view = o.narrow(c.dim, offset, small_size); |
| addTensorInfo(*c.subtensorDesc, view); |
| offset += small_size; |
| } |
| } |
| } |
| launch_raw(numel, arguments.data()); |
| } |
| |
| void CompiledFusionFunction::launch(at::ArrayRef<at::Tensor> inputs, std::vector<at::Tensor> & outputs) { |
| AutoGPU guard(inputs.back()); |
| outputs.clear(); |
| outputs.reserve(outputDescriptors().size()); |
| for(auto & od : outputDescriptors()) { |
| outputs.push_back(at::getType(backend(),od.scalar_type).tensor()); |
| } |
| launch_with_tensors(inputs, outputs); |
| } |
| |
| #ifdef WITH_CUDA |
| |
| void checkCUDAVersion(const cudaDeviceProp & prop) { |
| if ((prop.major >= 6 && CUDA_VERSION < 8000) || |
| (prop.major >= 7 && CUDA_VERSION < 9000)) { |
| std::stringstream err_string; |
| err_string << "In CompiledFusionFunction, PyTorch compiled with insufficient CUDA version: " |
| << CUDA_VERSION << " for the current GPU device " << prop.name |
| << " with device capability " << prop.major << "." << prop.minor; |
| throw std::runtime_error(err_string.str()); |
| } |
| } |
| |
| struct CUDAFusionFunction : public CompiledFusionFunction { |
| CUDAFusionFunction(const std::string & name, AnnotatedGraph & agraph) |
| : CompiledFusionFunction(name, agraph) { |
| AutoGPU gpu_guard(agraph.device); |
| |
| TORCH_CUDA_CHECK(cudaGetDeviceProperties(&prop, agraph.device)); |
| checkCUDAVersion(prop); |
| |
| std::stringstream cu; |
| concat_desc = codegen::emitCompilationUnit(cu, name, agraph, true); |
| compilation_unit = cu.str(); |
| nvrtcProgram program; |
| TORCH_NVRTC_CHECK(nvrtcCreateProgram(&program, compilation_unit.c_str(), NULL, 0, nullptr, nullptr)); |
| |
| std::string compute = "--gpu-architecture=compute_" + std::to_string(prop.major) + std::to_string(prop.minor); |
| std::vector<const char *> args = {"--std=c++11", compute.c_str()}; |
| nvrtcResult result = nvrtcCompileProgram(program, args.size(), args.data()); |
| if (result == NVRTC_ERROR_COMPILATION) { |
| size_t logsize; |
| nvrtcGetProgramLogSize(program, &logsize); |
| std::vector<char> log(logsize); |
| nvrtcGetProgramLog(program, log.data()); |
| cu << log.data(); |
| throw std::runtime_error(cu.str()); |
| } |
| ResourceGuard holdProgram([&] { |
| TORCH_NVRTC_CHECK(nvrtcDestroyProgram(&program)); |
| }); |
| TORCH_NVRTC_CHECK(result); |
| |
| size_t ptx_size; |
| TORCH_NVRTC_CHECK(nvrtcGetPTXSize(program, &ptx_size)); |
| ptx.resize(ptx_size); |
| TORCH_NVRTC_CHECK(nvrtcGetPTX(program, ptx.data())); |
| |
| TORCH_CU_CHECK(cuModuleLoadData(&module, ptx.data())); |
| TORCH_CU_CHECK(cuModuleGetFunction(&function, module, name.c_str())); |
| |
| TORCH_CU_CHECK(cuOccupancyMaxActiveBlocksPerMultiprocessor( |
| &maxBlocks, function, 128, 0)); |
| maxBlocks *= prop.multiProcessorCount; |
| } |
| virtual ~CUDAFusionFunction() override { |
| TORCH_CU_CHECK(cuModuleUnload(module)); |
| } |
| protected: |
| virtual at::Backend backend() const override { |
| return at::kCUDA; |
| } |
| virtual void launch_raw(uint32_t numel, void ** arguments) override { |
| int numBlocks = std::min(maxBlocks, ceilDiv(numel, blockSize)); |
| //std::cout << "maxBlocks = " << maxBlocks << " needed blocks: " << ceilDiv(numel,blockSize) |
| // << " numblocks = " << numBlocks; |
| |
| // it is possible that this is the first cuda call on this thread |
| // so make sure we initialize the Driver API's context |
| // cudaFree(0) accomplishes this. |
| cudaFree(0); |
| |
| TORCH_CU_CHECK(cuLaunchKernel( |
| function, |
| numBlocks, 1, 1, |
| blockSize, 1, 1, |
| 0, nullptr, |
| arguments, |
| nullptr)); |
| } |
| std::vector<char> ptx; |
| CUmodule module; |
| CUfunction function; |
| |
| // we record prop/device so if they are availiable for launch heuristics |
| // querying at launch is too slow for device properties. |
| int device; |
| cudaDeviceProp prop; |
| int blockSize = 128; |
| int maxBlocks; |
| }; |
| |
| #endif |
| |
| struct TempFile { |
| TH_DISALLOW_COPY_AND_ASSIGN(TempFile); |
| TempFile(const std::string & t, int suffix) { |
| // mkstemps edits its first argument in places |
| // so we make a copy of the string here, including null terminator |
| std::vector<char> tt(t.c_str(), t.c_str() + t.size() + 1); |
| int fd = mkstemps(tt.data(), suffix); |
| JIT_ASSERT(fd != -1); |
| file_ = fdopen(fd, "r+"); |
| |
| // - 1 becuase tt.size() includes the null terminator, |
| // but std::string does not expect one |
| name_ = std::string(tt.begin(), tt.end() - 1); |
| } |
| const std::string & name() const { |
| return name_; |
| } |
| void sync() { |
| fflush(file_); |
| } |
| void write(const std::string & str) { |
| size_t result = fwrite(str.c_str(), 1, str.size(), file_); |
| JIT_ASSERT(str.size() == result); |
| } |
| FILE* file() { |
| return file_; |
| } |
| ~TempFile() { |
| if(file_ != nullptr) { |
| // unlink first to ensure another mkstemps doesn't |
| // race between close and unlink |
| unlink(name_.c_str()); |
| fclose(file_); |
| } |
| } |
| private: |
| FILE * file_ = nullptr; |
| std::string name_; |
| }; |
| |
| static void* checkDL(void * x) { |
| if(!x) { |
| barf("error in dlopen or dlsym: %s", dlerror()); |
| } |
| return x; |
| } |
| |
| struct DynamicLibrary { |
| TH_DISALLOW_COPY_AND_ASSIGN(DynamicLibrary); |
| DynamicLibrary(const char * name) { |
| handle = checkDL(dlopen(name, RTLD_LOCAL | RTLD_NOW)); |
| } |
| void * sym(const char * name) { |
| JIT_ASSERT(handle); |
| return checkDL(dlsym(handle, name)); |
| } |
| ~DynamicLibrary() { |
| if(!handle) return; |
| int r = dlclose(handle); |
| if(r) { |
| barf("error in dlclose: %s", dlerror()); |
| } |
| } |
| private: |
| void * handle = nullptr; |
| }; |
| |
| static const std::string so_template = "/tmp/pytorch_fuserXXXXXX.so"; |
| static const std::string cpp_template = "/tmp/pytorch_fuserXXXXXX.cpp"; |
| |
| // NB: -march=native not supported on PPC64 g++. It's a bit annoying |
| // to do a configure-style test to decide whether or not the g++ |
| // actually supports it or not, so we heuristically use the host |
| // compiler to predict if the runtime compiler supports the option we |
| // want. This probably won't work if you're cross-compiling. |
| static const std::string compile_string = |
| "\"${cxx}\" -O3 -g " |
| #ifndef __PPC64__ |
| "-march=native " |
| #endif |
| "-std=c++11 -fPIC ${fopenmp} -shared \"${cpp_file}\" -o \"${so_file}\""; |
| |
| static void runCompiler(FusionCompilerConfig & config, const std::string & cpp_file, const std::string & so_file) { |
| TemplateEnv env; |
| env.s("cxx", config.cxx); |
| env.s("fopenmp", config.openmp ? "-fopenmp" : ""); |
| env.s("cpp_file",cpp_file); |
| env.s("so_file",so_file); |
| std::string result = format(compile_string,env); |
| int r = system(result.c_str()); |
| if(config.openmp && r != 0) { |
| std::cerr << "warning: pytorch jit fuser failed to compile with openmp, trying without it...\n"; |
| config.openmp = false; // disable for future compiles |
| return runCompiler(config, cpp_file, so_file); |
| } |
| JIT_ASSERT(r == 0); |
| } |
| |
| |
| static const std::string disas_string = |
| "objdump -M intel -d \"${so_file}\""; |
| static void disas(const std::string & so_file) { |
| TemplateEnv env; |
| env.s("so_file", so_file); |
| std::string cmd = format(disas_string, env); |
| int r = system(cmd.c_str()); |
| JIT_ASSERT(r == 0); |
| } |
| |
| struct CPUFusionFunction : public CompiledFusionFunction { |
| CPUFusionFunction(const std::string & name, AnnotatedGraph & agraph, FusionCompilerConfig & config) |
| : CompiledFusionFunction(name, agraph) { |
| TempFile so_file(so_template, 3); |
| TempFile cpp_file(cpp_template, 4); |
| |
| std::stringstream cu; |
| concat_desc = codegen::emitCompilationUnit(cu, name, agraph, false); |
| compilation_unit = cu.str(); |
| cpp_file.write(compilation_unit); |
| cpp_file.sync(); |
| runCompiler(config, cpp_file.name(), so_file.name()); |
| if(config.debug) { |
| std::cout << compilation_unit << "\n"; |
| disas(so_file.name()); |
| } |
| so_lib.reset(new DynamicLibrary(so_file.name().c_str())); |
| kernel = reinterpret_cast<void(*)(uint32_t, void**)>(so_lib->sym(name.c_str())); |
| } |
| protected: |
| virtual at::Backend backend() const override { |
| return at::kCPU; |
| } |
| virtual void launch_raw(uint32_t numel, void ** arguments) override { |
| kernel(numel, arguments); |
| } |
| std::unique_ptr<DynamicLibrary> so_lib; |
| void (*kernel)(uint32_t, void**) = nullptr; |
| }; |
| |
| std::shared_ptr<CompiledFusionFunction> FusionCompiler::getOrCompile(AnnotatedGraph & agraph) { |
| std::stringstream key; |
| key << *agraph.graph << "\n"; |
| key << "device " << agraph.device << "\n"; |
| for(auto & i : agraph.input_desc) |
| key << i << "\n"; |
| for(auto & i : agraph.output_desc) |
| key << i << "\n"; |
| std::string key_ = key.str(); |
| |
| auto it = cache.find(key_); |
| if (it == cache.end()) { |
| std::string name = "kernel_" + std::to_string(cache.size()); |
| CompiledFusionFunction * raw_func; |
| if(agraph.device != kCPUDevice) { |
| #ifdef WITH_CUDA |
| raw_func = new CUDAFusionFunction(name, agraph); |
| #else |
| throw std::runtime_error("cannot compile a CUDA fusion group, CUDA is not enabled."); |
| #endif |
| } else { |
| JIT_ASSERT(canCompileOnCPU()); |
| raw_func = new CPUFusionFunction(name, agraph, config_); |
| } |
| it = cache.emplace(key_, std::shared_ptr<CompiledFusionFunction>(raw_func)).first; |
| } |
| return it->second; |
| } |
| |
| std::shared_ptr<CompiledFusionFunction> FusionCompiler::getOrCompile(Node* fusion_group) { |
| auto & graph = *fusion_group->g(kSubgraph); |
| AnnotatedGraph agraph(graph, fusion_group->i(kdevice)); |
| for(auto & input : graph.inputs()) { |
| auto t = input->type()->expect<TensorType>(); |
| agraph.input_desc.emplace_back(t); |
| } |
| for(auto & output : graph.outputs()) { |
| auto t = output->type()->expect<TensorType>(); |
| agraph.output_desc.emplace_back(t); |
| } |
| return getOrCompile(agraph); |
| } |
| |
| |
| std::shared_ptr<CompiledFusionFunction> FusionCompiler::getOrCompile(Graph & graph, |
| int device, |
| at::ArrayRef<at::Tensor> inputs, |
| at::ArrayRef<at::Tensor> outputs) { |
| AnnotatedGraph agraph(graph, device); |
| for(auto & i : inputs) { |
| agraph.input_desc.emplace_back(i); |
| } |
| for(auto & i : outputs) { |
| agraph.output_desc.emplace_back(i); |
| } |
| return getOrCompile(agraph); |
| } |
| |
| void FusionCompiler::debugLaunchGraph(Graph & graph, int device, at::ArrayRef<at::Tensor> inputs, at::ArrayRef<at::Tensor> outputs) { |
| auto func = getOrCompile(graph, device, inputs, outputs); |
| func->launch_with_tensors(inputs, outputs); |
| } |
| |
| static const std::string check_exists_string = |
| "which '${program}' > /dev/null"; |
| |
| static bool programExists(const std::string & program) { |
| TemplateEnv env; |
| env.s("program", program); |
| std::string cmd = format(check_exists_string, env); |
| return 0 == system(cmd.c_str()); |
| } |
| |
| FusionCompiler::FusionCompiler() { |
| const char * cxx_env = getenv("CXX"); |
| if(cxx_env != nullptr) { |
| config_.cxx = cxx_env; |
| } |
| if(!programExists(config_.cxx)) { |
| config_.cxx = ""; |
| } |
| const char * debug_env = getenv("PYTORCH_FUSION_DEBUG"); |
| config_.debug = debug_env && atoi(debug_env) != 0; |
| } |
| |
| //TODO: thread safety |
| FusionCompiler & sharedFusionCompiler() { |
| static FusionCompiler compiler; |
| return compiler; |
| } |
| |
| }} |
| |
| # else |
| // dummy implementations for windows |
| |
| #include "torch/csrc/jit/fusion_compiler.h" |
| #include "torch/csrc/jit/ir.h" |
| #include "torch/csrc/jit/code_template.h" |
| #include "torch/csrc/jit/resource_guard.h" |
| #include "torch/csrc/utils/disallow_copy.h" |
| #include "ATen/ATen.h" |
| #ifdef WITH_CUDA |
| #include "torch/csrc/cuda/cuda_check.h" |
| #include <nvrtc.h> |
| #include <cuda.h> |
| #include <cuda_runtime.h> |
| #endif |
| #include <string> |
| #include <algorithm> |
| #include <unordered_map> |
| #include <vector> |
| #include <sstream> |
| #include <iostream> |
| |
| namespace torch { namespace jit { |
| |
| CompiledFusionFunction::CompiledFusionFunction(const std::string & name, AnnotatedGraph & agraph) {} |
| |
| void CompiledFusionFunction::launch_with_tensors(at::ArrayRef<at::Tensor> inputs, at::ArrayRef<at::Tensor> outputs) {} |
| |
| void CompiledFusionFunction::launch(at::ArrayRef<at::Tensor> inputs, std::vector<at::Tensor> & outputs) {} |
| |
| std::shared_ptr<CompiledFusionFunction> FusionCompiler::getOrCompile(AnnotatedGraph & agraph) { |
| return nullptr; |
| } |
| |
| std::shared_ptr<CompiledFusionFunction> FusionCompiler::getOrCompile(Node* fusion_group) { |
| return nullptr; |
| } |
| |
| |
| std::shared_ptr<CompiledFusionFunction> FusionCompiler::getOrCompile(Graph & graph, |
| int device, |
| at::ArrayRef<at::Tensor> inputs, |
| at::ArrayRef<at::Tensor> outputs) { |
| return nullptr; |
| } |
| |
| void FusionCompiler::debugLaunchGraph(Graph & graph, int device, at::ArrayRef<at::Tensor> inputs, at::ArrayRef<at::Tensor> outputs) {} |
| |
| FusionCompiler::FusionCompiler() {} |
| |
| FusionCompiler & sharedFusionCompiler() { |
| throw std::runtime_error("NYI: fuser is not supported on Windows."); |
| } |
| |
| }} |
| |
| # endif |