| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ |
| #define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ |
| |
| #include "absl/container/flat_hash_map.h" |
| #include "absl/container/inlined_vector.h" |
| #include "absl/types/optional.h" |
| #include "absl/types/span.h" |
| #include "tensorflow/compiler/tf2xla/xla_compiler.h" |
| #include "tensorflow/compiler/tf2xla/xla_context.h" |
| #include "tensorflow/compiler/xla/client/local_client.h" |
| #include "tensorflow/compiler/xla/statusor.h" |
| #include "tensorflow/core/common_runtime/device.h" |
| #include "tensorflow/core/common_runtime/device_mgr.h" |
| #include "tensorflow/core/framework/graph.pb.h" |
| #include "tensorflow/core/framework/op_kernel.h" |
| #include "tensorflow/core/lib/core/threadpool.h" |
| #include "tensorflow/core/platform/mutex.h" |
| #include "tensorflow/core/platform/thread_annotations.h" |
| |
| namespace tensorflow { |
| |
| // The XlaCompilationCache class caches the results of the XlaCompiler class, |
| // which converts a Tensorflow graph into a compiled XLA compilation. |
| // |
| // Since XLA computations must have static shapes, the cache generates a new |
| // XLA computation for each new set of input shapes. |
| // |
| // Currently no cache eviction policy is implemented and the cache grows without |
| // bound. |
| class XlaCompilationCache : public ResourceBase { |
| public: |
| XlaCompilationCache(xla::LocalClient* client, DeviceType device_type); |
| ~XlaCompilationCache() override; |
| |
| enum class CompileMode { |
| kLazy, |
| kStrict, |
| }; |
| |
| // Compiles a function into a XlaCompiler::CompilationResult that can be used |
| // to execute an XLA Computation. Compilation results are cached. |
| // `function` is the name of a Tensorflow function to compile. |
| // `args` is a description of the arguments to the computation. |
| // |
| // `compile_mode` controls the behavior of the compilation cache on a cache |
| // miss. If `compile_mode` is `kLazy` then, based on some profitability |
| // heuristics, the compilation cache may decide not to compile the cluster at |
| // this time. In this case it returns null into both `out_compilation_result` |
| // and `out_executable`. If `compile_mode` is `kStrict` then the compilation |
| // cache always attempts the compilation on a cache miss. |
| // |
| // The result of compilation is written to `*out_compilation_result`, which |
| // must be non-null. If `out_executable` is non-null, also builds an |
| // xla::LocalExecutable and sets `out_executable` to point to it. The |
| // resulting executable pointer may be null if the computation has no |
| // non-constant outputs. |
| Status Compile(const XlaCompiler::Options& options, |
| const NameAttrList& function, |
| absl::Span<const XlaCompiler::Argument> args, |
| const XlaCompiler::CompileOptions& compile_options, |
| CompileMode compile_mode, |
| const XlaCompiler::CompilationResult** out_compilation_result, |
| xla::LocalExecutable** out_executable); |
| |
| // As above, but calls XlaCompiler::CompileSingleOp instead of |
| // XlaCompiler::CompileFunction. If MLIR bridge is enabled through ConfigProto |
| // in OpKernelContext, then uses MLIR bridge for compilation instead of |
| // XlaCompiler, if possible. |
| Status CompileSingleOp( |
| const XlaCompiler::Options& options, |
| absl::Span<const XlaCompiler::Argument> args, OpKernelContext* ctx, |
| const XlaCompiler::CompileOptions& compile_options, |
| const XlaCompiler::CompilationResult** out_compilation_result, |
| xla::LocalExecutable** out_executable); |
| |
| xla::LocalClient* client() const { return client_; } |
| const DeviceType& device_type() const { return device_type_; } |
| |
| string DebugString() const override; |
| |
| // Describes the types, shapes and any compile-time constant arguments |
| // to a kernel. Key that uniquely identifies a compilation output. |
| struct Signature { |
| string name; |
| |
| // List of Tensor types & shapes for compile-time constant arguments to the |
| // compilation, ordered by argument number. |
| absl::InlinedVector<std::pair<DataType, absl::InlinedVector<int64, 4>>, 4> |
| arg_shapes; |
| |
| // List of Tensor values for compile-time constant arguments to the |
| // compilation, ordered by argument number. Tensors must be in host memory. |
| absl::InlinedVector<Tensor, 4> arg_values; |
| |
| bool operator==(const Signature& other) const; |
| |
| struct Hash { |
| uint64 operator()(const Signature& signature) const; |
| }; |
| |
| // Returns a human-readable description of the signature. |
| string HumanString() const; |
| }; |
| |
| // Builds the signature for a compilation. |
| static xla::StatusOr<Signature> BuildSignature( |
| const NameAttrList& function, |
| absl::Span<const XlaCompiler::Argument> args); |
| |
| private: |
| // Common implementation of Compile and CompileSingleOp. |
| Status CompileImpl( |
| const XlaCompiler::Options& options, const NameAttrList& function, |
| absl::Span<const XlaCompiler::Argument> args, |
| const std::function<Status(XlaCompiler* compiler, |
| XlaCompiler::CompilationResult*)>& compile_fn, |
| absl::optional<int64> compile_threshold, |
| const XlaCompiler::CompilationResult** out_compilation_result, |
| xla::LocalExecutable** out_executable); |
| |
| // Takes `result` which has been compiled from a Tensorflow subgraph to a |
| // XLA computation already, and generates an XLA LocalExecutable `executable`. |
| Status BuildExecutable(const XlaCompiler::Options& options, |
| const XlaCompiler::CompilationResult& result, |
| std::unique_ptr<xla::LocalExecutable>* executable); |
| |
| xla::LocalClient* const client_; |
| const DeviceType device_type_; |
| |
| // The value associated with a cache entry. |
| struct Entry { |
| mutex mu; |
| |
| // Have we tried compiling this entry? |
| bool compiled = false; |
| |
| // The number of times a compilation with this signature has been requested. |
| int64 request_count = 0; |
| |
| // Did compilation succeed? |
| Status compilation_status TF_GUARDED_BY(mu); |
| |
| // Output of the XlaCompiler. |
| XlaCompiler::CompilationResult compilation_result TF_GUARDED_BY(mu); |
| |
| // The XLA executable compiled from <computation>. May be null if no |
| // executable has been built. |
| std::unique_ptr<xla::LocalExecutable> executable TF_GUARDED_BY(mu); |
| }; |
| |
| mutex compile_cache_mu_; |
| absl::flat_hash_map<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_ |
| TF_GUARDED_BY(compile_cache_mu_); |
| |
| struct ClusterCompileStats { |
| // Number of times the cluster has been (re-)compiled. |
| int64 compile_count = 0; |
| |
| // The number of times this cluster has been executed. |
| int64 execution_count = 0; |
| |
| // Cumulative time spent compiling the cluster. |
| int64 cumulative_compile_time_us = 0; |
| |
| // True if we have decided that this cluster is too dynamic (i.e. its shapes |
| // change too frequently) to profitably JIT compile. Once a cluster is |
| // tagged megamorphic, it stays megamorphic forever. |
| bool is_megamorphic = false; |
| }; |
| |
| mutex cluster_compile_stats_mu_; |
| |
| // Maps cluster names to compilation statistics for said cluster. |
| absl::flat_hash_map<string, ClusterCompileStats> cluster_compile_stats_ |
| TF_GUARDED_BY(cluster_compile_stats_mu_); |
| |
| // The number of times a lazy compilation must be requested for a specific |
| // signature before we attempt to compile it. |
| static constexpr int64 kDefaultCompilationThreshold = 2; |
| |
| TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache); |
| }; |
| |
| } // namespace tensorflow |
| |
| #endif // TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ |