|  | #pragma once | 
|  |  | 
|  | #include <torch/csrc/jit/ir.h> | 
|  | #include "torch/csrc/utils/disallow_copy.h" | 
|  | #include "torch/csrc/utils/hash.h" | 
|  | #include <torch/csrc/jit/assertions.h> | 
|  | #include <torch/csrc/jit/stack.h> | 
|  | #include <torch/csrc/jit/argument_spec.h> | 
|  | #include <torch/csrc/jit/interpreter.h> | 
|  |  | 
|  | #include "ATen/ATen.h" | 
|  | #include <string> | 
|  | #include <algorithm> | 
|  | #include <unordered_map> | 
|  | #include <vector> | 
|  | #include <memory> | 
|  |  | 
|  | namespace torch { namespace jit { | 
|  |  | 
|  | struct FusedKernel; | 
|  | struct FusionCompiler; | 
|  |  | 
|  | // type information needed by the compiler for input/outputs | 
|  | // contiguity[i] is true if the dim i is contiguous with dim i + 1. | 
|  | // contiguity.back() == true means strides.back() == 1. | 
|  | struct TensorDesc { | 
|  | at::ScalarType scalar_type; | 
|  | std::vector<bool> contiguity; | 
|  |  | 
|  | TensorDesc(const at::ScalarType& type, const std::vector<bool>& contiguity) | 
|  | : scalar_type(type), contiguity(contiguity) { | 
|  | if (contiguity.size() == 0) { | 
|  | nDim_ = 0; | 
|  | } else { | 
|  | nDim_ = std::count(contiguity.begin(), contiguity.end(), false) + (lastIsContiguous() ? 1 : 0); | 
|  | } | 
|  | } | 
|  |  | 
|  | TensorDesc(const at::ScalarType& type, const at::IntList& sizes, const at::IntList& strides) | 
|  | : TensorDesc(type, TensorDesc::findContiguous(sizes, strides)) {} | 
|  | TensorDesc(const at::Tensor& t) | 
|  | : TensorDesc(t.type().scalarType(), t.sizes(), t.strides()) {} | 
|  | TensorDesc(CompleteTensorTypePtr type) | 
|  | : TensorDesc(type->scalarType(), type->sizes(), type->strides()) {} | 
|  |  | 
|  | // number of dimensions after contiguity compression | 
|  | size_t nDim() const { | 
|  | return nDim_; | 
|  | } | 
|  |  | 
|  | // do we have inner stride == 1? | 
|  | bool lastIsContiguous() const { | 
|  | return contiguity.size() == 0 || contiguity.back(); | 
|  | } | 
|  |  | 
|  | static std::vector<bool> findContiguous( | 
|  | const at::IntList& sizes, | 
|  | const at::IntList& strides); | 
|  |  | 
|  | bool operator==(const TensorDesc & desc) const { | 
|  | return scalar_type == desc.scalar_type && contiguity == desc.contiguity; | 
|  | } | 
|  | bool operator!=(const TensorDesc & desc) const { | 
|  | return !(*this == desc); | 
|  | } | 
|  | static size_t hash(const TensorDesc& spec) { | 
|  | return torch::get_hash(spec.scalar_type, spec.nDim_, std::hash<std::vector<bool>>{}(spec.contiguity)); | 
|  | } | 
|  |  | 
|  | private: | 
|  | size_t nDim_; | 
|  | }; | 
|  |  | 
|  | inline std::ostream& operator<<(std::ostream & out, const TensorDesc & d) { | 
|  | out << d.scalar_type << "["; | 
|  | for(auto b : d.contiguity) | 
|  | out << b << ";"; | 
|  | out << "]"; | 
|  | return out; | 
|  | } | 
|  |  | 
|  | struct FusedKernelArgSpec { | 
|  | FusedKernelArgSpec(at::TensorList inputs) | 
|  | : descs_(fmap<TensorDesc>(inputs)) | 
|  | , hash_code_(torch::get_hash(inputs.size(), descs_)) {} | 
|  |  | 
|  | bool operator==(const FusedKernelArgSpec & spec) const { | 
|  | return hash_code_ == spec.hash_code_ && descs_ == spec.descs_; | 
|  | } | 
|  | bool operator!=(const FusedKernelArgSpec & spec) const { | 
|  | return !(*this == spec); | 
|  | } | 
|  | static size_t hash(const FusedKernelArgSpec& spec) { | 
|  | return spec.hash_code_; | 
|  | } | 
|  | const std::vector<TensorDesc>& descs() const { | 
|  | return descs_; | 
|  | } | 
|  |  | 
|  | private: | 
|  | std::vector<TensorDesc> descs_; | 
|  | size_t hash_code_; | 
|  | }; | 
|  |  | 
|  | constexpr int kCPUDevice = -1; | 
|  | struct AnnotatedGraph { | 
|  | // short-term storage only, so it borrows Graph. | 
|  | AnnotatedGraph(Graph & graph, int device) | 
|  | : graph(&graph), device(device) {} | 
|  | Graph* graph = nullptr; // TODO: this should really be const | 
|  | int device = kCPUDevice; | 
|  | std::vector<TensorDesc> input_desc; | 
|  | std::vector<TensorDesc> output_desc; | 
|  | }; | 
|  |  | 
|  | // FusionCompiler has very limited shape information available at the time getOrCompile | 
|  | // is called, and this is why it can't really prepare the kernels at that time. Instead, | 
|  | // it returns this object, which will take care of matching the run-time shapes to whatever | 
|  | // kernels we have compiled already. | 
|  | // | 
|  | // Two configurations are considered eligible for the same fused kernel if: | 
|  | //   - the shapes satisfy graph invariants for our fused code (e.g. that all intermediate shapes | 
|  | //     are the same - see fusion_compiler.cpp for more details). | 
|  | //   - their FusedKernelArgSpecs compare equal | 
|  | struct FusedKernelCache { | 
|  | FusedKernelCache(FusionCompiler& compiler, std::shared_ptr<Graph> graph, int device); | 
|  |  | 
|  | void run(Stack& inputs); | 
|  | private: | 
|  | struct PartitionInfo { | 
|  | PartitionInfo(int64_t nsub, int64_t dim) | 
|  | : nSubtensors(nsub), dim(dim) {}; | 
|  | int64_t nSubtensors; | 
|  | int64_t dim; | 
|  | }; | 
|  |  | 
|  | void runFallback(Stack& stack); | 
|  | void expandArgs(std::vector<at::Tensor>& args, std::vector<int64_t>& map_size); | 
|  | at::optional<std::vector<int64_t>> canRunKernel(at::TensorList args); | 
|  | at::optional<std::vector<int64_t>> getMapSize(at::TensorList args, at::IntList arg_subset); | 
|  | std::vector<std::vector<int64_t>> getInputBroadcastGroups(); | 
|  | std::vector<PartitionInfo> getInputChunkDescriptors(); | 
|  | std::unique_ptr<FusedKernel> compileSpec( | 
|  | const FusedKernelArgSpec& spec, const std::vector<int64_t>& map_size); | 
|  |  | 
|  | static std::atomic<size_t> next_kernel_id; | 
|  |  | 
|  | int device; | 
|  | Code fallback_code; | 
|  | FusionCompiler& compiler; | 
|  | std::shared_ptr<Graph> graph; | 
|  | std::vector<std::vector<int64_t>> input_broadcast_groups; | 
|  | std::vector<PartitionInfo> input_chunks; | 
|  | std::unordered_map<FusedKernelArgSpec, std::unique_ptr<FusedKernel>, torch::hash<FusedKernelArgSpec>> kernels; | 
|  | }; | 
|  |  | 
|  | struct FusionCompilerConfig { | 
|  | std::string cxx = "g++"; // compiler location | 
|  | bool debug = false; // emit debugging information about fusions | 
|  | bool openmp = true; | 
|  | }; | 
|  |  | 
|  | // caching compiler | 
|  | struct FusionCompiler { | 
|  | friend struct FusedKernelCache; | 
|  |  | 
|  | FusionCompiler(); | 
|  | TH_DISALLOW_COPY_AND_ASSIGN(FusionCompiler); | 
|  |  | 
|  | // uses type annotations in fusion_group to create Annotated graph | 
|  | std::shared_ptr<FusedKernelCache> getOrCompile(Node * fusion_group); | 
|  |  | 
|  | // debugging function that lets you do everything from compilation to execution | 
|  | // in one step. | 
|  | // this should not be used in the hot path of execution because it has to serialize | 
|  | // the graph each time | 
|  | std::vector<at::Tensor> debugLaunchGraph(Graph & graph, int device, at::ArrayRef<at::Tensor> inputs); | 
|  | bool canCompileOnCPU() const { | 
|  | return config_.cxx.size() > 0; | 
|  | } | 
|  | private: | 
|  | FusionCompilerConfig config_; | 
|  | std::unordered_map<std::string, std::shared_ptr<FusedKernelCache>> cache_map; | 
|  | }; | 
|  |  | 
|  | FusionCompiler & sharedFusionCompiler(); | 
|  |  | 
|  | }} |