blob: 7a1db86657f979ec543d4474744f846749af922d [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include <functional>
#include <iterator>
#include <memory>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/base/casts.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/ascii.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "absl/types/variant.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_lexer.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
namespace xla {
namespace {
using absl::nullopt;
using absl::optional;
using absl::StrAppend;
using absl::StrCat;
using absl::StrFormat;
using absl::StrJoin;
// Creates and returns a schedule created using the order of the instructions in
// the HloComputation::instructions() vectors in the module.
HloSchedule ScheduleFromInstructionOrder(HloModule* module) {
HloSchedule schedule(module);
for (HloComputation* computation : module->computations()) {
if (!computation->IsFusionComputation()) {
for (HloInstruction* instruction : computation->instructions()) {
schedule.GetOrCreateSequence(computation).push_back(instruction);
}
}
}
return schedule;
}
bool CanInferShape(HloOpcode code) {
switch (code) {
case HloOpcode::kAbs:
case HloOpcode::kAdd:
case HloOpcode::kAddDependency:
case HloOpcode::kAfterAll:
case HloOpcode::kAtan2:
case HloOpcode::kBatchNormGrad:
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormTraining:
case HloOpcode::kBroadcast:
case HloOpcode::kCall:
case HloOpcode::kCeil:
case HloOpcode::kCholesky:
case HloOpcode::kClamp:
case HloOpcode::kClz:
case HloOpcode::kCompare:
case HloOpcode::kComplex:
case HloOpcode::kConcatenate:
case HloOpcode::kConditional:
case HloOpcode::kConvolution:
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kDivide:
case HloOpcode::kDomain:
case HloOpcode::kDot:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFft:
case HloOpcode::kFloor:
case HloOpcode::kGather:
case HloOpcode::kGetDimensionSize:
case HloOpcode::kSetDimensionSize:
case HloOpcode::kGetTupleElement:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kLogistic:
case HloOpcode::kAnd:
case HloOpcode::kNot:
case HloOpcode::kOr:
case HloOpcode::kXor:
case HloOpcode::kMap:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kNegate:
case HloOpcode::kPad:
case HloOpcode::kPartitionId:
case HloOpcode::kPopulationCount:
case HloOpcode::kPower:
case HloOpcode::kReal:
case HloOpcode::kReduce:
case HloOpcode::kRemainder:
case HloOpcode::kReplicaId:
case HloOpcode::kReverse:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kRsqrt:
case HloOpcode::kScatter:
case HloOpcode::kSelect:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSqrt:
case HloOpcode::kCbrt:
case HloOpcode::kReduceWindow:
case HloOpcode::kSelectAndScatter:
case HloOpcode::kSort:
case HloOpcode::kSubtract:
case HloOpcode::kTanh:
case HloOpcode::kTrace:
case HloOpcode::kTranspose:
case HloOpcode::kTriangularSolve:
case HloOpcode::kTuple:
case HloOpcode::kTupleSelect:
case HloOpcode::kWhile:
return true;
// Technically the following ops do not require an explicit result shape,
// but we made it so that we always write the shapes explicitly.
case HloOpcode::kAllGather:
case HloOpcode::kAllGatherStart:
case HloOpcode::kAllGatherDone:
case HloOpcode::kAllReduce:
case HloOpcode::kAllReduceStart:
case HloOpcode::kAllReduceDone:
case HloOpcode::kAllToAll:
case HloOpcode::kCollectivePermute:
case HloOpcode::kCollectivePermuteStart:
case HloOpcode::kCollectivePermuteDone:
case HloOpcode::kCopyDone:
case HloOpcode::kCopyStart:
case HloOpcode::kDynamicReshape:
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kReduceScatter:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kSlice:
// The following ops require an explicit result shape.
case HloOpcode::kBitcast:
case HloOpcode::kBitcastConvert:
case HloOpcode::kConstant:
case HloOpcode::kConvert:
case HloOpcode::kCustomCall:
case HloOpcode::kFusion:
case HloOpcode::kInfeed:
case HloOpcode::kIota:
case HloOpcode::kOutfeed:
case HloOpcode::kParameter:
case HloOpcode::kReducePrecision:
case HloOpcode::kReshape:
case HloOpcode::kRng:
case HloOpcode::kRngBitGenerator:
case HloOpcode::kRngGetAndUpdateState:
return false;
}
}
// Parser for the HloModule::ToString() format text.
class HloParserImpl : public HloParser {
public:
using LocTy = HloLexer::LocTy;
explicit HloParserImpl(absl::string_view str) : lexer_(str) {}
// Runs the parser and constructs the resulting HLO in the given (empty)
// HloModule. Returns the error status in case an error occurred.
Status Run(HloModule* module) override;
// Returns the error information.
std::string GetError() const { return StrJoin(error_, "\n"); }
// Stand alone parsing utils for various aggregate data types.
StatusOr<Shape> ParseShapeOnly();
StatusOr<HloSharding> ParseShardingOnly();
StatusOr<FrontendAttributes> ParseFrontendAttributesOnly();
StatusOr<std::vector<bool>> ParseParameterReplicationOnly();
StatusOr<Window> ParseWindowOnly();
StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly();
StatusOr<PaddingConfig> ParsePaddingConfigOnly();
StatusOr<std::vector<ReplicaGroup>> ParseReplicaGroupsOnly();
private:
using InstrNameTable =
absl::flat_hash_map<std::string, std::pair<HloInstruction*, LocTy>>;
// Returns the map from the instruction name to the instruction itself and its
// location in the current scope.
InstrNameTable& current_name_table() { return scoped_name_tables_.back(); }
// Locates an instruction with the given name in the current_name_table() or
// returns nullptr.
//
// When the name is not found or name is empty, if create_missing_instruction_
// hook is registered and a "shape" is provided, the hook will be called to
// create an instruction. This is useful when we reify parameters as they're
// resolved; i.e. for ParseSingleInstruction.
std::pair<HloInstruction*, LocTy>* FindInstruction(
const std::string& name, const optional<Shape>& shape = nullopt);
// Parse a single instruction worth of text.
bool ParseSingleInstruction(HloModule* module);
// Parses a module, returning false if an error occurred.
bool ParseHloModule(HloModule* module);
bool ParseComputations(HloModule* module);
bool ParseComputation(HloComputation** entry_computation);
bool ParseInstructionList(HloComputation** computation,
const std::string& computation_name);
bool ParseInstruction(HloComputation::Builder* builder,
std::string* root_name);
bool ParseInstructionRhs(HloComputation::Builder* builder, std::string name,
LocTy name_loc, bool allow_attributes = true);
bool ParseControlPredecessors(HloInstruction* instruction);
bool ParseLiteral(Literal* literal);
bool ParseLiteral(Literal* literal, const Shape& shape);
bool ParseTupleLiteral(Literal* literal, const Shape& shape);
bool ParseNonTupleLiteral(Literal* literal, const Shape& shape);
bool ParseDenseLiteral(Literal* literal, const Shape& shape);
// Sets the sub-value of literal at the given linear index to the
// given value. If the literal is dense, it must have the default layout.
//
// `loc` should be the source location of the value.
bool SetValueInLiteral(LocTy loc, int64_t value, int64_t index,
Literal* literal);
bool SetValueInLiteral(LocTy loc, double value, int64_t index,
Literal* literal);
bool SetValueInLiteral(LocTy loc, bool value, int64_t index,
Literal* literal);
bool SetValueInLiteral(LocTy loc, std::complex<double> value, int64_t index,
Literal* literal);
// `loc` should be the source location of the value.
template <typename LiteralNativeT, typename ParsedElemT>
bool SetValueInLiteralHelper(LocTy loc, ParsedElemT value, int64_t index,
Literal* literal);
// Checks whether the given value is within the range of LiteralNativeT.
// `loc` should be the source location of the value.
template <typename LiteralNativeT, typename ParsedElemT>
bool CheckParsedValueIsInRange(LocTy loc, ParsedElemT value);
template <typename LiteralNativeT>
bool CheckParsedValueIsInRange(LocTy loc, std::complex<double> value);
bool ParseOperands(std::vector<HloInstruction*>* operands,
HloComputation::Builder* builder);
// Fills parsed operands into 'operands' and expects a certain number of
// operands.
bool ParseOperands(std::vector<HloInstruction*>* operands,
HloComputation::Builder* builder, const int expected_size);
// Describes the start, limit, and stride on every dimension of the operand
// being sliced.
struct SliceRanges {
std::vector<int64_t> starts;
std::vector<int64_t> limits;
std::vector<int64_t> strides;
};
// The data parsed for the kDomain instruction.
struct DomainData {
std::unique_ptr<DomainMetadata> entry_metadata;
std::unique_ptr<DomainMetadata> exit_metadata;
};
// Types of attributes.
enum class AttrTy {
kBool,
kInt64,
kInt32,
kFloat,
kString,
kLiteral,
kBracedInt64List,
kBracedInt64ListList,
kHloComputation,
kBracedHloComputationList,
kFftType,
kPaddingType,
kComparisonDirection,
kComparisonType,
kWindow,
kConvolutionDimensionNumbers,
kSharding,
kFrontendAttributes,
kParameterReplication,
kInstructionList,
kSliceRanges,
kPaddingConfig,
kMetadata,
kFusionKind,
kDistribution,
kDomain,
kPrecisionList,
kShape,
kShapeList,
kEnum,
kRandomAlgorithm,
kAliasing,
kInstructionAliasing,
kCustomCallSchedule,
kCustomCallApiVersion,
};
struct AttrConfig {
bool required; // whether it's required or optional
AttrTy attr_type; // what type it is
void* result; // where to store the parsed result.
};
// attributes ::= (',' attribute)*
//
// Parses attributes given names and configs of the attributes. Each parsed
// result is passed back through the result pointer in corresponding
// AttrConfig. Note that the result pointer must point to a optional<T> typed
// variable which outlives this function. Returns false on error. You should
// not use the any of the results if this function failed.
//
// Example usage:
//
// absl::flat_hash_map<std::string, AttrConfig> attrs;
// optional<int64_t> foo;
// attrs["foo"] = {/*required=*/false, AttrTy::kInt64, &foo};
// optional<Window> bar;
// attrs["bar"] = {/*required=*/true, AttrTy::kWindow, &bar};
// if (!ParseAttributes(attrs)) {
// return false; // Do not use 'foo' 'bar' if failed.
// }
// // Do something with 'bar'.
// if (foo) { // If attr foo is seen, do something with 'foo'. }
//
bool ParseAttributes(
const absl::flat_hash_map<std::string, AttrConfig>& attrs);
// sub_attributes ::= '{' (','? attribute)* '}'
//
// Usage is the same as ParseAttributes. See immediately above.
bool ParseSubAttributes(
const absl::flat_hash_map<std::string, AttrConfig>& attrs);
// Parses one attribute. If it has already been seen, return error. Returns
// true and adds to seen_attrs on success.
//
// Do not call this except in ParseAttributes or ParseSubAttributes.
bool ParseAttributeHelper(
const absl::flat_hash_map<std::string, AttrConfig>& attrs,
absl::flat_hash_set<std::string>* seen_attrs);
// Copy attributes from `attrs` to `message`, unless the attribute name is in
// `non_proto_attrs`.
bool CopyAttributeToProtoMessage(
absl::flat_hash_set<std::string> non_proto_attrs,
const absl::flat_hash_map<std::string, AttrConfig>& attrs,
tensorflow::protobuf::Message* message);
// Parses an attribute string into a protocol buffer `message`.
// Since proto3 has no notion of mandatory fields, `required_attrs` gives the
// set of mandatory attributes.
// `non_proto_attrs` specifies attributes that are not written to the proto,
// but added to the HloInstruction.
bool ParseAttributesAsProtoMessage(
const absl::flat_hash_map<std::string, AttrConfig>& non_proto_attrs,
tensorflow::protobuf::Message* message);
// Parses a name and finds the corresponding hlo computation.
bool ParseComputationName(HloComputation** value);
// Parses a list of names and finds the corresponding hlo instructions.
bool ParseInstructionNames(std::vector<HloInstruction*>* instructions);
// Pass expect_outer_curlies == true when parsing a Window in the context of a
// larger computation. Pass false when parsing a stand-alone Window string.
bool ParseWindow(Window* window, bool expect_outer_curlies);
bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums);
bool ParsePaddingConfig(PaddingConfig* padding);
bool ParseMetadata(OpMetadata* metadata);
bool ParseSingleOrListMetadata(
tensorflow::protobuf::RepeatedPtrField<OpMetadata>* metadata);
bool ParseOpShardingType(OpSharding::Type* type);
bool ParseListShardingType(std::vector<OpSharding::Type>* types);
bool ParseSharding(OpSharding* sharding);
bool ParseFrontendAttributes(FrontendAttributes* frontend_attributes);
bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed);
bool ParseParameterReplication(ParameterReplication* parameter_replication);
bool ParseReplicaGroupsOnly(std::vector<ReplicaGroup>* replica_groups);
// Parses the metadata behind a kDOmain instruction.
bool ParseDomain(DomainData* domain);
// Parses a sub-attribute of the window attribute, e.g.,size=1x2x3.
bool ParseDxD(const std::string& name, std::vector<int64_t>* result);
// Parses window's pad sub-attribute, e.g., pad=0_0x3x3.
bool ParseWindowPad(std::vector<std::vector<int64_t>>* pad);
bool ParseSliceRanges(SliceRanges* result);
bool ParsePrecisionList(std::vector<PrecisionConfig::Precision>* result);
bool ParseHloComputation(HloComputation** result);
bool ParseHloComputationList(std::vector<HloComputation*>* result);
bool ParseShapeList(std::vector<Shape>* result);
bool ParseInt64List(const TokKind start, const TokKind end,
const TokKind delim, std::vector<int64_t>* result);
bool ParseInt64ListList(const TokKind start, const TokKind end,
const TokKind delim,
std::vector<std::vector<int64_t>>* result);
// 'parse_and_add_item' is an lambda to parse an element in the list and add
// the parsed element to the result. It's supposed to capture the result.
bool ParseList(const TokKind start, const TokKind end, const TokKind delim,
const std::function<bool()>& parse_and_add_item);
bool ParseParamListToShape(Shape* shape, LocTy* shape_loc);
bool ParseParamList();
bool ParseName(std::string* result);
bool ParseAttributeName(std::string* result);
bool ParseString(std::string* result);
bool ParseDimensionSizes(std::vector<int64_t>* dimension_sizes,
std::vector<bool>* dynamic_dimensions);
bool ParseShape(Shape* result);
bool ParseLayout(Layout* layout);
bool ParseLayoutIntAttribute(int64_t* attr_value,
absl::string_view attr_description);
bool ParseTiles(std::vector<Tile>* tiles);
bool ParseOpcode(HloOpcode* result);
bool ParseFftType(FftType* result);
bool ParsePaddingType(PaddingType* result);
bool ParseComparisonDirection(ComparisonDirection* result);
bool ParseComparisonType(Comparison::Type* result);
bool ParseFusionKind(HloInstruction::FusionKind* result);
bool ParseRandomDistribution(RandomDistribution* result);
bool ParseRandomAlgorithm(RandomAlgorithm* result);
bool ParsePrecision(PrecisionConfig::Precision* result);
bool ParseInt64(int64_t* result);
bool ParseDouble(double* result);
bool ParseComplex(std::complex<double>* result);
bool ParseBool(bool* result);
bool ParseToken(TokKind kind, const std::string& msg);
using AliasingData =
absl::flat_hash_map<ShapeIndex, HloInputOutputAliasConfig::Alias>;
// Parses the aliasing information from string `s`, returns `false` if it
// fails.
bool ParseAliasing(AliasingData* data);
// Parses the per-instruction aliasing information from string `s`, returns
// `false` if it fails.
bool ParseInstructionOutputOperandAliasing(
std::vector<std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>*
aliasing_output_operand_pairs);
bool ParseCustomCallSchedule(CustomCallSchedule* result);
bool ParseCustomCallApiVersion(CustomCallApiVersion* result);
bool ParseShapeIndex(ShapeIndex* out);
// Returns true if the current token is the beginning of a shape.
bool CanBeShape();
// Returns true if the current token is the beginning of a
// param_list_to_shape.
bool CanBeParamListToShape();
// Logs the current parsing line and the given message. Always returns false.
bool TokenError(absl::string_view msg);
bool Error(LocTy loc, absl::string_view msg);
// If the current token is 'kind', eats it (i.e. lexes the next token) and
// returns true.
bool EatIfPresent(TokKind kind);
// Adds the instruction to the pool. Returns false and emits an error if the
// instruction already exists.
bool AddInstruction(const std::string& name, HloInstruction* instruction,
LocTy name_loc);
// Adds the computation to the pool. Returns false and emits an error if the
// computation already exists.
bool AddComputation(const std::string& name, HloComputation* computation,
LocTy name_loc);
HloLexer lexer_;
// A stack for the instruction names. The top of the stack stores the
// instruction name table for the current scope.
//
// A instruction's name is unique among its scope (i.e. its parent
// computation), but it's not necessarily unique among all computations in the
// module. When there are multiple levels of nested computations, the same
// name could appear in both an outer computation and an inner computation. So
// we need a stack to make sure a name is only visible within its scope,
std::vector<InstrNameTable> scoped_name_tables_;
// A helper class which pushes and pops to an InstrNameTable stack via RAII.
class Scope {
public:
explicit Scope(std::vector<InstrNameTable>* scoped_name_tables)
: scoped_name_tables_(scoped_name_tables) {
scoped_name_tables_->emplace_back();
}
~Scope() { scoped_name_tables_->pop_back(); }
private:
std::vector<InstrNameTable>* scoped_name_tables_;
};
// Map from the computation name to the computation itself and its location.
absl::flat_hash_map<std::string, std::pair<HloComputation*, LocTy>>
computation_pool_;
std::vector<std::unique_ptr<HloComputation>> computations_;
std::vector<std::string> error_;
// When an operand name cannot be resolved, this function is called to create
// a parameter instruction with the given name and shape. It registers the
// name, instruction, and a placeholder location in the name table. It returns
// the newly-created instruction and the placeholder location. If `name` is
// empty, this should create the parameter with a generated name. This is
// supposed to be set and used only in ParseSingleInstruction.
std::function<std::pair<HloInstruction*, LocTy>*(const std::string& name,
const Shape& shape)>
create_missing_instruction_;
// Used to generate names for anonymous instructions.
NameUniquer name_uniquer_{/*separator=*/"."};
};
bool SplitToInt64s(absl::string_view s, char delim, std::vector<int64_t>* out) {
for (const auto& split : absl::StrSplit(s, delim)) {
int64_t val;
if (!absl::SimpleAtoi(split, &val)) {
return false;
}
out->push_back(val);
}
return true;
}
// Creates replica groups from the provided nested array. groups[i] represents
// the replica ids for group 'i'.
std::vector<ReplicaGroup> CreateReplicaGroups(
absl::Span<const std::vector<int64_t>> groups) {
std::vector<ReplicaGroup> replica_groups;
absl::c_transform(groups, std::back_inserter(replica_groups),
[](const std::vector<int64_t>& ids) {
ReplicaGroup group;
*group.mutable_replica_ids() = {ids.begin(), ids.end()};
return group;
});
return replica_groups;
}
bool HloParserImpl::Error(LocTy loc, absl::string_view msg) {
auto line_col = lexer_.GetLineAndColumn(loc);
const unsigned line = line_col.first;
const unsigned col = line_col.second;
std::vector<std::string> error_lines;
error_lines.push_back(
StrCat("was parsing ", line, ":", col, ": error: ", msg));
error_lines.emplace_back(lexer_.GetLine(loc));
error_lines.push_back(col == 0 ? "" : StrCat(std::string(col - 1, ' '), "^"));
error_.push_back(StrJoin(error_lines, "\n"));
VLOG(1) << "Error: " << error_.back();
return false;
}
bool HloParserImpl::TokenError(absl::string_view msg) {
return Error(lexer_.GetLoc(), msg);
}
Status HloParserImpl::Run(HloModule* module) {
lexer_.Lex();
if (lexer_.GetKind() == TokKind::kw_HloModule) {
// This means that the text contains a full HLO module.
if (!ParseHloModule(module)) {
return InvalidArgument(
"Syntax error when trying to parse the text as a HloModule:\n%s",
GetError());
}
return Status::OK();
}
// This means that the text is a single HLO instruction.
if (!ParseSingleInstruction(module)) {
return InvalidArgument(
"Syntax error when trying to parse the text as a single "
"HloInstruction:\n%s",
GetError());
}
return Status::OK();
}
std::pair<HloInstruction*, HloParserImpl::LocTy>*
HloParserImpl::FindInstruction(const std::string& name,
const optional<Shape>& shape) {
std::pair<HloInstruction*, LocTy>* instr = nullptr;
if (!name.empty()) {
instr = tensorflow::gtl::FindOrNull(current_name_table(), name);
}
// Potentially call the missing instruction hook.
if (instr == nullptr && create_missing_instruction_ != nullptr &&
scoped_name_tables_.size() == 1) {
if (!shape.has_value()) {
Error(lexer_.GetLoc(),
"Operand had no shape in HLO text; cannot create parameter for "
"single-instruction module.");
return nullptr;
}
return create_missing_instruction_(name, *shape);
}
if (instr != nullptr && shape.has_value() &&
!ShapeUtil::Compatible(instr->first->shape(), shape.value())) {
Error(
lexer_.GetLoc(),
StrCat("The declared operand shape ",
ShapeUtil::HumanStringWithLayout(shape.value()),
" is not compatible with the shape of the operand instruction ",
ShapeUtil::HumanStringWithLayout(instr->first->shape()), "."));
return nullptr;
}
return instr;
}
bool HloParserImpl::ParseShapeIndex(ShapeIndex* out) {
if (!ParseToken(TokKind::kLbrace, "Expects '{' at the start of ShapeIndex")) {
return false;
}
std::vector<int64_t> idxs;
while (lexer_.GetKind() != TokKind::kRbrace) {
int64_t idx;
if (!ParseInt64(&idx)) {
return false;
}
idxs.push_back(idx);
if (!EatIfPresent(TokKind::kComma)) {
break;
}
}
if (!ParseToken(TokKind::kRbrace, "Expects '}' at the end of ShapeIndex")) {
return false;
}
*out = ShapeIndex(idxs.begin(), idxs.end());
return true;
}
bool HloParserImpl::ParseAliasing(AliasingData* data) {
if (!ParseToken(TokKind::kLbrace,
"Expects '{' at the start of aliasing description")) {
return false;
}
while (lexer_.GetKind() != TokKind::kRbrace) {
ShapeIndex out;
if (!ParseShapeIndex(&out)) {
return false;
}
std::string errmsg =
"Expected format: <output_shape_index>: (<input_param>, "
"<input_param_shape_index>) OR <output_shape_index>: <input_param>";
if (!ParseToken(TokKind::kColon, errmsg)) {
return false;
}
if (!ParseToken(TokKind::kLparen, errmsg)) {
return false;
}
int64_t param_num;
ParseInt64(&param_num);
if (!ParseToken(TokKind::kComma, errmsg)) {
return false;
}
ShapeIndex param_idx;
if (!ParseShapeIndex(&param_idx)) {
return false;
}
HloInputOutputAliasConfig::AliasKind alias_kind =
HloInputOutputAliasConfig::kMayAlias;
if (EatIfPresent(TokKind::kComma)) {
std::string type;
ParseName(&type);
if (type == "must-alias") {
alias_kind = HloInputOutputAliasConfig::kMustAlias;
} else if (type == "may-alias") {
alias_kind = HloInputOutputAliasConfig::kMayAlias;
} else {
return TokenError("Unexpected aliasing kind; expected SYSTEM or USER");
}
}
data->emplace(std::piecewise_construct, std::forward_as_tuple(out),
std::forward_as_tuple(param_num, param_idx, alias_kind));
if (!ParseToken(TokKind::kRparen, errmsg)) {
return false;
}
if (!EatIfPresent(TokKind::kComma)) {
break;
}
}
if (!ParseToken(TokKind::kRbrace,
"Expects '}' at the end of aliasing description")) {
return false;
}
return true;
}
bool HloParserImpl::ParseInstructionOutputOperandAliasing(
std::vector<std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>*
aliasing_output_operand_pairs) {
if (!ParseToken(
TokKind::kLbrace,
"Expects '{' at the start of instruction aliasing description")) {
return false;
}
while (lexer_.GetKind() != TokKind::kRbrace) {
ShapeIndex out;
if (!ParseShapeIndex(&out)) {
return false;
}
std::string errmsg =
"Expected format: <output_shape_index>: (<operand_index>, "
"<operand_shape_index>)";
if (!ParseToken(TokKind::kColon, errmsg)) {
return false;
}
if (!ParseToken(TokKind::kLparen, errmsg)) {
return false;
}
int64_t operand_index;
ParseInt64(&operand_index);
if (!ParseToken(TokKind::kComma, errmsg)) {
return false;
}
ShapeIndex operand_shape_index;
if (!ParseShapeIndex(&operand_shape_index)) {
return false;
}
aliasing_output_operand_pairs->emplace_back(
out,
std::pair<int64_t, ShapeIndex>{operand_index, operand_shape_index});
if (!ParseToken(TokKind::kRparen, errmsg)) {
return false;
}
if (!EatIfPresent(TokKind::kComma)) {
break;
}
}
if (!ParseToken(
TokKind::kRbrace,
"Expects '}' at the end of instruction aliasing description")) {
return false;
}
return true;
}
bool HloParserImpl::ParseCustomCallSchedule(CustomCallSchedule* result) {
VLOG(3) << "ParseCustomCallSchedule";
if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects custom-call schedule");
}
std::string val = lexer_.GetStrVal();
auto status_or_result = StringToCustomCallSchedule(val);
if (!status_or_result.ok()) {
return TokenError(
StrFormat("expects custom-call schedule but sees: %s, error: %s", val,
status_or_result.status().error_message()));
}
*result = status_or_result.ValueOrDie();
lexer_.Lex();
return true;
}
bool HloParserImpl::ParseCustomCallApiVersion(CustomCallApiVersion* result) {
VLOG(3) << "ParseCustomCallApiVersion";
if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects custom-call API version");
}
std::string val = lexer_.GetStrVal();
auto status_or_result = StringToCustomCallApiVersion(val);
if (!status_or_result.ok()) {
return TokenError(
StrFormat("expects custom-call API version but sees: %s, error: %s",
val, status_or_result.status().error_message()));
}
*result = status_or_result.ValueOrDie();
lexer_.Lex();
return true;
}
// ::= 'HloModule' name computations
bool HloParserImpl::ParseHloModule(HloModule* module) {
if (lexer_.GetKind() != TokKind::kw_HloModule) {
return TokenError("expects HloModule");
}
// Eat 'HloModule'
lexer_.Lex();
std::string name;
if (!ParseName(&name)) {
return false;
}
absl::optional<bool> is_scheduled;
absl::optional<AliasingData> aliasing_data;
absl::optional<bool> alias_passthrough_params;
absl::flat_hash_map<std::string, AttrConfig> attrs;
attrs["is_scheduled"] = {/*required=*/false, AttrTy::kBool, &is_scheduled};
attrs["input_output_alias"] = {/*required=*/false, AttrTy::kAliasing,
&aliasing_data};
attrs["alias_passthrough_params"] = {/*required=*/false, AttrTy::kBool,
&alias_passthrough_params};
if (!ParseAttributes(attrs)) {
return false;
}
module->set_name(name);
if (!ParseComputations(module)) {
return false;
}
if (is_scheduled.has_value() && *is_scheduled) {
TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module)));
}
if (alias_passthrough_params.has_value() && *alias_passthrough_params) {
HloModuleConfig config = module->config();
config.set_alias_passthrough_params(true);
module->set_config(config);
}
if (aliasing_data) {
HloInputOutputAliasConfig alias_config(module->result_shape());
for (auto& p : *aliasing_data) {
Status st =
alias_config.SetUpAlias(p.first, p.second.parameter_number,
p.second.parameter_index, p.second.kind);
if (!st.ok()) {
return TokenError(st.error_message());
}
}
module->input_output_alias_config() = alias_config;
}
return true;
}
// computations ::= (computation)+
bool HloParserImpl::ParseComputations(HloModule* module) {
HloComputation* entry_computation = nullptr;
do {
if (!ParseComputation(&entry_computation)) {
return false;
}
} while (lexer_.GetKind() != TokKind::kEof);
for (int i = 0; i < computations_.size(); i++) {
// If entry_computation is not nullptr, it means the computation it pointed
// to is marked with "ENTRY"; otherwise, no computation is marked with
// "ENTRY", and we use the last computation as the entry computation. We
// add the non-entry computations as embedded computations to the module.
if ((entry_computation != nullptr &&
computations_[i].get() != entry_computation) ||
(entry_computation == nullptr && i != computations_.size() - 1)) {
module->AddEmbeddedComputation(std::move(computations_[i]));
continue;
}
auto computation = module->AddEntryComputation(std::move(computations_[i]));
// The parameters and result layouts were set to default layout. Here we
// set the layouts to what the hlo text says.
for (int p = 0; p < computation->num_parameters(); p++) {
const Shape& param_shape = computation->parameter_instruction(p)->shape();
TF_CHECK_OK(module->mutable_entry_computation_layout()
->mutable_parameter_layout(p)
->CopyLayoutFromShape(param_shape));
}
const Shape& result_shape = computation->root_instruction()->shape();
TF_CHECK_OK(module->mutable_entry_computation_layout()
->mutable_result_layout()
->CopyLayoutFromShape(result_shape));
}
return true;
}
// computation ::= ('ENTRY')? name (param_list_to_shape)? instruction_list
bool HloParserImpl::ParseComputation(HloComputation** entry_computation) {
LocTy maybe_entry_loc = lexer_.GetLoc();
const bool is_entry_computation = EatIfPresent(TokKind::kw_ENTRY);
std::string name;
LocTy name_loc = lexer_.GetLoc();
if (!ParseName(&name)) {
return false;
}
LocTy shape_loc = nullptr;
Shape shape;
if (CanBeParamListToShape() && !ParseParamListToShape(&shape, &shape_loc)) {
return false;
}
HloComputation* computation = nullptr;
if (!ParseInstructionList(&computation, name)) {
return false;
}
// If param_list_to_shape was present, check compatibility.
if (shape_loc != nullptr &&
!ShapeUtil::Compatible(computation->root_instruction()->shape(), shape)) {
return Error(
shape_loc,
StrCat(
"Shape of computation ", name, ", ", ShapeUtil::HumanString(shape),
", is not compatible with that of its root instruction ",
computation->root_instruction()->name(), ", ",
ShapeUtil::HumanString(computation->root_instruction()->shape())));
}
if (is_entry_computation) {
if (*entry_computation != nullptr) {
return Error(maybe_entry_loc, "expects only one ENTRY");
}
*entry_computation = computation;
}
return AddComputation(name, computation, name_loc);
}
// instruction_list ::= '{' instruction_list1 '}'
// instruction_list1 ::= (instruction)+
bool HloParserImpl::ParseInstructionList(HloComputation** computation,
const std::string& computation_name) {
Scope scope(&scoped_name_tables_);
HloComputation::Builder builder(computation_name);
if (!ParseToken(TokKind::kLbrace,
"expects '{' at the beginning of instruction list.")) {
return false;
}
std::string root_name;
do {
if (!ParseInstruction(&builder, &root_name)) {
return false;
}
} while (lexer_.GetKind() != TokKind::kRbrace);
if (!ParseToken(TokKind::kRbrace,
"expects '}' at the end of instruction list.")) {
return false;
}
HloInstruction* root = nullptr;
if (!root_name.empty()) {
std::pair<HloInstruction*, LocTy>* root_node =
tensorflow::gtl::FindOrNull(current_name_table(), root_name);
// This means some instruction was marked as ROOT but we didn't find it in
// the pool, which should not happen.
if (root_node == nullptr) {
// LOG(FATAL) crashes the program by calling abort().
LOG(FATAL) << "instruction " << root_name
<< " was marked as ROOT but the parser has not seen it before";
}
root = root_node->first;
}
// Now root can be either an existing instruction or a nullptr. If it's a
// nullptr, the implementation of Builder will set the last instruction as
// the root instruction.
computations_.emplace_back(builder.Build(root));
*computation = computations_.back().get();
return true;
}
// instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)*
bool HloParserImpl::ParseInstruction(HloComputation::Builder* builder,
std::string* root_name) {
std::string name;
LocTy maybe_root_loc = lexer_.GetLoc();
bool is_root = EatIfPresent(TokKind::kw_ROOT);
const LocTy name_loc = lexer_.GetLoc();
if (!ParseName(&name) ||
!ParseToken(TokKind::kEqual, "expects '=' in instruction")) {
return false;
}
if (is_root) {
if (!root_name->empty()) {
return Error(maybe_root_loc, "one computation should have only one ROOT");
}
*root_name = name;
}
return ParseInstructionRhs(builder, name, name_loc);
}
bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
std::string name, LocTy name_loc,
bool allow_attributes) {
Shape shape;
HloOpcode opcode;
std::vector<HloInstruction*> operands;
const bool parse_shape = CanBeShape();
if ((parse_shape && !ParseShape(&shape)) || !ParseOpcode(&opcode)) {
return false;
}
if (!parse_shape && !CanInferShape(opcode)) {
return TokenError(StrFormat("cannot infer shape for opcode: %s",
HloOpcodeString(opcode)));
}
// Add optional attributes. These are added to any HloInstruction type if
// present.
absl::flat_hash_map<std::string, AttrConfig> attrs;
optional<OpSharding> sharding;
optional<FrontendAttributes> frontend_attributes;
attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding};
attrs["frontend_attributes"] = {
/*required=*/false, AttrTy::kFrontendAttributes, &frontend_attributes};
optional<ParameterReplication> parameter_replication;
attrs["parameter_replication"] = {/*required=*/false,
AttrTy::kParameterReplication,
&parameter_replication};
optional<std::vector<HloInstruction*>> predecessors;
attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList,
&predecessors};
optional<OpMetadata> metadata;
attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata};
optional<std::string> backend_config;
attrs["backend_config"] = {/*required=*/false, AttrTy::kString,
&backend_config};
optional<std::vector<int64_t>> outer_dimension_partitions;
attrs["outer_dimension_partitions"] = {/*required=*/false,
AttrTy::kBracedInt64List,
&outer_dimension_partitions};
const auto maybe_infer_shape =
[&](const std::function<StatusOr<Shape>()>& infer, Shape* shape) {
if (parse_shape) {
return true;
}
auto inferred = infer();
if (!inferred.ok()) {
return TokenError(StrFormat(
"failed to infer shape for opcode: %s, error: %s",
HloOpcodeString(opcode), inferred.status().error_message()));
}
*shape = std::move(inferred).ValueOrDie();
return true;
};
const auto parse_attributes = [&] {
if (!allow_attributes) return true;
return ParseAttributes(attrs);
};
HloInstruction* instruction;
switch (opcode) {
case HloOpcode::kParameter: {
int64_t parameter_number;
if (!ParseToken(TokKind::kLparen,
"expects '(' before parameter number") ||
!ParseInt64(&parameter_number)) {
return false;
}
if (parameter_number < 0) {
Error(lexer_.GetLoc(), "parameter number must be >= 0");
return false;
}
if (!ParseToken(TokKind::kRparen, "expects ')' after parameter number") ||
!parse_attributes()) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateParameter(parameter_number, shape, name));
break;
}
case HloOpcode::kConstant: {
Literal literal;
if (!ParseToken(TokKind::kLparen,
"expects '(' before constant literal") ||
!ParseLiteral(&literal, shape) ||
!ParseToken(TokKind::kRparen, "expects ')' after constant literal") ||
!parse_attributes()) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
break;
}
case HloOpcode::kIota: {
optional<int64_t> iota_dimension;
attrs["iota_dimension"] = {/*required=*/true, AttrTy::kInt64,
&iota_dimension};
if (!ParseOperands(&operands, builder, /*expected_size=*/0) ||
!parse_attributes()) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateIota(shape, *iota_dimension));
break;
}
// Unary ops.
case HloOpcode::kAbs:
case HloOpcode::kAllGatherDone:
case HloOpcode::kAllReduceDone:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kBitcast:
case HloOpcode::kCeil:
case HloOpcode::kClz:
case HloOpcode::kCollectivePermuteDone:
case HloOpcode::kCopy:
case HloOpcode::kCopyDone:
case HloOpcode::kCos:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kFloor:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kLogistic:
case HloOpcode::kNot:
case HloOpcode::kNegate:
case HloOpcode::kPopulationCount:
case HloOpcode::kReal:
case HloOpcode::kRsqrt:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSqrt:
case HloOpcode::kCbrt:
case HloOpcode::kTanh: {
if (!ParseOperands(&operands, builder, /*expected_size=*/1) ||
!parse_attributes()) {
return false;
}
if (!maybe_infer_shape(
[&] {
return ShapeInference::InferUnaryOpShape(opcode, operands[0]);
},
&shape)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateUnary(shape, opcode, operands[0]));
break;
}
// Binary ops.
case HloOpcode::kAdd:
case HloOpcode::kDivide:
case HloOpcode::kMultiply:
case HloOpcode::kSubtract:
case HloOpcode::kAtan2:
case HloOpcode::kComplex:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kAnd:
case HloOpcode::kOr:
case HloOpcode::kXor:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical: {
if (!ParseOperands(&operands, builder, /*expected_size=*/2) ||
!parse_attributes()) {
return false;
}
if (!maybe_infer_shape(
[&] {
return ShapeInference::InferBinaryOpShape(opcode, operands[0],
operands[1]);
},
&shape)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateBinary(
shape, opcode, operands[0], operands[1]));
break;
}
// Ternary ops.
case HloOpcode::kClamp:
case HloOpcode::kSelect:
case HloOpcode::kTupleSelect: {
if (!ParseOperands(&operands, builder, /*expected_size=*/3) ||
!parse_attributes()) {
return false;
}
if (!maybe_infer_shape(
[&] {
return ShapeInference::InferTernaryOpShape(
opcode, operands[0], operands[1], operands[2]);
},
&shape)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateTernary(
shape, opcode, operands[0], operands[1], operands[2]));
break;
}
// Other supported ops.
case HloOpcode::kConvert: {
if (!ParseOperands(&operands, builder, /*expected_size=*/1) ||
!parse_attributes()) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateConvert(shape, operands[0]));
break;
}
case HloOpcode::kBitcastConvert: {
if (!ParseOperands(&operands, builder, /*expected_size=*/1) ||
!parse_attributes()) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateBitcastConvert(shape, operands[0]));
break;
}
case HloOpcode::kAllGather:
case HloOpcode::kAllGatherStart: {
optional<std::vector<std::vector<int64_t>>> tmp_groups;
optional<std::vector<int64_t>> replica_group_ids;
optional<int64_t> channel_id;
optional<std::vector<int64_t>> dimensions;
optional<bool> constrain_layout;
optional<bool> use_global_device_ids;
attrs["replica_groups"] = {/*required=*/false,
AttrTy::kBracedInt64ListList, &tmp_groups};
attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions};
attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool,
&constrain_layout};
attrs["use_global_device_ids"] = {/*required=*/false, AttrTy::kBool,
&use_global_device_ids};
if (!ParseOperands(&operands, builder) || !parse_attributes()) {
return false;
}
std::vector<ReplicaGroup> replica_groups;
if (tmp_groups) {
replica_groups = CreateReplicaGroups(*tmp_groups);
}
if (opcode == HloOpcode::kAllGather) {
instruction = builder->AddInstruction(HloInstruction::CreateAllGather(
shape, operands, dimensions->at(0), replica_groups,
constrain_layout ? *constrain_layout : false, channel_id,
use_global_device_ids ? *use_global_device_ids : false));
} else {
instruction =
builder->AddInstruction(HloInstruction::CreateAllGatherStart(
shape, operands, dimensions->at(0), replica_groups,
constrain_layout ? *constrain_layout : false, channel_id,
use_global_device_ids ? *use_global_device_ids : false));
}
break;
}
case HloOpcode::kAllReduce:
case HloOpcode::kAllReduceStart:
case HloOpcode::kReduceScatter: {
optional<std::vector<std::vector<int64_t>>> tmp_groups;
optional<HloComputation*> to_apply;
optional<std::vector<int64_t>> replica_group_ids;
optional<int64_t> channel_id;
optional<bool> constrain_layout;
optional<bool> use_global_device_ids;
optional<std::vector<int64_t>> dimensions;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&to_apply};
attrs["replica_groups"] = {/*required=*/false,
AttrTy::kBracedInt64ListList, &tmp_groups};
attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool,
&constrain_layout};
attrs["use_global_device_ids"] = {/*required=*/false, AttrTy::kBool,
&use_global_device_ids};
if (opcode == HloOpcode::kReduceScatter) {
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions};
}
if (!ParseOperands(&operands, builder) || !parse_attributes()) {
return false;
}
std::vector<ReplicaGroup> replica_groups;
if (tmp_groups) {
replica_groups = CreateReplicaGroups(*tmp_groups);
}
if (opcode == HloOpcode::kAllReduce) {
instruction = builder->AddInstruction(HloInstruction::CreateAllReduce(
shape, operands, *to_apply, replica_groups,
constrain_layout ? *constrain_layout : false, channel_id,
use_global_device_ids ? *use_global_device_ids : false));
} else if (opcode == HloOpcode::kReduceScatter) {
instruction =
builder->AddInstruction(HloInstruction::CreateReduceScatter(
shape, operands, *to_apply, replica_groups,
constrain_layout ? *constrain_layout : false, channel_id,
use_global_device_ids ? *use_global_device_ids : false,
dimensions->at(0)));
} else {
instruction =
builder->AddInstruction(HloInstruction::CreateAllReduceStart(
shape, operands, *to_apply, replica_groups,
constrain_layout ? *constrain_layout : false, channel_id,
use_global_device_ids ? *use_global_device_ids : false));
}
break;
}
case HloOpcode::kAllToAll: {
optional<std::vector<std::vector<int64_t>>> tmp_groups;
attrs["replica_groups"] = {/*required=*/false,
AttrTy::kBracedInt64ListList, &tmp_groups};
optional<int64_t> channel_id;
attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
optional<std::vector<int64_t>> dimensions;
attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List,
&dimensions};
optional<bool> constrain_layout;
attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool,
&constrain_layout};
if (!ParseOperands(&operands, builder) || !parse_attributes() ||
(dimensions && dimensions->size() != 1)) {
return false;
}
std::vector<ReplicaGroup> replica_groups;
if (tmp_groups) {
replica_groups = CreateReplicaGroups(*tmp_groups);
}
optional<int64_t> split_dimension;
if (dimensions) {
split_dimension = dimensions->at(0);
}
instruction = builder->AddInstruction(HloInstruction::CreateAllToAll(
shape, operands, replica_groups,
constrain_layout ? *constrain_layout : false, channel_id,
split_dimension));
break;
}
case HloOpcode::kCollectivePermute:
case HloOpcode::kCollectivePermuteStart: {
optional<std::vector<std::vector<int64_t>>> source_targets;
attrs["source_target_pairs"] = {
/*required=*/true, AttrTy::kBracedInt64ListList, &source_targets};
optional<int64_t> channel_id;
attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
optional<std::vector<std::vector<int64_t>>> slice_sizes;
attrs["slice_sizes"] = {/*required=*/false, AttrTy::kBracedInt64ListList,
&slice_sizes};
if (!ParseOperands(&operands, builder) || !parse_attributes()) {
return false;
}
std::vector<std::pair<int64_t, int64_t>> pairs(source_targets->size());
for (int i = 0; i < pairs.size(); i++) {
if ((*source_targets)[i].size() != 2) {
return TokenError(
"expects 'source_target_pairs=' to be a list of pairs");
}
pairs[i].first = (*source_targets)[i][0];
pairs[i].second = (*source_targets)[i][1];
}
if (!slice_sizes.has_value()) {
if (operands.size() != 1) {
return TokenError(
"CollectivePermute and CollectivePermuteStart must "
"have exactly one operand (input buffer) unless "
"it performs dynamic-slice and in-place update.");
}
if (opcode == HloOpcode::kCollectivePermute) {
instruction =
builder->AddInstruction(HloInstruction::CreateCollectivePermute(
shape, operands[0], pairs, channel_id));
} else if (opcode == HloOpcode::kCollectivePermuteStart) {
instruction = builder->AddInstruction(
HloInstruction::CreateCollectivePermuteStart(shape, operands[0],
pairs, channel_id));
} else {
LOG(FATAL) << "Expect opcode to be CollectivePermute or "
"CollectivePermuteStart, but got "
<< HloOpcodeString(opcode);
}
} else {
if (operands.size() != 4) {
return TokenError(
"CollectivePermute and CollectivePermuteStart must "
"have exactly four operands for dynamic-slice and "
"in-place update.");
}
if (opcode == HloOpcode::kCollectivePermute) {
instruction =
builder->AddInstruction(HloInstruction::CreateCollectivePermute(
shape, operands[0], operands[1], operands[2], operands[3],
pairs, *slice_sizes, channel_id));
} else if (opcode == HloOpcode::kCollectivePermuteStart) {
instruction = builder->AddInstruction(
HloInstruction::CreateCollectivePermuteStart(
shape, operands[0], operands[1], operands[2], operands[3],
pairs, *slice_sizes, channel_id));
} else {
LOG(FATAL) << "Expect opcode to be CollectivePermute or "
"CollectivePermuteStart, but got "
<< HloOpcodeString(opcode);
}
}
break;
}
case HloOpcode::kCopyStart: {
// If the is_cross_program_prefetch attribute is not present then default
// to false.
optional<bool> is_cross_program_prefetch = false;
attrs["is_cross_program_prefetch"] = {/*required=*/false, AttrTy::kBool,
&is_cross_program_prefetch};
if (!ParseOperands(&operands, builder, /*expected_size=*/1) ||
!parse_attributes()) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateCopyStart(
shape, operands[0], *is_cross_program_prefetch));
break;
}
case HloOpcode::kReplicaId: {
if (!ParseOperands(&operands, builder, /*expected_size=*/0) ||
!parse_attributes()) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateReplicaId());
break;
}
case HloOpcode::kPartitionId: {
if (!ParseOperands(&operands, builder, /*expected_size=*/0) ||
!parse_attributes()) {
return false;
}
instruction =
builder->AddInstruction(HloInstruction::CreatePartitionId());
break;
}
case HloOpcode::kDynamicReshape: {
if (!ParseOperands(&operands, builder) || !parse_attributes()) {
return false;
}
instruction =
builder->AddInstruction(HloInstruction::CreateDynamicReshape(
shape, operands[0],
absl::Span<HloInstruction* const>(operands).subspan(1)));
break;
}
case HloOpcode::kReshape: {
optional<int64_t> inferred_dimension;
attrs["inferred_dimension"] = {/*required=*/false, AttrTy::kInt64,
&inferred_dimension};
if (!ParseOperands(&operands, builder, /*expected_size=*/1) ||
!parse_attributes()) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateReshape(
shape, operands[0], inferred_dimension.value_or(-1)));
break;
}
case HloOpcode::kAfterAll: {
if (!ParseOperands(&operands, builder) || !parse_attributes()) {
return false;
}
if (operands.empty()) {
instruction = builder->AddInstruction(HloInstruction::CreateToken());
} else {
instruction =
builder->AddInstruction(HloInstruction::CreateAfterAll(operands));
}
break;
}
case HloOpcode::kAddDependency: {
if (!ParseOperands(&operands, builder, /*expected_size=*/2) ||
!parse_attributes()) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateAddDependency(operands[0], operands[1]));
break;
}
case HloOpcode::kSort: {
optional<std::vector<int64_t>> dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions};
optional<bool> is_stable = false;
attrs["is_stable"] = {/*required=*/false, AttrTy::kBool, &is_stable};
optional<HloComputation*> to_apply;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&to_apply};
if (!ParseOperands(&operands, builder) || !parse_attributes() ||
dimensions->size() != 1) {
return false;
}
if (!maybe_infer_shape(
[&] {
absl::InlinedVector<const Shape*, 2> arg_shapes;
arg_shapes.reserve(operands.size());
for (auto* operand : operands) {
arg_shapes.push_back(&operand->shape());
}
return ShapeInference::InferVariadicOpShape(opcode, arg_shapes);
},
&shape)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateSort(shape, dimensions->at(0), operands,
to_apply.value(), is_stable.value()));
break;
}
case HloOpcode::kTuple: {
if (!ParseOperands(&operands, builder) || !parse_attributes()) {
return false;
}
if (!maybe_infer_shape(
[&] {
absl::InlinedVector<const Shape*, 2> arg_shapes;
arg_shapes.reserve(operands.size());
for (auto* operand : operands) {
arg_shapes.push_back(&operand->shape());
}
return ShapeInference::InferVariadicOpShape(opcode, arg_shapes);
},
&shape)) {
return false;
}
// HloInstruction::CreateTuple() infers the shape of the tuple from
// operands and should not be used here.
instruction = builder->AddInstruction(
HloInstruction::CreateVariadic(shape, HloOpcode::kTuple, operands));
break;
}
case HloOpcode::kWhile: {
optional<HloComputation*> condition;
optional<HloComputation*> body;
attrs["condition"] = {/*required=*/true, AttrTy::kHloComputation,
&condition};
attrs["body"] = {/*required=*/true, AttrTy::kHloComputation, &body};
if (!ParseOperands(&operands, builder, /*expected_size=*/1) ||
!parse_attributes()) {
return false;
}
if (!maybe_infer_shape(
[&] {
return ShapeInference::InferWhileShape(
condition.value()->ComputeProgramShape(),
body.value()->ComputeProgramShape(), operands[0]->shape());
},
&shape)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateWhile(
shape, *condition, *body, /*init=*/operands[0]));
break;
}
case HloOpcode::kRecv: {
optional<int64_t> channel_id;
// If the is_host_transfer attribute is not present then default to false.
optional<bool> is_host_transfer = false;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
&is_host_transfer};
if (!ParseOperands(&operands, builder, /*expected_size=*/1) ||
!parse_attributes()) {
return false;
}
// If the is_host_transfer attribute is not present then default to false.
instruction = builder->AddInstruction(HloInstruction::CreateRecv(
shape.tuple_shapes(0), operands[0], *channel_id, *is_host_transfer));
break;
}
case HloOpcode::kRecvDone: {
optional<int64_t> channel_id;
// If the is_host_transfer attribute is not present then default to false.
optional<bool> is_host_transfer = false;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
&is_host_transfer};
if (!ParseOperands(&operands, builder, /*expected_size=*/1) ||
!parse_attributes()) {
return false;
}
if (dynamic_cast<const HloChannelInstruction*>(operands[0]) == nullptr) {
return false;
}
if (channel_id != operands[0]->channel_id()) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateRecvDone(operands[0], *is_host_transfer));
break;
}
case HloOpcode::kSend: {
optional<int64_t> channel_id;
// If the is_host_transfer attribute is not present then default to false.
optional<bool> is_host_transfer = false;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
&is_host_transfer};
if (!ParseOperands(&operands, builder, /*expected_size=*/2) ||
!parse_attributes()) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateSend(
operands[0], operands[1], *channel_id, *is_host_transfer));
break;
}
case HloOpcode::kSendDone: {
optional<int64_t> channel_id;
// If the is_host_transfer attribute is not present then default to false.
optional<bool> is_host_transfer = false;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
&is_host_transfer};
if (!ParseOperands(&operands, builder, /*expected_size=*/1) ||
!parse_attributes()) {
return false;
}
if (dynamic_cast<const HloChannelInstruction*>(operands[0]) == nullptr) {
return false;
}
if (channel_id != operands[0]->channel_id()) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateSendDone(operands[0], *is_host_transfer));
break;
}
case HloOpcode::kGetTupleElement: {
optional<int64_t> index;
attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index};
if (!ParseOperands(&operands, builder, /*expected_size=*/1) ||
!parse_attributes()) {
return false;
}
if (!maybe_infer_shape(
[&] {
return ShapeUtil::GetTupleElementShape(operands[0]->shape(),
*index);
},
&shape)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateGetTupleElement(shape, operands[0], *index));
break;
}
case HloOpcode::kCall: {
optional<HloComputation*> to_apply;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&to_apply};
if (!ParseOperands(&operands, builder) || !parse_attributes()) {
return false;
}
if (!maybe_infer_shape(
[&] {
absl::InlinedVector<const Shape*, 2> arg_shapes;
arg_shapes.reserve(operands.size());
for (auto* operand : operands) {
arg_shapes.push_back(&operand->shape());
}
return ShapeInference::InferCallShape(
arg_shapes, to_apply.value()->ComputeProgramShape());
},
&shape)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateCall(shape, operands, *to_apply));
break;
}
case HloOpcode::kReduceWindow: {
optional<HloComputation*> reduce_computation;
optional<Window> window;
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&reduce_computation};
if (!ParseOperands(&operands, builder) || !parse_attributes()) {
return false;
}
if (!window) {
window.emplace();
}
if (operands.size() % 2) {
auto loc = lexer_.GetLoc();
return Error(loc, StrCat("expects an even number of operands, but has ",
operands.size(), " operands"));
}
if (!maybe_infer_shape(
[&] {
return ShapeInference::InferReduceWindowShape(
operands[0]->shape(), operands[1]->shape(), *window,
reduce_computation.value()->ComputeProgramShape());
},
&shape)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateReduceWindow(
shape, /*operands=*/
absl::Span<HloInstruction* const>(operands).subspan(
0, operands.size() / 2),
/*init_values=*/
absl::Span<HloInstruction* const>(operands).subspan(operands.size() /
2),
*window, *reduce_computation));
break;
}
case HloOpcode::kConvolution: {
optional<Window> window;
optional<ConvolutionDimensionNumbers> dnums;
optional<int64_t> feature_group_count;
optional<int64_t> batch_group_count;
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
attrs["dim_labels"] = {/*required=*/true,
AttrTy::kConvolutionDimensionNumbers, &dnums};
attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
&feature_group_count};
attrs["batch_group_count"] = {/*required=*/false, AttrTy::kInt64,
&batch_group_count};
optional<std::vector<PrecisionConfig::Precision>> operand_precision;
attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
&operand_precision};
if (!ParseOperands(&operands, builder, /*expected_size=*/2) ||
!parse_attributes()) {
return false;
}
if (!window) {
window.emplace();
}
if (!feature_group_count) {
feature_group_count = 1;
}
if (!batch_group_count) {
batch_group_count = 1;
}
PrecisionConfig precision_config;
if (operand_precision) {
*precision_config.mutable_operand_precision() = {
operand_precision->begin(), operand_precision->end()};
} else {
precision_config.mutable_operand_precision()->Resize(
operands.size(), PrecisionConfig::DEFAULT);
}
if (!maybe_infer_shape(
[&] {
return ShapeInference::InferConvolveShape(
operands[0]->shape(), operands[1]->shape(),
*feature_group_count, *batch_group_count, *window, *dnums,
/*preferred_element_type=*/absl::nullopt);
},
&shape)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateConvolve(
shape, /*lhs=*/operands[0], /*rhs=*/operands[1],
feature_group_count.value(), batch_group_count.value(), *window,
*dnums, precision_config));
break;
}
case HloOpcode::kFft: {
optional<FftType> fft_type;
optional<std::vector<int64_t>> fft_length;
attrs["fft_type"] = {/*required=*/true, AttrTy::kFftType, &fft_type};
attrs["fft_length"] = {/*required=*/true, AttrTy::kBracedInt64List,
&fft_length};
if (!ParseOperands(&operands, builder, /*expected_size=*/1) ||
!parse_attributes()) {
return false;
}
if (!maybe_infer_shape(
[&] {
return ShapeInference::InferFftShape(operands[0]->shape(),
*fft_type, *fft_length);
},
&shape)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateFft(
shape, operands[0], *fft_type, *fft_length));
break;
}
case HloOpcode::kTriangularSolve: {
TriangularSolveOptions options;
if (!ParseOperands(&operands, builder, /*expected_size=*/2) ||
(allow_attributes && !ParseAttributesAsProtoMessage(
/*non_proto_attrs=*/attrs, &options))) {
return false;
}
if (!maybe_infer_shape(
[&] {
return ShapeInference::InferTriangularSolveShape(
operands[0]->shape(), operands[1]->shape(), options);
},
&shape)) {
return false;
}
instruction =
builder->AddInstruction(HloInstruction::CreateTriangularSolve(
shape, operands[0], operands[1], options));
break;
}
case HloOpcode::kCompare: {
optional<ComparisonDirection> direction;
optional<Comparison::Type> type;
attrs["direction"] = {/*required=*/true, AttrTy::kComparisonDirection,
&direction};
attrs["type"] = {/*required=*/false, AttrTy::kComparisonType, &type};
if (!ParseOperands(&operands, builder, /*expected_size=*/2) ||
!parse_attributes()) {
return false;
}
if (!maybe_infer_shape(
[&] {
return ShapeInference::InferBinaryOpShape(opcode, operands[0],
operands[1]);
},
&shape)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateCompare(
shape, operands[0], operands[1], *direction, type));
break;
}
case HloOpcode::kCholesky: {
CholeskyOptions options;
if (!ParseOperands(&operands, builder, /*expected_size=*/1) ||
(allow_attributes && !ParseAttributesAsProtoMessage(
/*non_proto_attrs=*/attrs, &options))) {
return false;
}
if (!maybe_infer_shape(
[&] {
return ShapeInference::InferCholeskyShape(operands[0]->shape());
},
&shape)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateCholesky(shape, operands[0], options));
break;
}
case HloOpcode::kBroadcast: {
optional<std::vector<int64_t>> broadcast_dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&broadcast_dimensions};
if (!ParseOperands(&operands, builder, /*expected_size=*/1) ||
!parse_attributes()) {
return false;
}
if (!maybe_infer_shape(
[&] {
return ShapeInference::InferBroadcastShape(
operands[0]->shape(), *broadcast_dimensions);
},
&shape)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateBroadcast(
shape, operands[0], *broadcast_dimensions));
break;
}
case HloOpcode::kConcatenate: {
optional<std::vector<int64_t>> dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions};
if (!ParseOperands(&operands, builder) || !parse_attributes() ||
dimensions->size() != 1) {
return false;
}
if (!maybe_infer_shape(
[&] {
absl::InlinedVector<const Shape*, 2> arg_shapes;
arg_shapes.reserve(operands.size());
for (auto* operand : operands) {
arg_shapes.push_back(&operand->shape());
}
return ShapeInference::InferConcatOpShape(arg_shapes,
dimensions->at(0));
},
&shape)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateConcatenate(
shape, operands, dimensions->at(0)));
break;
}
case HloOpcode::kMap: {
optional<HloComputation*> to_apply;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&to_apply};
optional<std::vector<int64_t>> dimensions;
attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List,
&dimensions};
if (!ParseOperands(&operands, builder) || !parse_attributes()) {
return false;
}
if (!maybe_infer_shape(
[&] {
absl::InlinedVector<const Shape*, 2> arg_shapes;
arg_shapes.reserve(operands.size());
for (auto* operand : operands) {
arg_shapes.push_back(&operand->shape());
}
return ShapeInference::InferMapShape(
arg_shapes, to_apply.value()->ComputeProgramShape(),
*dimensions);
},
&shape)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateMap(shape, operands, *to_apply));
break;
}
case HloOpcode::kReduce: {
auto loc = lexer_.GetLoc();
optional<HloComputation*> reduce_computation;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&reduce_computation};
optional<std::vector<int64_t>> dimensions_to_reduce;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions_to_reduce};
if (!ParseOperands(&operands, builder) || !parse_attributes()) {
return false;
}
if (operands.size() % 2) {
return Error(loc, StrCat("expects an even number of operands, but has ",
operands.size(), " operands"));
}
if (!maybe_infer_shape(
[&] {
absl::InlinedVector<const Shape*, 2> arg_shapes;
arg_shapes.reserve(operands.size());
for (auto* operand : operands) {
arg_shapes.push_back(&operand->shape());
}
return ShapeInference::InferReduceShape(
arg_shapes, *dimensions_to_reduce,
reduce_computation.value()->ComputeProgramShape());
},
&shape)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateReduce(
shape, /*operands=*/
absl::Span<HloInstruction* const>(operands).subspan(
0, operands.size() / 2),
/*init_values=*/
absl::Span<HloInstruction* const>(operands).subspan(operands.size() /
2),
*dimensions_to_reduce, *reduce_computation));
break;
}
case HloOpcode::kReverse: {
optional<std::vector<int64_t>> dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions};
if (!ParseOperands(&operands, builder, /*expected_size=*/1) ||
!parse_attributes()) {
return false;
}
if (!maybe_infer_shape(
[&] {
return ShapeInference::InferReverseShape(operands[0]->shape(),
*dimensions);
},
&shape)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateReverse(shape, operands[0], *dimensions));
break;
}
case HloOpcode::kSelectAndScatter: {
optional<HloComputation*> select;
attrs["select"] = {/*required=*/true, AttrTy::kHloComputation, &select};
optional<HloComputation*> scatter;
attrs["scatter"] = {/*required=*/true, AttrTy::kHloComputation, &scatter};
optional<Window> window;
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
if (!ParseOperands(&operands, builder, /*expected_size=*/3) ||
!parse_attributes()) {
return false;
}
if (!window) {
window.emplace();
}
if (!maybe_infer_shape(
[&] {
return ShapeInference::InferSelectAndScatterShape(
operands[0]->shape(), select.value()->ComputeProgramShape(),
*window, operands[1]->shape(), operands[2]->shape(),
scatter.value()->ComputeProgramShape());
},
&shape)) {
return false;
}
instruction =
builder->AddInstruction(HloInstruction::CreateSelectAndScatter(
shape, /*operand=*/operands[0], *select, *window,
/*source=*/operands[1], /*init_value=*/operands[2], *scatter));
break;
}
case HloOpcode::kSlice: {
optional<SliceRanges> slice_ranges;
attrs["slice"] = {/*required=*/true, AttrTy::kSliceRanges, &slice_ranges};
if (!ParseOperands(&operands, builder, /*expected_size=*/1) ||
!parse_attributes()) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateSlice(
shape, operands[0], slice_ranges->starts, slice_ranges->limits,
slice_ranges->strides));
break;
}
case HloOpcode::kDynamicSlice: {
optional<std::vector<int64_t>> dynamic_slice_sizes;
attrs["dynamic_slice_sizes"] = {
/*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes};
LocTy loc = lexer_.GetLoc();
if (!ParseOperands(&operands, builder) || !parse_attributes()) {
return false;
}
if (operands.empty()) {
return Error(loc, "Expected at least one operand.");
}
if (!(operands.size() == 2 && operands[1]->shape().rank() == 1) &&
operands.size() != 1 + operands[0]->shape().rank()) {
return Error(loc, "Wrong number of operands.");
}
instruction = builder->AddInstruction(HloInstruction::CreateDynamicSlice(
shape, /*operand=*/operands[0],
/*start_indices=*/absl::MakeSpan(operands).subspan(1),
*dynamic_slice_sizes));
break;
}
case HloOpcode::kDynamicUpdateSlice: {
LocTy loc = lexer_.GetLoc();
if (!ParseOperands(&operands, builder) || !parse_attributes()) {
return false;
}
if (operands.size() < 2) {
return Error(loc, "Expected at least two operands.");
}
if (!(operands.size() == 3 && operands[2]->shape().rank() == 1) &&
operands.size() != 2 + operands[0]->shape().rank()) {
return Error(loc, "Wrong number of operands.");
}
instruction =
builder->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
shape, /*operand=*/operands[0], /*update=*/operands[1],
/*start_indices=*/absl::MakeSpan(operands).subspan(2)));
break;
}
case HloOpcode::kTranspose: {
optional<std::vector<int64_t>> dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions};
if (!ParseOperands(&operands, builder, /*expected_size=*/1) ||
!parse_attributes()) {
return false;
}
if (!maybe_infer_shape(
[&] {
return ShapeInference::InferTransposeShape(operands[0]->shape(),
*dimensions);
},
&shape)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateTranspose(shape, operands[0], *dimensions));
break;
}
case HloOpcode::kBatchNormTraining: {
optional<float> epsilon;
attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
optional<int64_t> feature_index;
attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
&feature_index};
if (!ParseOperands(&operands, builder, /*expected_size=*/3) ||
!parse_attributes()) {
return false;
}
if (!maybe_infer_shape(
[&] {
return ShapeInference::InferBatchNormTrainingShape(
operands[0]->shape(), operands[1]->shape(),
operands[2]->shape(), *feature_index);
},
&shape)) {
return false;
}
instruction =
builder->AddInstruction(HloInstruction::CreateBatchNormTraining(
shape, /*operand=*/operands[0], /*scale=*/operands[1],
/*offset=*/operands[2], *epsilon, *feature_index));
break;
}
case HloOpcode::kBatchNormInference: {
optional<float> epsilon;
attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
optional<int64_t> feature_index;
attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
&feature_index};
if (!ParseOperands(&operands, builder, /*expected_size=*/5) ||
!parse_attributes()) {
return false;
}
if (!maybe_infer_shape(
[&] {
return ShapeInference::InferBatchNormInferenceShape(
operands[0]->shape(), operands[1]->shape(),
operands[2]->shape(), operands[3]->shape(),
operands[4]->shape(), *feature_index);
},
&shape)) {
return false;
}
instruction =
builder->AddInstruction(HloInstruction::CreateBatchNormInference(
shape, /*operand=*/operands[0], /*scale=*/operands[1],
/*offset=*/operands[2], /*mean=*/operands[3],
/*variance=*/operands[4], *epsilon, *feature_index));
break;
}
case HloOpcode::kBatchNormGrad: {
optional<float> epsilon;
attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
optional<int64_t> feature_index;
attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
&feature_index};
if (!ParseOperands(&operands, builder, /*expected_size=*/5) ||
!parse_attributes()) {
return false;
}
if (!maybe_infer_shape(
[&] {
return ShapeInference::InferBatchNormGradShape(
operands[0]->shape(), operands[1]->shape(),
operands[2]->shape(), operands[3]->shape(),
operands[4]->shape(), *feature_index);
},
&shape)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateBatchNormGrad(
shape, /*operand=*/operands[0], /*scale=*/operands[1],
/*mean=*/operands[2], /*variance=*/operands[3],
/*grad_output=*/operands[4], *epsilon, *feature_index));
break;
}
case HloOpcode::kPad: {
optional<PaddingConfig> padding;
attrs["padding"] = {/*required=*/true, AttrTy::kPaddingConfig, &padding};
if (!ParseOperands(&operands, builder, /*expected_size=*/2) ||
!parse_attributes()) {
return false;
}
if (!maybe_infer_shape(
[&] {
return ShapeInference::InferPadShape(
operands[0]->shape(), operands[1]->shape(), *padding);
},
&shape)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreatePad(
shape, operands[0], /*padding_value=*/operands[1], *padding));
break;
}
case HloOpcode::kFusion: {
optional<HloComputation*> fusion_computation;
attrs["calls"] = {/*required=*/true, AttrTy::kHloComputation,
&fusion_computation};
optional<HloInstruction::FusionKind> fusion_kind;
attrs["kind"] = {/*required=*/true, AttrTy::kFusionKind, &fusion_kind};
if (!ParseOperands(&operands, builder) || !parse_attributes()) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateFusion(
shape, *fusion_kind, operands, *fusion_computation));
break;
}
case HloOpcode::kInfeed: {
optional<std::string> config;
attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config};
if (!ParseOperands(&operands, builder, /*expected_size=*/1) ||
!parse_attributes()) {
return false;
}
// We need to know the infeed data shape to construct the infeed
// instruction. This is the zero-th element of the tuple-shaped output of
// the infeed instruction. ShapeUtil::GetTupleElementShape will check fail
// if the shape is not a non-empty tuple, so add guard so an error message
// can be emitted instead of a check fail
if (!shape.IsTuple() && !ShapeUtil::IsEmptyTuple(shape)) {
return Error(lexer_.GetLoc(),
"infeed must have a non-empty tuple shape");
}
instruction = builder->AddInstruction(HloInstruction::CreateInfeed(
ShapeUtil::GetTupleElementShape(shape, 0), operands[0],
config ? *config : ""));
break;
}
case HloOpcode::kOutfeed: {
optional<std::string> config;
optional<Shape> outfeed_shape;
attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config};
attrs["outfeed_shape"] = {/*required=*/false, AttrTy::kShape,
&outfeed_shape};
if (!ParseOperands(&operands, builder, /*expected_size=*/2) ||
!parse_attributes()) {
return false;
}
HloInstruction* const outfeed_input = operands[0];
HloInstruction* const outfeed_token = operands[1];
const Shape shape =
outfeed_shape.has_value() ? *outfeed_shape : outfeed_input->shape();
instruction = builder->AddInstruction(HloInstruction::CreateOutfeed(
shape, outfeed_input, outfeed_token, config ? *config : ""));
break;
}
case HloOpcode::kRng: {
optional<RandomDistribution> distribution;
attrs["distribution"] = {/*required=*/true, AttrTy::kDistribution,
&distribution};
if (!ParseOperands(&operands, builder) || !parse_attributes()) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateRng(shape, *distribution, operands));
break;
}
case HloOpcode::kRngGetAndUpdateState: {
optional<int64_t> delta;
attrs["delta"] = {/*required=*/true, AttrTy::kInt64, &delta};
if (!ParseOperands(&operands, builder, /*expected_size=*/0) ||
!parse_attributes()) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateRngGetAndUpdateState(shape, *delta));
break;
}
case HloOpcode::kRngBitGenerator: {
optional<RandomAlgorithm> algorithm;
attrs["algorithm"] = {/*required=*/true, AttrTy::kRandomAlgorithm,
&algorithm};
if (!ParseOperands(&operands, builder) || !parse_attributes()) {
return false;
}
instruction =
builder->AddInstruction(HloInstruction::CreateRngBitGenerator(
shape, operands[0], *algorithm));
break;
}
case HloOpcode::kReducePrecision: {
optional<int64_t> exponent_bits;
optional<int64_t> mantissa_bits;
attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64,
&exponent_bits};
attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64,
&mantissa_bits};
if (!ParseOperands(&operands, builder, /*expected_size=*/1) ||
!parse_attributes()) {
return false;
}
instruction =
builder->AddInstruction(HloInstruction::CreateReducePrecision(
shape, operands[0], static_cast<int>(*exponent_bits),
static_cast<int>(*mantissa_bits)));
break;
}
case HloOpcode::kConditional: {
optional<HloComputation*> true_computation;
optional<HloComputation*> false_computation;
optional<std::vector<HloComputation*>> branch_computations;
if (!ParseOperands(&operands, builder)) {
return false;
}
if (!ShapeUtil::IsScalar(operands[0]->shape())) {
return Error(lexer_.GetLoc(), "The first operand must be a scalar");
}
const bool branch_index_is_bool =
operands[0]->shape().element_type() == PRED;
if (branch_index_is_bool) {
attrs["true_computation"] = {/*required=*/true, AttrTy::kHloComputation,
&true_computation};
attrs["false_computation"] = {
/*required=*/true, AttrTy::kHloComputation, &false_computation};
} else {
if (operands[0]->shape().element_type() != S32) {
return Error(lexer_.GetLoc(),
"The first operand must be a scalar of PRED or S32");
}
attrs["branch_computations"] = {/*required=*/true,
AttrTy::kBracedHloComputationList,
&branch_computations};
}
if (!parse_attributes()) {
return false;
}
if (branch_index_is_bool) {
branch_computations.emplace({*true_computation, *false_computation});
}
if (branch_computations->empty() ||
operands.size() != branch_computations->size() + 1) {
return false;
}
if (!maybe_infer_shape(
[&] {
absl::InlinedVector<ProgramShape, 2> branch_computation_shapes;
branch_computation_shapes.reserve(branch_computations->size());
for (auto* computation : *branch_computations) {
branch_computation_shapes.push_back(
computation->ComputeProgramShape());
}
absl::InlinedVector<Shape, 2> branch_operand_shapes;
branch_operand_shapes.reserve(operands.size() - 1);
for (int i = 1; i < operands.size(); ++i) {
branch_operand_shapes.push_back(operands[i]->shape());
}
return ShapeInference::InferConditionalShape(
operands[0]->shape(), branch_computation_shapes,
branch_operand_shapes);
},
&shape)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateConditional(
shape, /*branch_index=*/operands[0],
absl::MakeSpan(*branch_computations),
absl::MakeSpan(operands).subspan(1)));
break;
}
case HloOpcode::kCustomCall: {
optional<std::string> custom_call_target;
optional<Window> window;
optional<ConvolutionDimensionNumbers> dnums;
optional<int64_t> feature_group_count;
optional<int64_t> batch_group_count;
optional<std::vector<Shape>> operand_layout_constraints;
optional<bool> custom_call_has_side_effect;
optional<HloComputation*> to_apply;
optional<
std::vector<std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>>
output_to_operand_aliasing;
optional<PaddingType> padding_type;
optional<std::vector<HloComputation*>> called_computations;
optional<CustomCallSchedule> custom_call_schedule;
optional<CustomCallApiVersion> api_version;
attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
&custom_call_target};
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
attrs["dim_labels"] = {/*required=*/false,
AttrTy::kConvolutionDimensionNumbers, &dnums};
attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
&feature_group_count};
attrs["batch_group_count"] = {/*required=*/false, AttrTy::kInt64,
&batch_group_count};
attrs["operand_layout_constraints"] = {
/*required=*/false, AttrTy::kShapeList, &operand_layout_constraints};
attrs["custom_call_has_side_effect"] = {/*required=*/false, AttrTy::kBool,
&custom_call_has_side_effect};
attrs["to_apply"] = {/*required=*/false, AttrTy::kHloComputation,
&to_apply};
attrs["called_computations"] = {/*required=*/false,
AttrTy::kBracedHloComputationList,
&called_computations};
attrs["output_to_operand_aliasing"] = {/*required=*/false,
AttrTy::kInstructionAliasing,
&output_to_operand_aliasing};
attrs["padding_type"] = {/*required=*/false, AttrTy::kPaddingType,
&padding_type};
optional<Literal> literal;
attrs["literal"] = {/*required=*/false, AttrTy::kLiteral, &literal};
optional<std::vector<PrecisionConfig::Precision>> operand_precision;
attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
&operand_precision};
if (called_computations.has_value() && to_apply.has_value()) {
return Error(lexer_.GetLoc(),
"A single instruction can't have both to_apply and "
"calls field");
}
attrs["schedule"] = {/*required=*/false, AttrTy::kCustomCallSchedule,
&custom_call_schedule};
attrs["api_version"] = {/*required=*/false, AttrTy::kCustomCallApiVersion,
&api_version};
if (!ParseOperands(&operands, builder) || !parse_attributes()) {
return false;
}
if (api_version.has_value() &&
*api_version == CustomCallApiVersion::API_VERSION_UNSPECIFIED) {
return Error(lexer_.GetLoc(),
StrCat("Invalid API version: ",
CustomCallApiVersion_Name(*api_version)));
}
if (operand_layout_constraints.has_value()) {
if (!LayoutUtil::HasLayout(shape)) {
return Error(lexer_.GetLoc(),
"Layout must be set on layout-constrained custom call");
}
if (operands.size() != operand_layout_constraints->size()) {
return Error(lexer_.GetLoc(),
StrCat("Expected ", operands.size(),
" operand layout constraints, ",
operand_layout_constraints->size(), " given"));
}
for (int64_t i = 0; i < operands.size(); ++i) {
const Shape& operand_shape_with_layout =
(*operand_layout_constraints)[i];
if (!LayoutUtil::HasLayout(operand_shape_with_layout)) {
return Error(lexer_.GetLoc(),
StrCat("Operand layout constraint shape ",
ShapeUtil::HumanStringWithLayout(
operand_shape_with_layout),
" for operand ", i, " does not have a layout"));
}
if (!ShapeUtil::Compatible(operand_shape_with_layout,
operands[i]->shape())) {
return Error(
lexer_.GetLoc(),
StrCat(
"Operand layout constraint shape ",
ShapeUtil::HumanStringWithLayout(operand_shape_with_layout),
" for operand ", i,
" is not compatible with operand shape ",
ShapeUtil::HumanStringWithLayout(operands[i]->shape())));
}
}
instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
shape, operands, *custom_call_target, *operand_layout_constraints,
backend_config ? *backend_config : ""));
} else {
if (to_apply.has_value()) {
instruction =
builder->AddInstruction(HloInstruction::CreateCustomCall(
shape, operands, *to_apply, *custom_call_target,
backend_config ? *backend_config : ""));
} else if (called_computations.has_value()) {
instruction =
builder->AddInstruction(HloInstruction::CreateCustomCall(
shape, operands, *called_computations, *custom_call_target,
backend_config ? *backend_config : ""));
} else {
instruction =
builder->AddInstruction(HloInstruction::CreateCustomCall(
shape, operands, *custom_call_target,
backend_config ? *backend_config : ""));
}
}
auto custom_call_instr = Cast<HloCustomCallInstruction>(instruction);
if (window.has_value()) {
custom_call_instr->set_window(*window);
}
if (dnums.has_value()) {
custom_call_instr->set_convolution_dimension_numbers(*dnums);
}
if (feature_group_count.has_value()) {
custom_call_instr->set_feature_group_count(*feature_group_count);
}
if (batch_group_count.has_value()) {
custom_call_instr->set_batch_group_count(*batch_group_count);
}
if (padding_type.has_value()) {
custom_call_instr->set_padding_type(*padding_type);
}
if (custom_call_has_side_effect.has_value()) {
custom_call_instr->set_custom_call_has_side_effect(
*custom_call_has_side_effect);
}
if (custom_call_schedule.has_value()) {
custom_call_instr->set_custom_call_schedule(*custom_call_schedule);
}
if (api_version.has_value()) {
custom_call_instr->set_api_version(*api_version);
}
if (output_to_operand_aliasing.has_value()) {
custom_call_instr->set_output_to_operand_aliasing(
std::move(*output_to_operand_aliasing));
}
if (literal.has_value()) {
custom_call_instr->set_literal(std::move(*literal));
}
PrecisionConfig precision_config;
if (operand_precision) {
*precision_config.mutable_operand_precision() = {
operand_precision->begin(), operand_precision->end()};
} else {
precision_config.mutable_operand_precision()->Resize(
operands.size(), PrecisionConfig::DEFAULT);
}
*custom_call_instr->mutable_precision_config() = precision_config;
break;
}
case HloOpcode::kDot: {
optional<std::vector<int64_t>> lhs_contracting_dims;
attrs["lhs_contracting_dims"] = {
/*required=*/false, AttrTy::kBracedInt64List, &lhs_contracting_dims};
optional<std::vector<int64_t>> rhs_contracting_dims;
attrs["rhs_contracting_dims"] = {
/*required=*/false, AttrTy::kBracedInt64List, &rhs_contracting_dims};
optional<std::vector<int64_t>> lhs_batch_dims;
attrs["lhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
&lhs_batch_dims};
optional<std::vector<int64_t>> rhs_batch_dims;
attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
&rhs_batch_dims};
optional<std::vector<PrecisionConfig::Precision>> operand_precision;
attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
&operand_precision};
if (!ParseOperands(&operands, builder, /*expected_size=*/2) ||
!parse_attributes()) {
return false;
}
DotDimensionNumbers dnum;
if (lhs_contracting_dims) {
*dnum.mutable_lhs_contracting_dimensions() = {
lhs_contracting_dims->begin(), lhs_contracting_dims->end()};
}
if (rhs_contracting_dims) {
*dnum.mutable_rhs_contracting_dimensions() = {
rhs_contracting_dims->begin(), rhs_contracting_dims->end()};
}
if (lhs_batch_dims) {
*dnum.mutable_lhs_batch_dimensions() = {lhs_batch_dims->begin(),
lhs_batch_dims->end()};
}
if (rhs_batch_dims) {
*dnum.mutable_rhs_batch_dimensions() = {rhs_batch_dims->begin(),
rhs_batch_dims->end()};
}
PrecisionConfig precision_config;
if (operand_precision) {
*precision_config.mutable_operand_precision() = {
operand_precision->begin(), operand_precision->end()};
} else {
precision_config.mutable_operand_precision()->Resize(
operands.size(), PrecisionConfig::DEFAULT);
}
if (!maybe_infer_shape(
[&] {
return ShapeInference::InferDotOpShape(
operands[0]->shape(), operands[1]->shape(), dnum,
/*preferred_element_type=*/absl::nullopt);
},
&shape)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateDot(
shape, operands[0], operands[1], dnum, precision_config));
break;
}
case HloOpcode::kGather: {
optional<std::vector<int64_t>> offset_dims;
attrs["offset_dims"] = {/*required=*/true, AttrTy::kBracedInt64List,
&offset_dims};
optional<std::vector<int64_t>> collapsed_slice_dims;
attrs["collapsed_slice_dims"] = {
/*required=*/true, AttrTy::kBracedInt64List, &collapsed_slice_dims};
optional<std::vector<int64_t>> start_index_map;
attrs["start_index_map"] = {/*required=*/true, AttrTy::kBracedInt64List,
&start_index_map};
optional<int64_t> index_vector_dim;
attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
&index_vector_dim};
optional<std::vector<int64_t>> slice_sizes;
attrs["slice_sizes"] = {/*required=*/true, AttrTy::kBracedInt64List,
&slice_sizes};
optional<bool> indices_are_sorted = false;
attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool,
&indices_are_sorted};
if (!ParseOperands(&operands, builder, /*expected_size=*/2) ||
!parse_attributes()) {
return false;
}
GatherDimensionNumbers dim_numbers =
HloGatherInstruction::MakeGatherDimNumbers(
/*offset_dims=*/*offset_dims,
/*collapsed_slice_dims=*/*collapsed_slice_dims,
/*start_index_map=*/*start_index_map,
/*index_vector_dim=*/*index_vector_dim);
if (!maybe_infer_shape(
[&] {
return ShapeInference::InferGatherShape(
operands[0]->shape(), operands[1]->shape(), dim_numbers,
*slice_sizes);
},
&shape)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateGather(
shape, /*operand=*/operands[0], /*start_indices=*/operands[1],
dim_numbers, *slice_sizes, indices_are_sorted.value()));
break;
}
case HloOpcode::kScatter: {
optional<std::vector<int64_t>> update_window_dims;
attrs["update_window_dims"] = {
/*required=*/true, AttrTy::kBracedInt64List, &update_window_dims};
optional<std::vector<int64_t>> inserted_window_dims;
attrs["inserted_window_dims"] = {
/*required=*/true, AttrTy::kBracedInt64List, &inserted_window_dims};
optional<std::vector<int64_t>> scatter_dims_to_operand_dims;
attrs["scatter_dims_to_operand_dims"] = {/*required=*/true,
AttrTy::kBracedInt64List,
&scatter_dims_to_operand_dims};
optional<int64_t> index_vector_dim;
attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
&index_vector_dim};
optional<HloComputation*> update_computation;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&update_computation};
optional<bool> indices_are_sorted = false;
attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool,
&indices_are_sorted};
optional<bool> unique_indices = false;
attrs["unique_indices"] = {/*required=*/false, AttrTy::kBool,
&unique_indices};
if (!ParseOperands(&operands, builder, /*expected_size=*/3) ||
!parse_attributes()) {
return false;
}
ScatterDimensionNumbers dim_numbers =
HloScatterInstruction::MakeScatterDimNumbers(
/*update_window_dims=*/*update_window_dims,
/*inserted_window_dims=*/*inserted_window_dims,
/*scatter_dims_to_operand_dims=*/*scatter_dims_to_operand_dims,
/*index_vector_dim=*/*index_vector_dim);
if (!maybe_infer_shape(
[&] {
return ShapeInference::InferScatterShape(
operands[0]->shape(), operands[1]->shape(),
operands[2]->shape(),
update_computation.value()->ComputeProgramShape(),
dim_numbers);
},
&shape)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateScatter(
shape, /*operand=*/operands[0], /*scatter_indices=*/operands[1],
/*updates=*/operands[2], *update_computation, dim_numbers,
indices_are_sorted.value(), unique_indices.value()));
break;
}
case HloOpcode::kDomain: {
DomainData domain;
attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain};
if (!ParseOperands(&operands, builder, /*expected_size=*/1) ||
!parse_attributes()) {
return false;
}
if (!maybe_infer_shape(
[&] {
return ShapeInference::InferUnaryOpShape(opcode, operands[0]);
},
&shape)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateDomain(
shape, operands[0], std::move(domain.exit_metadata),
std::move(domain.entry_metadata)));
break;
}
case HloOpcode::kTrace:
return TokenError(StrCat("parsing not yet implemented for op: ",
HloOpcodeString(opcode)));
case HloOpcode::kGetDimensionSize: {
optional<std::vector<int64_t>> dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions};
if (!ParseOperands(&operands, builder, /*expected_size=*/1) ||
!parse_attributes()) {
return false;
}
if (!maybe_infer_shape(
[&] {
return ShapeInference::InferGetDimensionSizeShape(
operands[0]->shape(), dimensions->at(0));
},
&shape)) {
return false;
}
instruction =
builder->AddInstruction(HloInstruction::CreateGetDimensionSize(
shape, operands[0], (*dimensions)[0]));
break;
}
case HloOpcode::kSetDimensionSize: {
optional<std::vector<int64_t>> dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions};
if (!ParseOperands(&operands, builder, /*expected_size=*/2) ||
!parse_attributes()) {
return false;
}
if (!maybe_infer_shape(
[&] {
return ShapeInference::InferSetDimensionSizeShape(
operands[0]->shape(), operands[1]->shape(),
dimensions->at(0));
},
&shape)) {
return false;
}
instruction =
builder->AddInstruction(HloInstruction::CreateSetDimensionSize(
shape, operands[0], operands[1], (*dimensions)[0]));
break;
}
}
// Generate a unique name if the name is empty. This is used for nested
// instructions (e.g. the `max` in add(max(x, y), z)).
//
// Otherwise, register the given name with the name uniquer.
if (name.empty()) {
name = name_uniquer_.GetUniqueName(
absl::StrCat(HloOpcodeString(instruction->opcode()), ".anon"));
} else {
name_uniquer_.GetUniqueName(name);
}
instruction->SetAndSanitizeName(name);
if (instruction->name() != name) {
return Error(name_loc,
StrCat("illegal instruction name: ", name,
"; suggest renaming to: ", instruction->name()));
}
// Add shared attributes like metadata to the instruction, if they were seen.
if (sharding) {
instruction->set_sharding(
HloSharding::FromProto(sharding.value()).ValueOrDie());
}
if (parameter_replication) {
int leaf_count = ShapeUtil::GetLeafCount(instruction->shape());
const auto& replicated =
parameter_replication->replicated_at_leaf_buffers();
if (leaf_count != replicated.size()) {
return Error(lexer_.GetLoc(),
StrCat("parameter has ", leaf_count,
" leaf buffers, but parameter_replication has ",
replicated.size(), " elements."));
}
instruction->set_parameter_replicated_at_leaf_buffers(replicated);
}
if (predecessors) {
for (auto* pre : *predecessors) {
Status status = pre->AddControlDependencyTo(instruction);
if (!status.ok()) {
return Error(name_loc, StrCat("error adding control dependency for: ",
name, " status: ", status.ToString()));
}
}
}
if (metadata) {
instruction->set_metadata(*metadata);
}
if (backend_config) {
instruction->set_raw_backend_config_string(std::move(*backend_config));
}
if (outer_dimension_partitions) {
instruction->set_outer_dimension_partitions(*outer_dimension_partitions);
}
if (frontend_attributes) {
instruction->set_frontend_attributes(*frontend_attributes);
}
return AddInstruction(name, instruction, name_loc);
} // NOLINT(readability/fn_size)
// ::= '{' (single_sharding | tuple_sharding) '}'
//
// tuple_sharding ::= single_sharding* (',' single_sharding)*
bool HloParserImpl::ParseSharding(OpSharding* sharding) {
// A single sharding starts with '{' and is not followed by '{'.
// A tuple sharding starts with '{' and is followed by '{', or is '{''}' for
// an empty tuple.
if (!ParseToken(TokKind::kLbrace,
"expected '{' to start sharding attribute")) {
return false;
}
if (lexer_.GetKind() != TokKind::kLbrace &&
lexer_.GetKind() != TokKind::kRbrace) {
return ParseSingleSharding(sharding, /*lbrace_pre_lexed=*/true);
}
// Tuple sharding.
// Allow empty tuple shardings.
if (lexer_.GetKind() != TokKind::kRbrace) {
do {
if (!ParseSingleSharding(sharding->add_tuple_shardings(),
/*lbrace_pre_lexed=*/false)) {
return false;
}
} while (EatIfPresent(TokKind::kComma));
}
sharding->set_type(OpSharding::TUPLE);
return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute");
}
// frontend_attributes ::= '{' attributes '}'
// attributes
// ::= /*empty*/
// ::= attribute '=' value (',' attribute '=' value)*
bool HloParserImpl::ParseFrontendAttributes(
FrontendAttributes* frontend_attributes) {
CHECK(frontend_attributes != nullptr);
if (!ParseToken(TokKind::kLbrace,
"expected '{' to start frontend attributes")) {
return false;
}
if (lexer_.GetKind() == TokKind::kRbrace) {
// empty
} else {
do {
std::string attribute;
if (!ParseAttributeName(&attribute)) {
return false;
}
if (lexer_.GetKind() != TokKind::kString) {
return false;
}
(*frontend_attributes->mutable_map())[attribute] = lexer_.GetStrVal();
lexer_.Lex();
} while (EatIfPresent(TokKind::kComma));
}
return ParseToken(TokKind::kRbrace,
"expects '}' at the end of frontend attributes");
}
// ::= '{' 'replicated'? 'manual'? 'maximal'? ('device=' int)? shape?
// ('devices=' ('[' dims ']')* device_list)?
// ('metadata=' metadata)* '}'
//
// dims ::= int_list device_list ::= int_list
// metadata ::= single_metadata |
// ('{' [single_metadata (',' single_metadata)*] '}')
// last_tile_dims ::= sharding_type_list
bool HloParserImpl::ParseSingleSharding(OpSharding* sharding,
bool lbrace_pre_lexed) {
if (!lbrace_pre_lexed &&
!ParseToken(TokKind::kLbrace,
"expected '{' to start sharding attribute")) {
return false;
}
LocTy loc = lexer_.GetLoc();
bool maximal = false;
bool replicated = false;
bool manual = false;
bool last_tile_dim_replicate = false;
bool last_tile_dims = false;
std::vector<int64_t> devices;
std::vector<int64_t> tile_assignment_dimensions;
std::vector<OpSharding::Type> subgroup_types;
while (lexer_.GetKind() != TokKind::kRbrace) {
switch (lexer_.GetKind()) {
case TokKind::kw_maximal:
maximal = true;
lexer_.Lex();
break;
case TokKind::kw_replicated:
replicated = true;
lexer_.Lex();
break;
case TokKind::kw_manual:
manual = true;
lexer_.Lex();
break;
case TokKind::kAttributeName: {
if (lexer_.GetStrVal() == "device") {
if (lexer_.Lex() != TokKind::kInt) {
return TokenError("device= attribute must be an integer");
}
devices = {lexer_.GetInt64Val()};
lexer_.Lex();
} else if (lexer_.GetStrVal() == "devices") {
lexer_.Lex();
if (!ParseToken(TokKind::kLsquare,
"expected '[' to start sharding devices shape")) {
return false;
}
do {
int64_t dim;
if (!ParseInt64(&dim)) {
return false;
}
tile_assignment_dimensions.push_back(dim);
} while (EatIfPresent(TokKind::kComma));
if (!ParseToken(TokKind::kRsquare,
"expected ']' to start sharding devices shape")) {
return false;
}
do {
int64_t device;
if (!ParseInt64(&device)) {
return false;
}
devices.push_back(device);
} while (EatIfPresent(TokKind::kComma));
} else if (lexer_.GetStrVal() == "metadata") {
lexer_.Lex();
if (!ParseSingleOrListMetadata(sharding->mutable_metadata())) {
return false;
}
} else if (lexer_.GetStrVal() == "last_tile_dims") {
last_tile_dims = true;
lexer_.Lex();
if (!ParseListShardingType(&subgroup_types)) {
return false;
}
} else {
return TokenError(
"unknown attribute in sharding: expected device=, devices= "
"metadata= or last_tile_dims= ");
}
break;
}
case TokKind::kw_last_tile_dim_replicate:
last_tile_dim_replicate = true;
lexer_.Lex();
break;
case TokKind::kRbrace:
break;
default:
return TokenError("unexpected token");
}
}
if (replicated) {
if (!devices.empty()) {
return Error(loc,
"replicated shardings should not have any devices assigned");
}
sharding->set_type(OpSharding::REPLICATED);
} else if (maximal) {
if (devices.size() != 1) {
return Error(loc,
"maximal shardings should have exactly one device assigned");
}
sharding->set_type(OpSharding::MAXIMAL);
sharding->add_tile_assignment_devices(devices[0]);
} else if (manual) {
if (!devices.empty()) {
return Error(loc,
"manual shardings should not have any devices assigned");
}
sharding->set_type(OpSharding::MANUAL);
} else {
if (devices.size() <= 1) {
return Error(
loc, "non-maximal shardings must have more than one device assigned");
}
if (tile_assignment_dimensions.empty()) {
return Error(
loc,
"non-maximal shardings must have a tile assignment list including "
"dimensions");
}
sharding->set_type(OpSharding::OTHER);
for (int64_t dim : tile_assignment_dimensions) {
sharding->add_tile_assignment_dimensions(dim);
}
for (int64_t device : devices) {
sharding->add_tile_assignment_devices(device);
}
if (last_tile_dims) {
for (OpSharding::Type type : subgroup_types) {
sharding->add_last_tile_dims(type);
}
} else {
sharding->set_replicate_on_last_tile_dim(last_tile_dim_replicate);
}
}
lexer_.Lex();
return true;
}
// parameter_replication ::=
// '{' ('true' | 'false')* (',' ('true' | 'false'))* '}'
bool HloParserImpl::ParseParameterReplication(
ParameterReplication* parameter_replication) {
if (!ParseToken(TokKind::kLbrace,
"expected '{' to start parameter_replication attribute")) {
return false;
}
if (lexer_.GetKind() != TokKind::kRbrace) {
do {
if (lexer_.GetKind() == TokKind::kw_true) {
parameter_replication->add_replicated_at_leaf_buffers(true);
} else if (lexer_.GetKind() == TokKind::kw_false) {
parameter_replication->add_replicated_at_leaf_buffers(false);
} else {
return false;
}
lexer_.Lex();
} while (EatIfPresent(TokKind::kComma));
}
return ParseToken(TokKind::kRbrace,
"expected '}' to end parameter_replication attribute");
}
// replica_groups ::='{' int64_tlist_elements '}'
// int64_tlist_elements
// ::= /*empty*/
// ::= int64_tlist (',' int64_tlist)*
// int64_tlist ::= '{' int64_elements '}'
// int64_elements
// ::= /*empty*/
// ::= int64_val (',' int64_val)*
bool HloParserImpl::ParseReplicaGroupsOnly(
std::vector<ReplicaGroup>* replica_groups) {
std::vector<std::vector<int64_t>> result;
if (!ParseInt64ListList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
&result)) {
return false;
}
*replica_groups = CreateReplicaGroups(result);
return true;
}
// domain ::= '{' 'kind=' domain_kind ',' 'entry=' entry_sharding ','
// 'exit=' exit_sharding '}'
bool HloParserImpl::ParseDomain(DomainData* domain) {
absl::flat_hash_map<std::string, AttrConfig> attrs;
optional<std::string> kind;
optional<OpSharding> entry_sharding;
optional<OpSharding> exit_sharding;
attrs["kind"] = {/*required=*/true, AttrTy::kString, &kind};
attrs["entry"] = {/*required=*/true, AttrTy::kSharding, &entry_sharding};
attrs["exit"] = {/*required=*/true, AttrTy::kSharding, &exit_sharding};
if (!ParseSubAttributes(attrs)) {
return false;
}
if (*kind == ShardingMetadata::KindName()) {
auto entry_sharding_ptr = absl::make_unique<HloSharding>(
HloSharding::FromProto(*entry_sharding).ValueOrDie());
auto exit_sharding_ptr = absl::make_unique<HloSharding>(
HloSharding::FromProto(*exit_sharding).ValueOrDie());
domain->entry_metadata =
absl::make_unique<ShardingMetadata>(std::move(entry_sharding_ptr));
domain->exit_metadata =
absl::make_unique<ShardingMetadata>(std::move(exit_sharding_ptr));
} else {
return TokenError(StrCat("unsupported domain kind: ", *kind));
}
return true;
}
// '{' name+ '}'
bool HloParserImpl::ParseInstructionNames(
std::vector<HloInstruction*>* instructions) {
if (!ParseToken(TokKind::kLbrace,
"expects '{' at the beginning of instruction name list")) {
return false;
}
LocTy loc = lexer_.GetLoc();
do {
std::string name;
if (!ParseName(&name)) {
return Error(loc, "expects a instruction name");
}
std::pair<HloInstruction*, LocTy>* instr = FindInstruction(name);
if (!instr) {
return TokenError(StrFormat("instruction '%s' is not defined", name));
}
instructions->push_back(instr->first);
} while (EatIfPresent(TokKind::kComma));
return ParseToken(TokKind::kRbrace,
"expects '}' at the end of instruction name list");
}
bool HloParserImpl::SetValueInLiteral(LocTy loc, int64_t value, int64_t index,
Literal* literal) {
const Shape& shape = literal->shape();
switch (shape.element_type()) {
case S8:
return SetValueInLiteralHelper<int8_t>(loc, value, index, literal);
case S16:
return SetValueInLiteralHelper<int16_t>(loc, value, index, literal);
case S32:
return SetValueInLiteralHelper<int32_t>(loc, value, index, literal);
case S64:
return SetValueInLiteralHelper<int64_t>(loc, value, index, literal);
case U8:
return SetValueInLiteralHelper<uint8_t>(loc, value, index, literal);
case U16:
return SetValueInLiteralHelper<uint16_t>(loc, value, index, literal);
case U32:
return SetValueInLiteralHelper<uint32_t>(loc, value, index, literal);
case U64:
return SetValueInLiteralHelper<uint64_t>(loc, value, index, literal);
case PRED:
// Bool type literals with rank >= 1 are printed in 0s and 1s.
return SetValueInLiteralHelper<bool>(loc, static_cast<bool>(value), index,
literal);
default:
LOG(FATAL) << "unknown integral primitive type "
<< PrimitiveType_Name(shape.element_type());
}
}
bool HloParserImpl::SetValueInLiteral(LocTy loc, double value, int64_t index,
Literal* literal) {
const Shape& shape = literal->shape();
switch (shape.element_type()) {
case F16:
return SetValueInLiteralHelper<Eigen::half>(loc, value, index, literal);
case BF16:
return SetValueInLiteralHelper<tensorflow::bfloat16>(loc, value, index,
literal);
case F32:
return SetValueInLiteralHelper<float>(loc, value, index, literal);
case F64:
return SetValueInLiteralHelper<double>(loc, value, index, literal);
default:
LOG(FATAL) << "unknown floating point primitive type "
<< PrimitiveType_Name(shape.element_type());
}
}
bool HloParserImpl::SetValueInLiteral(LocTy loc, bool value, int64_t index,
Literal* literal) {
const Shape& shape = literal->shape();
switch (shape.element_type()) {
case PRED:
return SetValueInLiteralHelper<bool>(loc, value, index, literal);
default:
LOG(FATAL) << PrimitiveType_Name(shape.element_type())
<< " is not PRED type";
}
}
bool HloParserImpl::SetValueInLiteral(LocTy loc, std::complex<double> value,
int64_t index, Literal* literal) {
const Shape& shape = literal->shape();
switch (shape.element_type()) {
case C64:
return SetValueInLiteralHelper<std::complex<float>>(loc, value, index,
literal);
case C128:
return SetValueInLiteralHelper<std::complex<double>>(loc, value, index,
literal);
default:
LOG(FATAL) << PrimitiveType_Name(shape.element_type())
<< " is not a complex type";
}
}
template <typename T>
std::string StringifyValue(T val) {
return StrCat(val);
}
template <>
std::string StringifyValue(std::complex<double> val) {
return StrFormat("(%f, %f)", std::real(val), std::imag(val));
}
// Evaluates to V when T == U.
template <typename T, typename U, typename V>
using EnableIfSameWithType = std::enable_if_t<std::is_same<T, U>::value, V>;
template <class T, EnableIfSameWithType<T, bool, bool> = false>
uint64_t GetNanPayload(T val) {
return 0;
}
template <class T, EnableIfSameWithType<T, int64_t, bool> = false>
uint64_t GetNanPayload(T val) {
return 0;
}
template <class T, EnableIfSameWithType<T, double, bool> = false>
uint64_t GetNanPayload(T val) {
auto rep = absl::bit_cast<uint64_t>(val);
if (auto payload = rep & NanPayloadBitMask<double>()) {
return payload;
}
return QuietNanWithoutPayload<double>();
}
template <typename LiteralNativeT, typename LiteralComponentT>
EnableIfSameWithType<LiteralNativeT, LiteralComponentT, LiteralNativeT>
LiteralNativeFromRealImag(LiteralComponentT real, LiteralComponentT imag) {
return real;
}
template <typename LiteralNativeT, typename LiteralComponentT>
EnableIfSameWithType<LiteralNativeT, std::complex<LiteralComponentT>,
LiteralNativeT>
LiteralNativeFromRealImag(LiteralComponentT real, LiteralComponentT imag) {
return LiteralNativeT(real, imag);
}
template <typename T>
struct ComponentType {
using Type = T;
};
template <typename T>
struct ComponentType<std::complex<T>> {
using Type = T;
};
template <typename T>
T GetReal(T value) {
return value;
}
template <typename T>
T GetReal(std::complex<T> value) {
return value.real();
}
template <typename T>
T GetImag(T value) {
return 0;
}
template <typename T>
T GetImag(std::complex<T> value) {
return value.imag();
}
template <typename LiteralNativeT, typename ParsedElemT>
bool HloParserImpl::SetValueInLiteralHelper(LocTy loc, ParsedElemT value,
int64_t index, Literal* literal) {
if (!CheckParsedValueIsInRange<LiteralNativeT>(loc, value)) {
return false;
}
// Check that the index is in range and assign into the literal
if (index >= ShapeUtil::ElementsIn(literal->shape())) {
return Error(loc, StrCat("tries to set value ", StringifyValue(value),
" to a literal in shape ",
ShapeUtil::HumanString(literal->shape()),
" at linear index ", index,
", but the index is out of range"));
}
using ParsedElemComponentT = typename ComponentType<ParsedElemT>::Type;
using LiteralNativeComponentT = typename ComponentType<LiteralNativeT>::Type;
const auto handle_nan = [this, literal, index, loc](
ParsedElemComponentT parsed_value_component,
LiteralNativeComponentT*
literal_value_component) {
if (!std::isnan(static_cast<double>(parsed_value_component))) {
return true;
}
auto nan_payload = GetNanPayload(parsed_value_component);
if (nan_payload == QuietNanWithoutPayload<double>()) {
nan_payload = QuietNanWithoutPayload<LiteralNativeComponentT>();
}
const auto kLargestPayload = NanPayloadBitMask<LiteralNativeComponentT>();
if (nan_payload > kLargestPayload) {
return Error(
loc,
StrCat("tries to set NaN payload 0x", absl::Hex(nan_payload),
" to a literal in shape ",
ShapeUtil::HumanString(literal->shape()), " at linear index ",
index, ", but the NaN payload is out of range (0x",
absl::Hex(kLargestPayload), ")"));
}
*literal_value_component = NanWithSignAndPayload<LiteralNativeComponentT>(
/*sign=*/std::signbit(static_cast<double>(parsed_value_component)),
/*nan_payload=*/nan_payload);
return true;
};
const ParsedElemComponentT parsed_real_value = GetReal(value);
auto literal_real_value =
static_cast<LiteralNativeComponentT>(parsed_real_value);
if (std::is_floating_point<ParsedElemT>::value ||
std::is_same<ParsedElemT, std::complex<double>>::value) {
if (!handle_nan(parsed_real_value, &literal_real_value)) {
return false;
}
}
const ParsedElemComponentT parsed_imag_value = GetImag(value);
auto literal_imag_value =
static_cast<LiteralNativeComponentT>(parsed_imag_value);
if (std::is_same<ParsedElemT, std::complex<double>>::value) {
if (!handle_nan(parsed_real_value, &literal_imag_value)) {
return false;
}
}
literal->data<LiteralNativeT>().at(index) =
LiteralNativeFromRealImag<LiteralNativeT>(literal_real_value,
literal_imag_value);
return true;
}
// Similar to ParseLiteral(Literal* literal, const Shape& shape), but parse the
// shape instead of accepting one as argument.
bool HloParserImpl::ParseLiteral(Literal* literal) {
if (lexer_.GetKind() == TokKind::kLparen) {
// Consume Lparen
lexer_.Lex();
std::vector<Literal> elements;
while (lexer_.GetKind() != TokKind::kRparen) {
Literal element;
if (!ParseLiteral(&element)) {
return TokenError("Fails when parsing tuple element");
}
elements.emplace_back(std::move(element));
if (lexer_.GetKind() != TokKind::kRparen) {
ParseToken(TokKind::kComma, "expects ',' to separate tuple elements");
}
}
*literal = LiteralUtil::MakeTupleOwned(std::move(elements));
// Consume Rparen
return ParseToken(TokKind::kRparen, "expects ')' to close a tuple literal");
}
Shape literal_shape;
if (!ParseShape(&literal_shape)) {
return false;
}
return ParseLiteral(literal, literal_shape);
}
// literal
// ::= tuple
// ::= non_tuple
bool HloParserImpl::ParseLiteral(Literal* literal, const Shape& shape) {
return shape.IsTuple() ? ParseTupleLiteral(literal, shape)
: ParseNonTupleLiteral(literal, shape);
}
// tuple
// ::= shape '(' literal_list ')'
// literal_list
// ::= /*empty*/
// ::= literal (',' literal)*
bool HloParserImpl::ParseTupleLiteral(Literal* literal, const Shape& shape) {
if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) {
return false;
}
std::vector<Literal> elements(ShapeUtil::TupleElementCount(shape));
if (lexer_.GetKind() == TokKind::kRparen) {
// empty
} else {
// literal, (',' literal)*
for (int i = 0; i < elements.size(); i++) {
if (i > 0) {
ParseToken(TokKind::kComma, "expects ',' to separate tuple elements");
}
if (!ParseLiteral(&elements[i],
ShapeUtil::GetTupleElementShape(shape, i))) {
return TokenError(StrCat("expects the ", i, "th element"));
}
}
}
*literal = LiteralUtil::MakeTupleOwned(std::move(elements));
return ParseToken(TokKind::kRparen,
StrCat("expects ')' at the end of the tuple with ",
ShapeUtil::TupleElementCount(shape), "elements"));
}
// non_tuple
// ::= rank01
// ::= rank2345
// rank2345 ::= shape nested_array
bool HloParserImpl::ParseNonTupleLiteral(Literal* literal, const Shape& shape) {
CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ToString(true);
return ParseDenseLiteral(literal, shape);
}
bool HloParserImpl::ParseDenseLiteral(Literal* literal, const Shape& shape) {
// Cast `rank` to int because we call shape.dimensions(int rank) below, and if
// `rank` is an int64_t, that's an implicit narrowing conversion, which is
// implementation-defined behavior.
const int rank = static_cast<int>(shape.rank());
// Create a literal with the given shape in default layout.
*literal = LiteralUtil::CreateFromDimensions(shape.element_type(),
shape.dimensions());
int64_t nest_level = 0;
int64_t linear_index = 0;
// elems_seen_per_dim[i] is how many elements or sub-arrays we have seen for
// the dimension i. For example, to parse f32[2,3] {{1, 2, 3}, {4, 5, 6}},
// when we are parsing the 2nd '{' (right before '1'), we are seeing a
// sub-array of the dimension 0, so elems_seen_per_dim[0]++. When we are at
// the first '}' (right after '3'), it means the sub-array ends, and the
// sub-array is supposed to contain exactly 3 elements, so check if
// elems_seen_per_dim[1] is 3.
std::vector<int64_t> elems_seen_per_dim(rank);
auto get_index_str = [&elems_seen_per_dim](int dim) -> std::string {
std::vector<int64_t> elems_seen_until_dim(elems_seen_per_dim.begin(),
elems_seen_per_dim.begin() + dim);
return StrCat("[",
StrJoin(elems_seen_until_dim, ",",
[](std::string* out, const int64_t num_elems) {
StrAppend(out, num_elems - 1);
}),
"]");
};
auto add_one_elem_seen = [&] {
if (rank > 0) {
if (nest_level != rank) {
return TokenError(absl::StrFormat(
"expects nested array in rank %d, but sees %d", rank, nest_level));
}
elems_seen_per_dim[rank - 1]++;
if (elems_seen_per_dim[rank - 1] > shape.dimensions(rank - 1)) {
return TokenError(absl::StrFormat(
"expects %d elements on the minor-most dimension, but "
"sees more",
shape.dimensions(rank - 1)));
}
}
return true;
};
do {
switch (lexer_.GetKind()) {
default:
return TokenError("unexpected token type in a literal");
case TokKind::kLbrace: {
nest_level++;
if (nest_level > rank) {
return TokenError(absl::StrFormat(
"expects nested array in rank %d, but sees larger", rank));
}
if (nest_level > 1) {
elems_seen_per_dim[nest_level - 2]++;
if (elems_seen_per_dim[nest_level - 2] >
shape.dimensions(nest_level - 2)) {
return TokenError(absl::StrFormat(
"expects %d elements in the %sth element, but sees more",
shape.dimensions(nest_level - 2),
get_index_str(nest_level - 2)));
}
}
lexer_.Lex();
break;
}
case TokKind::kRbrace: {
nest_level--;
if (elems_seen_per_dim[nest_level] != shape.dimensions(nest_level)) {
return TokenError(absl::StrFormat(
"expects %d elements in the %sth element, but sees %d",
shape.dimensions(nest_level), get_index_str(nest_level),
elems_seen_per_dim[nest_level]));
}
elems_seen_per_dim[nest_level] = 0;
lexer_.Lex();
break;
}
case TokKind::kLparen: {
if (!primitive_util::IsComplexType(shape.element_type())) {
return TokenError(
absl::StrFormat("unexpected '(' in literal. Parens are only "
"valid for complex literals"));
}
std::complex<double> value;
LocTy loc = lexer_.GetLoc();
if (!add_one_elem_seen() || !ParseComplex(&value) ||
!SetValueInLiteral(loc, value, linear_index++, literal)) {
return false;
}
break;
}
case TokKind::kDots: {
if (nest_level != 1) {
return TokenError(absl::StrFormat(
"expects `...` at nest level 1, but sees it at nest level %d",
nest_level));
}
elems_seen_per_dim[0] = shape.dimensions(0);
lexer_.Lex();
// Fill data with deterministic (garbage) values. Use static to avoid
// creating identical constants which could potentially got CSE'ed
// away. This is a best-effort approach to make sure replaying a HLO
// gives us same optimized HLO graph.
static uint32_t data = 0;
uint32_t* raw_data = static_cast<uint32_t*>(literal->untyped_data());
for (int64_t i = 0; i < literal->size_bytes() / 4; ++i) {
raw_data[i] = data++;
}
uint8_t* raw_data_int8 = static_cast<uint8_t*>(literal->untyped_data());
static uint8_t data_int8 = 0;
for (int64_t i = 0; i < literal->size_bytes() % 4; ++i) {
raw_data_int8[literal->size_bytes() / 4 + i] = data_int8++;
}
break;
}
case TokKind::kComma:
// Skip.
lexer_.Lex();
break;
case TokKind::kw_true:
case TokKind::kw_false:
case TokKind::kInt:
case TokKind::kDecimal:
case TokKind::kw_inf:
case TokKind::kNegInf: {
add_one_elem_seen();
if (lexer_.GetKind() == TokKind::kw_true ||
lexer_.GetKind() == TokKind::kw_false) {
if (!SetValueInLiteral(lexer_.GetLoc(),
lexer_.GetKind() == TokKind::kw_true,
linear_index++, literal)) {
return false;
}
lexer_.Lex();
} else if (primitive_util::IsIntegralType(shape.element_type()) ||
shape.element_type() == PRED) {
LocTy loc = lexer_.GetLoc();
int64_t value;
if (!ParseInt64(&value)) {
return Error(loc, StrCat("expects integer for primitive type: ",
PrimitiveType_Name(shape.element_type())));
}
if (!SetValueInLiteral(loc, value, linear_index++, literal)) {
return false;
}
} else if (primitive_util::IsFloatingPointType(shape.element_type())) {
LocTy loc = lexer_.GetLoc();
double value;
if (!ParseDouble(&value)) {
return Error(
loc, StrCat("expect floating point value for primitive type: ",
PrimitiveType_Name(shape.element_type())));
}
if (!SetValueInLiteral(loc, value, linear_index++, literal)) {
return false;
}
} else {
return TokenError(StrCat("unsupported primitive type ",
PrimitiveType_Name(shape.element_type())));
}
break;
}
} // end of switch
} while (nest_level > 0);
*literal = literal->Relayout(shape.layout());
return true;
}
// MaxFiniteValue is a type-traits helper used by
// HloParserImpl::CheckParsedValueIsInRange.
template <typename T>
struct MinMaxFiniteValue {
static T max() { return std::numeric_limits<T>::max(); }
static T min() { return std::numeric_limits<T>::lowest(); }
};
template <>
struct MinMaxFiniteValue<Eigen::half> {
static double max() {
// Sadly this is not constexpr, so this forces `value` to be a method.
return static_cast<double>(Eigen::NumTraits<Eigen::half>::highest());
}
static double min() { return -max(); }
};
template <>
struct MinMaxFiniteValue<bfloat16> {
static double max() {
return static_cast<double>(Eigen::NumTraits<Eigen::bfloat16>::highest());
}
static double min() { return -max(); }
};
// MSVC's standard C++ library does not define isnan/isfinite for integer types.
// To work around that we will need to provide our own.
template <typename T>
std::enable_if_t<std::is_floating_point<T>::value, bool> IsFinite(T val) {
return std::isfinite(val);
}
template <typename T>
std::enable_if_t<std::is_floating_point<T>::value, bool> IsNaN(T val) {
return std::isnan(val);
}
template <typename T>
std::enable_if_t<std::is_integral<T>::value, bool> IsFinite(T val) {
return std::isfinite(static_cast<double>(val));
}
template <typename T>
std::enable_if_t<std::is_integral<T>::value, bool> IsNaN(T val) {
return std::isnan(static_cast<double>(val));
}
template <typename LiteralNativeT, typename ParsedElemT>
bool HloParserImpl::CheckParsedValueIsInRange(LocTy loc, ParsedElemT value) {
if (std::is_floating_point<ParsedElemT>::value) {
auto value_as_native_t = static_cast<LiteralNativeT>(value);
auto value_double_converted = static_cast<ParsedElemT>(value_as_native_t);
if (!IsFinite(value) || IsFinite(value_double_converted)) {
value = value_double_converted;
}
}
PrimitiveType literal_ty =
primitive_util::NativeToPrimitiveType<LiteralNativeT>();
if (IsNaN(value) ||
(std::numeric_limits<ParsedElemT>::has_infinity &&
(std::numeric_limits<ParsedElemT>::infinity() == value ||
-std::numeric_limits<ParsedElemT>::infinity() == value))) {
// Skip range checking for non-finite value.
} else if (std::is_unsigned<LiteralNativeT>::value) {
CHECK((std::is_same<ParsedElemT, int64_t>::value ||
std::is_same<ParsedElemT, bool>::value))
<< "Unimplemented checking for ParsedElemT";
const uint64_t unsigned_value = value;
const uint64_t upper_bound =
static_cast<uint64_t>(std::numeric_limits<LiteralNativeT>::max());
if (unsigned_value > upper_bound) {
// Value is out of range for LiteralNativeT.
return Error(loc, StrCat("value ", value,
" is out of range for literal's primitive type ",
PrimitiveType_Name(literal_ty), " namely [0, ",
upper_bound, "]."));
}
} else if (value > MinMaxFiniteValue<LiteralNativeT>::max() ||
value < MinMaxFiniteValue<LiteralNativeT>::min()) {
// Value is out of range for LiteralNativeT.
return Error(loc, StrCat("value ", value,
" is out of range for literal's primitive type ",
PrimitiveType_Name(literal_ty), " namely [",
MinMaxFiniteValue<LiteralNativeT>::min(), ", ",
MinMaxFiniteValue<LiteralNativeT>::max(), "]."));
}
return true;
}
template <typename LiteralNativeT>
bool HloParserImpl::CheckParsedValueIsInRange(LocTy loc,
std::complex<double> value) {
// e.g. `float` for std::complex<float>
using LiteralComplexComponentT =
decltype(std::real(std::declval<LiteralNativeT>()));
// We could do simply
//
// return CheckParsedValueIsInRange<LiteralNativeT>(std::real(value)) &&
// CheckParsedValueIsInRange<LiteralNativeT>(std::imag(value));
//
// but this would give bad error messages on failure.
auto check_component = [&](absl::string_view name, double v) {
if (std::isnan(v) || v == std::numeric_limits<double>::infinity() ||
v == -std::numeric_limits<double>::infinity()) {
// Skip range-checking for non-finite values.
return true;
}
double min = MinMaxFiniteValue<LiteralComplexComponentT>::min();
double max = MinMaxFiniteValue<LiteralComplexComponentT>::max();
if (v < min || v > max) {
// Value is out of range for LitearlComplexComponentT.
return Error(
loc,
StrCat(name, " part ", v,
" is out of range for literal's primitive type ",
PrimitiveType_Name(
primitive_util::NativeToPrimitiveType<LiteralNativeT>()),
", namely [", min, ", ", max, "]."));
}
return true;
};
return check_component("real", std::real(value)) &&
check_component("imaginary", std::imag(value));
}
// operands ::= '(' operands1 ')'
// operands1
// ::= /*empty*/
// ::= operand (, operand)*
// operand ::= (shape)? name
// ::= (shape)? opcode operands
bool HloParserImpl::ParseOperands(std::vector<HloInstruction*>* operands,
HloComputation::Builder* builder) {
CHECK(operands != nullptr);
if (!ParseToken(TokKind::kLparen,
"expects '(' at the beginning of operands")) {
return false;
}
if (lexer_.GetKind() == TokKind::kRparen) {
// empty
} else {
do {
// Try to parse the operand as a name with an optional shape. If that
// doesn't work, try again parsing it as a nested instruction.
//
// (Trying nested instructions second is important here: If you have a
// giant HLO dump, it likely doesn't have any nested instructions, but
// likely has tons of non-nested operands. Generating an error is slow --
// O(n) as of writing -- so we only want to hit the error branch in the
// uncommon case.)
HloLexer lexer_copy = lexer_;
std::vector<std::string> saved_errors;
std::swap(saved_errors, error_);
bool is_normal_operand = [&] {
LocTy loc = lexer_.GetLoc();
std::string name;
optional<Shape> shape;
if (CanBeShape()) {
shape.emplace();
if (!ParseShape(&shape.value())) {
return false;
}
}
if (!ParseName(&name)) {
// When parsing a single instruction (as opposed to a whole module),
// an HLO may have one or more operands with a shape but no name:
//
// foo = add(f32[10], f32[10])
//
// create_missing_instruction_ is always non-null when parsing a
// single instruction, and is responsible for creating kParameter
// instructions for these operands.
if (shape.has_value() && create_missing_instruction_ != nullptr &&
scoped_name_tables_.size() == 1) {
name = "";
} else {
return false;
}
}
std::pair<HloInstruction*, LocTy>* instruction =
FindInstruction(name, shape);
if (instruction == nullptr) {
return Error(loc, StrCat("instruction does not exist: ", name));
}
// If this is a regular named operand, it must be followed by a comma or
// a close-paren. If not, it has to be a named instruction. Don't
// output an error here -- if it fails to parse as a named instruction
// too, we'll just use that set of errors.
auto next = lexer_.GetKind();
if (next != TokKind::kComma && next != TokKind::kRparen) {
return false;
}
operands->push_back(instruction->first);
return true;
}();
if (is_normal_operand) {
error_ = std::move(saved_errors);
continue;
}
// If parsing as a normal operand failed, try parsing as a nested
// instruction.
std::vector<std::string> normal_operand_errors;
std::swap(error_, normal_operand_errors);
lexer_ = lexer_copy;
// Nested instructions can't have attributes because it's ambiguous
// whether the comma separates an instruction from its attribute, or
// whether the comma separates two instructions.
LocTy loc = lexer_.GetLoc();
bool is_nested_instruction = ParseInstructionRhs(
builder, /*name=*/"", loc, /*allow_attributes=*/false);
if (is_nested_instruction) {
operands->push_back(builder->last_added_instruction());
error_ = std::move(saved_errors);
continue;
}
// If neither parsing as a normal operand nor parsing as a nested
// instruction worked, fail. Return both sets of errors.
std::vector<std::string> nested_instruction_errors;
std::swap(error_, nested_instruction_errors);
error_ = std::move(saved_errors);
Error(loc,
"cannot parse as an instruction name or as a nested instruction:");
error_.insert(error_.end(),
std::make_move_iterator(normal_operand_errors.begin()),
std::make_move_iterator(normal_operand_errors.end()));
error_.insert(error_.end(),
std::make_move_iterator(nested_instruction_errors.begin()),
std::make_move_iterator(nested_instruction_errors.end()));
} while (EatIfPresent(TokKind::kComma));
}
return ParseToken(TokKind::kRparen, "expects ')' at the end of operands");
}
bool HloParserImpl::ParseOperands(std::vector<HloInstruction*>* operands,
HloComputation::Builder* builder,
const int expected_size) {
CHECK(operands != nullptr);
LocTy loc = lexer_.GetLoc();
if (!ParseOperands(operands, builder)) {
return false;
}
if (expected_size != operands->size()) {
return Error(loc, StrCat("expects ", expected_size, " operands, but has ",
operands->size(), " operands"));
}
return true;
}
// sub_attributes ::= '{' (','? attribute)* '}'
bool HloParserImpl::ParseSubAttributes(
const absl::flat_hash_map<std::string, AttrConfig>& attrs) {
LocTy loc = lexer_.GetLoc();
if (!ParseToken(TokKind::kLbrace, "expects '{' to start sub attributes")) {
return false;
}
absl::flat_hash_set<std::string> seen_attrs;
if (lexer_.GetKind() == TokKind::kRbrace) {
// empty
} else {
do {
EatIfPresent(TokKind::kComma);
if (!ParseAttributeHelper(attrs, &seen_attrs)) {
return false;
}
} while (lexer_.GetKind() != TokKind::kRbrace);
}
// Check that all required attrs were seen.
for (const auto& attr_it : attrs) {
if (attr_it.second.required &&
seen_attrs.find(attr_it.first) == seen_attrs.end()) {
return Error(loc, StrFormat("sub-attribute %s is expected but not seen",
attr_it.first));
}
}
return ParseToken(TokKind::kRbrace, "expects '}' to end sub attributes");
}
// attributes ::= (',' attribute)*
bool HloParserImpl::ParseAttributes(
const absl::flat_hash_map<std::string, AttrConfig>& attrs) {
LocTy loc = lexer_.GetLoc();
absl::flat_hash_set<std::string> seen_attrs;
while (EatIfPresent(TokKind::kComma)) {
if (!ParseAttributeHelper(attrs, &seen_attrs)) {
return false;
}
}
// Check that all required attrs were seen.
for (const auto& attr_it : attrs) {
if (attr_it.second.required &&
seen_attrs.find(attr_it.first) == seen_attrs.end()) {
return Error(loc, StrFormat("attribute %s is expected but not seen",
attr_it.first));
}
}
return true;
}
bool HloParserImpl::ParseAttributeHelper(
const absl::flat_hash_map<std::string, AttrConfig>& attrs,
absl::flat_hash_set<std::string>* seen_attrs) {
LocTy loc = lexer_.GetLoc();
std::string name;
if (!ParseAttributeName(&name)) {
return Error(loc, "error parsing attributes");
}
VLOG(3) << "Parsing attribute " << name;
if (!seen_attrs->insert(name).second) {
return Error(loc, StrFormat("attribute %s already exists", name));
}
auto attr_it = attrs.find(name);
if (attr_it == attrs.end()) {
std::string allowed_attrs;
if (attrs.empty()) {
allowed_attrs = "No attributes are allowed here.";
} else {
allowed_attrs =
StrCat("Allowed attributes: ",
StrJoin(attrs, ", ",
[&](std::string* out,
const std::pair<std::string, AttrConfig>& kv) {
StrAppend(out, kv.first);
}));
}
return Error(loc, StrFormat("unexpected attribute \"%s\". %s", name,
allowed_attrs));
}
AttrTy attr_type = attr_it->second.attr_type;
void* attr_out_ptr = attr_it->second.result;
bool success = [&] {
LocTy attr_loc = lexer_.GetLoc();
switch (attr_type) {
case AttrTy::kBool: {
bool result;
if (!ParseBool(&result)) {
return false;
}
static_cast<optional<bool>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kInt64: {
int64_t result;
if (!ParseInt64(&result)) {
return false;
}
static_cast<optional<int64_t>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kInt32: {
int64_t result;
if (!ParseInt64(&result)) {
return false;
}
if (result != static_cast<int32_t>(result)) {
return Error(attr_loc, "value out of range for int32_t");
}
static_cast<optional<int32_t>*>(attr_out_ptr)
->emplace(static_cast<int32_t>(result));
return true;
}
case AttrTy::kFloat: {
double result;
if (!ParseDouble(&result)) {
return false;
}
if (result > std::numeric_limits<float>::max() ||
result < std::numeric_limits<float>::lowest()) {
return Error(attr_loc, "value out of range for float");
}
static_cast<optional<float>*>(attr_out_ptr)
->emplace(static_cast<float>(result));
return true;
}
case AttrTy::kHloComputation: {
HloComputation* result = nullptr;
if (!ParseHloComputation(&result)) {
return false;
}
static_cast<optional<HloComputation*>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kBracedHloComputationList: {
std::vector<HloComputation*> result;
if (!ParseHloComputationList(&result)) {
return false;
}
static_cast<optional<std::vector<HloComputation*>>*>(attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kFftType: {
FftType result;
if (!ParseFftType(&result)) {
return false;
}
static_cast<optional<FftType>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kPaddingType: {
PaddingType result;
if (!ParsePaddingType(&result)) {
return false;
}
static_cast<optional<PaddingType>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kComparisonDirection: {
ComparisonDirection result;
if (!ParseComparisonDirection(&result)) {
return false;
}
static_cast<optional<ComparisonDirection>*>(attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kComparisonType: {
Comparison::Type result;
if (!ParseComparisonType(&result)) {
return false;
}
static_cast<optional<Comparison::Type>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kEnum: {
if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects an enumeration value");
}
std::string result = lexer_.GetStrVal();
lexer_.Lex();
static_cast<optional<std::string>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kWindow: {
Window result;
if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) {
return false;
}
static_cast<optional<Window>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kConvolutionDimensionNumbers: {
ConvolutionDimensionNumbers result;
if (!ParseConvolutionDimensionNumbers(&result)) {
return false;
}
static_cast<optional<ConvolutionDimensionNumbers>*>(attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kSharding: {
OpSharding sharding;
if (!ParseSharding(&sharding)) {
return false;
}
static_cast<optional<OpSharding>*>(attr_out_ptr)->emplace(sharding);
return true;
}
case AttrTy::kFrontendAttributes: {
FrontendAttributes frontend_attributes;
if (!ParseFrontendAttributes(&frontend_attributes)) {
return false;
}
static_cast<optional<FrontendAttributes>*>(attr_out_ptr)
->emplace(frontend_attributes);
return true;
}
case AttrTy::kParameterReplication: {
ParameterReplication parameter_replication;
if (!ParseParameterReplication(&parameter_replication)) {
return false;
}
static_cast<optional<ParameterReplication>*>(attr_out_ptr)
->emplace(parameter_replication);
return true;
}
case AttrTy::kInstructionList: {
std::vector<HloInstruction*> result;
if (!ParseInstructionNames(&result)) {
return false;
}
static_cast<optional<std::vector<HloInstruction*>>*>(attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kFusionKind: {
HloInstruction::FusionKind result;
if (!ParseFusionKind(&result)) {
return false;
}
static_cast<optional<HloInstruction::FusionKind>*>(attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kBracedInt64List: {
std::vector<int64_t> result;
if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
&result)) {
return false;
}
static_cast<optional<std::vector<int64_t>>*>(attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kBracedInt64ListList: {
std::vector<std::vector<int64_t>> result;
if (!ParseInt64ListList(TokKind::kLbrace, TokKind::kRbrace,
TokKind::kComma, &result)) {
return false;
}
static_cast<optional<std::vector<std::vector<int64_t>>>*>(attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kSliceRanges: {
SliceRanges result;
if (!ParseSliceRanges(&result)) {
return false;
}
static_cast<optional<SliceRanges>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kPaddingConfig: {
PaddingConfig result;
if (!ParsePaddingConfig(&result)) {
return false;
}
static_cast<optional<PaddingConfig>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kString: {
std::string result;
if (!ParseString(&result)) {
return false;
}
static_cast<optional<std::string>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kMetadata: {
OpMetadata result;
if (!ParseMetadata(&result)) {
return false;
}
static_cast<optional<OpMetadata>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kDistribution: {
RandomDistribution result;
if (!ParseRandomDistribution(&result)) {
return false;
}
static_cast<optional<RandomDistribution>*>(attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kDomain: {
return ParseDomain(static_cast<DomainData*>(attr_out_ptr));
}
case AttrTy::kPrecisionList: {
std::vector<PrecisionConfig::Precision> result;
if (!ParsePrecisionList(&result)) {
return false;
}
static_cast<optional<std::vector<PrecisionConfig::Precision>>*>(
attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kShape: {
Shape result;
if (!ParseShape(&result)) {
return false;
}
static_cast<optional<Shape>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kShapeList: {
std::vector<Shape> result;
if (!ParseShapeList(&result)) {
return false;
}
static_cast<optional<std::vector<Shape>>*>(attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kRandomAlgorithm: {
RandomAlgorithm result;
if (!ParseRandomAlgorithm(&result)) {
return false;
}
static_cast<optional<RandomAlgorithm>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kAliasing: {
AliasingData aliasing_data;
if (!ParseAliasing(&aliasing_data)) {
return false;
}
static_cast<optional<AliasingData>*>(attr_out_ptr)
->emplace(aliasing_data);
return true;
}
case AttrTy::kInstructionAliasing: {
std::vector<std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
aliasing_output_operand_pairs;
if (!ParseInstructionOutputOperandAliasing(
&aliasing_output_operand_pairs)) {
return false;
}
static_cast<optional<std::vector<
std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>>*>(
attr_out_ptr)
->emplace(std::move(aliasing_output_operand_pairs));
return true;
}
case AttrTy::kLiteral: {
Literal result;
if (!ParseLiteral(&result)) {
return false;
}
static_cast<optional<Literal>*>(attr_out_ptr)
->emplace(std::move(result));
return true;
}
case AttrTy::kCustomCallSchedule: {
CustomCallSchedule result;
if (!ParseCustomCallSchedule(&result)) {
return false;
}
static_cast<optional<CustomCallSchedule>*>(attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kCustomCallApiVersion: {
CustomCallApiVersion result;
if (!ParseCustomCallApiVersion(&result)) {
return false;
}
static_cast<optional<CustomCallApiVersion>*>(attr_out_ptr)
->emplace(result);
return true;
}
}
}();
if (!success) {
return Error(loc, StrFormat("error parsing attribute %s", name));
}
return true;
}
bool HloParserImpl::CopyAttributeToProtoMessage(
absl::flat_hash_set<std::string> non_proto_attrs,
const absl::flat_hash_map<std::string, AttrConfig>& attrs,
tensorflow::protobuf::Message* message) {
const tensorflow::protobuf::Descriptor* descriptor = message->GetDescriptor();
const tensorflow::protobuf::Reflection* reflection = message->GetReflection();
for (const auto& p : attrs) {
const std::string& name = p.first;
if (non_proto_attrs.find(name) != non_proto_attrs.end()) {
continue;
}
const tensorflow::protobuf::FieldDescriptor* fd =
descriptor->FindFieldByName(name);
if (!fd) {
std::string allowed_attrs = "Allowed attributes: ";
for (int i = 0; i < descriptor->field_count(); ++i) {
if (i == 0) {
absl::StrAppend(&allowed_attrs, descriptor->field(i)->name());
} else {
absl::StrAppend(&allowed_attrs, ", ", descriptor->field(i)->name());
}
}
return TokenError(
StrFormat("unexpected attribute \"%s\". %s", name, allowed_attrs));
}
CHECK(!fd->is_repeated()); // Repeated fields not implemented.
bool success = [&] {
switch (fd->type()) {
case tensorflow::protobuf::FieldDescriptor::TYPE_BOOL: {
auto attr_value = static_cast<optional<bool>*>(p.second.result);
if (attr_value->has_value()) {
reflection->SetBool(message, fd, **attr_value);
}
return true;
}
case tensorflow::protobuf::FieldDescriptor::TYPE_ENUM: {
auto attr_value =
static_cast<optional<std::string>*>(p.second.result);
if (attr_value->has_value()) {
const tensorflow::protobuf::EnumValueDescriptor* evd =
fd->enum_type()->FindValueByName(**attr_value);
reflection->SetEnum(message, fd, evd);
}
return true;
}
default:
return false;
}
}();
if (!success) {
return TokenError(StrFormat("error parsing attribute %s", name));
}
}
return true;
}
// attributes ::= (',' attribute)*
bool HloParserImpl::ParseAttributesAsProtoMessage(
const absl::flat_hash_map<std::string, AttrConfig>& non_proto_attrs,
tensorflow::protobuf::Message* message) {
const tensorflow::protobuf::Descriptor* descriptor = message->GetDescriptor();
absl::flat_hash_map<std::string, AttrConfig> attrs;
// Storage for attributes.
std::vector<optional<bool>> bool_params;
std::vector<optional<std::string>> string_params;
// Reserve enough capacity to make sure that the vector is not growing, so we
// can rely on the pointers to stay valid.
bool_params.reserve(descriptor->field_count());
string_params.reserve(descriptor->field_count());
// Populate the storage of expected attributes from the protobuf description.
for (int field_idx = 0; field_idx < descriptor->field_count(); field_idx++) {
const tensorflow::protobuf::FieldDescriptor* fd =
descriptor->field(field_idx);
const std::string& field_name = fd->name();
switch (fd->type()) {
case tensorflow::protobuf::FieldDescriptor::TYPE_BOOL: {
bool_params.emplace_back(absl::nullopt);
attrs[field_name] = {/*is_required*/ false, AttrTy::kBool,
&bool_params.back()};
break;
}
case tensorflow::protobuf::FieldDescriptor::TYPE_ENUM: {
string_params.emplace_back(absl::nullopt);
attrs[field_name] = {/*is_required*/ false, AttrTy::kEnum,
&string_params.back()};
break;
}
default:
return TokenError(absl::StrFormat(
"Unexpected protocol buffer type: %s ", fd->DebugString()));
}
}
absl::flat_hash_set<std::string> non_proto_attrs_names;
non_proto_attrs_names.reserve(non_proto_attrs.size());
for (const auto& p : non_proto_attrs) {
const std::string& attr_name = p.first;
// If an attribute is both specified within 'non_proto_attrs' and an
// attribute of the proto message, we prefer the attribute of the proto
// message.
if (attrs.find(attr_name) == attrs.end()) {
non_proto_attrs_names.insert(attr_name);
attrs[attr_name] = p.second;
}
}
if (!ParseAttributes(attrs)) {
return false;
}
return CopyAttributeToProtoMessage(non_proto_attrs_names, attrs, message);
}
bool HloParserImpl::ParseComputationName(HloComputation** value) {
std::string name;
LocTy loc = lexer_.GetLoc();
if (!ParseName(&name)) {
return Error(loc, "expects computation name");
}
std::pair<HloComputation*, LocTy>* computation =
tensorflow::gtl::FindOrNull(computation_pool_, name);
if (computation == nullptr) {
return Error(loc, StrCat("computation does not exist: ", name));
}
*value = computation->first;
return true;
}
// ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}'
// The subattributes can appear in any order. 'size=' is required, others are
// optional.
bool HloParserImpl::ParseWindow(Window* window, bool expect_outer_curlies) {
LocTy loc = lexer_.GetLoc();
if (expect_outer_curlies &&
!ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) {
return false;
}
std::vector<int64_t> size;
std::vector<int64_t> stride;
std::vector<std::vector<int64_t>> pad;
std::vector<int64_t> lhs_dilate;
std::vector<int64_t> rhs_dilate;
std::vector<int64_t> rhs_reversal;
const auto end_token =
expect_outer_curlies ? TokKind::kRbrace : TokKind::kEof;
while (lexer_.GetKind() != end_token) {
LocTy attr_loc = lexer_.GetLoc();
std::string field_name;
if (!ParseAttributeName(&field_name)) {
return Error(attr_loc, "expects sub-attributes in window");
}
bool ok = [&] {
if (field_name == "size") {
return ParseDxD("size", &size);
}
if (field_name == "stride") {
return ParseDxD("stride", &stride);
}
if (field_name == "lhs_dilate") {
return ParseDxD("lhs_dilate", &lhs_dilate);
}
if (field_name == "rhs_dilate") {
return ParseDxD("rls_dilate", &rhs_dilate);
}
if (field_name == "pad") {
return ParseWindowPad(&pad);
}
if (field_name == "rhs_reversal") {
return ParseDxD("rhs_reversal", &rhs_reversal);
}
return Error(attr_loc, StrCat("unexpected attribute name: ", field_name));
}();
if (!ok) {
return false;
}
}
if (!stride.empty() && stride.size() != size.size()) {
return Error(loc, "expects 'stride=' has the same size as 'size='");
}
if (!lhs_dilate.empty() && lhs_dilate.size() != size.size()) {
return Error(loc, "expects 'lhs_dilate=' has the same size as 'size='");
}
if (!rhs_dilate.empty() && rhs_dilate.size() != size.size()) {
return Error(loc, "expects 'rhs_dilate=' has the same size as 'size='");
}
if (!pad.empty() && pad.size() != size.size()) {
return Error(loc, "expects 'pad=' has the same size as 'size='");
}
for (int i = 0; i < size.size(); i++) {
window->add_dimensions()->set_size(size[i]);
if (!pad.empty()) {
window->mutable_dimensions(i)->set_padding_low(pad[i][0]);
window->mutable_dimensions(i)->set_padding_high(pad[i][1]);
}
// If some field is not present, it has the default value.
window->mutable_dimensions(i)->set_stride(stride.empty() ? 1 : stride[i]);
window->mutable_dimensions(i)->set_base_dilation(
lhs_dilate.empty() ? 1 : lhs_dilate[i]);
window->mutable_dimensions(i)->set_window_dilation(
rhs_dilate.empty() ? 1 : rhs_dilate[i]);
window->mutable_dimensions(i)->set_window_reversal(
rhs_reversal.empty() ? false : (rhs_reversal[i] == 1));
}
return !expect_outer_curlies ||
ParseToken(TokKind::kRbrace, "expected '}' to end window attribute");
}
// This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString.
// The string looks like "dim_labels=0bf_0io->0bf".
//
// '?' dims don't appear in ConvolutionDimensionNumbers. There can be more than
// one '?' dim.
bool HloParserImpl::ParseConvolutionDimensionNumbers(
ConvolutionDimensionNumbers* dnums) {
if (lexer_.GetKind() != TokKind::kDimLabels) {
return TokenError("expects dim labels pattern, e.g., 'bf0_0io->0bf'");
}
std::string str = lexer_.GetStrVal();
// The str is expected to have 3 items, lhs, rhs, out, and it must look like
// lhs_rhs->out, that is, the first separator is "_" and the second is "->".
std::vector<std::string> split1 = absl::StrSplit(str, '_');
if (split1.size() != 2) {
LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
<< str;
}
std::vector<std::string> split2 = absl::StrSplit(split1[1], "->");
if (split2.size() != 2) {
LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
<< str;
}
absl::string_view lhs = split1[0];
absl::string_view rhs = split2[0];
absl::string_view out = split2[1];
auto is_unique = [](absl::string_view str) -> bool {
absl::flat_hash_set<char> chars;
for (char c : str) {
// '?' dims are skipped.
if (c == '?') {
continue;
}
if (!chars.insert(c).second) {
return false;
}
}
return true;
};
// lhs
{
if (!is_unique(lhs)) {
return TokenError(
StrCat("expects unique lhs dimension numbers, but sees ", lhs));
}
// Count number of spatial dimensions.
for (char c : lhs) {
if (c != 'b' && c != 'f' && c != '?') {
dnums->add_input_spatial_dimensions(-1);
}
}
for (int i = 0; i < lhs.size(); i++) {
char c = lhs[i];
if (c == '?') {
continue;
} else if (c == 'b') {
dnums->set_input_batch_dimension(i);
} else if (c == 'f') {
dnums->set_input_feature_dimension(i);
} else if (c < '0' + lhs.size() && c >= '0') {
dnums->set_input_spatial_dimensions(c - '0', i);
} else {
return TokenError(StrFormat(
"expects [0-%dbf?] in lhs dimension numbers", lhs.size() - 1));
}
}
}
// rhs
{
if (!is_unique(rhs)) {
return TokenError(
StrCat("expects unique rhs dimension numbers, but sees ", rhs));
}
// Count number of spatial dimensions.
for (char c : rhs) {
if (c != 'i' && c != 'o' && c != '?') {
dnums->add_kernel_spatial_dimensions(-1);
}
}
for (int i = 0; i < rhs.size(); i++) {
char c = rhs[i];
if (c == '?') {
continue;
} else if (c == 'i') {
dnums->set_kernel_input_feature_dimension(i);
} else if (c == 'o') {
dnums->set_kernel_output_feature_dimension(i);
} else if (c < '0' + rhs.size() && c >= '0') {
dnums->set_kernel_spatial_dimensions(c - '0', i);
} else {
return TokenError(StrFormat(
"expects [0-%dio?] in rhs dimension numbers", rhs.size() - 1));
}
}
}
// output
{
if (!is_unique(out)) {
return TokenError(
StrCat("expects unique output dimension numbers, but sees ", out));
}
// Count number of spatial dimensions.
for (char c : out) {
if (c != 'b' && c != 'f' && c != '?') {
dnums->add_output_spatial_dimensions(-1);
}
}
for (int i = 0; i < out.size(); i++) {
char c = out[i];
if (c == '?') {
continue;
} else if (c == 'b') {
dnums->set_output_batch_dimension(i);
} else if (c == 'f') {
dnums->set_output_feature_dimension(i);
} else if (c < '0' + out.size() && c >= '0') {
dnums->set_output_spatial_dimensions(c - '0', i);
} else {
return TokenError(StrFormat(
"expects [0-%dbf?] in output dimension numbers", out.size() - 1));
}
}
}
// lhs, rhs, and output should have the same number of spatial dimensions.
if (dnums->input_spatial_dimensions_size() !=
dnums->output_spatial_dimensions_size() ||
dnums->input_spatial_dimensions_size() !=
dnums->kernel_spatial_dimensions_size()) {
return TokenError(
StrFormat("input, kernel, and output must have same number of spatial "
"dimensions, but got %d, %d, %d, respectively.",
dnums->input_spatial_dimensions_size(),
dnums->kernel_spatial_dimensions_size(),
dnums->output_spatial_dimensions_size()));
}
lexer_.Lex();
return true;
}
// ::= '{' ranges '}'
// ::= /*empty*/
// ::= range (',' range)*
// range ::= '[' start ':' limit (':' stride)? ']'
//
// The slice ranges are printed as:
//
// {[dim0_start:dim0_limit:dim0stride], [dim1_start:dim1_limit], ...}
//
// This function extracts the starts, limits, and strides as 3 vectors to the
// result. If stride is not present, stride is 1. For example, if the slice
// ranges is printed as:
//
// {[2:3:4], [5:6:7], [8:9]}
//
// The parsed result will be:
//
// {/*starts=*/{2, 5, 8}, /*limits=*/{3, 6, 9}, /*strides=*/{4, 7, 1}}
//
bool HloParserImpl::ParseSliceRanges(SliceRanges* result) {
if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) {
return false;
}
std::vector<std::vector<int64_t>> ranges;
if (lexer_.GetKind() == TokKind::kRbrace) {
// empty
return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
}
do {
LocTy loc = lexer_.GetLoc();
ranges.emplace_back();
if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kColon,
&ranges.back())) {
return false;
}
const auto& range = ranges.back();
if (range.size() != 2 && range.size() != 3) {
return Error(loc,
StrFormat("expects [start:limit:step] or [start:limit], "
"but sees %d elements.",
range.size()));
}
} while (EatIfPresent(TokKind::kComma));
for (const auto& range : ranges) {
result->starts.push_back(range[0]);
result->limits.push_back(range[1]);
result->strides.push_back(range.size() == 3 ? range[2] : 1);
}
return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
}
// precisionlist ::= start precision_elements end
// precision_elements
// ::= /*empty*/
// ::= precision_val (delim precision_val)*
bool HloParserImpl::ParsePrecisionList(
std::vector<PrecisionConfig::Precision>* result) {
auto parse_and_add_item = [&]() {
PrecisionConfig::Precision item;
if (!ParsePrecision(&item)) {
return false;
}
result->push_back(item);
return true;
};
return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
parse_and_add_item);
}
bool HloParserImpl::ParseHloComputation(HloComputation** result) {
if (lexer_.GetKind() == TokKind::kLbrace) {
// This means it is a nested computation.
return ParseInstructionList(result, /*computation_name=*/"_");
}
// This means it is a computation name.
return ParseComputationName(result);
}
bool HloParserImpl::ParseHloComputationList(
std::vector<HloComputation*>* result) {
auto parse_and_add_item = [&]() {
HloComputation* computation;
if (!ParseHloComputation(&computation)) {
return false;
}
VLOG(3) << "parsed computation " << computation->name();
result->push_back(computation);
return true;
};
return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
parse_and_add_item);
}
// shapelist ::= '{' shapes '}'
// precision_elements
// ::= /*empty*/
// ::= shape (',' shape)*
bool HloParserImpl::ParseShapeList(std::vector<Shape>* result) {
auto parse_and_add_item = [&]() {
Shape shape;
if (!ParseShape(&shape)) {
return false;
}
result->push_back(std::move(shape));
return true;
};
return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
parse_and_add_item);
}
// int64_tlist ::= start int64_elements end
// int64_elements
// ::= /*empty*/
// ::= int64_val (delim int64_val)*
bool HloParserImpl::ParseInt64List(const TokKind start, const TokKind end,
const TokKind delim,
std::vector<int64_t>* result) {
auto parse_and_add_item = [&]() {
int64_t i;
if (!ParseInt64(&i)) {
return false;
}
result->push_back(i);
return true;
};
return ParseList(start, end, delim, parse_and_add_item);
}
// int64_tlistlist ::= start int64_tlist_elements end
// int64_tlist_elements
// ::= /*empty*/
// ::= int64_tlist (delim int64_tlist)*
// int64_tlist ::= start int64_elements end
// int64_elements
// ::= /*empty*/
// ::= int64_val (delim int64_val)*
bool HloParserImpl::ParseInt64ListList(
const TokKind start, const TokKind end, const TokKind delim,
std::vector<std::vector<int64_t>>* result) {
auto parse_and_add_item = [&]() {
std::vector<int64_t> item;
if (!ParseInt64List(start, end, delim, &item)) {
return false;
}
result->push_back(item);
return true;
};
return ParseList(start, end, delim, parse_and_add_item);
}
bool HloParserImpl::ParseList(const TokKind start, const TokKind end,
const TokKind delim,
const std::function<bool()>& parse_and_add_item) {
if (!ParseToken(start, StrCat("expects a list starting with ",
TokKindToString(start)))) {
return false;
}
if (lexer_.GetKind() == end) {
// empty
} else {
do {
if (!parse_and_add_item()) {
return false;
}
} while (EatIfPresent(delim));
}
return ParseToken(
end, StrCat("expects a list to end with ", TokKindToString(end)));
}
// param_list_to_shape ::= param_list '->' shape
bool HloParserImpl::ParseParamListToShape(Shape* shape, LocTy* shape_loc) {
if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) {
return false;
}
*shape_loc = lexer_.GetLoc();
return ParseShape(shape);
}
bool HloParserImpl::CanBeParamListToShape() {
return lexer_.GetKind() == TokKind::kLparen;
}
// param_list ::= '(' param_list1 ')'
// param_list1
// ::= /*empty*/
// ::= param (',' param)*
// param ::= name shape
bool HloParserImpl::ParseParamList() {
if (!ParseToken(TokKind::kLparen,
"expects '(' at the beginning of param list")) {
return false;
}
if (lexer_.GetKind() == TokKind::kRparen) {
// empty
} else {
do {
Shape shape;
std::string name;
if (!ParseName(&name) || !ParseShape(&shape)) {
return false;
}
} while (EatIfPresent(TokKind::kComma));
}
return ParseToken(TokKind::kRparen, "expects ')' at the end of param list");
}
// dimension_sizes ::= '[' dimension_list ']'
// dimension_list
// ::= /*empty*/
// ::= <=? int64_t (',' param)*
// param ::= name shape
bool HloParserImpl::ParseDimensionSizes(std::vector<int64_t>* dimension_sizes,
std::vector<bool>* dynamic_dimensions) {
auto parse_and_add_item = [&]() {
int64_t i;
bool is_dynamic = false;
if (lexer_.GetKind() == TokKind::kLeq) {
is_dynamic = true;
lexer_.Lex();
}
if (!ParseInt64(&i)) {
return false;
}
dimension_sizes->push_back(i);
dynamic_dimensions->push_back(is_dynamic);
return true;
};
return ParseList(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma,
parse_and_add_item);
}
// tiles
// ::= /*empty*/
// ::= 'T' '(' dim_list ')'
// dim_list
// ::= /*empty*/
// ::= (int64_t | '*') (',' (int64_t | '*'))*
bool HloParserImpl::ParseTiles(std::vector<Tile>* tiles) {
auto parse_and_add_tile_dimension = [&]() {
int64_t i;
if (ParseInt64(&i)) {
tiles->back().add_dimensions(i);
return true;
}
if (lexer_.GetKind() == TokKind::kAsterisk) {
tiles->back().add_dimensions(Tile::kCombineDimension);
lexer_.Lex();
return true;
}
return false;
};
do {
tiles->push_back(Tile());
if (!ParseList(TokKind::kLparen, TokKind::kRparen, TokKind::kComma,
parse_and_add_tile_dimension)) {
return false;
}
} while (lexer_.GetKind() == TokKind::kLparen);
return true;
}
// int_attribute
// ::= /*empty*/
// ::= attr_token '(' attr_value ')'
// attr_token
// ::= 'E' | 'S'
// attr_value
// ::= int64_t
bool HloParserImpl::ParseLayoutIntAttribute(
int64_t* attr_value, absl::string_view attr_description) {
if (!ParseToken(TokKind::kLparen,
StrCat("expects ", attr_description, " to start with ",
TokKindToString(TokKind::kLparen)))) {
return false;
}
if (!ParseInt64(attr_value)) {
return false;
}
if (!ParseToken(TokKind::kRparen,
StrCat("expects ", attr_description, " to end with ",
TokKindToString(TokKind::kRparen)))) {
return false;
}
return true;
}
// layout ::= '{' int64_list (':' tiles element_size_in_bits memory_space)? '}'
// element_size_in_bits
// ::= /*empty*/
// ::= 'E' '(' int64_t ')'
// memory_space
// ::= /*empty*/
// ::= 'S' '(' int64_t ')'
bool HloParserImpl::ParseLayout(Layout* layout) {
std::vector<int64_t> minor_to_major;
std::vector<Tile> tiles;
int64_t element_size_in_bits = 0;
int64_t memory_space = 0;
auto parse_and_add_item = [&]() {
int64_t i;
if (!ParseInt64(&i)) {
return false;
}
minor_to_major.push_back(i);
return true;
};
if (!ParseToken(TokKind::kLbrace,
StrCat("expects layout to start with ",
TokKindToString(TokKind::kLbrace)))) {
return false;
}
if (lexer_.GetKind() != TokKind::kRbrace) {
if (lexer_.GetKind() == TokKind::kInt) {
// Parse minor to major.
do {
if (!parse_and_add_item()) {
return false;
}
} while (EatIfPresent(TokKind::kComma));
}
if (lexer_.GetKind() == TokKind::kColon) {
lexer_.Lex();
if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "T") {
lexer_.Lex();
ParseTiles(&tiles);
}
if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "E") {
lexer_.Lex();
ParseLayoutIntAttribute(&element_size_in_bits, "element size in bits");
}
if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "S") {
lexer_.Lex();
ParseLayoutIntAttribute(&memory_space, "memory space");
}
}
}
if (!ParseToken(TokKind::kRbrace,
StrCat("expects layout to end with ",
TokKindToString(TokKind::kRbrace)))) {
return false;
}
std::vector<Tile> vec_tiles(tiles.size());
for (int i = 0; i < tiles.size(); i++) {
vec_tiles[i] = Tile(tiles[i]);
}
*layout = LayoutUtil::MakeLayout(minor_to_major, vec_tiles,
element_size_in_bits, memory_space);
return true;
}
// shape ::= shape_val_
// shape ::= '(' tuple_elements ')'
// tuple_elements
// ::= /*empty*/
// ::= shape (',' shape)*
bool HloParserImpl::ParseShape(Shape* result) {
if (EatIfPresent(TokKind::kLparen)) { // Tuple
std::vector<Shape> shapes;
if (lexer_.GetKind() == TokKind::kRparen) {
/*empty*/
} else {
// shape (',' shape)*
do {
shapes.emplace_back();
if (!ParseShape(&shapes.back())) {
return false;
}
} while (EatIfPresent(TokKind::kComma));
}
*result = ShapeUtil::MakeTupleShape(shapes);
return ParseToken(TokKind::kRparen, "expects ')' at the end of tuple.");
}
if (lexer_.GetKind() != TokKind::kPrimitiveType) {
return TokenError(absl::StrCat("expected primitive type, saw ",
TokKindToString(lexer_.GetKind())));
}
PrimitiveType primitive_type = lexer_.GetPrimitiveTypeVal();
lexer_.Lex();
// Each element contains a dimension size and a bool indicating whether this
// is a dynamic dimension.
std::vector<int64_t> dimension_sizes;
std::vector<bool> dynamic_dimensions;
if (!ParseDimensionSizes(&dimension_sizes, &dynamic_dimensions)) {
return false;
}
result->set_element_type(primitive_type);
for (int i = 0; i < dimension_sizes.size(); ++i) {
result->add_dimensions(dimension_sizes[i]);
result->set_dynamic_dimension(i, dynamic_dimensions[i]);
}
if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "invalid") {
lexer_.Lex();
if (lexer_.GetKind() != TokKind::kLbrace) {
return false;
}
lexer_.Lex();
if (lexer_.GetKind() != TokKind::kRbrace) {
return false;
}
lexer_.Lex();
result->mutable_layout()->Clear();
return true;
}
LayoutUtil::SetToDefaultLayout(result);
// We need to lookahead to see if a following open brace is the start of a
// layout. The specific problematic case is:
//
// ENTRY %foo (x: f32[42]) -> f32[123] {
// ...
// }
//
// The open brace could either be the start of a computation or the start of a
// layout for the f32[123] shape. We consider it the start of a layout if the
// next token after the open brace is an integer or a colon.
if (lexer_.GetKind() == TokKind::kLbrace &&
(lexer_.LookAhead() == TokKind::kInt ||
lexer_.LookAhead() == TokKind::kColon)) {
Layout layout;
if (!ParseLayout(&layout)) {
return false;
}
if (layout.minor_to_major_size() != result->rank()) {
return Error(
lexer_.GetLoc(),
StrFormat("Dimensions size is %ld, but minor to major size is %ld.",
result->rank(), layout.minor_to_major_size()));
}
*result->mutable_layout() = layout;
}
return true;
}
bool HloParserImpl::CanBeShape() {
// A non-tuple shape starts with a kPrimitiveType token; a tuple shape starts
// with '('.
return lexer_.GetKind() == TokKind::kPrimitiveType ||
lexer_.GetKind() == TokKind::kLparen;
}
bool HloParserImpl::ParseName(std::string* result) {
VLOG(3) << "ParseName";
if (lexer_.GetKind() != TokKind::kIdent &&
lexer_.GetKind() != TokKind::kName) {
return TokenError("expects name");
}
*result = lexer_.GetStrVal();
lexer_.Lex();
return true;
}
bool HloParserImpl::ParseAttributeName(std::string* result) {
if (lexer_.GetKind() != TokKind::kAttributeName) {
return TokenError("expects attribute name");
}
*result = lexer_.GetStrVal();
lexer_.Lex();
return true;
}
bool HloParserImpl::ParseString(std::string* result) {
VLOG(3) << "ParseString";
if (lexer_.GetKind() != TokKind::kString) {
return TokenError("expects string");
}
*result = lexer_.GetStrVal();
lexer_.Lex();
return true;
}
bool HloParserImpl::ParseDxD(const std::string& name,
std::vector<int64_t>* result) {
LocTy loc = lexer_.GetLoc();
if (!result->empty()) {
return Error(loc, StrFormat("sub-attribute '%s=' already exists", name));
}
// 1D
if (lexer_.GetKind() == TokKind::kInt) {
int64_t number;
if (!ParseInt64(&number)) {
return Error(loc, StrFormat("expects sub-attribute '%s=i'", name));
}
result->push_back(number);
return true;
}
// 2D or higher.
if (lexer_.GetKind() == TokKind::kDxD) {
std::string str = lexer_.GetStrVal();
if (!SplitToInt64s(str, 'x', result)) {
return Error(loc, StrFormat("expects sub-attribute '%s=ixj...'", name));
}
lexer_.Lex();
return true;
}
return TokenError("expects token type kInt or kDxD");
}
bool HloParserImpl::ParseWindowPad(std::vector<std::vector<int64_t>>* pad) {
LocTy loc = lexer_.GetLoc();
if (!pad->empty()) {
return Error(loc, "sub-attribute 'pad=' already exists");
}
if (lexer_.GetKind() != TokKind::kPad) {
return TokenError("expects window pad pattern, e.g., '0_0x3_3'");
}
std::string str = lexer_.GetStrVal();
for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) {
std::vector<int64_t> low_high;
if (!SplitToInt64s(padding_dim_str, '_', &low_high) ||
low_high.size() != 2) {
return Error(loc,
"expects padding_low and padding_high separated by '_'");
}
pad->push_back(low_high);
}
lexer_.Lex();
return true;
}
// This is the inverse xla::ToString(PaddingConfig). The padding config string
// looks like "0_0_0x3_3_1". The string is first separated by 'x', each
// substring represents one PaddingConfigDimension. The substring is 3 (or 2)
// numbers joined by '_'.
bool HloParserImpl::ParsePaddingConfig(PaddingConfig* padding) {
if (lexer_.GetKind() != TokKind::kPad) {
return TokenError("expects padding config, e.g., '0_0_0x3_3_1'");
}
LocTy loc = lexer_.GetLoc();
std::string str = lexer_.GetStrVal();
for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) {
std::vector<int64_t> padding_dim;
if (!SplitToInt64s(padding_dim_str, '_', &padding_dim) ||
(padding_dim.size() != 2 && padding_dim.size() != 3)) {
return Error(loc,
"expects padding config pattern like 'low_high_interior' or "
"'low_high'");
}
auto* dim = padding->add_dimensions();
dim->set_edge_padding_low(padding_dim[0]);
dim->set_edge_padding_high(padding_dim[1]);
dim->set_interior_padding(padding_dim.size() == 3 ? padding_dim[2] : 0);
}
lexer_.Lex();
return true;
}
// '{' metadata_string '}'
bool HloParserImpl::ParseMetadata(OpMetadata* metadata) {
absl::flat_hash_map<std::string, AttrConfig> attrs;
optional<std::string> op_type;
optional<std::string> op_name;
optional<std::string> source_file;
optional<int32_t> source_line;
optional<std::vector<int64_t>> profile_type;
attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type};
attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name};
attrs["source_file"] = {/*required=*/false, AttrTy::kString, &source_file};
attrs["source_line"] = {/*required=*/false, AttrTy::kInt32, &source_line};
attrs["profile_type"] = {/*required=*/false, AttrTy::kBracedInt64List,
&profile_type};
if (!ParseSubAttributes(attrs)) {
return false;
}
if (op_type) {
metadata->set_op_type(*op_type);
}
if (op_name) {
metadata->set_op_name(*op_name);
}
if (source_file) {
metadata->set_source_file(*source_file);
}
if (source_line) {
metadata->set_source_line(*source_line);
}
if (profile_type) {
for (const auto& type : *profile_type) {
if (!ProfileType_IsValid(type)) {
return false;
}
metadata->add_profile_type(static_cast<ProfileType>(type));
}
}
return true;
}
// ::= single_metadata | ('{' [single_metadata (',' single_metadata)*] '}')
bool HloParserImpl::ParseSingleOrListMetadata(
tensorflow::protobuf::RepeatedPtrField<OpMetadata>* metadata) {
if (lexer_.GetKind() == TokKind::kLbrace &&
lexer_.LookAhead() == TokKind::kLbrace) {
if (!ParseToken(TokKind::kLbrace, "expected '{' to start metadata list")) {
return false;
}
if (lexer_.GetKind() != TokKind::kRbrace) {
do {
if (!ParseMetadata(metadata->Add())) {
return false;
}
} while (EatIfPresent(TokKind::kComma));
}
return ParseToken(TokKind::kRbrace, "expected '}' to end metadata list");
}
return ParseMetadata(metadata->Add());
}
bool HloParserImpl::ParseOpShardingType(OpSharding::Type* type) {
switch (lexer_.GetKind()) {
case TokKind::kw_maximal:
*type = OpSharding::MAXIMAL;
lexer_.Lex();
break;
case TokKind::kw_replicated:
*type = OpSharding::REPLICATED;
lexer_.Lex();
break;
case TokKind::kw_manual:
*type = OpSharding::MANUAL;
lexer_.Lex();
break;
default:
return false;
}
return true;
}
bool HloParserImpl::ParseListShardingType(
std::vector<OpSharding::Type>* types) {
if (!ParseToken(TokKind::kLbrace,
"expected '{' to start sharding type list")) {
return false;
}
if (lexer_.GetKind() != TokKind::kRbrace) {
do {
OpSharding::Type type;
if (!ParseOpShardingType(&type)) {
return false;
}
types->emplace_back(type);
} while (EatIfPresent(TokKind::kComma));
}
return ParseToken(TokKind::kRbrace, "expected '}' to end sharding type list");
}
bool HloParserImpl::ParseOpcode(HloOpcode* result) {
VLOG(3) << "ParseOpcode";
if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects opcode");
}
std::string val = lexer_.GetStrVal();
auto status_or_result = StringToHloOpcode(val);
if (!status_or_result.ok()) {
return TokenError(StrFormat("expects opcode but sees: %s, error: %s", val,
status_or_result.status().error_message()));
}
*result = status_or_result.ValueOrDie();
lexer_.Lex();
return true;
}
bool HloParserImpl::ParseFftType(FftType* result) {
VLOG(3) << "ParseFftType";
if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects fft type");
}
std::string val = lexer_.GetStrVal();
if (!FftType_Parse(val, result) || !FftType_IsValid(*result)) {
return TokenError(StrFormat("expects fft type but sees: %s", val));
}
lexer_.Lex();
return true;
}
bool HloParserImpl::ParsePaddingType(PaddingType* result) {
VLOG(3) << "ParsePaddingType";
if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects padding type");
}
std::string val = lexer_.GetStrVal();
if (!PaddingType_Parse(val, result) || !PaddingType_IsValid(*result)) {
return TokenError(StrFormat("expects padding type but sees: %s", val));
}
lexer_.Lex();
return true;
}
bool HloParserImpl::ParseComparisonDirection(ComparisonDirection* result) {
VLOG(3) << "ParseComparisonDirection";
if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects comparison direction");
}
std::string val = lexer_.GetStrVal();
auto status_or_result = StringToComparisonDirection(val);
if (!status_or_result.ok()) {
return TokenError(
StrFormat("expects comparison direction but sees: %s", val));
}
*result = status_or_result.ValueOrDie();
lexer_.Lex();
return true;
}
bool HloParserImpl::ParseComparisonType(Comparison::Type* result) {
VLOG(1) << "ParseComparisonType";
if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects comparison type");
}
std::string val = lexer_.GetStrVal();
auto status_or_result = StringToComparisonType(val);
if (!status_or_result.ok()) {
return TokenError(StrFormat("expects comparison type but sees: %s", val));
}
*result = status_or_result.ValueOrDie();
lexer_.Lex();
return true;
}
bool HloParserImpl::ParseFusionKind(HloInstruction::FusionKind* result) {
VLOG(3) << "ParseFusionKind";
if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects fusion kind");
}
std::string val = lexer_.GetStrVal();
auto status_or_result = StringToFusionKind(val);
if (!status_or_result.ok()) {
return TokenError(StrFormat("expects fusion kind but sees: %s, error: %s",
val,
status_or_result.status().error_message()));
}
*result = status_or_result.ValueOrDie();
lexer_.Lex();
return true;
}
bool HloParserImpl::ParseRandomDistribution(RandomDistribution* result) {
VLOG(3) << "ParseRandomDistribution";
if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects random distribution");
}
std::string val = lexer_.GetStrVal();
auto status_or_result = StringToRandomDistribution(val);
if (!status_or_result.ok()) {
return TokenError(
StrFormat("expects random distribution but sees: %s, error: %s", val,
status_or_result.status().error_message()));
}
*result = status_or_result.ValueOrDie();
lexer_.Lex();
return true;
}
bool HloParserImpl::ParseRandomAlgorithm(RandomAlgorithm* result) {
VLOG(3) << "ParseRandomAlgorithm";
if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects random algorithm");
}
std::string val = lexer_.GetStrVal();
auto status_or_result = StringToRandomAlgorithm(val);
if (!status_or_result.ok()) {
return TokenError(
StrFormat("expects random algorithm but sees: %s, error: %s", val,
status_or_result.status().error_message()));
}
*result = status_or_result.ValueOrDie();
lexer_.Lex();
return true;
}
bool HloParserImpl::ParsePrecision(PrecisionConfig::Precision* result) {
VLOG(3) << "ParsePrecision";
if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects random distribution");
}
std::string val = lexer_.GetStrVal();
auto status_or_result = StringToPrecision(val);
if (!status_or_result.ok()) {
return TokenError(StrFormat("expects precision but sees: %s, error: %s",
val,
status_or_result.status().error_message()));
}
*result = status_or_result.ValueOrDie();
lexer_.Lex();
return true;
}
bool HloParserImpl::ParseInt64(int64_t* result) {
VLOG(3) << "ParseInt64";
if (lexer_.GetKind() != TokKind::kInt) {
return TokenError("expects integer");
}
*result = lexer_.GetInt64Val();
lexer_.Lex();
return true;
}
bool HloParserImpl::ParseDouble(double* result) {
switch (lexer_.GetKind()) {
case TokKind::kDecimal: {
double val = lexer_.GetDecimalVal();
// If GetDecimalVal returns +/-inf, that means that we overflowed
// `double`.
if (std::isinf(val)) {
return TokenError(StrCat("Constant is out of range for double (+/-",
std::numeric_limits<double>::max(),
") and so is unparsable."));
}
*result = val;
break;
}
case TokKind::kInt:
*result = static_cast<double>(lexer_.GetInt64Val());
break;
case TokKind::kw_inf:
*result = std::numeric_limits<double>::infinity();
break;
case TokKind::kNegInf:
*result = -std::numeric_limits<double>::infinity();
break;
default:
return TokenError("expects decimal or integer");
}
lexer_.Lex();
return true;
}
bool HloParserImpl::ParseComplex(std::complex<double>* result) {
if (lexer_.GetKind() != TokKind::kLparen) {
return TokenError("expects '(' before complex number");
}
lexer_.Lex();
double real;
LocTy loc = lexer_.GetLoc();
if (!ParseDouble(&real)) {
return Error(loc,
"expect floating-point value for real part of complex number");
}
if (lexer_.GetKind() != TokKind::kComma) {
return TokenError(
absl::StrFormat("expect comma after real part of complex literal"));
}
lexer_.Lex();
double imag;
loc = lexer_.GetLoc();
if (!ParseDouble(&imag)) {
return Error(
loc,
"expect floating-point value for imaginary part of complex number");
}
if (lexer_.GetKind() != TokKind::kRparen) {
return TokenError(absl::StrFormat("expect ')' after complex number"));
}
*result = std::complex<double>(real, imag);
lexer_.Lex();
return true;
}
bool HloParserImpl::ParseBool(bool* result) {
if (lexer_.GetKind() != TokKind::kw_true &&
lexer_.GetKind() != TokKind::kw_false) {
return TokenError("expects true or false");
}
*result = lexer_.GetKind() == TokKind::kw_true;
lexer_.Lex();
return true;
}
bool HloParserImpl::ParseToken(TokKind kind, const std::string& msg) {
VLOG(3) << "ParseToken " << TokKindToString(kind) << " " << msg;
if (lexer_.GetKind() != kind) {
return TokenError(msg);
}
lexer_.Lex();
return true;
}
bool HloParserImpl::EatIfPresent(TokKind kind) {
if (lexer_.GetKind() != kind) {
return false;
}
lexer_.Lex();
return true;
}
bool HloParserImpl::AddInstruction(const std::string& name,
HloInstruction* instruction,
LocTy name_loc) {
auto result = current_name_table().insert({name, {instruction, name_loc}});
if (!result.second) {
Error(name_loc, StrCat("instruction already exists: ", name));
return Error(/*loc=*/result.first->second.second,
"instruction previously defined here");
}
return true;
}
bool HloParserImpl::AddComputation(const std::string& name,
HloComputation* computation,
LocTy name_loc) {
auto result = computation_pool_.insert({name, {computation, name_loc}});
if (!result.second) {
Error(name_loc, StrCat("computation already exists: ", name));
return Error(/*loc=*/result.first->second.second,
"computation previously defined here");
}
return true;
}
StatusOr<Shape> HloParserImpl::ParseShapeOnly() {
lexer_.Lex();
Shape shape;
if (!ParseShape(&shape)) {
return InvalidArgument("Syntax error:\n%s", GetError());
}
if (lexer_.GetKind() != TokKind::kEof) {
return InvalidArgument("Syntax error:\nExtra content after shape");
}
return shape;
}
StatusOr<HloSharding> HloParserImpl::ParseShardingOnly() {
lexer_.Lex();
OpSharding op_sharding;
if (!ParseSharding(&op_sharding)) {
return InvalidArgument("Syntax error:\n%s", GetError());
}
if (lexer_.GetKind() != TokKind::kEof) {
return InvalidArgument("Syntax error:\nExtra content after sharding");
}
return HloSharding::FromProto(op_sharding);
}
StatusOr<FrontendAttributes> HloParserImpl::ParseFrontendAttributesOnly() {
lexer_.Lex();
FrontendAttributes attributes;
if (!ParseFrontendAttributes(&attributes)) {
return InvalidArgument("Syntax error:\n%s", GetError());
}
if (lexer_.GetKind() != TokKind::kEof) {
return InvalidArgument(
"Syntax error:\nExtra content after frontend attributes");
}
return attributes;
}
StatusOr<std::vector<bool>> HloParserImpl::ParseParameterReplicationOnly() {
lexer_.Lex();
ParameterReplication parameter_replication;
if (!ParseParameterReplication(&parameter_replication)) {
return InvalidArgument("Syntax error:\n%s", GetError());
}
if (lexer_.GetKind() != TokKind::kEof) {
return InvalidArgument(
"Syntax error:\nExtra content after parameter replication");
}
return std::vector<bool>(
parameter_replication.replicated_at_leaf_buffers().begin(),
parameter_replication.replicated_at_leaf_buffers().end());
}
StatusOr<std::vector<ReplicaGroup>> HloParserImpl::ParseReplicaGroupsOnly() {
lexer_.Lex();
std::vector<ReplicaGroup> replica_groups;
if (!ParseReplicaGroupsOnly(&replica_groups)) {
return InvalidArgument("Syntax error:\n%s", GetError());
}
if (lexer_.GetKind() != TokKind::kEof) {
return InvalidArgument("Syntax error:\nExtra content after replica groups");
}
return replica_groups;
}
StatusOr<Window> HloParserImpl::ParseWindowOnly() {
lexer_.Lex();
Window window;
if (!ParseWindow(&window, /*expect_outer_curlies=*/false)) {
return InvalidArgument("Syntax error:\n%s", GetError());
}
if (lexer_.GetKind() != TokKind::kEof) {
return InvalidArgument("Syntax error:\nExtra content after window");
}
return window;
}
StatusOr<ConvolutionDimensionNumbers>
HloParserImpl::ParseConvolutionDimensionNumbersOnly() {
lexer_.Lex();
ConvolutionDimensionNumbers dnums;
if (!ParseConvolutionDimensionNumbers(&dnums)) {
return InvalidArgument("Syntax error:\n%s", GetError());
}
if (lexer_.GetKind() != TokKind::kEof) {
return InvalidArgument(
"Syntax error:\nExtra content after convolution dnums");
}
return dnums;
}
StatusOr<PaddingConfig> HloParserImpl::ParsePaddingConfigOnly() {
lexer_.Lex();
PaddingConfig padding_config;
if (!ParsePaddingConfig(&padding_config)) {
return InvalidArgument("Syntax error:\n%s", GetError());
}
if (lexer_.GetKind() != TokKind::kEof) {
return InvalidArgument("Syntax error:\nExtra content after PaddingConfig");
}
return padding_config;
}
bool HloParserImpl::ParseSingleInstruction(HloModule* module) {
if (create_missing_instruction_ != nullptr || !scoped_name_tables_.empty()) {
LOG(FATAL) << "Parser state is not clean. Please do not call any other "
"methods before calling ParseSingleInstruction.";
}
HloComputation::Builder builder(module->name());
// The missing instruction hook we register creates the shaped instruction on
// the fly as a parameter and returns it.
int64_t parameter_count = 0;
create_missing_instruction_ =
[this, &builder, &parameter_count](
const std::string& name,
const Shape& shape) -> std::pair<HloInstruction*, LocTy>* {
std::string new_name = name.empty() ? StrCat("_", parameter_count) : name;
HloInstruction* parameter = builder.AddInstruction(
HloInstruction::CreateParameter(parameter_count++, shape, new_name));
current_name_table()[new_name] = {parameter, lexer_.GetLoc()};
return tensorflow::gtl::FindOrNull(current_name_table(), new_name);
};
// Parse the instruction with the registered hook.
Scope scope(&scoped_name_tables_);
if (CanBeShape()) {
// This means that the instruction's left-hand side is probably omitted,
// e.g.
//
// f32[10] fusion(...), calls={...}
if (!ParseInstructionRhs(&builder, module->name(), lexer_.GetLoc())) {
return false;
}
} else {
// This means that the instruction's left-hand side might exist, e.g.
//
// foo = f32[10] fusion(...), calls={...}
std::string root_name;
if (!ParseInstruction(&builder, &root_name)) {
return false;
}
}
if (lexer_.GetKind() != TokKind::kEof) {
Error(
lexer_.GetLoc(),
"Syntax error:\nExpected eof after parsing single instruction. Did "
"you mean to write an HLO module and forget the \"HloModule\" header?");
return false;
}
module->AddEntryComputation(builder.Build());
for (auto& comp : computations_) {
module->AddEmbeddedComputation(std::move(comp));
}
TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module)));
return true;
}
} // namespace
StatusOr<std::unique_ptr<HloModule>> ParseAndReturnUnverifiedModule(
absl::string_view str, const HloModuleConfig& config) {
auto module = absl::make_unique<HloModule>(/*name=*/"_", config);
HloParserImpl parser(str);
TF_RETURN_IF_ERROR(parser.Run(module.get()));
return std::move(module);
}
StatusOr<std::unique_ptr<HloModule>> ParseAndReturnUnverifiedModule(
absl::string_view str) {
return ParseAndReturnUnverifiedModule(str, HloModuleConfig());
}
StatusOr<HloSharding> ParseSharding(absl::string_view str) {
HloParserImpl parser(str);
return parser.ParseShardingOnly();
}
StatusOr<FrontendAttributes> ParseFrontendAttributes(absl::string_view str) {
HloParserImpl parser(str);
return parser.ParseFrontendAttributesOnly();
}
StatusOr<std::vector<bool>> ParseParameterReplication(absl::string_view str) {
HloParserImpl parser(str);
return parser.ParseParameterReplicationOnly();
}
StatusOr<std::vector<ReplicaGroup>> ParseReplicaGroupsOnly(
absl::string_view str) {
HloParserImpl parser(str);
return parser.ParseReplicaGroupsOnly();
}
StatusOr<Window> ParseWindow(absl::string_view str) {
HloParserImpl parser(str);
return parser.ParseWindowOnly();
}
StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
absl::string_view str) {
HloParserImpl parser(str);
return parser.ParseConvolutionDimensionNumbersOnly();
}
StatusOr<PaddingConfig> ParsePaddingConfig(absl::string_view str) {
HloParserImpl parser(str);
return parser.ParsePaddingConfigOnly();
}
StatusOr<Shape> ParseShape(absl::string_view str) {
HloParserImpl parser(str);
return parser.ParseShapeOnly();
}
std::unique_ptr<HloParser> HloParser::CreateHloParserForTests(
absl::string_view str) {
return absl::make_unique<HloParserImpl>(str);
}
} // namespace xla