blob: 978c1f69a95b4b5cd7b455c81825ab5bcb8bca47 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_query.h"
#include "tensorflow/compiler/xla/service/slow_operation_alarm.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
// Checks whether instr is or transitively contains an instruction that we
// shouldn't fold.
//
// Specifically, we don't fold kRng or kAfterAll instructions:
//
// - kRng is already marked as side-effecting and so is skipped elsewhere, but
// we check for it here. Even kRng weren't side-effecting and took an
// explicit seed, we *still* wouldn't want to constant-fold it, because the
// evaluator's handling of rng is not guaranteed to be identical to any
// particular backend's rng.
//
// - kAfterAll needs to be skipped because a kAfterAll op with no args can
// currently materialize a token "out of thin air". TODO(b/110532604):
// Remove this check once AfterAll requires at least one operand, in which
// case constant folding will be impossible.
static bool IsOrContainsIllegalInstr(const HloInstruction* instr) {
if (instr->opcode() == HloOpcode::kAfterAll ||
instr->opcode() == HloOpcode::kRng) {
return true;
}
for (const HloComputation* c : instr->called_computations()) {
if (absl::c_any_of(c->instructions(), IsOrContainsIllegalInstr)) {
return true;
}
}
return false;
}
/*static*/ std::atomic<int64_t> HloConstantFolding::slow_op_counter_{0};
StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
// Limit the constant folding to 0 iterations to skip folding loops. This
// retains the behavior from before while loop support in HloEvaluator and may
// be revised.
auto evaluator = std::make_unique<HloEvaluator>(/*max_loop_iterations=*/0);
// fast-path lets us e.g. use Eigen for matmuls.
evaluator->set_use_fast_path(true);
bool changed = false;
for (auto* computation : module->MakeNonfusionComputations()) {
for (auto* instruction : computation->MakeInstructionPostOrder()) {
// Skip dead code.
if (instruction->IsDead()) {
continue;
}
// We only handle instructions where
//
// - at least one operand is a constant, and
// - all other operands are either constants or broadcast(constant).
//
// Why this particular set of rules around broadcasts?
//
// - We don't want to fold broadcast(constant) on its own, because in
// general it's "simpler" to remember that it's a broadcast. Also,
// algsimp will fold an all-one-value constant into a broadcast, so
// we'd just end up fighting with it.
//
// - We don't want to fold an op where all operands are broadcasts of
// constants, because algsimp will transform op(broadcast(constant) =>
// broadcast(op(constant)). Then we can constant-fold the smaller op.
//
// - So the only remaining case is where some but not all operands are
// broadcasts of constants, e.g. op(constant, broadcast(constant)).
//
if (!absl::c_any_of(instruction->operands(),
[](const HloInstruction* operand) {
return operand->opcode() == HloOpcode::kConstant;
}) ||
!absl::c_all_of(
instruction->operands(), [](const HloInstruction* operand) {
return operand->opcode() == HloOpcode::kConstant ||
(operand->opcode() == HloOpcode::kBroadcast &&
operand->operand(0)->opcode() == HloOpcode::kConstant);
})) {
continue;
}
// Don't fold Constant, Parameter, and Tuple instructions. Tuple
// constants are not directly supported by any backends, hence folding
// Tuple is not useful and would in fact be expanded back into kTuple by
// Algebraic Simplifier.
//
// (We do allow folding subcomputations that contain these instructions.)
if (instruction->opcode() == HloOpcode::kParameter ||
instruction->opcode() == HloOpcode::kConstant ||
instruction->opcode() == HloOpcode::kTuple) {
continue;
}
// Broadcasts dramatically increase the size of constants, which is often
// detrimental to performance and memory capacity, so do not fold
// broadcasts.
if (instruction->opcode() == HloOpcode::kBroadcast ||
instruction->opcode() == HloOpcode::kIota) {
continue;
}
// Do not fold FFT. Evaluating it may significantly increase compile time.
if (instruction->opcode() == HloOpcode::kFft) {
continue;
}
// Check for instructions that we can't fold even if they appear inside of
// a subcomputation (e.g. a kCall).
if (IsOrContainsIllegalInstr(instruction)) {
continue;
}
// Don't constant-fold side-effecting instructions or instructions which
// contain side-effecting instructions.
if (instruction->HasSideEffect()) {
continue;
}
// Don't constant fold unless it's a net positive or the output is small.
if (instruction->shape().IsArray()) {
int64_t elements_in_removed_operands = 0;
for (HloInstruction* operand : instruction->operands()) {
if (operand->user_count() == 1 && operand->shape().IsArray()) {
elements_in_removed_operands +=
ShapeUtil::ElementsIn(operand->shape());
}
}
int64_t elements_in_constant =
ShapeUtil::ElementsIn(instruction->shape());
static const int64_t kMaximumConstantSizeElements = 45 * 1000 * 1000;
if (elements_in_constant > elements_in_removed_operands &&
elements_in_constant > kMaximumConstantSizeElements) {
continue;
}
}
VLOG(5) << "Constant folding: " << instruction->ToString();
absl::Duration slow_timeout =
absl::Seconds(uint64_t{1} << slow_op_counter_.load());
// We cannot call `instruction->ToString() within the callback, because
// the instruction may be modified and invalidated in place, and ToString
// will fail if the compilation is slow. We probably do not want to
// call `ToString()` for all the instructions, thus, we only display the
// name by default.
std::string instruction_msg;
if (VLOG_IS_ON(4)) {
instruction_msg = instruction->ToString();
} else {
instruction_msg =
absl::StrCat(instruction->name(),
" (displaying the full instruction incurs a runtime "
"overhead. Raise your logging level to 4 or above).");
}
SlowOperationAlarm slow_alarm(slow_timeout, [instruction_msg = std::move(
instruction_msg),
slow_timeout] {
const bool ndebug =
#if NDEBUG
true;
#else
false;
#endif
absl::string_view explanation_msg =
ndebug
? "This isn't necessarily a bug; constant-folding is "
"inherently a trade-off between compilation time and speed "
"at runtime. XLA has some guards that attempt to keep "
"constant folding from taking too long, but fundamentally "
"you'll always be able to come up with an input program that "
"takes a long time.\n\n"
"If you'd like to file a bug, run with envvar "
"XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results."
: "XLA was built without compiler optimizations, which can be "
"slow. Try rebuilding with -c opt.";
return absl::StrFormat(
"Constant folding an instruction is taking > %s:\n\n"
" %s\n\n" // instruction->name() or instruction->ToString()
"%s", // explanation_msg
absl::FormatDuration(slow_timeout), instruction_msg,
explanation_msg);
});
// Currently we skip unimplemented operations.
// TODO(b/35975797): Fold constant computations for more operations.
Literal result;
if (!evaluator->TryEvaluate(
instruction, &result,
/*recursively_evaluate_nonconstant_operands=*/true)) {
VLOG(2) << "Constant folding failed for instruction: "
<< instruction->ToString();
continue;
}
slow_alarm.cancel();
if (slow_alarm.fired()) {
slow_op_counter_++;
}
VLOG(4) << "Constant folded: " << instruction->ToString();
TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
instruction, HloInstruction::CreateConstant(std::move(result))));
changed = true;
}
}
return changed;
}
} // namespace xla