blob: a24fe0817e0517e068571a30332f3c800bf6673d [file] [log] [blame]
#pragma once
#include <torch/csrc/jit/ir.h>
#include "torch/csrc/utils/disallow_copy.h"
#include "torch/csrc/utils/hash.h"
#include <torch/csrc/jit/assertions.h>
#include <torch/csrc/jit/stack.h>
#include <torch/csrc/jit/argument_spec.h>
#include <torch/csrc/jit/interpreter.h>
#include "ATen/ATen.h"
#include <string>
#include <algorithm>
#include <unordered_map>
#include <vector>
#include <memory>
namespace torch { namespace jit {
struct FusedKernel;
struct FusionCompiler;
// type information needed by the compiler for input/outputs
// contiguity[i] is true if the dim i is contiguous with dim i + 1.
// contiguity.back() == true means strides.back() == 1.
struct TensorDesc {
at::ScalarType scalar_type;
std::vector<bool> contiguity;
TensorDesc(const at::ScalarType& type, const std::vector<bool>& contiguity)
: scalar_type(type), contiguity(contiguity) {
if (contiguity.size() == 0) {
nDim_ = 0;
} else {
nDim_ = std::count(contiguity.begin(), contiguity.end(), false) + (lastIsContiguous() ? 1 : 0);
}
}
TensorDesc(const at::ScalarType& type, const at::IntList& sizes, const at::IntList& strides)
: TensorDesc(type, TensorDesc::findContiguous(sizes, strides)) {}
TensorDesc(const at::Tensor& t)
: TensorDesc(t.type().scalarType(), t.sizes(), t.strides()) {}
TensorDesc(CompleteTensorTypePtr type)
: TensorDesc(type->scalarType(), type->sizes(), type->strides()) {}
// number of dimensions after contiguity compression
size_t nDim() const {
return nDim_;
}
// do we have inner stride == 1?
bool lastIsContiguous() const {
return contiguity.size() == 0 || contiguity.back();
}
static std::vector<bool> findContiguous(
const at::IntList& sizes,
const at::IntList& strides);
bool operator==(const TensorDesc & desc) const {
return scalar_type == desc.scalar_type && contiguity == desc.contiguity;
}
bool operator!=(const TensorDesc & desc) const {
return !(*this == desc);
}
static size_t hash(const TensorDesc& spec) {
return torch::get_hash(spec.scalar_type, spec.nDim_, std::hash<std::vector<bool>>{}(spec.contiguity));
}
private:
size_t nDim_;
};
inline std::ostream& operator<<(std::ostream & out, const TensorDesc & d) {
out << d.scalar_type << "[";
for(auto b : d.contiguity)
out << b << ";";
out << "]";
return out;
}
struct FusedKernelArgSpec {
FusedKernelArgSpec(at::TensorList inputs)
: descs_(fmap<TensorDesc>(inputs))
, hash_code_(torch::get_hash(inputs.size(), descs_)) {}
bool operator==(const FusedKernelArgSpec & spec) const {
return hash_code_ == spec.hash_code_ && descs_ == spec.descs_;
}
bool operator!=(const FusedKernelArgSpec & spec) const {
return !(*this == spec);
}
static size_t hash(const FusedKernelArgSpec& spec) {
return spec.hash_code_;
}
const std::vector<TensorDesc>& descs() const {
return descs_;
}
private:
std::vector<TensorDesc> descs_;
size_t hash_code_;
};
constexpr int kCPUDevice = -1;
struct AnnotatedGraph {
// short-term storage only, so it borrows Graph.
AnnotatedGraph(Graph & graph, int device)
: graph(&graph), device(device) {}
Graph* graph = nullptr; // TODO: this should really be const
int device = kCPUDevice;
std::vector<TensorDesc> input_desc;
std::vector<TensorDesc> output_desc;
};
// FusionCompiler has very limited shape information available at the time getOrCompile
// is called, and this is why it can't really prepare the kernels at that time. Instead,
// it returns this object, which will take care of matching the run-time shapes to whatever
// kernels we have compiled already.
//
// Two configurations are considered eligible for the same fused kernel if:
// - the shapes satisfy graph invariants for our fused code (e.g. that all intermediate shapes
// are the same - see fusion_compiler.cpp for more details).
// - their FusedKernelArgSpecs compare equal
struct FusedKernelCache {
FusedKernelCache(FusionCompiler& compiler, std::shared_ptr<Graph> graph, int device);
void run(Stack& inputs);
private:
struct PartitionInfo {
PartitionInfo(int64_t nsub, int64_t dim)
: nSubtensors(nsub), dim(dim) {};
int64_t nSubtensors;
int64_t dim;
};
void runFallback(Stack& stack);
void expandArgs(std::vector<at::Tensor>& args, std::vector<int64_t>& map_size);
at::optional<std::vector<int64_t>> canRunKernel(at::TensorList args);
at::optional<std::vector<int64_t>> getMapSize(at::TensorList args, at::IntList arg_subset);
std::vector<std::vector<int64_t>> getInputBroadcastGroups();
std::vector<PartitionInfo> getInputChunkDescriptors();
std::unique_ptr<FusedKernel> compileSpec(
const FusedKernelArgSpec& spec, const std::vector<int64_t>& map_size);
static std::atomic<size_t> next_kernel_id;
int device;
Code fallback_code;
FusionCompiler& compiler;
std::shared_ptr<Graph> graph;
std::vector<std::vector<int64_t>> input_broadcast_groups;
std::vector<PartitionInfo> input_chunks;
std::unordered_map<FusedKernelArgSpec, std::unique_ptr<FusedKernel>, torch::hash<FusedKernelArgSpec>> kernels;
};
struct FusionCompilerConfig {
std::string cxx = "g++"; // compiler location
bool debug = false; // emit debugging information about fusions
bool openmp = true;
};
// caching compiler
struct FusionCompiler {
friend struct FusedKernelCache;
FusionCompiler();
TH_DISALLOW_COPY_AND_ASSIGN(FusionCompiler);
// uses type annotations in fusion_group to create Annotated graph
std::shared_ptr<FusedKernelCache> getOrCompile(Node * fusion_group);
// debugging function that lets you do everything from compilation to execution
// in one step.
// this should not be used in the hot path of execution because it has to serialize
// the graph each time
std::vector<at::Tensor> debugLaunchGraph(Graph & graph, int device, at::ArrayRef<at::Tensor> inputs);
bool canCompileOnCPU() const {
return config_.cxx.size() > 0;
}
private:
FusionCompilerConfig config_;
std::unordered_map<std::string, std::shared_ptr<FusedKernelCache>> cache_map;
};
FusionCompiler & sharedFusionCompiler();
}}