blob: 2ccacf1d116ac9f34f388bbbb96b3431df0380fd [file] [log] [blame]
#pragma once
#include <torch/csrc/jit/ir.h>
#include "torch/csrc/jit/DisallowCopy.h"
#include "ATen/ATen.h"
#include <cuda.h>
#include <cuda_runtime.h>
#include <string>
#include <algorithm>
#include <unordered_map>
#include <vector>
namespace torch { namespace jit {
std::vector<bool> findContiguous(
at::IntList sizes,
at::IntList strides);
// type information needed by the compiler for input/outputs
// contiguity[i] is true if the dim i is contiguous with dim i + 1.
// contiguity.back() == true means strides.back() == 1.
struct TensorDesc {
at::ScalarType scalar_type;
std::vector<bool> contiguity;
TensorDesc(at::ScalarType scalar_type, const std::vector<bool> & contiguity)
: scalar_type(scalar_type), contiguity(contiguity) {
calcDim();
}
TensorDesc(const at::Tensor & t)
: scalar_type(t.type().scalarType()),
contiguity(findContiguous(t.sizes(), t.strides())) {
calcDim();
}
// number of dimensions after contiguity compression
size_t nDim() const {
return nDim_;
}
// do we have inner stride == 1?
bool lastIsContiguous() const {
return contiguity.size() == 0 || contiguity.back();
}
private:
size_t nDim_;
void calcDim() {
nDim_ = std::count(contiguity.begin(), contiguity.end(), false)
+ (lastIsContiguous() ? 1 : 0);
}
};
// short-term storage only, so it borrows Graph.
// this type is probably temporary.
// it will be replaced when the needed TensorDesc information is encoded
// directly in the information in the IR (e.g. in the Type object)
struct AnnotatedGraph {
Graph* graph;
std::vector<TensorDesc> input_desc;
std::vector<TensorDesc> output_desc;
};
struct CompiledFusionFunction {
TH_DISALLOW_COPY_AND_ASSIGN(CompiledFusionFunction);
CompiledFusionFunction(const std::string & name, AnnotatedGraph & agraph);
~CompiledFusionFunction();
void launch(at::ArrayRef<at::Tensor> inputs, at::ArrayRef<at::Tensor> outputs);
private:
void launch(uint32_t numel, void ** arguments);
std::string name;
//we keep these around for debugging
std::string compliation_unit;
std::vector<char> ptx;
CUmodule module;
CUfunction function;
// we record prop/device so if they are availiable for launch heuristics
// querying at launch is too slow for device properties.
int device;
cudaDeviceProp prop;
int blockSize = 128;
int maxBlocks;
std::vector<TensorDesc> input_desc;
std::vector<TensorDesc> output_desc;
};
// caching compiler
struct FusionCompiler {
TH_DISALLOW_COPY_AND_ASSIGN(FusionCompiler);
FusionCompiler();
std::shared_ptr<CompiledFusionFunction> getOrCompile(AnnotatedGraph & agraph);
// debugging function that lets you do everything from compilation to execution
// in one step.
// this should not be used in the hot path of execution because it has to serialize
// the graph each time
void debugLaunchGraph(Graph & graph, at::ArrayRef<at::Tensor> inputs, at::ArrayRef<at::Tensor> outputs);
private:
std::unordered_map<std::string, std::shared_ptr<CompiledFusionFunction> > cache;
};
}}