| /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ |
| #define TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ |
| |
| #include <type_traits> |
| |
| #include "tensorflow/compiler/xla/service/hlo_instruction.h" |
| #include "tensorflow/compiler/xla/service/hlo_module.h" |
| #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" |
| #include "tensorflow/core/lib/gtl/flatmap.h" |
| #include "tensorflow/core/util/ptr_util.h" |
| |
| namespace xla { |
| |
| // IndexedArrayAnalysis decides if an HLO instruction can be rewritten as a |
| // gather from another array. It does this by mapping HLO instructions to |
| // instances of IndexedArrayAnalysis::Array, which can be inspected to discover |
| // whether said HLO is equivalent to a gather. |
| class IndexedArrayAnalysis { |
| public: |
| // IndexedArrayAnalysis maps each HLO instruction to an instance of a Array. |
| // Array really just a sum type of the classes that inherit from it. The |
| // meaning of each of the subtypes is documented on the subtype declaration. |
| // |
| // Array instances are immutable once created. |
| class Array { |
| public: |
| enum Kind { |
| kUnknown, |
| kConstant, |
| kReshaped, |
| kScalarIndexedConstant, |
| kScalarIndexed |
| }; |
| |
| virtual Kind kind() const = 0; |
| virtual const Shape& shape() const = 0; |
| |
| // Does a checked downcast from `Array` to `T` which must be one of its |
| // subtypes. |
| template <typename T> |
| T* as() { |
| static_assert((std::is_base_of<Array, T>::value), |
| "target type not derived from source type"); |
| // We skip the CHECK and hence the dynamic_cast if RTTI is disabled. |
| #if !defined(__GNUC__) || defined(__GXX_RTTI) |
| CHECK_NE(dynamic_cast<T*>(this), nullptr); |
| #endif // !defined(__GNUC__) || defined(__GXX_RTTI) |
| |
| return static_cast<T*>(this); |
| } |
| |
| virtual ~Array() = default; |
| |
| Array& operator=(const Array& other) = delete; |
| }; |
| |
| // Represents an HLO instruction that was not analyzable by this |
| // IndexedArrayAnalysis. Instances of UnknownArray just wrap an existing |
| // HloInstruction. |
| class UnknownArray : public Array { |
| public: |
| Kind kind() const override { return kUnknown; } |
| const Shape& shape() const override { return instruction().shape(); } |
| const HloInstruction& instruction() const { return instruction_; } |
| |
| private: |
| explicit UnknownArray(const HloInstruction* instr) : instruction_(*instr) {} |
| |
| const HloInstruction& instruction_; |
| |
| friend class IndexedArrayAnalysis; |
| }; |
| |
| // Represents a constant value. This constant value may be present in the HLO |
| // module being analyzed, or it could have been created on the fly by the |
| // analysis. |
| class ConstantArray : public Array { |
| public: |
| Kind kind() const override { return kConstant; } |
| const Shape& shape() const override { return literal()->shape(); } |
| const Literal* literal() const { return literal_; } |
| |
| private: |
| explicit ConstantArray(const Literal* literal) : literal_(literal) {} |
| const Literal* literal_; |
| |
| friend class IndexedArrayAnalysis; |
| }; |
| |
| // Represents an Array that is a reshape of another Array. |
| class ReshapedArray : public Array { |
| public: |
| Kind kind() const override { return kReshaped; } |
| |
| // The array to reshape. |
| Array* operand() const { return operand_; } |
| |
| // The output shape. |
| const Shape& shape() const override { return shape_; } |
| |
| private: |
| explicit ReshapedArray(Array* operand, Shape shape) |
| : operand_(operand), shape_(shape) {} |
| |
| Array* operand_; |
| const Shape shape_; |
| |
| friend class IndexedArrayAnalysis; |
| }; |
| |
| // --------------------------------------------------------------------------- |
| // Indexed Array Overview |
| // --------------------------------------------------------------------------- |
| // |
| // ScalarIndexedArray and ScalarIndexedConstantArray form the core of this |
| // analysis. ScalarIndexedConstantArray is just a specialization of |
| // ScalarIndexedArray so we will only discuss ScalarIndexedArray in this |
| // overview. |
| // |
| // A ScalarIndexedArray represents an array that can be computed by indexing |
| // into a "source" array using an "indices" tensor. A simple example is a |
| // gather operation gathering 12 rows out of a [100,100] matrix -- such an |
| // operation will be represented by an instance of a ScalarIndexedArray with |
| // the [100,100] matrix as the "source" array and the [12]-shaped indices |
| // array as the "indices" tensor. The ScalarIndexedArray operation itself |
| // will be of shape [12,100] (assuming we were gathering with axis=0). |
| // |
| // Gather operations are not the only operation that maps to |
| // ScalarIndexedArray instances (if that were true there would be little point |
| // in having a separate analysis). We can often infer ScalarIndexedArrays for |
| // other operations too. For instance, consider: |
| // |
| // %source = f32[100,100] constant |
| // %indices = s32[12] ... |
| // %gather = f32[12,100] ... gather from %source using %indices at axis 0 |
| // %dot = dot(%gather, other_constant) [canonical contracting dims] |
| // |
| // The dot operation itself is also a ScalarIndexedArray with source = |
| // dot(constant, other_constant) and indices = %indices. A reshape of %gather |
| // to [12,5,20] too is a ScalarIndexedArray with source = an appropriately |
| // reshaped constant and indices = %indices. |
| |
| // Represents the result of a gather operation. This gather operation may |
| // explicitly be present in the HLO module being analyzed, or it could have |
| // been created on the fly by the analysis. |
| // |
| // An instance of ScalarIndexedArray represents a array whose I'th element can |
| // be mapped to the J'th element of the `source` array (where I and J are |
| // multidimensional indices) in this way: |
| // |
| // I' = remove components at positions `output_dims` from I |
| // G' = remove components not at positions `output_dims` from I |
| // T = indices[G'] |
| // J = I' with T inserted at position `source_dim` |
| // |
| // For example, if source is of shape [11,13,17,19], indices is of shape |
| // [23,29], output_dims is [0,2] and source_dim is 2 then the output is of |
| // shape [23,11,29,13,19] and the output index [A,B,C,D,E] is mapped to the |
| // input index [B,D,indices[A,C],E]. |
| class ScalarIndexedArray : public Array { |
| public: |
| Kind kind() const override { return kScalarIndexed; } |
| const Shape& shape() const override { return shape_; } |
| |
| Array* source() const { return source_; } |
| Array* indices() const { return indices_; } |
| |
| // `source_dim` is the dimension in the source array that is being indexed |
| // over using indices from the `indices` array. See the class documentation |
| // and the overview for more details. |
| int64 source_dim() const { return source_dim_; } |
| |
| // `output_dims` are the dimensions in the output array that are being used |
| // to compute an index into the `indices` array. See the class |
| // documentation and the overview for more details. |
| absl::Span<const int64> output_dims() const { return output_dims_; } |
| |
| private: |
| explicit ScalarIndexedArray(Array* source, Array* indices, int64 source_dim, |
| std::vector<int64> output_dims, Shape shape) |
| : source_(source), |
| indices_(indices), |
| source_dim_(source_dim), |
| output_dims_(std::move(output_dims)), |
| shape_(std::move(shape)) {} |
| |
| Array* source_; |
| Array* indices_; |
| int64 source_dim_; |
| std::vector<int64> output_dims_; |
| Shape shape_; |
| |
| friend class IndexedArrayAnalysis; |
| }; |
| |
| // A ScalarIndexedConstantArray is just a ScalarIndexedArray constrained to |
| // have a ConstantArray instance as the source. This is an ergonomic |
| // concession -- in theory it is possible to just keep ScalarIndexedArray and |
| // check source()->kind(). |
| class ScalarIndexedConstantArray : public ScalarIndexedArray { |
| public: |
| Kind kind() const override { return kScalarIndexedConstant; } |
| |
| const Literal& literal() const { |
| return *source()->as<ConstantArray>()->literal(); |
| } |
| |
| private: |
| explicit ScalarIndexedConstantArray(Array* source, Array* indices, |
| int64 source_dim, |
| std::vector<int64> output_dims, |
| Shape shape) |
| : ScalarIndexedArray(source, indices, source_dim, |
| std::move(output_dims), std::move(shape)) { |
| CHECK(dynamic_cast<ConstantArray*>(source)); |
| } |
| |
| friend class IndexedArrayAnalysis; |
| }; |
| |
| // Returns an Array instance for `instr`. The IndexedArrayAnalysis instance |
| // keeps ownership of the returned Array instance. |
| // |
| // Caching Behavior: IndexedArrayAnalysis has a cache mapping HLO |
| // instructions to IndexedArrayAnalysis::Array instances. This entire cache |
| // becomes stale and may cause the analysis to return incorrect results if any |
| // transitive operand (stopping at the containing computation) is modified for |
| // any HLO instruction on which GetArrayFor has been invoked. |
| // |
| // NB! By inspecting the implementation, you may be able to infer a stronger |
| // caching guarantee than what is mentioned above. Nevertheless, what is |
| // stated above is the contract. |
| StatusOr<Array*> GetArrayFor(const HloInstruction* instr); |
| |
| // Pretty-prints the expression rooted at `root`. |
| string ToString(Array* root, bool print_constants = false); |
| |
| private: |
| // Helper function that ensures that every HLO instruction that is |
| // transitively used by `root` has an entry in `cache_`. |
| Status TraverseAndPopulateCache(const HloInstruction* root); |
| |
| // Creates an Array instance for `instr` under the assumption that all |
| // operations of `instr` are present in `cache_`. |
| StatusOr<Array*> ComputeArrayFor(const HloInstruction* instr); |
| |
| StatusOr<Array*> ComputeArrayForConstant(const Literal& literal); |
| |
| StatusOr<Array*> ComputeArrayForGather( |
| const Shape& shape, const GatherDimensionNumbers& dim_numbers, |
| absl::Span<const int64> slice_sizes, Array* source, Array* indices); |
| |
| StatusOr<Array*> ComputeArrayForDotWithIndexedLhs( |
| const Shape& shape, const DotDimensionNumbers& dim_numbers, |
| const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs, |
| ConstantArray* rhs); |
| |
| StatusOr<Array*> ComputeArrayForDotWithIndexedRhs( |
| const Shape& shape, const DotDimensionNumbers& dim_numbers, |
| const PrecisionConfig& precision_config, ConstantArray* lhs, |
| ScalarIndexedConstantArray* rhs); |
| |
| StatusOr<Array*> ComputeArrayForDot(const Shape& shape, |
| const DotDimensionNumbers& dim_numbers, |
| const PrecisionConfig& precision_config, |
| Array* lhs, Array* rhs); |
| |
| // This tries to fold a ScalarIndexedArray which has another |
| // ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a |
| // ScalarIndexedArray as indices. If `source` happened to be a |
| // ScalarIndexedConstantArray this can result in an expression that is more |
| // canonical. |
| // |
| // As an example, consider a gather operation, G0, gathering 7 elements from |
| // an array "Arr" of shape [100] resulting in an array of shape [7], and a |
| // second gather operation, G1, which gathers 3 elements out of the result of |
| // G0 resulting in an array of shape [3]. Let the indices uses by G0 be I0 |
| // (of shape [7]) and the indices used by G1 be I1 (of shape [3]). We can |
| // instead rewrite G1 to gather directly from "Arr" with the three indices |
| // from I0 as per I1. In other words, we can rewrite: |
| // |
| // G0 = [Arr[i] for i in I0] |
| // G1 = [G0[i] for i in I1] |
| // |
| // into |
| // |
| // I2 = [I0[i] for i in I1] |
| // G1 = [Arr[i] for i in I2] |
| StatusOr<ScalarIndexedArray*> FoldGatherOfGather( |
| ScalarIndexedArray* source, Array* indices, int64 source_dim, |
| absl::Span<const int64> output_dims, Shape shape); |
| |
| // Reshapes a scalar-indexed node to remove the degenerate dimensions in its |
| // output. The result is always a scalar-indexed node. |
| StatusOr<ScalarIndexedArray*> ReshapeToRemoveDegenerateDims( |
| ScalarIndexedArray* operand); |
| |
| // Reshapes a scalar-indexed node such that the result has the degenerate |
| // dimensions `degenerate_dims`. The result is always a scalar-indexed node. |
| StatusOr<ScalarIndexedArray*> ReshapeToAddDegenerateDims( |
| ScalarIndexedArray* operand, absl::Span<const int64> degenerate_dims); |
| |
| StatusOr<ScalarIndexedArray*> FoldReshapeOfGather( |
| const Shape& shape, ScalarIndexedConstantArray* operand); |
| StatusOr<ScalarIndexedArray*> FoldReshapeOfGatherNoDegenerateDims( |
| const Shape& shape, ScalarIndexedConstantArray* scalar_indexed); |
| StatusOr<Array*> ComputeArrayForReshape(const Shape& shape, Array* operand); |
| |
| StatusOr<Array*> ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, |
| Array* lhs, Array* rhs); |
| StatusOr<Array*> ComputeArrayForElementwiseUnaryOp(HloOpcode opcode, |
| Array* operand); |
| |
| template <typename T, typename... Args> |
| T* Construct(Args&&... args) { |
| T* new_tensor = new T(std::forward<Args>(args)...); |
| owned_tensors_.push_back(std::unique_ptr<T>(new_tensor)); |
| return new_tensor; |
| } |
| |
| ScalarIndexedArray* ConstructScalarIndexedArray( |
| Array* source, Array* indices, int64 source_dim, |
| std::vector<int64> output_dims, Shape shape) { |
| if (source->kind() == Array::kConstant) { |
| return Construct<ScalarIndexedConstantArray>(source, indices, source_dim, |
| std::move(output_dims), |
| std::move(shape)); |
| } else { |
| return Construct<ScalarIndexedArray>(source, indices, source_dim, |
| std::move(output_dims), |
| std::move(shape)); |
| } |
| } |
| |
| Literal* TakeOwnership(Literal literal) { |
| owned_literals_.push_back(std::move(literal)); |
| return &owned_literals_.back(); |
| } |
| |
| StatusOr<Literal*> TakeOwnership(StatusOr<Literal> literal_or_error) { |
| TF_ASSIGN_OR_RETURN(Literal literal, std::move(literal_or_error)); |
| owned_literals_.push_back(std::move(literal)); |
| return &owned_literals_.back(); |
| } |
| |
| std::vector<std::unique_ptr<Array>> owned_tensors_; |
| std::vector<Literal> owned_literals_; |
| tensorflow::gtl::FlatMap<const HloInstruction*, Array*> cache_; |
| }; |
| |
| // A pass that prints all non-trivial results returned by IndexedArrayAnalysis. |
| // This pass is a no-op if !VLOG_IS_ON(2) so it should be fine to |
| // unconditionally add to the regular HLO pass pipeline. |
| class IndexedArrayAnalysisPrinterPass : public HloModulePass { |
| public: |
| absl::string_view name() const override; |
| StatusOr<bool> Run(HloModule* module) override; |
| }; |
| |
| } // namespace xla |
| |
| #endif // TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ |