| /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ |
| #define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ |
| |
| #include "absl/strings/string_view.h" |
| #include "absl/utility/utility.h" |
| #include "tensorflow/compiler/xla/layout_util.h" |
| #include "tensorflow/compiler/xla/literal_util.h" |
| #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" |
| #include "tensorflow/compiler/xla/service/hlo_instruction.h" |
| #include "tensorflow/compiler/xla/service/hlo_instructions.h" |
| #include "tensorflow/compiler/xla/service/hlo_opcode.h" |
| #include "tensorflow/compiler/xla/shape_util.h" |
| |
| namespace xla { |
| |
| // A pattern matcher for HloInstructions, Shapes, and Layouts. |
| // |
| // The Match function's first argument must be HloInstruction*, Shape*, or |
| // Layout*. The second argument is a pattern that will be matched against the |
| // first argument, as described below. |
| // |
| // Patterns are constructed using the match::Op, match::Shape, or match::Layout |
| // functions. By default, the returned patterns will match any HloInstruction, |
| // Shape, or Layout, respectively. However the match can be made more specific |
| // by using the pattern's modifier methods, for example: |
| // |
| // match::Op().WithOpcode(HloOpcode::kAdd).WithOperand( |
| // 0, match::Op().WithOpcode(HloOpcode::kConstant)) |
| // |
| // This pattern will match Add instructions whose first operand is a constant. |
| // |
| // Each pattern type has the following modifiers: |
| // |
| // Op(): |
| // - WithName: match operations with the given name |
| // - WithOpcode: match operations with the given opcode |
| // - WithShape: match operations whose shape matches the given pattern |
| // - WithOperand: match operations whose operand matches the given pattern |
| // |
| // Shape(): |
| // - EqualTo: matches shapes that are equal to the argument |
| // - CompatibleTo: matches shapes that are compatible to the argument |
| // - IsScalar/IsArray/IsTuple: matches scalar/array/tuple shapes |
| // - IsDenseArray/IsSparseArray: matches arrays with dense/sparse format |
| // - WithLayout: match shapes whose layout matches the given pattern |
| // - WithLayoutEqualTo: matches shapes whose layouts equal the argument |
| // - WithSubshape: matches tuple shapes whose subshape matches the given |
| // pattern |
| // - WithSubshapeEqualTo: matches shapes with a subshape equal the argument |
| // - WithElementType: matches array/scalar shapes with the given element |
| // type |
| // - WithRank: matches array/scalar types with the given rank |
| // |
| // Layout(): |
| // - EqualTo: matches layouts that are equal to the argument |
| // - WithDenseFormat/WithSparseFormat: matches layouts with dense/sparse |
| // format |
| // |
| // Op(), Shape(), and Layout() may be passed an argument of type |
| // HloInstruction**, Shape**, or Layout**, respectively, or const versions of |
| // these pointers. If the pattern is matched, the address of the matched value |
| // will be "captured" and stored at this location. |
| // |
| // For example: |
| // HloInstruction* foo = ...; |
| // HloInstruction* matched_operand; |
| // CHECK(Match(foo, |
| // match::Op().WithOperand(0, match::Op(&matched_operand)))); |
| // |
| // Helpers are provided for common nullary, unary, binary, and ternary |
| // instructions. These helpers can be called with no arguments, in which case |
| // they will match any instruction matching the opcode. They may also be called |
| // with matches for the operands and with an optional capture. (The capture must |
| // be the first argument.) Some examples of these helpers and their equivalents |
| // are provided below. |
| // |
| // Example nullary instruction: |
| // Param() == Op().WithOpcode(HloOpcode::kParam) |
| // Param(&a) == Op(&a).WithOpcode(HloOpcode::kParam) |
| // |
| // Example unary instruction: |
| // Abs() == Op().WithOpcode(HloOpcode::kAbs) |
| // Abs(Op(&a)) == Op().WithOpcode(HloOpcode::kAbs) |
| // .WithOperand(0, Op(&a))) |
| // Abs(&a, Op(&b)) == Op(&a).WithOpcode(HloOpcode::kAbs) |
| // .WithOperand(0, Op(&b)) |
| // |
| // Example binary instruction: |
| // Add() == Op().WithOpcode(HloOpcode::kAdd) |
| // Add(Op(&a), Op(&b)) == Op().WithOpcode(HloOpcode::kAdd) |
| // .WithOperand(0, Op(&a)) |
| // .WithOperand(1, Op(&b)) |
| // Add(&a, Op(&b), Op(&c)) == Op(&a).WithOpcode(HloOpcode::kAdd) |
| // .WithOperand(0, Op(&b)) |
| // .WithOperand(1, Op(&c)) |
| // |
| // Example ternary instruction: |
| // Clamp() == Op().WithOpcode(HloOpcode::kClamp) |
| // Clamp(Op(&a), Op(&b), Op(&c)) == Op().WithOpcode(HloOpcode::kClamp) |
| // .WithOperand(0, Op(&a)) |
| // .WithOperand(1, Op(&b)) |
| // .WithOperand(2, Op(&c)) |
| // Clamp(&a, Op(&b), Op(&c), Op(&d)) == Op(&a).WithOpcode(HloOpcode::kClamp) |
| // .WithOperand(0, Op(&b)) |
| // .WithOperand(1, Op(&c)) |
| // .WithOperand(2, Op(&d)) |
| // |
| template <typename Value, typename Pattern> |
| bool Match(Value* value, const Pattern& pattern) { |
| return pattern.Match(value); |
| } |
| |
| namespace match { |
| |
| namespace detail { |
| |
| template <typename LayoutType, typename Impl> |
| class LayoutPattern; |
| |
| // The base LayoutPattern implementation. Matches only if the layout is not |
| // nullptr. |
| class LayoutPatternBaseImpl { |
| public: |
| bool Match(const ::xla::Layout* layout) const { return layout != nullptr; } |
| }; |
| |
| // A LayoutPattern implementation that matches only if the layout equals a |
| // Layout proto. |
| template <typename Previous> |
| class LayoutPatternEqualImpl { |
| public: |
| explicit constexpr LayoutPatternEqualImpl(const Previous& previous, |
| const ::xla::Layout* layout) |
| : previous_(previous), layout_(layout) {} |
| |
| bool Match(const ::xla::Layout* layout) const { |
| return previous_.Match(layout) && LayoutUtil::Equal(*layout_, *layout); |
| } |
| |
| private: |
| Previous previous_; |
| const ::xla::Layout* layout_; |
| }; |
| |
| // A LayoutPattern implementation that matches only if the layout has a given |
| // format. |
| template <typename Previous> |
| class LayoutPatternFormatImpl { |
| public: |
| explicit constexpr LayoutPatternFormatImpl(const Previous& previous, |
| Format format) |
| : previous_(previous), format_(format) {} |
| |
| bool Match(const ::xla::Layout* layout) const { |
| return previous_.Match(layout) && layout->format() == format_; |
| } |
| |
| private: |
| Previous previous_; |
| Format format_; |
| }; |
| |
| // A pattern that matches Layouts. |
| template <typename LayoutType, typename Impl> |
| class LayoutPattern { |
| public: |
| explicit constexpr LayoutPattern(const Impl& impl, |
| LayoutType** matched_layout) |
| : impl_(impl), matched_layout_(matched_layout) {} |
| |
| // Returns true and captures the layout iff it matches the pattern. |
| bool Match(const ::xla::Layout* layout) const { |
| if (impl_.Match(layout)) { |
| if (matched_layout_) { |
| *matched_layout_ = layout; |
| } |
| return true; |
| } |
| return false; |
| } |
| |
| // Returns true and captures the layout iff it matches the pattern. |
| bool Match(::xla::Layout* layout) const { |
| if (impl_.Match(layout)) { |
| if (matched_layout_) { |
| *matched_layout_ = layout; |
| } |
| return true; |
| } |
| return false; |
| } |
| |
| // Modifies the pattern to match only if the layout equals the given proto. |
| // The layout must outlive the returned pattern. |
| constexpr LayoutPattern<LayoutType, LayoutPatternEqualImpl<Impl>> EqualTo( |
| const ::xla::Layout* layout) const { |
| return LayoutPattern<LayoutType, LayoutPatternEqualImpl<Impl>>( |
| LayoutPatternEqualImpl<Impl>(impl_, layout), matched_layout_); |
| } |
| |
| // Modifies the pattern to match only if the layout has a dense format. |
| constexpr LayoutPattern<LayoutType, LayoutPatternFormatImpl<Impl>> |
| WithDenseFormat() const { |
| return LayoutPattern<LayoutType, LayoutPatternFormatImpl<Impl>>( |
| LayoutPatternFormatImpl<Impl>(impl_, DENSE), matched_layout_); |
| } |
| |
| // Modifies the pattern to match only if the layout has a sparse format. |
| constexpr LayoutPattern<LayoutType, LayoutPatternFormatImpl<Impl>> |
| WithSparseFormat() const { |
| return LayoutPattern<LayoutType, LayoutPatternFormatImpl<Impl>>( |
| LayoutPatternFormatImpl<Impl>(impl_, SPARSE), matched_layout_); |
| } |
| |
| private: |
| Impl impl_; |
| LayoutType** matched_layout_; |
| }; |
| |
| template <typename Item, typename... Patterns> |
| class AnyOfPattern { |
| public: |
| explicit AnyOfPattern(const Patterns&... patterns) : patterns_(patterns...) {} |
| |
| bool Match(const Item* item) const { |
| return MatchImpl(item, std::integral_constant<size_t, 0>()); |
| } |
| |
| bool Match(Item* item) const { |
| return MatchImpl(item, std::integral_constant<size_t, 0>()); |
| } |
| |
| private: |
| template <typename ItemType, size_t index> |
| bool MatchImpl(ItemType* item, std::integral_constant<size_t, index>) const { |
| return std::get<index>(patterns_).Match(item) || |
| MatchImpl(item, std::integral_constant<size_t, index + 1>()); |
| } |
| |
| template <typename ItemType> |
| bool MatchImpl(ItemType* item, |
| std::integral_constant<size_t, sizeof...(Patterns)>) const { |
| return false; |
| } |
| |
| std::tuple<Patterns...> patterns_; |
| }; |
| } // namespace detail |
| |
| // Returns a pattern that represents the logical disjunction of the input |
| // patterns. The returned pattern matches from left to right, and stops on the |
| // first match. |
| template <typename Item, typename... Patterns> |
| detail::AnyOfPattern<typename std::remove_const<Item>::type, Patterns...> AnyOf( |
| const Patterns&... patterns) { |
| return detail::AnyOfPattern<typename std::remove_const<Item>::type, |
| Patterns...>(patterns...); |
| } |
| |
| // Creates a layout pattern that will capture the matched layout in the |
| // argument. |
| inline constexpr detail::LayoutPattern<const ::xla::Layout, |
| detail::LayoutPatternBaseImpl> |
| Layout(const ::xla::Layout** matched_layout = nullptr) { |
| return detail::LayoutPattern<const ::xla::Layout, |
| detail::LayoutPatternBaseImpl>( |
| detail::LayoutPatternBaseImpl(), matched_layout); |
| } |
| |
| // Creates a layout pattern that will capture the matched layout in the |
| // argument. |
| inline constexpr detail::LayoutPattern<::xla::Layout, |
| detail::LayoutPatternBaseImpl> |
| Layout(::xla::Layout** matched_layout) { |
| return detail::LayoutPattern<::xla::Layout, detail::LayoutPatternBaseImpl>( |
| detail::LayoutPatternBaseImpl(), matched_layout); |
| } |
| |
| namespace detail { |
| |
| template <typename ShapeType, typename Impl> |
| class ShapePattern; |
| |
| // The base ShapePattern implementation. Matches only if the shape is not |
| // nullptr. |
| class ShapePatternBaseImpl { |
| public: |
| bool Match(const ::xla::Shape* shape) const { return shape != nullptr; } |
| }; |
| |
| // A ShapePattern implementation that matches only if the shape equals a Shape |
| // proto. |
| template <typename Previous> |
| class ShapePatternEqualImpl { |
| public: |
| explicit constexpr ShapePatternEqualImpl(const Previous& previous, |
| const ::xla::Shape* shape) |
| : previous_(previous), shape_(shape) {} |
| |
| bool Match(const ::xla::Shape* shape) const { |
| return previous_.Match(shape) && ShapeUtil::Equal(*shape_, *shape); |
| } |
| |
| private: |
| Previous previous_; |
| const ::xla::Shape* shape_; |
| }; |
| |
| // A ShapePattern implementation that matches only if the shape is compatible to |
| // a Shape proto. |
| template <typename Previous> |
| class ShapePatternCompatibleImpl { |
| public: |
| explicit constexpr ShapePatternCompatibleImpl(const Previous& previous, |
| const ::xla::Shape* shape) |
| : previous_(previous), shape_(shape) {} |
| |
| bool Match(const ::xla::Shape* shape) const { |
| return previous_.Match(shape) && ShapeUtil::Compatible(*shape_, *shape); |
| } |
| |
| private: |
| Previous previous_; |
| const ::xla::Shape* shape_; |
| }; |
| |
| // A ShapePattern implementation that matches only if the shape has a given |
| // element type. |
| template <typename Previous> |
| class ShapePatternElementTypeImpl { |
| public: |
| explicit constexpr ShapePatternElementTypeImpl(const Previous& previous, |
| PrimitiveType element_type) |
| : previous_(previous), element_type_(element_type) {} |
| |
| bool Match(const ::xla::Shape* shape) const { |
| return previous_.Match(shape) && shape->element_type() == element_type_; |
| } |
| |
| private: |
| Previous previous_; |
| PrimitiveType element_type_; |
| }; |
| |
| // A ShapePattern implementation that matches only if the shape is scalar. |
| template <typename Previous> |
| class ShapePatternIsScalarImpl { |
| public: |
| explicit constexpr ShapePatternIsScalarImpl(const Previous& previous) |
| : previous_(previous) {} |
| |
| bool Match(const ::xla::Shape* shape) const { |
| return previous_.Match(shape) && ShapeUtil::IsScalar(*shape); |
| } |
| |
| private: |
| Previous previous_; |
| }; |
| |
| // A ShapePattern implementation that matches only if the shape is an array |
| template <typename Previous> |
| class ShapePatternIsArrayImpl { |
| public: |
| explicit constexpr ShapePatternIsArrayImpl(const Previous& previous) |
| : previous_(previous) {} |
| |
| bool Match(const ::xla::Shape* shape) const { |
| return previous_.Match(shape) && ShapeUtil::IsArray(*shape); |
| } |
| |
| private: |
| Previous previous_; |
| }; |
| |
| // A ShapePattern implementation that matches only if the shape is a tuple. |
| template <typename Previous> |
| class ShapePatternIsTupleImpl { |
| public: |
| explicit constexpr ShapePatternIsTupleImpl(const Previous& previous) |
| : previous_(previous) {} |
| |
| bool Match(const ::xla::Shape* shape) const { |
| return previous_.Match(shape) && ShapeUtil::IsTuple(*shape); |
| } |
| |
| private: |
| Previous previous_; |
| }; |
| |
| // A ShapePattern implementation that matches only if the shape has a given |
| // rank. |
| template <typename Previous> |
| class ShapePatternRankImpl { |
| public: |
| explicit constexpr ShapePatternRankImpl(const Previous& previous, int64 rank) |
| : previous_(previous), rank_(rank) {} |
| |
| bool Match(const ::xla::Shape* shape) const { |
| return previous_.Match(shape) && ShapeUtil::Rank(*shape) == rank_; |
| } |
| |
| private: |
| Previous previous_; |
| int64 rank_; |
| }; |
| |
| // A ShapePattern implementation that matches only if the shape has a layout |
| // that matches a given pattern. |
| template <typename Previous, typename LayoutType, typename LayoutImpl> |
| class ShapePatternLayoutImpl { |
| public: |
| explicit constexpr ShapePatternLayoutImpl( |
| const Previous& previous, |
| const LayoutPattern<LayoutType, LayoutImpl>& layout) |
| : previous_(previous), layout_(layout) {} |
| |
| bool Match(const ::xla::Shape* shape) const { |
| return previous_.Match(shape) && LayoutUtil::HasLayout(*shape) && |
| layout_.Match(&shape->layout()); |
| } |
| |
| bool Match(Shape* shape) const { |
| return previous_.Match(shape) && LayoutUtil::HasLayout(*shape) && |
| layout_.Match(shape->mutable_layout()); |
| } |
| |
| private: |
| Previous previous_; |
| LayoutPattern<LayoutType, LayoutImpl> layout_; |
| }; |
| |
| // A ShapePattern implementation that matches only if the shape has a subshape |
| // that matches a given pattern. |
| template <typename Previous, typename SubshapeType, typename SubshapeImpl> |
| class ShapePatternSubshapeImpl { |
| public: |
| explicit ShapePatternSubshapeImpl( |
| const Previous& previous, ShapeIndexView index, |
| const ShapePattern<SubshapeType, SubshapeImpl>& subshape) |
| : previous_(previous), index_(index), subshape_(subshape) {} |
| |
| bool Match(const ::xla::Shape* shape) const { |
| return previous_.Match(shape) && ShapeUtil::IndexIsValid(*shape, index_) && |
| subshape_.Match(&ShapeUtil::GetSubshape(*shape, index_)); |
| } |
| |
| bool Match(::xla::Shape* shape) const { |
| return previous_.Match(shape) && ShapeUtil::IndexIsValid(*shape, index_) && |
| subshape_.Match(ShapeUtil::GetMutableSubshape(shape, index_)); |
| } |
| |
| private: |
| Previous previous_; |
| ShapeIndexView index_; |
| ShapePattern<SubshapeType, SubshapeImpl> subshape_; |
| }; |
| |
| // A pattern that matches Shapes. |
| template <typename ShapeType, typename Impl> |
| class ShapePattern { |
| public: |
| explicit constexpr ShapePattern(const Impl& impl, ShapeType** matched_shape) |
| : impl_(impl), matched_shape_(matched_shape) {} |
| |
| // Returns true and captures the shape iff it matches the pattern. |
| bool Match(const ::xla::Shape* shape) const { |
| if (impl_.Match(shape)) { |
| if (matched_shape_) { |
| *matched_shape_ = shape; |
| } |
| return true; |
| } |
| return false; |
| } |
| |
| // Returns true and captures the shape iff it matches the pattern. |
| bool Match(::xla::Shape* shape) const { |
| if (impl_.Match(shape)) { |
| if (matched_shape_) { |
| *matched_shape_ = shape; |
| } |
| return true; |
| } |
| return false; |
| } |
| |
| // Modifies the pattern to match only if the shape equals the given proto. |
| // The layout must outlive the returned pattern. |
| constexpr ShapePattern<ShapeType, ShapePatternEqualImpl<Impl>> EqualTo( |
| const ::xla::Shape* shape) const { |
| return ShapePattern<ShapeType, ShapePatternEqualImpl<Impl>>( |
| ShapePatternEqualImpl<Impl>(impl_, shape), matched_shape_); |
| } |
| |
| // Modifies the pattern to match only if the shape is compatible to the given |
| // proto. The layout must outlive the returned pattern. |
| constexpr ShapePattern<ShapeType, ShapePatternCompatibleImpl<Impl>> |
| CompatibleTo(const ::xla::Shape* shape) const { |
| return ShapePattern<ShapeType, ShapePatternCompatibleImpl<Impl>>( |
| ShapePatternCompatibleImpl<Impl>(impl_, shape), matched_shape_); |
| } |
| |
| // Modifies the pattern to match only if the shape has the given element type. |
| constexpr ShapePattern<ShapeType, ShapePatternElementTypeImpl<Impl>> |
| WithElementType(PrimitiveType element_type) const { |
| return ShapePattern<ShapeType, ShapePatternElementTypeImpl<Impl>>( |
| ShapePatternElementTypeImpl<Impl>(impl_, element_type), matched_shape_); |
| } |
| |
| // Modifies the pattern to match only if the shape is scalar. |
| constexpr ShapePattern<ShapeType, ShapePatternIsScalarImpl<Impl>> IsScalar() |
| const { |
| return ShapePattern<ShapeType, ShapePatternIsScalarImpl<Impl>>( |
| ShapePatternIsScalarImpl<Impl>(impl_), matched_shape_); |
| } |
| |
| // Modifies the pattern to match only if the shape is an array. |
| constexpr ShapePattern<ShapeType, ShapePatternIsArrayImpl<Impl>> IsArray() |
| const { |
| return ShapePattern<ShapeType, ShapePatternIsArrayImpl<Impl>>( |
| ShapePatternIsArrayImpl<Impl>(impl_), matched_shape_); |
| } |
| |
| // Modifies the pattern to match only if the shape is a tuple. |
| constexpr ShapePattern<ShapeType, ShapePatternIsTupleImpl<Impl>> IsTuple() |
| const { |
| return ShapePattern<ShapeType, ShapePatternIsTupleImpl<Impl>>( |
| ShapePatternIsTupleImpl<Impl>(impl_), matched_shape_); |
| } |
| |
| // Modifies the pattern to match only if the shape has the given rank. |
| constexpr ShapePattern<ShapeType, ShapePatternRankImpl<Impl>> WithRank( |
| int64 rank) const { |
| return ShapePattern<ShapeType, ShapePatternRankImpl<Impl>>( |
| ShapePatternRankImpl<Impl>(impl_, rank), matched_shape_); |
| } |
| |
| // Modifies the pattern to match only if the shape has a layout that matches |
| // the given pattern. |
| template <typename LayoutType, typename LayoutImpl> |
| constexpr ShapePattern<ShapeType, |
| ShapePatternLayoutImpl<Impl, LayoutType, LayoutImpl>> |
| WithLayout(const LayoutPattern<LayoutType, LayoutImpl>& layout) const { |
| return ShapePattern<ShapeType, |
| ShapePatternLayoutImpl<Impl, LayoutType, LayoutImpl>>( |
| ShapePatternLayoutImpl<Impl, LayoutType, LayoutImpl>(impl_, layout), |
| matched_shape_); |
| } |
| |
| constexpr ShapePattern< |
| ShapeType, |
| ShapePatternLayoutImpl<Impl, const ::xla::Layout, |
| LayoutPatternEqualImpl<LayoutPatternBaseImpl>>> |
| WithLayoutEqualTo(const ::xla::Layout* layout) const { |
| return WithLayout(Layout().EqualTo(layout)); |
| } |
| |
| constexpr ShapePattern< |
| ShapeType, |
| ShapePatternLayoutImpl<Impl, const ::xla::Layout, |
| LayoutPatternFormatImpl<LayoutPatternBaseImpl>>> |
| IsDenseArray() const { |
| return WithLayout(Layout().WithDenseFormat()); |
| } |
| |
| constexpr ShapePattern< |
| ShapeType, |
| ShapePatternLayoutImpl<Impl, const ::xla::Layout, |
| LayoutPatternFormatImpl<LayoutPatternBaseImpl>>> |
| IsSparseArray() const { |
| return WithLayout(Layout().WithSparseFormat()); |
| } |
| |
| // Modifies the pattern to match only if the shape has a subshape that matches |
| // the given pattern. |
| template <typename SubshapeType, typename SubshapeImpl> |
| ShapePattern<ShapeType, |
| ShapePatternSubshapeImpl<Impl, SubshapeType, SubshapeImpl>> |
| WithSubshape(ShapeIndexView index, |
| const ShapePattern<SubshapeType, SubshapeImpl>& subshape) const { |
| return ShapePattern< |
| ShapeType, ShapePatternSubshapeImpl<Impl, SubshapeType, SubshapeImpl>>( |
| ShapePatternSubshapeImpl<Impl, SubshapeType, SubshapeImpl>(impl_, index, |
| subshape), |
| matched_shape_); |
| } |
| |
| ShapePattern<ShapeType, ShapePatternSubshapeImpl< |
| Impl, const ::xla::Shape, |
| ShapePatternEqualImpl<ShapePatternBaseImpl>>> |
| WithSubshapeEqualTo(ShapeIndexView index, const ::xla::Shape* shape) const { |
| return WithSubshape(index, |
| ShapePattern<const ::xla::Shape, ShapePatternBaseImpl>( |
| ShapePatternBaseImpl(), nullptr) |
| .EqualTo(shape)); |
| } |
| |
| ShapePattern<ShapeType, ShapePatternSubshapeImpl< |
| Impl, const ::xla::Shape, |
| ShapePatternCompatibleImpl<ShapePatternBaseImpl>>> |
| WithSubshapeCompatibleTo(ShapeIndexView index, |
| const ::xla::Shape* shape) const { |
| return WithSubshape(index, |
| ShapePattern<const ::xla::Shape, ShapePatternBaseImpl>( |
| ShapePatternBaseImpl(), nullptr) |
| .CompatibleTo(shape)); |
| } |
| |
| private: |
| Impl impl_; |
| ShapeType** matched_shape_; |
| }; |
| |
| } // namespace detail |
| |
| // Creates a shape pattern that will capture the matched layout in the argument. |
| inline constexpr detail::ShapePattern<const ::xla::Shape, |
| detail::ShapePatternBaseImpl> |
| Shape(const ::xla::Shape** matched_shape = nullptr) { |
| return detail::ShapePattern<const ::xla::Shape, detail::ShapePatternBaseImpl>( |
| detail::ShapePatternBaseImpl(), matched_shape); |
| } |
| |
| // Creates a shape pattern that will capture the matched layout in the argument. |
| inline constexpr detail::ShapePattern<::xla::Shape, |
| detail::ShapePatternBaseImpl> |
| Shape(::xla::Shape** matched_shape) { |
| return detail::ShapePattern<::xla::Shape, detail::ShapePatternBaseImpl>( |
| detail::ShapePatternBaseImpl(), matched_shape); |
| } |
| |
| namespace detail { |
| |
| template <typename HloInstructionType, typename Impl> |
| class HloInstructionPattern; |
| |
| // The base HloInstructionPattern implementation. Matches only if the |
| // instruction is not nullptr. |
| class HloInstructionPatternBaseImpl { |
| public: |
| bool Match(const ::xla::HloInstruction* inst) const { |
| return inst != nullptr; |
| } |
| }; |
| |
| // An HloInstructionPattern implementation that matches only if the instruction |
| // has a given name. |
| template <typename Previous> |
| class HloInstructionPatternNameImpl { |
| public: |
| explicit HloInstructionPatternNameImpl(const Previous& previous, |
| absl::string_view name) |
| : previous_(previous), name_(name) {} |
| |
| bool Match(const ::xla::HloInstruction* inst) const { |
| return previous_.Match(inst) && inst->name() == name_; |
| } |
| |
| private: |
| Previous previous_; |
| absl::string_view name_; |
| }; |
| |
| // An HloInstructionPattern implementation that matches only if the instruction |
| // has a given opcode. |
| template <typename Previous> |
| class HloInstructionPatternOpcodeImpl { |
| public: |
| explicit constexpr HloInstructionPatternOpcodeImpl(const Previous& previous, |
| HloOpcode opcode, |
| bool invert) |
| : previous_(previous), opcode_(opcode), invert_(invert) {} |
| |
| bool Match(const ::xla::HloInstruction* inst) const { |
| return previous_.Match(inst) && (invert_ ^ (inst->opcode() == opcode_)); |
| } |
| |
| private: |
| Previous previous_; |
| HloOpcode opcode_; |
| bool invert_; |
| }; |
| |
| // An HloInstructionPattern implementation that matches only if the instruction |
| // has a shape that matches a given pattern. |
| template <typename Previous, typename ShapeType, typename ShapeImpl> |
| class HloInstructionPatternShapeImpl { |
| public: |
| explicit constexpr HloInstructionPatternShapeImpl( |
| const Previous& previous, const ShapePattern<ShapeType, ShapeImpl>& shape) |
| : previous_(previous), shape_(shape) {} |
| |
| bool Match(const ::xla::HloInstruction* inst) const { |
| return previous_.Match(inst) && shape_.Match(&inst->shape()); |
| } |
| |
| bool Match(::xla::HloInstruction* inst) const { |
| return previous_.Match(inst) && shape_.Match(inst->mutable_shape()); |
| } |
| |
| private: |
| Previous previous_; |
| ShapePattern<ShapeType, ShapeImpl> shape_; |
| }; |
| |
| // An HloInstructionPattern implementation that matches only if the instruction |
| // has an operand that matches a given pattern. |
| template <typename Previous, typename OperandType, typename OperandImpl> |
| class HloInstructionPatternOperandImpl { |
| public: |
| explicit constexpr HloInstructionPatternOperandImpl( |
| const Previous& previous, int64 operand_index, |
| const HloInstructionPattern<OperandType, OperandImpl>& operand) |
| : previous_(previous), operand_index_(operand_index), operand_(operand) {} |
| |
| bool Match(const ::xla::HloInstruction* inst) const { |
| return previous_.Match(inst) && operand_index_ < inst->operand_count() && |
| operand_.Match(inst->operand(operand_index_)); |
| } |
| |
| bool Match(::xla::HloInstruction* inst) const { |
| return previous_.Match(inst) && operand_index_ < inst->operand_count() && |
| operand_.Match(inst->mutable_operand(operand_index_)); |
| } |
| |
| private: |
| Previous previous_; |
| int64 operand_index_; |
| HloInstructionPattern<OperandType, OperandImpl> operand_; |
| }; |
| |
| // An HloInstructionPattern implementation that matches only if the instruction |
| // is a fusion node with a particular kind. |
| template <typename Previous> |
| class HloInstructionPatternFusionKindImpl { |
| public: |
| explicit constexpr HloInstructionPatternFusionKindImpl( |
| const Previous& previous, ::xla::HloInstruction::FusionKind kind) |
| : previous_(previous), kind_(kind) {} |
| |
| bool Match(const ::xla::HloInstruction* inst) const { |
| return previous_.Match(inst) && inst->opcode() == HloOpcode::kFusion && |
| inst->fusion_kind() == kind_; |
| } |
| |
| bool Match(::xla::HloInstruction* inst) const { |
| return previous_.Match(inst) && inst->opcode() == HloOpcode::kFusion && |
| inst->fusion_kind() == kind_; |
| } |
| |
| private: |
| Previous previous_; |
| ::xla::HloInstruction::FusionKind kind_; |
| }; |
| |
| // An HloInstructionPattern implementation that matches only if the instruction |
| // is a kGetTupleElement with a particular tuple index. |
| template <typename Previous> |
| class HloInstructionPatternTupleIndexImpl { |
| public: |
| explicit constexpr HloInstructionPatternTupleIndexImpl( |
| const Previous& previous, int64 tuple_index) |
| : previous_(previous), tuple_index_(tuple_index) {} |
| |
| bool Match(const ::xla::HloInstruction* inst) const { |
| return previous_.Match(inst) && |
| inst->opcode() == HloOpcode::kGetTupleElement && |
| inst->tuple_index() == tuple_index_; |
| } |
| |
| bool Match(::xla::HloInstruction* inst) const { |
| return previous_.Match(inst) && |
| inst->opcode() == HloOpcode::kGetTupleElement && |
| inst->tuple_index() == tuple_index_; |
| } |
| |
| private: |
| Previous previous_; |
| int64 tuple_index_; |
| }; |
| |
| template <typename Previous, typename ItemType, typename Predicate> |
| class HloPredicatePatternImpl { |
| public: |
| explicit HloPredicatePatternImpl(const Previous& previous, Predicate pred) |
| : previous_(previous), pred_(std::move(pred)) {} |
| |
| bool Match(const ItemType* item) const { |
| return previous_.Match(item) && pred_(item); |
| } |
| |
| bool Match(ItemType* item) const { |
| return previous_.Match(item) && pred_(item); |
| } |
| |
| private: |
| Previous previous_; |
| Predicate pred_; |
| }; |
| |
| struct PatternFriend; |
| |
| // A pattern that matches HloInstructions. |
| template <typename HloInstructionType, typename Impl> |
| class HloInstructionPattern { |
| public: |
| explicit constexpr HloInstructionPattern(const Impl& impl, |
| HloInstructionType** matched_inst) |
| : impl_(impl), matched_inst_(matched_inst) {} |
| |
| // Returns true and captures the instruction iff it matches the pattern. |
| bool Match(const ::xla::HloInstruction* inst) const { |
| if (impl_.Match(inst)) { |
| if (matched_inst_) { |
| *matched_inst_ = inst; |
| } |
| return true; |
| } |
| return false; |
| } |
| |
| // Returns true and captures the instruction iff it matches the pattern. |
| bool Match(::xla::HloInstruction* inst) const { |
| if (impl_.Match(inst)) { |
| if (matched_inst_) { |
| *matched_inst_ = inst; |
| } |
| return true; |
| } |
| return false; |
| } |
| |
| // Modifies the pattern to match only if the instruction has the given name. |
| HloInstructionPattern<HloInstructionType, HloInstructionPatternNameImpl<Impl>> |
| WithName(absl::string_view name) const { |
| return HloInstructionPattern<HloInstructionType, |
| HloInstructionPatternNameImpl<Impl>>( |
| HloInstructionPatternNameImpl<Impl>(impl_, name), matched_inst_); |
| } |
| |
| // Modifies the pattern to match only if the instruction has the given opcode. |
| constexpr HloInstructionPattern<HloInstructionType, |
| HloInstructionPatternOpcodeImpl<Impl>> |
| WithOpcode(HloOpcode opcode) const { |
| return HloInstructionPattern<HloInstructionType, |
| HloInstructionPatternOpcodeImpl<Impl>>( |
| HloInstructionPatternOpcodeImpl<Impl>(impl_, opcode, false), |
| matched_inst_); |
| } |
| |
| // Modifies the pattern to match only if the instruction does not have the |
| // given opcode. |
| constexpr HloInstructionPattern<HloInstructionType, |
| HloInstructionPatternOpcodeImpl<Impl>> |
| WithoutOpcode(HloOpcode opcode) const { |
| return HloInstructionPattern<HloInstructionType, |
| HloInstructionPatternOpcodeImpl<Impl>>( |
| HloInstructionPatternOpcodeImpl<Impl>(impl_, opcode, true), |
| matched_inst_); |
| } |
| |
| // Modifies the pattern to match only if the instruction is a constant. |
| constexpr HloInstructionPattern<HloInstructionType, |
| HloInstructionPatternOpcodeImpl<Impl>> |
| IsConstant() const { |
| return WithOpcode(HloOpcode::kConstant); |
| } |
| |
| // Modifies the pattern to match only if the instruction is not a constant. |
| constexpr HloInstructionPattern<HloInstructionType, |
| HloInstructionPatternOpcodeImpl<Impl>> |
| IsNonConstant() const { |
| return WithoutOpcode(HloOpcode::kConstant); |
| } |
| |
| // Modifies the pattern to match only if the instruction has a shape that |
| // matches the given pattern. |
| template <typename ShapeType, typename ShapeImpl> |
| constexpr HloInstructionPattern< |
| HloInstructionType, |
| HloInstructionPatternShapeImpl<Impl, ShapeType, ShapeImpl>> |
| WithShape(const ShapePattern<ShapeType, ShapeImpl>& shape) const { |
| return HloInstructionPattern< |
| HloInstructionType, |
| HloInstructionPatternShapeImpl<Impl, ShapeType, ShapeImpl>>( |
| HloInstructionPatternShapeImpl<Impl, ShapeType, ShapeImpl>(impl_, |
| shape), |
| matched_inst_); |
| } |
| |
| // Modifies the pattern to match only if the instruction has an operand that |
| // matches the given pattern. |
| template <typename OperandType, typename OperandImpl> |
| constexpr HloInstructionPattern< |
| HloInstructionType, |
| HloInstructionPatternOperandImpl<Impl, OperandType, OperandImpl>> |
| WithOperand( |
| int64 operand_index, |
| const HloInstructionPattern<OperandType, OperandImpl>& operand) const { |
| return HloInstructionPattern< |
| HloInstructionType, |
| HloInstructionPatternOperandImpl<Impl, OperandType, OperandImpl>>( |
| HloInstructionPatternOperandImpl<Impl, OperandType, OperandImpl>( |
| impl_, operand_index, operand), |
| matched_inst_); |
| } |
| |
| // Modifies the pattern to match only if the instruction is a fusion node with |
| // the given kind. |
| constexpr HloInstructionPattern<HloInstructionType, |
| HloInstructionPatternFusionKindImpl<Impl>> |
| WithFusionKind(HloInstruction::FusionKind kind) const { |
| return HloInstructionPattern<HloInstructionType, |
| HloInstructionPatternFusionKindImpl<Impl>>( |
| HloInstructionPatternFusionKindImpl<Impl>(impl_, kind), matched_inst_); |
| } |
| |
| // Modifies the pattern to match only if the instruction is a |
| // get-tuple-element with the given tuple index. |
| constexpr HloInstructionPattern<HloInstructionType, |
| HloInstructionPatternTupleIndexImpl<Impl>> |
| WithTupleIndex(int64 tuple_index) const { |
| return HloInstructionPattern<HloInstructionType, |
| HloInstructionPatternTupleIndexImpl<Impl>>( |
| HloInstructionPatternTupleIndexImpl<Impl>(impl_, tuple_index), |
| matched_inst_); |
| } |
| |
| private: |
| template <typename Predicate> |
| constexpr HloInstructionPattern< |
| HloInstructionType, |
| HloPredicatePatternImpl< |
| Impl, typename std::remove_const<HloInstructionType>::type, |
| Predicate>> |
| WithPredicate(Predicate pred) const { |
| using NewImplType = HloPredicatePatternImpl< |
| Impl, typename std::remove_const<HloInstructionType>::type, Predicate>; |
| return HloInstructionPattern<HloInstructionType, NewImplType>( |
| NewImplType(impl_, std::move(pred)), matched_inst_); |
| } |
| |
| friend struct PatternFriend; |
| |
| Impl impl_; |
| HloInstructionType** matched_inst_; |
| }; |
| |
| } // namespace detail |
| |
| // Creates an instruction pattern that will capture the matched instruction in |
| // the argument. |
| inline constexpr detail::HloInstructionPattern< |
| const ::xla::HloInstruction, detail::HloInstructionPatternBaseImpl> |
| Op(const ::xla::HloInstruction** matched_inst = nullptr) { |
| return detail::HloInstructionPattern<const ::xla::HloInstruction, |
| detail::HloInstructionPatternBaseImpl>( |
| detail::HloInstructionPatternBaseImpl(), matched_inst); |
| } |
| |
| // Creates an instruction pattern that will capture the matched instruction in |
| // the argument. |
| inline constexpr detail::HloInstructionPattern< |
| ::xla::HloInstruction, detail::HloInstructionPatternBaseImpl> |
| Op(::xla::HloInstruction** matched_inst) { |
| return detail::HloInstructionPattern<::xla::HloInstruction, |
| detail::HloInstructionPatternBaseImpl>( |
| detail::HloInstructionPatternBaseImpl(), matched_inst); |
| } |
| |
| // Helpers for nullary instructions. |
| #define XLA_NULLOP_PATTERN(NAME) \ |
| inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ |
| return Op().WithOpcode(HloOpcode::k##NAME); \ |
| } \ |
| \ |
| template <typename HloInstructionType> \ |
| inline auto NAME(HloInstructionType** matched_inst) \ |
| ->decltype(Op(matched_inst).WithOpcode(HloOpcode::k##NAME)) { \ |
| return Op(matched_inst).WithOpcode(HloOpcode::k##NAME); \ |
| } |
| XLA_NULLOP_PATTERN(Constant) |
| XLA_NULLOP_PATTERN(Parameter) |
| XLA_NULLOP_PATTERN(Iota) |
| #undef XLA_NULLOP_PATTERN |
| |
| // Helpers for unary instructions. |
| #define XLA_UNOP_PATTERN(NAME) \ |
| inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ |
| return Op().WithOpcode(HloOpcode::k##NAME); \ |
| } \ |
| \ |
| template <typename Arg> \ |
| inline auto NAME(Arg&& arg)->decltype( \ |
| Op().WithOpcode(HloOpcode::k##NAME) \ |
| .WithOperand(0, std::forward<Arg>(arg))) { \ |
| return Op() \ |
| .WithOpcode(HloOpcode::k##NAME) \ |
| .WithOperand(0, std::forward<Arg>(arg)); \ |
| } \ |
| \ |
| template <typename HloInstructionType, typename Arg> \ |
| inline auto NAME(HloInstructionType** matched_inst, Arg&& arg) \ |
| ->decltype(Op(matched_inst) \ |
| .WithOpcode(HloOpcode::k##NAME) \ |
| .WithOperand(0, std::forward<Arg>(arg))) { \ |
| return Op(matched_inst) \ |
| .WithOpcode(HloOpcode::k##NAME) \ |
| .WithOperand(0, std::forward<Arg>(arg)); \ |
| } |
| XLA_UNOP_PATTERN(Abs) |
| XLA_UNOP_PATTERN(RoundNearestAfz) |
| XLA_UNOP_PATTERN(Bitcast) |
| XLA_UNOP_PATTERN(Broadcast) |
| XLA_UNOP_PATTERN(Ceil) |
| XLA_UNOP_PATTERN(Copy) |
| XLA_UNOP_PATTERN(Cos) |
| XLA_UNOP_PATTERN(Exp) |
| XLA_UNOP_PATTERN(Fft) |
| XLA_UNOP_PATTERN(Floor) |
| XLA_UNOP_PATTERN(GetTupleElement) |
| XLA_UNOP_PATTERN(Imag) |
| XLA_UNOP_PATTERN(Infeed) |
| XLA_UNOP_PATTERN(IsFinite) |
| XLA_UNOP_PATTERN(Log) |
| XLA_UNOP_PATTERN(Not) |
| XLA_UNOP_PATTERN(Negate) |
| XLA_UNOP_PATTERN(Real) |
| XLA_UNOP_PATTERN(Recv) |
| XLA_UNOP_PATTERN(RecvDone) |
| XLA_UNOP_PATTERN(Reduce) |
| XLA_UNOP_PATTERN(ReducePrecision) |
| XLA_UNOP_PATTERN(Reshape) |
| XLA_UNOP_PATTERN(Reverse) |
| XLA_UNOP_PATTERN(SendDone) |
| XLA_UNOP_PATTERN(Sign) |
| XLA_UNOP_PATTERN(Sin) |
| XLA_UNOP_PATTERN(Sort) |
| XLA_UNOP_PATTERN(Tanh) |
| XLA_UNOP_PATTERN(Transpose) |
| #undef XLA_UNOP_PATTERN |
| |
| // Helpers for binary instructions. |
| #define XLA_BINOP_PATTERN(NAME) \ |
| inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ |
| return Op().WithOpcode(HloOpcode::k##NAME); \ |
| } \ |
| \ |
| template <typename Lhs, typename Rhs> \ |
| inline auto NAME(Lhs&& lhs, Rhs&& rhs) \ |
| ->decltype(Op().WithOpcode(HloOpcode::k##NAME) \ |
| .WithOperand(0, std::forward<Lhs>(lhs)) \ |
| .WithOperand(1, std::forward<Rhs>(rhs))) { \ |
| return Op() \ |
| .WithOpcode(HloOpcode::k##NAME) \ |
| .WithOperand(0, std::forward<Lhs>(lhs)) \ |
| .WithOperand(1, std::forward<Rhs>(rhs)); \ |
| } \ |
| \ |
| template <typename HloInstructionType, typename Lhs, typename Rhs> \ |
| inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) \ |
| ->decltype(Op(matched_inst) \ |
| .WithOpcode(HloOpcode::k##NAME) \ |
| .WithOperand(0, std::forward<Lhs>(lhs)) \ |
| .WithOperand(1, std::forward<Rhs>(rhs))) { \ |
| return Op(matched_inst) \ |
| .WithOpcode(HloOpcode::k##NAME) \ |
| .WithOperand(0, std::forward<Lhs>(lhs)) \ |
| .WithOperand(1, std::forward<Rhs>(rhs)); \ |
| } |
| |
| #define XLA_COMMUTATIVE_BINOP_PATTERN(NAME) \ |
| XLA_BINOP_PATTERN(NAME) \ |
| \ |
| template <typename Lhs, typename Rhs> \ |
| inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \ |
| ->decltype(AnyOf<HloInstruction>(NAME(lhs, rhs), NAME(rhs, lhs))) { \ |
| return AnyOf<HloInstruction>(NAME(lhs, rhs), NAME(rhs, lhs)); \ |
| } \ |
| \ |
| template <typename HloInstructionType, typename Lhs, typename Rhs> \ |
| inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \ |
| Rhs&& rhs) \ |
| ->decltype(AnyOf<HloInstructionType>(NAME(matched_inst, lhs, rhs), \ |
| NAME(matched_inst, rhs, lhs))) { \ |
| return AnyOf<HloInstructionType>(NAME(matched_inst, lhs, rhs), \ |
| NAME(matched_inst, rhs, lhs)); \ |
| } |
| XLA_COMMUTATIVE_BINOP_PATTERN(Add) |
| XLA_BINOP_PATTERN(Atan2) |
| XLA_BINOP_PATTERN(Divide) |
| XLA_BINOP_PATTERN(Complex) |
| XLA_BINOP_PATTERN(Dot) |
| XLA_COMMUTATIVE_BINOP_PATTERN(Eq) |
| XLA_BINOP_PATTERN(Gather) |
| XLA_BINOP_PATTERN(Ge) |
| XLA_BINOP_PATTERN(Gt) |
| XLA_BINOP_PATTERN(Le) |
| XLA_BINOP_PATTERN(Lt) |
| XLA_COMMUTATIVE_BINOP_PATTERN(Maximum) |
| XLA_COMMUTATIVE_BINOP_PATTERN(Minimum) |
| XLA_COMMUTATIVE_BINOP_PATTERN(Multiply) |
| XLA_COMMUTATIVE_BINOP_PATTERN(Ne) |
| XLA_BINOP_PATTERN(Outfeed) |
| XLA_BINOP_PATTERN(Power) |
| XLA_BINOP_PATTERN(Remainder) |
| XLA_BINOP_PATTERN(Send) |
| XLA_BINOP_PATTERN(Subtract) |
| XLA_COMMUTATIVE_BINOP_PATTERN(And) |
| XLA_COMMUTATIVE_BINOP_PATTERN(Or) |
| XLA_BINOP_PATTERN(ShiftLeft) |
| XLA_BINOP_PATTERN(ShiftRightArithmetic) |
| XLA_BINOP_PATTERN(ShiftRightLogical) |
| #undef XLA_COMMUTATIVE_BINOP_PATTERN |
| #undef XLA_BINOP_PATTERN |
| |
| // Helpers for ternary instructions. |
| #define XLA_TERNOP_PATTERN(NAME) \ |
| inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ |
| return Op().WithOpcode(HloOpcode::k##NAME); \ |
| } \ |
| \ |
| template <typename Arg0, typename Arg1, typename Arg2> \ |
| inline auto NAME(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2) \ |
| ->decltype(Op().WithOpcode(HloOpcode::k##NAME) \ |
| .WithOperand(0, std::forward<Arg0>(arg0)) \ |
| .WithOperand(1, std::forward<Arg1>(arg1)) \ |
| .WithOperand(2, std::forward<Arg2>(arg2))) { \ |
| return Op() \ |
| .WithOpcode(HloOpcode::k##NAME) \ |
| .WithOperand(0, std::forward<Arg0>(arg0)) \ |
| .WithOperand(1, std::forward<Arg1>(arg1)) \ |
| .WithOperand(2, std::forward<Arg2>(arg2)); \ |
| } \ |
| \ |
| template <typename HloInstructionType, typename Arg0, typename Arg1, \ |
| typename Arg2> \ |
| inline auto NAME(HloInstructionType** matched_inst, Arg0&& arg0, \ |
| Arg1&& arg1, Arg2&& arg2) \ |
| ->decltype(Op(matched_inst) \ |
| .WithOpcode(HloOpcode::k##NAME) \ |
| .WithOperand(0, std::forward<Arg0>(arg0)) \ |
| .WithOperand(1, std::forward<Arg1>(arg1)) \ |
| .WithOperand(2, std::forward<Arg2>(arg2))) { \ |
| return Op(matched_inst) \ |
| .WithOpcode(HloOpcode::k##NAME) \ |
| .WithOperand(0, std::forward<Arg0>(arg0)) \ |
| .WithOperand(1, std::forward<Arg1>(arg1)) \ |
| .WithOperand(2, std::forward<Arg2>(arg2)); \ |
| } |
| XLA_TERNOP_PATTERN(Clamp); |
| XLA_TERNOP_PATTERN(Select); |
| #undef XLA_TERNOP_PATTERN |
| |
| namespace detail { |
| struct PatternFriend { |
| template <typename T> |
| static auto ConstantScalar(T constant) -> decltype( |
| Constant() |
| .WithShape(match::Shape().IsScalar()) |
| .WithPredicate( |
| std::declval<std::function<bool(const HloInstruction*)>>())) { |
| std::function<bool(const HloInstruction*)> pred = |
| [constant](const HloInstruction* instr) { |
| const auto& literal = Cast<HloConstantInstruction>(instr)->literal(); |
| auto status_or_const = LiteralUtil::CreateR0(constant).Convert( |
| literal.shape().element_type()); |
| return status_or_const.ok() && |
| literal == status_or_const.ConsumeValueOrDie(); |
| }; |
| |
| return Constant() |
| .WithShape(match::Shape().IsScalar()) |
| .WithPredicate(std::move(pred)); |
| } |
| }; |
| } // namespace detail |
| |
| // Helpers for matching non-constant instructions. |
| inline auto NonConstant() -> decltype(Op().IsNonConstant()) { |
| return Op().IsNonConstant(); |
| } |
| |
| template <typename HloInstructionType> |
| inline auto NonConstant(HloInstructionType** matched_inst) |
| -> decltype(Op(matched_inst).IsNonConstant()) { |
| return Op(matched_inst).IsNonConstant(); |
| } |
| |
| // Add overloads for GetTupleElement which take a int64 specifying which tuple |
| // element is selected. |
| template <typename Arg> |
| inline auto GetTupleElement(Arg&& arg, int64 tuple_index) |
| -> decltype(Op().WithOpcode(HloOpcode::kGetTupleElement) |
| .WithOperand(0, std::forward<Arg>(arg)) |
| .WithTupleIndex(tuple_index)) { |
| return Op() |
| .WithOpcode(HloOpcode::kGetTupleElement) |
| .WithOperand(0, std::forward<Arg>(arg)) |
| .WithTupleIndex(tuple_index); |
| } |
| |
| template <typename HloInstructionType, typename Arg> |
| inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg, |
| int64 tuple_index) |
| -> decltype(Op(matched_inst) |
| .WithOpcode(HloOpcode::kGetTupleElement) |
| .WithOperand(0, std::forward<Arg>(arg)) |
| .WithTupleIndex(tuple_index)) { |
| return Op(matched_inst) |
| .WithOpcode(HloOpcode::kGetTupleElement) |
| .WithOperand(0, std::forward<Arg>(arg)) |
| .WithTupleIndex(tuple_index); |
| } |
| |
| template <typename T> |
| inline auto ConstantScalar(T constant) |
| -> decltype(detail::PatternFriend::ConstantScalar(constant)) { |
| return detail::PatternFriend::ConstantScalar(constant); |
| } |
| |
| } // namespace match |
| |
| } // namespace xla |
| |
| #endif // TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ |