blob: a7e1d3a80d7ff21fe505daf8a54642121ea70aef [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include <algorithm>
#include <cmath>
#include <complex>
#include <cstdlib>
#include <functional>
#include <iterator>
#include <string>
#include <type_traits>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_query.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/core/lib/core/bitmap.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
namespace {
template <typename OperandT>
StatusOr<Literal> Compare(const Shape& shape, ComparisonDirection direction,
LiteralSlice lhs_literal, LiteralSlice rhs_literal) {
std::function<bool(OperandT, OperandT)> compare_op;
switch (direction) {
case ComparisonDirection::kEq:
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
return lhs_el == rhs_el;
};
break;
case ComparisonDirection::kNe:
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
return lhs_el != rhs_el;
};
break;
case ComparisonDirection::kGe:
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
return lhs_el >= rhs_el;
};
break;
case ComparisonDirection::kGt:
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
return lhs_el > rhs_el;
};
break;
case ComparisonDirection::kLe:
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
return lhs_el <= rhs_el;
};
break;
case ComparisonDirection::kLt:
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
return lhs_el < rhs_el;
};
break;
}
Literal result(shape);
TF_RETURN_IF_ERROR(
result.Populate<bool>([&](absl::Span<const int64> multi_index) {
return compare_op(lhs_literal.Get<OperandT>(multi_index),
rhs_literal.Get<OperandT>(multi_index));
}));
return std::move(result);
}
template <>
StatusOr<Literal> Compare<complex64>(const Shape& shape,
ComparisonDirection direction,
LiteralSlice lhs_literal,
LiteralSlice rhs_literal) {
std::function<bool(complex64, complex64)> compare_op;
switch (direction) {
case ComparisonDirection::kEq:
compare_op = [](complex64 lhs_el, complex64 rhs_el) {
return lhs_el == rhs_el;
};
break;
case ComparisonDirection::kNe:
compare_op = [](complex64 lhs_el, complex64 rhs_el) {
return lhs_el != rhs_el;
};
break;
default:
LOG(FATAL) << "unhandled direction for conversion to Comparison: "
<< ComparisonDirectionToString(direction);
}
Literal result(shape);
TF_RETURN_IF_ERROR(
result.Populate<bool>([&](absl::Span<const int64> multi_index) {
return compare_op(lhs_literal.Get<complex64>(multi_index),
rhs_literal.Get<complex64>(multi_index));
}));
return std::move(result);
}
template <>
StatusOr<Literal> Compare<complex128>(const Shape& shape,
ComparisonDirection direction,
LiteralSlice lhs_literal,
LiteralSlice rhs_literal) {
std::function<bool(complex128, complex128)> compare_op;
switch (direction) {
case ComparisonDirection::kEq:
compare_op = [](complex128 lhs_el, complex128 rhs_el) {
return lhs_el == rhs_el;
};
break;
case ComparisonDirection::kNe:
compare_op = [](complex128 lhs_el, complex128 rhs_el) {
return lhs_el != rhs_el;
};
break;
default:
LOG(FATAL) << "unhandled direction for conversion to Comparison: "
<< ComparisonDirectionToString(direction);
}
Literal result(shape);
TF_RETURN_IF_ERROR(
result.Populate<bool>([&](absl::Span<const int64> multi_index) {
return compare_op(lhs_literal.Get<complex128>(multi_index),
rhs_literal.Get<complex128>(multi_index));
}));
return std::move(result);
}
} // namespace
// Note that unsupported types by the typed visitor does not necessarily imply
// the non-typed HloEvaluator (parent evaluator) would not support them either
// in the type-agnostic handler. For e.g., HandleGetTupleElement in the parent
// type-agnostic evaluator will be able to accept Tuple primitive type, whereas
// HloEvaluatorTypedVisitor cannot.
HloEvaluator::HloEvaluator(int64 max_loop_iterations)
: max_loop_iterations_(max_loop_iterations) {
typed_visitors_[PRED] =
absl::make_unique<HloEvaluatorTypedVisitor<bool>>(this);
typed_visitors_[U8] =
absl::make_unique<HloEvaluatorTypedVisitor<uint8>>(this);
typed_visitors_[U16] =
absl::make_unique<HloEvaluatorTypedVisitor<uint16>>(this);
typed_visitors_[U32] =
absl::make_unique<HloEvaluatorTypedVisitor<uint32>>(this);
typed_visitors_[U64] =
absl::make_unique<HloEvaluatorTypedVisitor<uint64>>(this);
typed_visitors_[S8] = absl::make_unique<HloEvaluatorTypedVisitor<int8>>(this);
typed_visitors_[S16] =
absl::make_unique<HloEvaluatorTypedVisitor<int16>>(this);
typed_visitors_[S32] =
absl::make_unique<HloEvaluatorTypedVisitor<int32>>(this);
typed_visitors_[S64] =
absl::make_unique<HloEvaluatorTypedVisitor<int64>>(this);
typed_visitors_[F16] =
absl::make_unique<HloEvaluatorTypedVisitor<Eigen::half, float>>(this);
typed_visitors_[F32] =
absl::make_unique<HloEvaluatorTypedVisitor<float>>(this);
typed_visitors_[F64] =
absl::make_unique<HloEvaluatorTypedVisitor<double>>(this);
typed_visitors_[C64] =
absl::make_unique<HloEvaluatorTypedVisitor<complex64>>(this);
typed_visitors_[C128] =
absl::make_unique<HloEvaluatorTypedVisitor<complex128>>(this);
// Most of the evaluator computations we use don't support BF16 (e.g.,
// std::ceil, std::tanh). To make evaluator work with BF16, we set all
// elementwise computations to be done in F32 and do BF16<->F32 conversion
// around the input and the output of the computations.
typed_visitors_[BF16] =
absl::make_unique<HloEvaluatorTypedVisitor<bfloat16, float>>(this);
typed_visitors_[TUPLE] =
absl::make_unique<FunctionVisitor>([](HloInstruction*) {
return Unimplemented(
"HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE.");
});
typed_visitors_[OPAQUE_TYPE] =
absl::make_unique<FunctionVisitor>([](HloInstruction*) {
return Unimplemented(
"HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE_TYPE.");
});
typed_visitors_[TOKEN] =
absl::make_unique<FunctionVisitor>([](HloInstruction*) {
return Unimplemented(
"HloEvaluatorTypedVisitor: unhandled primitive type: TOKEN.");
});
}
StatusOr<Literal> HloEvaluator::Evaluate(
const HloComputation& computation,
absl::Span<const Literal* const> arg_literals) {
CHECK(computation.parent() != nullptr);
XLA_VLOG_LINES(
2, "HloEvaluator::Evaluate computation:\n" + computation.ToString());
if (arg_literals.size() != computation.num_parameters()) {
return InvalidArgument(
"Expected %d argument%s, but got %d.", computation.num_parameters(),
computation.num_parameters() == 1 ? "" : "s", arg_literals.size());
}
for (int64 i = 0; i < arg_literals.size(); ++i) {
const auto& computation_shape =
computation.parameter_instruction(i)->shape();
const auto& arg_shape = arg_literals[i]->shape();
if (!Shape::Equal().MinorToMajorOnlyInLayout()(computation_shape,
arg_shape)) {
return InvalidArgument(
"Shape mismatch at parameter %d. Computation expected %s, but arg "
"was %s.",
i, ShapeUtil::HumanStringWithLayout(computation_shape),
ShapeUtil::HumanStringWithLayout(arg_shape));
}
}
evaluated_.clear();
arg_literals_.clear();
for (const auto& literal_ptr : arg_literals) {
arg_literals_.push_back(&*literal_ptr);
}
// Re-seed RNG, either from the configuration's seed or a monotonic
// per-evaluator seed (which prevents two evaluators from returning the same
// random sequence).
if (computation.parent()->config().seed()) {
seed_ = computation.parent()->config().seed();
} else {
// Start global_seed at a (true) random value.
static std::atomic<uint64> global_seed{std::random_device()()};
seed_ = global_seed.fetch_add(1);
}
engine_.seed(seed_);
TF_RETURN_IF_ERROR(computation.Accept(this));
return GetEvaluatedLiteralFor(computation.root_instruction()).Clone();
}
StatusOr<Literal> HloEvaluator::Evaluate(HloInstruction* instruction) {
if (instruction->opcode() == HloOpcode::kParameter) {
return tensorflow::errors::FailedPrecondition(
"Cannot evaluate a parameter.");
}
if (!hlo_query::AllOperandsAreConstants(*instruction)) {
return tensorflow::errors::FailedPrecondition(
"Not all operands are constants.");
}
arg_literals_.clear();
evaluated_.clear();
TF_RETURN_IF_ERROR(Preprocess(instruction));
TF_RETURN_IF_ERROR(instruction->Visit(this));
TF_RETURN_IF_ERROR(Postprocess(instruction));
return GetEvaluatedLiteralFor(instruction).Clone();
}
bool HloEvaluator::TryEvaluate(HloInstruction* instruction, Literal* result) {
CHECK(result != nullptr);
auto result_or = Evaluate(instruction);
if (!result_or.ok()) {
VLOG(1) << "TryEvaluate failed:" << result_or.status();
return false;
}
*result = result_or.ConsumeValueOrDie();
return true;
}
StatusOr<Literal> HloEvaluator::EvaluateWithSubstitutions(
const HloInstruction* instruction,
const std::unordered_map<const HloInstruction*, const Literal*>&
substitutions) {
std::vector<std::unique_ptr<HloInstruction>> owned_operands;
for (const HloInstruction* operand : instruction->operands()) {
auto it = substitutions.find(operand);
if (it == substitutions.end()) {
owned_operands.push_back(operand->Clone());
} else {
owned_operands.push_back(
HloInstruction::CreateConstant(it->second->Clone()));
}
}
std::vector<HloInstruction*> operands;
operands.reserve(owned_operands.size());
for (auto& operand : owned_operands) {
operands.push_back(operand.get());
}
std::unique_ptr<HloInstruction> cloned_instruction =
instruction->CloneWithNewOperands(instruction->shape(), operands);
auto result = Evaluate(cloned_instruction.get());
return result;
}
StatusOr<Literal> HloEvaluator::EvaluateElementwiseBinaryOp(
HloOpcode opcode, const Literal& lhs, const Literal& rhs) {
std::unique_ptr<HloInstruction> lhs_instr =
HloInstruction::CreateConstant(lhs.Clone());
std::unique_ptr<HloInstruction> rhs_instr =
HloInstruction::CreateConstant(rhs.Clone());
std::unique_ptr<HloInstruction> cloned_instruction =
HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(),
rhs_instr.get());
auto result = Evaluate(cloned_instruction.get());
return result;
}
StatusOr<Literal> HloEvaluator::EvaluateElementwiseUnaryOp(
HloOpcode opcode, const Literal& operand) {
std::unique_ptr<HloInstruction> operand_instr =
HloInstruction::CreateConstant(operand.Clone());
std::unique_ptr<HloInstruction> cloned_instruction =
HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get());
auto result = Evaluate(cloned_instruction.get());
return result;
}
StatusOr<Literal> HloEvaluator::EvaluateDotOp(
const DotDimensionNumbers& dim_numbers,
const PrecisionConfig& precision_config, const Literal& lhs,
const Literal& rhs) {
std::unique_ptr<HloInstruction> lhs_instr =
HloInstruction::CreateConstant(lhs.Clone());
std::unique_ptr<HloInstruction> rhs_instr =
HloInstruction::CreateConstant(rhs.Clone());
TF_ASSIGN_OR_RETURN(
Shape dot_shape,
ShapeInference::InferDotOpShape(lhs.shape(), rhs.shape(), dim_numbers));
std::unique_ptr<HloInstruction> cloned_instruction =
HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(),
dim_numbers, precision_config);
return Evaluate(cloned_instruction.get());
}
Status HloEvaluator::HandleBitcast(HloInstruction* bitcast) {
const Literal& operand_literal = GetEvaluatedLiteralFor(bitcast->operand(0));
Literal result(bitcast->shape());
TF_RET_CHECK(operand_literal.size_bytes() == result.size_bytes());
memcpy(result.untyped_data(), operand_literal.untyped_data(),
operand_literal.size_bytes());
evaluated_[bitcast] = std::move(result);
return Status::OK();
}
Status HloEvaluator::HandleGetDimensionSize(
HloInstruction* get_dimension_size) {
HloInstruction* operand = get_dimension_size->mutable_operand(0);
int64 dim = get_dimension_size->dimension();
if (dynamic_dimension_inference_ == nullptr) {
return InvalidArgument(
"Evaluator cannot evaluate get_dimension_size without "
"set_dynamic_dimension_inference.");
}
HloInstruction* dynamic_size =
dynamic_dimension_inference_->GetDynamicSize(operand, {}, dim);
if (dynamic_size != nullptr) {
evaluated_[get_dimension_size] =
GetEvaluatedLiteralFor(dynamic_size).Clone();
return Status::OK();
}
const Shape& shape = get_dimension_size->operand(0)->shape();
Literal output(ShapeUtil::MakeShape(U32, {}));
output.PopulateWithValue(
static_cast<uint32>(shape.dimensions(get_dimension_size->dimension())));
evaluated_[get_dimension_size] = std::move(output);
return Status::OK();
}
Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
// Nothing to do other than sanity checks. Parameters' values are stored in
// arg_literals_.
CHECK_LT(parameter->parameter_number(), arg_literals_.size());
#ifndef NDEBUG
const Literal* input_literal = arg_literals_[parameter->parameter_number()];
VLOG(2) << "Parameter evaluated to: " << input_literal->ToString();
DCHECK(Shape::Equal().MinorToMajorOnlyInLayout()(parameter->shape(),
input_literal->shape()))
<< "parameter shape is: "
<< ShapeUtil::HumanStringWithLayout(parameter->shape())
<< ", but input literal shape is: "
<< ShapeUtil::HumanStringWithLayout(input_literal->shape());
#endif
return Status::OK();
}
Status HloEvaluator::HandleConstant(HloInstruction*) { return Status::OK(); }
Status HloEvaluator::HandleReshape(HloInstruction* reshape) {
TF_ASSIGN_OR_RETURN(
evaluated_[reshape],
GetEvaluatedLiteralFor(reshape->operand(0))
.Reshape(AsInt64Slice(reshape->shape().dimensions())));
return Status::OK();
}
Status HloEvaluator::HandleTranspose(HloInstruction* transpose) {
evaluated_[transpose] = GetEvaluatedLiteralFor(transpose->operand(0))
.Transpose(transpose->dimensions());
return Status::OK();
}
Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
absl::Span<HloInstruction* const> operands(concatenate->operands());
// The result concatenate dimension is going to be the sum of all
// concatenate dimensions of the operands taking part of the operation.
const Shape& reference_shape = operands[0]->shape();
CHECK(reference_shape.IsArray());
const int64 rank = reference_shape.rank();
const int64 concat_dim = concatenate->dimensions()[0];
CHECK_GE(concat_dim, 0);
CHECK_LT(concat_dim, rank);
DimensionVector concat_dimensions(reference_shape.dimensions().begin(),
reference_shape.dimensions().end());
for (int64 i = 1; i < operands.size(); ++i) {
const Shape& operand_shape = operands[i]->shape();
CHECK(operand_shape.IsArray());
// Accumulate the concat dimension from all tensors taking part to the
// operation.
concat_dimensions[concat_dim] +=
ShapeUtil::GetDimension(operand_shape, concat_dim);
}
auto result_literal = LiteralUtil::CreateFromDimensions(
reference_shape.element_type(), concat_dimensions);
DimensionVector source_indices(rank, 0);
DimensionVector dest_indices(concat_dimensions.size(), 0);
for (auto operand : operands) {
const Shape& operand_shape = operand->shape();
TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
GetEvaluatedLiteralFor(operand), source_indices, dest_indices,
AsInt64Slice(operand_shape.dimensions())));
dest_indices[concat_dim] +=
ShapeUtil::GetDimension(operand_shape, concat_dim);
}
evaluated_[concatenate] = std::move(result_literal);
return Status::OK();
}
Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) {
auto operand = is_finite->operand(0);
auto elem_ty = operand->shape().element_type();
switch (elem_ty) {
case PRED:
case TUPLE:
case OPAQUE_TYPE:
case TOKEN:
case S8:
case S16:
case S32:
case S64:
case U8:
case U16:
case U32:
case U64:
case C64:
case C128:
// Explicitly enumerate all types in this switch so that when we add a new
// type, we'll get a compile error here.
case PRIMITIVE_TYPE_INVALID:
case PrimitiveType_INT_MIN_SENTINEL_DO_NOT_USE_:
case PrimitiveType_INT_MAX_SENTINEL_DO_NOT_USE_:
return InvalidArgument(
"expected element type in shape to be floating point, but "
"got: %s",
PrimitiveType_Name(elem_ty));
case F16: {
auto result_or = ElementWiseUnaryOpImpl<bool, Eigen::half>(
is_finite,
[](Eigen::half elem_operand) {
return std::isfinite(static_cast<float>(elem_operand));
},
GetEvaluatedLiteralFor(operand));
TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
break;
}
case BF16: {
auto result_or = ElementWiseUnaryOpImpl<bool, bfloat16>(
is_finite,
[](bfloat16 elem_operand) {
return std::isfinite(static_cast<float>(elem_operand));
},
GetEvaluatedLiteralFor(operand));
TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
break;
}
case F32: {
auto result_or = ElementWiseUnaryOpImpl<bool, float>(
is_finite,
[](float elem_operand) { return std::isfinite(elem_operand); },
GetEvaluatedLiteralFor(operand));
TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
break;
}
case F64: {
auto result_or = ElementWiseUnaryOpImpl<bool, double>(
is_finite,
[](double elem_operand) { return std::isfinite(elem_operand); },
GetEvaluatedLiteralFor(operand));
TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
break;
}
}
return Status::OK();
}
Status HloEvaluator::HandleReal(HloInstruction* real) {
auto operand = real->operand(0);
switch (operand->shape().element_type()) {
case BF16: {
auto result_or = ElementWiseUnaryOpImpl<bfloat16, bfloat16>(
real, [](bfloat16 elem_operand) { return elem_operand; },
GetEvaluatedLiteralFor(operand));
TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
break;
}
case C64: {
auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
real, [](complex64 elem_operand) { return std::real(elem_operand); },
GetEvaluatedLiteralFor(operand));
TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
break;
}
case C128: {
auto result_or = ElementWiseUnaryOpImpl<double, complex128>(
real, [](complex128 elem_operand) { return std::real(elem_operand); },
GetEvaluatedLiteralFor(operand));
TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
break;
}
case F16: {
auto result_or = ElementWiseUnaryOpImpl<Eigen::half, Eigen::half>(
real, [](Eigen::half elem_operand) { return elem_operand; },
GetEvaluatedLiteralFor(operand));
TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
break;
}
case F32: {
auto result_or = ElementWiseUnaryOpImpl<float, float>(
real, [](float elem_operand) { return elem_operand; },
GetEvaluatedLiteralFor(operand));
TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
break;
}
case F64: {
auto result_or = ElementWiseUnaryOpImpl<double, double>(
real, [](double elem_operand) { return elem_operand; },
GetEvaluatedLiteralFor(operand));
TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
break;
}
default:
LOG(FATAL) << "HandleReal: unknown/unhandled primitive type: "
<< PrimitiveType_Name(operand->shape().element_type());
}
return Status::OK();
}
Status HloEvaluator::HandleImag(HloInstruction* imag) {
auto operand = imag->operand(0);
switch (operand->shape().element_type()) {
case C64: {
auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
imag, [](complex64 elem_operand) { return std::imag(elem_operand); },
GetEvaluatedLiteralFor(imag->operand(0)));
TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
break;
}
case C128: {
auto result_or = ElementWiseUnaryOpImpl<double, complex128>(
imag, [](complex128 elem_operand) { return std::imag(elem_operand); },
GetEvaluatedLiteralFor(imag->operand(0)));
TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
break;
}
default:
LOG(FATAL) << "HandleImag: unknown/unhandled primitive type: "
<< PrimitiveType_Name(operand->shape().element_type());
}
return Status::OK();
}
Status HloEvaluator::HandleComplex(HloInstruction* complex) {
const Literal& real = GetEvaluatedLiteralFor(complex->operand(0));
const Literal& imag = GetEvaluatedLiteralFor(complex->operand(1));
TF_RET_CHECK(ShapeUtil::Compatible(real.shape(), imag.shape()));
Literal result(complex->shape());
switch (complex->shape().element_type()) {
case C64: {
TF_RETURN_IF_ERROR(
result.Populate<complex64>([&](absl::Span<const int64> multi_index) {
return std::complex<float>(real.Get<float>(multi_index),
imag.Get<float>(multi_index));
}));
break;
}
case C128: {
TF_RETURN_IF_ERROR(
result.Populate<complex128>([&](absl::Span<const int64> multi_index) {
return std::complex<double>(real.Get<double>(multi_index),
imag.Get<double>(multi_index));
}));
break;
}
default:
LOG(FATAL) << "HandleComplex: unknown/unhandled primitive type: "
<< PrimitiveType_Name(complex->shape().element_type());
}
evaluated_[complex] = std::move(result);
return Status::OK();
}
Status HloEvaluator::HandleCompare(HloInstruction* compare) {
ComparisonDirection direction = compare->comparison_direction();
auto lhs = compare->operand(0);
auto rhs = compare->operand(1);
DCHECK(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) &&
ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()));
TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type());
const Literal& lhs_literal = GetEvaluatedLiteralFor(lhs);
const Literal& rhs_literal = GetEvaluatedLiteralFor(rhs);
// Note here we switch on the operand's type.
switch (lhs->shape().element_type()) {
case PRED: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<bool>(compare->shape(), direction, lhs_literal, rhs_literal));
} break;
case U8: {
TF_ASSIGN_OR_RETURN(evaluated_[compare],
Compare<uint8>(compare->shape(), direction,
lhs_literal, rhs_literal));
} break;
case U16: {
TF_ASSIGN_OR_RETURN(evaluated_[compare],
Compare<uint16>(compare->shape(), direction,
lhs_literal, rhs_literal));
} break;
case U32: {
TF_ASSIGN_OR_RETURN(evaluated_[compare],
Compare<uint32>(compare->shape(), direction,
lhs_literal, rhs_literal));
} break;
case U64: {
TF_ASSIGN_OR_RETURN(evaluated_[compare],
Compare<uint64>(compare->shape(), direction,
lhs_literal, rhs_literal));
} break;
case S8: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<int8>(compare->shape(), direction, lhs_literal, rhs_literal));
} break;
case S16: {
TF_ASSIGN_OR_RETURN(evaluated_[compare],
Compare<int16>(compare->shape(), direction,
lhs_literal, rhs_literal));
} break;
case S32: {
TF_ASSIGN_OR_RETURN(evaluated_[compare],
Compare<int32>(compare->shape(), direction,
lhs_literal, rhs_literal));
} break;
case S64: {
TF_ASSIGN_OR_RETURN(evaluated_[compare],
Compare<int64>(compare->shape(), direction,
lhs_literal, rhs_literal));
} break;
case F16: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<half>(compare->shape(), direction, lhs_literal, rhs_literal));
} break;
case BF16: {
TF_ASSIGN_OR_RETURN(evaluated_[compare],
Compare<bfloat16>(compare->shape(), direction,
lhs_literal, rhs_literal));
} break;
case F32: {
TF_ASSIGN_OR_RETURN(evaluated_[compare],
Compare<float>(compare->shape(), direction,
lhs_literal, rhs_literal));
} break;
case F64: {
TF_ASSIGN_OR_RETURN(evaluated_[compare],
Compare<double>(compare->shape(), direction,
lhs_literal, rhs_literal));
} break;
case C64: {
TF_ASSIGN_OR_RETURN(evaluated_[compare],
Compare<complex64>(compare->shape(), direction,
lhs_literal, rhs_literal));
} break;
case C128: {
TF_ASSIGN_OR_RETURN(evaluated_[compare],
Compare<complex128>(compare->shape(), direction,
lhs_literal, rhs_literal));
} break;
default:
LOG(FATAL) << "HandleCompare: unknown primitive type: "
<< PrimitiveType_Name(lhs->shape().element_type());
}
return Status::OK();
}
Status HloEvaluator::HandleTuple(HloInstruction* tuple) {
std::vector<const Literal*> operand_literals;
for (auto operand : tuple->operands()) {
operand_literals.push_back(&GetEvaluatedLiteralFor(operand));
}
evaluated_[tuple] = LiteralUtil::MakeTuple(operand_literals);
return Status::OK();
}
namespace {
// Common code used by 1D implementations, which copies data from the input to
// the contiguous buffer. Returns true if all copied values are zero.
bool GatherToBuffer(absl::Span<complex128> data, int64 length, int64 start,
int64 stride, bool expand_input,
absl::Span<complex128> buffer) {
CHECK_GE(buffer.size(), length);
bool input_is_zero = true;
const int64 ub = expand_input ? length / 2 + 1 : length;
CHECK_GE(data.size(), start + (ub - 1) * stride);
for (int64 k = 0; k < ub; k++) {
complex128 value = data[start + k * stride];
input_is_zero &= value == complex128(0.0, 0.0);
buffer[k] = value;
if (expand_input) {
// Use conjugates of the values at indices [1 ... (ub - 2)] when the
// length is even and at indices [1 ... (ub - 1)] when the length is odd
// to calculate missing values at indices [(length - 1) ... ub].
if (k > 0 && k < (length - ub + 1)) {
buffer[length - k] = std::conj(value);
}
}
}
return input_is_zero;
}
// Returns (conjugated, if 'inverse' is true) k-th twiddle for the given length.
inline complex128 Twiddle(int64 k, int64 length, bool inverse) {
auto coeff = std::exp(complex128(0.0, -2.0 * M_PI * k / length));
return inverse ? std::conj(coeff) : coeff;
}
// Straightforward implementation of 1D DFT transform of arbitrary length. Uses
// passed-in start index and stride to gather inputs from the data vector into
// the preallocated buffer, computes the result, and writes it back to the same
// locations in the data vector. Runs in O(length^2) time.
//
// Parameters contract_output and expand_input are used to avoid unnecessary
// calculations. When contract_output is set to true, then only (length / 2) + 1
// output values are computed. When expand_input is set to true, then
// (length / 2) + 1 values from the data set are used to re-create the full set
// of size 'length', on which the transform is then performed.
//
void NaiveDft1D(int64 length, int64 start, int64 stride, bool inverse,
bool contract_output, bool expand_input,
absl::Span<complex128> data, absl::Span<complex128> buffer) {
const bool input_is_zero =
GatherToBuffer(data, length, start, stride, expand_input, buffer);
if (!input_is_zero) {
const int64 ub = contract_output ? length / 2 + 1 : length;
for (int64 k = 0; k < ub; k++) {
complex128 value = complex128(0.0, 0.0);
for (int n = 0; n < length; n++) {
value += buffer[n] * Twiddle(n * k, length, inverse);
}
data[start + k * stride] =
inverse ? value / complex128(length, 0.0) : value;
}
}
}
// Non-recursive implementation of the Cooley-Tukey radix-2 decimation in time.
// Performs 1D FFT transform for the lengths, which are powers of 2. Runs in
// O(length * log(length)) time. Uses the same parameters as the naive
// implementation above, except that the preallocated buffer must be at least
// twice as big as the length of the transform, because the buffer is used to
// hold both input and output values for each stage of the transform.
//
void Fft1D(int64 length, int64 start, int64 stride, bool inverse,
bool contract_output, bool expand_input, absl::Span<complex128> data,
absl::Span<complex128> buffer) {
CHECK(IsPowerOfTwo(static_cast<uint64>(length)));
const bool input_is_zero =
GatherToBuffer(data, length, start, stride, expand_input, buffer);
if (!input_is_zero) {
auto generate_twiddles = [](int64 length, bool inverse) {
std::vector<complex128> twiddles;
// Need only half the twiddles.
for (int64 k = 0; k < length / 2; k++) {
twiddles.push_back(Twiddle(k, length, inverse));
}
return twiddles;
};
// Indices into the parts of the buffer used for input and output values.
int64 in_base = length;
int64 out_base = 0;
// At each stage, we "split" the input data into num_blocks, with block_size
// values in each block.
for (int64 num_blocks = 1; num_blocks < length; num_blocks *= 2) {
// Swap input and output parts of the buffer.
std::swap(in_base, out_base);
auto twiddles = generate_twiddles(num_blocks * 2, inverse);
const int64 block_size = length / num_blocks;
const int64 next_iteration_block_size = block_size / 2;
for (int64 block = 0; block < num_blocks; block++) {
const int64 in_offset = in_base + block * block_size;
const int64 out_offset = out_base + block * next_iteration_block_size;
// For each (even, odd) pair of values in the block, calculate two
// output values as even + twiddle * odd and even - twiddle * odd.
for (int64 pair = 0; pair < block_size / 2; pair++) {
const complex128 even = buffer[in_offset + pair];
const complex128 odd = buffer[in_offset + block_size / 2 + pair];
const complex128 twiddled_odd = twiddles[block] * odd;
buffer[out_offset + pair] = even + twiddled_odd;
buffer[out_offset + length / 2 + pair] = even - twiddled_odd;
}
}
}
// Copy computed result back to data.
const int64 ub = contract_output ? length / 2 + 1 : length;
for (int64 k = 0; k < ub; k++) {
complex128 value = buffer[out_base + k];
data[start + k * stride] =
inverse ? value / complex128(length, 0.0) : value;
}
}
}
// Determine, which implementation of 1D transform to use and call it.
void Dft1D(int64 length, int64 start, int64 stride, bool inverse,
bool contract_output, bool expand_input, absl::Span<complex128> data,
absl::Span<complex128> buffer) {
if (IsPowerOfTwo(static_cast<uint64>(length))) {
Fft1D(length, start, stride, inverse, contract_output, expand_input, data,
buffer);
} else {
NaiveDft1D(length, start, stride, inverse, contract_output, expand_input,
data, buffer);
}
}
// Helper to reverse the order of dimension lengths in the passed-in literal.
std::vector<int64> GetDimensionLengths(const Literal& literal) {
std::vector<int64> lengths = literal.shape().dimensions();
absl::c_reverse(lengths);
return lengths;
}
// Helper to compute strides for creating linear indices into multidimensional
// data from the dimension lengths and the layout. Returns a new vector of size
// lengths.size() + 1. The last element of the returned vector at index
// [lengths.size()] contains the product of all dimension lengths.
std::vector<int64> ComputeStrides(const absl::Span<const int64> lengths,
const Layout& layout) {
const int64 num_dimensions = lengths.size();
// Make sure that the layout length matches the number of dimensions.
CHECK_EQ(num_dimensions, layout.minor_to_major_size());
// Calculate strides using layout-specified ordering of the dimensions and
// place the stride for axis 0 at index 0, for axis 1 at index 1, etc.
std::vector<int64> strides(num_dimensions + 1);
int64 stride = 1;
for (int64 i = 0; i < num_dimensions; i++) {
// Reverse the ordering of the dimensions in the layout.
const int64 index = (num_dimensions - 1) - layout.minor_to_major(i);
strides[index] = stride;
stride *= lengths[index];
}
strides[num_dimensions] = stride;
return strides;
}
// Compute strides as above using the default layout.
std::vector<int64> ComputeStrides(const absl::Span<const int64> lengths) {
return ComputeStrides(lengths,
LayoutUtil::GetDefaultLayoutForRank(lengths.size()));
}
// Compute strides as above using the layout from the literal, if available.
std::vector<int64> ComputeStrides(const absl::Span<const int64> lengths,
const Literal& literal) {
return literal.shape().has_layout()
? ComputeStrides(lengths, literal.shape().layout())
: ComputeStrides(lengths);
}
// Make 1D sweeps along each transform axis.
void Sweep(int64 fft_rank, FftType fft_type,
const absl::Span<const int64> fft_lengths,
const absl::Span<const int64> fft_strides,
absl::Span<complex128> data, absl::Span<complex128> buffer) {
const bool inverse = fft_type == FftType::IFFT || fft_type == FftType::IRFFT;
const bool input_is_truncated = fft_type == FftType::IRFFT;
const bool output_is_truncated = fft_type == FftType::RFFT;
// Recursively visit each column of the data along the sweep_axis. Calculate
// linearized index of that column's first element and the stride, then invoke
// 1D transform.
// For RFFT, avoid calculating unused output values: first, compute only
// (length_x / 2) + 1 values along the X axis, then limit the X coordinate to
// [0 ... (length / 2)] during the sweeps along other axes. Similarly, for
// IRFFT sweep along higher dimensions first, while keeping the X coordinate
// in the [0 ... (length / 2)] range, then re-create negative frequencies
// omitted in the input and perform the full-length transform along the X axis
// in the last sweep.
std::function<void(int64, int64, int64)> sweep = [&](int64 sweep_axis,
int64 axis,
int64 start) {
if (axis < 0) {
// Base case: invoke 1D transform.
const int64 length = fft_lengths[sweep_axis];
const int64 stride = fft_strides[sweep_axis];
const bool expand_input = input_is_truncated && sweep_axis == 0;
const bool contract_oputput = output_is_truncated && sweep_axis == 0;
Dft1D(length, start, stride, inverse, contract_oputput, expand_input,
data, buffer);
} else if (axis == sweep_axis) {
// Visit only the elements with coordinate 0 along the sweep axis.
sweep(sweep_axis, axis - 1, start);
} else {
const int64 length = fft_lengths[axis];
const bool is_truncated = input_is_truncated || output_is_truncated;
const int64 ub = is_truncated && axis == 0 ? (length / 2) + 1 : length;
for (int64 i = 0; i < ub; i++) {
sweep(sweep_axis, axis - 1, start + i * fft_strides[axis]);
}
}
};
if (input_is_truncated) {
// Sweep along the X axis last for IRFFT.
for (int64 sweep_axis = fft_rank - 1; sweep_axis >= 0; sweep_axis--) {
sweep(sweep_axis, fft_rank - 1, 0);
}
} else {
// Sweep along the X axis first for RFFT. The order does not matter for FFT
// and IFFT types; handle them here as well.
for (int64 sweep_axis = 0; sweep_axis < fft_rank; sweep_axis++) {
sweep(sweep_axis, fft_rank - 1, 0);
}
}
}
// These templates convert the data from the input data type to the type used in
// calculations and then to the output data type. They are intended to be used
// only within the DFT implementation. One special case is IRFFT, where the
// specialization drops imaginary parts of complex values (which is expected to
// be 0) and returns real numbers.
template <typename ToType, typename FromType>
ToType GetAs(FromType value) {
return static_cast<ToType>(value);
}
template <>
float GetAs<float, complex128>(complex128 value) {
return static_cast<float>(value.real());
}
// This template generates two linearized indices, which can be used to access
// multidimensional arrays. It uses a recursive function, which passes the
// indices to the user-supplied callback function. The destination index is
// always within dst_lengths[] bounds. The boolean parameter within_src_bounds
// indicates whether the source index is within src_lengths[] bounds.
//
// The value returned from the callback function controls the recursion depth.
// Returning true indicates that the base case had been hit and the recursion
// stops. Otherwise, the recursion proceeds along the next less-major axis.
//
// For example, the base case when the axis value becomes negative invokes the
// callback function for each possible index within dst_lengths[] bounds. The
// base case when the axis value is equal to zero limits the indices to point
// only to first elements along the minor-most dimension, allowing the callback
// function to handle all values along the X axis.
//
template <typename BaseFn>
void GenerateIndices(const absl::Span<const int64> dst_lengths,
const absl::Span<const int64> dst_strides,
const absl::Span<const int64> src_lengths,
const absl::Span<const int64> src_strides, int64 fft_rank,
int64 dst_start, int64 src_start, BaseFn&& base) {
CHECK_EQ(dst_lengths.size() + 1, dst_strides.size());
CHECK_GE(dst_lengths.size(), fft_rank);
CHECK_EQ(src_lengths.size() + 1, src_strides.size());
CHECK_GE(src_lengths.size(), fft_rank);
std::function<void(int64, int64, int64, bool)> generate =
[&](int64 axis, int64 dst_index, int64 src_index,
bool within_src_bounds) {
if (!base(axis, dst_index, src_index, within_src_bounds)) {
for (int64 i = 0; i < dst_lengths[axis]; i++) {
// Because the loop goes over dst_lengths[], the source index may be
// out of src_lengths[] bounds. In this case, within_src_bounds is
// false.
within_src_bounds &= i < src_lengths[axis];
generate(axis - 1, dst_index, src_index, within_src_bounds);
dst_index += dst_strides[axis];
src_index += src_strides[axis];
}
}
};
generate(fft_rank - 1, dst_start, src_start, true);
}
// Copies the input data from a literal to a pre-allocated vector. The sizes of
// the input and the transform do not need to match. For each axis of the
// transform, any extra input values beyond the transform length are ignored.
// Conversely, if the input does not contain enough elements along any axis, the
// data is padded with zeroes.
//
// For IRFFT transforms, we use (length_x / 2) + 1 elements from the input,
// where length_x is the size of the full transform along the X axis.
//
// The input literal may have a rank higher than the rank of the transform.
// Passed-in input_index value points to the first element of the input literal
// to be copied.
//
// Returns true if all values in the work data set are zeroes.
//
template <typename InputType>
bool CopyDataFromInput(const Literal& input_literal, int64 input_start,
int64 fft_rank, FftType fft_type, int64 fft_size,
const absl::Span<const int64> fft_lengths,
const absl::Span<const int64> fft_strides,
const absl::Span<const int64> input_lengths,
const absl::Span<const int64> input_strides,
absl::Span<complex128> data) {
CHECK_GE(data.size(), fft_size);
const bool input_is_truncated = fft_type == FftType::IRFFT;
// Recursively visit each transform dimension to copy input values to the
// working data set. The base case handles inputs along the X axis.
bool input_is_zero = true;
const InputType* input_data = input_literal.data<InputType>().data();
auto base_case = [&](int64 axis, int64 dst_index, int64 src_index,
bool within_src_bounds) {
if (axis == 0) {
// For IRFFT, the negavie frequencies are only needed for the sweep along
// the X axis, which is performed last. Leave this part of the working set
// uninitialized until then.
const int64 length = fft_lengths[axis];
const int64 ub = input_is_truncated ? (length / 2) + 1 : length;
for (int64 i = 0; i < ub; i++) {
complex128 value = InputType(0);
// Read input value only if the index is within bounds.
if (within_src_bounds && i < input_lengths[axis]) {
value = GetAs<complex128, InputType>(
input_data[src_index + i * input_strides[axis]]);
input_is_zero &= value == complex128(0.0, 0.0);
}
data[dst_index + i * fft_strides[axis]] = value;
}
return true;
}
return false;
};
GenerateIndices(fft_lengths, fft_strides, input_lengths, input_strides,
fft_rank, 0, input_start, base_case);
return input_is_zero;
}
// Copies the result of the transform to the literal output. The sizes of the
// transform and output must match.
//
// For RFFT transforms, we copy (length_x / 2) + 1 elements, where length_x is
// the size of the full transform along the X axis (the most minor dimension).
//
// The output literal may have a rank higher than the rank of the transform.
// Passed-in output_index value points to the first element of the output
// literal to be filled in.
//
template <typename OutputType>
void CopyDataToOutput(const absl::Span<complex128> data, int64 output_start,
int64 fft_rank, FftType fft_type,
const absl::Span<const int64> fft_lengths,
const absl::Span<const int64> fft_strides,
const absl::Span<const int64> output_lengths,
const absl::Span<const int64> output_strides,
Literal* output_literal) {
const bool output_is_truncated = fft_type == FftType::RFFT;
// Base case for recursive copy of the results to the output. The code avoids
// making a recursive call for each output element by handling axis 0 in the
// loop (as opposed to making "axis < 0" to be the base case).
OutputType* output_data = output_literal->data<OutputType>().data();
auto base_case = [&](int64 axis, int64 dst_index, int64 src_index,
bool within_src_bounds) {
if (axis == 0) {
// Drop negative frequencies for RFFT.
const int64 length = fft_lengths[axis];
const int64 ub = output_is_truncated ? (length / 2) + 1 : length;
for (int64 i = 0; i < output_lengths[axis]; i++) {
OutputType value = OutputType(0);
// Read data only if the index is within bounds.
if (within_src_bounds && i < ub) {
value = GetAs<OutputType, complex128>(
data[src_index + i * fft_strides[axis]]);
}
output_data[dst_index + i * output_strides[axis]] = value;
}
return true;
}
return false;
};
GenerateIndices(output_lengths, output_strides, fft_lengths, fft_strides,
fft_rank, output_start, 0, base_case);
}
// Determine the type to use with the CopyDataFromInput<> template above.
bool CopyDataFromInput(const Literal& input_literal, int64 input_start,
int64 fft_rank, FftType fft_type, int64 fft_size,
const absl::Span<const int64> fft_lengths,
const absl::Span<const int64> fft_strides,
const absl::Span<const int64> input_lengths,
const absl::Span<const int64> input_strides,
absl::Span<complex128> data) {
const bool input_is_float = fft_type == FftType::RFFT;
if (input_is_float) {
return CopyDataFromInput<float>(
input_literal, input_start, fft_rank, fft_type, fft_size, fft_lengths,
fft_strides, input_lengths, input_strides, data);
} else {
return CopyDataFromInput<complex64>(
input_literal, input_start, fft_rank, fft_type, fft_size, fft_lengths,
fft_strides, input_lengths, input_strides, data);
}
}
// Determine the type to use with the CopyDataToOutput<> template above.
void CopyDataToOutput(const absl::Span<complex128> data, int64 output_start,
int64 fft_rank, FftType fft_type,
const absl::Span<const int64> fft_lengths,
const absl::Span<const int64> fft_strides,
const absl::Span<const int64> output_lengths,
const absl::Span<const int64> output_strides,
Literal* output_literal) {
const bool output_is_float = fft_type == FftType::IRFFT;
if (output_is_float) {
CopyDataToOutput<float>(data, output_start, fft_rank, fft_type, fft_lengths,
fft_strides, output_lengths, output_strides,
output_literal);
} else {
CopyDataToOutput<complex64>(data, output_start, fft_rank, fft_type,
fft_lengths, fft_strides, output_lengths,
output_strides, output_literal);
}
}
Status CheckParameters(const Shape& input_shape, const Shape& output_shape,
int64 fft_rank, FftType fft_type,
const absl::Span<const int64> fft_lengths) {
// Check FFT parameters.
if (fft_rank <= 0) {
return InvalidArgument("Zero or negative FFT rank.");
}
if (*absl::c_min_element(fft_lengths) < 0) {
return InvalidArgument("Negative FFT length.");
}
// Check input-related values.
TF_CHECK_OK(ShapeUtil::ValidateShape(input_shape));
if (!input_shape.IsArray()) {
return Unimplemented("Only array input shapes are supported.");
}
auto input_elt_type = input_shape.element_type();
if (fft_type == FftType::RFFT && input_elt_type != PrimitiveType::F32) {
return InvalidArgument("Invalid input type: %d, must be %d (float).",
input_elt_type, PrimitiveType::F32);
}
if (fft_type != FftType::RFFT && input_elt_type != PrimitiveType::C64) {
return InvalidArgument("Invalid input type: %d, must be %d (complex64).",
input_elt_type, PrimitiveType::C64);
}
const int64 input_rank = input_shape.rank();
if (input_rank < fft_rank) {
return InvalidArgument("Input shape rank is smaller than FFT rank.");
}
// Check output-related values.
TF_CHECK_OK(ShapeUtil::ValidateShape(output_shape));
if (!output_shape.IsArray()) {
return Unimplemented("Only array output shapes are supported.");
}
auto output_elt_type = output_shape.element_type();
if (fft_type == FftType::IRFFT && output_elt_type != PrimitiveType::F32) {
return InvalidArgument("Invalid output type: %d, must be %d (float).",
output_elt_type, PrimitiveType::F32);
}
if (fft_type != FftType::IRFFT && output_elt_type != PrimitiveType::C64) {
return InvalidArgument("Invalid output type: %d, must be %d (complex64).",
output_elt_type, PrimitiveType::C64);
}
const int64 output_rank = output_shape.rank();
if (output_rank < fft_rank) {
return InvalidArgument("Output shape rank is smaller than FFT rank.");
}
// Consistency of input and output parameters.
if (input_rank != output_rank) {
return InvalidArgument(
"Ranks of input shape and output shape do not match.");
}
for (int64 dim = 0; dim < input_rank - fft_rank; dim++) {
if (ShapeUtil::GetDimension(input_shape, dim) !=
ShapeUtil::GetDimension(output_shape, dim)) {
return InvalidArgument(
"Higher dimension lengths of input shape and output shape do not "
"match.");
}
}
return Status::OK();
}
} // namespace
// Flexible implementation of the discrete Fourier transform. All transform
// types (FFT, IFFT, RFFT, and IRFFT) are supported, as well as the arbitrary
// rank and length of each dimension of the transform, and arbitrary layouts of
// the input and output literals.
//
// The input literal in operand 0 provides input data, which must be complex64
// for FFT, IFFT, IRFFT transforms and float for RFFT. The transform is computed
// over the innermost dimensions of the input, thus the rank of the input data
// must be same as fft_rank or larger. The input is expected to provide Ni
// values along each transform axis with one exception: for IRFFT, only
// (N0 / 2) + 1 values are needed along the X axis (the innermost index). To
// increase flexibility, this implementation can handle mismatches between the
// input size and transform lengths by either dropping extra input values or
// using zeroes in place of missing input values as necessary. If the input data
// has rank higher than the transform, the transform is applied for each valid
// combination of the higher-ranking indices.
//
// The output contains complex64 values for FFT, IFFT, RFFT, and float values
// for IRFFT. The rank of the output as well as the sizes of the dimensions
// above the rank of the transform must match those of the input. Sizes of the
// output's "fft_rank" innermost dimensions are expected to match the length of
// the transform along respective axes with one exception: for RFFT, the output
// is trimmed along the X axis to have only (N0 / 2) + 1 values. In case the
// length(s) mismatch, the FFT output is trimmed to fit into the provided output
// shape, or the output is padded with zero values appropriately.
//
// For example, 2D FFT transform of size 16x16 applied to complex64[2][15][17]
// input array will perform two transforms over the [][15][17] data in the sub
// arrays [0][][] and [1][][], dropping the values along axis X and padding axis
// Y with zeroes to create 16x16 working sets, and generating
// complex64[2][16][16] output. 3D IRFFT transform of size 64x16x16 applied to
// complex64[64][16][9] input array will use all input values and will produce
// float[64][16][16] output.
//
// The implementation of the 1D transform for lengths, that are powers of 2, is
// the Cooley-Tukey radix-2 decimation-in-time. For all other 1D transform
// lengths, a straightforward, but slow, loop nest is used. The transforms of
// higher ranks apply sets of 1D transforms along each axis. For example, the 2D
// transform is computed by applying 1D transforms to each column followed by
// applying 1D transforms to each row.
//
// In general, a transform of rank n runs in O(N0*N1*...*Nn*(N0+N1+...+Nn))
// time, where Ni is the length of the transform's i-th dimension. However, for
// dimension lengths, which are powers of 2, the run time along these dimensions
// is reduced to log(Ni) in the summation, giving the runtime of
// O(N0*N1*...*Nn*(log(N0)+log(N1)+...+log(Nn)) in the best case.
//
Status HloEvaluator::HandleFft(HloInstruction* fft) {
const FftType fft_type = fft->fft_type();
std::vector<int64> fft_lengths = fft->fft_length();
const int64 fft_rank = fft_lengths.size();
const Literal& input_literal = GetEvaluatedLiteralFor(fft->operand(0));
const Shape& input_shape = input_literal.shape();
const Shape& output_shape = fft->shape();
Literal output_literal = Literal::CreateFromShape(output_shape);
// Make fft_lengths[0] the minor-most dimension.
absl::c_reverse(fft_lengths);
TF_RETURN_IF_ERROR(CheckParameters(input_shape, output_shape, fft_rank,
fft_type, fft_lengths));
const auto fft_strides = ComputeStrides(fft_lengths);
// Working set size.
const int64 fft_size = fft_strides[fft_rank];
if (fft_size > 0) {
// Linearized working data set.
std::vector<complex128> data(fft_size);
// Temporary buffer allocated once and used in 1D sweeps. For dimension
// length values that are powers of 2, the buffer should be twice as large.
int64 buffer_size = 0;
for (auto len : fft_lengths) {
int64 size = IsPowerOfTwo(static_cast<uint64>(len)) ? len * 2 : len;
buffer_size = std::max(buffer_size, size);
}
std::vector<complex128> buffer(buffer_size);
// Sizes of each axis of input and output literals.
const auto input_lengths = GetDimensionLengths(input_literal);
const auto output_lengths = GetDimensionLengths(output_literal);
// Strides for generating linearized indices into multidimensional arrays.
const auto input_strides = ComputeStrides(input_lengths, input_literal);
const auto output_strides = ComputeStrides(output_lengths, output_literal);
// Visit all elements in the dimensions with ranks above the FFT rank. For
// each such element invoke the transform. Use separate indices for the
// input and the output to allow different layouts.
auto base_case = [&](int64 axis, int64 output_index, int64 input_index,
bool within_src_bounds) {
if (axis == fft_rank - 1) {
// Base case: copy the data from the input literal, apply the
// transform, and copy the result to the output literal.
CHECK(within_src_bounds);
bool input_is_zero =
CopyDataFromInput(input_literal, input_index, fft_rank, fft_type,
fft_size, fft_lengths, fft_strides, input_lengths,
input_strides, absl::MakeSpan(data));
if (!input_is_zero) {
// Make 1D sweeps along each transform axis.
Sweep(fft_rank, fft_type, fft_lengths, fft_strides,
absl::MakeSpan(data), absl::MakeSpan(buffer));
}
CopyDataToOutput(absl::MakeSpan(data), output_index, fft_rank, fft_type,
fft_lengths, fft_strides, output_lengths,
output_strides, &output_literal);
return true;
}
return false;
};
GenerateIndices(output_lengths, output_strides, input_lengths,
input_strides, input_shape.rank(), 0, 0, base_case);
}
evaluated_[fft] = std::move(output_literal);
return Status::OK();
}
// Returns an ShapeUtil::IndexIterationSpace that iterates over the output batch
// dimensions while keeping the rest of the output dimensions clamped to 0.
ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices(
const Shape& output_shape, const GatherDimensionNumbers& dim_numbers) {
int64 output_rank = output_shape.dimensions_size();
std::vector<int64> index_base(output_rank, 0);
std::vector<int64> index_count;
index_count.reserve(output_rank);
for (int64 i = 0; i < output_rank; i++) {
bool is_output_batch_dim =
!absl::c_binary_search(dim_numbers.offset_dims(), i);
index_count.push_back(is_output_batch_dim ? output_shape.dimensions(i) : 1);
}
return {std::move(index_base), std::move(index_count),
std::vector<int64>(output_rank, 1)};
}
// Return an ShapeUtil::IndexIterationSpace that iterates over the output slice
// dimensions while keeping the rest of the output dimensions clamped to 0.
ShapeUtil::IndexIterationSpace IterationSpaceForOutputOffsetIndices(
int64 output_rank, absl::Span<const int64> slice_sizes,
const GatherDimensionNumbers& dim_numbers) {
std::vector<int64> index_base(output_rank, 0);
std::vector<int64> index_count(output_rank, 1);
int64 slice_sizes_idx = 0;
for (int64 i = 0; i < output_rank; i++) {
bool is_output_window_dim =
absl::c_binary_search(dim_numbers.offset_dims(), i);
if (is_output_window_dim) {
while (absl::c_binary_search(dim_numbers.collapsed_slice_dims(),
slice_sizes_idx)) {
slice_sizes_idx++;
}
index_count[i] = slice_sizes[slice_sizes_idx++];
}
}
return {std::move(index_base), std::move(index_count),
std::vector<int64>(output_rank, 1)};
}
// This functor computes the contribution of start_indices to an input index
// corresponding to an output index. That is, given an output index I, it picks
// out the batch indices in I and uses them to look up a starting index, G, from
// the start indices tensor, and expands G into the input space according to
// start_index_map.
class OutputBatchIndexToInputIndex {
public:
// The constructor does some setup work that is amortized across all
// iterations.
explicit OutputBatchIndexToInputIndex(
const GatherDimensionNumbers* dim_numbers, const Shape& input_shape,
const Shape& output_shape, const Literal* start_indices)
: dim_numbers_(*dim_numbers), start_indices_(*start_indices) {
for (int64 i = 0; i < output_shape.dimensions_size(); i++) {
output_dim_is_batch_dims_.push_back(
!absl::c_binary_search(dim_numbers_.offset_dims(), i));
}
for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
int64 index_of_input_dim_in_index_vector =
std::distance(dim_numbers_.start_index_map().begin(),
absl::c_find(dim_numbers_.start_index_map(), i));
if (index_of_input_dim_in_index_vector ==
dim_numbers_.start_index_map_size()) {
input_dim_value_to_index_vector_.push_back(-1);
} else {
input_dim_value_to_index_vector_.push_back(
index_of_input_dim_in_index_vector);
}
}
index_vector_index_.resize(start_indices_.shape().dimensions_size());
input_index_.resize(input_shape.dimensions_size());
int64 index_vector_size =
start_indices_.shape().dimensions(dim_numbers_.index_vector_dim());
index_vector_.resize(index_vector_size);
}
// Returns the contribution of start_indices to the input index corresponding
// to output_index. See gather_inner_loop_body.
//
// This is conceptually a stateless transformation from output_index to the
// gather input index, but:
//
// - Instead of allocating memory to represent the gather input index on
// every invocation we reuse the same storage for the result
// (input_index_), mutating it in place.
// - Instead of allocating buffers for temporary values like
// index_vector_index_ and index_vector on every invocation, we reuse the
// same storage for all invocations.
//
// This returns a Span into memory owned by the class.
StatusOr<absl::Span<const int64>> operator()(
absl::Span<const int64> output_index) {
PropagateOutputIndexGatherDimsToIndexVectorIndex(output_index);
TF_RETURN_IF_ERROR(FetchIndexVector());
PropagateIndexVectorToInputIndex();
return absl::Span<const int64>(input_index_);
}
private:
// Propagates the batch dimensions from the output index into
// index_vector_index_ by mutating index_vector_index_ in place. Does not
// update the dim_numbers.index_vector_dim() dimension -- that's the dimension
// we iterate over in FetchIndexVector.
void PropagateOutputIndexGatherDimsToIndexVectorIndex(
absl::Span<const int64> output_index) {
int64 index_vector_index_i = 0;
for (int64 i = 0, e = output_index.size(); i < e; i++) {
if (!output_dim_is_batch_dims_[i]) {
continue;
}
if (index_vector_index_i == dim_numbers_.index_vector_dim()) {
index_vector_index_i++;
}
index_vector_index_[index_vector_index_i++] = output_index[i];
}
}
// Populates index_vector_ by iterating over start_indices_ according to
// index_vector_index_.
Status FetchIndexVector() {
int64 index_vector_dim = dim_numbers_.index_vector_dim();
for (int64 i = 0, e = index_vector_.size(); i < e; i++) {
index_vector_index_[index_vector_dim] = i;
TF_ASSIGN_OR_RETURN(index_vector_[i],
start_indices_.GetIntegralAsS64(index_vector_index_));
}
return Status::OK();
}
// Populates input_index_.
void PropagateIndexVectorToInputIndex() {
for (int64 i = 0, e = input_index_.size(); i < e; i++) {
if (input_dim_value_to_index_vector_[i] != -1) {
input_index_[i] = index_vector_[input_dim_value_to_index_vector_[i]];
}
// If input_dim_value_to_index_vector_[i] == -1 then input_index_[i]
// remains 0, as set by the constructor.
}
}
// input_dim_value_to_index_vector_[i] tells us how to compute dimension i of
// the input index from the index vector. See
// PropagateIndexVectorToInputIndex.
std::vector<int64> input_dim_value_to_index_vector_;
// output_dim_is_batch_dims_[i] is true iff the output index i is a gather
// dimension.
std::vector<bool> output_dim_is_batch_dims_;
// The buffer into which we construct an index into start_indices_ to fetch
// the index vector.
std::vector<int64> index_vector_index_;
// The index vector fetched from start_indices_.
std::vector<int64> index_vector_;
// The result computed by this functor. operator() returns a Span into
// this vector.
std::vector<int64> input_index_;
const GatherDimensionNumbers& dim_numbers_;
const Literal& start_indices_;
};
// This functor computes the contribution of the offset indices in an output
// index to an input index. That is, given an output index I it picks out the
// output offset indices in I and expands it into an index into the input shape.
class OutputOffsetIndexToInputIndex {
public:
// The constructor does some setup work that is amortized across all
// iterations.
explicit OutputOffsetIndexToInputIndex(
const GatherDimensionNumbers& dim_numbers, const Shape& input_shape,
const Shape& output_shape) {
std::vector<int64> window_index_to_output_index;
int64 output_index_count = 0;
for (int64 i = 0; i < output_shape.dimensions_size(); i++) {
if (absl::c_binary_search(dim_numbers.offset_dims(), i)) {
window_index_to_output_index.push_back(output_index_count++);
} else {
output_index_count++;
}
}
int64 window_dim_count = 0;
for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
input_dim_value_to_output_index_.push_back(-1);
} else {
input_dim_value_to_output_index_.push_back(
window_index_to_output_index[window_dim_count++]);
}
}
input_index_.resize(input_shape.dimensions_size());
}
// Returns the contribution of the window indices to the input index
// corresponding to output_index. See gather_inner_loop_body.
//
// This is conceptually a stateless transformation from output_index to the
// window input index, but instead of allocating memory to represent the
// gather input index on every invocation we reuse the same storage for the
// result (input_index_), mutating it in place.
//
// This returns a Span into memory owned by the class.
StatusOr<absl::Span<const int64>> operator()(
absl::Span<const int64> output_index) {
PropagateOutputIndexWindowDimsToInputIndex(output_index);
return absl::Span<const int64>(input_index_);
}
// Returns for a given 'input_dim' the corresponding output dimension index,
// or -1 if 'input_dim' is an elided window dimension.
int64 input_dim_value_to_output_index(int64 input_dim) {
return input_dim_value_to_output_index_[input_dim];
}
private:
// Propagates window dimensions from the output index to input_index_ by
// mutating input_index_ in place.
void PropagateOutputIndexWindowDimsToInputIndex(
absl::Span<const int64> output_index) {
for (int64 i = 0, e = input_index_.size(); i < e; i++) {
if (input_dim_value_to_output_index_[i] != -1) {
input_index_[i] = output_index[input_dim_value_to_output_index_[i]];
}
// If input_dim_value_to_index_vector_[i] == -1 then input_index_[i]
// remains 0, as set by the constructor.
}
}
// input_dim_value_to_index_vector_[i] tells us how to compute dimension i of
// the input index from the output index. See
// PropagateOutputIndexWindowDimsToInputIndex.
std::vector<int64> input_dim_value_to_output_index_;
// The result computed by this functor. operator() returns a Span into
// this vector.
std::vector<int64> input_index_;
};
// Rehapes the gather indices input to have a trailing degenerate `1` dimension
// if necessary. Hands over the ownership of the newly created literal (if
// there is one) to `reshaped_start_indices`.
static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
int64 index_vector_dim, const Literal& start_indices,
Literal* reshaped_start_indices) {
if (start_indices.shape().dimensions_size() != index_vector_dim) {
return std::cref(start_indices);
}
std::vector<int64> new_shape(start_indices.shape().dimensions().begin(),
start_indices.shape().dimensions().end());
new_shape.push_back(1);
TF_ASSIGN_OR_RETURN(*reshaped_start_indices,
start_indices.Reshape(new_shape));
return std::cref(*reshaped_start_indices);
}
Status HloEvaluator::HandleGather(HloInstruction* gather) {
Literal result = Literal::CreateFromShape(gather->shape());
const Shape& shape = gather->shape();
const GatherDimensionNumbers& dim_numbers =
gather->gather_dimension_numbers();
const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0));
Literal reshaped_start_indices;
TF_ASSIGN_OR_RETURN(
const Literal& start_indices,
ReshapedGatherIndices(dim_numbers.index_vector_dim(),
GetEvaluatedLiteralFor(gather->operand(1)),
&reshaped_start_indices));
// We iterate over the gather dimensions in the output shape in an outer loop
// nest, and iterate over the window dimensions in the output shape in an
// inner loop nest.
ShapeUtil::IndexIterationSpace start_indices_iteration_space =
IterationSpaceForOutputBatchIndices(shape, dim_numbers);
ShapeUtil::IndexIterationSpace offset_indices_iteration_space =
IterationSpaceForOutputOffsetIndices(
shape.dimensions_size(), gather->gather_slice_sizes(), dim_numbers);
// Scratch buffers that hold an index in the output shape and the
// corresponding index in the input shape.
std::vector<int64> input_index(operand.shape().dimensions_size());
std::vector<int64> output_index(gather->shape().dimensions_size());
std::vector<int64> input_index_clamped(operand.shape().dimensions_size());
OutputBatchIndexToInputIndex output_batch_index_to_input_index(
&gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
/*output_shape=*/shape, &start_indices);
OutputOffsetIndexToInputIndex output_offset_index_to_input_index(
gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
/*output_shape=*/shape);
const Shape& operand_shape = operand.shape();
auto gather_inner_loop_body =
[&](absl::Span<const int64> output_window_index,
absl::Span<const int64> input_gather_index,
absl::Span<const int64> output_gather_index) -> StatusOr<bool> {
TF_ASSIGN_OR_RETURN(
absl::Span<const int64> input_window_index,
output_offset_index_to_input_index(output_window_index));
for (int i = 0, e = output_index.size(); i < e; i++) {
output_index[i] = output_gather_index[i] + output_window_index[i];
DCHECK_LT(output_index[i], shape.dimensions(i));
}
for (int i = 0, e = input_gather_index.size(); i < e; i++) {
int64 output_dim =
output_offset_index_to_input_index.input_dim_value_to_output_index(i);
// If 'output_dim' is -1, it means 'i' is an elided window dim. This means
// we set the iteration index to 0, so for the purpose of the following
// calculations we can consider the output dimension size to be 1.
int64 output_dim_size =
output_dim == -1 ? 1 : shape.dimensions(output_dim);
// Clamp the gather index so that the gather region fits in the operand.
// input_index_clamped[i] = clamp(input_gather_index[i], 0,
// operand_shape.dimensions(i) -
// output_dim_size);
input_index_clamped[i] =
std::min(operand_shape.dimensions(i) - output_dim_size,
std::max(0LL, input_gather_index[i]));
}
for (int i = 0, e = input_index.size(); i < e; i++) {
input_index[i] = input_index_clamped[i] + input_window_index[i];
DCHECK_GE(input_index[i], 0);
DCHECK_LT(input_index[i], operand_shape.dimensions(i));
}
TF_RETURN_IF_ERROR(
result.CopyElementFrom(operand, input_index, output_index));
return true;
};
auto gather_outer_loop_body =
[&](absl::Span<const int64> output_gather_index) -> StatusOr<bool> {
TF_ASSIGN_OR_RETURN(absl::Span<const int64> input_gather_index,
output_batch_index_to_input_index(output_gather_index));
TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
shape, offset_indices_iteration_space,
std::bind(gather_inner_loop_body, std::placeholders::_1,
input_gather_index, output_gather_index)));
return true;
};
TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
shape, start_indices_iteration_space, gather_outer_loop_body));
evaluated_[gather] = std::move(result);
return Status::OK();
}
Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) {
const Literal& operand = GetEvaluatedLiteralFor(broadcast->operand(0));
TF_RET_CHECK(broadcast->dimensions().size() == operand.shape().rank())
<< "broadcast dimensions is of size: " << broadcast->dimensions().size()
<< " and rank of operand_to_broadcast is: " << operand.shape().rank();
// Checks that operand's dimensions are the same as the broadcast's
// dimensions along the dimensions to be broadcasted.
for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
auto operand_dim_size = operand.shape().dimensions(i);
auto broadcast_dim_size =
broadcast->shape().dimensions(broadcast->dimensions(i));
TF_RET_CHECK(operand_dim_size == broadcast_dim_size) << absl::StreamFormat(
"Operand dimension %d is broadcast to output dimension %d, but the "
"sizes of these two dims do not match (%d vs %d): %s",
i, broadcast->dimensions(i), operand_dim_size, broadcast_dim_size,
broadcast->ToString());
}
TF_ASSIGN_OR_RETURN(
evaluated_[broadcast],
operand.Broadcast(broadcast->shape(), broadcast->dimensions()));
return Status::OK();
}
Status HloEvaluator::HandleAfterAll(HloInstruction* after_all) {
evaluated_[after_all] = LiteralUtil::CreateToken();
return Status::OK();
}
Status HloEvaluator::HandleAddDependency(HloInstruction* add_dependency) {
// AddDedendency just forwards its zero-th operand.
evaluated_[add_dependency] =
GetEvaluatedLiteralFor(add_dependency->operand(0)).Clone();
return Status::OK();
}
Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
const auto result_shape = get_tuple_element->shape();
const int64 index = get_tuple_element->tuple_index();
auto operand = get_tuple_element->operand(0);
TF_ASSIGN_OR_RETURN(
auto inferred_return_shape,
ShapeInference::InferGetTupleElementShape(operand->shape(), index));
TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
<< "return shape set to: " << ShapeUtil::HumanString(result_shape)
<< " but is inferred to be: "
<< ShapeUtil::HumanString(inferred_return_shape);
const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
evaluated_[get_tuple_element] =
Literal(ShapeUtil::GetTupleElementShape(operand->shape(), index));
return evaluated_[get_tuple_element].CopyFrom(operand_tuple_literal,
/*dest_shape_index=*/{},
/*src_shape_index=*/{index});
}
Status HloEvaluator::HandleCopy(HloInstruction* copy) {
TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape()));
evaluated_[copy] = GetEvaluatedLiteralFor(copy->operand(0)).Clone();
return Status::OK();
}
Status HloEvaluator::HandleCall(HloInstruction* call) {
auto* computation = call->to_apply();
auto operands = call->operands();
std::vector<const Literal*> arg_literals;
arg_literals.reserve(operands.size());
for (auto operand : operands) {
const Literal& arg_literal = GetEvaluatedLiteralFor(operand);
arg_literals.push_back(&arg_literal);
}
HloEvaluator embedded_evaluator;
embedded_evaluator.set_dynamic_dimension_inference(
dynamic_dimension_inference_);
TF_ASSIGN_OR_RETURN(Literal result,
embedded_evaluator.Evaluate(*computation, arg_literals));
evaluated_[call] = std::move(result);
return Status::OK();
}
Status HloEvaluator::HandleFusion(HloInstruction* fusion) {
HloModuleConfig config;
// Attach cloned computation to an empty HLO module so the existing ones are
// not modified.
HloModule empty_hlo_module("EmptyModuleForFusion", config);
HloCloneContext context(&empty_hlo_module);
auto cloned_fused_computation =
fusion->fused_instructions_computation()->Clone(
/*suffix=*/"clone_with_layout", &context);
for (auto* instruction : cloned_fused_computation->instructions()) {
if (!LayoutUtil::HasLayout(instruction->shape())) {
LayoutUtil::SetToDefaultLayout(instruction->mutable_shape());
}
}
auto readded_computation =
empty_hlo_module.AddEntryComputation(std::move(cloned_fused_computation));
auto operands = fusion->operands();
std::vector<const Literal*> arg_literals;
arg_literals.reserve(operands.size());
for (auto operand : operands) {
const Literal& arg_literal = GetEvaluatedLiteralFor(operand);
arg_literals.push_back(&arg_literal);
}
HloEvaluator embedded_evaluator;
embedded_evaluator.set_dynamic_dimension_inference(
dynamic_dimension_inference_);
TF_ASSIGN_OR_RETURN(Literal result, embedded_evaluator.Evaluate(
*readded_computation, arg_literals));
evaluated_[fusion] = std::move(result);
return Status::OK();
}
Status HloEvaluator::HandleConditional(HloInstruction* conditional) {
const auto& branch_index_literal =
GetEvaluatedLiteralFor(conditional->operand(0));
int branch_index;
if (conditional->operand(0)->shape().element_type() == PRED) {
branch_index = branch_index_literal.Get<bool>({}) ? 0 : 1;
} else {
branch_index = branch_index_literal.Get<int32>({});
if (branch_index < 0 || branch_index >= conditional->branch_count()) {
branch_index = conditional->branch_count() - 1;
}
}
const auto& branch_computation_arg =
GetEvaluatedLiteralFor(conditional->operand(1 + branch_index));
HloEvaluator embedded_evaluator;
embedded_evaluator.set_dynamic_dimension_inference(
dynamic_dimension_inference_);
TF_ASSIGN_OR_RETURN(Literal result,
embedded_evaluator.Evaluate(
*conditional->branch_computation(branch_index),
{&branch_computation_arg}));
evaluated_[conditional] = std::move(result);
return Status::OK();
}
Status HloEvaluator::HandleSelect(HloInstruction* select) {
const auto& pred = GetEvaluatedLiteralFor(select->operand(0));
const auto& on_true = GetEvaluatedLiteralFor(select->operand(1));
const auto& on_false = GetEvaluatedLiteralFor(select->operand(2));
// If predicate is of scalar type, no element-wise selection would be needed.
if (ShapeUtil::IsScalar(pred.shape())) {
if (pred.Get<bool>({})) {
evaluated_[select] = on_true.Clone();
} else {
evaluated_[select] = on_false.Clone();
}
return Status::OK();
}
return DefaultAction(select);
}
Status HloEvaluator::HandleTupleSelect(HloInstruction* tuple_select) {
const auto& pred = GetEvaluatedLiteralFor(tuple_select->operand(0));
const auto& on_true = GetEvaluatedLiteralFor(tuple_select->operand(1));
const auto& on_false = GetEvaluatedLiteralFor(tuple_select->operand(2));
if (pred.Get<bool>({})) {
evaluated_[tuple_select] = on_true.Clone();
} else {
evaluated_[tuple_select] = on_false.Clone();
}
return Status::OK();
}
Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
HloComputation* cond_comp = while_hlo->while_condition();
HloComputation* body_comp = while_hlo->while_body();
// Initialize the loop carried valued with the input to the While instruction.
auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).Clone();
bool keep_going = true;
int64 iteration_count = 0;
HloEvaluator cond_evaluator(max_loop_iterations_);
cond_evaluator.set_dynamic_dimension_inference(dynamic_dimension_inference_);
HloEvaluator loop_body_evaluator(max_loop_iterations_);
loop_body_evaluator.set_dynamic_dimension_inference(
dynamic_dimension_inference_);
while (keep_going) {
if (max_loop_iterations_ >= 0 && iteration_count++ > max_loop_iterations_) {
return InvalidArgument("Loop %s exceeded loop iteration limit (%d).",
while_hlo->name(), max_loop_iterations_);
}
TF_ASSIGN_OR_RETURN(auto cond_val,
cond_evaluator.Evaluate(*cond_comp, {&lcv}));
keep_going = cond_val.GetFirstElement<bool>();
if (keep_going) {
TF_ASSIGN_OR_RETURN(auto body_val,
loop_body_evaluator.Evaluate(*body_comp, {&lcv}));
VLOG(3) << "Loop iteration result: " << body_val.ToString();
lcv = std::move(body_val);
cond_evaluator.ResetVisitStates();
loop_body_evaluator.ResetVisitStates();
}
}
evaluated_[while_hlo] = std::move(lcv);
return Status::OK();
}
namespace {
template <typename NativeT>
Literal ExtractLiteralFromIndexPositions(const Literal& from,
absl::Span<int64 const> indices,
bool extract_as_scalar) {
if (extract_as_scalar) {
return LiteralUtil::CreateR0<NativeT>(from.Get<NativeT>({indices[0]}));
}
// We use a InlinedVector here because we need to convert it to an
// absl::Span later, and this would not work with std::vector<bool>.
absl::InlinedVector<NativeT, 10> values;
for (int64 index : indices) {
values.push_back(from.Get<NativeT>({index}));
}
return LiteralUtil::CreateR1<NativeT>(values);
}
StatusOr<Literal> ExtractFromIndexPositions(const Literal& from,
absl::Span<int64 const> indices,
bool extract_as_scalar = false) {
if (extract_as_scalar) {
CHECK_EQ(indices.size(), 1);
}
PrimitiveType type = from.shape().element_type();
switch (type) {
case PRED: {
return ExtractLiteralFromIndexPositions<bool>(from, indices,
extract_as_scalar);
}
case U8: {
return ExtractLiteralFromIndexPositions<uint8>(from, indices,
extract_as_scalar);
}
case S8: {
return ExtractLiteralFromIndexPositions<int8>(from, indices,
extract_as_scalar);
}
case BF16: {
return ExtractLiteralFromIndexPositions<bfloat16>(from, indices,
extract_as_scalar);
}
case F16: {
return ExtractLiteralFromIndexPositions<Eigen::half>(from, indices,
extract_as_scalar);
}
case U16: {
return ExtractLiteralFromIndexPositions<uint16>(from, indices,
extract_as_scalar);
}
case S16: {
return ExtractLiteralFromIndexPositions<int16>(from, indices,
extract_as_scalar);
}
case F32: {
return ExtractLiteralFromIndexPositions<float>(from, indices,
extract_as_scalar);
}
case U32: {
return ExtractLiteralFromIndexPositions<uint32>(from, indices,
extract_as_scalar);
}
case S32: {
return ExtractLiteralFromIndexPositions<int32>(from, indices,
extract_as_scalar);
}
case F64: {
return ExtractLiteralFromIndexPositions<double>(from, indices,
extract_as_scalar);
}
case U64: {
return ExtractLiteralFromIndexPositions<uint64>(from, indices,
extract_as_scalar);
}
case S64: {
return ExtractLiteralFromIndexPositions<int64>(from, indices,
extract_as_scalar);
}
default:
return InvalidArgument("Unsupported type for Sort: %s",
PrimitiveType_Name(type));
}
}
} // namespace
Status HloEvaluator::HandleSort(HloInstruction* sort) {
TF_RET_CHECK(sort->operand_count() >= 1)
<< "Expected at least 1 operand for sort";
for (int64 i = 1; i < sort->operand_count(); ++i) {
TF_RET_CHECK(ShapeUtil::SameDimensions(sort->operand(0)->shape(),
sort->operand(i)->shape()))
<< "All Sort operands must have the same dimensions";
}
if (VLOG_IS_ON(3)) {
for (int64 i = 0; i < sort->operand_count(); ++i) {
VLOG(3) << "HandleSort operand " << i << " literal: "
<< GetEvaluatedLiteralFor(sort->operand(i)).ToString();
}
}
Shape key_shape = sort->operand(0)->shape();
auto rank = key_shape.rank();
std::vector<Literal> result_literals;
result_literals.reserve(sort->operand_count());
for (int64 i = 0; i < sort->operand_count(); ++i) {
result_literals.emplace_back(sort->operand(i)->shape());
}
std::vector<int64> zero_base(rank, 0);
std::vector<int64> increment(rank, 1);
int64 sort_dim = sort->dimensions(0);
int64 sort_dim_elements = key_shape.dimensions(sort_dim);
increment[sort_dim] = sort_dim_elements;
HloEvaluator embedded_evaluator(max_loop_iterations_);
// Iterate through each dimension except 'sort_dim'.
TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
key_shape, zero_base, AsInt64Slice(key_shape.dimensions()), increment,
[&](absl::Span<const int64> indices) -> StatusOr<bool> {
// Extract a slice from each operand literal that corresponds to
// exactly the row in dimension 'sort_dim'.
std::vector<int64> limit_indices(indices.begin(), indices.end());
absl::c_for_each(limit_indices, [](int64& index) { ++index; });
limit_indices[sort_dim] = sort_dim_elements;
std::vector<Literal> literals_to_sort;
literals_to_sort.reserve(sort->operand_count());
for (int64 i = 0; i < sort->operand_count(); ++i) {
TF_ASSIGN_OR_RETURN(auto literal_to_sort,
GetEvaluatedLiteralFor(sort->operand(i))
.Slice(indices, limit_indices)
.Reshape({sort_dim_elements}));
literals_to_sort.push_back(std::move(literal_to_sort));
}
std::vector<int64> indices_to_sort(sort_dim_elements);
std::iota(indices_to_sort.begin(), indices_to_sort.end(), 0);
Status compare_status = Status::OK();
auto comparator = [sort, &compare_status, &embedded_evaluator,
&literals_to_sort](int64 a, int64 b) {
std::vector<Literal> literals;
literals.reserve(2 * sort->operand_count());
for (int64 i = 0; i < sort->operand_count(); ++i) {
auto lhs = ExtractFromIndexPositions(literals_to_sort[i], {a},
/*extract_as_scalar=*/true);
if (!lhs.ok()) {
compare_status = lhs.status();
return false;
}
literals.push_back(std::move(lhs.ValueOrDie()));
auto rhs = ExtractFromIndexPositions(literals_to_sort[i], {b},
/*extract_as_scalar=*/true);
if (!rhs.ok()) {
compare_status = rhs.status();
return false;
}
literals.push_back(std::move(rhs.ValueOrDie()));
}
std::vector<const Literal*> literal_ptrs;
absl::c_transform(literals, std::back_inserter(literal_ptrs),
[](const Literal& literal) { return &literal; });
auto computed_result =
embedded_evaluator.Evaluate(*sort->to_apply(), literal_ptrs);
// Clear visit states so that we can use the evaluator again
// on the same computation.
embedded_evaluator.ResetVisitStates();
if (!computed_result.ok()) {
compare_status = computed_result.status();
return false;
}
return computed_result.ValueOrDie().Get<bool>({});
};
if (Cast<HloSortInstruction>(sort)->is_stable()) {
std::stable_sort(indices_to_sort.begin(), indices_to_sort.end(),
comparator);
} else {
std::sort(indices_to_sort.begin(), indices_to_sort.end(), comparator);
}
if (!compare_status.ok()) {
return compare_status;
}
std::vector<int64> slice_dimensions(rank, 1);
slice_dimensions[sort_dim] = sort_dim_elements;
std::vector<int64> start_indices(rank, 0);
for (int64 i = 0; i < sort->operand_count(); ++i) {
TF_ASSIGN_OR_RETURN(
Literal sorted_literal,
ExtractFromIndexPositions(literals_to_sort[i], indices_to_sort));
TF_ASSIGN_OR_RETURN(auto sorted_literal_reshaped,
sorted_literal.Reshape(slice_dimensions));
TF_RETURN_IF_ERROR(result_literals[i].CopySliceFrom(
sorted_literal_reshaped, start_indices, indices,
slice_dimensions));
}
return true;
}));
if (sort->operand_count() == 1) {
evaluated_[sort] = std::move(result_literals[0]);
} else {
std::vector<const Literal*> literal_ptrs;
absl::c_transform(result_literals, std::back_inserter(literal_ptrs),
[](const Literal& literal) { return &literal; });
Literal result_tuple = LiteralUtil::MakeTuple(literal_ptrs);
VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString();
evaluated_[sort] = std::move(result_tuple);
}
return Status::OK();
}
static bool IsScalarAdd(HloComputation* computation) {
HloInstruction* instruction = computation->root_instruction();
if (instruction->opcode() == HloOpcode::kAdd &&
computation->num_parameters() == 2) {
const HloInstruction* lhs = instruction->operand(0);
const HloInstruction* rhs = instruction->operand(1);
return lhs->opcode() == HloOpcode::kParameter &&
ShapeUtil::IsScalar(lhs->shape()) &&
rhs->opcode() == HloOpcode::kParameter &&
ShapeUtil::IsScalar(rhs->shape()) && lhs != rhs;
}
return false;
}
// Run a single step of an inner loop while running reduction, which applies
// the user-provided computation on the accumulator and the output element
// (until the reduction is completed, the output element is also used as
// an accumulator).
static StatusOr<bool> PerformReductionStep(
absl::Span<const int64> input_index, absl::Span<const int64> output_index,
absl::Span<const Literal* const> input_args, absl::Span<Literal> results,
HloComputation* computation, HloEvaluator* embedded_evaluator) {
int num_args = results.size();
bool is_tuple = num_args > 1;
absl::InlinedVector<Literal, 1> arg_values;
arg_values.reserve(num_args);
absl::InlinedVector<Literal, 1> accumulators;
accumulators.reserve(num_args);
for (int64 i = 0; i < num_args; ++i) {
arg_values.emplace_back(
ShapeUtil::MakeShape(input_args[i]->shape().element_type(), {}));
accumulators.emplace_back(
ShapeUtil::MakeShape(input_args[i]->shape().element_type(), {}));
TF_RETURN_IF_ERROR(
arg_values[i].CopyElementFrom(*input_args[i], input_index, {}));
TF_RETURN_IF_ERROR(
accumulators[i].CopyElementFrom(results[i], output_index, {}));
}
// Evaluate computation with specified literal operands.
absl::InlinedVector<Literal*, 2> embedded_operands;
for (Literal& accumulator : accumulators) {
embedded_operands.push_back(&accumulator);
}
for (Literal& local_input : arg_values) {
embedded_operands.push_back(&local_input);
}
TF_ASSIGN_OR_RETURN(
Literal computed_result,
embedded_evaluator->Evaluate(*computation, embedded_operands));
// Clear visit states so that we can use the evaluator again on the same
// computation.
embedded_evaluator->ResetVisitStates();
if (is_tuple) {
std::vector<Literal> computed_results = computed_result.DecomposeTuple();
for (int64 i = 0; i < num_args; ++i) {
TF_RETURN_IF_ERROR(
results[i].CopyElementFrom(computed_results[i], {}, output_index));
}
} else {
TF_RETURN_IF_ERROR(
results[0].CopyElementFrom(computed_result, {}, output_index));
}
return true;
}
static StatusOr<bool> GenerateReduceOutputElement(
absl::Span<const int64> output_index,
absl::Span<const Literal* const> init_values,
absl::Span<const Literal* const> input_args, absl::Span<Literal> results,
HloComputation* function, HloEvaluator* embedded_evaluator,
absl::Span<const int64> arg_dim_steps,
absl::Span<const int64> arg_dim_counts,
absl::Span<const int64> result_to_arg_index) {
bool is_tuple = results.size() > 1;
bool use_fast_add = ShapeUtil::ElementIsFloating(init_values[0]->shape()) &&
IsScalarAdd(function) && !is_tuple;
const Shape& arg_shape = input_args[0]->shape();
absl::Span<const int64> arg_dimensions = AsInt64Slice(arg_shape.dimensions());
std::vector<int64> base(arg_dimensions.size());
for (int64 i = 0; i < output_index.size(); ++i) {
base[result_to_arg_index[i]] = output_index[i];
}
for (int64 i = 0; i < results.size(); ++i) {
TF_RETURN_IF_ERROR(
results[i].CopyElementFrom(*init_values[i], {}, output_index));
}
if (use_fast_add) {
TF_ASSIGN_OR_RETURN(double computed_result,
init_values[0]->GetAsDouble({}));
auto reduction_step =
[&](absl::Span<const int64> input_index) -> StatusOr<bool> {
TF_ASSIGN_OR_RETURN(double argument,
input_args[0]->GetAsDouble(input_index));
computed_result += argument;
return true;
};
TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
arg_shape, base, arg_dim_counts, arg_dim_steps, reduction_step));
TF_RETURN_IF_ERROR(results[0].SetFromDouble(output_index, computed_result));
return true;
}
// Iterates only over reduced shape, as counts and steps are set to zero
// for all non-reduced dimensions.
TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
arg_shape, base, arg_dim_counts, arg_dim_steps,
[&](absl::Span<const int64> input_index) {
return PerformReductionStep(input_index, output_index, input_args,
results, function, embedded_evaluator);
}));
return true;
}
Status HloEvaluator::HandleReduce(HloInstruction* instr) {
HloReduceInstruction* reduce = Cast<HloReduceInstruction>(instr);
int64 num_args = reduce->inputs().size();
absl::Span<const int64> dimensions_to_reduce(reduce->dimensions());
HloComputation* function = reduce->to_apply();
absl::InlinedVector<const Shape*, 1> operand_shapes;
for (const HloInstruction* operand : reduce->operands()) {
operand_shapes.push_back(&operand->shape());
}
TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
ShapeInference::InferReduceShape(
operand_shapes, dimensions_to_reduce,
/*to_apply=*/function->ComputeProgramShape()));
TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(reduce->shape(),
inferred_return_shape))
<< "return shape is set to: " << ShapeUtil::HumanString(reduce->shape())
<< " but is inferred to be: "
<< ShapeUtil::HumanString(inferred_return_shape);
absl::InlinedVector<const Literal*, 1> input_args(num_args);
absl::InlinedVector<const Literal*, 1> init_values(num_args);
for (int64 i = 0; i < num_args; ++i) {
input_args[i] = &GetEvaluatedLiteralFor(reduce->inputs()[i]);
VLOG(3) << "HandleReduce arg_literal: " << input_args[i]->ToString();
init_values[i] = &GetEvaluatedLiteralFor(reduce->init_values()[i]);
VLOG(3) << "HandleReduce init_literal: " << init_values[i]->ToString();
TF_RET_CHECK(ShapeUtil::IsScalar(init_values[i]->shape()));
}
// All args and results have the same dimensions, so pick an arbitrary one.
const Shape& arg_shape = input_args[0]->shape();
const Shape& out_shape = inferred_return_shape;
bool is_tuple = out_shape.IsTuple();
const Shape& output_shape = inferred_return_shape.IsTuple()
? inferred_return_shape.tuple_shapes(0)
: inferred_return_shape;
absl::Span<const int64> arg_dimensions = AsInt64Slice(arg_shape.dimensions());
// All increments are set to 0.
std::vector<int64> arg_dim_steps(arg_dimensions.size());
// All counts are set to 0.
std::vector<int64> arg_dim_counts(arg_dimensions.size());
// Set steps and counts for reduced dimensions.
// This avoids iterating over non-reduced dimensions, as their step
// and count is set to zero.
for (const int64 dim : dimensions_to_reduce) {
arg_dim_steps[dim] = 1;
arg_dim_counts[dim] = arg_dimensions[dim];
}
auto reduced_dimensions = arg_shape.dimensions();
// Map each dimension in the result to a dimension in arg that isn't
// being reduced.
std::vector<int64> result_to_arg_index;
for (int64 i = 0; i < arg_dimensions.size(); ++i) {
if (arg_dim_steps[i] == 0) {
result_to_arg_index.push_back(i);
}
}
HloEvaluator embedded_evaluator(max_loop_iterations_);
absl::InlinedVector<Literal, 1> results(num_args);
for (int64 i = 0; i < num_args; ++i) {
results[i] = Literal(is_tuple ? out_shape.tuple_shapes(i) : out_shape);
}
TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
output_shape, [&](absl::Span<const int64> output_index) {
return GenerateReduceOutputElement(
output_index, init_values, input_args, absl::Span<Literal>(results),
function, &embedded_evaluator, arg_dim_steps, arg_dim_counts,
result_to_arg_index);
}));
if (is_tuple) {
Literal tuple_result(inferred_return_shape);
for (int64 i = 0; i < num_args; ++i) {
TF_CHECK_OK(tuple_result.MoveFrom(std::move(results[i]), {i}));
}
evaluated_[reduce] = std::move(tuple_result);
} else {
CHECK_EQ(results.size(), 1);
evaluated_[reduce] = std::move(results[0]);
}
if (!ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) {
TF_ASSIGN_OR_RETURN(evaluated_[reduce],
evaluated_[reduce].ConvertToShape(reduce->shape()));
}
return Status::OK();
}
Status HloEvaluator::HandleCustomCall(HloInstruction* custom_call) {
if (!custom_call_handler_) {
// No handler is registered; this means custom-calls are not allowed.
return DefaultAction(custom_call);
}
// Evaluate input operands so the handler has access to the operand data.
std::vector<const Literal*> operands;
operands.reserve(custom_call->operand_count());
for (const HloInstruction* operand : custom_call->operands()) {
operands.push_back(&GetEvaluatedLiteralFor(operand));
}
// Synchronously issue the handler to populate the instruction output literal.
TF_ASSIGN_OR_RETURN(
auto output, custom_call_handler_(custom_call, absl::MakeSpan(operands)));
evaluated_[custom_call] = std::move(output);
return Status::OK();
}
Status HloEvaluator::Preprocess(HloInstruction* hlo) {
VLOG(2) << "About to visit HLO: " << hlo->ToString();
return ShapeUtil::ValidateShape(hlo->shape());
}
Status HloEvaluator::Postprocess(HloInstruction* hlo) {
VLOG(2) << "Finished visiting " << hlo->ToString()
<< "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString();
// Out of convenience the literal may have been produced with a different
// layout. Relayout as indicated by the HLO instruction.
if (!Layout::Equal().MinorToMajorOnly()(
GetEvaluatedLiteralFor(hlo).shape().layout(),
hlo->shape().layout())) {
evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo->shape());
}
return Status::OK();
}
namespace {
template <typename T>
std::unique_ptr<Array2D<T>> MatmulArray2DImpl(
const Array2D<T>& lhs, const Array2D<T>& rhs,
const std::function<void(
const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, int64 n,
int64 k, int32 transpose_lhs, int32 transpose_rhs)>& impl_fn) {
CHECK_EQ(lhs.width(), rhs.height());
int m = lhs.height();
int n = rhs.width();
int k = lhs.width();
auto result = absl::make_unique<Array2D<T>>(m, n);
// Because Eigen is a header-oriented library, make sure that the Eigen code
// is the same as the code used by the CPU backend (otherwise the linker will
// randomly pick *some* definition).
impl_fn(
/*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m,
k,
/*transpose_lhs=*/0,
/*transpose_rhs=*/0);
return result;
}
} // namespace
std::unique_ptr<Array2D<Eigen::half>> HloEvaluator::MatmulArray2D(
const Array2D<Eigen::half>& lhs, const Array2D<Eigen::half>& rhs) {
return MatmulArray2DImpl<Eigen::half>(
lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF16);
}
std::unique_ptr<Array2D<float>> HloEvaluator::MatmulArray2D(
const Array2D<float>& lhs, const Array2D<float>& rhs) {
return MatmulArray2DImpl<float>(
lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF32);
}
std::unique_ptr<Array2D<double>> HloEvaluator::MatmulArray2D(
const Array2D<double>& lhs, const Array2D<double>& rhs) {
return MatmulArray2DImpl<double>(
lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF64);
}
} // namespace xla