| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" |
| |
| #include <algorithm> |
| #include <cmath> |
| #include <functional> |
| #include <iterator> |
| #include <memory> |
| #include <numeric> |
| #include <string> |
| #include <utility> |
| #include <vector> |
| |
| #include "absl/algorithm/container.h" |
| #include "absl/container/flat_hash_map.h" |
| #include "absl/container/flat_hash_set.h" |
| #include "absl/container/inlined_vector.h" |
| #include "absl/memory/memory.h" |
| #include "absl/strings/str_cat.h" |
| #include "absl/types/optional.h" |
| #include "absl/types/span.h" |
| #include "tensorflow/compiler/xla/comparison_util.h" |
| #include "tensorflow/compiler/xla/layout_util.h" |
| #include "tensorflow/compiler/xla/literal.h" |
| #include "tensorflow/compiler/xla/literal_util.h" |
| #include "tensorflow/compiler/xla/overflow_util.h" |
| #include "tensorflow/compiler/xla/permutation_util.h" |
| #include "tensorflow/compiler/xla/primitive_util.h" |
| #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" |
| #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" |
| #include "tensorflow/compiler/xla/service/hlo_computation.h" |
| #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" |
| #include "tensorflow/compiler/xla/service/hlo_evaluator.h" |
| #include "tensorflow/compiler/xla/service/hlo_instruction.h" |
| #include "tensorflow/compiler/xla/service/hlo_instructions.h" |
| #include "tensorflow/compiler/xla/service/hlo_opcode.h" |
| #include "tensorflow/compiler/xla/service/hlo_query.h" |
| #include "tensorflow/compiler/xla/service/pattern_matcher.h" |
| #include "tensorflow/compiler/xla/shape.h" |
| #include "tensorflow/compiler/xla/shape_util.h" |
| #include "tensorflow/compiler/xla/status_macros.h" |
| #include "tensorflow/compiler/xla/types.h" |
| #include "tensorflow/compiler/xla/util.h" |
| #include "tensorflow/compiler/xla/window_util.h" |
| #include "tensorflow/compiler/xla/xla_data.pb.h" |
| #include "tensorflow/core/lib/core/bits.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/core/status.h" |
| #include "tensorflow/core/platform/errors.h" |
| #include "tensorflow/core/platform/logging.h" |
| #include "tensorflow/core/platform/statusor.h" |
| #include "tensorflow/core/platform/types.h" |
| #include "tensorflow/stream_executor/lib/statusor.h" |
| |
| namespace xla { |
| |
| namespace { |
| |
| namespace m = match; |
| |
| bool IsAll(const HloInstruction* op, int8 value) { |
| switch (op->opcode()) { |
| case HloOpcode::kBroadcast: |
| return IsAll(op->operand(0), value); |
| case HloOpcode::kConstant: |
| return op->literal().IsAll(value); |
| default: |
| return false; |
| } |
| } |
| |
| bool IsAnyOperandComplex(const HloInstruction* hlo) { |
| for (auto operand : hlo->operands()) { |
| if (ShapeUtil::ElementIsComplex(operand->shape())) { |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| bool IsPositive(const HloInstruction* hlo, |
| const AlgebraicSimplifierOptions& options) { |
| // Utility only handles real types. |
| if (IsAnyOperandComplex(hlo)) { |
| return false; |
| } |
| switch (hlo->opcode()) { |
| case HloOpcode::kGetTupleElement: { |
| const HloInstruction* gte_operand = hlo->operand(0); |
| switch (gte_operand->opcode()) { |
| case HloOpcode::kCustomCall: { |
| const auto& target = gte_operand->custom_call_target(); |
| return target == |
| options.get_cudnn_batchnorm_forward_training_metadata() && |
| hlo->tuple_index() == 2; |
| } |
| default: |
| return false; |
| } |
| } |
| case HloOpcode::kPower: |
| case HloOpcode::kAbs: |
| case HloOpcode::kRsqrt: |
| case HloOpcode::kSqrt: |
| return IsPositive(hlo->operand(0), options); |
| |
| case HloOpcode::kMultiply: { |
| return hlo->operand(0) == hlo->operand(1) && |
| IsPositive(hlo->operand(0), options); |
| } |
| default: |
| return false; |
| } |
| } |
| |
| bool IsNonNegative(const HloInstruction* hlo, |
| const AlgebraicSimplifierOptions& options) { |
| // Utility only handles real types. |
| if (IsAnyOperandComplex(hlo)) { |
| return false; |
| } |
| switch (hlo->opcode()) { |
| case HloOpcode::kMultiply: { |
| return hlo->operand(0) == hlo->operand(1); |
| } |
| case HloOpcode::kAbs: { |
| return true; |
| } |
| default: |
| return IsPositive(hlo, options); |
| } |
| } |
| |
| // Checks whether `op` is a floating-point constant or broadcast of a constant |
| // of the form +/- 2^k for some integer k positive, negative, or zero. Such |
| // values are interesting because multiplying by a power of 2 just moves the |
| // exponent. |
| bool IsAllFpConstantPowerOf2(const HloInstruction* op) { |
| // Unwrap the broadcast if necessary. |
| const HloInstruction* c; |
| if (!Match(op, m::ConstantEffectiveScalar(&c)) && |
| !Match(op, m::Broadcast(m::Constant(&c).WithShape( |
| m::Shape().IsEffectiveScalar())))) { |
| return false; |
| } |
| auto val = [&]() -> absl::optional<double> { |
| switch (c->shape().element_type()) { |
| case BF16: |
| return static_cast<double>(c->literal().GetFirstElement<bfloat16>()); |
| case F16: |
| return static_cast<double>(c->literal().GetFirstElement<Eigen::half>()); |
| case F32: |
| return c->literal().GetFirstElement<float>(); |
| case F64: |
| return c->literal().GetFirstElement<double>(); |
| default: |
| // Cowardly refuse to consider complex types. |
| return absl::nullopt; |
| } |
| }(); |
| if (!val) { |
| return false; |
| } |
| |
| int exp; |
| double mantissa = std::frexp(*val, &exp); |
| // frexp returns a value in the range (-1, -0.5] U [0.5, 1). A return value |
| // of +/-0.5 therefore indicates that the floating point value is a power of |
| // 2. |
| return mantissa == 0.5 || mantissa == -0.5; |
| } |
| |
| // Returns whether the given transpose produces a result which is bit-wise |
| // identical to its operand and thus may be replaced with a bitcast. |
| bool TransposeIsBitcast(const HloInstruction* transpose) { |
| CHECK_EQ(HloOpcode::kTranspose, transpose->opcode()); |
| const HloInstruction* operand = transpose->operand(0); |
| return ShapeUtil::TransposeIsBitcast(operand->shape(), transpose->shape(), |
| transpose->dimensions()); |
| } |
| |
| // Recursive helper for method below. |
| HloInstruction* BitcastingOperandOfReshapeOrCopyChainHelper( |
| HloInstruction* instr, HloInstruction* operand, |
| const AlgebraicSimplifierOptions& options) { |
| // Can't replace chain of copies and reshapes with bitcasts if the compiler |
| // used a memory layout which isn't compatible. |
| if (options.ReshapeIsBitcast(operand->shape(), instr->shape())) { |
| return operand; |
| } |
| |
| // If the operand is a copy or reshape try to see if the operand's operand |
| // would produce a bitcast with initial instruction. |
| if (HloOpcode::kReshape == operand->opcode() || |
| HloOpcode::kCopy == operand->opcode()) { |
| return BitcastingOperandOfReshapeOrCopyChainHelper( |
| instr, operand->mutable_operand(0), options); |
| } |
| return nullptr; |
| } |
| |
| // Returns an operand of a chain of reshapes and copies that is bit-wise |
| // identical to first reshape or copy in the chain. |
| HloInstruction* BitcastingOperandOfReshapeOrCopyChain( |
| HloInstruction* instr, const AlgebraicSimplifierOptions& options) { |
| if (!options.is_layout_sensitive()) { |
| return nullptr; |
| } |
| CHECK(HloOpcode::kReshape == instr->opcode() || |
| HloOpcode::kCopy == instr->opcode()); |
| return BitcastingOperandOfReshapeOrCopyChainHelper( |
| instr, instr->mutable_operand(0), options); |
| } |
| |
| bool IsUnstridedSlice(const HloInstruction* hlo) { |
| return absl::c_all_of(hlo->slice_strides(), |
| [](int64 stride) { return stride == 1; }); |
| } |
| |
| // Returns bool to determine whether a pair of converts can be eliminated. |
| bool IsConvertPairNoOp(const HloInstruction* convert) { |
| // [operand_convert] [convert] |
| // (src)->convert-(intermediate)->convert-(dest) |
| const HloInstruction* operand_convert = convert->operand(0); |
| CHECK_EQ(operand_convert->opcode(), HloOpcode::kConvert); |
| const Shape& src_shape = operand_convert->operand(0)->shape(); |
| const Shape& intermediate_shape = operand_convert->shape(); |
| const Shape& dest_shape = convert->shape(); |
| |
| const PrimitiveType src_type = src_shape.element_type(); |
| const PrimitiveType intermediate_type = intermediate_shape.element_type(); |
| const PrimitiveType dest_type = dest_shape.element_type(); |
| |
| // src_type must be equal to dest_type. |
| if (src_type != dest_type) { |
| return false; |
| } |
| |
| // src_type must be a larger container than intermediate_type. |
| if (ShapeUtil::ByteSizeOfPrimitiveType(intermediate_type) <= |
| ShapeUtil::ByteSizeOfPrimitiveType(src_type)) { |
| return false; |
| } |
| |
| // Both src_type and intermediate_type must be either floating or integral. |
| bool is_conversion_floating = |
| ShapeUtil::ElementIsFloating(src_shape) && |
| ShapeUtil::ElementIsFloating(intermediate_shape); |
| bool is_conversion_integral = |
| ShapeUtil::ElementIsIntegral(src_shape) && |
| ShapeUtil::ElementIsIntegral(intermediate_shape); |
| |
| return is_conversion_floating || is_conversion_integral; |
| } |
| |
| // AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain |
| // algebraic expressions to simplified forms. Note: This only supports |
| // simplifications that simply look at the operands of an instruction. For the |
| // more general case a worklist based approach would be needed. |
| class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { |
| public: |
| explicit AlgebraicSimplifierVisitor(const AlgebraicSimplifierOptions& options, |
| AlgebraicSimplifier* simplifier) |
| : options_(options), simplifier_(simplifier) {} |
| |
| Status HandleAbs(HloInstruction* abs) override; |
| |
| Status HandleAdd(HloInstruction* add) override; |
| |
| Status HandleAnd(HloInstruction* logical_and) override; |
| |
| Status HandleBitcast(HloInstruction* bitcast) override; |
| |
| Status HandleBitcastConvert(HloInstruction* bitcast) override; |
| |
| Status HandleBroadcast(HloInstruction* broadcast) override; |
| |
| Status HandleCompare(HloInstruction* compare) override; |
| |
| Status HandleConcatenate(HloInstruction* concatenate) override; |
| |
| Status HandleConstant(HloInstruction* constant) override; |
| |
| Status HandleCopy(HloInstruction* copy) override; |
| |
| Status HandleConvert(HloInstruction* convert) override; |
| |
| Status HandleComplex(HloInstruction* complex) override; |
| |
| Status HandleReal(HloInstruction* real) override; |
| |
| Status HandleImag(HloInstruction* imag) override; |
| |
| Status HandleIota(HloInstruction* instruction) override; |
| |
| Status HandleConvolution(HloInstruction* convolution) override; |
| |
| Status HandleDivide(HloInstruction* divide) override; |
| |
| Status HandleDot(HloInstruction* dot) override; |
| |
| Status HandleGather(HloInstruction* gather) override; |
| |
| Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; |
| |
| Status HandleLog(HloInstruction* log) override; |
| |
| Status HandleMaximum(HloInstruction* maximum) override; |
| |
| Status HandleMinimum(HloInstruction* minimum) override; |
| |
| Status HandleClamp(HloInstruction* clamp) override; |
| |
| Status HandleMultiply(HloInstruction* multiply) override; |
| |
| Status HandleNegate(HloInstruction* negate) override; |
| |
| Status HandleNot(HloInstruction* logical_not) override; |
| |
| Status HandleOr(HloInstruction* logical_or) override; |
| |
| Status HandlePad(HloInstruction* pad) override; |
| |
| Status HandlePower(HloInstruction* power) override; |
| |
| Status HandleRemainder(HloInstruction* remainder) override; |
| |
| Status HandleReshape(HloInstruction* reshape) override; |
| |
| Status HandleReduce(HloInstruction* hlo) override; |
| |
| Status HandleReduceWindow(HloInstruction* hlo) override; |
| |
| Status HandleReverse(HloInstruction* reverse) override; |
| |
| Status HandleRsqrt(HloInstruction* rsqrt) override; |
| |
| Status HandleSlice(HloInstruction* slice) override; |
| |
| Status HandleSqrt(HloInstruction* sqrt) override; |
| |
| Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; |
| |
| Status HandleDynamicUpdateSlice( |
| HloInstruction* dynamic_update_slice) override; |
| Status HandleScatter(HloInstruction* scatter) override; |
| |
| Status HandleSelect(HloInstruction* select) override; |
| |
| Status HandleSort(HloInstruction* sort) override; |
| |
| Status HandleTranspose(HloInstruction* transpose) override; |
| |
| Status HandleSubtract(HloInstruction* sub) override; |
| |
| Status HandleMap(HloInstruction* map) override; |
| |
| // Runs the visitor on a computation. |
| bool Run(HloComputation* computation, |
| const AlgebraicSimplifierOptions& options, |
| AlgebraicSimplifier* simplifier); |
| |
| private: |
| // Removes degenerate dimension from dot. |
| StatusOr<bool> RemoveDegenerateDimensionFromDot(HloInstruction* dot); |
| |
| // Converts to primitive type if the input hlo is not that type, otherwise |
| // returns the original hlo. |
| HloInstruction* AsType(HloInstruction* hlo, |
| const PrimitiveType element_type) { |
| if (hlo->shape().element_type() == element_type) { |
| return hlo; |
| } |
| Shape changed_shape = |
| ShapeUtil::ChangeElementType(hlo->shape(), element_type); |
| simplifier_->UpdateLayout(&changed_shape); |
| return computation_->AddInstruction( |
| HloInstruction::CreateConvert(changed_shape, hlo)); |
| } |
| |
| // Transposes a dot operand such that the batch dimensions are the most major, |
| // and the contracting dimensions are most minor. |
| StatusOr<HloInstruction*> NormalizeDotOperandToBatchMajorAndContractingMinor( |
| HloInstruction* dot_operand, absl::Span<const int64> batch_dimensions, |
| absl::Span<const int64> contracting_dimensions) { |
| std::vector<int64> transpose_dimensions(batch_dimensions.begin(), |
| batch_dimensions.end()); |
| for (int64 i = 0; i < dot_operand->shape().rank(); ++i) { |
| if (!(absl::c_linear_search(batch_dimensions, i) || |
| absl::c_linear_search(contracting_dimensions, i))) { |
| transpose_dimensions.push_back(i); |
| } |
| } |
| transpose_dimensions.insert(transpose_dimensions.end(), |
| contracting_dimensions.begin(), |
| contracting_dimensions.end()); |
| if (absl::c_is_sorted(transpose_dimensions)) { |
| return dot_operand; |
| } |
| return MakeTransposeHlo(dot_operand, transpose_dimensions); |
| } |
| |
| // Helper method to perform and add reduction on a list of dimensions. |
| HloInstruction* AddReduce(HloInstruction* hlo, absl::Span<const int64> dims, |
| PrimitiveType type) { |
| HloInstruction* zero = computation_->AddInstruction( |
| simplifier_->CreateConstantWithLayoutUpdated( |
| LiteralUtil::Zero(hlo->shape().element_type()).Clone())); |
| HloComputation* AddReduce_computation = |
| GetOrCreateScalarAddComputation(type); |
| Shape shape = ShapeUtil::FilterDimensions( |
| [&](int64 dim) { return !absl::c_linear_search(dims, dim); }, |
| hlo->shape()); |
| simplifier_->UpdateLayout(&shape); |
| return computation_->AddInstruction(HloInstruction::CreateReduce( |
| shape, hlo, zero, dims, AddReduce_computation)); |
| } |
| |
| // Move scalar multiply to the smallest side of convolution to |
| // reduce multiply computations. |
| Status ScalarMultiplyReduction(HloInstruction* dot); |
| |
| // Convenience method for replacing an instruction with a bitcast. If operand |
| // is not null, then the bitcast will use the specified operand instead of the |
| // operand of the instruction. |
| void ReplaceWithBitcast(HloInstruction* instruction, |
| HloInstruction* operand = nullptr); |
| |
| // Replace old instruction with new instruction if old and new instructions |
| // have the same shape. Updates uses and root instruction. Returns whether a |
| // replacement was made. |
| bool ReplaceInstructionIfSameShape(HloInstruction* old_instruction, |
| HloInstruction* new_instruction); |
| |
| // Returns whether the shape of the output of the given instructions are the |
| // same for the purposes of simplification. If options_.is_layout_sensitive() |
| // is true, then this tests shape equality including layout |
| // (ShapeUtil::Equal). If options_.is_layout_sensitive() is false, then the |
| // tests shape compatibility (ShapeUtil::Compatible). |
| bool SameShape(const HloInstruction* lhs, const HloInstruction* rhs) const; |
| |
| // Returns whether it was possible to transform `root` to a clamp instruction. |
| // With min a minimum instruction, max a maximum instruction, min_operand a |
| // operand of min and max_operand a operand of max. |
| // Precondition: root is either a minimum or a maximum. |
| bool TransformToClampIfSameShape(HloInstruction* root, HloInstruction* min, |
| HloInstruction* min_operand, |
| HloInstruction* operand, HloInstruction* max, |
| HloInstruction* max_operand); |
| |
| // A Broadcast that feeds an element-wise operation with a unique non-scalar |
| // operand can sink to after the operation. |
| StatusOr<bool> TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( |
| HloInstruction* broadcast); |
| |
| StatusOr<HloInstruction*> OptimizeDotOfConcat(HloInstruction* dot); |
| StatusOr<HloInstruction*> OptimizeDotOfConcatHelper( |
| const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim, |
| HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped); |
| |
| StatusOr<HloInstruction*> OptimizeDotOfGather(HloInstruction* dot); |
| |
| StatusOr<HloInstruction*> OptimizeDotOfReorderContractingDims( |
| HloInstruction* dot); |
| |
| HloComputation* GetOrCreateScalarAddComputation(PrimitiveType type) { |
| HloComputation*& scalar_add_computation = scalar_add_computations_[type]; |
| if (scalar_add_computation) { |
| return scalar_add_computation; |
| } |
| |
| HloComputation::Builder b("scalar_add_computation"); |
| Shape shape = ShapeUtil::MakeShape(type, {}); |
| simplifier_->UpdateLayout(&shape); |
| auto scalar_lhs = b.AddInstruction( |
| HloInstruction::CreateParameter(0, shape, "scalar_lhs")); |
| auto scalar_rhs = b.AddInstruction( |
| HloInstruction::CreateParameter(1, shape, "scalar_rhs")); |
| auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( |
| shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs)); |
| scalar_add_computation = |
| computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); |
| return scalar_add_computation; |
| } |
| |
| // Tries to fold a kPad in the input or filter into the convolution |
| // instruction's window. |
| StatusOr<bool> FoldConvInputPad(HloInstruction* convolution); |
| StatusOr<bool> FoldConvFilterPad(HloInstruction* convolution); |
| |
| // Tries to swap convolution operands if they would result in a more efficient |
| // convolution. |
| StatusOr<bool> SwapConvOperands(HloInstruction* convolution); |
| |
| // Tries to use a kDot in place of the given convolution. |
| StatusOr<bool> SimplifyConvToDot(HloInstruction* convolution); |
| |
| // Tries to simplify a slice where the result of the slice is a scalar. |
| StatusOr<bool> TrySimplifyScalarSlice(HloInstruction* slice); |
| |
| // Tries to convert slice(reshape(X)) into reshape(slice(X)) |
| StatusOr<bool> TryToReorderSliceAndReshape(HloInstruction* slice); |
| |
| // Tries to convert slice(reverse(X)) into reverse(slice(X)) |
| StatusOr<bool> TryToReorderSliceAndReverse(HloInstruction* slice); |
| |
| // Tries to simplify `(and (< a N) (< a K))` in cases where `N <= K` into |
| // `(< a N)`. This is crucial for being able to figure out the loop trip |
| // count. |
| // |
| // Assumes that the input is conjunction. |
| StatusOr<bool> TrySimplifyTautologicalCompare(HloInstruction* conjunction); |
| |
| // Useful when we want to use the same visitor over multiple computations. |
| void ResetState(HloComputation* computation); |
| |
| // Current HloComputation instance the AlgebraicSimplifierVisitor is |
| // traversing. |
| HloComputation* computation_; |
| |
| // The backend-specific options selected for the algebraic simplifier. |
| const AlgebraicSimplifierOptions& options_; |
| |
| // Whether algebraic simplification has occurred. |
| bool changed_ = false; |
| |
| // Cached computation for adding two scalars of a given type. |
| absl::flat_hash_map<PrimitiveType, HloComputation*> scalar_add_computations_; |
| |
| AlgebraicSimplifier* simplifier_ = nullptr; |
| }; |
| |
| } // namespace |
| |
| void AlgebraicSimplifierVisitor::ResetState(HloComputation* computation) { |
| changed_ = false; |
| ResetVisitStates(); |
| computation_ = computation; |
| } |
| |
| bool AlgebraicSimplifierVisitor::Run(HloComputation* computation, |
| const AlgebraicSimplifierOptions& options, |
| AlgebraicSimplifier* simplifier) { |
| ResetState(computation); |
| TF_CHECK_OK(computation->Accept(this)); |
| return changed_ || changed(); |
| } |
| |
| bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs, |
| const HloInstruction* rhs) const { |
| if (options_.is_layout_sensitive()) { |
| return ShapeUtil::Equal(lhs->shape(), rhs->shape()); |
| } else { |
| return ShapeUtil::Compatible(lhs->shape(), rhs->shape()); |
| } |
| } |
| |
| namespace { |
| |
| float GetConstantValue(HloInstruction* inst) { |
| switch (inst->shape().element_type()) { |
| case BF16: |
| return static_cast<float>(inst->literal().GetFirstElement<bfloat16>()); |
| case F32: |
| return inst->literal().GetFirstElement<float>(); |
| default: |
| LOG(FATAL) << "Unsupported data type: " << inst->shape().element_type(); |
| } |
| } |
| |
| bool IsOpCodeMultiplyCommutative(HloOpcode opcode) { |
| switch (opcode) { |
| case HloOpcode::kMultiply: |
| case HloOpcode::kTranspose: |
| case HloOpcode::kReshape: |
| case HloOpcode::kSelect: |
| return true; |
| default: |
| return false; |
| } |
| } |
| |
| std::unique_ptr<HloInstruction> MakeScalarInstruction(HloInstruction* target, |
| float multiplier) { |
| switch (target->shape().element_type()) { |
| case BF16: |
| return HloInstruction::CreateConstant(LiteralUtil::ConvertF32ToBF16( |
| LiteralUtil::CreateR0<float>(multiplier))); |
| break; |
| case F32: |
| return HloInstruction::CreateConstant( |
| LiteralUtil::CreateR0<float>(multiplier)); |
| break; |
| default: |
| LOG(FATAL) << "Unsupported data type: " << target->shape().element_type(); |
| } |
| } |
| |
| } // namespace |
| |
| Status AlgebraicSimplifierVisitor::ScalarMultiplyReduction( |
| HloInstruction* dot) { |
| // We only process bfloat16 and float32 for now. |
| if (dot->shape().element_type() != BF16 && |
| dot->shape().element_type() != F32) { |
| return Status::OK(); |
| } |
| |
| auto lhs = dot->mutable_operand(0); |
| auto rhs = dot->mutable_operand(1); |
| |
| const int64 dot_size = ShapeUtil::ElementsIn(dot->shape()); |
| const int64 lhs_size = ShapeUtil::ElementsIn(lhs->shape()); |
| const int64 rhs_size = ShapeUtil::ElementsIn(rhs->shape()); |
| |
| HloInstruction* target = nullptr; |
| // (current node, user, operand_index) |
| std::vector<std::tuple<HloInstruction*, HloInstruction*, int64>> operands; |
| std::vector<HloInstruction*> users; |
| |
| // Find which side of dot has the smallest size: |
| // operand 0, operand 1, or output. |
| if (dot_size <= std::min(lhs_size, rhs_size)) { |
| target = dot; |
| if (dot_size < lhs_size) { |
| operands.emplace_back(lhs, dot, 0); |
| } |
| if (dot_size < rhs_size) { |
| operands.emplace_back(rhs, dot, 1); |
| } |
| } else if (lhs_size <= rhs_size) { |
| target = lhs; |
| if (lhs_size < rhs_size) { |
| operands.emplace_back(rhs, dot, 1); |
| } |
| if (lhs_size < dot_size && dot->user_count() == 1) { |
| users.push_back(dot->users().front()); |
| } |
| } else { |
| target = rhs; |
| if (rhs_size < lhs_size) { |
| operands.emplace_back(lhs, dot, 0); |
| } |
| if (rhs_size < dot_size && dot->user_count() == 1) { |
| users.push_back(dot->users().front()); |
| } |
| } |
| |
| std::vector<float> values; |
| |
| // DFS to find scalar multiply ops from the operands. |
| while (!operands.empty()) { |
| HloInstruction* inst; |
| HloInstruction* user; |
| int64 index; |
| std::tie(inst, user, index) = operands.back(); |
| operands.pop_back(); |
| |
| // Skip the op types that are not commutative with multiply. |
| if (!IsOpCodeMultiplyCommutative(inst->opcode())) { |
| continue; |
| } |
| |
| HloInstruction* operand; |
| HloInstruction* multiplier; |
| // Pattern match a scalar multiply. |
| if (Match(inst, m::MultiplyAnyOrder( |
| m::Op(&operand), |
| m::Broadcast(m::ConstantScalar(&multiplier))))) { |
| CHECK_LT(index, user->operand_count()); |
| CHECK_EQ(inst, user->operands()[index]); |
| |
| // When found a scalar multiply, save its scalar value. |
| values.push_back(GetConstantValue(multiplier)); |
| // And remove the scalar multiply op. |
| TF_RETURN_IF_ERROR(user->ReplaceOperandWith(index, operand)); |
| inst = operand; |
| } |
| |
| // Push the operands of inst. |
| int64 i = 0; |
| for (auto* operand : inst->operands()) { |
| operands.emplace_back(operand, inst, i++); |
| } |
| } |
| |
| // DFS to find scalar multiply ops from the users. |
| while (!users.empty()) { |
| auto inst = users.back(); |
| users.pop_back(); |
| |
| if (!IsOpCodeMultiplyCommutative(inst->opcode())) { |
| continue; |
| } |
| |
| HloInstruction* operand; |
| HloInstruction* multiplier; |
| if (Match(inst, m::MultiplyAnyOrder( |
| m::Op(&operand), |
| m::Broadcast(m::ConstantScalar(&multiplier))))) { |
| values.push_back(GetConstantValue(multiplier)); |
| |
| TF_RETURN_IF_ERROR(inst->ReplaceAllUsesWith(operand)); |
| inst = operand; |
| } |
| |
| // Process the instructions with only one user. |
| // Otherwise moving scalar multiply to the operands changes the values of |
| // other users. |
| if (inst->user_count() == 1) { |
| users.push_back(inst->users().front()); |
| } |
| } |
| |
| if (values.empty()) { |
| return Status::OK(); |
| } |
| |
| changed_ = true; |
| |
| // Combine all constant multipliers. |
| float multiplier = 1.0; |
| for (const float v : values) { |
| multiplier *= v; |
| } |
| |
| // Create a new const scalar multiply instruction. |
| HloInstruction* new_const_inst; |
| new_const_inst = |
| computation_->AddInstruction(MakeScalarInstruction(target, multiplier)); |
| |
| // Broadcast the scalar multiplier. |
| HloInstruction* new_broadcast = computation_->AddInstruction( |
| HloInstruction::CreateBroadcast(target->shape(), new_const_inst, {})); |
| // Create a new scalar multiply instruction. |
| HloInstruction* new_multiply = |
| computation_->AddInstruction(HloInstruction::CreateBinary( |
| target->shape(), HloOpcode::kMultiply, target, new_broadcast)); |
| CHECK_EQ(new_multiply->shape(), target->shape()); |
| |
| // Update the dependency with the rest of the instructions. |
| if (target == lhs) { |
| return dot->ReplaceOperandWith(0, new_multiply); |
| } else if (target == rhs) { |
| return dot->ReplaceOperandWith(1, new_multiply); |
| } else { |
| CHECK_EQ(target, dot); |
| return dot->ReplaceAllUsesWith(new_multiply); |
| } |
| } |
| |
| void AlgebraicSimplifierVisitor::ReplaceWithBitcast(HloInstruction* instruction, |
| HloInstruction* operand) { |
| CHECK_EQ(1, instruction->operand_count()); |
| if (operand == nullptr) { |
| operand = instruction->mutable_operand(0); |
| } |
| CHECK_EQ(ShapeUtil::ElementsIn(instruction->shape()), |
| ShapeUtil::ElementsIn(operand->shape())); |
| CHECK_EQ(ShapeUtil::ByteSizeOf(instruction->shape()), |
| ShapeUtil::ByteSizeOf(operand->shape())); |
| |
| auto bitcast = computation_->AddInstruction( |
| HloInstruction::CreateBitcast(instruction->shape(), operand)); |
| bitcast->set_metadata(instruction->metadata()); |
| TF_CHECK_OK(ReplaceInstruction(instruction, bitcast)); |
| } |
| |
| bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape( |
| HloInstruction* old_instruction, HloInstruction* new_instruction) { |
| if (!SameShape(old_instruction, new_instruction)) { |
| return false; |
| } |
| TF_CHECK_OK(ReplaceInstruction(old_instruction, new_instruction)); |
| return true; |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleAbs(HloInstruction* abs) { |
| HloInstruction* abs_operand = abs->mutable_operand(0); |
| VLOG(10) << "trying transform [Abs(A) => A] " << abs->ToString() |
| << " Abs operand is: " << abs_operand->ToString(); |
| if (IsNonNegative(abs->operand(0), options_)) { |
| return ReplaceInstruction(abs, abs_operand); |
| } |
| return Status::OK(); |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { |
| HloInstruction *lhs, *rhs; |
| CHECK(Match(add, m::Add(m::Op(&lhs), m::Op(&rhs)))); |
| |
| // A + 0 => A |
| VLOG(10) << "trying transform [A + 0 => A]: " << add->ToString(); |
| if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(add, lhs)) { |
| return Status::OK(); |
| } |
| // 0 + A => A |
| VLOG(10) << "trying transform [0 + A => A]: " << add->ToString(); |
| if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(add, rhs)) { |
| return Status::OK(); |
| } |
| |
| // Canonicalization: Put constants on the right. This makes the reassociation |
| // rules below simpler. |
| VLOG(10) << "trying transform [Const + A => A + Const]"; |
| if (Match(add, m::Add(m::Constant(), m::NonConstant()))) { |
| return ReplaceWithNewInstruction( |
| add, |
| HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, rhs, lhs)); |
| } |
| |
| // Reassociate to allow constant folding. |
| // |
| // Note: This is not general. For example, we won't reassociate |
| // |
| // (A + C1) + (B + C2) => A + B + (C1 + C2). |
| // |
| VLOG(10) << "trying transform [(A + C1) + C2 => A + (C1 + C2)]"; |
| HloInstruction *a, *c1, *c2; |
| if (Match(add, m::Add(m::Add(m::NonConstant(&a), m::Constant(&c1)), |
| m::Constant(&c2))) || |
| Match(add, m::Add(m::Add(m::NonConstant(&a), |
| m::Broadcast(m::ConstantScalar(&c1))), |
| m::Broadcast(m::ConstantScalar(&c2))))) { |
| TF_ASSIGN_OR_RETURN(auto* sum_of_constants, |
| MakeBinaryHlo(HloOpcode::kAdd, c1, c2)); |
| if (ShapeUtil::IsScalar(sum_of_constants->shape()) && |
| !ShapeUtil::IsScalar(add->shape())) { |
| sum_of_constants = computation_->AddInstruction( |
| HloInstruction::CreateBroadcast(add->shape(), sum_of_constants, {})); |
| } |
| return ReplaceWithNewInstruction( |
| add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, a, |
| sum_of_constants)); |
| } |
| |
| // Convert add with fullshape into add with partial shape when a |
| // portion of add is effective: |
| // zero (fullshape) rhs (partialshape) |
| // . | | |
| // . lhs . dynamic_update_slice (fullshape) |
| // . | | |
| // Add (fullshape) |
| // |
| // to: |
| // lhs |
| // | |
| // dynamic_slice (partialshape) rhs (partialshape) |
| // . | | |
| // . lhs . add (partial_shape)+----+ |
| // . | | |
| // dynamic_update_slice (fullshape) |
| // |
| // This is pattern is discovered in control flow V2 gradient update. |
| if (Match(add, |
| m::Add(m::Op(&lhs), |
| m::Op(&rhs) |
| .WithOpcode(HloOpcode::kDynamicUpdateSlice) |
| .WithOperand( |
| 0, m::Broadcast(m::ConstantEffectiveScalar(0)))))) { |
| const Shape& partial_shape = rhs->operand(1)->shape(); |
| auto sliced_lhs = |
| computation_->AddInstruction(HloInstruction::CreateDynamicSlice( |
| partial_shape, lhs, absl::MakeSpan(rhs->operands()).subspan(2), |
| partial_shape.dimensions())); |
| |
| auto add_partial = computation_->AddInstruction( |
| HloInstruction::CreateBinary(rhs->operand(1)->shape(), HloOpcode::kAdd, |
| sliced_lhs, rhs->mutable_operand(1))); |
| |
| auto dynamic_update_slice_full = HloInstruction::CreateDynamicUpdateSlice( |
| lhs->shape(), lhs, add_partial, |
| absl::MakeSpan(rhs->operands()).subspan(2)); |
| |
| return ReplaceWithNewInstruction(add, std::move(dynamic_update_slice_full)); |
| } |
| |
| // A*C + B*C => (A+B)*C |
| // |
| // - If A, B, and C are integers, do this unconditionally. Proof of |
| // correctness: https://rise4fun.com/Alive/u9X. |
| // |
| // - If A, B, and C are floating point, do this if C is a scalar constant or |
| // broadcast of scalar constant and is equal to +/- 2^k for some (possibly |
| // negative) integer k. |
| // |
| // Multiplying by a power of 2 just moves the exponent, so our answer is |
| // exact modulo rounding of intermediate results so long as |
| // |
| // - none of the three products has an exponent which underflows (so the |
| // result is 0 or denormal), and |
| // - none of the three products overflows to inf. |
| // |
| // Proof: See algebraic_simplifier_proof_distributive_property.py. |
| // |
| // We deem these differences in rounding, underflow, and overflow |
| // acceptable in the ML context. |
| HloInstruction *b, *c; |
| if (((Match(lhs, m::Multiply(m::Op(&a), m::Op(&c))) && |
| Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b)))) || |
| (Match(lhs, m::Multiply(m::Op(&c), m::Op(&a))) && |
| Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b))))) && |
| (ShapeUtil::ElementIsIntegral(add->shape()) || |
| options_.enable_floats_are_real() || IsAllFpConstantPowerOf2(c))) { |
| return ReplaceWithNewInstruction( |
| add, HloInstruction::CreateBinary( |
| add->shape(), HloOpcode::kMultiply, |
| computation_->AddInstruction(HloInstruction::CreateBinary( |
| add->shape(), HloOpcode::kAdd, a, b)), |
| c)); |
| } |
| |
| if (options_.is_layout_sensitive()) { |
| return Status::OK(); |
| } |
| |
| HloInstruction* lhs_scatter_operand = nullptr; |
| HloInstruction* rhs_scatter_operand = nullptr; |
| HloInstruction* lhs_scatter_update = nullptr; |
| HloInstruction* rhs_scatter_update = nullptr; |
| HloInstruction* lhs_scatter_index = nullptr; |
| HloInstruction* rhs_scatter_index = nullptr; |
| bool lhs_scatter = Match(lhs, m::Scatter(m::Op(&lhs_scatter_operand), |
| m::Op(&lhs_scatter_index), |
| m::Op(&lhs_scatter_update)) |
| .WithOneUse()) && |
| Match(lhs->to_apply()->root_instruction(), |
| m::Add(m::Parameter(), m::Parameter())); |
| bool rhs_scatter = Match(rhs, m::Scatter(m::Op(&rhs_scatter_operand), |
| m::Op(&rhs_scatter_index), |
| m::Op(&rhs_scatter_update)) |
| .WithOneUse()) && |
| Match(rhs->to_apply()->root_instruction(), |
| m::Add(m::Parameter(), m::Parameter())); |
| if (rhs_scatter && lhs_scatter) { |
| const auto& lhs_dnums = lhs->scatter_dimension_numbers(); |
| const auto& rhs_dnums = rhs->scatter_dimension_numbers(); |
| absl::optional<int64> index_concat_dimension; |
| absl::optional<int64> update_concat_dimension; |
| // Don't try to combine scatters of different ranks. |
| if (lhs_scatter_index->shape().rank() != |
| rhs_scatter_index->shape().rank()) { |
| return Status::OK(); |
| } |
| |
| int64 first_index_dim = lhs_scatter_index->shape().rank(); |
| int64 first_update_dim = lhs_scatter_update->shape().rank(); |
| // Find a dimension where it is possible to concatenate the indices and |
| // updates. This is the first and only non-equal dimension or the first |
| // equally sized dimension. |
| for (int64 d = lhs_scatter_index->shape().rank() - 1, |
| update_dim = lhs_scatter_update->shape().rank() - 1; |
| d >= 0; --d) { |
| if (d == lhs_dnums.index_vector_dim()) { |
| continue; |
| } |
| while ( |
| absl::c_linear_search(lhs_dnums.update_window_dims(), update_dim)) { |
| --update_dim; |
| } |
| if (lhs_scatter_index->shape().dimensions(d) == |
| rhs_scatter_index->shape().dimensions(d)) { |
| first_index_dim = d; |
| first_update_dim = update_dim--; |
| continue; |
| } |
| // More than one dimension of unequal size was found, bail out. |
| if (index_concat_dimension) { |
| return Status::OK(); |
| } |
| index_concat_dimension = d; |
| update_concat_dimension = update_dim--; |
| } |
| if (!index_concat_dimension) { |
| index_concat_dimension = first_index_dim; |
| update_concat_dimension = first_update_dim; |
| } |
| |
| // A scalar scatter will require additional reshapes of the index and |
| // update. |
| if (*index_concat_dimension == lhs_scatter_index->shape().rank()) { |
| return Status::OK(); |
| } |
| const bool update_concat_is_cheap = |
| ShapeUtil::ElementsIn(rhs_scatter_update->shape()) + |
| ShapeUtil::ElementsIn(lhs_scatter_update->shape()) < |
| ShapeUtil::ElementsIn(lhs->shape()); |
| if (!update_concat_is_cheap) { |
| return Status::OK(); |
| } |
| const bool same_dimension_numbers = |
| lhs_dnums.index_vector_dim() == rhs_dnums.index_vector_dim() && |
| absl::c_equal(lhs_dnums.scatter_dims_to_operand_dims(), |
| rhs_dnums.scatter_dims_to_operand_dims()) && |
| absl::c_equal(lhs_dnums.inserted_window_dims(), |
| rhs_dnums.inserted_window_dims()) && |
| absl::c_equal(lhs_dnums.update_window_dims(), |
| rhs_dnums.update_window_dims()); |
| const bool index_concat_is_safe = |
| !lhs->unique_indices() && !rhs->unique_indices() && |
| !DynCast<HloScatterInstruction>(lhs)->indices_are_sorted() && |
| !DynCast<HloScatterInstruction>(rhs)->indices_are_sorted(); |
| |
| Shape lhs_update_window = ShapeUtil::FilterDimensions( |
| [&](int64 dim) { |
| return absl::c_linear_search(lhs_dnums.update_window_dims(), dim); |
| }, |
| lhs_scatter_update->shape()); |
| Shape rhs_update_window = ShapeUtil::FilterDimensions( |
| [&](int64 dim) { |
| return absl::c_linear_search(rhs_dnums.update_window_dims(), dim); |
| }, |
| rhs_scatter_update->shape()); |
| // Concatenate the indices and updates |
| if (index_concat_is_safe && same_dimension_numbers && |
| index_concat_dimension && |
| lhs_scatter_index->shape().element_type() == |
| rhs_scatter_index->shape().element_type() && |
| ShapeUtil::SameDimensions(lhs_update_window, rhs_update_window)) { |
| TF_ASSIGN_OR_RETURN(HloInstruction * new_operand, |
| MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand, |
| rhs_scatter_operand)); |
| TF_ASSIGN_OR_RETURN(HloInstruction * new_index, |
| MakeConcatHlo({lhs_scatter_index, rhs_scatter_index}, |
| *index_concat_dimension)); |
| TF_ASSIGN_OR_RETURN( |
| HloInstruction * new_update, |
| MakeConcatHlo({lhs_scatter_update, rhs_scatter_update}, |
| *update_concat_dimension)); |
| return ReplaceWithNewInstruction( |
| add, HloInstruction::CreateScatter( |
| add->shape(), new_operand, new_index, new_update, |
| lhs->to_apply(), lhs_dnums, false, false)); |
| } |
| TF_ASSIGN_OR_RETURN(HloInstruction * new_operand, |
| MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand, |
| rhs_scatter_operand)); |
| TF_RETURN_IF_ERROR(rhs->ReplaceOperandWith(0, new_operand)); |
| TF_RETURN_IF_ERROR(lhs->ReplaceOperandWith(0, rhs)); |
| return ReplaceInstruction(add, lhs); |
| } else if (rhs_scatter) { |
| TF_ASSIGN_OR_RETURN( |
| HloInstruction * new_operand, |
| MakeBinaryHlo(HloOpcode::kAdd, lhs, rhs_scatter_operand)); |
| TF_RETURN_IF_ERROR(rhs->ReplaceOperandWith(0, new_operand)); |
| return ReplaceInstruction(add, rhs); |
| } else if (lhs_scatter) { |
| TF_ASSIGN_OR_RETURN( |
| HloInstruction * new_operand, |
| MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand, rhs)); |
| TF_RETURN_IF_ERROR(lhs->ReplaceOperandWith(0, new_operand)); |
| return ReplaceInstruction(add, lhs); |
| } |
| return Status::OK(); |
| } |
| |
| StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyTautologicalCompare( |
| HloInstruction* conjunction) { |
| HloInstruction *lhs, *rhs; |
| if (!Match(conjunction, m::And(m::Op(&lhs), m::Op(&rhs)))) { |
| return false; |
| } |
| struct LessThanCompareInfo { // (LT var constant) |
| HloInstruction* var; |
| int64 constant; |
| }; |
| |
| auto get_compare_info = |
| [&](HloInstruction* cmp) -> absl::optional<LessThanCompareInfo> { |
| HloInstruction *lhs, *rhs; |
| auto scalar_shape_matcher = |
| m::Shape().IsEffectiveScalar().WithElementType(PrimitiveType::S32); |
| if (Match(cmp, m::Compare(m::Op(&lhs), |
| m::Constant(&rhs).WithShape(scalar_shape_matcher)) |
| .WithComparisonDirection(ComparisonDirection::kLt))) { |
| return {LessThanCompareInfo{lhs, *rhs->literal().GetFirstInteger()}}; |
| } else if (Match( |
| cmp, |
| m::Compare(m::Constant(&lhs).WithShape(scalar_shape_matcher), |
| m::Op(&rhs)) |
| .WithComparisonDirection(ComparisonDirection::kGt))) { |
| return {LessThanCompareInfo{rhs, *lhs->literal().GetFirstInteger()}}; |
| } |
| return absl::nullopt; |
| }; |
| |
| absl::optional<LessThanCompareInfo> lhs_info = get_compare_info(lhs); |
| absl::optional<LessThanCompareInfo> rhs_info = get_compare_info(rhs); |
| if (lhs_info && rhs_info && lhs_info->var == rhs_info->var) { |
| int64 new_bound = std::min(lhs_info->constant, rhs_info->constant); |
| TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( |
| conjunction, |
| HloInstruction::CreateCompare(lhs->shape(), lhs_info->var, |
| MakeScalarLike(lhs_info->var, new_bound), |
| ComparisonDirection::kLt))); |
| return true; |
| } |
| return false; |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) { |
| HloInstruction *lhs, *rhs; |
| CHECK(Match(logical_and, m::And(m::Op(&lhs), m::Op(&rhs)))); |
| // Simplify logical and |
| if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) && |
| ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) { |
| // A && True => A |
| VLOG(10) << "trying transform [A && True => A]: " |
| << logical_and->ToString(); |
| if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(logical_and, lhs)) { |
| return Status::OK(); |
| } |
| // True && A => A |
| VLOG(10) << "trying transform [True && A => A]: " |
| << logical_and->ToString(); |
| if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(logical_and, rhs)) { |
| return Status::OK(); |
| } |
| } |
| |
| // A && False => False or A & 0 => 0 |
| VLOG(10) << "trying transform [A && False => False]: " |
| << logical_and->ToString(); |
| if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(logical_and, rhs)) { |
| return Status::OK(); |
| } |
| |
| // False && A => False or A & 0 => 0 |
| VLOG(10) << "trying transform [False && A => False]: " |
| << logical_and->ToString(); |
| if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(logical_and, lhs)) { |
| return Status::OK(); |
| } |
| |
| // Simplify tautological conjunctions. |
| TF_ASSIGN_OR_RETURN(bool found_tautological_compare, |
| TrySimplifyTautologicalCompare(logical_and)); |
| if (found_tautological_compare) { |
| return Status::OK(); |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) { |
| // If a bitcast feeds a bitcast, make it a single bitcast. |
| HloInstruction* op; |
| if (Match(bitcast, m::Bitcast(m::Bitcast(m::Op(&op))))) { |
| return ReplaceWithNewInstruction( |
| bitcast, HloInstruction::CreateBitcast(bitcast->shape(), op)); |
| } |
| // All bitcasts can be eliminated (assuming layout constraints are |
| // satisfied). |
| ReplaceInstructionIfSameShape(bitcast, bitcast->mutable_operand(0)); |
| return Status::OK(); |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleBitcastConvert( |
| HloInstruction* bitcast) { |
| // Eliminate bitcast converts between same shape. |
| ReplaceInstructionIfSameShape(bitcast, bitcast->mutable_operand(0)); |
| return Status::OK(); |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { |
| // If a copy feeds a copy, make it a single copy. |
| HloInstruction* op; |
| if (Match(copy, m::Copy(m::Copy(m::Op(&op))))) { |
| return ReplaceWithNewInstruction( |
| copy, HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy, op)); |
| } |
| // All copies can be eliminated (assuming layout constraints are satisfied). |
| if (ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0))) { |
| return Status::OK(); |
| } |
| |
| if (HloInstruction* bitcast_operand = |
| BitcastingOperandOfReshapeOrCopyChain(copy, options_)) { |
| ReplaceWithBitcast(copy, bitcast_operand); |
| return Status::OK(); |
| } |
| |
| // Replace Copy(Reshape()) with Reshape() if the Reshape is a logical bitcast. |
| if (copy->operand(0)->opcode() == HloOpcode::kReshape && |
| copy->operand(0)->user_count() == 1 && |
| ShapeUtil::ReshapeIsBitcast(copy->operand(0)->shape(), copy->shape())) { |
| return ReplaceWithNewInstruction( |
| copy, |
| copy->operand(0)->CloneWithNewOperands( |
| copy->shape(), {copy->mutable_operand(0)->mutable_operand(0)})); |
| } |
| return Status::OK(); |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleConcatenate( |
| HloInstruction* concatenate) { |
| absl::Span<HloInstruction* const> operands(concatenate->operands()); |
| if (operands.size() == 1) { |
| // Unary concatenates are useless. |
| ReplaceInstructionIfSameShape(concatenate, operands[0]); |
| return Status::OK(); |
| } |
| // Filter out and remove empty operands. |
| std::vector<HloInstruction*> nonempty_operands; |
| for (HloInstruction* operand : operands) { |
| if (!ShapeUtil::IsZeroElementArray(operand->shape())) { |
| nonempty_operands.push_back(operand); |
| } |
| } |
| if (nonempty_operands.size() < operands.size()) { |
| HloInstruction* replacement; |
| if (nonempty_operands.empty()) { |
| replacement = operands[0]; |
| } else if (nonempty_operands.size() == 1) { |
| replacement = nonempty_operands[0]; |
| } else { |
| replacement = |
| computation_->AddInstruction(concatenate->CloneWithNewOperands( |
| concatenate->shape(), nonempty_operands)); |
| } |
| VLOG(10) << "trying to replace " << concatenate->ToString() << " with " |
| << replacement->ToString(); |
| ReplaceInstructionIfSameShape(concatenate, replacement); |
| return Status::OK(); |
| } |
| |
| if (options_.is_layout_sensitive()) { |
| return Status::OK(); |
| } |
| |
| // Check if we can merge "adjacent" slice operands which take slices from the |
| // same other op. For simplicity we only merge unstrided slices. |
| int64 concatenate_dimension = concatenate->concatenate_dimension(); |
| std::vector<HloInstruction*> new_operands; |
| int64 i = 0; |
| while (i < operands.size()) { |
| if (operands[i]->opcode() != HloOpcode::kSlice || |
| !IsUnstridedSlice(operands[i])) { |
| new_operands.push_back(operands[i]); |
| ++i; |
| continue; |
| } |
| int64 slice_end = operands[i]->slice_limits(concatenate_dimension); |
| HloInstruction* slice_operand = operands[i]->mutable_operand(0); |
| int64 j = i + 1; |
| while (j < operands.size()) { |
| if (operands[j]->opcode() != HloOpcode::kSlice || |
| !IsUnstridedSlice(operands[j]) || |
| operands[j]->operand(0) != slice_operand || |
| operands[j]->slice_starts(concatenate_dimension) != slice_end) { |
| break; |
| } |
| // Check that all the slice_start values are the same in all other |
| // dimensions. This implies that the slice_limit values are also the same, |
| // because operands of concatenate need to have the same shape, and we |
| // already checked that the slices are unstrided. |
| bool same_other_starts = true; |
| for (int64 k = 0; k < operands[j]->slice_starts().size(); ++k) { |
| if (k == concatenate_dimension) { |
| continue; |
| } |
| if (operands[i]->slice_starts(k) != operands[j]->slice_starts(k)) { |
| same_other_starts = false; |
| break; |
| } |
| } |
| if (!same_other_starts) { |
| break; |
| } |
| slice_end = operands[j]->slice_limits(concatenate_dimension); |
| ++j; |
| } |
| if (j - i > 1) { |
| Shape new_slice_shape = operands[i]->shape(); |
| new_slice_shape.set_dimensions( |
| concatenate_dimension, |
| slice_end - operands[i]->slice_starts(concatenate_dimension)); |
| simplifier_->UpdateLayout(&new_slice_shape); |
| auto new_limit_indices = operands[i]->slice_limits(); |
| new_limit_indices[concatenate_dimension] = slice_end; |
| auto new_slice_op = |
| computation_->AddInstruction(HloInstruction::CreateSlice( |
| new_slice_shape, slice_operand, |
| /*start_indices=*/operands[i]->slice_starts(), |
| /*limit_indices=*/new_limit_indices, |
| /*strides=*/operands[i]->slice_strides())); |
| new_operands.push_back(new_slice_op); |
| } else { |
| new_operands.push_back(operands[i]); |
| } |
| i = j; |
| } |
| if (new_operands.size() < operands.size()) { |
| auto replacement = computation_->AddInstruction( |
| concatenate->CloneWithNewOperands(concatenate->shape(), new_operands)); |
| ReplaceInstructionIfSameShape(concatenate, replacement); |
| return Status::OK(); |
| } |
| |
| if (operands.size() == 2) { |
| // A binary concat with a broadcasted scalar as an operand can be converted |
| // into a pad which is simpler to fold into other operations. |
| bool is_effective_low_pad = Match( |
| operands[0], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar()))); |
| bool is_effective_high_pad = Match( |
| operands[1], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar()))); |
| if (!is_effective_low_pad && !is_effective_high_pad) { |
| return Status::OK(); |
| } |
| PaddingConfig padding_config; |
| for (int64 dim = 0; dim < operands[0]->shape().rank(); ++dim) { |
| auto padding_config_dim = padding_config.add_dimensions(); |
| padding_config_dim->set_edge_padding_high(0); |
| padding_config_dim->set_edge_padding_low(0); |
| padding_config_dim->set_interior_padding(0); |
| if (dim == concatenate_dimension) { |
| if (is_effective_low_pad) { |
| padding_config_dim->set_edge_padding_low( |
| operands[0]->shape().dimensions(dim)); |
| } else { |
| padding_config_dim->set_edge_padding_high( |
| operands[1]->shape().dimensions(dim)); |
| } |
| } |
| } |
| int64 operand_to_pad = is_effective_low_pad ? 1 : 0; |
| int64 pad_value_operand = is_effective_low_pad ? 0 : 1; |
| HloInstruction* pad = |
| computation_->AddInstruction(HloInstruction::CreatePad( |
| concatenate->shape(), operands[operand_to_pad], |
| operands[pad_value_operand]->mutable_operand(0), padding_config)); |
| return ReplaceInstruction(concatenate, pad); |
| } |
| |
| if (absl::c_count(operands, operands[0]) == operands.size() && |
| operands[0]->shape().dimensions(concatenate_dimension) == 1) { |
| Shape new_shape = operands[0]->shape(); |
| absl::InlinedVector<int64, 8> broadcast_dims; |
| for (int64 i = 0; i < new_shape.rank(); ++i) { |
| if (i == concatenate_dimension) { |
| continue; |
| } |
| broadcast_dims.push_back(i); |
| } |
| new_shape.DeleteDimension(concatenate_dimension); |
| return ReplaceInstruction( |
| concatenate, |
| MakeBroadcastHlo(MakeReshapeHlo(new_shape, operands[0]).ValueOrDie(), |
| broadcast_dims, concatenate->shape())); |
| } |
| return Status::OK(); |
| } |
| |
| static HloInstruction* BuildTupleConstant(HloComputation* computation, |
| const LiteralSlice& literal, |
| AlgebraicSimplifier* simplifier) { |
| if (literal.shape().IsTuple()) { |
| std::vector<HloInstruction*> elems; |
| elems.reserve(ShapeUtil::TupleElementCount(literal.shape())); |
| for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) { |
| elems.push_back(BuildTupleConstant( |
| computation, LiteralSlice(literal, {i}), simplifier)); |
| } |
| return computation->AddInstruction(HloInstruction::CreateTuple(elems)); |
| } else { |
| return computation->AddInstruction( |
| simplifier->CreateConstantWithLayoutUpdated(literal.Clone())); |
| } |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { |
| // Tuple constants aren't directly supported by any backend. Expand them into |
| // explicit Tuple instructions. |
| if (constant->shape().IsTuple()) { |
| return ReplaceInstruction( |
| constant, |
| BuildTupleConstant(computation_, constant->literal(), simplifier_)); |
| } |
| |
| if (constant->shape().element_type() == TOKEN) { |
| return Status::OK(); |
| } |
| |
| // If a literal is all the same element replace it with a scalar broadcast. |
| if (ShapeUtil::ElementsIn(constant->shape()) > 1 && |
| constant->literal().IsAllFirst()) { |
| Literal unique_scalar( |
| LiteralUtil::GetFirstScalarLiteral(constant->literal())); |
| HloInstruction* scalar = computation_->AddInstruction( |
| simplifier_->CreateConstantWithLayoutUpdated(std::move(unique_scalar))); |
| return ReplaceWithNewInstruction( |
| constant, |
| HloInstruction::CreateBroadcast(constant->shape(), scalar, {})); |
| } |
| |
| // If a literal is an increasing sequence from zero, replace it with an iota. |
| if (ShapeUtil::ElementsIn(constant->shape()) > 1 && |
| constant->literal().IsR1Iota()) { |
| return ReplaceWithNewInstruction( |
| constant, HloInstruction::CreateIota(constant->shape(), 0)); |
| } |
| |
| if (absl::optional<int64> stride = constant->literal().IsR1StridedIota()) { |
| // Replace the constant with iota * stride. |
| HloInstruction* stride_hlo = MakeScalarLike(constant, *stride); |
| HloInstruction* iota = computation_->AddInstruction( |
| HloInstruction::CreateIota(constant->shape(), 0)); |
| return ReplaceWithNewInstruction( |
| constant, |
| HloInstruction::CreateBinary(constant->shape(), HloOpcode::kMultiply, |
| iota, stride_hlo)); |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) { |
| HloInstruction *lhs, *rhs; |
| CHECK(Match(sub, m::Subtract(m::Op(&lhs), m::Op(&rhs)))); |
| // A - 0 => A |
| VLOG(10) << "trying transform [A - 0 => A]: " << sub->ToString(); |
| if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(sub, lhs)) { |
| return Status::OK(); |
| } |
| |
| // Canonicalize subtraction of a constant to addition. |
| VLOG(10) << "trying transform [A - Const => A + (-Const)]"; |
| if (Match(sub, m::Subtract(m::NonConstant(&lhs), m::Constant(&rhs))) || |
| Match(sub, m::Subtract(m::NonConstant(&lhs), |
| m::Broadcast(m::Constant(&rhs))))) { |
| HloInstruction* negative_const = computation_->AddInstruction( |
| HloInstruction::CreateUnary(rhs->shape(), HloOpcode::kNegate, rhs)); |
| if (const HloInstruction* broadcast = |
| DynCast<HloBroadcastInstruction>(sub->operand(1))) { |
| negative_const = |
| computation_->AddInstruction(HloInstruction::CreateBroadcast( |
| broadcast->shape(), negative_const, broadcast->dimensions())); |
| } |
| return ReplaceWithNewInstruction( |
| sub, HloInstruction::CreateBinary(sub->shape(), HloOpcode::kAdd, lhs, |
| negative_const)); |
| } |
| |
| return Status::OK(); |
| } |
| namespace { |
| template <typename T> |
| Status InvertConstant(const HloInstruction& constant, Literal* result) { |
| return result->Populate<T>([&](absl::Span<const int64> indices) { |
| return T{1.0} / constant.literal().Get<T>(indices); |
| }); |
| } |
| |
| template <typename T> |
| std::unique_ptr<HloInstruction> TryDivideToShift( |
| HloInstruction* divide, HloComputation* computation, |
| AlgebraicSimplifier* simplifier) { |
| HloInstruction *a, *b, *c; |
| CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b)))); |
| |
| if (ShapeUtil::ElementIsIntegral(divide->shape()) && |
| !Match(b, m::ConstantEffectiveScalar(&c)) && |
| !Match(b, m::Broadcast(m::ConstantEffectiveScalar(&c)))) { |
| return nullptr; |
| } |
| |
| if (ShapeUtil::ElementIsSigned(divide->shape())) { |
| int64 b_value = c->literal().GetFirstElement<T>(); |
| if (b_value > 0 && IsPowerOfTwo(static_cast<uint64>(b_value))) { |
| // Handle negative dividends by negating the result of the division. |
| HloInstruction* zero_like_a = MakeScalarLike(a, 0); |
| |
| Shape changed_shape = ShapeUtil::ChangeElementType(a->shape(), PRED); |
| simplifier->UpdateLayout(&changed_shape); |
| auto* dividend_is_negative = |
| computation->AddInstruction(HloInstruction::CreateCompare( |
| changed_shape, a, zero_like_a, ComparisonDirection::kLt)); |
| |
| auto* negated_dividend = computation->AddInstruction( |
| HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a)); |
| |
| auto* abs_dividend = |
| computation->AddInstruction(HloInstruction::CreateTernary( |
| a->shape(), HloOpcode::kSelect, dividend_is_negative, |
| negated_dividend, a)); |
| |
| auto* quotient = computation->AddInstruction(HloInstruction::CreateBinary( |
| divide->shape(), HloOpcode::kShiftRightLogical, abs_dividend, |
| MakeScalarLike(abs_dividend, tensorflow::Log2Floor64(b_value)))); |
| |
| auto* neqated_quotient = |
| computation->AddInstruction(HloInstruction::CreateUnary( |
| quotient->shape(), HloOpcode::kNegate, quotient)); |
| |
| return HloInstruction::CreateTernary(divide->shape(), HloOpcode::kSelect, |
| dividend_is_negative, |
| neqated_quotient, quotient); |
| } |
| } else { |
| uint64 b_value = c->literal().GetFirstElement<T>(); |
| if (IsPowerOfTwo(b_value)) { |
| return HloInstruction::CreateBinary( |
| divide->shape(), HloOpcode::kShiftRightLogical, a, |
| MakeScalarLike(a, tensorflow::Log2Floor64(b_value))); |
| } |
| } |
| |
| return nullptr; |
| } |
| } // namespace |
| |
| Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { |
| HloInstruction *a, *b, *c, *d; |
| CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b)))); |
| // A/1 => A |
| VLOG(10) << "trying transform [A/1 => A]: " << divide->ToString(); |
| if (IsAll(b, 1) && ReplaceInstructionIfSameShape(divide, a)) { |
| return Status::OK(); |
| } |
| |
| // A / B => A >> log2(B) if B is a power of 2. |
| switch (divide->shape().element_type()) { |
| case S8: |
| if (std::unique_ptr<HloInstruction> shift = |
| TryDivideToShift<int8>(divide, computation_, simplifier_)) { |
| return ReplaceWithNewInstruction(divide, std::move(shift)); |
| } |
| break; |
| case S16: |
| if (std::unique_ptr<HloInstruction> shift = |
| TryDivideToShift<int16>(divide, computation_, simplifier_)) { |
| return ReplaceWithNewInstruction(divide, std::move(shift)); |
| } |
| break; |
| case S32: |
| if (std::unique_ptr<HloInstruction> shift = |
| TryDivideToShift<int32>(divide, computation_, simplifier_)) { |
| return ReplaceWithNewInstruction(divide, std::move(shift)); |
| } |
| break; |
| case S64: |
| if (std::unique_ptr<HloInstruction> shift = |
| TryDivideToShift<int64>(divide, computation_, simplifier_)) { |
| return ReplaceWithNewInstruction(divide, std::move(shift)); |
| } |
| break; |
| case U8: |
| if (std::unique_ptr<HloInstruction> shift = |
| TryDivideToShift<uint8>(divide, computation_, simplifier_)) { |
| return ReplaceWithNewInstruction(divide, std::move(shift)); |
| } |
| break; |
| case U16: |
| if (std::unique_ptr<HloInstruction> shift = |
| TryDivideToShift<uint16>(divide, computation_, simplifier_)) { |
| return ReplaceWithNewInstruction(divide, std::move(shift)); |
| } |
| break; |
| case U32: |
| if (std::unique_ptr<HloInstruction> shift = |
| TryDivideToShift<uint32>(divide, computation_, simplifier_)) { |
| return ReplaceWithNewInstruction(divide, std::move(shift)); |
| } |
| break; |
| case U64: |
| if (std::unique_ptr<HloInstruction> shift = |
| TryDivideToShift<uint64>(divide, computation_, simplifier_)) { |
| return ReplaceWithNewInstruction(divide, std::move(shift)); |
| } |
| break; |
| default: |
| break; |
| } |
| |
| Shape* shape; |
| // exp(A)/exp(B) => exp(A-B) |
| if (Match(divide, m::Divide(m::Exp(m::Op(&a)), m::Exp(m::Op(&b))) |
| .WithShape(m::Shape(&shape)))) { |
| VLOG(10) << "transform [exp(A)/exp(B) => exp(A-B)]: " << divide->ToString(); |
| HloInstruction* subtract = computation_->AddInstruction( |
| HloInstruction::CreateBinary(*shape, HloOpcode::kSubtract, a, b)); |
| return ReplaceWithNewInstruction( |
| divide, HloInstruction::CreateUnary(*shape, HloOpcode::kExp, subtract)); |
| } |
| |
| // A/exp(B) => A*exp(-B) |
| if (Match(divide, m::Divide(m::Op(&a), m::Exp(m::Op(&b))))) { |
| VLOG(10) << "transform [A/exp(B) => A*exp(-B)]: " << divide->ToString(); |
| HloInstruction* negate = computation_->AddInstruction( |
| HloInstruction::CreateUnary(divide->shape(), HloOpcode::kNegate, b)); |
| HloInstruction* new_exp = computation_->AddInstruction( |
| HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp, negate)); |
| return ReplaceWithNewInstruction( |
| divide, HloInstruction::CreateBinary(divide->shape(), |
| HloOpcode::kMultiply, a, new_exp)); |
| } |
| |
| // A/pow(B,C) => A*pow(B,-C) |
| if (Match(divide, m::Divide(m::Op(&a), m::Power(m::Op(&b), m::Op(&c))))) { |
| VLOG(10) << "transform [A/pow(B,C) => A*pow(B,-C)]: " << divide->ToString(); |
| // The output shape of the created negate operator should be the same as the |
| // input. |
| const Shape& negate_shape = c->shape(); |
| HloInstruction* negate = computation_->AddInstruction( |
| HloInstruction::CreateUnary(negate_shape, HloOpcode::kNegate, c)); |
| // And the power operator should retain the output shape of the old one. |
| const Shape& new_power_shape = b->shape(); |
| HloInstruction* new_power = |
| computation_->AddInstruction(HloInstruction::CreateBinary( |
| new_power_shape, HloOpcode::kPower, b, negate)); |
| return ReplaceWithNewInstruction( |
| divide, HloInstruction::CreateBinary( |
| divide->shape(), HloOpcode::kMultiply, a, new_power)); |
| } |
| |
| // A/sqrt(B) => A*rsqrt(X). |
| if (Match(divide, m::Divide(m::Op(&a), m::Sqrt(m::Op(&b))))) { |
| auto* rsqrt = computation_->AddInstruction( |
| HloInstruction::CreateUnary(divide->shape(), HloOpcode::kRsqrt, b)); |
| return ReplaceWithNewInstruction( |
| divide, HloInstruction::CreateBinary(rsqrt->shape(), |
| HloOpcode::kMultiply, a, rsqrt)); |
| } |
| |
| // A/rsqrt(B) => A*sqrt(B). |
| if (Match(divide, m::Divide(m::Op(&a), m::Rsqrt(m::Op(&b))))) { |
| auto* sqrt = computation_->AddInstruction( |
| HloInstruction::CreateUnary(divide->shape(), HloOpcode::kSqrt, b)); |
| return ReplaceWithNewInstruction( |
| divide, HloInstruction::CreateBinary(sqrt->shape(), |
| HloOpcode::kMultiply, a, sqrt)); |
| } |
| |
| // Simplifying integral division would produce unexpected results. |
| if (ShapeUtil::ElementIsIntegral(divide->shape())) { |
| return Status::OK(); |
| } |
| |
| // A / Const => A * (1 / Const) |
| // |
| // (Backends can do this transformation, but generally only if the constant is |
| // a scalar.) |
| if (Match(divide, m::Divide(m::NonConstant(&a), m::Op(&b))) && |
| (Match(b, m::Constant(&c)) || Match(b, m::Broadcast(m::Constant(&c))))) { |
| Shape result_shape = c->literal().shape(); |
| Literal new_literal(result_shape); |
| switch (result_shape.element_type()) { |
| case F16: |
| TF_RETURN_IF_ERROR(InvertConstant<half>(*c, &new_literal)); |
| break; |
| case F32: |
| TF_RETURN_IF_ERROR(InvertConstant<float>(*c, &new_literal)); |
| break; |
| case BF16: |
| TF_RETURN_IF_ERROR(InvertConstant<bfloat16>(*c, &new_literal)); |
| break; |
| case F64: |
| TF_RETURN_IF_ERROR(InvertConstant<double>(*c, &new_literal)); |
| break; |
| case C64: |
| TF_RETURN_IF_ERROR(InvertConstant<complex64>(*c, &new_literal)); |
| break; |
| case C128: |
| TF_RETURN_IF_ERROR(InvertConstant<complex128>(*c, &new_literal)); |
| break; |
| default: |
| return Status::OK(); |
| } |
| auto inverse = computation_->AddInstruction( |
| simplifier_->CreateConstantWithLayoutUpdated(new_literal.Clone())); |
| if (b != c) { |
| inverse = computation_->AddInstruction(HloInstruction::CreateBroadcast( |
| b->shape(), inverse, b->dimensions())); |
| } |
| TF_ASSIGN_OR_RETURN(auto new_divide, |
| MakeBinaryHlo(HloOpcode::kMultiply, a, inverse)); |
| return ReplaceInstruction(divide, new_divide); |
| } |
| |
| // (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C) |
| if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), |
| m::Divide(m::Op(&c), m::Op(&d))))) { |
| TF_ASSIGN_OR_RETURN(auto a_times_d, |
| MakeBinaryHlo(HloOpcode::kMultiply, a, d)); |
| TF_ASSIGN_OR_RETURN(auto b_times_c, |
| MakeBinaryHlo(HloOpcode::kMultiply, b, c)); |
| TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kDivide, |
| a_times_d, b_times_c)); |
| |
| return ReplaceInstruction(divide, new_divide); |
| } |
| |
| // (A / B) / C => A / (B * C) |
| if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), m::Op(&c)))) { |
| TF_ASSIGN_OR_RETURN(auto b_times_c, |
| MakeBinaryHlo(HloOpcode::kMultiply, b, c)); |
| TF_ASSIGN_OR_RETURN(auto new_divide, |
| MakeBinaryHlo(HloOpcode::kDivide, a, b_times_c)); |
| return ReplaceInstruction(divide, new_divide); |
| } |
| |
| // A / (B / C) => (A*C) / B |
| if (Match(divide, m::Divide(m::Op(&a), m::Divide(m::Op(&b), m::Op(&c))))) { |
| TF_ASSIGN_OR_RETURN(auto a_times_c, |
| MakeBinaryHlo(HloOpcode::kMultiply, a, c)); |
| TF_ASSIGN_OR_RETURN(auto new_divide, |
| MakeBinaryHlo(HloOpcode::kDivide, a_times_c, b)); |
| return ReplaceInstruction(divide, new_divide); |
| } |
| |
| // If X is a convert from pred, then |
| // X / broadcast(Y) => broadcast(1/Y) * X |
| if (Match(divide, |
| m::Divide( |
| m::Convert(&a, |
| m::Op().WithShape(m::Shape().WithElementType(PRED))), |
| m::Broadcast(m::Op(&b).WithShape(m::Shape().IsScalar()))))) { |
| TF_ASSIGN_OR_RETURN( |
| auto recip, MakeBinaryHlo(HloOpcode::kDivide, MakeScalarLike(b, 1), b)); |
| auto recip_bcast = computation_->AddInstruction( |
| HloInstruction::CreateBroadcast(divide->shape(), recip, {})); |
| TF_ASSIGN_OR_RETURN(auto mul, |
| MakeBinaryHlo(HloOpcode::kMultiply, recip_bcast, a)); |
| return ReplaceInstruction(divide, mul); |
| } |
| |
| return Status::OK(); |
| } |
| |
| StatusOr<bool> AlgebraicSimplifierVisitor::RemoveDegenerateDimensionFromDot( |
| HloInstruction* dot) { |
| const Shape& lhs_shape = dot->operand(0)->shape(); |
| int64 num_degenerate_lhs_dims = 0; |
| std::vector<int64> lhs_dimension_map(lhs_shape.rank(), -1); |
| for (int64 i = 0; i < lhs_shape.rank(); ++i) { |
| if (lhs_shape.dimensions(i) == 1) { |
| ++num_degenerate_lhs_dims; |
| } else { |
| lhs_dimension_map[i] = i - num_degenerate_lhs_dims; |
| } |
| } |
| |
| const Shape& rhs_shape = dot->operand(1)->shape(); |
| int64 num_degenerate_rhs_dims = 0; |
| std::vector<int64> rhs_dimension_map(rhs_shape.rank(), -1); |
| for (int64 i = 0; i < rhs_shape.rank(); ++i) { |
| if (rhs_shape.dimensions(i) == 1) { |
| ++num_degenerate_rhs_dims; |
| } else { |
| rhs_dimension_map[i] = i - num_degenerate_rhs_dims; |
| } |
| } |
| if (num_degenerate_lhs_dims == 0 && num_degenerate_rhs_dims == 0) { |
| return false; |
| } |
| const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); |
| DotDimensionNumbers new_dnums; |
| for (int64 dim : dnums.lhs_batch_dimensions()) { |
| int64 new_dim = lhs_dimension_map[dim]; |
| if (new_dim != -1) { |
| new_dnums.add_lhs_batch_dimensions(new_dim); |
| } |
| } |
| for (int64 dim : dnums.lhs_contracting_dimensions()) { |
| int64 new_dim = lhs_dimension_map[dim]; |
| if (new_dim != -1) { |
| new_dnums.add_lhs_contracting_dimensions(new_dim); |
| } |
| } |
| |
| for (int64 dim : dnums.rhs_batch_dimensions()) { |
| int64 new_dim = rhs_dimension_map[dim]; |
| if (new_dim != -1) { |
| new_dnums.add_rhs_batch_dimensions(new_dim); |
| } |
| } |
| for (int64 dim : dnums.rhs_contracting_dimensions()) { |
| int64 new_dim = rhs_dimension_map[dim]; |
| if (new_dim != -1) { |
| new_dnums.add_rhs_contracting_dimensions(new_dim); |
| } |
| } |
| |
| HloInstruction* new_lhs = |
| num_degenerate_lhs_dims > 0 |
| ? dot->parent()->AddInstruction(HloInstruction::CreateReshape( |
| ShapeUtil::DropDegenerateDimensions(lhs_shape), |
| dot->mutable_operand(0))) |
| : dot->mutable_operand(0); |
| HloInstruction* new_rhs = |
| num_degenerate_rhs_dims > 0 |
| ? dot->parent()->AddInstruction(HloInstruction::CreateReshape( |
| ShapeUtil::DropDegenerateDimensions(rhs_shape), |
| dot->mutable_operand(1))) |
| : dot->mutable_operand(1); |
| TF_ASSIGN_OR_RETURN( |
| auto new_dot, |
| MakeDotHlo(new_lhs, new_rhs, new_dnums, dot->precision_config(), |
| /*preferred_element_type=*/dot->shape().element_type())); |
| if (ShapeUtil::Compatible(dot->shape(), new_dot->shape())) { |
| TF_RETURN_IF_ERROR(ReplaceInstruction(dot, new_dot)); |
| } else { |
| TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( |
| dot, HloInstruction::CreateReshape(dot->shape(), new_dot))); |
| } |
| return true; |
| } |
| |
| StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcat( |
| HloInstruction* dot) { |
| const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); |
| if (dnums.lhs_contracting_dimensions_size() != 1 || |
| dnums.lhs_batch_dimensions_size() != 0 || |
| dot->shape().dimensions_size() != 2) { // dot output 2D |
| return nullptr; |
| } |
| |
| const int64 lhs_contracting_dim = dnums.lhs_contracting_dimensions(0); |
| const int64 rhs_contracting_dim = dnums.rhs_contracting_dimensions(0); |
| HloInstruction *lhs, *rhs; |
| CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); |
| |
| TF_ASSIGN_OR_RETURN( |
| HloInstruction * optimized_lhs_concat, |
| OptimizeDotOfConcatHelper(*dot, lhs, lhs_contracting_dim, rhs, |
| rhs_contracting_dim, /*swapped=*/false)); |
| if (optimized_lhs_concat) { |
| return optimized_lhs_concat; |
| } |
| |
| return OptimizeDotOfConcatHelper(*dot, rhs, rhs_contracting_dim, lhs, |
| lhs_contracting_dim, /*swapped=*/true); |
| } |
| |
| StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( |
| const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim, |
| HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped) { |
| bool can_optimize = lhs->opcode() == HloOpcode::kConcatenate && |
| lhs->concatenate_dimension() == lhs_contracting_dim && |
| rhs->opcode() == HloOpcode::kConstant; |
| if (!can_optimize) { |
| return nullptr; |
| } |
| |
| // We're replacing this: |
| // |
| // +-----+-----+-----+ +-------------------+ |
| // | | | | | | |
| // | | | | | R_0 | |
| // | | | | | | |
| // | | | | +-------------------+ |
| // | | | | | | |
| // | L_0 | L_1 | L_2 | * | R_1 | |
| // | | | | | | |
| // | | | | +-------------------+ |
| // | | | | | | |
| // | | | | | R_2 | |
| // | | | | | | |
| // +-----+-----+-----+ +-------------------+ |
| // |
| // with this: |
| // |
| // [Sum over i] |
| // |
| // +-----+ +-------------------+ |
| // | | | | |
| // | | * | R_i | |
| // | | | | |
| // | | +-------------------+ |
| // | | |
| // | L_i | |
| // | | |
| // | | |
| // | | |
| // | | |
| // | | |
| // +-----+ |
| // |
| // where the LHS is a concatenate operation (so we can "split" the LHS tensor |
| // for free) and the RHS is a constant tensor (and thus can be split at |
| // compile time). In the future, we may also want to do this when both the |
| // LHS and the RHS are concatenate operations that line up along the dimension |
| // being contracted over. |
| // |
| // We should be able to generalize this transform to work on a non-constant |
| // RHS when/if we have in-place slices or support input-fusing slices into |
| // Dots. |
| |
| // Dimension numbers for the new dot instructions we'll create (L_i * R_i in |
| // the diagram above). |
| DotDimensionNumbers new_dot_dnums; |
| new_dot_dnums.add_lhs_contracting_dimensions(swapped ? rhs_contracting_dim |
| : lhs_contracting_dim); |
| new_dot_dnums.add_rhs_contracting_dimensions(swapped ? lhs_contracting_dim |
| : rhs_contracting_dim); |
| |
| // Here we use the MKN notation, where the contracted dimension has K |
| // elements and the two non-contracted dimensions have M and N elements. |
| HloInstruction* add_result = nullptr; |
| int64 rhs_contracting_dim_offset = 0; |
| int64 n = rhs->shape().dimensions(1 - rhs_contracting_dim); |
| for (HloInstruction* concat_op : lhs->operands()) { |
| int64 sub_k = concat_op->shape().dimensions(lhs_contracting_dim); |
| Shape rhs_slice_shape(rhs->shape()); |
| rhs_slice_shape.set_dimensions(rhs_contracting_dim, sub_k); |
| simplifier_->UpdateLayout(&rhs_slice_shape); |
| |
| std::array<int64, 2> start_indices; |
| start_indices[rhs_contracting_dim] = rhs_contracting_dim_offset; |
| start_indices[1 - rhs_contracting_dim] = 0; |
| |
| std::array<int64, 2> limit_indices; |
| limit_indices[rhs_contracting_dim] = rhs_contracting_dim_offset + sub_k; |
| limit_indices[1 - rhs_contracting_dim] = n; |
| |
| HloInstruction* rhs_slice = |
| computation_->AddInstruction(HloInstruction::CreateSlice( |
| rhs_slice_shape, rhs, /*start_indices=*/start_indices, |
| /*limit_indices=*/limit_indices, /*strides=*/{1, 1})); |
| |
| // TODO(b/69062148): We can get rid of `swapped` once all backends support |
| // "non-canonical" contraction dimensions (that contracts dimension 1 of the |
| // LHS with dimension 0 of the RHS). But for now we keep the same |
| // contraction dimensions as the incoming dot operation to ensure the new |
| // dot operations can be lowered. |
| HloInstruction *new_dot_lhs, *new_dot_rhs; |
| if (swapped) { |
| new_dot_lhs = rhs_slice; |
| new_dot_rhs = concat_op; |
| } else { |
| new_dot_lhs = concat_op; |
| new_dot_rhs = rhs_slice; |
| } |
| |
| auto* new_dot = computation_->AddInstruction( |
| HloInstruction::CreateDot(dot.shape(), new_dot_lhs, new_dot_rhs, |
| new_dot_dnums, dot.precision_config())); |
| |
| if (add_result) { |
| add_result = computation_->AddInstruction(HloInstruction::CreateBinary( |
| dot.shape(), HloOpcode::kAdd, add_result, new_dot)); |
| } else { |
| add_result = new_dot; |
| } |
| |
| rhs_contracting_dim_offset += sub_k; |
| } |
| |
| return add_result; |
| } |
| |
| StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather( |
| HloInstruction* dot) { |
| const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); |
| if (dnums.lhs_contracting_dimensions_size() != 1 || |
| dnums.rhs_contracting_dimensions_size() != 1 || |
| dnums.lhs_batch_dimensions_size() != 0 || |
| dnums.rhs_batch_dimensions_size() != 0 || |
| dot->shape().dimensions_size() != 2) { // dot output 2D |
| VLOG(10) << "DotOfGather: Can only optimize 2D, non-batch dot operations."; |
| return nullptr; |
| } |
| |
| // Optimize either dot(DS(ctA), ctB)) or dot(ctB, DS(ctA)). |
| // Currently a Gather is a DynamicSlice. |
| auto is_dynamic_slice_constant_combination = |
| [](HloInstruction* a, HloInstruction* b, int a_contracting_dimension) { |
| // First operand is a DynamicSlice(Constant). |
| if (a->opcode() != HloOpcode::kDynamicSlice) { |
| return false; |
| } |
| auto* dynamic_slice_op = a->operand(0); |
| if (dynamic_slice_op->opcode() != HloOpcode::kConstant) { |
| return false; |
| } |
| // Second operand is a Constant. |
| if (b->opcode() != HloOpcode::kConstant) { |
| return false; |
| } |
| // The DynamicSlice output is a vector. |
| const Shape& dynamic_slice_shape = a->shape(); |
| if (dynamic_slice_shape.dimensions(1 - a_contracting_dimension) != 1) { |
| return false; |
| } |
| // Constant size is the same before and after slice in the contracting |
| // dimension, otherwise we either must precompute for all possible slice |
| // indices or dot is invalid. |
| const Shape& dynamic_slice_op_shape = dynamic_slice_op->shape(); |
| if (dynamic_slice_op_shape.dimensions(a_contracting_dimension) != |
| dynamic_slice_shape.dimensions(a_contracting_dimension)) { |
| return false; |
| } |
| return true; |
| }; |
| |
| HloInstruction* lhs = dot->mutable_operand(0); |
| HloInstruction* rhs = dot->mutable_operand(1); |
| int lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0); |
| int rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0); |
| |
| if (!is_dynamic_slice_constant_combination( |
| lhs, rhs, /*a_contracting_dimension=*/lhs_contracting_dimension) && |
| !is_dynamic_slice_constant_combination( |
| rhs, lhs, /*a_contracting_dimension=*/rhs_contracting_dimension)) { |
| VLOG(10) << "DotOfGather: Can only optimize dot(DS(ctA), ctB)) or " |
| "dot(ctB, DS(ctA)), where the two constants have equal " |
| "contracting dimensions."; |
| return nullptr; |
| } |
| |
| // LHS is DynamicSlice: |
| // input: dot(DS(ctA), ctB)) |
| // where DS(ctA) = DS({M x K}, {start, 0}, {1, K}) and ctB = {K x N}. |
| // => input dimensions: dot({1 x K}, {K x N}) => {1 x N}. |
| // output: DS(dot(ctA, ctB)) |
| // => output dimensions: DS ({M x N}, {start, 0}, {1, N}) => {1 x N}. |
| |
| // RHS is DynamicSlice: |
| // input: dot(ctA, DS(ctB)) |
| // where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, start}, {K, 1}). |
| // => input dimensions: dot({M x K}, {K x 1}) => {M x 1}. |
| // output: DS(dot(ctA, ctB)) |
| // => output dimensions: DS ({M x N}, {0, start}, {M, 1}) => {M x 1}. |
| |
| bool lhs_is_dynamic_slice = lhs->opcode() == HloOpcode::kDynamicSlice; |
| HloDynamicSliceInstruction* dynamic_slice = |
| lhs_is_dynamic_slice ? Cast<HloDynamicSliceInstruction>(lhs) |
| : Cast<HloDynamicSliceInstruction>(rhs); |
| |
| // ctA: |
| HloInstruction* left_operand = |
| lhs_is_dynamic_slice ? lhs->mutable_operand(0) : lhs; |
| // ctB: |
| HloInstruction* right_operand = |
| lhs_is_dynamic_slice ? rhs : rhs->mutable_operand(0); |
| // Build ctA x ctB. |
| const int m = left_operand->shape().dimensions(1 - lhs_contracting_dimension); |
| const int n = |
| right_operand->shape().dimensions(1 - rhs_contracting_dimension); |
| auto memoized_shape = |
| ShapeUtil::MakeShape(dot->shape().element_type(), {m, n}); |
| simplifier_->UpdateLayout(&memoized_shape); |
| auto* memoized_inst = computation_->AddInstruction( |
| HloInstruction::CreateDot(memoized_shape, left_operand, right_operand, |
| dnums, dot->precision_config())); |
| // Get pair {start, 0} or {0, start}. |
| // Position of start: |
| int index_of_non_zero_start = lhs_is_dynamic_slice |
| ? 1 - lhs_contracting_dimension |
| : 1 - rhs_contracting_dimension; |
| // Position of zero: |
| int index_of_zero_start = 1 - index_of_non_zero_start; |
| |
| // Slice out start and 0 components and reorder if necessary. |
| auto indices_type = dynamic_slice->operand(1)->shape().element_type(); |
| Shape s_shape = ShapeUtil::MakeShape(indices_type, {1}); |
| simplifier_->UpdateLayout(&s_shape); |
| Shape d_shape = ShapeUtil::MakeShape(indices_type, {2}); |
| simplifier_->UpdateLayout(&d_shape); |
| HloInstruction* non_zero_start = |
| dynamic_slice->mutable_operand(1 + index_of_non_zero_start); |
| HloInstruction* zero_start = |
| dynamic_slice->mutable_operand(1 + index_of_zero_start); |
| std::vector<HloInstruction*> new_start_indices; |
| if (lhs_is_dynamic_slice) { |
| new_start_indices = {non_zero_start, zero_start}; |
| } else { |
| new_start_indices = {zero_start, non_zero_start}; |
| } |
| |
| // Build DynamicSlice(ctA x ctB). |
| const int new_slice_m = lhs_is_dynamic_slice ? 1 : m; |
| const int new_slice_n = lhs_is_dynamic_slice ? n : 1; |
| auto* memoized_lookup = |
| computation_->AddInstruction(HloInstruction::CreateDynamicSlice( |
| dot->shape(), memoized_inst, new_start_indices, |
| {new_slice_m, new_slice_n})); |
| |
| return memoized_lookup; |
| } |
| |
| // This function tries to transform |
| // dot(reshape(transpose(A)), Const) to |
| // dot(reshape(A), reshape(transpose(reshape(Const)))), |
| // so that the reshape and transpose on the Const side can be constant folded. |
| // |
| // The basic idea is that since the accumulation in the dot operation is |
| // associative, so as long as we permute the elements of the contracting |
| // dimensions on both sides of the dot in the same way, the result of the |
| // dot is not affected. |
| StatusOr<HloInstruction*> |
| AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims( |
| HloInstruction* dot) { |
| // This transformation assumes layout is not assigned yet. |
| if (options_.is_layout_sensitive()) { |
| return nullptr; |
| } |
| |
| // Canonicalize dot(<constant>, rhs) to dot(rhs, <constant>) to make the |
| // remainder of this function easier. |
| auto dnums = dot->dot_dimension_numbers(); |
| auto lhs_contracting_dims = dnums.lhs_contracting_dimensions(); |
| auto rhs_contracting_dims = dnums.rhs_contracting_dimensions(); |
| auto* lhs = dot->mutable_operand(0); |
| auto* rhs = dot->mutable_operand(1); |
| if (dot->operand(0)->IsConstant()) { |
| std::swap(lhs, rhs); |
| std::swap(lhs_contracting_dims, rhs_contracting_dims); |
| } |
| |
| // Require single contracting dim to make the implementation easier to |
| // track contracting dims. |
| if (dnums.lhs_contracting_dimensions_size() != 1) { |
| return nullptr; |
| } |
| |
| // Pattern match Dot(reshape(transpose(input), constant)) |
| HloInstruction* reshape; |
| HloInstruction* transpose; |
| HloInstruction* input; |
| HloInstruction* constant; |
| if (!Match(lhs, |
| m::Reshape(&reshape, m::Transpose(&transpose, m::Op(&input)))) || |
| !Match(rhs, m::Constant(&constant))) { |
| return nullptr; |
| } |
| |
| // Check that reshape squishes some dims into one dim and that this one |
| // dim is the dot's lhs contracting dim. The size of unmodified_dims should |
| // be N - 1, where N is the rank of the reshape output. This means that the |
| // reshape squishes some dims into one dim. lhs contracting dim should not |
| // be in unmodified_dims. This means that the squishing target dim is the |
| // lhs contracting dim. |
| auto unmodified_dims = ShapeUtil::DimensionsUnmodifiedByReshape( |
| reshape->operand(0)->shape(), reshape->shape()); |
| CHECK_EQ(lhs_contracting_dims.size(), 1); |
| if ((unmodified_dims.size() != reshape->shape().rank() - 1) || |
| absl::c_any_of(unmodified_dims, [&](const std::pair<int64, int64>& p) { |
| return p.second == lhs_contracting_dims[0]; |
| })) { |
| return nullptr; |
| } |
| |
| // Virtually pull the reshape into the dot so the dot operates on the |
| // transpose, with "unsquished" lhs contracting dims. The new contracting |
| // dims are all of the dims that are modified by the reshape -- that is, every |
| // dimension that's not in `unmodified_dims[i].first`. |
| // |
| // (We don't need to actually create a new dot instruction. We can just keep |
| // track of lhs and lhs_contracting_dims.) |
| absl::flat_hash_set<int64> unmodified_transpose_dims; |
| for (const auto& pair : unmodified_dims) { |
| unmodified_transpose_dims.insert(pair.first); |
| } |
| lhs_contracting_dims.Clear(); |
| for (int64 i = 0; i < transpose->shape().dimensions_size(); ++i) { |
| if (!unmodified_transpose_dims.contains(i)) { |
| lhs_contracting_dims.Add(i); |
| } |
| } |
| // We require the "unsquished" lhs contracting dims to be consecutive. |
| auto is_iota = [](absl::Span<const int64> dims) { |
| return absl::c_adjacent_find(dims, [](const int64 a, const int64 b) { |
| return (b != a + 1); |
| }) == dims.end(); |
| }; |
| if (!is_iota(AsInt64Slice(lhs_contracting_dims))) { |
| return nullptr; |
| } |
| lhs = lhs->mutable_operand(0); |
| |
| // Check that the transpose only permutes the contracting dims. |
| const auto& transpose_dims = transpose->dimensions(); |
| for (int64 i = 0; i < transpose_dims.size(); ++i) { |
| if (transpose_dims[i] != i && |
| !absl::c_linear_search(lhs_contracting_dims, i)) { |
| return nullptr; |
| } |
| } |
| // Virtually pull the transpose into the dot. Now the dot is equivalent to |
| // a new dot with "permuted" lhs contracting dims. |
| std::vector<int64> permutation; |
| for (auto dim : lhs_contracting_dims) { |
| permutation.push_back(transpose_dims[dim] - lhs_contracting_dims[0]); |
| } |
| CHECK(IsPermutation(permutation)); |
| auto new_lhs_contracting_dims = |
| ComposePermutations(AsInt64Slice(lhs_contracting_dims), permutation); |
| lhs_contracting_dims.Clear(); |
| for (auto dim : new_lhs_contracting_dims) { |
| lhs_contracting_dims.Add(dim); |
| } |
| lhs = lhs->mutable_operand(0); |
| |
| // All checks are passed at this point. |
| // |
| // Transform lhs. Remove the transpose and reshape by sorting the lhs |
| // contracting dims and squishing them into a single one. We don't actually |
| // squish the lhs_contracting_dims here because we still need the unsquished |
| // contracting dims to invert reshape and transpose. |
| absl::c_sort(lhs_contracting_dims); |
| lhs = computation_->AddInstruction( |
| HloInstruction::CreateReshape(reshape->shape(), lhs)); |
| |
| // Transform rhs. Say the input HLO is: |
| // |
| // t0 = f32[2, 2, 3] parameter(0) |
| // t1 = f32[2, 3, 2] transpose(t0) dimensions={0, 2, 1} |
| // t2 = f32[2, 6] reshape(t1) |
| // t3 = f32[6, 2] constant(...) |
| // dot = f32[2, 2] dot(t2, t3) lhs_contracting_dims={1}, |
| // rhs_contracting_dims={0} |
| // |
| // At this point in the function, we have decided that the second and third |
| // dims of t0 can be switched to remove the transpose, and we have |
| // "virtually decomposed" the input HLO to: |
| // |
| // t0 = f32[2, 2, 3] parameter(0) |
| // t2' = f32[2, 6] reshape(t0) |
| // t3' = f32[6, 2] ops-to-be-filled ... |
| // dot = f32[2, 2] dot(t2', t3') lhs_contracting_dims={1}, |
| // rhs_contracting_dims={0} |
| // |
| // The rest of this function is to fill in the ops of t3'. To do this, we |
| // unsquish the contracting dimensions in t3 and then apply the inverse of |
| // the transpose from t1. |
| |
| // Invert reshape. |
| CHECK_EQ(rhs_contracting_dims.size(), 1); |
| std::vector<int64> rhs_unsquished_shape_dims = |
| SpanToVector(constant->shape().dimensions()); |
| auto it = rhs_unsquished_shape_dims.erase(rhs_unsquished_shape_dims.begin() + |
| rhs_contracting_dims[0]); |
| for (auto dim : lhs_contracting_dims) { |
| it = rhs_unsquished_shape_dims.insert(it, |
| transpose->shape().dimensions(dim)); |
| ++it; |
| } |
| HloInstruction* rhs_reshape = |
| computation_->AddInstruction(HloInstruction::CreateReshape( |
| ShapeUtil::MakeShape(constant->shape().element_type(), |
| rhs_unsquished_shape_dims), |
| constant)); |
| rhs = rhs_reshape; |
| |
| // Rhs reshape "unsquishes" the single rhs contracting dim into multiple dims. |
| rhs_contracting_dims.Resize(lhs_contracting_dims.size(), |
| rhs_contracting_dims[0]); |
| absl::c_iota(rhs_contracting_dims, rhs_contracting_dims[0]); |
| |
| // Invert transpose. First compute the shape. |
| std::vector<int64> rhs_transpose_shape_dims = |
| SpanToVector(rhs_reshape->shape().dimensions()); |
| it = rhs_transpose_shape_dims.erase( |
| rhs_transpose_shape_dims.begin() + rhs_contracting_dims[0], |
| rhs_transpose_shape_dims.begin() + rhs_contracting_dims[0] + |
| rhs_contracting_dims.size()); |
| for (auto dim : lhs_contracting_dims) { |
| it = rhs_transpose_shape_dims.insert(it, input->shape().dimensions(dim)); |
| ++it; |
| } |
| // Then compute the transpose dims. |
| std::vector<int64> rhs_transpose_dims(rhs_reshape->shape().rank()); |
| absl::c_iota(rhs_transpose_dims, 0); |
| it = rhs_transpose_dims.erase( |
| rhs_transpose_dims.begin() + rhs_contracting_dims[0], |
| rhs_transpose_dims.begin() + rhs_contracting_dims[0] + |
| rhs_contracting_dims.size()); |
| auto inverse_lhs_transpose_dims = InversePermutation(transpose_dims); |
| for (auto dim : lhs_contracting_dims) { |
| it = rhs_transpose_dims.insert(it, inverse_lhs_transpose_dims[dim] - |
| lhs_contracting_dims[0] + |
| rhs_contracting_dims[0]); |
| ++it; |
| } |
| HloInstruction* rhs_transpose = |
| computation_->AddInstruction(HloInstruction::CreateTranspose( |
| ShapeUtil::MakeShape(constant->shape().element_type(), |
| rhs_transpose_shape_dims), |
| rhs_reshape, rhs_transpose_dims)); |
| rhs = rhs_transpose; |
| |
| // Squish the multiple rhs contracting dims into a single one. |
| rhs = computation_->AddInstruction( |
| HloInstruction::CreateReshape(constant->shape(), rhs)); |
| |
| // If we virtually swapped lhs and rhs, we need to swap it back before |
| // creating new dot. |
| if (dot->operand(0)->IsConstant()) { |
| std::swap(lhs, rhs); |
| } |
| |
| HloInstruction* new_dot = |
| computation_->AddInstruction(HloInstruction::CreateDot( |
| dot->shape(), lhs, rhs, dnums, dot->precision_config())); |
| return new_dot; |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { |
| CHECK(computation_ == dot->parent()); |
| HloInstruction *lhs, *rhs; |
| CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); |
| if (options_.is_layout_sensitive()) { |
| return Status::OK(); |
| } |
| // Replace a zero element dot with a broadcast of the constant 0. |
| if (ShapeUtil::IsZeroElementArray(dot->shape()) || |
| ShapeUtil::IsZeroElementArray(lhs->shape()) || |
| ShapeUtil::IsZeroElementArray(rhs->shape())) { |
| auto zero = computation_->AddInstruction( |
| simplifier_->CreateConstantWithLayoutUpdated( |
| LiteralUtil::Zero(dot->shape().element_type()))); |
| return ReplaceWithNewInstruction( |
| dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); |
| } |
| |
| // If there are no contracting dimensions, a dot can be rewritten as |
| // mul(broadcast(transpose(x)),broadcast(transpose(y))) |
| if (options_.enable_dot_to_multiply_rewrite() && |
| dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 0) { |
| TF_ASSIGN_OR_RETURN( |
| HloInstruction * new_lhs, |
| NormalizeDotOperandToBatchMajorAndContractingMinor( |
| lhs, |
| AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()), |
| AsInt64Slice( |
| dot->dot_dimension_numbers().lhs_contracting_dimensions()))); |
| if (!ShapeUtil::SameElementType(dot->shape(), new_lhs->shape())) { |
| new_lhs = MakeConvertToHlo(new_lhs, dot->shape().element_type()); |
| } |
| TF_ASSIGN_OR_RETURN( |
| HloInstruction * new_rhs, |
| NormalizeDotOperandToBatchMajorAndContractingMinor( |
| rhs, |
| AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()), |
| AsInt64Slice( |
| dot->dot_dimension_numbers().rhs_contracting_dimensions()))); |
| if (!ShapeUtil::SameElementType(dot->shape(), new_rhs->shape())) { |
| new_rhs = MakeConvertToHlo(new_rhs, dot->shape().element_type()); |
| } |
| if (dot->shape().rank() != lhs->shape().rank()) { |
| std::vector<int64> lhs_broadcast_dims(lhs->shape().rank()); |
| absl::c_iota(lhs_broadcast_dims, 0); |
| new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( |
| dot->shape(), new_lhs, lhs_broadcast_dims)); |
| } |
| if (dot->shape().rank() != rhs->shape().rank()) { |
| std::vector<int64> rhs_broadcast_dims( |
| dot->dot_dimension_numbers().lhs_batch_dimensions_size()); |
| absl::c_iota(rhs_broadcast_dims, 0); |
| for (int64 i = lhs->shape().rank(); i < dot->shape().rank(); ++i) { |
| rhs_broadcast_dims.push_back(i); |
| } |
| new_rhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( |
| dot->shape(), new_rhs, rhs_broadcast_dims)); |
| } |
| return ReplaceWithNewInstruction( |
| dot, HloInstruction::CreateBinary(dot->shape(), HloOpcode::kMultiply, |
| new_lhs, new_rhs)); |
| } |
| |
| // If the lhs or rhs have only batch and contracting dimensions, a dot can be |
| // rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y)))) |
| if (options_.enable_dot_strength_reduction() && |
| ((dot->dot_dimension_numbers().lhs_batch_dimensions_size() + |
| dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == |
| lhs->shape().rank()) || |
| (dot->dot_dimension_numbers().rhs_contracting_dimensions_size() + |
| dot->dot_dimension_numbers().rhs_batch_dimensions_size() == |
| rhs->shape().rank()))) { |
| TF_ASSIGN_OR_RETURN( |
| HloInstruction * new_lhs, |
| NormalizeDotOperandToBatchMajorAndContractingMinor( |
| lhs, |
| AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()), |
| AsInt64Slice( |
| dot->dot_dimension_numbers().lhs_contracting_dimensions()))); |
| if (!ShapeUtil::SameElementType(dot->shape(), new_lhs->shape())) { |
| new_lhs = MakeConvertToHlo(new_lhs, dot->shape().element_type()); |
| } |
| |
| TF_ASSIGN_OR_RETURN( |
| HloInstruction * new_rhs, |
| NormalizeDotOperandToBatchMajorAndContractingMinor( |
| rhs, |
| AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()), |
| AsInt64Slice( |
| dot->dot_dimension_numbers().rhs_contracting_dimensions()))); |
| if (!ShapeUtil::SameElementType(dot->shape(), new_rhs->shape())) { |
| new_rhs = MakeConvertToHlo(new_rhs, dot->shape().element_type()); |
| } |
| |
| int64 lhs_outer_dims = |
| lhs->shape().rank() - |
| (dot->dot_dimension_numbers().lhs_batch_dimensions_size() + |
| dot->dot_dimension_numbers().lhs_contracting_dimensions_size()); |
| int64 rhs_outer_dims = |
| rhs->shape().rank() - |
| (dot->dot_dimension_numbers().rhs_batch_dimensions_size() + |
| dot->dot_dimension_numbers().rhs_contracting_dimensions_size()); |
| CHECK(lhs_outer_dims == 0 || rhs_outer_dims == 0); |
| if (rhs_outer_dims > 0) { |
| std::vector<int64> lhs_broadcast_dims( |
| dot->dot_dimension_numbers().lhs_batch_dimensions_size()); |
| absl::c_iota(lhs_broadcast_dims, 0); |
| lhs_broadcast_dims.resize(lhs->shape().rank()); |
| std::iota(lhs_broadcast_dims.begin() + |
| dot->dot_dimension_numbers().lhs_batch_dimensions_size(), |
| lhs_broadcast_dims.end(), |
| dot->dot_dimension_numbers().lhs_batch_dimensions_size() + |
| rhs_outer_dims); |
| new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( |
| new_rhs->shape(), new_lhs, lhs_broadcast_dims)); |
| } else if (lhs_outer_dims > 0) { |
| std::vector<int64> rhs_broadcast_dims( |
| dot->dot_dimension_numbers().rhs_batch_dimensions_size()); |
| absl::c_iota(rhs_broadcast_dims, 0); |
| rhs_broadcast_dims.resize(rhs->shape().rank()); |
| std::iota(rhs_broadcast_dims.begin() + |
| dot->dot_dimension_numbers().rhs_batch_dimensions_size(), |
| rhs_broadcast_dims.end(), |
| dot->dot_dimension_numbers().rhs_batch_dimensions_size() + |
| lhs_outer_dims); |
| new_rhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( |
| new_lhs->shape(), new_rhs, rhs_broadcast_dims)); |
| } |
| |
| TF_ASSIGN_OR_RETURN(HloInstruction * new_dot, |
| MakeBinaryHlo(HloOpcode::kMultiply, new_lhs, new_rhs)); |
| std::vector<int64> reduce_dims( |
| dot->dot_dimension_numbers().lhs_contracting_dimensions_size()); |
| PrimitiveType dot_type = |
| ShapeUtil::ElementIsFloating(dot->shape()) |
| ? (dot->shape().element_type() == F64 ? F64 : F32) |
| : dot->shape().element_type(); |
| new_dot = AsType(new_dot, dot_type); |
| const int64 outer_dims = std::max(rhs_outer_dims, lhs_outer_dims); |
| absl::c_iota( |
| reduce_dims, |
| outer_dims + dot->dot_dimension_numbers().lhs_batch_dimensions_size()); |
| new_dot = AddReduce(new_dot, reduce_dims, dot_type); |
| new_dot = AsType(new_dot, dot->shape().element_type()); |
| return ReplaceInstruction(dot, new_dot); |
| } |
| |
| // Simplify dot(reshape(transpose(A)), Const) to: |
| // dot(reshape(A), reshape(transpose(reshape(Const)))), so that the reshape |
| // and transpose on the Const side can be constant folded. |
| TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_reorder_optimized, |
| OptimizeDotOfReorderContractingDims(dot)); |
| if (dot_of_reorder_optimized) { |
| VLOG(10) << " Replaced dot " << dot->ToString() |
| << " with new dot operation: " |
| << dot_of_reorder_optimized->ToString(); |
| return ReplaceInstruction(dot, dot_of_reorder_optimized); |
| } |
| |
| TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_concat_optimized, |
| OptimizeDotOfConcat(dot)); |
| if (dot_of_concat_optimized) { |
| VLOG(10) << "Replaced dot(concat(...), constant) with add(dot(..., " |
| "constant)...)"; |
| return ReplaceInstruction(dot, dot_of_concat_optimized); |
| } |
| |
| // Simplify dot(ConstA, Gather(Index, ConstB)) to: |
| // Gather(Index, dot*(ConstA, ConstB)), where dot* is an appropriately |
| // batched version of dot. |
| TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_gather_optimized, |
| OptimizeDotOfGather(dot)); |
| if (dot_of_gather_optimized) { |
| VLOG(10) << "Replaced dot(constA, gather(i, constB)) with " |
| "gather(i, dot*(constA, constB))"; |
| return ReplaceInstruction(dot, dot_of_gather_optimized); |
| } |
| |
| TF_ASSIGN_OR_RETURN(bool removed_degenerate_dimensions, |
| RemoveDegenerateDimensionFromDot(dot)); |
| if (removed_degenerate_dimensions) { |
| return Status::OK(); |
| } |
| |
| // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)). |
| if (dot->dot_dimension_numbers().lhs_batch_dimensions_size() == 0 && |
| dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 1 && |
| dot->dot_dimension_numbers().lhs_contracting_dimensions(0) == 1 && |
| dot->dot_dimension_numbers().rhs_contracting_dimensions(0) == 0 && |
| lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) { |
| DotDimensionNumbers dot_dimension_numbers; |
| dot_dimension_numbers.add_lhs_contracting_dimensions(1); |
| dot_dimension_numbers.add_rhs_contracting_dimensions(0); |
| auto new_dot = computation_->AddInstruction(HloInstruction::CreateDot( |
| ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), |
| rhs->mutable_operand(0), lhs->mutable_operand(0), dot_dimension_numbers, |
| dot->precision_config())); |
| return ReplaceWithNewInstruction( |
| dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0})); |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleGather(HloInstruction* gather) { |
| const Shape& operand_shape = gather->operand(0)->shape(); |
| if (ShapeUtil::IsZeroElementArray(operand_shape)) { |
| return ReplaceInstruction(gather, MakeScalarLike(gather, 0)); |
| } |
| |
| // Gathering from a scalar operand is simply a broadcast of that scalar |
| if (ShapeUtil::IsEffectiveScalar(operand_shape)) { |
| HloInstruction* new_operand = gather->mutable_operand(0); |
| if (operand_shape.rank()) { |
| TF_ASSIGN_OR_RETURN(new_operand, |
| MakeReshapeHlo(ShapeUtil::MakeScalarShape( |
| operand_shape.element_type()), |
| new_operand)); |
| } |
| HloInstruction* new_gather = |
| MakeBroadcastHlo(new_operand, {}, gather->shape()); |
| return ReplaceInstruction(gather, new_gather); |
| } |
| // If the operand of a gather is very small, it is easier to fuse a |
| // sequence of selects. |
| const Shape& index_shape = gather->operand(1)->shape(); |
| if (operand_shape.rank() == 1 && |
| operand_shape.dimensions(0) <= options_.very_small_gather_size() && |
| gather->gather_dimension_numbers().index_vector_dim() == |
| index_shape.rank() && |
| gather->gather_dimension_numbers().collapsed_slice_dims_size() == 1) { |
| const int64 operand_elements = operand_shape.dimensions(0); |
| auto get_value = [&](int64 i) { |
| auto slice = computation_->AddInstruction(HloInstruction::CreateSlice( |
| ShapeUtil::MakeShape(operand_shape.element_type(), {1}), |
| gather->mutable_operand(0), {i}, {i + 1}, {1})); |
| auto scalar = computation_->AddInstruction(HloInstruction::CreateReshape( |
| ShapeUtil::MakeShape(operand_shape.element_type(), {}), slice)); |
| return computation_->AddInstruction( |
| HloInstruction::CreateBroadcast(gather->shape(), scalar, {})); |
| }; |
| auto result = get_value(0); |
| auto pred_shape = ShapeUtil::ChangeElementType(gather->shape(), PRED); |
| simplifier_->UpdateLayout(&pred_shape); |
| auto iter_shape = ShapeUtil::ChangeElementType(gather->shape(), |
| index_shape.element_type()); |
| simplifier_->UpdateLayout(&iter_shape); |
| for (int64 i = 0; i < operand_elements; ++i) { |
| auto index_mask = |
| computation_->AddInstruction(HloInstruction::CreateCompare( |
| pred_shape, gather->mutable_operand(1), |
| MakeScalarLike(gather->mutable_operand(1), i), |
| ComparisonDirection::kGe)); |
| result = computation_->AddInstruction( |
| HloInstruction::CreateTernary(gather->shape(), HloOpcode::kSelect, |
| index_mask, get_value(i), result)); |
| } |
| return ReplaceInstruction(gather, result); |
| } |
| return Status::OK(); |
| } |
| |
| namespace { |
| StatusOr<std::unique_ptr<HloInstruction>> MinMaxToClamp( |
| HloInstruction* clamp_lower_bound_bcast, HloInstruction* to_clamp, |
| HloInstruction* clamp_upper_bound_bcast, AlgebraicSimplifier* simplifier) { |
| HloInstruction* clamp_lower_bound; |
| CHECK(Match(clamp_lower_bound_bcast, |
| m::Broadcast(m::ConstantEffectiveScalar(&clamp_lower_bound)))) |
| << clamp_lower_bound_bcast->ToString(); |
| |
| HloInstruction* clamp_upper_bound; |
| CHECK(Match(clamp_upper_bound_bcast, |
| m::Broadcast(m::ConstantEffectiveScalar(&clamp_upper_bound)))) |
| << clamp_upper_bound_bcast->ToString(); |
| |
| const Literal& lower_bound = |
| Cast<HloConstantInstruction>(clamp_lower_bound)->literal(); |
| const Literal& upper_bound = |
| Cast<HloConstantInstruction>(clamp_upper_bound)->literal(); |
| |
| TF_ASSIGN_OR_RETURN(Literal lower_bound_literal_reshaped, |
| lower_bound.Reshape({})); |
| TF_ASSIGN_OR_RETURN(Literal upper_bound_literal_reshaped, |
| upper_bound.Reshape({})); |
| std::unique_ptr<HloInstruction> lower_bound_instr = |
| HloInstruction::CreateConstant(std::move(lower_bound_literal_reshaped)); |
| std::unique_ptr<HloInstruction> upper_bound_instr = |
| HloInstruction::CreateConstant(std::move(upper_bound_literal_reshaped)); |
| |
| Shape compare_shape = |
| ShapeUtil::ChangeElementType(lower_bound_instr->shape(), PRED); |
| simplifier->UpdateLayout(&compare_shape); |
| std::unique_ptr<HloInstruction> cloned_instruction = |
| HloInstruction::CreateCompare(compare_shape, lower_bound_instr.get(), |
| upper_bound_instr.get(), |
| ComparisonDirection::kLt); |
| |
| HloEvaluator evaluator; |
| TF_ASSIGN_OR_RETURN(auto result, |
| evaluator.Evaluate(cloned_instruction.get())); |
| if (result.IsAll(true)) { |
| return HloInstruction::CreateTernary(to_clamp->shape(), HloOpcode::kClamp, |
| clamp_lower_bound_bcast, to_clamp, |
| clamp_upper_bound_bcast); |
| } |
| return std::unique_ptr<HloInstruction>(); |
| } |
| } // namespace |
| |
| Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) { |
| HloInstruction *lhs, *rhs; |
| CHECK(Match(maximum, m::Maximum(m::Op(&lhs), m::Op(&rhs)))); |
| |
| HloInstruction* clamp_upper_bound_bcast; |
| HloInstruction* clamp_lower_bound_bcast; |
| HloInstruction* to_clamp; |
| if (Match(maximum, m::MaximumAnyOrder( |
| m::Broadcast(&clamp_lower_bound_bcast, |
| m::ConstantEffectiveScalar()), |
| m::MinimumAnyOrder( |
| m::Op(&to_clamp), |
| m::Broadcast(&clamp_upper_bound_bcast, |
| m::ConstantEffectiveScalar()))))) { |
| TF_ASSIGN_OR_RETURN(auto clamp, |
| MinMaxToClamp(clamp_lower_bound_bcast, to_clamp, |
| clamp_upper_bound_bcast, simplifier_)); |
| if (clamp) { |
| return ReplaceWithNewInstruction(maximum, std::move(clamp)); |
| } |
| } |
| |
| HloInstruction* clamp_lower_bound; |
| HloInstruction* clamp_upper_bound; |
| HloInstruction* max_operand; |
| HloInstruction* clamp; |
| if (Match(maximum, |
| m::MaximumAnyOrder( |
| m::Op(&max_operand), |
| m::Clamp(&clamp, m::Op(&clamp_lower_bound), m::Op(&to_clamp), |
| m::Op(&clamp_upper_bound))))) { |
| if (max_operand == clamp_lower_bound && |
| ReplaceInstructionIfSameShape(maximum, clamp)) { |
| return Status::OK(); |
| } |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum) { |
| HloInstruction *lhs, *rhs; |
| CHECK(Match(minimum, m::Minimum(m::Op(&lhs), m::Op(&rhs)))); |
| |
| HloInstruction* clamp_upper_bound_bcast; |
| HloInstruction* clamp_lower_bound_bcast; |
| HloInstruction* to_clamp; |
| if (Match(minimum, m::MinimumAnyOrder( |
| m::Broadcast(&clamp_upper_bound_bcast, |
| m::ConstantEffectiveScalar()), |
| m::MaximumAnyOrder( |
| m::Op(&to_clamp), |
| m::Broadcast(&clamp_lower_bound_bcast, |
| m::ConstantEffectiveScalar()))))) { |
| TF_ASSIGN_OR_RETURN(auto clamp, |
| MinMaxToClamp(clamp_lower_bound_bcast, to_clamp, |
| clamp_upper_bound_bcast, simplifier_)); |
| if (clamp) { |
| return ReplaceWithNewInstruction(minimum, std::move(clamp)); |
| } |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleClamp(HloInstruction* clamp) { |
| HloInstruction* clamp_lower_bound; |
| HloInstruction* clamp_upper_bound; |
| HloInstruction* to_clamp; |
| CHECK(Match(clamp, m::Clamp(m::Op(&clamp_lower_bound), m::Op(&to_clamp), |
| m::Op(&clamp_upper_bound)))); |
| |
| // clamp(a, clamp(a, x, b), b) -> clamp(a, x, b) |
| if (Match(to_clamp, m::Clamp(m::Op().Is(clamp_lower_bound), m::Op(), |
| m::Op().Is(clamp_upper_bound))) && |
| ReplaceInstructionIfSameShape(clamp, to_clamp)) { |
| return Status::OK(); |
| } |
| |
| // Eliminate redundant clamping of replica-id or partition-id. |
| if ((Match(to_clamp, m::PartitionId()) || Match(to_clamp, m::ReplicaId())) && |
| Match(clamp_lower_bound, m::ConstantScalar(0U)) && |
| Match(clamp_upper_bound, m::ConstantScalar())) { |
| int64 upper_bound = Cast<HloConstantInstruction>(clamp_upper_bound) |
| ->literal() |
| .GetFirstElement<uint32_t>(); |
| const HloModuleConfig& config = clamp->GetModule()->config(); |
| int64 runtime_bound = Match(to_clamp, m::PartitionId()) |
| ? config.num_partitions() |
| : config.replica_count(); |
| |
| // If num_partitions or replica_count is 1, infer it as unknown. |
| // pid/rid < runtime_bound => The clamp(0, pid/rid, upper_bound) is |
| // redundant if the runtime_bound <= upper_bound + 1; |
| if (runtime_bound != 1 && runtime_bound <= upper_bound + 1) { |
| return ReplaceInstruction(clamp, to_clamp); |
| } |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { |
| HloInstruction *lhs, *rhs; |
| CHECK(Match(multiply, m::Multiply(m::Op(&lhs), m::Op(&rhs)))); |
| // LHS*1 => LHS |
| VLOG(10) << "trying transform [LHS*1 => LHS]: " << multiply->ToString(); |
| if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(multiply, lhs)) { |
| return Status::OK(); |
| } |
| // 1*RHS => RHS |
| VLOG(10) << "trying transform [1*RHS => RHS]: " << multiply->ToString(); |
| if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(multiply, rhs)) { |
| return Status::OK(); |
| } |
| |
| // 0*RHS => 0. Only applies for integral types for correct NaN-handling. |
| if (IsAll(lhs, 0) && |
| primitive_util::IsIntegralType(multiply->shape().element_type()) && |
| ReplaceInstructionIfSameShape(multiply, lhs)) { |
| return Status::OK(); |
| } |
| // LHS*0 => 0 |
| if (IsAll(rhs, 0) && |
| primitive_util::IsIntegralType(multiply->shape().element_type()) && |
| ReplaceInstructionIfSameShape(multiply, rhs)) { |
| return Status::OK(); |
| } |
| |
| { |
| HloInstruction* abs_operand; |
| if (lhs == rhs && Match(lhs, m::Abs(m::Op(&abs_operand))) && |
| !ShapeUtil::ElementIsComplex(abs_operand->shape())) { |
| TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(0, abs_operand)); |
| TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(1, abs_operand)); |
| changed_ = true; |
| return Status::OK(); |
| } |
| } |
| |
| { |
| HloInstruction *convert_operand, *operand; |
| // Mul(Convert(Pred), operand) => select(pred, operand, 0) |
| if (Match(multiply, |
| m::MultiplyAnyOrder( |
| m::Op(&operand), |
| m::Convert( |
| m::Op(&convert_operand) |
| .WithShape(m::Shape().WithElementType(PRED)))))) { |
| HloInstruction* zero_like_multiply = |
| BroadcastZeros(computation_, multiply->shape().element_type(), |
| multiply->shape().dimensions()); |
| return ReplaceWithNewInstruction( |
| multiply, HloInstruction::CreateTernary( |
| multiply->shape(), HloOpcode::kSelect, convert_operand, |
| operand, zero_like_multiply)); |
| } |
| } |
| |
| { |
| HloInstruction *a, *b, *c1, *c2; |
| // Mul(Mul(x, constant1), Mul(y, constant2)) => Mul(Mul(x, y), |
| // constant1*constant2) |
| if (Match(multiply, |
| m::MultiplyAnyOrder( |
| m::MultiplyAnyOrder(m::NonConstant(&a), m::Constant(&c1)), |
| m::MultiplyAnyOrder(m::NonConstant(&b), m::Constant(&c2))))) { |
| TF_ASSIGN_OR_RETURN(auto* product_of_constants, |
| MakeBinaryHlo(HloOpcode::kMultiply, c1, c2)); |
| if (ShapeUtil::IsScalar(product_of_constants->shape()) && |
| !ShapeUtil::IsScalar(multiply->shape())) { |
| product_of_constants = |
| computation_->AddInstruction(HloInstruction::CreateBroadcast( |
| multiply->shape(), product_of_constants, {})); |
| } |
| |
| return ReplaceWithNewInstruction( |
| multiply, |
| HloInstruction::CreateBinary( |
| multiply->shape(), HloOpcode::kMultiply, |
| computation_->AddInstruction(HloInstruction::CreateBinary( |
| multiply->shape(), HloOpcode::kMultiply, a, b)), |
| product_of_constants)); |
| } |
| } |
| |
| { |
| HloInstruction *a, *c1, *c2; |
| // Mul(Mul(a, constant1), constant2) => Mul(a, constant1*constant2) |
| if (Match(multiply, |
| m::MultiplyAnyOrder( |
| m::MultiplyAnyOrder(m::NonConstant(&a), m::Constant(&c1)), |
| m::Constant(&c2)))) { |
| TF_ASSIGN_OR_RETURN(auto* product_of_constants, |
| MakeBinaryHlo(HloOpcode::kMultiply, c1, c2)); |
| if (ShapeUtil::IsScalar(product_of_constants->shape()) && |
| !ShapeUtil::IsScalar(multiply->shape())) { |
| product_of_constants = |
| computation_->AddInstruction(HloInstruction::CreateBroadcast( |
| multiply->shape(), product_of_constants, {})); |
| } |
| |
| return ReplaceWithNewInstruction( |
| multiply, |
| HloInstruction::CreateBinary(multiply->shape(), HloOpcode::kMultiply, |
| a, product_of_constants)); |
| } |
| } |
| |
| { |
| HloInstruction *a, *b, *constant, *op; |
| // Mul(Mul(a, constant1), Broadcast(b)) => |
| // Mul(Broadcast(Mul(b, constant1), a)) |
| if (Match(multiply, |
| m::MultiplyAnyOrder(m::MultiplyAnyOrder(m::NonConstant(&a), |
| m::Constant(&constant)), |
| m::Op(&op))) || |
| Match(multiply, |
| m::MultiplyAnyOrder( |
| m::MultiplyAnyOrder(m::NonConstant(&a), |
| m::Broadcast(m::Constant(&constant))), |
| m::Op(&op)))) { |
| // Check that the other side was a broadcast, and not of a constant. |
| if (ShapeUtil::IsScalar(constant->shape()) && |
| Match(op, m::Broadcast(m::NonConstant()))) { |
| auto dims = op->dimensions(); |
| b = op->mutable_operand(0); |
| if (!ShapeUtil::IsScalar(b->shape())) { |
| constant = computation_->AddInstruction( |
| HloInstruction::CreateBroadcast(b->shape(), constant, {})); |
| } |
| |
| auto new_mul = |
| computation_->AddInstruction(HloInstruction::CreateBinary( |
| b->shape(), HloOpcode::kMultiply, b, constant)); |
| |
| return ReplaceWithNewInstruction( |
| multiply, |
| HloInstruction::CreateBinary( |
| multiply->shape(), HloOpcode::kMultiply, a, |
| computation_->AddInstruction(HloInstruction::CreateBroadcast( |
| multiply->shape(), new_mul, dims)))); |
| } |
| } |
| } |
| |
| VLOG(10) << "trying transform [(A * C1) * C2 => A * (C1 * C2)]"; |
| HloInstruction *a, *c1, *c2; |
| if (Match(multiply, |
| m::Multiply(m::Multiply(m::NonConstant(&a), m::Constant(&c1)), |
| m::Constant(&c2))) || |
| Match(multiply, |
| m::Multiply( |
| m::Multiply(m::Op(&a), m::Broadcast(m::ConstantScalar(&c1))), |
| m::Broadcast(m::ConstantScalar(&c2))))) { |
| TF_ASSIGN_OR_RETURN(auto* product_of_constants, |
| MakeBinaryHlo(HloOpcode::kMultiply, c1, c2)); |
| if (ShapeUtil::IsScalar(product_of_constants->shape()) && |
| !ShapeUtil::IsScalar(multiply->shape())) { |
| product_of_constants = |
| computation_->AddInstruction(HloInstruction::CreateBroadcast( |
| multiply->shape(), product_of_constants, {})); |
| } |
| return ReplaceWithNewInstruction( |
| multiply, |
| HloInstruction::CreateBinary(multiply->shape(), HloOpcode::kMultiply, a, |
| product_of_constants)); |
| } |
| |
| VLOG(10) << "trying to transform exp(LHS) * exp(RHS) => exp(LHS+RHS) " |
| << multiply->ToString(); |
| if (Match(multiply, m::Multiply(m::Exp(m::Op(&lhs)), m::Exp(m::Op(&rhs))))) { |
| auto add = computation_->AddInstruction(HloInstruction::CreateBinary( |
| multiply->shape(), HloOpcode::kAdd, lhs, rhs)); |
| return ReplaceWithNewInstruction( |
| multiply, |
| HloInstruction::CreateUnary(multiply->shape(), HloOpcode::kExp, add)); |
| } |
| |
| VLOG(10) << "trying transform [rsqrt(B) * rsqrt(B) => 1/B] " |
| << multiply->ToString(); |
| HloInstruction* b; |
| if (Match(multiply, m::Multiply(m::Rsqrt(m::Op(&b)), m::Rsqrt(m::Op(&b)))) && |
| IsPositive(b, options_)) { |
| return ReplaceWithNewInstruction( |
| multiply, |
| HloInstruction::CreateBinary(multiply->shape(), HloOpcode::kDivide, |
| MakeScalarLike(b, 1), b)); |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleNegate(HloInstruction* negate) { |
| // negate(negate(x)) => x |
| HloInstruction* x; |
| if (Match(negate, m::Negate(m::Negate(m::Op(&x)))) && |
| ReplaceInstructionIfSameShape(negate, x)) { |
| return Status::OK(); |
| } |
| return Status::OK(); |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleNot(HloInstruction* logical_not) { |
| // not(not(x)) => x |
| HloInstruction* x; |
| if (Match(logical_not, m::Not(m::Not(m::Op(&x)))) && |
| ReplaceInstructionIfSameShape(logical_not, x)) { |
| return Status::OK(); |
| } |
| return Status::OK(); |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleOr(HloInstruction* logical_or) { |
| HloInstruction *lhs, *rhs; |
| CHECK(Match(logical_or, m::Or(m::Op(&lhs), m::Op(&rhs)))); |
| |
| // Simplify logical or |
| if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) && |
| ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) { |
| // A || True => True |
| VLOG(10) << "trying transform [A || True => True]: " |
| << logical_or->ToString(); |
| if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(logical_or, rhs)) { |
| return Status::OK(); |
| } |
| // True || A => True |
| VLOG(10) << "trying transform [True || A => True]: " |
| << logical_or->ToString(); |
| if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(logical_or, lhs)) { |
| return Status::OK(); |
| } |
| } |
| |
| // A || False => A and A | 0 => A |
| VLOG(10) << "trying transform [A || False => A]: " << logical_or->ToString(); |
| if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(logical_or, lhs)) { |
| return Status::OK(); |
| } |
| |
| // False || A => A and 0 | A => A |
| VLOG(10) << "trying transform [False || A => A]: " << logical_or->ToString(); |
| if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(logical_or, rhs)) { |
| return Status::OK(); |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log) { |
| // ln(exp(A)) => A |
| VLOG(10) << "trying transform [ln(exp(A)) => A]: " << log->ToString(); |
| HloInstruction *a, *b; |
| if (Match(log, m::Log(m::Exp(m::Op(&a)))) && |
| ReplaceInstructionIfSameShape(log, a)) { |
| return Status::OK(); |
| } |
| |
| // ln(pow(A,B)) => B*ln(abs(A)) |
| // or B*ln(A) if A is complex. |
| if (Match(log, m::Log(m::Power(m::Op(&a), m::Op(&b))))) { |
| auto abs_a = ShapeUtil::ElementIsComplex(a->shape()) |
| ? a |
| : computation_->AddInstruction(HloInstruction::CreateUnary( |
| log->shape(), HloOpcode::kAbs, a)); |
| auto new_log = computation_->AddInstruction( |
| HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, abs_a)); |
| return ReplaceWithNewInstruction( |
| log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply, |
| new_log, b)); |
| } |
| |
| if (Match(log, m::Log(m::Sqrt(m::Op(&a))))) { |
| auto new_log = computation_->AddInstruction( |
| HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, a)); |
| return ReplaceWithNewInstruction( |
| log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply, |
| new_log, MakeScalarLike(log, 0.5))); |
| } |
| |
| if (Match(log, m::Log(m::Rsqrt(m::Op(&a))))) { |
| auto new_log = computation_->AddInstruction( |
| HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, a)); |
| return ReplaceWithNewInstruction( |
| log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply, |
| new_log, MakeScalarLike(log, -0.5))); |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleGetTupleElement( |
| HloInstruction* get_tuple_element) { |
| auto operand = get_tuple_element->mutable_operand(0); |
| if (operand->opcode() == HloOpcode::kTuple) { |
| // get_tuple_element(make_tuple({A_0, A_1, ..., A_n}), i) => A_i |
| VLOG(10) << "trying transform " |
| << "[get_tuple_element(make_tuple({...,A_i,...}), i)] => A_i: " |
| << get_tuple_element->ToString(); |
| if (ReplaceInstructionIfSameShape( |
| get_tuple_element, |
| operand->mutable_operand(get_tuple_element->tuple_index()))) { |
| return Status::OK(); |
| } |
| } |
| return Status::OK(); |
| } |
| |
| namespace { |
| |
| absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified( |
| const HloInstruction* hlo, absl::Span<const int64> input_dim_indices) { |
| CHECK_EQ(hlo->opcode(), HloOpcode::kReshape); |
| return ShapeUtil::ReshapeLeavesDimensionsUnmodified( |
| hlo->operand(0)->shape(), hlo->shape(), input_dim_indices); |
| } |
| |
| // Returns true if the output of "instruction" is a permutation of the |
| // elements of "operand". Precondition: "operand" is an operand of |
| // "instruction". |
| bool OutputIsPermutationOfOperandElements(HloInstruction* instruction, |
| HloInstruction* operand) { |
| DCHECK(!instruction->OperandIndices(operand).empty()); |
| switch (instruction->opcode()) { |
| case HloOpcode::kReshape: |
| case HloOpcode::kReverse: |
| case HloOpcode::kTranspose: |
| return true; |
| case HloOpcode::kSort: |
| return (!instruction->shape().IsTuple()); |
| default: |
| return false; |
| } |
| } |
| |
| // Returns true if the output of "instruction" is a subset of the elements of |
| // "operand". Precondition: "operand" is an operand of "instruction". |
| bool OutputIsSubsetOfOperandElements(HloInstruction* instruction, |
| HloInstruction* operand) { |
| const auto operand_indices = instruction->OperandIndices(operand); |
| CHECK(!operand_indices.empty()); |
| if (operand_indices.size() != 1) { |
| return false; |
| } |
| int64 operand_index = operand_indices[0]; |
| switch (instruction->opcode()) { |
| case HloOpcode::kSlice: |
| CHECK_EQ(0, operand_index); |
| return true; |
| case HloOpcode::kDynamicSlice: |
| return operand_index == 0; |
| default: |
| return false; |
| } |
| } |
| |
| } // namespace |
| |
| Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { |
| HloInstruction* operand; |
| CHECK(Match(broadcast, m::Broadcast(m::Op(&operand)))); |
| auto dims = broadcast->dimensions(); |
| // A degenerate broadcast of a reshape that does not change the number of |
| // elements can be replaced by a reshape. |
| if (std::is_sorted(dims.begin(), dims.end()) && |
| ShapeUtil::ElementsIn(broadcast->shape()) == |
| ShapeUtil::ElementsIn(operand->shape())) { |
| VLOG(10) << "transform broadcast(X) -> reshape(X) where " |
| "n(broadcast(X)) == n(X)"; |
| return ReplaceWithNewInstruction( |
| broadcast, HloInstruction::CreateReshape(broadcast->shape(), operand)); |
| } |
| |
| // A degenerate broadcast that has the same input and output rank can be |
| // converted into a transpose. |
| if (broadcast->shape().rank() == operand->shape().rank() && |
| ShapeUtil::ElementsIn(broadcast->shape()) == |
| ShapeUtil::ElementsIn(operand->shape())) { |
| VLOG(10) << "transform broadcast(X) -> transpose(X) where " |
| "n(broadcast(X)) == n(X)"; |
| return ReplaceWithNewInstruction( |
| broadcast, |
| HloInstruction::CreateTranspose(broadcast->shape(), operand, dims)); |
| } |
| |
| // A broadcast of a reshape which merely inserts 1-sized dimensions can |
| // elide its operand. |
| { |
| bool merely_inserts_or_deletes_1_sized_dimensions; |
| std::vector<int64> inserted_indices, deleted_indices; |
| std::tie(merely_inserts_or_deletes_1_sized_dimensions, deleted_indices, |
| inserted_indices) = |
| operand->ReshapeMerelyInsertsOrDeletes1SizedDimensions(); |
| if (merely_inserts_or_deletes_1_sized_dimensions && |
| deleted_indices.empty()) { |
| std::reverse(inserted_indices.begin(), inserted_indices.end()); |
| for (auto inserted_index : inserted_indices) { |
| dims.erase(dims.begin() + inserted_index); |
| } |
| return ReplaceWithNewInstruction( |
| broadcast, |
| HloInstruction::CreateBroadcast(broadcast->shape(), |
| operand->mutable_operand(0), dims)); |
| } |
| } |
| |
| // A Broadcast that feeds a unary element-wise operation can sink the |
| // broadcast after the unary element-wise operation. |
| TF_ASSIGN_OR_RETURN( |
| bool sink_succeeded, |
| TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(broadcast)); |
| changed_ |= sink_succeeded; |
| if (sink_succeeded) { |
| return Status::OK(); |
| } |
| |
| // A scalar broadcast feeding an instruction which only permutes (reshape, |
| // transpose, sort, reverse) or selects a subset of operand elements (slice, |
| // dynamic slice) can be replaced with a broadcast directly to the output |
| // shape of the instruction. |
| if (ShapeUtil::IsScalar(operand->shape())) { |
| for (HloInstruction* user : broadcast->users()) { |
| // Skip if the broadcast user has no uses itself. |
| if (user->user_count() == 0 && user != computation_->root_instruction()) { |
| continue; |
| } |
| if (OutputIsPermutationOfOperandElements(user, broadcast) || |
| OutputIsSubsetOfOperandElements(user, broadcast)) { |
| VLOG(10) << "transform permuting/subset of a scalar broadcast into " |
| << "a single broadcast"; |
| HloInstruction* new_broadcast = computation_->AddInstruction( |
| HloInstruction::CreateBroadcast(user->shape(), operand, {})); |
| // Use HloInstruction::ReplaceAllUsesWith instead of |
| // HloComputation::ReplaceWithNewInstruction because we are replacing an |
| // instruction other than the visited instruction. |
| changed_ = true; |
| return user->ReplaceAllUsesWith(new_broadcast); |
| } |
| } |
| return Status::OK(); |
| } |
| |
| // broadcast(iota) -> iota. |
| if (operand->opcode() == HloOpcode::kIota) { |
| return ReplaceWithNewInstruction( |
| broadcast, |
| HloInstruction::CreateIota( |
| broadcast->shape(), |
| dims[Cast<HloIotaInstruction>(operand)->iota_dimension()])); |
| } |
| |
| // Merge two consecutive broadcasts into a single one. |
| if (operand->opcode() == HloOpcode::kBroadcast) { |
| std::vector<int64> new_dimensions; |
| for (auto dim : operand->dimensions()) { |
| new_dimensions.push_back(dims[dim]); |
| } |
| return ReplaceWithNewInstruction( |
| broadcast, |
| HloInstruction::CreateBroadcast( |
| broadcast->shape(), operand->mutable_operand(0), new_dimensions)); |
| } |
| if (options_.is_layout_sensitive()) { |
| return Status::OK(); |
| } |
| if (ShapeUtil::HasDegenerateDimensions(operand->shape())) { |
| auto new_operand = |
| operand->parent()->AddInstruction(HloInstruction::CreateReshape( |
| ShapeUtil::DropDegenerateDimensions(operand->shape()), operand)); |
| std::vector<int64> new_dims; |
| new_dims.reserve(new_operand->shape().rank()); |
| for (int64 i = 0; i < operand->shape().rank(); ++i) { |
| if (operand->shape().dimensions(i) != 1) { |
| new_dims.push_back(dims[i]); |
| } |
| } |
| return ReplaceWithNewInstruction( |
| broadcast, HloInstruction::CreateBroadcast(broadcast->shape(), |
| new_operand, new_dims)); |
| } |
| return Status::OK(); |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleCompare(HloInstruction* compare) { |
| HloInstruction* lhs; |
| HloInstruction* rhs; |
| CHECK(Match(compare, m::Compare(m::Op(&lhs), m::Op(&rhs)))); |
| { |
| // compare(broadcast(a) + x, broadcast(b)) ==> |
| // compare(x, broadcast(b-a)), only enabled for integral types. |
| HloInstruction *x, *a, *b; |
| if (Match(compare, |
| m::Compare( |
| m::AddAnyOrder(m::Op(&x), m::Broadcast(m::Op(&a).WithShape( |
| m::Shape().IsScalar()))), |
| m::Broadcast(m::Op(&b).WithShape(m::Shape().IsScalar()))))) { |
| if (ShapeUtil::ElementIsSigned(x->shape()) && |
| ShapeUtil::ElementIsIntegral(x->shape())) { |
| HloInstruction* sub = |
| computation_->AddInstruction(HloInstruction::CreateBinary( |
| b->shape(), HloOpcode::kSubtract, b, a)); |
| HloInstruction* broadcast = computation_->AddInstruction( |
| HloInstruction::CreateBroadcast(x->shape(), sub, {})); |
| HloInstruction* new_compare = computation_->AddInstruction( |
| HloInstruction::CreateCompare(compare->shape(), x, broadcast, |
| compare->comparison_direction())); |
| return ReplaceInstruction(compare, new_compare); |
| } |
| } |
| } |
| |
| if (Cast<HloCompareInstruction>(compare)->type() == |
| Comparison::Type::kUnsigned) { |
| // X u< 0 -> false |
| if (compare->comparison_direction() == ComparisonDirection::kLt && |
| IsAll(rhs, 0)) { |
| return ReplaceInstruction(compare, MakeScalarLike(compare, false)); |
| } |
| // X u>= 0 -> true |
| if (compare->comparison_direction() == ComparisonDirection::kGe && |
| IsAll(rhs, 0)) { |
| return ReplaceInstruction(compare, MakeScalarLike(compare, true)); |
| } |
| // 0 u> X -> false |
| if (compare->comparison_direction() == ComparisonDirection::kGt && |
| IsAll(lhs, 0)) { |
| return ReplaceInstruction(compare, MakeScalarLike(compare, false)); |
| } |
| // 0 u<= X -> true |
| if (compare->comparison_direction() == ComparisonDirection::kLe && |
| IsAll(lhs, 0)) { |
| return ReplaceInstruction(compare, MakeScalarLike(compare, true)); |
| } |
| } |
| |
| if (compare->comparison_direction() == ComparisonDirection::kLt && |
| lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) { |
| return ReplaceInstruction(compare, MakeScalarLike(compare, false)); |
| } else if (compare->comparison_direction() == ComparisonDirection::kGt && |
| IsAll(lhs, 0) && rhs->opcode() == HloOpcode::kIota) { |
| return ReplaceInstruction(compare, MakeScalarLike(compare, false)); |
| } else if (compare->comparison_direction() == ComparisonDirection::kGe && |
| lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) { |
| return ReplaceInstruction(compare, MakeScalarLike(compare, true)); |
| } else if (compare->comparison_direction() == ComparisonDirection::kLe && |
| IsAll(lhs, 0) && rhs->opcode() == HloOpcode::kIota) { |
| return ReplaceInstruction(compare, MakeScalarLike(compare, true)); |
| } |
| if (lhs == rhs && |
| primitive_util::IsIntegralType(lhs->shape().element_type())) { |
| switch (compare->comparison_direction()) { |
| case ComparisonDirection::kGt: |
| case ComparisonDirection::kLt: |
| case ComparisonDirection::kNe: |
| return ReplaceInstruction(compare, MakeScalarLike(compare, false)); |
| case ComparisonDirection::kEq: |
| case ComparisonDirection::kGe: |
| case ComparisonDirection::kLe: |
| return ReplaceInstruction(compare, MakeScalarLike(compare, true)); |
| } |
| } |
| return Status::OK(); |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) { |
| PrimitiveType src_type = convert->operand(0)->shape().element_type(); |
| PrimitiveType dest_type = convert->shape().element_type(); |
| // A conversion to the same element type as the operand is a nop and can be |
| // removed. A conversion of a constant can be simplified by making a new |
| // constant. |
| if (src_type == dest_type) { |
| return ReplaceInstruction(convert, convert->mutable_operand(0)); |
| } |
| |
| // Eliminate a convert pair if it is a no-op. The following are a few |
| // example cases that are being handled: |
| // 1. convert(convert(A, $TYPE1), $TYPE2) is simplified to A if A is of $TYPE2 |
| // and convert(A, $TYPE1) is an upcast |
| // 2. convert(convert(A, $TYPE1),$TYPE2) is simplified to A if A is of $TYPE2 |
| // and convert(A, $TYPE1) is an upcast and is an integral conversion from |
| // unsigned to signed (only signed to unsigned conversion is NOT allowed) |
| // 3. Tuple(convert(A, $TYPE1) , floor(convert(convert(A, $TYPE1), $TYPE2)), |
| // convert(convert(A, $TYPE1), $TYPE2)) is simplified to Tuple(convert(A, |
| // $TYPE1) , floor(A), A) -> a case where the first convert has a |
| // fan-out |
| if (convert->operand(0)->opcode() == HloOpcode::kConvert && |
| IsConvertPairNoOp(convert)) { |
| return ReplaceInstruction(convert, |
| convert->mutable_operand(0)->mutable_operand(0)); |
| } |
| return Status::OK(); |
| } |
| |
| // Complex(Real(c), Imag(c)) -> c |
| Status AlgebraicSimplifierVisitor::HandleComplex(HloInstruction* complex) { |
| HloInstruction *c0, *c1; |
| if (Match(complex, m::Complex(m::Real(m::Op(&c0)), m::Imag(m::Op(&c1)))) && |
| c0 == c1) { |
| return ReplaceInstruction(complex, c0); |
| } |
| return Status::OK(); |
| } |
| |
| // Real(Complex(r, i)) -> r |
| Status AlgebraicSimplifierVisitor::HandleReal(HloInstruction* real) { |
| HloInstruction* op; |
| if (Match(real, m::Real(m::Complex(m::Op(&op), m::Op())))) { |
| return ReplaceInstruction(real, op); |
| } |
| return Status::OK(); |
| } |
| |
| // Imag(Complex(r, i)) -> i |
| Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) { |
| HloInstruction* op; |
| if (Match(imag, m::Imag(m::Complex(m::Op(), m::Op(&op))))) { |
| return ReplaceInstruction(imag, op); |
| } |
| return Status::OK(); |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) { |
| // iota -> zero if the iota dimension never produces an element other than |
| // zero. |
| auto* iota = Cast<HloIotaInstruction>(instruction); |
| if (iota->shape().dimensions(iota->iota_dimension()) <= 1) { |
| auto zero = computation_->AddInstruction( |
| simplifier_->CreateConstantWithLayoutUpdated( |
| LiteralUtil::Zero(iota->shape().element_type()).Clone())); |
| return ReplaceWithNewInstruction( |
| iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {})); |
| } |
| return Status::OK(); |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { |
| if (ShapeUtil::IsZeroElementArray(pad->operand(0)->shape())) { |
| return ReplaceWithNewInstruction( |
| pad, HloInstruction::CreateBroadcast(pad->shape(), |
| pad->mutable_operand(1), {})); |
| } |
| |
| // Interior padding on one sized dimensions have no effect. As a result it |
| // makes other simplifications possible if there is no interior padding. |
| if (HasInteriorPadding(pad->padding_config())) { |
| PaddingConfig padding_config = pad->padding_config(); |
| bool cleared_interior_padding = false; |
| for (int64 i = 0; i < pad->shape().rank(); ++i) { |
| if (padding_config.dimensions(i).interior_padding() > 0 && |
| pad->operand(0)->shape().dimensions(i) == 1) { |
| cleared_interior_padding = true; |
| padding_config.mutable_dimensions(i)->set_interior_padding(0); |
| } |
| } |
| if (cleared_interior_padding) { |
| return ReplaceWithNewInstruction( |
| pad, |
| HloInstruction::CreatePad(pad->shape(), pad->mutable_operand(0), |
| pad->mutable_operand(1), padding_config)); |
| } |
| } |
| |
| // Eliminate nop pads (padding all zero), and replace a pad with negative |
| // padding with a pad with non-negative padding followed by a slice. |
| bool all_zero = true; |
| bool has_negative = false; |
| // Used to possibly split off the unchanged padding dimensions. |
| std::vector<int64> padding_dimensions; |
| int64 dimension_index = 0; |
| for (auto& padding_dimension : pad->padding_config().dimensions()) { |
| if (padding_dimension.edge_padding_low() < 0 || |
| padding_dimension.edge_padding_high() < 0) { |
| has_negative = true; |
| } |
| if (padding_dimension.edge_padding_low() != 0 || |
| padding_dimension.edge_padding_high() != 0) { |
| all_zero = false; |
| padding_dimensions.push_back(dimension_index); |
| } else if (padding_dimension.interior_padding()) { |
| padding_dimensions.push_back(dimension_index); |
| } |
| dimension_index++; |
| } |
| |
| if (all_zero) { |
| if (ReplaceInstructionIfSameShape(pad, pad->mutable_operand(0))) { |
| return Status::OK(); |
| } |
| } |
| |
| // The context of this optimization can be found at b/163617402 |
| // It tries to capture the case of pad(broadcast(x)), where |
| // x->shape().dimensions(), or broadcast(x)->dimensions(), is |
| // a subset of the padded dimensions in pad->config(), |
| // and the padded dimensions in pad->config() is in turn a strict |
| // subset of broadcast->shape().dimensions(). The combined op can be |
| // rewritten to broadcast2(pad(broadcast1(x))), where broadcast1 extends |
| // x with dimensions that need to be padded, and broadcast2 extends |
| // the result of padding to full dimensions. |
| // TODO(qyi): for future extensions: The condition for broadcast(x) |
| // ->dimensions() to be a subset of padded dimensions in pad->config() |
| // does not have to be strictly required, but it makes the calculation |
| // for optimization easier, so it is required by the current implementation. |
| // Only the second condition between the padded dimensions and the |
| // dimensions of the final shape have to be enforced for the optimization |
| // to make sense. If needed to remove the first constraint, the shape |
| // calculations across the implementation need to be re-adjusted. |
| auto pad_dims = padding_dimensions.size(); |
| if (pad_dims < dimension_index && |
| pad->operand(0)->opcode() == HloOpcode::kBroadcast && |
| pad->operand(0)->user_count() == 1 && |
| pad->operand(0)->operand(0)->shape().rank() <= pad_dims) { |
| // Check broadcast operand dimensions is a subset of pading_dimensions. |
| // If not, skip the optimization. |
| bool opt_is_valid = true; |
| std::vector<int64> broadcast_dimensions; |
| HloBroadcastInstruction* broadcast = |
| static_cast<HloBroadcastInstruction*>(pad->mutable_operand(0)); |
| for (auto broadcast_index : broadcast->dimensions()) { |
| bool found = false; |
| for (int i = 0; i < pad_dims; ++i) { |
| if (broadcast_index == padding_dimensions[i]) { |
| broadcast_dimensions.push_back(i); |
| found = true; |
| break; |
| } |
| } |
| if (!found) { |
| opt_is_valid = false; |
| break; |
| } |
| } |
| if (opt_is_valid) { |
| auto pad_shape = pad->shape(); |
| auto broadcast_shape = broadcast->shape(); |
| auto pad_shape1 = pad_shape; |
| auto broadcast_shape1 = broadcast_shape; |
| PaddingConfig pad_config; |
| for (int i = padding_dimensions.size() - 1; i >= 0; --i) { |
| int64 j = padding_dimensions[i]; |
| while (--dimension_index > j) { |
| broadcast_shape1.DeleteDimension(dimension_index); |
| pad_shape1.DeleteDimension(dimension_index); |
| } |
| } |
| while (--dimension_index >= 0) { |
| broadcast_shape1.DeleteDimension(dimension_index); |
| pad_shape1.DeleteDimension(dimension_index); |
| } |
| for (auto dimension_to_pad : padding_dimensions) { |
| auto dimension = pad_config.add_dimensions(); |
| *dimension = pad->padding_config().dimensions(dimension_to_pad); |
| } |
| *broadcast->mutable_shape() = broadcast_shape1; |
| *broadcast->mutable_dimensions() = broadcast_dimensions; |
| simplifier_->UpdateLayout(broadcast->mutable_shape()); |
| auto pad2 = |
| computation_->AddInstruction(pad->CloneWithNewShape(pad_shape1)); |
| *pad2->mutable_padding_config() = pad_config; |
| simplifier_->UpdateLayout(pad2->mutable_shape()); |
| auto broadcast2 = computation_->AddInstruction( |
| HloInstruction::CreateBroadcast(pad_shape, pad2, padding_dimensions)); |
| return ReplaceInstruction(pad, broadcast2); |
| } |
| } |
| |
| if (has_negative && options_.enable_negative_padding_replacement()) { |
| // Pad has negative padding. Replace with a pad with the non-negative |
| // padding followed by a slice which effectively performs the negative |
| // padding. |
| // TODO(b/34628603): Add support for negative padding in the backends, or |
| // change kPad semantics to disallow negative padding and use slice |
| // instead. |
| |
| // First construct the padding config with non-negative entries and the |
| // compute the shape of this new pad instruction. |
| PaddingConfig nonzero_padding = pad->padding_config(); |
| for (int i = 0; i < pad->padding_config().dimensions_size(); ++i) { |
| PaddingConfig::PaddingConfigDimension* padding_dimension = |
| nonzero_padding.mutable_dimensions(i); |
| // Set negative padding to zero. |
| if (padding_dimension->edge_padding_low() < 0) { |
| padding_dimension->set_edge_padding_low(0); |
| } |
| if (padding_dimension->edge_padding_high() < 0) { |
| padding_dimension->set_edge_padding_high(0); |
| } |
| } |
| |
| TF_ASSIGN_OR_RETURN(HloInstruction * nonzero_pad, |
| MakePadHlo(pad->mutable_operand(0), |
| pad->mutable_operand(1), nonzero_padding)); |
| // Copy the layout from the original pad instructions. The new pad and the |
| // slice instruction should all have the same layout. |
| TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( |
| pad->shape(), nonzero_pad->mutable_shape())); |
| simplifier_->UpdateLayout(nonzero_pad->mutable_shape()); |
| |
| // Second, construct the slice instruction to perform the negative |
| // padding. |
| std::vector<int64> start_indices; |
| std::vector<int64> end_indices; |
| std::vector<int64> strides; |
| for (int64 i = 0; i < pad->padding_config().dimensions_size(); ++i) { |
| const PaddingConfig::PaddingConfigDimension& padding_dimension = |
| pad->padding_config().dimensions(i); |
| int64 start = 0; |
| if (padding_dimension.edge_padding_low() < 0) { |
| start = -1 * padding_dimension.edge_padding_low(); |
| } |
| int64 end = nonzero_pad->shape().dimensions(i); |
| if (padding_dimension.edge_padding_high() < 0) { |
| end += padding_dimension.edge_padding_high(); |
| } |
| start_indices.push_back(start); |
| end_indices.push_back(end); |
| strides.push_back(1); |
| } |
| |
| TF_ASSIGN_OR_RETURN( |
| HloInstruction * slice, |
| MakeSliceHlo(nonzero_pad, start_indices, end_indices, strides)); |
| TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( |
| pad->shape(), slice->mutable_shape())); |
| simplifier_->UpdateLayout(slice->mutable_shape()); |
| |
| // Verify that the slice shape matches the pad shape. |
| auto equal = Shape::Equal(); |
| if (!options_.is_layout_sensitive()) { |
| equal.IgnoreTilesInLayout(); |
| } |
| TF_RET_CHECK(equal(slice->shape(), pad->shape())); |
| |
| return ReplaceInstruction(pad, slice); |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { |
| VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString(); |
| HloInstruction *lhs, *rhs; |
| CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs)))); |
| if (IsAll(rhs, 0)) { |
| return ReplaceInstruction(power, MakeScalarLike(power, 1)); |
| } |
| |
| VLOG(10) << "trying transform [pow(A, 1) => A]: " << power->ToString(); |
| if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(power, lhs)) { |
| return Status::OK(); |
| } |
| |
| // pow(exp(A),B) => exp(A*B) |
| HloInstruction *a, *b; |
| if (Match(power, m::Power(m::Exp(m::Op(&a)), m::Op(&b)))) { |
| auto a_times_b = computation_->AddInstruction(HloInstruction::CreateBinary( |
| power->shape(), HloOpcode::kMultiply, a, b)); |
| return ReplaceWithNewInstruction( |
| power, HloInstruction::CreateUnary(power->shape(), HloOpcode::kExp, |
| a_times_b)); |
| } |
| |
| VLOG(10) << "trying transform [pow(A, 2) => A*A]: " << power->ToString(); |
| if (IsAll(rhs, 2)) { |
| return ReplaceWithNewInstruction( |
| power, HloInstruction::CreateBinary(power->shape(), |
| HloOpcode::kMultiply, lhs, lhs)); |
| } |
| |
| // Pow(A, 3) is used in GELU. |
| VLOG(10) << "trying transform [pow(A, 3) => A*A*A]: " << power->ToString(); |
| if (IsAll(rhs, 3)) { |
| HloInstruction* tmp = |
| computation_->AddInstruction(HloInstruction::CreateBinary( |
| power->shape(), HloOpcode::kMultiply, lhs, lhs)); |
| return ReplaceWithNewInstruction( |
| power, HloInstruction::CreateBinary(power->shape(), |
| HloOpcode::kMultiply, lhs, tmp)); |
| } |
| |
| VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString(); |
| if (IsAll(rhs, -1)) { |
| return ReplaceWithNewInstruction( |
| power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide, |
| MakeScalarLike(lhs, 1), lhs)); |
| } |
| |
| return Status::OK(); |
| } |
| |
| StatusOr<bool> |
| AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( |
| HloInstruction* broadcast) { |
| TF_RET_CHECK(broadcast->opcode() == HloOpcode::kBroadcast); |
| bool changed = false; |
| if (ShapeUtil::IsScalar(broadcast->shape())) { |
| return false; |
| } |
| HloInstruction* operand = broadcast->mutable_operand(0); |
| auto is_scalar_broadcast = [](const HloInstruction* instruction) { |
| return instruction->opcode() == HloOpcode::kBroadcast && |
| ShapeUtil::IsScalar(instruction->operand(0)->shape()); |
| }; |
| auto is_equal_broadcast = [operand, |
| broadcast](const HloInstruction* instruction) { |
| return instruction->opcode() == HloOpcode::kBroadcast && |
| ShapeUtil::Equal(operand->shape(), |
| instruction->operand(0)->shape()) && |
| broadcast->dimensions() == instruction->dimensions(); |
| }; |
| auto is_compatible_broadcast = [&](const HloInstruction* instruction) { |
| return is_scalar_broadcast(instruction) || is_equal_broadcast(instruction); |
| }; |
| for (HloInstruction* user : broadcast->users()) { |
| if (user->user_count() == 0 && user != computation_->root_instruction()) { |
| continue; |
| } |
| // Do not move reshapes or broadcasts past copies since the shape the copy |
| // will operate on will change. |
| if (user->opcode() == HloOpcode::kCopy) { |
| continue; |
| } |
| // Do not change the shape of fusion nodes in case there a multiple shapes |
| // inside the fusion node already. |
| if (user->opcode() == HloOpcode::kFusion) { |
| continue; |
| } |
| if (!user->IsElementwise()) { |
| continue; |
| } |
| |
| // Check if all the operands of the user are compatible broadcasts for |
| // sinking. (They are either scalar broadcasts or broadcasts casting |
| // from/to the same shape/dimensions) |
| int64 compatible_broadcast_count = 0; |
| int64 broadcast_use_count = 0; |
| for (HloInstruction* user_operand : user->operands()) { |
| if (is_compatible_broadcast(user_operand)) { |
| ++compatible_broadcast_count; |
| } else if (broadcast == user_operand) { |
| ++broadcast_use_count; |
| } |
| } |
| if (compatible_broadcast_count + broadcast_use_count != |
| user->operand_count()) { |
| continue; |
| } |
| std::vector<HloInstruction*> new_operands; |
| new_operands.reserve(user->operand_count()); |
| |
| Shape changed_shape; |
| for (HloInstruction* user_operand : user->operands()) { |
| // If this is a broadcast operand that is not our original broadcast input |
| // to this function then we might need to change the input. |
| if (is_compatible_broadcast(user_operand)) { |
| // If this is a broadcast from a scalar value rewrite a broadcast from |
| // the scalar to the new shape enforced from the other broadcast |
| // operands. |
| if (is_scalar_broadcast(user_operand)) { |
| changed_shape = ShapeUtil::ChangeElementType( |
| operand->shape(), user_operand->shape().element_type()); |
| simplifier_->UpdateLayout(&changed_shape); |
| new_operands.push_back( |
| computation_->AddInstruction(HloInstruction::CreateBroadcast( |
| changed_shape, user_operand->mutable_operand(0), {}))); |
| user_operand->SetupDerivedInstruction(new_operands.back()); |
| } else { |
| // For the non-scalar broadcasts we guarantee that the shape of the |
| // operand of the broadcast needs to be already a compatible shape. |
| new_operands.push_back(user_operand->mutable_operand(0)); |
| } |
| } else { |
| CHECK_EQ(broadcast, user_operand); |
| new_operands.push_back(operand); |
| } |
| } |
| VLOG(4) << "Sinking broadcast after user:"; |
| VLOG(4) << " old broadcast: " << broadcast->ToString(); |
| VLOG(4) << " old user: " << user->ToString(); |
| changed_shape = ShapeUtil::ChangeElementType(operand->shape(), |
| user->shape().element_type()); |
| simplifier_->UpdateLayout(&changed_shape); |
| HloInstruction* new_user = computation_->AddInstruction( |
| user->CloneWithNewOperands(changed_shape, new_operands)); |
| VLOG(4) << " new user: " << new_user->ToString(); |
| HloInstruction* new_broadcast = |
| computation_->AddInstruction(HloInstruction::CreateBroadcast( |
| user->shape(), new_user, broadcast->dimensions())); |
| broadcast->SetupDerivedInstruction(new_broadcast); |
| VLOG(4) << " new broadcast: " << new_broadcast->ToString(); |
| TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_broadcast)); |
| changed = true; |
| } |
| return changed; |
| } |
| |
| namespace { |
| template <typename T> |
| std::unique_ptr<HloInstruction> TryRemainderToAnd( |
| HloInstruction* remainder, HloComputation* computation, |
| AlgebraicSimplifier* simplifier) { |
| HloInstruction *a, *b, *c; |
| CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b)))); |
| |
| if (ShapeUtil::ElementIsIntegral(remainder->shape()) && |
| !Match(b, m::ConstantEffectiveScalar(&c)) && |
| !Match(b, m::Broadcast(m::ConstantEffectiveScalar(&c)))) { |
| return nullptr; |
| } |
| |
| if (ShapeUtil::ElementIsSigned(remainder->shape())) { |
| int64 b_value = c->literal().GetFirstElement<T>(); |
| if (b_value > 0 && IsPowerOfTwo(static_cast<uint64>(b_value))) { |
| // Handle negative dividends by negating the result of the division. |
| HloInstruction* zero_like_a = BroadcastZeros( |
| computation, a->shape().element_type(), a->shape().dimensions()); |
| |
| Shape compare_shape = ShapeUtil::ChangeElementType(a->shape(), PRED); |
| simplifier->UpdateLayout(&compare_shape); |
| auto* dividend_is_negative = |
| computation->AddInstruction(HloInstruction::CreateCompare( |
| compare_shape, a, zero_like_a, ComparisonDirection::kLt)); |
| |
| auto* negated_dividend = computation->AddInstruction( |
| HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a)); |
| |
| auto* abs_dividend = |
| computation->AddInstruction(HloInstruction::CreateTernary( |
| a->shape(), HloOpcode::kSelect, dividend_is_negative, |
| negated_dividend, a)); |
| |
| auto* quotient = computation->AddInstruction(HloInstruction::CreateBinary( |
| remainder->shape(), HloOpcode::kAnd, abs_dividend, |
| MakeScalarLike(abs_dividend, b_value - 1))); |
| |
| auto* neqated_quotient = |
| computation->AddInstruction(HloInstruction::CreateUnary( |
| quotient->shape(), HloOpcode::kNegate, quotient)); |
| |
| return HloInstruction::CreateTernary( |
| remainder->shape(), HloOpcode::kSelect, dividend_is_negative, |
| neqated_quotient, quotient); |
| } |
| } else { |
| uint64 b_value = c->literal().GetFirstElement<T>(); |
| if (IsPowerOfTwo(b_value)) { |
| HloInstruction* mask_amount = computation->AddInstruction( |
| simplifier->CreateConstantWithLayoutUpdated( |
| LiteralUtil::CreateR0<T>(b_value - 1))); |
| if (!ShapeUtil::IsScalar(b->shape())) { |
| mask_amount = computation->AddInstruction( |
| HloInstruction::CreateBroadcast(b->shape(), mask_amount, {})); |
| } |
| return HloInstruction::CreateBinary(remainder->shape(), HloOpcode::kAnd, |
| a, mask_amount); |
| } |
| } |
| return nullptr; |
| } |
| } // namespace |
| |
| Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) { |
| HloInstruction *a, *b; |
| CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b)))); |
| |
| // (A % B) % B == A % B. |
| if (Match(a, m::Remainder(m::Op(), m::Op().Is(b)))) { |
| return ReplaceInstruction(remainder, a); |
| } |
| |
| // A % B => A & (B - 1) if B is a power of 2. |
| switch (remainder->shape().element_type()) { |
| case S8: |
| if (std::unique_ptr<HloInstruction> shift = |
| TryRemainderToAnd<int8>(remainder, computation_, simplifier_)) { |
| return ReplaceWithNewInstruction(remainder, std::move(shift)); |
| } |
| break; |
| case S16: |
| if (std::unique_ptr<HloInstruction> shift = |
| TryRemainderToAnd<int16>(remainder, computation_, simplifier_)) { |
| return ReplaceWithNewInstruction(remainder, std::move(shift)); |
| } |
| break; |
| case S32: |
| if (std::unique_ptr<HloInstruction> shift = |
| TryRemainderToAnd<int32>(remainder, computation_, simplifier_)) { |
| return ReplaceWithNewInstruction(remainder, std::move(shift)); |
| } |
| break; |
| case S64: |
| if (std::unique_ptr<HloInstruction> shift = |
| TryRemainderToAnd<int64>(remainder, computation_, simplifier_)) { |
| return ReplaceWithNewInstruction(remainder, std::move(shift)); |
| } |
| break; |
| case U8: |
| if (std::unique_ptr<HloInstruction> shift = |
| TryRemainderToAnd<uint8>(remainder, computation_, simplifier_)) { |
| return ReplaceWithNewInstruction(remainder, std::move(shift)); |
| } |
| break; |
| case U16: |
| if (std::unique_ptr<HloInstruction> shift = |
| TryRemainderToAnd<uint16>(remainder, computation_, simplifier_)) { |
| return ReplaceWithNewInstruction(remainder, std::move(shift)); |
| } |
| break; |
| case U32: |
| if (std::unique_ptr<HloInstruction> shift = |
| TryRemainderToAnd<uint32>(remainder, computation_, simplifier_)) { |
| return ReplaceWithNewInstruction(remainder, std::move(shift)); |
| } |
| break; |
| case U64: |
| if (std::unique_ptr<HloInstruction> shift = |
| TryRemainderToAnd<uint64>(remainder, computation_, simplifier_)) { |
| return ReplaceWithNewInstruction(remainder, std::move(shift)); |
| } |
| break; |
| default: |
| break; |
| } |
| |
| // If M < N, then {0, ..., M} % N ==> {0, ..., M}. |
| // |
| // Currently this only covers the case when N is a broadcasted constant |
| // scalar. We could also cover the case when N is a non-broadcasted constant |
| // with the same value repeated. |
| HloInstruction* iota; |
| HloInstruction* divisor; |
| if (Match(remainder, |
| m::Remainder(m::Iota(&iota), |
| m::Broadcast(m::ConstantEffectiveScalar(&divisor))))) { |
| // The iota counts {0, ..., iota_upper_bound - 1}. (Actually this is |
| // conservative; the iota may overflow and count up to a smaller value than |
| // this. But that's OK for our purposes here.) |
| int64 iota_upper_bound = iota->shape().dimensions( |
| Cast<HloIotaInstruction>(iota)->iota_dimension()); |
| absl::optional<int64> divisor_val = divisor->literal().GetIntegralAsS64( |
| std::vector<int64>(0, divisor->shape().dimensions_size())); |
| if (divisor_val && *divisor_val >= iota_upper_bound) { |
| return ReplaceInstruction(remainder, iota); |
| } |
| } |
| |
| // (X + N) % N = X % N, so long as X + N does not overflow. |
| // |
| // We don't have range tracking in XLA that would let us know whether X + N |
| // overflows, so for now we only do this simplification when X is an iota. We |
| // could add other operations where it's easy to see a range, such as |
| // remainder, convert, etc., though at some point we'd probably want a |
| // range-tracking analysis. |
| HloInstruction* bcast; |
| HloInstruction* addend; |
| if (Match( |
| remainder, |
| m::Remainder( |
| m::AddAnyOrder(m::Iota(&iota), |
| m::Broadcast(m::ConstantEffectiveScalar(&addend))), |
| m::Broadcast(&bcast, m::ConstantEffectiveScalar(&divisor)))) && |
| addend == divisor) { |
| // The iota counts {0, ...iota_upper_bound - 1}, with the same caveat above |
| // that iota_upper_bound is conservative, and the true upper bound may be |
| // smaller. |
| int64 iota_upper_bound = iota->shape().dimensions( |
| Cast<HloIotaInstruction>(iota)->iota_dimension()); |
| absl::optional<int64> divisor_val = divisor->literal().GetIntegralAsS64( |
| std::vector<int64>(0, divisor->shape().dimensions_size())); |
| if (divisor_val) { |
| // Check whether divisor_val + iota_upper_bound - 1 overflows. |
| absl::optional<int64> max_val = |
| OverflowSafeAdd(*divisor_val, iota_upper_bound); |
| if (max_val.has_value() && |
| FitsInIntegralType(*max_val, iota->shape().element_type())) { |
| return ReplaceWithNewInstruction( |
| remainder, |
| HloInstruction::CreateBinary(remainder->shape(), |
| HloOpcode::kRemainder, iota, bcast)); |
| } |
| } |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { |
| auto operand = reshape->mutable_operand(0); |
| |
| // Reshape directly to empty constant if the shape contains zero-element |
| // dimension. |
| if (ShapeUtil::IsZeroElementArray(reshape->shape())) { |
| // If the instruction doesn't have a layout, use a default layout for |
| // the literal result. |
| Shape reshaped_shape = reshape->shape(); |
| if (!LayoutUtil::HasLayout(reshaped_shape)) { |
| LayoutUtil::SetToDefaultLayout(&reshaped_shape); |
| } |
| auto empty_constant = simplifier_->CreateConstantWithLayoutUpdated( |
| Literal::CreateFromShape(reshaped_shape)); |
| |
| return ReplaceWithNewInstruction(reshape, std::move(empty_constant)); |
| } |
| |
| // Delete no-op reshapes, i.e. where shape = operand shape. |
| if (SameShape(reshape, operand)) { |
| VLOG(10) << "deleting no-op reshape"; |
| return ReplaceInstruction(reshape, operand); |
| } |
| |
| // Merge reshapes. |
| if (HloOpcode::kReshape == operand->opcode()) { |
| return ReplaceWithNewInstruction( |
| reshape, HloInstruction::CreateReshape(reshape->shape(), |
| operand->mutable_operand(0))); |
| } |
| |
| if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) { |
| *operand->mutable_shape() = reshape->shape(); |
| return ReplaceInstruction(reshape, operand); |
| } |
| |
| if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) { |
| auto opt_dims = ReshapeLeavesDimensionsUnmodified( |
| reshape, reshape->operand(0)->dimensions()); |
| if (opt_dims.has_value()) { |
| return ReplaceWithNewInstruction( |
| reshape, |
| HloInstruction::CreateBroadcast( |
| reshape->shape(), reshape->mutable_operand(0)->mutable_operand(0), |
| *opt_dims)); |
| } |
| } |
| |
| // reshape(iota) -> iota or a mixed radix calculation like |
| // s32[2,3,4] reshape(s32[24] iota()) to |
| // add( |
| // add(s32[2,3,4] iota() iota_dimension=2, |
| // 4 * s32[2,3,4] iota() iota_dimension=1), |
| // 12 * s32[2,3,4] iota() iota_dimension=0). |
| if (operand->opcode() == HloOpcode::kIota) { |
| auto* iota = Cast<HloIotaInstruction>(operand); |
| auto common_factors = |
| CommonFactors(reshape->operand(0)->shape().dimensions(), |
| reshape->shape().dimensions()); |
| auto iota_dim = absl::c_find_if( |
| common_factors, [&](const std::pair<int64, int64>& dim_pair) { |
| return dim_pair.first == iota->iota_dimension() && |
| reshape->shape().dimensions(dim_pair.second) > 1; |
| }); |
| auto next_dim = absl::c_find_if( |
| common_factors, [&](const std::pair<int64, int64>& dim_pair) { |
| return dim_pair.first == iota->iota_dimension() + 1; |
| }); |
| if (iota_dim != common_factors.end() && next_dim != common_factors.end()) { |
| int64 multiplier = 1; |
| HloInstruction* new_reshape = nullptr; |
| |
| for (int64 dim = (iota_dim + 1)->second - 1; dim >= iota_dim->second; |
| --dim) { |
| HloInstruction* new_iota = computation_->AddInstruction( |
| HloInstruction::CreateIota(reshape->shape(), dim)); |
| iota->SetupDerivedInstruction(new_iota); |
| if (new_reshape) { |
| new_reshape = |
| computation_->AddInstruction(HloInstruction::CreateBinary( |
| reshape->shape(), HloOpcode::kAdd, new_reshape, |
| computation_->AddInstruction(HloInstruction::CreateBinary( |
| reshape->shape(), HloOpcode::kMultiply, new_iota, |
| MakeScalarLike(reshape, multiplier))))); |
| reshape->SetupDerivedInstruction(new_reshape); |
| } else { |
| new_reshape = new_iota; |
| } |
| multiplier *= reshape->shape().dimensions(dim); |
| } |
| reshape->SetupDerivedInstruction(new_reshape); |
| return ReplaceInstruction(reshape, new_reshape); |
| } |
| } |
| |
| // Make this a bitcast if possible. |
| if (HloInstruction* bitcast_operand = |
| BitcastingOperandOfReshapeOrCopyChain(reshape, options_)) { |
| ReplaceWithBitcast(reshape, bitcast_operand); |
| } |
| return Status::OK(); |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) { |
| // When all the dimensions to reverse are trivial (i.e. the bound is 1), |
| // there is nothing to be done. |
| auto dim_is_one = [&](int64 i) -> bool { |
| return reverse->shape().dimensions(i) == 1; |
| }; |
| if (absl::c_all_of(reverse->dimensions(), dim_is_one)) { |
| return ReplaceInstruction(reverse, reverse->mutable_operand(0)); |
| } |
| return Status::OK(); |
| } |
| |
| StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyScalarSlice( |
| HloInstruction* slice) { |
| // Only try to do this for effective scalars. We could do the same for slicing |
| // out larger pieces of padding (replacing with a broadcast of the padding |
| // value), but this is probably not worth it. |
| if (!ShapeUtil::IsEffectiveScalar(slice->shape())) { |
| return false; |
| } |
| |
| if (slice->operand(0)->opcode() == HloOpcode::kConcatenate) { |
| VLOG(10) << "Trying to simplify scalar slice of concat"; |
| // Only do this for R1, there's no chance of this being useful otherwise. |
| if (slice->shape().rank() != 1) { |
| VLOG(10) << "Not folding, slice is not rank 1"; |
| return false; |
| } |
| HloConcatenateInstruction* concat = |
| Cast<HloConcatenateInstruction>(slice->mutable_operand(0)); |
| int64 operand_start = 0; |
| int64 operand_num = 0; |
| // Weird loop structure to avoid annoying off-by-one errors. |
| while (true) { |
| TF_RET_CHECK(operand_num < concat->operand_count()); |
| const HloInstruction* operand = concat->operand(operand_num); |
| int64 next_operand_start = operand_start + operand->shape().dimensions(0); |
| if (next_operand_start > slice->slice_starts(0)) { |
| break; |
| } |
| operand_start = next_operand_start; |
| operand_num++; |
| } |
| |
| bool replaced = ReplaceInstructionIfSameShape( |
| slice, concat->mutable_operand(operand_num)); |
| if (replaced) { |
| VLOG(10) << "Folding scalar slice of concat into concat operand"; |
| } else { |
| VLOG(10) << "Folding scalar slice of concat into slice of concat operand"; |
| TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( |
| slice, HloInstruction::CreateSlice( |
| slice->shape(), concat->mutable_operand(operand_num), |
| {slice->slice_starts(0) - operand_start}, |
| {slice->slice_starts(0) - operand_start + 1}, |
| slice->slice_strides()))); |
| } |
| return true; |
| } |
| |
| return false; |
| } |
| |
| StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape( |
| HloInstruction* slice) { |
| CHECK_EQ(slice->opcode(), HloOpcode::kSlice); |
| if (!IsUnstridedSlice(slice)) { |
| return false; |
| } |
| HloInstruction* reshape = slice->mutable_operand(0); |
| if (reshape->opcode() != HloOpcode::kReshape) { |
| return false; |
| } |
| HloInstruction* new_slice_operand = reshape->mutable_operand(0); |
| int64 slice_rank = slice->shape().rank(); |
| std::vector<int64> sliced_dims; |
| for (int64 i = 0; i < slice_rank; ++i) { |
| if (slice->slice_starts(i) != 0 || |
| slice->slice_limits(i) != reshape->shape().dimensions(i)) { |
| sliced_dims.push_back(i); |
| } |
| } |
| |
| if (sliced_dims.size() == 1 && sliced_dims[0] == 0 && |
| slice->slice_starts(0) == 0) { |
| const Shape& new_slice_shape = new_slice_operand->shape(); |
| const int64 rank = new_slice_shape.rank(); |
| std::vector<int64> new_slice_starts(rank, 0); |
| std::vector<int64> new_slice_stides(rank, 1); |
| std::vector<int64> new_slice_limits(new_slice_shape.dimensions().begin(), |
| new_slice_shape.dimensions().end()); |
| int64 slice_elements = ShapeUtil::ElementsIn(slice->shape()); |
| for (int64 i = rank - 1; i >= 0; --i) { |
| if (slice_elements >= new_slice_limits[i]) { |
| if (slice_elements % new_slice_limits[i] != 0) { |
| return false; |
| } |
| slice_elements /= new_slice_limits[i]; |
| } else { |
| new_slice_limits[i] = slice_elements; |
| slice_elements = 1; |
| } |
| } |
| HloInstruction* new_slice = |
| computation_->AddInstruction(HloInstruction::CreateSlice( |
| ShapeUtil::MakeShape(new_slice_shape.element_type(), |
| new_slice_limits), |
| new_slice_operand, new_slice_starts, new_slice_limits, |
| new_slice_stides)); |
| simplifier_->UpdateLayout(new_slice->mutable_shape()); |
| TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( |
| slice, HloInstruction::CreateReshape(slice->shape(), new_slice))); |
| return true; |
| } |
| return false; |
| } |
| |
| // Allowing a slice to move through a reverse with any necessary updates to the |
| // slice config. |
| StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReverse( |
| HloInstruction* slice) { |
| VLOG(2) << "Entered TryToReorderSliceAndReverse for slice:" |
| << slice->ToString(); |
| if (Match(slice, m::Slice(m::Reverse()))) { |
| HloInstruction* reverse = slice->mutable_operand(0); |
| HloInstruction* reverse_operand = reverse->mutable_operand(0); |
| std::vector<int64> new_starts = slice->slice_starts(); |
| std::vector<int64> new_limits = slice->slice_limits(); |
| std::vector<int64> new_strides = slice->slice_strides(); |
| for (auto rdim : reverse->dimensions()) { |
| int64 start = slice->slice_starts(rdim); |
| int64 limit = slice->slice_limits(rdim); |
| int64 stride = slice->slice_strides(rdim); |
| // find_nth allows us to compute the appropriate index to begin |
| // with during reverse even in the presence of non-unit strides |
| int64 find_nth = (limit - start - 1) / stride; |
| find_nth = start + find_nth * stride; |
| limit = find_nth + 1; |
| new_starts[rdim] = |
| (reverse->shape().dimensions(rdim) - start) - (limit - start); |
| new_limits[rdim] = reverse->shape().dimensions(rdim) - start; |
| VLOG(2) << "Analyzing dim:" << rdim << " (start,limit):" << start << "," |
| << limit << " and new (start, limit):" << new_starts[rdim] << "," |
| << new_limits[rdim]; |
| } |
| // New slice formed from the reverse_operand, but strides and shape of the |
| // slice output remains the same. New slice's starts and limits are updated |
| // for ONLY the reversed dimensions as indicated above. |
| HloInstruction* new_slice = computation_->AddInstruction( |
| HloInstruction::CreateSlice(slice->shape(), reverse_operand, new_starts, |
| new_limits, new_strides)); |
| simplifier_->UpdateLayout(new_slice->mutable_shape()); |
| TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( |
| slice, HloInstruction::CreateReverse(new_slice->shape(), new_slice, |
| reverse->dimensions()))); |
| // We do not delete the old reverse, since there might be another |
| // consumer of that reverse (i.e., full reverse output). DCE should take |
| // care of any deletion that is necessary if there was no use of reverse. |
| return true; |
| } |
| return false; |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { |
| // Delete no-op slices, i.e. where shape = operand shape. |
| if (ReplaceInstructionIfSameShape(slice, slice->mutable_operand(0))) { |
| return Status::OK(); |
| } |
| |
| HloInstruction* pad; |
| HloInstruction* pad_operand; |
| if (Match(slice, m::Slice(m::Pad(&pad, m::Op(&pad_operand), m::Op())))) { |
| // Is the result of the slice the pad operand. |
| bool slice_undoes_pad = true; |
| // Can the slice be moved to the pad_operand without any padding being read. |
| bool slice_inside_pad = true; |
| // Does this slice slice out pading only. |
| bool slice_in_padding = false; |
| std::vector<int64> new_starts = slice->slice_starts(); |
| std::vector<int64> new_limits = slice->slice_limits(); |
| for (int64 i = 0; i < slice->shape().rank(); ++i) { |
| const int64 start = slice->slice_starts(i); |
| const int64 stride = slice->slice_strides(i); |
| const int64 limit = slice->slice_limits(i); |
| const int64 size = pad->shape().dimensions(i); |
| |
| const auto& dim = pad->padding_config().dimensions(i); |
| const int64 low = dim.edge_padding_low(); |
| const int64 high = dim.edge_padding_high(); |
| const int64 interior = dim.interior_padding(); |
| const int64 edge = size - high; |
| |
| if (limit <= low || start >= edge) { |
| slice_in_padding = true; |
| break; |
| } |
| |
| if (start != low || stride - 1 != interior) { |
| slice_undoes_pad = false; |
| } |
| |
| if (start < low || limit > edge || interior != 0 || stride != 1) { |
| slice_inside_pad = false; |
| } |
| new_starts[i] -= low; |
| new_limits[i] -= low; |
| } |
| if (slice_in_padding) { |
| HloInstruction* broadcast = |
| MakeBroadcastHlo(pad->mutable_operand(1), {}, slice->shape()); |
| *(broadcast->mutable_shape()) = slice->shape(); |
| return ReplaceInstruction(slice, broadcast); |
| } |
| if (slice_undoes_pad && ReplaceInstructionIfSameShape(slice, pad_operand)) { |
| return Status::OK(); |
| } |
| if (slice_inside_pad) { |
| TF_ASSIGN_OR_RETURN(HloInstruction * new_slice, |
| MakeSliceHlo(pad_operand, new_starts, new_limits, |
| slice->slice_strides())); |
| *(new_slice->mutable_shape()) = slice->shape(); |
| return ReplaceInstruction(slice, new_slice); |
| } |
| } |
| |
| if (slice->operand(0)->opcode() == HloOpcode::kSlice && |
| IsUnstridedSlice(slice) && IsUnstridedSlice(slice->operand(0))) { |
| HloInstruction* operand_slice = slice->mutable_operand(0); |
| std::vector<int64> new_slice_starts = slice->slice_starts(); |
| std::vector<int64> new_slice_limits = slice->slice_limits(); |
| for (int64 i = 0; i < new_slice_starts.size(); ++i) { |
| new_slice_starts[i] += operand_slice->slice_starts(i); |
| new_slice_limits[i] += operand_slice->slice_starts(i); |
| } |
| return ReplaceWithNewInstruction( |
| slice, HloInstruction::CreateSlice( |
| slice->shape(), operand_slice->mutable_operand(0), |
| new_slice_starts, new_slice_limits, slice->slice_strides())); |
| } |
| |
| auto only_broadcast_dims_sliced = [&] { |
| if (slice->operand(0)->opcode() != HloOpcode::kBroadcast) { |
| return false; |
| } |
| for (int64 dim : slice->operand(0)->dimensions()) { |
| if (slice->slice_starts(dim) != 0 || slice->slice_strides(dim) != 1 || |
| slice->slice_limits(dim) != |
| slice->operand(0)->shape().dimensions(dim)) { |
| return false; |
| } |
| } |
| return true; |
| }; |
| if (only_broadcast_dims_sliced()) { |
| return ReplaceWithNewInstruction( |
| slice, |
| HloInstruction::CreateBroadcast( |
| slice->shape(), slice->mutable_operand(0)->mutable_operand(0), |
| slice->mutable_operand(0)->dimensions())); |
| } |
| |
| TF_ASSIGN_OR_RETURN(bool replaced, TrySimplifyScalarSlice(slice)); |
| if (replaced) { |
| return Status::OK(); |
| } |
| |
| HloInstruction* broadcast; |
| HloInstruction* broadcast_operand; |
| if (Match(slice, |
| m::Slice(m::Broadcast(&broadcast, m::Op(&broadcast_operand))))) { |
| std::vector<int64> new_slice_starts; |
| std::vector<int64> new_slice_strides; |
| std::vector<int64> new_slice_limits; |
| new_slice_starts.reserve(broadcast_operand->shape().rank()); |
| new_slice_strides.reserve(broadcast_operand->shape().rank()); |
| new_slice_limits.reserve(broadcast_operand->shape().rank()); |
| for (int64 dim : broadcast->dimensions()) { |
| new_slice_starts.push_back(slice->slice_starts(dim)); |
| new_slice_strides.push_back(slice->slice_strides(dim)); |
| new_slice_limits.push_back(slice->slice_limits(dim)); |
| } |
| VLOG(3) << "Sink broadcast through slice"; |
| VLOG(3) << "Original slice: " << slice->ToString(); |
| VLOG(3) << "Original broadcast: " << broadcast->ToString(); |
| auto new_slice_shape = broadcast_operand->shape(); |
| for (int64 i = 0; i < broadcast_operand->shape().rank(); ++i) { |
| int64 size_i = (new_slice_limits[i] - new_slice_starts[i] + |
| new_slice_strides[i] - 1) / |
| new_slice_strides[i]; |
| new_slice_shape.set_dimensions(i, size_i); |
| } |
| simplifier_->UpdateLayout(&new_slice_shape); |
| HloComputation* computation = broadcast_operand->parent(); |
| auto new_slice = computation->AddInstruction(HloInstruction::CreateSlice( |
| new_slice_shape, broadcast_operand, new_slice_starts, new_slice_limits, |
| new_slice_strides)); |
| auto new_broadcast = HloInstruction::CreateBroadcast( |
| slice->shape(), new_slice, broadcast->dimensions()); |
| VLOG(3) << "New slice: " << slice->ToString(); |
| VLOG(3) << "New broadcast: " << new_broadcast->ToString(); |
| return ReplaceWithNewInstruction(slice, std::move(new_broadcast)); |
| } |
| |
| // Try to simplify concat -> slice to an operand of concat. |
| if (slice->operand(0)->opcode() == HloOpcode::kConcatenate && |
| IsUnstridedSlice(slice)) { |
| auto concat = slice->operand(0); |
| int64 concat_dim = concat->concatenate_dimension(); |
| int64 piece_start = 0; |
| for (auto piece : concat->operands()) { |
| if (!SameShape(piece, slice)) { |
| piece_start += piece->shape().dimensions(concat_dim); |
| continue; |
| } |
| if (slice->slice_starts(concat_dim) == piece_start) { |
| return ReplaceInstruction(slice, piece); |
| } |
| piece_start += piece->shape().dimensions(concat_dim); |
| } |
| } |
| |
| // Do not try to reorder slices and reshapes after layout assignment as it may |
| // be invalid. |
| if (!options_.is_layout_sensitive()) { |
| TF_ASSIGN_OR_RETURN(replaced, TryToReorderSliceAndReshape(slice)); |
| } |
| if (replaced) { |
| return Status::OK(); |
| } |
| |
| bool reversed = false; |
| if (Match(slice, m::Slice(m::Reverse(m::Op())))) { |
| TF_ASSIGN_OR_RETURN(reversed, TryToReorderSliceAndReverse(slice)); |
| } |
| if (reversed) { |
| return Status::OK(); |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleRsqrt(HloInstruction* rsqrt) { |
| VLOG(10) << "trying transform [rsqrt(Pow(A, -2)) => |A|] " |
| << rsqrt->ToString(); |
| HloInstruction* rsqrt_operand = rsqrt->mutable_operand(0); |
| if (rsqrt_operand->opcode() == HloOpcode::kPower && |
| IsAll(rsqrt_operand->operand(1), -2) && |
| IsPositive(rsqrt_operand, options_)) { |
| return ReplaceWithNewInstruction( |
| rsqrt, HloInstruction::CreateUnary(rsqrt->shape(), HloOpcode::kAbs, |
| rsqrt_operand->mutable_operand(0))); |
| } |
| |
| VLOG(10) << "trying transform [rsqrt(Divide(1, A)) => sqrt(A)] " |
| << rsqrt->ToString(); |
| if (rsqrt_operand->opcode() == HloOpcode::kDivide && |
| IsAll(rsqrt_operand->operand(0), 1) && |
| IsPositive(rsqrt_operand->operand(1), options_)) { |
| return ReplaceWithNewInstruction( |
| rsqrt, HloInstruction::CreateUnary(rsqrt->shape(), HloOpcode::kSqrt, |
| rsqrt_operand->mutable_operand(1))); |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleDynamicSlice( |
| HloInstruction* dynamic_slice) { |
| auto operand = dynamic_slice->mutable_operand(0); |
| if (ShapeUtil::IsScalar(dynamic_slice->shape())) { |
| return ReplaceInstruction(dynamic_slice, operand); |
| } |
| // DynamicSlice where operand has the same size as the output is simply equal |
| // to operand. |
| if (SameShape(operand, dynamic_slice)) { |
| return ReplaceInstruction(dynamic_slice, operand); |
| } |
| |
| HloInstruction* broadcast_operand; |
| if (Match(operand, m::Broadcast(m::Op(&broadcast_operand)))) { |
| std::vector<HloInstruction*> new_indices; |
| new_indices.reserve(broadcast_operand->shape().rank()); |
| std::vector<int64> new_slice_sizes; |
| new_slice_sizes.reserve(broadcast_operand->shape().rank()); |
| |
| for (int64 dim : operand->dimensions()) { |
| new_indices.push_back(dynamic_slice->mutable_operand(1 + dim)); |
| new_slice_sizes.push_back(dynamic_slice->slice_sizes(dim)); |
| } |
| |
| VLOG(3) << "Sink broadcast through dynamic slice"; |
| VLOG(3) << "Original dynamic slice: " << dynamic_slice->ToString(); |
| VLOG(3) << "Original broadcast: " << operand->ToString(); |
| HloInstruction* new_dynamic_slice = broadcast_operand; |
| if (!new_slice_sizes.empty()) { |
| auto new_ds_shape = broadcast_operand->shape(); |
| for (int64 i = 0; i < broadcast_operand->shape().rank(); ++i) { |
| new_ds_shape.set_dimensions(i, new_slice_sizes[i]); |
| } |
| simplifier_->UpdateLayout(&new_ds_shape); |
| HloComputation* computation = broadcast_operand->parent(); |
| new_dynamic_slice = |
| computation->AddInstruction(HloInstruction::CreateDynamicSlice( |
| new_ds_shape, broadcast_operand, new_indices, new_slice_sizes)); |
| } |
| auto new_broadcast = HloInstruction::CreateBroadcast( |
| dynamic_slice->shape(), new_dynamic_slice, operand->dimensions()); |
| VLOG(3) << "New dynamic slice: " << dynamic_slice->ToString(); |
| VLOG(3) << "New broadcast: " << new_broadcast->ToString(); |
| return ReplaceWithNewInstruction(dynamic_slice, std::move(new_broadcast)); |
| } |
| |
| // Convert a dynamic slice into a slice if all offsets are constant and the |
| // operand is not constant. |
| if (operand->opcode() != HloOpcode::kConstant && |
| absl::c_all_of(absl::MakeSpan(dynamic_slice->operands().begin() + 1, |
| dynamic_slice->operands().end()), |
| [](HloInstruction* operand) { |
| return operand->opcode() == HloOpcode::kConstant && |
| ShapeUtil::ElementIsIntegral(operand->shape()); |
| })) { |
| const int64 rank = operand->shape().rank(); |
| std::vector<int64> slice_starts(rank); |
| std::vector<int64> slice_limits(rank); |
| std::vector<int64> slice_strides(rank, 1); |
| |
| for (int64 i = 0; i < rank; ++i) { |
| absl::optional<int64> offset = |
| dynamic_slice->operand(i + 1)->literal().GetFirstInteger(); |
| if (!offset || *offset < 0) { |
| return Status::OK(); |
| } |
| const int64 max_offset = |
| dynamic_slice->operand(0)->shape().dimensions(i) - |
| dynamic_slice->shape().dimensions(i); |
| slice_starts[i] = std::min(max_offset, *offset); |
| slice_limits[i] = |
| std::min(max_offset, *offset) + dynamic_slice->shape().dimensions(i); |
| } |
| return ReplaceWithNewInstruction( |
| dynamic_slice, |
| HloInstruction::CreateSlice(dynamic_slice->shape(), operand, |
| slice_starts, slice_limits, slice_strides)); |
| } |
| |
| // Convert the dynamic slice of an iota to just a reference to the index |
| // (possibly clamped and scaled). Index is always a scalar integer. Output |
| // should be a rank 1 array of size 1 with element type matching that of the |
| // scalar index (except the signedness). |
| const PrimitiveType element_type = dynamic_slice->shape().element_type(); |
| if (operand->shape().rank() == 1 && dynamic_slice->shape().rank() == 1 && |
| dynamic_slice->shape().dimensions(0) == 1 && |
| (element_type == S32 || element_type == U32)) { |
| // Match multiply(x, broadcast(scalar)) and return the scalar |
| // constant. |
| auto match_multiply_with_scalar = |
| [&](HloInstruction* hlo) -> HloInstruction* { |
| if (hlo->opcode() != HloOpcode::kMultiply) { |
| return nullptr; |
| } |
| HloInstruction* broadcast = hlo->mutable_operand(1); |
| if (broadcast->opcode() == HloOpcode::kBroadcast && |
| broadcast->dimensions().empty() && |
| ShapeUtil::IsScalar(broadcast->operand(0)->shape())) { |
| return broadcast->mutable_operand(0); |
| } |
| return nullptr; |
| }; |
| |
| HloInstruction* multiplier = match_multiply_with_scalar(operand); |
| if (multiplier) { |
| operand = operand->mutable_operand(0); |
| } |
| |
| if (operand->opcode() == HloOpcode::kIota) { |
| // This dynamic_slice will have a single start_index operand (since its |
| // operand is rank 1). |
| HloInstruction* index = dynamic_slice->mutable_operand(1); |
| const PrimitiveType index_type = index->shape().element_type(); |
| |
| auto create_constant = [&](int64 value) { |
| if (index_type == S32) { |
| return MakeScalarLike<int32_t>(index, value); |
| } else { |
| return MakeScalarLike<uint32_t>(index, value); |
| } |
| }; |
| |
| if (index_type == S32 || index_type == U32) { |
| // Clamp the index to the range of the iota. |
| int64 iota_size = operand->shape().dimensions(0); |
| HloInstruction* low = create_constant(0); |
| HloInstruction* high = create_constant(iota_size - 1); |
| HloInstruction* clamped = |
| computation_->AddInstruction(HloInstruction::CreateTernary( |
| index->shape(), HloOpcode::kClamp, low, index, high)); |
| |
| // Convert the clamped index from index_type to element_type and |
| // multiply with the multiplier. |
| HloInstruction* result = clamped; |
| if (index_type != element_type) { |
| result = computation_->AddInstruction(HloInstruction::CreateConvert( |
| ShapeUtil::MakeScalarShape(element_type), clamped)); |
| } |
| |
| if (multiplier) { |
| result = computation_->AddInstruction(HloInstruction::CreateBinary( |
| result->shape(), HloOpcode::kMultiply, result, multiplier)); |
| } |
| |
| return ReplaceWithNewInstruction( |
| dynamic_slice, |
| HloInstruction::CreateReshape( |
| ShapeUtil::MakeShape(element_type, {1}), result)); |
| } |
| } |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( |
| HloInstruction* dynamic_update_slice) { |
| // Rewriting DynamicUpdateSlice when it matches |
| // dynamic_update_slice(broadcast(constant),data,constant_index0,...) |
| // to a Pad(x, constant) |
| // Only Broadcast considered currently, other ops need to be considered |
| // in the future. |
| HloInstruction* updated = dynamic_update_slice->mutable_operand(0); |
| HloInstruction* dus_update = dynamic_update_slice->mutable_operand(1); |
| HloInstruction* pad_value; |
| if (Match(updated, |
| m::Broadcast(m::Op(&pad_value).WithShape(m::Shape().IsScalar())))) { |
| auto updated_shape = updated->shape(); |
| auto update_shape = dus_update->shape(); |
| auto update_start_indx = dynamic_update_slice->operand(2); |
| int64 offset = 0; |
| bool compatible = true; |
| // Whether the start indices to dynamic update slice is a list, |
| // output of a tuple/concatenate, we setup the update_start_indx |
| // appropriately. |
| if (ShapeUtil::IsScalar(update_start_indx->shape())) { |
| update_start_indx = dynamic_update_slice; |
| offset = 2; |
| } else { |
| if (update_start_indx->opcode() == HloOpcode::kTuple || |
| update_start_indx->opcode() == HloOpcode::kConcatenate) { |
| offset = 0; |
| } else { |
| compatible = false; |
| } |
| } |
| PaddingConfig padding_config; |
| if (compatible) { |
| for (int64 dim = 0; dim < updated_shape.rank(); ++dim) { |
| auto padding_config_dim = padding_config.add_dimensions(); |
| auto slice_dim_start = update_start_indx->operand(dim + offset); |
| if (!Match(slice_dim_start, m::ConstantScalar())) { |
| compatible = false; |
| break; |
| } |
| VLOG(2) << "slice: " << slice_dim_start->ToString(); |
| absl::optional<int64> beg = |
| slice_dim_start->literal().GetFirstInteger(); |
| if (!beg) { |
| compatible = false; |
| break; |
| } |
| VLOG(2) << "beg value: " << *beg; |
| auto update_width = ShapeUtil::GetDimension(update_shape, dim); |
| auto bcast_width = ShapeUtil::GetDimension(updated_shape, dim); |
| // Clamp beg so that it is non-negative. |
| *beg = std::max<int64>(0, *beg); |
| // Clamp beg so that it is in-bounds. |
| *beg = std::min<int64>(bcast_width - update_width, *beg); |
| VLOG(2) << "adjusted beg value: " << *beg; |
| padding_config_dim->set_edge_padding_low(*beg); |
| padding_config_dim->set_edge_padding_high(bcast_width - |
| (*beg + update_width)); |
| // dynamic_update_slice does not specify a stride |
| padding_config_dim->set_interior_padding(0); |
| } |
| } |
| |
| if (compatible) { |
| HloInstruction* pad = |
| computation_->AddInstruction(HloInstruction::CreatePad( |
| updated_shape, dus_update, pad_value, padding_config)); |
| VLOG(2) << dynamic_update_slice->ToString(); |
| VLOG(2) << " with pad:" << pad->ToString(); |
| VLOG(2) << " Computation before rewrite is: " |
| << dynamic_update_slice->parent()->ToString(); |
| return ReplaceInstruction(dynamic_update_slice, pad); |
| } |
| } |
| |
| // DynamicUpdateSlice where operand and dus_update have the same size is |
| // equal to dus_update. |
| if (SameShape(dynamic_update_slice, dus_update)) { |
| return ReplaceInstruction(dynamic_update_slice, dus_update); |
| } |
| |
| // If any dimension of dus_update is 0, elide the DynamicUpdateSlice. This |
| // optimization becomes invalid should we later prefer to warn about out of |
| // bound indices. |
| if (ShapeUtil::IsZeroElementArray(dus_update->shape())) { |
| return ReplaceInstruction(dynamic_update_slice, updated); |
| } |
| return Status::OK(); |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { |
| HloReduceInstruction* reduce = Cast<HloReduceInstruction>(hlo); |
| bool multi_output_reduce = reduce->shape().IsTuple(); |
| // For tuple reduce, we require all reduce shapes to be the same, up to the |
| // element types, so we can just the first operand and the first result as a |
| // representative. |
| auto arg = reduce->inputs()[0]; |
| auto init_value = reduce->init_values()[0]; |
| const Shape& reduce_result_shape = |
| multi_output_reduce ? reduce->shape().tuple_shapes(0) : reduce->shape(); |
| |
| absl::Span<const int64> dimensions(reduce->dimensions()); |
| HloComputation* function = reduce->to_apply(); |
| if (ShapeUtil::IsZeroElementArray(arg->shape()) || |
| ShapeUtil::IsZeroElementArray(reduce_result_shape)) { |
| if (multi_output_reduce) { |
| std::vector<HloInstruction*> broadcast_inits; |
| int64 inputs = reduce->input_count(); |
| for (int64 i = 0; i < inputs; ++i) { |
| broadcast_inits.push_back(computation_->AddInstruction( |
| HloInstruction::CreateBroadcast(reduce->shape().tuple_shapes(i), |
| reduce->init_values()[i], {}))); |
| } |
| return ReplaceWithNewInstruction( |
| reduce, HloInstruction::CreateTuple(broadcast_inits)); |
| } else { |
| return ReplaceWithNewInstruction( |
| reduce, |
| HloInstruction::CreateBroadcast(reduce_result_shape, init_value, {})); |
| } |
| } |
| |
| // Turn trivial variadic reductions into normal reductions. |
| if (multi_output_reduce && reduce->shape().tuple_shapes_size() == 1 && |
| reduce->input_count() == 1 && |
| Match(function->root_instruction(), m::Tuple())) { |
| absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>> |
| replacements; |
| replacements[function->root_instruction()] = nullptr; |
| auto new_function = computation_->parent()->AddEmbeddedComputation( |
| function->CloneWithReplacements( |
| std::move(replacements), /*extra_parameters=*/{}, |
| /*context=*/nullptr, |
| /*suffix=*/"clone", |
| /*new_root=*/function->root_instruction()->operand(0))); |
| auto new_reduce = computation_->AddInstruction( |
| HloInstruction::CreateReduce(reduce_result_shape, arg, init_value, |
| reduce->dimensions(), new_function)); |
| return ReplaceWithNewInstruction(reduce, |
| HloInstruction::CreateTuple({new_reduce})); |
| } |
| |
| if (options_.is_layout_sensitive()) { |
| return Status::OK(); |
| } |
| |
| // If the reduction results in the same number of elements, then the only |
| // possible side effect would be a reshape. Since the init_value is an |
| // identity of the reduction function, we can therefore replace the reduce |
| // with a simple reshape, ignoring the reduction function completely. |
| if (ShapeUtil::ElementsIn(reduce_result_shape) == |
| ShapeUtil::ElementsIn(arg->shape())) { |
| if (multi_output_reduce) { |
| std::vector<HloInstruction*> reshaped_args; |
| int64 inputs = reduce->input_count(); |
| for (int64 i = 0; i < inputs; ++i) { |
| reshaped_args.push_back( |
| computation_->AddInstruction(HloInstruction::CreateReshape( |
| reduce->shape().tuple_shapes(i), reduce->inputs()[i]))); |
| } |
| return ReplaceWithNewInstruction( |
| reduce, HloInstruction::CreateTuple(reshaped_args)); |
| } else { |
| return ReplaceWithNewInstruction( |
| reduce, HloInstruction::CreateReshape(reduce_result_shape, arg)); |
| } |
| } |
| |
| // TODO(b/131122694): Most of those optimizations below can be done for |
| // multi-output reduces. |
| if (multi_output_reduce) { |
| return Status::OK(); |
| } |
| |
| // A Transpose feeding a reduce can simply permute the reduction dimensions |
| // field if the output of the reduce is a vector or scalar. Higher ranked |
| // result may require a transpose of the output. |
| if (arg->opcode() == HloOpcode::kTranspose && |
| (reduce->shape().rank() < 2 || arg->user_count() == 1 || |
| absl::c_all_of(arg->users(), [](HloInstruction* use) { |
| return use->opcode() == HloOpcode::kReduce; |
| }))) { |
| auto transpose_dimensions = arg->dimensions(); |
| std::vector<int64> new_reduce_dimensions; |
| for (auto dim : dimensions) { |
| new_reduce_dimensions.push_back(transpose_dimensions[dim]); |
| } |
| |
| Shape new_reduce_result_shape = ShapeUtil::FilterDimensions( |
| [&](const int64 dim) { |
| return !absl::c_linear_search(new_reduce_dimensions, dim); |
| }, |
| arg->mutable_operand(0)->shape()); |
| HloInstruction* new_reduce = |
| computation_->AddInstruction(HloInstruction::CreateReduce( |
| new_reduce_result_shape, arg->mutable_operand(0), init_value, |
| new_reduce_dimensions, function)); |
| reduce->SetupDerivedInstruction(new_reduce); |
| std::vector<int64> new_transpose_dimensions; |
| for (auto dim : transpose_dimensions) { |
| if (absl::c_linear_search(new_reduce_dimensions, dim)) { |
| continue; |
| } |
| new_transpose_dimensions.push_back(dim); |
| } |
| |
| // If new transpose dimensions are sorted, then there is no need to |
| // transpose reduce result. |
| if (absl::c_is_sorted(new_transpose_dimensions)) { |
| return ReplaceInstruction(reduce, new_reduce); |
| } |
| for (auto& d : new_transpose_dimensions) { |
| auto old_dim = d; |
| for (auto reduced_dim : new_reduce_dimensions) { |
| if (old_dim > reduced_dim) { |
| --d; |
| } |
| } |
| } |
| TF_ASSIGN_OR_RETURN(HloInstruction * new_transpose, |
| MakeTransposeHlo(new_reduce, new_transpose_dimensions)); |
| return ReplaceInstruction(reduce, new_transpose); |
| } |
| |
| // If a reduce feeds a reduce with the same computation and initial value, |
| // they can be combined into a single reduce. |
| if (arg->opcode() == HloOpcode::kReduce && |
| init_value->Identical(*arg->operand(1)) && |
| *function == *arg->to_apply()) { |
| // Create a new reduce with the combined reduction dimensions of both |
| // reduces. |
| std::vector<int64> arg_dims = arg->dimensions(); |
| absl::c_sort(arg_dims); |
| std::vector<int64> reduce_dims = reduce->dimensions(); |
| absl::c_sort(reduce_dims); |
| // Transform reduce_dims to the same rank as the operand of the operand. |
| for (int64 arg_dim : arg_dims) { |
| for (int64& dim : reduce_dims) { |
| if (dim >= arg_dim) { |
| ++dim; |
| } |
| } |
| } |
| std::vector<int64> new_dimensions; |
| new_dimensions.reserve(arg->dimensions().size() + |
| reduce->dimensions().size()); |
| std::merge(arg_dims.begin(), arg_dims.end(), reduce_dims.begin(), |
| reduce_dims.end(), std::back_inserter(new_dimensions)); |
| return ReplaceWithNewInstruction( |
| reduce, HloInstruction::CreateReduce( |
| reduce_result_shape, arg->mutable_operand(0), init_value, |
| new_dimensions, function)); |
| } |
| |
| // A reshape that collapses multiple dimensions into a dimension being |
| // reduced can just reduce all of those dimensions instead of doing a |
| // collapsing reshape before a reduction. |
| if (options_.enable_reduce_of_reshape() && |
| arg->opcode() == HloOpcode::kReshape) { |
| std::vector<std::pair<int64, int64>> unmodified_dims = |
| ShapeUtil::DimensionsUnmodifiedByReshape(arg->operand(0)->shape(), |
| arg->shape()); |
| std::vector<bool> arg_dim_in_output(arg->shape().rank(), true); |
| std::vector<bool> arg_dim_unmodified(arg->shape().rank(), false); |
| for (auto dim : dimensions) { |
| arg_dim_in_output[dim] = false; |
| } |
| for (auto dim_pair : unmodified_dims) { |
| arg_dim_unmodified[dim_pair.second] = true; |
| } |
| // The goal is to verify that all dimensions that are not removed in the |
| // reduce are unmodified by the reshape. For example: |
| // reduce(reshape([A,B*C], a[A,B,C]),[1]) = reduce(a[A, B, C], [1, 2]) |
| bool can_move_reshape_into_reduce = true; |
| for (int64 i = 0; i < arg_dim_in_output.size(); ++i) { |
| if (arg_dim_in_output[i] && !arg_dim_unmodified[i]) { |
| can_move_reshape_into_reduce = false; |
| } |
| } |
| if (can_move_reshape_into_reduce) { |
| changed_ = true; |
| absl::flat_hash_set<int64> dimensions_not_to_reduce; |
| for (auto dim_pair : unmodified_dims) { |
| if (arg_dim_in_output[dim_pair.second]) { |
| dimensions_not_to_reduce.insert(dim_pair.first); |
| } |
| } |
| std::vector<int64> new_reduce_dimensions; |
| for (int64 i = 0; i < arg->operand(0)->shape().rank(); ++i) { |
| if (!dimensions_not_to_reduce.contains(i)) { |
| new_reduce_dimensions.push_back(i); |
| } |
| } |
| return ReplaceWithNewInstruction( |
| reduce, HloInstruction::CreateReduce( |
| reduce_result_shape, arg->mutable_operand(0), init_value, |
| new_reduce_dimensions, function)); |
| } |
| } |
| // Convert Reduce(concat({a,b,...})) to |
| // map(reduce(a),map(reduce(b),...,)) |
| // |
| // This should make fusion easier or use less memory bandwidth in the unfused |
| // case. |
| if (arg->opcode() == HloOpcode::kConcatenate && |
| absl::c_linear_search(reduce->dimensions(), |
| arg->concatenate_dimension())) { |
| HloInstruction* old_reduce = nullptr; |
| for (HloInstruction* operand : arg->operands()) { |
| HloInstruction* new_reduce = computation_->AddInstruction( |
| HloInstruction::CreateReduce(reduce_result_shape, operand, init_value, |
| reduce->dimensions(), function)); |
| if (old_reduce != nullptr) { |
| new_reduce = computation_->AddInstruction(HloInstruction::CreateMap( |
| reduce_result_shape, {old_reduce, new_reduce}, function)); |
| } |
| old_reduce = new_reduce; |
| } |
| return ReplaceInstruction(reduce, old_reduce); |
| } |
| |
| HloInstruction *dot, *lhs, *rhs; |
| // Convert Reduce(Dot(X,Y)) to Dot(X,Y) if any of the dimensions reduced were |
| // batch dimensions of the dot. The transformation supports reducing other |
| // dimensions as well. |
| if (options_.enable_dot_strength_reduction() && |
| Match(arg, m::Dot(&dot, m::Op(&lhs), m::Op(&rhs)).WithOneUser()) && |
| Match(reduce->to_apply()->root_instruction(), |
| m::Add(m::Parameter(), m::Parameter())) && |
| absl::c_any_of(reduce->dimensions(), [&](int64 dim) { |
| return dim < dot->dot_dimension_numbers().lhs_batch_dimensions_size(); |
| })) { |
| const auto& dnums = dot->dot_dimension_numbers(); |
| DotDimensionNumbers new_dnums = dnums; |
| new_dnums.clear_lhs_batch_dimensions(); |
| new_dnums.clear_rhs_batch_dimensions(); |
| int64 removed_dims = 0; |
| for (int64 batch_dim = 0; batch_dim < dnums.lhs_batch_dimensions_size(); |
| ++batch_dim) { |
| if (absl::c_linear_search(reduce->dimensions(), batch_dim)) { |
| new_dnums.add_rhs_contracting_dimensions( |
| dnums.rhs_batch_dimensions(batch_dim)); |
| new_dnums.add_lhs_contracting_dimensions( |
| dnums.lhs_batch_dimensions(batch_dim)); |
| ++removed_dims; |
| } else { |
| new_dnums.add_rhs_batch_dimensions( |
| dnums.rhs_batch_dimensions(batch_dim)); |
| new_dnums.add_lhs_batch_dimensions( |
| dnums.lhs_batch_dimensions(batch_dim)); |
| } |
| } |
| std::vector<int64> reduce_dims; |
| for (int64 dim : reduce->dimensions()) { |
| if (dim >= dnums.lhs_batch_dimensions_size()) { |
| reduce_dims.push_back(dim - removed_dims); |
| } |
| } |
| TF_ASSIGN_OR_RETURN( |
| auto new_dot, |
| MakeDotHlo(lhs, rhs, new_dnums, dot->precision_config(), |
| /*preferred_element_type=*/dot->shape().element_type())); |
| dot->SetupDerivedInstruction(new_dot); |
| if (reduce_dims.empty()) { |
| return ReplaceInstruction(hlo, new_dot); |
| } |
| TF_ASSIGN_OR_RETURN( |
| auto new_reduce, |
| MakeReduceHlo(new_dot, init_value, reduce_dims, HloOpcode::kAdd)); |
| reduce->SetupDerivedInstruction(new_reduce); |
| return ReplaceInstruction(hlo, new_reduce); |
| } |
| return Status::OK(); |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleReduceWindow(HloInstruction* hlo) { |
| auto* reduce_window = Cast<HloReduceWindowInstruction>(hlo); |
| const bool multi_output_reduce_window = reduce_window->shape().IsTuple(); |
| auto inputs = reduce_window->inputs(); |
| auto init_values = reduce_window->init_values(); |
| auto input_count = reduce_window->input_count(); |
| auto input_shapes = reduce_window->input_shapes(); |
| auto output_shapes = reduce_window->output_shapes(); |
| auto replace_with_span = [&](const std::vector<HloInstruction*>& elements) { |
| CHECK(multi_output_reduce_window || elements.size() == 1); |
| if (multi_output_reduce_window) { |
| return ReplaceWithNewInstruction(reduce_window, |
| HloInstruction::CreateTuple(elements)); |
| } |
| return ReplaceInstruction(reduce_window, elements[0]); |
| }; |
| // For tuple reduce, we require all reduce shapes to be the same, up to the |
| // element types, so we can use just the first operand and the first result as |
| // a representative. |
| if (ShapeUtil::IsZeroElementArray(*input_shapes[0]) || |
| ShapeUtil::IsZeroElementArray(*output_shapes[0])) { |
| std::vector<HloInstruction*> broadcast_inits; |
| for (int64 i = 0; i < input_count; ++i) { |
| broadcast_inits.push_back( |
| computation_->AddInstruction(HloInstruction::CreateBroadcast( |
| *output_shapes[i], init_values[i], {}))); |
| } |
| return replace_with_span(broadcast_inits); |
| } |
| if (ShapeUtil::IsScalar(*input_shapes[0]) && |
| (!multi_output_reduce_window || |
| reduce_window->to_apply()->root_instruction()->opcode() == |
| HloOpcode::kTuple)) { |
| std::vector<HloInstruction*> maps; |
| for (int64 i = 0; i < input_count; ++i) { |
| TF_RET_CHECK(ShapeUtil::IsScalar(*input_shapes[i])); |
| TF_RET_CHECK(ShapeUtil::IsScalar(*output_shapes[i])); |
| HloInstruction* map_computation_root; |
| absl::flat_hash_map<const HloInstruction*, |
| std::unique_ptr<HloInstruction>> |
| replacements; |
| if (multi_output_reduce_window) { |
| map_computation_root = |
| reduce_window->to_apply()->root_instruction()->mutable_operand(i); |
| replacements[reduce_window->to_apply()->root_instruction()] = nullptr; |
| } else { |
| map_computation_root = reduce_window->to_apply()->root_instruction(); |
| } |
| auto map_computation = computation_->parent()->AddEmbeddedComputation( |
| reduce_window->to_apply()->CloneWithReplacements( |
| std::move(replacements), |
| /*extra_parameters=*/{}, nullptr, "clone", map_computation_root)); |
| auto map = computation_->AddInstruction(HloInstruction::CreateMap( |
| reduce_window->shape(), {init_values[i], inputs[i]}, |
| map_computation)); |
| maps.push_back(map); |
| } |
| return replace_with_span(maps); |
| } |
| // Turn trivial variadic reduce windows into normal reduce windows. |
| auto reduce_function_root = reduce_window->to_apply()->root_instruction(); |
| if (multi_output_reduce_window && input_count == 1 && |
| Match(reduce_function_root, m::Tuple())) { |
| // Make a new reducer which is identical but does not have a tuple |
| // instruction at the bottom. |
| absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>> |
| replacements; |
| replacements[reduce_function_root] = nullptr; |
| auto new_function = computation_->parent()->AddEmbeddedComputation( |
| reduce_window->to_apply()->CloneWithReplacements( |
| std::move(replacements), /*extra_parameters=*/{}, |
| /*context=*/nullptr, |
| /*suffix=*/"clone", |
| /*new_root=*/reduce_function_root->operand(0))); |
| auto new_reduce = |
| computation_->AddInstruction(HloInstruction::CreateReduceWindow( |
| *output_shapes[0], inputs[0], init_values[0], |
| reduce_window->window(), new_function)); |
| return ReplaceWithNewInstruction(reduce_window, |
| HloInstruction::CreateTuple({new_reduce})); |
| } |
| // TODO(b/73062247) Variadic reduce window is not yet supported in simplifier. |
| if (multi_output_reduce_window) { |
| return Status::OK(); |
| } |
| auto operand = reduce_window->mutable_operand(0); |
| auto function = reduce_window->to_apply(); |
| const Window& window = reduce_window->window(); |
| |
| if (options_.enable_window_reduce_to_reduce_replacement()) { |
| // A reduce window can be expressed as a reduce and a reshape if all |
| // dimensions either have a window size of one or the entire dimension. If |
| // there is no stride, dilation, or padding, this is as easy as checking the |
| // size of the output shape and window dimension. |
| // |
| // The reshape is a bitcast since it adds one-sized dimensions. Often these |
| // ones are immediately removed as well with another reshape. The |
| // implementation of reduce tends to be slightly more efficient at reducing |
| // entire dimensions compared to reduce window. |
| auto effective_reduce_dims = [&] { |
| if (window_util::HasStride(window) || window_util::HasDilation(window) || |
| window_util::HasPadding(window)) { |
| return absl::InlinedVector<int64, 8>{}; |
| } |
| absl::InlinedVector<int64, 8> reduce_dims; |
| for (int64 i = 0; i < window.dimensions_size(); ++i) { |
| if (window.dimensions(i).size() == 1) { |
| continue; |
| } else if (reduce_window->shape().dimensions(i) == 1) { |
| reduce_dims.push_back(i); |
| } else { |
| return absl::InlinedVector<int64, 8>{}; |
| } |
| } |
| return reduce_dims; |
| }(); |
| |
| // If a reduce window can be expressed as a reduce, do so and reshape the |
| // output. |
| if (!effective_reduce_dims.empty()) { |
| Shape reduce_shape = ShapeUtil::FilterDimensions( |
| [&](int64 dim) { |
| return !absl::c_linear_search(effective_reduce_dims, dim); |
| }, |
| reduce_window->shape()); |
| simplifier_->UpdateLayout(&reduce_shape); |
| HloInstruction* reduce = |
| computation_->AddInstruction(HloInstruction::CreateReduce( |
| /*shape=*/reduce_shape, |
| /*operand=*/operand, |
| /*init_value=*/reduce_window->mutable_operand(1), |
| /*dimensions_to_reduce=*/effective_reduce_dims, |
| /*reduce_computation=*/function)); |
| return ReplaceWithNewInstruction( |
| reduce_window, |
| HloInstruction::CreateReshape(reduce_window->shape(), reduce)); |
| } |
| } |
| |
| // This optimization folds a pad op into reduce_window. |
| HloInstruction* pad; |
| const HloInstruction* convert = nullptr; |
| if (operand->opcode() == HloOpcode::kPad) { |
| pad = operand; |
| } else if (operand->opcode() == HloOpcode::kConvert && |
| operand->operand(0)->opcode() == HloOpcode::kPad) { |
| convert = operand; |
| pad = operand->mutable_operand(0); |
| } else { |
| VLOG(10) << "Not folding pad into reduce-window as there is no pad."; |
| return Status::OK(); |
| } |
| |
| VLOG(10) << "Considering folding Pad: " << pad->ToString() |
| << "\ninto reduce-window: " << reduce_window->ToString() |
| << (convert != nullptr |
| ? absl::StrCat("\nvia convert: ", convert->ToString()) |
| : ""); |
| |
| // Do not fold interior padding into ReduceWindow since the backends do not |
| // support it. |
| const PaddingConfig& pad_config = pad->padding_config(); |
| if (HasInteriorPadding(pad_config) && window_util::HasBaseDilation(window)) { |
| VLOG(10) << "Not folding interior pad into base-dilated reduce-window."; |
| return Status::OK(); |
| } |
| |
| // If reduce_window already has padding, the pad value of the pad op and the |
| // init value of reduce_window must match to allow folding the pad. |
| const HloInstruction* pad_value = pad->operand(1); |
| const HloInstruction* reduce_init_value = reduce_window->operand(1); |
| if (pad_value != reduce_init_value) { |
| auto literals_are_equivalent = [&] { |
| auto& pad_literal = pad_value->literal(); |
| auto& reduce_init_literal = reduce_init_value->literal(); |
| if (pad_literal == reduce_init_literal) { |
| return true; |
| } |
| auto converted_pad_literal = |
| pad_literal.ConvertToShape(reduce_init_value->shape()); |
| if (!converted_pad_literal.ok()) { |
| return false; |
| } |
| return converted_pad_literal.ValueOrDie() == reduce_init_literal; |
| }; |
| // The pad value is usually a constant, so we handle that case and do not |
| // try to get more fancy about proving equivalence in cases beyond that. |
| if (pad_value->opcode() != HloOpcode::kConstant || |
| reduce_init_value->opcode() != HloOpcode::kConstant || |
| !literals_are_equivalent()) { |
| VLOG(10) << "Not folding pad into reduce-window due to different pad " |
| "values."; |
| return Status::OK(); |
| } |
| } |
| |
| // If the pad puts a single non-identity value in each window that we're |
| // reducing, then this is a broadcast. |
| HloInstruction* pad_operand = pad->mutable_operand(0); |
| auto is_effective_broadcast = [&] { |
| if (window_util::HasStride(window)) { |
| VLOG(10) << "Window has stride."; |
| return false; |
| } |
| if (!window_util::HasSymmetricPadding(pad_config)) { |
| VLOG(10) << "Window has uneven padding."; |
| return false; |
| } |
| if (HasInteriorPadding(pad_config)) { |
| VLOG(10) << "Window has interior padding."; |
| return false; |
| } |
| for (int64 i = 0; i < pad_config.dimensions_size(); ++i) { |
| const auto& pad_dimension = pad_config.dimensions(i); |
| if ((pad_dimension.edge_padding_low() != 0 || |
| pad_dimension.edge_padding_high() != 0) && |
| pad_operand->shape().dimensions(i) != 1) { |
| VLOG(10) << "Found non-trivial dimension being padded: " << i; |
| return false; |
| } |
| } |
| VLOG(10) << "Found to be padding trivial dimensions only."; |
| |
| for (int64 i = 0; i < window.dimensions_size(); ++i) { |
| const auto& pad_dimension = pad_config.dimensions(i); |
| const WindowDimension& window_dimension = window.dimensions(i); |
| bool dimension_has_padding = (pad_dimension.edge_padding_low() != 0 || |
| pad_dimension.edge_padding_high() != 0); |
| if (dimension_has_padding && |
| window_dimension.size() < pad_dimension.edge_padding_low() + 1) { |
| VLOG(10) << "Found window did not cover single unpadded element in " |
| "dimension: " |
| << i; |
| return false; |
| } |
| if (pad_operand->shape().dimensions(i) != 1 && |
| window_dimension.size() != 1) { |
| VLOG(10) << "Found window covers more than one element in non-trivial " |
| "dimension: " |
| << i; |
| return false; |
| } |
| } |
| VLOG(10) << "Found window covers a single unpadded element."; |
| return true; |
| }; |
| |
| HloInstruction* new_reduce_window_operand; |
| if (convert != nullptr) { |
| Shape changed_shape = ShapeUtil::ChangeElementType( |
| pad_operand->shape(), convert->shape().element_type()); |
| simplifier_->UpdateLayout(&changed_shape); |
| new_reduce_window_operand = computation_->AddInstruction( |
| HloInstruction::CreateConvert(changed_shape, pad_operand)); |
| } else { |
| new_reduce_window_operand = pad_operand; |
| } |
| |
| if (is_effective_broadcast()) { |
| VLOG(10) << "Replacing pad/reduce-window with broadcast."; |
| auto fadd = [this](std::unique_ptr<HloInstruction> x) { |
| return computation_->AddInstruction(std::move(x)); |
| }; |
| return ReplaceWithNewInstruction( |
| reduce_window, HloInstruction::CreateBroadcastSequence( |
| /*output_shape=*/reduce_window->shape(), |
| /*operand=*/new_reduce_window_operand, fadd)); |
| } |
| |
| // Carry out the folding of the pad into reduce_window. |
| VLOG(10) << "Folding pad into reduce-window."; |
| Window new_window = window; |
| const int64 rank = reduce_window->shape().rank(); |
| TF_RET_CHECK(pad_config.dimensions_size() == rank); |
| TF_RET_CHECK(window.dimensions_size() == rank); |
| for (int64 i = 0; i < rank; ++i) { |
| const auto& pad_dim = pad_config.dimensions(i); |
| auto& window_dim = *new_window.mutable_dimensions(i); |
| window_dim.set_padding_low(window_dim.padding_low() + |
| pad_dim.edge_padding_low()); |
| window_dim.set_padding_high(window_dim.padding_high() + |
| pad_dim.edge_padding_high()); |
| if (pad_dim.interior_padding() != 0) { |
| CHECK_EQ(window_dim.base_dilation(), 1); |
| window_dim.set_base_dilation(1 + pad_dim.interior_padding()); |
| } |
| } |
| |
| return ReplaceWithNewInstruction( |
| reduce_window, HloInstruction::CreateReduceWindow( |
| /*shape=*/reduce_window->shape(), |
| /*operand=*/new_reduce_window_operand, |
| /*init_value=*/reduce_window->mutable_operand(1), |
| /*window=*/new_window, |
| /*reduce_computation=*/function)); |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleSelect(HloInstruction* select) { |
| // select(x, y, y) -> y. |
| if (select->operand(1) == select->operand(2)) { |
| return ReplaceInstruction(select, select->mutable_operand(1)); |
| } |
| // select(true, x, y) -> x. |
| if (IsAll(select->operand(0), true)) { |
| return ReplaceInstruction(select, select->mutable_operand(1)); |
| } |
| // select(false, x, y) -> y. |
| if (IsAll(select->operand(0), false)) { |
| return ReplaceInstruction(select, select->mutable_operand(2)); |
| } |
| return Status::OK(); |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleScatter(HloInstruction* scatter) { |
| if (ShapeUtil::IsZeroElementArray(scatter->operand(2)->shape()) && |
| ReplaceInstructionIfSameShape(scatter, scatter->mutable_operand(0))) { |
| return Status::OK(); |
| } |
| if (ShapeUtil::IsZeroElementArray(scatter->operand(1)->shape()) && |
| SameShape(scatter, scatter->operand(0)) && |
| SameShape(scatter, scatter->operand(2))) { |
| return ReplaceWithNewInstruction( |
| scatter, HloInstruction::CreateMap( |
| scatter->shape(), |
| {scatter->mutable_operand(0), scatter->mutable_operand(2)}, |
| scatter->to_apply())); |
| } |
| return Status::OK(); |
| } |
| Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) { |
| auto operand = sort->mutable_operand(0); |
| int64 dimension_to_sort = sort->dimensions(0); |
| if (ShapeUtil::IsZeroElementArray(operand->shape()) || |
| operand->shape().dimensions(dimension_to_sort) <= 1) { |
| if (sort->operand_count() == 1) { |
| return ReplaceInstruction(sort, operand); |
| } |
| // If it is key/value sort, the output of sort is a tuple. |
| return ReplaceWithNewInstruction( |
| sort, HloInstruction::CreateTuple(sort->operands())); |
| } |
| return Status::OK(); |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleSqrt(HloInstruction* sqrt) { |
| VLOG(10) << "trying transform [sqrt(A*A) => |A|] " << sqrt->ToString(); |
| HloInstruction* sqrt_operand = sqrt->mutable_operand(0); |
| if (sqrt_operand->opcode() == HloOpcode::kMultiply && |
| sqrt_operand->operand(0) == sqrt_operand->operand(1)) { |
| return ReplaceWithNewInstruction( |
| sqrt, HloInstruction::CreateUnary( |
| sqrt_operand->mutable_operand(0)->shape(), HloOpcode::kAbs, |
| sqrt_operand->mutable_operand(0))); |
| } |
| return Status::OK(); |
| } |
| |
| namespace { |
| bool OnlyPermutesDegenerateDims(const Shape& shape, |
| absl::Span<const int64> perm) { |
| std::vector<int64> new_permutation; |
| int64 degenerate_count = 0; |
| for (int64 i = 0; i < perm.size(); ++i) { |
| if (shape.dimensions(i) != 1) { |
| new_permutation.push_back(perm[i]); |
| } else { |
| ++degenerate_count; |
| } |
| } |
| return degenerate_count > 0 && absl::c_is_sorted(new_permutation); |
| } |
| } // namespace |
| |
| Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { |
| auto operand = transpose->mutable_operand(0); |
| if (std::is_sorted(transpose->dimensions().begin(), |
| transpose->dimensions().end())) { |
| VLOG(10) << "deleting no-op transpose"; |
| return ReplaceInstruction(transpose, operand); |
| } |
| |
| if (HloOpcode::kTranspose == operand->opcode()) { |
| return ReplaceWithNewInstruction( |
| transpose, HloInstruction::CreateTranspose( |
| transpose->shape(), operand->mutable_operand(0), |
| ComposePermutations(operand->dimensions(), |
| transpose->dimensions()))); |
| } |
| |
| // Convert transpose(dot(a,b)) to dot(b,a). |
| if (operand->opcode() == HloOpcode::kDot && operand->user_count() == 1 && |
| operand->shape().rank() == 2) { |
| TF_ASSIGN_OR_RETURN(bool did_transform, [&]() -> StatusOr<bool> { |
| const auto& dnums = operand->dot_dimension_numbers(); |
| if (dnums.lhs_batch_dimensions_size() != 0) { |
| return false; |
| } |
| HloInstruction* lhs = operand->mutable_operand(0); |
| if (lhs->shape().rank() != 1 + dnums.lhs_contracting_dimensions_size()) { |
| return false; |
| } |
| HloInstruction* rhs = operand->mutable_operand(1); |
| if (rhs->shape().rank() != 1 + dnums.rhs_contracting_dimensions_size()) { |
| return false; |
| } |
| DotDimensionNumbers new_dnums; |
| *new_dnums.mutable_lhs_contracting_dimensions() = |
| dnums.rhs_contracting_dimensions(); |
| *new_dnums.mutable_rhs_contracting_dimensions() = |
| dnums.lhs_contracting_dimensions(); |
| TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( |
| transpose, HloInstruction::CreateDot(transpose->shape(), /*lhs=*/rhs, |
| /*rhs=*/lhs, new_dnums, |
| operand->precision_config()))); |
| return true; |
| }()); |
| if (did_transform) { |
| return Status::OK(); |
| } |
| } |
| |
| // Replace transpose with a reshape if more than one degenerate method is |
| // permuted. |
| if (OnlyPermutesDegenerateDims(transpose->shape(), transpose->dimensions())) { |
| return ReplaceWithNewInstruction( |
| transpose, HloInstruction::CreateReshape( |
| transpose->shape(), transpose->mutable_operand(0))); |
| } |
| |
| if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) { |
| *operand->mutable_shape() = transpose->shape(); |
| return ReplaceInstruction(transpose, operand); |
| } |
| |
| if (options_.is_layout_sensitive() && |
| options_.replace_transpose_with_bitcast() && |
| TransposeIsBitcast(transpose)) { |
| ReplaceWithBitcast(transpose); |
| return Status::OK(); |
| } |
| |
| return Status::OK(); |
| } |
| |
| StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvInputPad( |
| HloInstruction* convolution) { |
| HloInstruction *lhs, *a, *b; |
| if (Match(convolution, |
| m::Convolution(m::Pad(&lhs, m::Op(&a), m::ConstantScalar(0)), |
| m::Op(&b)))) { |
| const auto& window = convolution->window(); |
| const ConvolutionDimensionNumbers& dnums = |
| convolution->convolution_dimension_numbers(); |
| |
| const auto& padding = lhs->padding_config(); |
| |
| // Can't pad batch or feature dims. |
| for (int64 dim : |
| {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) { |
| const auto& p = padding.dimensions(dim); |
| if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || |
| p.interior_padding() != 0) { |
| return false; |
| } |
| } |
| |
| // Compute the window which is the result of merging the kPad and the |
| // convolution's existing window. |
| Window new_window = window; |
| for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) { |
| auto& w = *new_window.mutable_dimensions(dim); |
| const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim)); |
| // Edge padding composes with itself in the straightforward way, but |
| // composing interior padding is nontrivial, and we cowardly refuse to |
| // think about it. If we see interior padding in either the kPad or conv, |
| // bail if there's any sort of padding in the other. |
| if (p.interior_padding() != 0 && |
| (w.padding_low() != 0 || w.padding_high() != 0 || |
| w.base_dilation() != 1)) { |
| return false; |
| } |
| if (w.base_dilation() != 1 && |
| (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || |
| p.interior_padding() != 0)) { |
| return false; |
| } |
| |
| w.set_padding_low(w.padding_low() + p.edge_padding_low()); |
| w.set_padding_high(w.padding_high() + p.edge_padding_high()); |
| if (p.interior_padding() != 0) { |
| CHECK_EQ(w.base_dilation(), 1); |
| w.set_base_dilation(1 + p.interior_padding()); |
| } |
| } |
| |
| auto new_conv = |
| convolution->CloneWithNewOperands(convolution->shape(), {a, b}); |
| new_conv->set_window(new_window); |
| TF_RETURN_IF_ERROR( |
| ReplaceWithNewInstruction(convolution, std::move(new_conv))); |
| return true; |
| } |
| return false; |
| } |
| |
| StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvFilterPad( |
| HloInstruction* convolution) { |
| auto* lhs = convolution->mutable_operand(0); |
| auto* rhs = convolution->mutable_operand(1); |
| const ConvolutionDimensionNumbers& dnums = |
| convolution->convolution_dimension_numbers(); |
| |
| if (rhs->opcode() != HloOpcode::kPad) { |
| return false; |
| } |
| |
| // Convolution's padding is always zero, so bail if the kPad is adding |
| // something other than zero. |
| if (!IsAll(rhs->operand(1), 0)) { |
| return false; |
| } |
| |
| const auto& padding = rhs->padding_config(); |
| |
| // Can't pad or dilate feature dims. |
| for (int64 dim : {dnums.kernel_input_feature_dimension(), |
| dnums.kernel_output_feature_dimension()}) { |
| const auto& p = padding.dimensions(dim); |
| if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || |
| p.interior_padding() != 0) { |
| return false; |
| } |
| } |
| |
| // Compute the window which is the result of merging the kPad and the |
| // convolution's existing window. |
| Window new_window = convolution->window(); |
| for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) { |
| auto& w = *new_window.mutable_dimensions(dim); |
| const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim)); |
| |
| // We can only do this transformation if p adds dilation to the filter -- |
| // edge padding on the filter is not supported in conv. |
| if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) { |
| return false; |
| } |
| |
| // Nothing to do if the kPad for this dim is entirely a nop. |
| if (p.interior_padding() == 0) { |
| continue; |
| } |
| |
| // We cowardly refuse to think about how dilation composes with itself; |
| // bail if both the kPad and conv have dilation on this dimension. |
| if (w.window_dilation() > 1) { |
| return false; |
| } |
| CHECK_EQ(w.window_dilation(), 1); |
| w.set_window_dilation(1 + p.interior_padding()); |
| w.set_size(rhs->operand(0)->shape().dimensions( |
| dnums.kernel_spatial_dimensions(dim))); |
| } |
| |
| auto new_conv = convolution->CloneWithNewOperands( |
| convolution->shape(), {lhs, rhs->mutable_operand(0)}); |
| new_conv->set_window(new_window); |
| TF_RETURN_IF_ERROR( |
| ReplaceWithNewInstruction(convolution, std::move(new_conv))); |
| return true; |
| } |
| |
| StatusOr<bool> AlgebraicSimplifierVisitor::SwapConvOperands( |
| HloInstruction* convolution) { |
| if (!options_.enable_conv_operand_swap() || options_.is_layout_sensitive()) { |
| return false; |
| } |
| if (convolution->feature_group_count() > 1 || |
| convolution->batch_group_count() > 1) { |
| return false; |
| } |
| |
| const auto& dnums = convolution->convolution_dimension_numbers(); |
| const auto& window_dims = convolution->window().dimensions(); |
| Window swapped_window; |
| |
| HloInstruction *input = convolution->mutable_operand(0), |
| *kernel = convolution->mutable_operand(1); |
| int64 kernel_product = 1; |
| int64 swapped_kernel_product = 1; |
| DimensionVector reverse_dimensions; |
| for (int64 spatial_dim = 0; |
| spatial_dim < dnums.input_spatial_dimensions_size(); ++spatial_dim) { |
| const int64 kernel_size = window_dims[spatial_dim].size(); |
| const bool can_be_group_or_contraction = |
| !window_dims[spatial_dim].window_reversal() && |
| window_dims[spatial_dim].padding_low() == 0 && |
| window_dims[spatial_dim].padding_high() == 0 && |
| window_dims[spatial_dim].window_dilation() == 1; |
| const bool is_group_dim = |
| can_be_group_or_contraction && |
| window_dims[spatial_dim].base_dilation() == kernel_size && |
| window_dims[spatial_dim].stride() == kernel_size - 1; |
| const int64 input_size = |
| input->shape().dimensions(dnums.input_spatial_dimensions(spatial_dim)); |
| const bool is_pure_contraction_dim = |
| kernel_size == input_size && can_be_group_or_contraction && |
| window_dims[spatial_dim].base_dilation() == 1 && |
| window_dims[spatial_dim].stride() == 1; |
| if (is_group_dim || is_pure_contraction_dim) { |
| *(swapped_window.add_dimensions()) = window_dims[spatial_dim]; |
| continue; |
| } |
| |
| const int64 dilated_kernel_size = |
| 1 + (kernel_size - 1) * window_dims[spatial_dim].window_dilation(); |
| const int64 dilated_input_size = |
| 1 + (input_size - 1) * window_dims[spatial_dim].base_dilation(); |
| |
| // Don't decide to swap if the input size is one, since many convolution |
| // implementations can easily hand that special case efficiently. |
| kernel_product *= kernel_size; |
| swapped_kernel_product *= input_size == 1 ? kernel_size : input_size; |
| |
| auto new_dim = swapped_window.add_dimensions(); |
| new_dim->set_size(input_size); |
| // If the kernel is not reversed, the activations must be manually reversed. |
| if (!window_dims[spatial_dim].window_reversal()) { |
| reverse_dimensions.push_back( |
| dnums.kernel_spatial_dimensions(spatial_dim)); |
| } |
| // The input is not originally reversed so it must be reversed to move the |
| // kernel. |
| new_dim->set_window_reversal(true); |
| // Base dilation and window dilation switch places. |
| new_dim->set_base_dilation(window_dims[spatial_dim].window_dilation()); |
| new_dim->set_window_dilation(window_dims[spatial_dim].base_dilation()); |
| new_dim->set_stride(window_dims[spatial_dim].stride()); |
| new_dim->set_padding_low(dilated_input_size + |
| window_dims[spatial_dim].padding_low() - |
| dilated_kernel_size); |
| new_dim->set_padding_high(dilated_input_size + |
| window_dims[spatial_dim].padding_high() - |
| dilated_kernel_size); |
| } |
| |
| // Don't transform if a naive convolution implementation would not have fewer |
| // flops. |
| if (kernel_product <= swapped_kernel_product) { |
| return false; |
| } |
| ConvolutionDimensionNumbers swapped_dnums; |
| *swapped_dnums.mutable_output_spatial_dimensions() = |
| dnums.output_spatial_dimensions(); |
| // Swap batch and output feature of the output. |
| swapped_dnums.set_output_batch_dimension(dnums.output_feature_dimension()); |
| swapped_dnums.set_output_feature_dimension(dnums.output_batch_dimension()); |
| |
| // Swap input dnums with kernel dnums |
| *swapped_dnums.mutable_input_spatial_dimensions() = |
| dnums.kernel_spatial_dimensions(); |
| swapped_dnums.set_input_batch_dimension( |
| dnums.kernel_output_feature_dimension()); |
| swapped_dnums.set_input_feature_dimension( |
| dnums.kernel_input_feature_dimension()); |
| |
| // Swap kernel dnums with input dnums |
| *swapped_dnums.mutable_kernel_spatial_dimensions() = |
| dnums.input_spatial_dimensions(); |
| swapped_dnums.set_kernel_output_feature_dimension( |
| dnums.input_batch_dimension()); |
| swapped_dnums.set_kernel_input_feature_dimension( |
| dnums.input_feature_dimension()); |
| |
| PrecisionConfig precision_config; |
| precision_config.add_operand_precision( |
| convolution->precision_config().operand_precision(1)); |
| precision_config.add_operand_precision( |
| convolution->precision_config().operand_precision(0)); |
| if (!reverse_dimensions.empty()) { |
| TF_ASSIGN_OR_RETURN(kernel, MakeReverseHlo(kernel, reverse_dimensions)); |
| } |
| TF_ASSIGN_OR_RETURN( |
| HloInstruction * new_convolution, |
| MakeConvolveHlo( |
| kernel, input, /*feature_group_count=*/1, |
| /*batch_group_count=*/1, swapped_window, swapped_dnums, |
| precision_config, |
| /*preferred_element_type=*/convolution->shape().element_type())); |
| |
| convolution->SetupDerivedInstruction(new_convolution); |
| TF_RETURN_IF_ERROR(ReplaceInstruction(convolution, new_convolution)); |
| |
| return true; |
| } |
| |
| StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot( |
| HloInstruction* convolution) { |
| auto* lhs = convolution->mutable_operand(0); |
| auto* rhs = convolution->mutable_operand(1); |
| const auto& window = convolution->window(); |
| const ConvolutionDimensionNumbers& dnums = |
| convolution->convolution_dimension_numbers(); |
| |
| if (!options_.enable_conv_simplification()) { |
| return false; |
| } |
| |
| // TODO(b/31337498): For now, we cowardly refuse to do this optimization in |
| // layout-insensitive mode, for fear of adding nontrivial reshapes. |
| if (!options_.is_layout_sensitive()) { |
| return false; |
| } |
| |
| const Shape& input_shape = lhs->shape(); |
| const Shape& filter_shape = rhs->shape(); |
| const Shape& convolution_shape = convolution->shape(); |
| TF_RET_CHECK(LayoutUtil::HasLayout(input_shape)); |
| TF_RET_CHECK(LayoutUtil::HasLayout(filter_shape)); |
| TF_RET_CHECK(LayoutUtil::HasLayout(convolution_shape)); |
| |
| // Require the spatial dimensions in the kernel to have a bound of one. |
| for (int64 i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) { |
| if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(i)) != 1) { |
| return false; |
| } |
| } |
| |
| // Stride ignores part of the output, which matrix multiplication does not do, |
| // so require no stride. Padding and base (lhs) dilation both implicitly |
| // extend the data, which matrix multiplication also does not do, so require |
| // no padding and no base (lhs) dilation. Window (rhs) dilation has no effect |
| // for a 1x1 window, so window dilation is no problem. |
| if (window_util::HasStride(window) || window_util::HasPadding(window) || |
| window_util::HasBaseDilation(window)) { |
| return false; |
| } |
| |
| // Also, the shapes must align for a rowmajor matmul: |
| // - the input and output have the same layout. |
| // - for input/output, the channel dimension must be the most minor. Other |
| // spatial dims can be in any order. |
| // - for filters, the input channel dimension must be more major than the |
| // output channel dimension. The width+height don't matter because |
| // they are 1. |
| // |
| // These constraints are harsh. If the channel dimension is the most major |
| // and/or the layout of input/output feature dimensions are reversed, we can |
| // still convert Conv into more efficient Matmul with operand transposition |
| // (such as the transposition flags in cuBLAS SGEMM). |
| if (!LayoutUtil::Equal(input_shape.layout(), convolution_shape.layout()) || |
| LayoutUtil::Minor(input_shape.layout(), 0) != |
| dnums.input_feature_dimension() || |
| LayoutUtil::Minor(convolution_shape.layout(), 0) != |
| dnums.output_feature_dimension() || |
| // The input feature dimension should come later in the minor-to-major |
| // order. |
| (PositionInContainer(LayoutUtil::MinorToMajor(filter_shape), |
| dnums.kernel_input_feature_dimension()) < |
| PositionInContainer(LayoutUtil::MinorToMajor(filter_shape), |
| dnums.kernel_output_feature_dimension()))) { |
| return false; |
| } |
| |
| auto add_bitcast = [&](Shape shape, HloInstruction* operand) { |
| std::vector<int64> dims(operand->shape().dimensions_size()); |
| std::iota(dims.begin(), dims.end(), 0); |
| return computation_->AddInstruction( |
| HloInstruction::CreateBitcast(shape, operand)); |
| }; |
| |
| // Replace it with a dot, with bitcasts around it to get the right shape. |
| const int64 input_channels = |
| input_shape.dimensions(dnums.input_feature_dimension()); |
| const int64 output_channels = |
| filter_shape.dimensions(dnums.kernel_output_feature_dimension()); |
| |
| // Computes the product of the non-feature dimensions. |
| int64 conv_width = 1; |
| for (int i = 0; i < input_shape.dimensions_size(); ++i) { |
| if (i != dnums.input_feature_dimension()) { |
| conv_width *= input_shape.dimensions(i); |
| } |
| } |
| |
| // We already checked feature_dimension is most minor, so data in input_shape |
| // and row-major {conv_width,input_channels} are bitwise identical. |
| Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout( |
| input_shape.element_type(), {conv_width, input_channels}); |
| simplifier_->UpdateLayout(&new_input_shape); |
| // We already checked input_feature_dimension is more major than |
| // output_feature_dimension, so data in filter_shape and row-major |
| // {input_channels,output_channels} are bitwise identical. |
| Shape new_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout( |
| filter_shape.element_type(), {input_channels, output_channels}); |
| simplifier_->UpdateLayout(&new_filter_shape); |
| Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout( |
| convolution_shape.element_type(), {conv_width, output_channels}); |
| simplifier_->UpdateLayout(&dot_output_shape); |
| |
| auto new_lhs = add_bitcast(new_input_shape, lhs); |
| auto new_rhs = add_bitcast(new_filter_shape, rhs); |
| DotDimensionNumbers dot_dimension_numbers; |
| dot_dimension_numbers.add_lhs_contracting_dimensions(1); |
| dot_dimension_numbers.add_rhs_contracting_dimensions(0); |
| auto dot = computation_->AddInstruction(HloInstruction::CreateDot( |
| dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers, |
| convolution->precision_config())); |
| |
| TF_RETURN_IF_ERROR( |
| ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot))); |
| return true; |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleConvolution( |
| HloInstruction* convolution) { |
| if (options_.enable_scalar_multiply_reduction()) { |
| TF_RETURN_IF_ERROR(ScalarMultiplyReduction(convolution)); |
| } |
| |
| // Zero-sized input or filter. |
| if (ShapeUtil::IsZeroElementArray(convolution->operand(0)->shape()) || |
| ShapeUtil::IsZeroElementArray(convolution->operand(1)->shape())) { |
| return ReplaceInstruction(convolution, MakeScalarLike(convolution, 0)); |
| } |
| |
| // Try to merge padding/dilation of the input with the convolution's window. |
| TF_ASSIGN_OR_RETURN(bool folded_input_pad, FoldConvInputPad(convolution)); |
| if (folded_input_pad) { |
| return Status::OK(); |
| } |
| |
| // Try to merge dilation of the filter with the convolution's window. |
| TF_ASSIGN_OR_RETURN(bool folded_filter_pad, FoldConvFilterPad(convolution)); |
| if (folded_filter_pad) { |
| return Status::OK(); |
| } |
| |
| // Try to swap convolution operands. |
| TF_ASSIGN_OR_RETURN(bool swapped, SwapConvOperands(convolution)); |
| if (swapped) { |
| return Status::OK(); |
| } |
| // Try to replace the convolution with a kDot instruction. |
| TF_ASSIGN_OR_RETURN(bool replaced_with_dot, SimplifyConvToDot(convolution)); |
| if (replaced_with_dot) { |
| return Status::OK(); |
| } |
| |
| return Status::OK(); |
| } |
| |
| bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape( |
| HloInstruction* root, HloInstruction* min, HloInstruction* min_operand, |
| HloInstruction* operand, HloInstruction* max, HloInstruction* max_operand) { |
| // Ensure shapes of min and max operand are equal to match current shape |
| // inference. |
| if (!SameShape(min_operand, max_operand)) { |
| return false; |
| } |
| |
| auto clamp = HloInstruction::CreateTernary(root->shape(), HloOpcode::kClamp, |
| max_operand, operand, min_operand); |
| TF_CHECK_OK(ReplaceWithNewInstruction(root, std::move(clamp))); |
| return true; |
| } |
| |
| Status AlgebraicSimplifierVisitor::HandleMap(HloInstruction* map) { |
| auto* map_computation = map->to_apply(); |
| auto* map_root = map_computation->root_instruction(); |
| if (map_root->opcode() == HloOpcode::kParameter) { |
| ReplaceInstructionIfSameShape( |
| map, map->mutable_operand(map_root->parameter_number())); |
| return Status::OK(); |
| } |
| if (map_root->opcode() == HloOpcode::kConstant) { |
| if (!ShapeUtil::IsScalar(map_root->shape())) { |
| return Status::OK(); |
| } |
| auto clone = map_root->CloneWithNewOperands(map_root->shape(), {}); |
| if (ShapeUtil::IsScalar(map->shape())) { |
| return ReplaceWithNewInstruction(map, std::move(clone)); |
| } |
| return ReplaceWithNewInstruction( |
| map, |
| HloInstruction::CreateBroadcast( |
| map->shape(), computation_->AddInstruction(std::move(clone)), {})); |
| } |
| // Inline the map if the map computation only contains an elementwise |
| // operation that can accept arbitrary shapes. |
| if (map_root->opcode() == HloOpcode::kFusion || !map_root->IsElementwise()) { |
| return Status::OK(); |
| } |
| std::vector<HloInstruction*> new_operands; |
| for (auto* root_operand : map_root->operands()) { |
| if (root_operand->opcode() != HloOpcode::kParameter) { |
| return Status::OK(); |
| } |
| new_operands.push_back( |
| map->mutable_operand(root_operand->parameter_number())); |
| } |
| auto clone = map_root->CloneWithNewOperands(map->shape(), new_operands); |
| return ReplaceWithNewInstruction(map, std::move(clone)); |
| } |
| |
| StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) { |
| XLA_VLOG_LINES(2, |
| "AlgebraicSimplifier::Run(), before:\n" + module->ToString()); |
| bool changed = false; |
| AlgebraicSimplifierVisitor visitor(options_, this); |
| for (auto* comp : module->MakeNonfusionComputations()) { |
| if (visitor.Run(comp, options_, this)) { |
| changed = true; |
| } |
| } |
| XLA_VLOG_LINES(2, |
| "AlgebraicSimplifier::Run(), after:\n" + module->ToString()); |
| return changed; |
| } |
| |
| } // namespace xla |