Use C++14 `auto` return type inference in pattern_matcher.h.
In C++11, we had to say
auto fn() -> decltype(...) { ... }
but in C++14, we can simply say
auto fn() { ... }
Do this in pattern_matcher.h.
PiperOrigin-RevId: 327085487
Change-Id: Ida65f273a45c8ac7a59a1876f1db1e5fc895a8f6
diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h
index febbf92..eb29fa8 100644
--- a/tensorflow/compiler/xla/service/pattern_matcher.h
+++ b/tensorflow/compiler/xla/service/pattern_matcher.h
@@ -351,8 +351,7 @@
// Returns a pattern that represents the conjunction of all input patterns. All
// patterns need to match in order to have the AllOf pattern match.
template <typename Item, typename... Patterns>
-detail::AllOfPattern<typename std::remove_const<Item>::type, Patterns...> AllOf(
- const Patterns&... patterns) {
+auto AllOf(const Patterns&... patterns) {
return detail::AllOfPattern<typename std::remove_const<Item>::type,
Patterns...>(patterns...);
}
@@ -361,10 +360,8 @@
//
// This transformation is necessary for good pretty-printing.
template <typename Item, typename... InnerPs, typename... OuterPs>
-detail::AllOfPattern<typename std::remove_const<Item>::type, InnerPs...,
- OuterPs...>
-AllOf(const detail::AllOfPattern<Item, InnerPs...>& inner_p,
- const OuterPs&... outer_ps) {
+auto AllOf(const detail::AllOfPattern<Item, InnerPs...>& inner_p,
+ const OuterPs&... outer_ps) {
// Invoke constructor of AllOfPattern<Item, InnerPs..., OuterPs...>.
auto make_all_of = [](const InnerPs&... inner_ps,
const OuterPs&... outer_ps) {
@@ -453,10 +450,7 @@
class LayoutPattern {
private:
template <typename NewImpl>
- auto AppendImpl(NewImpl new_impl) const
- -> LayoutPattern<LayoutType,
- decltype(AllOf<::xla::Layout>(std::declval<Impl>(),
- std::move(new_impl)))> {
+ auto AppendImpl(NewImpl new_impl) const {
auto new_allof = AllOf<::xla::Layout>(impl_, std::move(new_impl));
return LayoutPattern<LayoutType, decltype(new_allof)>(std::move(new_allof),
matched_layout_);
@@ -495,14 +489,12 @@
// Modifies the pattern to match only if the layout equals the given proto.
// The layout must outlive the returned pattern.
- constexpr auto EqualTo(const ::xla::Layout* layout) const
- -> decltype(this->AppendImpl(LayoutPatternEqualImpl(layout))) {
+ constexpr auto EqualTo(const ::xla::Layout* layout) const {
return AppendImpl(LayoutPatternEqualImpl(layout));
}
// Modifies the pattern to match only if the layout has a dense format.
- constexpr auto WithDenseFormat() const
- -> decltype(this->AppendImpl(LayoutPatternFormatImpl(DENSE))) {
+ constexpr auto WithDenseFormat() const {
return AppendImpl(LayoutPatternFormatImpl(DENSE));
}
@@ -626,17 +618,14 @@
// patterns. The returned pattern matches from left to right, and stops on the
// first match.
template <typename Item, typename... Patterns>
-detail::AnyOfPattern<typename std::remove_const<Item>::type, Patterns...> AnyOf(
- const Patterns&... patterns) {
+auto AnyOf(const Patterns&... patterns) {
return detail::AnyOfPattern<typename std::remove_const<Item>::type,
Patterns...>(patterns...);
}
// Creates a layout pattern that will capture the matched layout in the
// argument.
-inline constexpr detail::LayoutPattern<const ::xla::Layout,
- detail::LayoutPatternBaseImpl>
-Layout(const ::xla::Layout** matched_layout = nullptr) {
+inline constexpr auto Layout(const ::xla::Layout** matched_layout = nullptr) {
return detail::LayoutPattern<const ::xla::Layout,
detail::LayoutPatternBaseImpl>(
detail::LayoutPatternBaseImpl(), matched_layout);
@@ -644,9 +633,7 @@
// Creates a layout pattern that will capture the matched layout in the
// argument.
-inline constexpr detail::LayoutPattern<::xla::Layout,
- detail::LayoutPatternBaseImpl>
-Layout(::xla::Layout** matched_layout) {
+inline constexpr auto Layout(::xla::Layout** matched_layout) {
return detail::LayoutPattern<::xla::Layout, detail::LayoutPatternBaseImpl>(
detail::LayoutPatternBaseImpl(), matched_layout);
}
@@ -939,10 +926,7 @@
class ShapePattern {
private:
template <typename NewImpl>
- auto AppendImpl(NewImpl new_impl) const
- -> ShapePattern<ShapeType,
- decltype(AllOf<::xla::Shape>(std::declval<Impl>(),
- std::move(new_impl)))> {
+ auto AppendImpl(NewImpl new_impl) const {
auto new_all_of = AllOf<::xla::Shape>(impl_, std::move(new_impl));
return ShapePattern<ShapeType, decltype(new_all_of)>(std::move(new_all_of),
matched_shape_);
@@ -988,80 +972,66 @@
// Modifies the pattern to match only if the shape equals the given proto.
// The layout must outlive the returned pattern.
- constexpr auto EqualTo(const ::xla::Shape* shape) const
- -> decltype(this->AppendImpl(ShapePatternEqualImpl(shape))) {
+ constexpr auto EqualTo(const ::xla::Shape* shape) const {
return AppendImpl(ShapePatternEqualImpl(shape));
}
// Modifies the pattern to match only if the shape is compatible to the given
// proto. The layout must outlive the returned pattern.
- constexpr auto CompatibleTo(const ::xla::Shape* shape) const
- -> decltype(this->AppendImpl(ShapePatternCompatibleImpl(shape))) {
+ constexpr auto CompatibleTo(const ::xla::Shape* shape) const {
return AppendImpl(ShapePatternCompatibleImpl(shape));
}
// Modifies the pattern to match only if the shape has the given element type.
- constexpr auto WithElementType(PrimitiveType element_type) const
- -> decltype(this->AppendImpl(ShapePatternElementTypeImpl(element_type))) {
+ constexpr auto WithElementType(PrimitiveType element_type) const {
return AppendImpl(ShapePatternElementTypeImpl(element_type));
}
// Modifies the pattern to match only if the shape is scalar.
- constexpr auto IsScalar() const
- -> decltype(this->AppendImpl(ShapePatternIsScalarImpl())) {
+ constexpr auto IsScalar() const {
return AppendImpl(ShapePatternIsScalarImpl());
}
// Modifies the pattern to match only if the shape is an array.
- constexpr auto IsArray() const
- -> decltype(this->AppendImpl(ShapePatternIsArrayImpl())) {
+ constexpr auto IsArray() const {
return AppendImpl(ShapePatternIsArrayImpl());
}
// Modifies the pattern to match only if the shape is a tuple.
- constexpr auto IsTuple() const
- -> decltype(this->AppendImpl(ShapePatternIsTupleImpl())) {
+ constexpr auto IsTuple() const {
return AppendImpl(ShapePatternIsTupleImpl());
}
- constexpr auto IsEffectiveScalar() const
- -> decltype(this->AppendImpl(ShapePatternEffectiveScalarImpl())) {
+ constexpr auto IsEffectiveScalar() const {
return AppendImpl(ShapePatternEffectiveScalarImpl());
}
// Modifies the pattern to match only if the shape has the given rank.
- constexpr auto WithRank(int64 rank) const
- -> decltype(this->AppendImpl(ShapePatternRankImpl(rank))) {
+ constexpr auto WithRank(int64 rank) const {
return AppendImpl(ShapePatternRankImpl(rank));
}
// Modifies the pattern to match only if the shape has a layout that matches
// the given pattern.
template <typename LayoutType, typename LayoutImpl>
- auto WithLayout(const LayoutPattern<LayoutType, LayoutImpl>& layout) const
- -> decltype(this->AppendImpl(
- ShapePatternLayoutImpl<LayoutType, LayoutImpl>(layout))) {
+ auto WithLayout(const LayoutPattern<LayoutType, LayoutImpl>& layout) const {
return AppendImpl(ShapePatternLayoutImpl<LayoutType, LayoutImpl>(layout));
}
- constexpr auto WithLayoutEqualTo(const ::xla::Layout* layout) const
- -> decltype(this->WithLayout(Layout().EqualTo(layout))) {
+ constexpr auto WithLayoutEqualTo(const ::xla::Layout* layout) const {
return WithLayout(Layout().EqualTo(layout));
}
- constexpr auto IsDenseArray() const
- -> decltype(this->WithLayout(Layout().WithDenseFormat())) {
+ constexpr auto IsDenseArray() const {
return WithLayout(Layout().WithDenseFormat());
}
// Modifies the pattern to match only if the shape has a subshape that matches
// the given pattern.
template <typename SubshapeType, typename SubshapeImpl>
- auto WithSubshape(ShapeIndexView index,
- const ShapePattern<SubshapeType, SubshapeImpl>& subshape)
- const -> decltype(this->AppendImpl(
- ShapePatternSubshapeImpl<SubshapeType, SubshapeImpl>(index,
- subshape))) {
+ auto WithSubshape(
+ ShapeIndexView index,
+ const ShapePattern<SubshapeType, SubshapeImpl>& subshape) const {
return AppendImpl(
ShapePatternSubshapeImpl<SubshapeType, SubshapeImpl>(index, subshape));
}
@@ -1101,17 +1071,13 @@
} // namespace detail
// Creates a shape pattern that will capture the matched layout in the argument.
-inline constexpr detail::ShapePattern<const ::xla::Shape,
- detail::ShapePatternBaseImpl>
-Shape(const ::xla::Shape** matched_shape = nullptr) {
+inline constexpr auto Shape(const ::xla::Shape** matched_shape = nullptr) {
return detail::ShapePattern<const ::xla::Shape, detail::ShapePatternBaseImpl>(
detail::ShapePatternBaseImpl(), matched_shape);
}
// Creates a shape pattern that will capture the matched layout in the argument.
-inline constexpr detail::ShapePattern<::xla::Shape,
- detail::ShapePatternBaseImpl>
-Shape(::xla::Shape** matched_shape) {
+inline constexpr auto Shape(::xla::Shape** matched_shape) {
return detail::ShapePattern<::xla::Shape, detail::ShapePatternBaseImpl>(
detail::ShapePatternBaseImpl(), matched_shape);
}
@@ -1797,9 +1763,7 @@
class HloInstructionPattern {
private:
template <typename NewImpl>
- auto AppendImpl(NewImpl new_impl) const -> HloInstructionPattern<
- HloInstructionType, decltype(AllOf<::xla::HloInstruction>(
- std::declval<Impl>(), std::move(new_impl)))> {
+ auto AppendImpl(NewImpl new_impl) const {
auto new_allof = AllOf<::xla::HloInstruction>(impl_, std::move(new_impl));
return HloInstructionPattern<HloInstructionType, decltype(new_allof)>(
std::move(new_allof), matched_inst_);
@@ -1837,51 +1801,38 @@
}
// Modifies the pattern to match only if the instruction has the given name.
- auto WithName(absl::string_view name) const
- -> decltype(this->AppendImpl(HloInstructionPatternNameImpl(name))) {
+ auto WithName(absl::string_view name) const {
return AppendImpl(HloInstructionPatternNameImpl(name));
}
// Modifies the pattern to match only if the instruction has the given opcode.
- auto WithOpcode(HloOpcode opcode) const
- -> decltype(this->AppendImpl(HloInstructionPatternOpcodeImpl(opcode,
- false))) {
+ auto WithOpcode(HloOpcode opcode) const {
return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, false));
}
// Modifies the pattern to match only the custom call with a given target.
- auto WithCustomCallTarget(absl::string_view custom_call_target) const
- -> decltype(this->AppendImpl(
- HloInstructionCustomCallTargetImpl(custom_call_target))) {
+ auto WithCustomCallTarget(absl::string_view custom_call_target) const {
return AppendImpl(HloInstructionCustomCallTargetImpl(custom_call_target));
}
- auto WithNumOperands(int64 num_operands) const -> decltype(
- this->AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands))) {
+ auto WithNumOperands(int64 num_operands) const {
return AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands));
}
// Modifies the pattern to match only if the instruction does not have the
// given opcode.
- auto WithoutOpcode(HloOpcode opcode) const
- -> decltype(this->AppendImpl(HloInstructionPatternOpcodeImpl(opcode,
- true))) {
+ auto WithoutOpcode(HloOpcode opcode) const {
return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, true));
}
- constexpr auto Is(const HloInstruction* instr) const
- -> decltype(this->AppendImpl(HloInstructionIsImpl(instr))) {
+ constexpr auto Is(const HloInstruction* instr) const {
return AppendImpl(HloInstructionIsImpl(instr));
}
// Modifies the pattern to match only if the instruction is a constant.
- constexpr auto IsConstant() const
- -> decltype(this->WithOpcode(HloOpcode::kConstant)) {
- return WithOpcode(HloOpcode::kConstant);
- }
+ constexpr auto IsConstant() const { return WithOpcode(HloOpcode::kConstant); }
- constexpr auto IsConstantScalar() const -> decltype(this->AppendImpl(
- HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/false))) {
+ constexpr auto IsConstantScalar() const {
return AppendImpl(
HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/false));
}
@@ -1889,39 +1840,32 @@
// This does not check that T has the same type as the instruction, so e.g.
// IsConstantScalar(1.0) may match a constant of shape int32[].
template <typename ScalarTy>
- constexpr auto IsConstantScalar(const ScalarTy& val) const
- -> decltype(this->AppendImpl(HloConstantScalarImpl<ScalarTy>(
- val, /*match_effective_scalar=*/false))) {
+ constexpr auto IsConstantScalar(const ScalarTy& val) const {
return AppendImpl(
HloConstantScalarImpl<ScalarTy>(val, /*match_effective_scalar=*/false));
}
- constexpr auto IsConstantEffectiveScalar() const -> decltype(this->AppendImpl(
- HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/true))) {
+ constexpr auto IsConstantEffectiveScalar() const {
return AppendImpl(
HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/true));
}
template <typename ScalarTy>
- constexpr auto IsConstantEffectiveScalar(const ScalarTy& val) const
- -> decltype(this->AppendImpl(HloConstantScalarImpl<ScalarTy>(
- val, /*match_effective_scalar=*/true))) {
+ constexpr auto IsConstantEffectiveScalar(const ScalarTy& val) const {
return AppendImpl(
HloConstantScalarImpl<ScalarTy>(val, /*match_effective_scalar=*/true));
}
// Modifies the pattern to match only if the instruction is not a constant.
- constexpr auto IsNonConstant() const
- -> decltype(this->WithoutOpcode(HloOpcode::kConstant)) {
+ constexpr auto IsNonConstant() const {
return WithoutOpcode(HloOpcode::kConstant);
}
// Modifies the pattern to match only if the instruction has a shape that
// matches the given pattern.
template <typename ShapeType, typename ShapeImpl>
- constexpr auto WithShape(const ShapePattern<ShapeType, ShapeImpl>& shape)
- const -> decltype(this->AppendImpl(
- HloInstructionPatternShapeImpl<ShapeType, ShapeImpl>(shape))) {
+ constexpr auto WithShape(
+ const ShapePattern<ShapeType, ShapeImpl>& shape) const {
return AppendImpl(
HloInstructionPatternShapeImpl<ShapeType, ShapeImpl>(shape));
}
@@ -1929,16 +1873,14 @@
// Make this a templated function to work around gcc 4.9.4 template infinite
// recursion bug.
template <typename Dummy = void>
- constexpr auto WithShapeEqualTo(const ::xla::Shape* shape) const
- -> decltype(this->WithShape(Shape().EqualTo(shape))) {
+ constexpr auto WithShapeEqualTo(const ::xla::Shape* shape) const {
return WithShape(Shape().EqualTo(shape));
}
// Make this a templated function to work around gcc 4.9.4 template infinite
// recursion bug.
template <typename Dummy = void>
- constexpr auto WithShapeCompatibleTo(const ::xla::Shape* shape) const
- -> decltype(this->WithShape(Shape().CompatibleTo(shape))) {
+ constexpr auto WithShapeCompatibleTo(const ::xla::Shape* shape) const {
return WithShape(Shape().CompatibleTo(shape));
}
@@ -1947,10 +1889,7 @@
template <typename OperandType, typename OperandImpl>
constexpr auto WithOperand(
int64 operand_index,
- const HloInstructionPattern<OperandType, OperandImpl>& operand) const
- -> decltype(this->AppendImpl(
- HloInstructionPatternOperandImpl<OperandType, OperandImpl>(
- operand_index, operand))) {
+ const HloInstructionPattern<OperandType, OperandImpl>& operand) const {
return AppendImpl(
HloInstructionPatternOperandImpl<OperandType, OperandImpl>(
operand_index, operand));
@@ -1960,11 +1899,7 @@
typename OperandImpl2>
constexpr auto WithBinaryOperandsAnyOrder(
const HloInstructionPattern<OperandType1, OperandImpl1>& op1,
- const HloInstructionPattern<OperandType2, OperandImpl2>& op2) const
- -> decltype(this->AppendImpl(
- HloInstructionPatternBinaryOperandsAnyOrderImpl<
- OperandType1, OperandImpl1, OperandType2, OperandImpl2>(op1,
- op2))) {
+ const HloInstructionPattern<OperandType2, OperandImpl2>& op2) const {
return AppendImpl(
HloInstructionPatternBinaryOperandsAnyOrderImpl<
OperandType1, OperandImpl1, OperandType2, OperandImpl2>(op1, op2));
@@ -1972,46 +1907,39 @@
// Modifies the pattern to match only if the instruction is a fusion node with
// the given kind.
- constexpr auto WithFusionKind(HloInstruction::FusionKind kind) const
- -> decltype(this->AppendImpl(HloInstructionPatternFusionKindImpl(kind))) {
+ constexpr auto WithFusionKind(HloInstruction::FusionKind kind) const {
return AppendImpl(HloInstructionPatternFusionKindImpl(kind));
}
// Modifies the pattern to match only if the instruction is a
// get-tuple-element with the given tuple index.
- constexpr auto WithTupleIndex(int64 tuple_index) const -> decltype(
- this->AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index))) {
+ constexpr auto WithTupleIndex(int64 tuple_index) const {
return AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index));
}
// Modifies the pattern to match only if the instruction is a parameter
// with the given parameter number.
- constexpr auto WithParameterNum(int64 parameter_num) const -> decltype(
- this->AppendImpl(HloInstructionPatternParameterNumImpl(parameter_num))) {
+ constexpr auto WithParameterNum(int64 parameter_num) const {
return AppendImpl(HloInstructionPatternParameterNumImpl(parameter_num));
}
// Modifies the pattern to match if the instruction is used exactly once.
// Does not match if the instruction is used twice by the same user (e.g.
// multiply(x,x)).
- constexpr auto WithOneUse() const
- -> decltype(this->AppendImpl(HloInstructionPatternOneUseImpl())) {
+ constexpr auto WithOneUse() const {
return AppendImpl(HloInstructionPatternOneUseImpl());
}
// Modifies the pattern to match if the instruction is used by exactly one
// other instruction. Will match if the instruction is used twice, so long as
// it's by the same user (e.g. multiply(x,x)).
- constexpr auto WithOneUser() const
- -> decltype(this->AppendImpl(HloInstructionPatternOneUserImpl())) {
+ constexpr auto WithOneUser() const {
return AppendImpl(HloInstructionPatternOneUserImpl());
}
// Modifies the pattern to match only if the instruction has the given
// comparison direction.
- auto WithComparisonDirection(ComparisonDirection direction) const
- -> decltype(this->AppendImpl(
- HloInstructionPatternComparisonDirectionImpl(direction))) {
+ auto WithComparisonDirection(ComparisonDirection direction) const {
return AppendImpl(HloInstructionPatternComparisonDirectionImpl(direction));
}
@@ -2028,9 +1956,7 @@
// Creates an instruction pattern that will capture the matched instruction in
// the argument.
-inline constexpr detail::HloInstructionPattern<
- const ::xla::HloInstruction, detail::HloInstructionPatternBaseImpl>
-Op(const ::xla::HloInstruction** matched_inst = nullptr) {
+inline constexpr auto Op(const ::xla::HloInstruction** matched_inst = nullptr) {
return detail::HloInstructionPattern<const ::xla::HloInstruction,
detail::HloInstructionPatternBaseImpl>(
detail::HloInstructionPatternBaseImpl(), matched_inst);
@@ -2038,24 +1964,19 @@
// Creates an instruction pattern that will capture the matched instruction in
// the argument.
-inline constexpr detail::HloInstructionPattern<
- ::xla::HloInstruction, detail::HloInstructionPatternBaseImpl>
-Op(::xla::HloInstruction** matched_inst) {
+inline constexpr auto Op(::xla::HloInstruction** matched_inst) {
return detail::HloInstructionPattern<::xla::HloInstruction,
detail::HloInstructionPatternBaseImpl>(
detail::HloInstructionPatternBaseImpl(), matched_inst);
}
// Helpers for nullary instructions.
-#define XLA_NULLOP_PATTERN(NAME) \
- inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \
- return Op().WithOpcode(HloOpcode::k##NAME); \
- } \
- \
- template <typename HloInstructionType> \
- inline auto NAME(HloInstructionType** matched_inst) \
- ->decltype(Op(matched_inst).WithOpcode(HloOpcode::k##NAME)) { \
- return Op(matched_inst).WithOpcode(HloOpcode::k##NAME); \
+#define XLA_NULLOP_PATTERN(NAME) \
+ inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \
+ \
+ template <typename HloInstructionType> \
+ inline auto NAME(HloInstructionType** matched_inst) { \
+ return Op(matched_inst).WithOpcode(HloOpcode::k##NAME); \
}
XLA_NULLOP_PATTERN(Constant)
XLA_NULLOP_PATTERN(Parameter)
@@ -2064,28 +1985,21 @@
#undef XLA_NULLOP_PATTERN
// Helpers for unary instructions.
-#define XLA_UNOP_PATTERN(NAME) \
- inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \
- return Op().WithOpcode(HloOpcode::k##NAME); \
- } \
- \
- template <typename Arg> \
- inline auto NAME(Arg&& arg)->decltype( \
- Op().WithOpcode(HloOpcode::k##NAME) \
- .WithOperand(0, std::forward<Arg>(arg))) { \
- return Op() \
- .WithOpcode(HloOpcode::k##NAME) \
- .WithOperand(0, std::forward<Arg>(arg)); \
- } \
- \
- template <typename HloInstructionType, typename Arg> \
- inline auto NAME(HloInstructionType** matched_inst, Arg&& arg) \
- ->decltype(Op(matched_inst) \
- .WithOpcode(HloOpcode::k##NAME) \
- .WithOperand(0, std::forward<Arg>(arg))) { \
- return Op(matched_inst) \
- .WithOpcode(HloOpcode::k##NAME) \
- .WithOperand(0, std::forward<Arg>(arg)); \
+#define XLA_UNOP_PATTERN(NAME) \
+ inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \
+ \
+ template <typename Arg> \
+ inline auto NAME(Arg&& arg) { \
+ return Op() \
+ .WithOpcode(HloOpcode::k##NAME) \
+ .WithOperand(0, std::forward<Arg>(arg)); \
+ } \
+ \
+ template <typename HloInstructionType, typename Arg> \
+ inline auto NAME(HloInstructionType** matched_inst, Arg&& arg) { \
+ return Op(matched_inst) \
+ .WithOpcode(HloOpcode::k##NAME) \
+ .WithOperand(0, std::forward<Arg>(arg)); \
}
XLA_UNOP_PATTERN(Abs)
XLA_UNOP_PATTERN(RoundNearestAfz)
@@ -2124,55 +2038,40 @@
#undef XLA_UNOP_PATTERN
// Helpers for binary instructions.
-#define XLA_BINOP_PATTERN(NAME) \
- inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \
- return Op().WithOpcode(HloOpcode::k##NAME); \
- } \
- \
- template <typename Lhs, typename Rhs> \
- inline auto NAME(Lhs&& lhs, Rhs&& rhs) \
- ->decltype(Op().WithOpcode(HloOpcode::k##NAME) \
- .WithOperand(0, std::forward<Lhs>(lhs)) \
- .WithOperand(1, std::forward<Rhs>(rhs))) { \
- return Op() \
- .WithOpcode(HloOpcode::k##NAME) \
- .WithOperand(0, std::forward<Lhs>(lhs)) \
- .WithOperand(1, std::forward<Rhs>(rhs)); \
- } \
- \
- template <typename HloInstructionType, typename Lhs, typename Rhs> \
- inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) \
- ->decltype(Op(matched_inst) \
- .WithOpcode(HloOpcode::k##NAME) \
- .WithOperand(0, std::forward<Lhs>(lhs)) \
- .WithOperand(1, std::forward<Rhs>(rhs))) { \
- return Op(matched_inst) \
- .WithOpcode(HloOpcode::k##NAME) \
- .WithOperand(0, std::forward<Lhs>(lhs)) \
- .WithOperand(1, std::forward<Rhs>(rhs)); \
+#define XLA_BINOP_PATTERN(NAME) \
+ inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \
+ \
+ template <typename Lhs, typename Rhs> \
+ inline auto NAME(Lhs&& lhs, Rhs&& rhs) { \
+ return Op() \
+ .WithOpcode(HloOpcode::k##NAME) \
+ .WithOperand(0, std::forward<Lhs>(lhs)) \
+ .WithOperand(1, std::forward<Rhs>(rhs)); \
+ } \
+ \
+ template <typename HloInstructionType, typename Lhs, typename Rhs> \
+ inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) { \
+ return Op(matched_inst) \
+ .WithOpcode(HloOpcode::k##NAME) \
+ .WithOperand(0, std::forward<Lhs>(lhs)) \
+ .WithOperand(1, std::forward<Rhs>(rhs)); \
}
-#define XLA_COMMUTATIVE_BINOP_PATTERN(NAME) \
- XLA_BINOP_PATTERN(NAME) \
- \
- template <typename HloInstructionType, typename Lhs, typename Rhs> \
- inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \
- Rhs&& rhs) \
- ->decltype(Op(matched_inst) \
- .WithOpcode(HloOpcode::k##NAME) \
- .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs), \
- std::forward<Rhs>(rhs))) { \
- return Op(matched_inst) \
- .WithOpcode(HloOpcode::k##NAME) \
- .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs), \
- std::forward<Rhs>(rhs)); \
- } \
- template <typename Lhs, typename Rhs> \
- inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \
- ->decltype(NAME##AnyOrder<const HloInstruction>( \
- nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs))) { \
- return NAME##AnyOrder<const HloInstruction>( \
- nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs)); \
+#define XLA_COMMUTATIVE_BINOP_PATTERN(NAME) \
+ XLA_BINOP_PATTERN(NAME) \
+ \
+ template <typename HloInstructionType, typename Lhs, typename Rhs> \
+ inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \
+ Rhs&& rhs) { \
+ return Op(matched_inst) \
+ .WithOpcode(HloOpcode::k##NAME) \
+ .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs), \
+ std::forward<Rhs>(rhs)); \
+ } \
+ template <typename Lhs, typename Rhs> \
+ inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) { \
+ return NAME##AnyOrder<const HloInstruction>( \
+ nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs)); \
}
XLA_COMMUTATIVE_BINOP_PATTERN(Add)
XLA_BINOP_PATTERN(Atan2)
@@ -2202,16 +2101,10 @@
// Helpers for ternary instructions.
#define XLA_TERNOP_PATTERN(NAME) \
- inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \
- return Op().WithOpcode(HloOpcode::k##NAME); \
- } \
+ inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \
\
template <typename Arg0, typename Arg1, typename Arg2> \
- inline auto NAME(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2) \
- ->decltype(Op().WithOpcode(HloOpcode::k##NAME) \
- .WithOperand(0, std::forward<Arg0>(arg0)) \
- .WithOperand(1, std::forward<Arg1>(arg1)) \
- .WithOperand(2, std::forward<Arg2>(arg2))) { \
+ inline auto NAME(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2) { \
return Op() \
.WithOpcode(HloOpcode::k##NAME) \
.WithOperand(0, std::forward<Arg0>(arg0)) \
@@ -2222,12 +2115,7 @@
template <typename HloInstructionType, typename Arg0, typename Arg1, \
typename Arg2> \
inline auto NAME(HloInstructionType** matched_inst, Arg0&& arg0, \
- Arg1&& arg1, Arg2&& arg2) \
- ->decltype(Op(matched_inst) \
- .WithOpcode(HloOpcode::k##NAME) \
- .WithOperand(0, std::forward<Arg0>(arg0)) \
- .WithOperand(1, std::forward<Arg1>(arg1)) \
- .WithOperand(2, std::forward<Arg2>(arg2))) { \
+ Arg1&& arg1, Arg2&& arg2) { \
return Op(matched_inst) \
.WithOpcode(HloOpcode::k##NAME) \
.WithOperand(0, std::forward<Arg0>(arg0)) \
@@ -2241,17 +2129,13 @@
namespace detail {
template <typename Matcher, typename FirstArg>
-inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg)
- -> decltype(m.WithOperand(operand_num, std::forward<FirstArg>(first_arg))) {
+inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg) {
return m.WithOperand(operand_num, std::forward<FirstArg>(first_arg));
}
template <typename Matcher, typename FirstArg, typename... Args>
inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg,
- Args&&... args)
- -> decltype(WithOperands(m.WithOperand(operand_num,
- std::forward<FirstArg>(first_arg)),
- operand_num + 1, std::forward<Args>(args)...)) {
+ Args&&... args) {
return WithOperands(
m.WithOperand(operand_num, std::forward<FirstArg>(first_arg)),
operand_num + 1, std::forward<Args>(args)...);
@@ -2259,26 +2143,17 @@
} // namespace detail
#define XLA_VARIADIC_OP_PATTERN(NAME) \
- inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \
- return Op().WithOpcode(HloOpcode::k##NAME); \
- } \
+ inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \
\
template <typename... Args> \
- inline auto NAME(Args&&... args) \
- ->decltype(detail::WithOperands(Op().WithOpcode(HloOpcode::k##NAME) \
- .WithNumOperands(sizeof...(Args)), \
- 0, std::forward<Args>(args)...)) { \
+ inline auto NAME(Args&&... args) { \
return detail::WithOperands( \
Op().WithOpcode(HloOpcode::k##NAME).WithNumOperands(sizeof...(Args)), \
/*operand_num=*/0, std::forward<Args>(args)...); \
} \
\
template <typename HloInstructionType, typename... Args> \
- inline auto NAME(HloInstructionType** matched_inst, Args&&... args) \
- ->decltype(detail::WithOperands(Op(matched_inst) \
- .WithOpcode(HloOpcode::k##NAME) \
- .WithNumOperands(sizeof...(Args)), \
- 0, std::forward<Args>(args)...)) { \
+ inline auto NAME(HloInstructionType** matched_inst, Args&&... args) { \
return detail::WithOperands(Op(matched_inst) \
.WithOpcode(HloOpcode::k##NAME) \
.WithNumOperands(sizeof...(Args)), \
@@ -2299,63 +2174,46 @@
XLA_VARIADIC_OP_PATTERN(Tuple);
// Helpers for comparison instructions.
-#define XLA_COMPARE_PATTERN(NAME) \
- inline auto NAME()->decltype( \
- Op().WithOpcode(HloOpcode::kCompare) \
- .WithComparisonDirection(ComparisonDirection::k##NAME)) { \
- return Op() \
- .WithOpcode(HloOpcode::kCompare) \
- .WithComparisonDirection(ComparisonDirection::k##NAME); \
- } \
- \
- template <typename Lhs, typename Rhs> \
- inline auto NAME(Lhs&& lhs, Rhs&& rhs) \
- ->decltype(Op().WithOpcode(HloOpcode::kCompare) \
- .WithOperand(0, std::forward<Lhs>(lhs)) \
- .WithOperand(1, std::forward<Rhs>(rhs)) \
- .WithComparisonDirection(ComparisonDirection::k##NAME)) { \
- return Op() \
- .WithOpcode(HloOpcode::kCompare) \
- .WithOperand(0, std::forward<Lhs>(lhs)) \
- .WithOperand(1, std::forward<Rhs>(rhs)) \
- .WithComparisonDirection(ComparisonDirection::k##NAME); \
- } \
- \
- template <typename HloInstructionType, typename Lhs, typename Rhs> \
- inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) \
- ->decltype(Op(matched_inst) \
- .WithOpcode(HloOpcode::kCompare) \
- .WithOperand(0, std::forward<Lhs>(lhs)) \
- .WithOperand(1, std::forward<Rhs>(rhs)) \
- .WithComparisonDirection(ComparisonDirection::k##NAME)) { \
- return Op(matched_inst) \
- .WithOpcode(HloOpcode::kCompare) \
- .WithOperand(0, std::forward<Lhs>(lhs)) \
- .WithOperand(1, std::forward<Rhs>(rhs)) \
- .WithComparisonDirection(ComparisonDirection::k##NAME); \
+#define XLA_COMPARE_PATTERN(NAME) \
+ inline auto NAME() { \
+ return Op() \
+ .WithOpcode(HloOpcode::kCompare) \
+ .WithComparisonDirection(ComparisonDirection::k##NAME); \
+ } \
+ \
+ template <typename Lhs, typename Rhs> \
+ inline auto NAME(Lhs&& lhs, Rhs&& rhs) { \
+ return Op() \
+ .WithOpcode(HloOpcode::kCompare) \
+ .WithOperand(0, std::forward<Lhs>(lhs)) \
+ .WithOperand(1, std::forward<Rhs>(rhs)) \
+ .WithComparisonDirection(ComparisonDirection::k##NAME); \
+ } \
+ \
+ template <typename HloInstructionType, typename Lhs, typename Rhs> \
+ inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) { \
+ return Op(matched_inst) \
+ .WithOpcode(HloOpcode::kCompare) \
+ .WithOperand(0, std::forward<Lhs>(lhs)) \
+ .WithOperand(1, std::forward<Rhs>(rhs)) \
+ .WithComparisonDirection(ComparisonDirection::k##NAME); \
}
-#define XLA_COMMUTATIVE_COMPARE_PATTERN(NAME) \
- XLA_COMPARE_PATTERN(NAME) \
- \
- template <typename HloInstructionType, typename Lhs, typename Rhs> \
- inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \
- Rhs&& rhs) \
- ->decltype(Op(matched_inst) \
- .WithOpcode(HloOpcode::kCompare) \
- .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs), \
- std::forward<Rhs>(rhs))) { \
- return Op(matched_inst) \
- .WithOpcode(HloOpcode::kCompare) \
- .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs), \
- std::forward<Rhs>(rhs)); \
- } \
- template <typename Lhs, typename Rhs> \
- inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \
- ->decltype(NAME##AnyOrder<const HloInstruction>( \
- nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs))) { \
- return NAME##AnyOrder<const HloInstruction>( \
- nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs)); \
+#define XLA_COMMUTATIVE_COMPARE_PATTERN(NAME) \
+ XLA_COMPARE_PATTERN(NAME) \
+ \
+ template <typename HloInstructionType, typename Lhs, typename Rhs> \
+ inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \
+ Rhs&& rhs) { \
+ return Op(matched_inst) \
+ .WithOpcode(HloOpcode::kCompare) \
+ .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs), \
+ std::forward<Rhs>(rhs)); \
+ } \
+ template <typename Lhs, typename Rhs> \
+ inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) { \
+ return NAME##AnyOrder<const HloInstruction>( \
+ nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs)); \
}
XLA_COMMUTATIVE_COMPARE_PATTERN(Eq);
@@ -2366,23 +2224,17 @@
XLA_COMPARE_PATTERN(Lt);
// Helpers for matching non-constant instructions.
-inline auto NonConstant() -> decltype(Op().IsNonConstant()) {
- return Op().IsNonConstant();
-}
+inline auto NonConstant() { return Op().IsNonConstant(); }
template <typename HloInstructionType>
-inline auto NonConstant(HloInstructionType** matched_inst)
- -> decltype(Op(matched_inst).IsNonConstant()) {
+inline auto NonConstant(HloInstructionType** matched_inst) {
return Op(matched_inst).IsNonConstant();
}
// Add overloads for GetTupleElement which take a int64 specifying which tuple
// element is selected.
template <typename Arg>
-inline auto GetTupleElement(Arg&& arg, int64 tuple_index)
- -> decltype(Op().WithOpcode(HloOpcode::kGetTupleElement)
- .WithOperand(0, std::forward<Arg>(arg))
- .WithTupleIndex(tuple_index)) {
+inline auto GetTupleElement(Arg&& arg, int64 tuple_index) {
return Op()
.WithOpcode(HloOpcode::kGetTupleElement)
.WithOperand(0, std::forward<Arg>(arg))
@@ -2391,11 +2243,7 @@
template <typename HloInstructionType, typename Arg>
inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg,
- int64 tuple_index)
- -> decltype(Op(matched_inst)
- .WithOpcode(HloOpcode::kGetTupleElement)
- .WithOperand(0, std::forward<Arg>(arg))
- .WithTupleIndex(tuple_index)) {
+ int64 tuple_index) {
return Op(matched_inst)
.WithOpcode(HloOpcode::kGetTupleElement)
.WithOperand(0, std::forward<Arg>(arg))
@@ -2404,62 +2252,50 @@
// Add overloads for Parameter which take an int64 specifying the parameter
// number.
-inline auto Parameter(int64 parameter_num) -> decltype(
- Op().WithOpcode(HloOpcode::kParameter).WithParameterNum(parameter_num)) {
+inline auto Parameter(int64 parameter_num) {
return Op().WithOpcode(HloOpcode::kParameter).WithParameterNum(parameter_num);
}
template <typename HloInstructionType>
-inline auto Parameter(HloInstructionType** matched_inst, int64 parameter_num)
- -> decltype(Op(matched_inst)
- .WithOpcode(HloOpcode::kParameter)
- .WithParameterNum(parameter_num)) {
+inline auto Parameter(HloInstructionType** matched_inst, int64 parameter_num) {
return Op(matched_inst)
.WithOpcode(HloOpcode::kParameter)
.WithParameterNum(parameter_num);
}
-inline auto ConstantScalar() -> decltype(Op().IsConstantScalar()) {
- return Op().IsConstantScalar();
-}
+inline auto ConstantScalar() { return Op().IsConstantScalar(); }
template <typename HloInstructionType>
-inline auto ConstantScalar(HloInstructionType** matched_inst)
- -> decltype(Op(matched_inst).IsConstantScalar()) {
+inline auto ConstantScalar(HloInstructionType** matched_inst) {
return Op(matched_inst).IsConstantScalar();
}
template <typename ScalarTy>
-inline auto ConstantScalar(ScalarTy val)
- -> decltype(Op().IsConstantScalar(val)) {
+inline auto ConstantScalar(ScalarTy val) {
return Op().IsConstantScalar(val);
}
template <typename HloInstructionType, typename ScalarTy>
-inline auto ConstantScalar(HloInstructionType** matched_inst, ScalarTy val)
- -> decltype(Op(matched_inst).IsConstantScalar(val)) {
+inline auto ConstantScalar(HloInstructionType** matched_inst, ScalarTy val) {
return Op(matched_inst).IsConstantScalar(val);
}
-inline auto ConstantEffectiveScalar() -> decltype(Op().IsConstantScalar()) {
+inline auto ConstantEffectiveScalar() {
return Op().IsConstantEffectiveScalar();
}
template <typename HloInstructionType>
-inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst)
- -> decltype(Op(matched_inst).IsConstantScalar()) {
+inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst) {
return Op(matched_inst).IsConstantEffectiveScalar();
}
template <typename ScalarTy>
-inline auto ConstantEffectiveScalar(ScalarTy val)
- -> decltype(Op().IsConstantEffectiveScalar(val)) {
+inline auto ConstantEffectiveScalar(ScalarTy val) {
return Op().IsConstantEffectiveScalar(val);
}
template <typename HloInstructionType, typename ScalarTy>
inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst,
- ScalarTy val)
- -> decltype(Op(matched_inst).IsConstantEffectiveScalar(val)) {
+ ScalarTy val) {
return Op(matched_inst).IsConstantEffectiveScalar(val);
}