| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/xla/service/hlo_parser.h" |
| |
| #include "absl/algorithm/container.h" |
| #include "absl/memory/memory.h" |
| #include "absl/strings/str_cat.h" |
| #include "absl/strings/str_format.h" |
| #include "absl/strings/str_join.h" |
| #include "absl/strings/str_split.h" |
| #include "tensorflow/compiler/xla/literal.h" |
| #include "tensorflow/compiler/xla/literal_util.h" |
| #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" |
| #include "tensorflow/compiler/xla/service/hlo_instructions.h" |
| #include "tensorflow/compiler/xla/service/hlo_opcode.h" |
| #include "tensorflow/compiler/xla/service/hlo_schedule.h" |
| #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" |
| #include "tensorflow/compiler/xla/shape_util.h" |
| #include "tensorflow/compiler/xla/util.h" |
| #include "tensorflow/core/lib/gtl/map_util.h" |
| |
| namespace xla { |
| |
| namespace { |
| |
| using absl::nullopt; |
| using absl::optional; |
| using absl::StrAppend; |
| using absl::StrCat; |
| using absl::StrFormat; |
| using absl::StrJoin; |
| |
| const double kF16max = 65504; |
| |
| // Creates and returns a schedule created using the order of the instructions in |
| // the HloComputation::instructions() vectors in the module. |
| HloSchedule ScheduleFromInstructionOrder(const HloModule* module) { |
| HloSchedule schedule(module); |
| for (const HloComputation* computation : module->computations()) { |
| if (!computation->IsFusionComputation()) { |
| for (const HloInstruction* instruction : computation->instructions()) { |
| schedule.GetOrCreateSequence(computation).push_back(instruction); |
| } |
| } |
| } |
| return schedule; |
| } |
| |
| // Parser for the HloModule::ToString() format text. |
| class HloParser { |
| public: |
| using LocTy = HloLexer::LocTy; |
| |
| explicit HloParser(absl::string_view str) : lexer_(str) {} |
| |
| // Runs the parser and constructs the resulting HLO in the given (empty) |
| // HloModule. Returns false if an error occurred. |
| bool Run(HloModule* module); |
| |
| // Returns the error information. |
| string GetError() const { return StrJoin(error_, "\n"); } |
| |
| // Stand alone parsing utils for various aggregate data types. |
| StatusOr<HloSharding> ParseShardingOnly(); |
| StatusOr<Window> ParseWindowOnly(); |
| StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly(); |
| StatusOr<PaddingConfig> ParsePaddingConfigOnly(); |
| |
| // Stand-alone parsing utility for a single instruction worth of text. |
| Status ParseSingleInstruction(HloComputation::Builder* builder, |
| string* root_name); |
| |
| private: |
| // Locates an instruction with the given name in the instruction_pool_ or |
| // returns nullptr. |
| // |
| // If the missing_instruction_hook_ is registered and a "shape" is provided, |
| // the hook will be called and may satisfy the request for the given |
| // instruction. This is useful when we reify parameters as they're resolved; |
| // i.e. for ParseSingleInstruction. |
| std::pair<HloInstruction*, LocTy>* FindInstruction( |
| const string& name, const optional<Shape>& shape = nullopt); |
| |
| // ParseXXX returns false if an error occurred. |
| bool ParseHloModule(HloModule* module); |
| bool ParseComputations(HloModule* module); |
| bool ParseComputation(HloComputation** entry_computation); |
| bool ParseInstructionList(HloComputation::Builder* builder, |
| string* root_name); |
| bool ParseInstruction(HloComputation::Builder* builder, string* root_name); |
| bool ParseControlPredecessors(HloInstruction* instruction); |
| bool ParseLiteral(Literal* literal, const Shape& shape); |
| bool ParseTupleLiteral(Literal* literal, const Shape& shape); |
| bool ParseNonTupleLiteral(Literal* literal, const Shape& shape); |
| bool ParseDenseLiteral(Literal* literal, const Shape& shape); |
| bool ParseSparseLiteral(Literal* literal, const Shape& shape); |
| template <typename LiteralNativeT> |
| bool ParseSparseLiteralHelper(Literal* literal, const Shape& shape); |
| |
| // Sets the sub-value of literal at the given index to the given value. The |
| // literal's shape must have the default layout. |
| bool SetValueInLiteral(tensorflow::int64 value, |
| tensorflow::int64 linear_index, Literal* literal); |
| bool SetValueInLiteral(double value, tensorflow::int64 linear_index, |
| Literal* literal); |
| bool SetValueInLiteral(bool value, tensorflow::int64 linear_index, |
| Literal* literal); |
| template <typename LiteralNativeT, typename ParsedElemT> |
| bool SetValueInLiteralHelper(ParsedElemT value, |
| tensorflow::int64 linear_index, |
| Literal* literal); |
| |
| bool ParseOperands(std::vector<HloInstruction*>* operands); |
| // Fills parsed operands into 'operands' and expects a certain number of |
| // operands. |
| bool ParseOperands(std::vector<HloInstruction*>* operands, |
| const int expected_size); |
| |
| // Describes the start, limit, and stride on every dimension of the operand |
| // being sliced. |
| struct SliceRanges { |
| std::vector<tensorflow::int64> starts; |
| std::vector<tensorflow::int64> limits; |
| std::vector<tensorflow::int64> strides; |
| }; |
| |
| // The data parsed for the kDomain instruction. |
| struct DomainData { |
| std::unique_ptr<DomainMetadata> entry_metadata; |
| std::unique_ptr<DomainMetadata> exit_metadata; |
| }; |
| |
| // Types of attributes. |
| enum class AttrTy { |
| kBool, |
| kInt64, |
| kInt32, |
| kFloat, |
| kString, |
| kBracedInt64List, |
| kBracedInt64ListList, |
| kHloComputation, |
| kFftType, |
| kWindow, |
| kConvolutionDimensionNumbers, |
| kSharding, |
| kInstructionList, |
| kSliceRanges, |
| kPaddingConfig, |
| kMetadata, |
| kFusionKind, |
| kDistribution, |
| kDomain, |
| kPrecisionList, |
| }; |
| |
| struct AttrConfig { |
| bool required; // whether it's required or optional |
| AttrTy attr_type; // what type it is |
| void* result; // where to store the parsed result. |
| }; |
| |
| // attributes ::= (',' attribute)* |
| // |
| // Parses attributes given names and configs of the attributes. Each parsed |
| // result is passed back through the result pointer in corresponding |
| // AttrConfig. Note that the result pointer must point to a optional<T> typed |
| // variable which outlives this function. Returns false on error. You should |
| // not use the any of the results if this function failed. |
| // |
| // Example usage: |
| // |
| // std::unordered_map<string, AttrConfig> attrs; |
| // optional<int64> foo; |
| // attrs["foo"] = {/*required=*/false, AttrTy::kInt64, &foo}; |
| // optional<Window> bar; |
| // attrs["bar"] = {/*required=*/true, AttrTy::kWindow, &bar}; |
| // if (!ParseAttributes(attrs)) { |
| // return false; // Do not use 'foo' 'bar' if failed. |
| // } |
| // // Do something with 'bar'. |
| // if (foo) { // If attr foo is seen, do something with 'foo'. } |
| // |
| bool ParseAttributes(const std::unordered_map<string, AttrConfig>& attrs); |
| |
| // sub_attributes ::= '{' (','? attribute)* '}' |
| // |
| // Usage is the same as ParseAttributes. See immediately above. |
| bool ParseSubAttributes(const std::unordered_map<string, AttrConfig>& attrs); |
| |
| // Parses one attribute. If it has already been seen, return error. Returns |
| // true and adds to seen_attrs on success. |
| // |
| // Do not call this except in ParseAttributes or ParseSubAttributes. |
| bool ParseAttributeHelper(const std::unordered_map<string, AttrConfig>& attrs, |
| std::unordered_set<string>* seen_attrs); |
| |
| // Parses a name and finds the corresponding hlo computation. |
| bool ParseComputationName(HloComputation** value); |
| // Parses a list of names and finds the corresponding hlo instructions. |
| bool ParseInstructionNames(std::vector<HloInstruction*>* instructions); |
| // Pass expect_outer_curlies == true when parsing a Window in the context of a |
| // larger computation. Pass false when parsing a stand-alone Window string. |
| bool ParseWindow(Window* window, bool expect_outer_curlies); |
| bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums); |
| bool ParsePaddingConfig(PaddingConfig* padding); |
| bool ParseMetadata(OpMetadata* metadata); |
| bool ParseSharding(OpSharding* sharding); |
| bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed); |
| |
| // Parses the metadata behind a kDOmain instruction. |
| bool ParseDomain(DomainData* domain); |
| |
| // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3. |
| bool ParseDxD(const string& name, std::vector<tensorflow::int64>* result); |
| // Parses window's pad sub-attriute, e.g., pad=0_0x3x3. |
| bool ParseWindowPad(std::vector<std::vector<tensorflow::int64>>* pad); |
| |
| bool ParseSliceRanges(SliceRanges* result); |
| bool ParsePrecisionList(std::vector<PrecisionConfig::Precision>* result); |
| bool ParseInt64List(const TokKind start, const TokKind end, |
| const TokKind delim, |
| std::vector<tensorflow::int64>* result); |
| // 'parse_and_add_item' is an lambda to parse an element in the list and add |
| // the parsed element to the result. It's supposed to capture the result. |
| bool ParseList(const TokKind start, const TokKind end, const TokKind delim, |
| const std::function<bool()>& parse_and_add_item); |
| |
| bool ParseParamListToShape(Shape* shape, LocTy* shape_loc); |
| bool ParseParamList(); |
| bool ParseName(string* result); |
| bool ParseAttributeName(string* result); |
| bool ParseString(string* result); |
| bool ParseShape(Shape* result); |
| bool ParseOpcode(HloOpcode* result); |
| bool ParseFftType(FftType* result); |
| bool ParseFusionKind(HloInstruction::FusionKind* result); |
| bool ParseRandomDistribution(RandomDistribution* result); |
| bool ParsePrecision(PrecisionConfig::Precision* result); |
| bool ParseInt64(tensorflow::int64* result); |
| bool ParseDouble(double* result); |
| bool ParseBool(bool* result); |
| bool ParseToken(TokKind kind, const string& msg); |
| |
| // Returns true if the current token is the beginning of a shape. |
| bool CanBeShape(); |
| // Returns true if the current token is the beginning of a |
| // param_list_to_shape. |
| bool CanBeParamListToShape(); |
| |
| // Logs the current parsing line and the given message. Always returns false. |
| bool TokenError(absl::string_view msg); |
| bool Error(LocTy loc, absl::string_view msg); |
| |
| // If the current token is 'kind', eats it (i.e. lexes the next token) and |
| // returns true. |
| bool EatIfPresent(TokKind kind); |
| // Parses a shape, and returns true if the result is compatible with the given |
| // shape. |
| bool EatShapeAndCheckCompatible(const Shape& shape); |
| |
| // Adds the instruction to the pool. Returns false and emits an error if the |
| // instruction already exists. |
| bool AddInstruction(const string& name, HloInstruction* instruction, |
| LocTy name_loc); |
| // Adds the computation to the pool. Returns false and emits an error if the |
| // computation already exists. |
| bool AddComputation(const string& name, HloComputation* computation, |
| LocTy name_loc); |
| |
| // The map from the instruction/computation name to the |
| // instruction/computation itself and it's location. This does not own the |
| // pointers. |
| std::unordered_map<string, std::pair<HloInstruction*, LocTy>> |
| instruction_pool_; |
| std::unordered_map<string, std::pair<HloComputation*, LocTy>> |
| computation_pool_; |
| |
| HloLexer lexer_; |
| std::vector<std::unique_ptr<HloComputation>> computations_; |
| std::vector<string> error_; |
| |
| // Function that gets invoked when we try to resolve an instruction |
| // instruction_pool_ but fail to do so. |
| std::function<std::pair<HloInstruction*, LocTy>*(string, |
| const optional<Shape>&)> |
| missing_instruction_hook_; |
| }; |
| |
| bool SplitToInt64s(absl::string_view s, char delim, std::vector<int64>* out) { |
| for (const auto& split : absl::StrSplit(s, delim)) { |
| int64 val; |
| if (!absl::SimpleAtoi(split, &val)) { |
| return false; |
| } |
| out->push_back(val); |
| } |
| return true; |
| } |
| |
| // Creates replica groups from the provided nested array. groups[i] represents |
| // the replica ids for group 'i'. |
| std::vector<ReplicaGroup> CreateReplicaGroups( |
| absl::Span<const std::vector<int64>> groups) { |
| std::vector<ReplicaGroup> replica_groups; |
| absl::c_transform(groups, std::back_inserter(replica_groups), |
| [](const std::vector<int64>& ids) { |
| ReplicaGroup group; |
| *group.mutable_replica_ids() = {ids.begin(), ids.end()}; |
| return group; |
| }); |
| return replica_groups; |
| } |
| |
| bool HloParser::Error(LocTy loc, absl::string_view msg) { |
| auto line_col = lexer_.GetLineAndColumn(loc); |
| const unsigned line = line_col.first; |
| const unsigned col = line_col.second; |
| std::vector<string> error_lines; |
| error_lines.push_back( |
| StrCat("was parsing ", line, ":", col, ": error: ", msg)); |
| error_lines.emplace_back(lexer_.GetLine(loc)); |
| error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^")); |
| |
| error_.push_back(StrJoin(error_lines, "\n")); |
| VLOG(1) << "Error: " << error_.back(); |
| return false; |
| } |
| |
| bool HloParser::TokenError(absl::string_view msg) { |
| return Error(lexer_.GetLoc(), msg); |
| } |
| |
| bool HloParser::Run(HloModule* module) { |
| lexer_.Lex(); |
| return ParseHloModule(module); |
| } |
| |
| std::pair<HloInstruction*, HloParser::LocTy>* HloParser::FindInstruction( |
| const string& name, const optional<Shape>& shape) { |
| std::pair<HloInstruction*, LocTy>* instr = |
| tensorflow::gtl::FindOrNull(instruction_pool_, name); |
| // Potentially call the missing instruction hook. |
| if (instr == nullptr && missing_instruction_hook_ != nullptr) { |
| return missing_instruction_hook_(name, shape); |
| } |
| return instr; |
| } |
| |
| // ::= 'HloModule' name computations |
| bool HloParser::ParseHloModule(HloModule* module) { |
| if (lexer_.GetKind() != TokKind::kw_HloModule) { |
| return TokenError("expects HloModule"); |
| } |
| // Eat 'HloModule' |
| lexer_.Lex(); |
| |
| string name; |
| if (!ParseName(&name)) { |
| return false; |
| } |
| |
| absl::optional<bool> is_scheduled; |
| std::unordered_map<string, AttrConfig> attrs; |
| attrs["is_scheduled"] = {/*required=*/false, AttrTy::kBool, &is_scheduled}; |
| if (!ParseAttributes(attrs)) { |
| return false; |
| } |
| |
| module->set_name(name); |
| if (!ParseComputations(module)) { |
| return false; |
| } |
| |
| if (is_scheduled.has_value() && *is_scheduled) { |
| TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module))); |
| } |
| |
| return true; |
| } |
| |
| // computations ::= (computation)+ |
| bool HloParser::ParseComputations(HloModule* module) { |
| HloComputation* entry_computation = nullptr; |
| do { |
| if (!ParseComputation(&entry_computation)) { |
| return false; |
| } |
| } while (lexer_.GetKind() != TokKind::kEof); |
| |
| for (int i = 0; i < computations_.size(); i++) { |
| // If entry_computation is not nullptr, it means the computation it pointed |
| // to is marked with "ENTRY"; otherwise, no computation is marked with |
| // "ENTRY", and we use the last computation as the entry computation. We |
| // add the non-entry computations as embedded computations to the module. |
| if ((entry_computation != nullptr && |
| computations_[i].get() != entry_computation) || |
| (entry_computation == nullptr && i != computations_.size() - 1)) { |
| module->AddEmbeddedComputation(std::move(computations_[i])); |
| continue; |
| } |
| auto computation = module->AddEntryComputation(std::move(computations_[i])); |
| // The parameters and result layouts were set to default layout. Here we |
| // set the layouts to what the hlo text says. |
| for (int p = 0; p < computation->num_parameters(); p++) { |
| const Shape& param_shape = computation->parameter_instruction(p)->shape(); |
| TF_CHECK_OK(module->mutable_entry_computation_layout() |
| ->mutable_parameter_layout(p) |
| ->CopyLayoutFromShape(param_shape)); |
| } |
| const Shape& result_shape = computation->root_instruction()->shape(); |
| TF_CHECK_OK(module->mutable_entry_computation_layout() |
| ->mutable_result_layout() |
| ->CopyLayoutFromShape(result_shape)); |
| } |
| return true; |
| } |
| |
| // computation ::= ('ENTRY')? name (param_list_to_shape)? instruction_list |
| bool HloParser::ParseComputation(HloComputation** entry_computation) { |
| LocTy maybe_entry_loc = lexer_.GetLoc(); |
| const bool is_entry_computation = EatIfPresent(TokKind::kw_ENTRY); |
| |
| string name; |
| LocTy name_loc = lexer_.GetLoc(); |
| if (!ParseName(&name)) { |
| return false; |
| } |
| auto builder = absl::make_unique<HloComputation::Builder>(name); |
| |
| LocTy shape_loc = nullptr; |
| Shape shape; |
| if (CanBeParamListToShape() && !ParseParamListToShape(&shape, &shape_loc)) { |
| return false; |
| } |
| |
| string root_name; |
| if (!ParseInstructionList(builder.get(), &root_name)) { |
| return false; |
| } |
| |
| std::pair<HloInstruction*, LocTy>* root_node = FindInstruction(root_name); |
| // This means some instruction was marked as ROOT but we didn't find it in the |
| // pool, which should not happen. |
| if (!root_name.empty() && root_node == nullptr) { |
| LOG(FATAL) << "instruction " << root_name |
| << " was marked as ROOT but the parser has not seen it before"; |
| } |
| |
| HloInstruction* root = root_node == nullptr ? nullptr : root_node->first; |
| // Now root can be either an existing instruction or a nullptr. If it's a |
| // nullptr, the implementation of Builder will set the last instruction as |
| // root instruction. |
| computations_.emplace_back(builder->Build(root)); |
| HloComputation* computation = computations_.back().get(); |
| |
| if (!root) { |
| root = computation->root_instruction(); |
| } else { |
| CHECK_EQ(root, computation->root_instruction()); |
| } |
| |
| // If param_list_to_shape was present, check compatibility. |
| if (shape_loc != nullptr && !ShapeUtil::Compatible(root->shape(), shape)) { |
| return Error( |
| shape_loc, |
| StrCat("Shape of computation ", name, ", ", |
| ShapeUtil::HumanString(shape), |
| ", is not compatible with that of its root instruction ", |
| root_name, ", ", ShapeUtil::HumanString(root->shape()))); |
| } |
| |
| if (is_entry_computation) { |
| if (*entry_computation != nullptr) { |
| return Error(maybe_entry_loc, "expects only one ENTRY"); |
| } |
| *entry_computation = computation; |
| } |
| instruction_pool_.clear(); |
| |
| return AddComputation(name, computation, name_loc); |
| } |
| |
| // instruction_list ::= '{' instruction_list1 '}' |
| // instruction_list1 ::= (instruction)+ |
| bool HloParser::ParseInstructionList(HloComputation::Builder* builder, |
| string* root_name) { |
| if (!ParseToken(TokKind::kLbrace, |
| "expects '{' at the beginning of instruction list.")) { |
| return false; |
| } |
| do { |
| if (!ParseInstruction(builder, root_name)) { |
| return false; |
| } |
| } while (lexer_.GetKind() != TokKind::kRbrace); |
| return ParseToken(TokKind::kRbrace, |
| "expects '}' at the end of instruction list."); |
| } |
| |
| // instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)* |
| bool HloParser::ParseInstruction(HloComputation::Builder* builder, |
| string* root_name) { |
| string name; |
| Shape shape; |
| HloOpcode opcode; |
| std::vector<HloInstruction*> operands; |
| |
| LocTy maybe_root_loc = lexer_.GetLoc(); |
| bool is_root = EatIfPresent(TokKind::kw_ROOT); |
| |
| const LocTy name_loc = lexer_.GetLoc(); |
| if (!ParseName(&name) || |
| !ParseToken(TokKind::kEqual, "expects '=' in instruction") || |
| !ParseShape(&shape) || !ParseOpcode(&opcode)) { |
| return false; |
| } |
| |
| if (is_root) { |
| if (!root_name->empty()) { |
| return Error(maybe_root_loc, "one computation should have only one ROOT"); |
| } |
| *root_name = name; |
| } |
| |
| // Add optional attributes. |
| std::unordered_map<string, AttrConfig> attrs; |
| optional<OpSharding> sharding; |
| attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding}; |
| optional<std::vector<HloInstruction*>> predecessors; |
| attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList, |
| &predecessors}; |
| optional<OpMetadata> metadata; |
| attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata}; |
| |
| optional<string> backend_config; |
| attrs["backend_config"] = {/*required=*/false, AttrTy::kString, |
| &backend_config}; |
| |
| HloInstruction* instruction; |
| switch (opcode) { |
| case HloOpcode::kParameter: { |
| tensorflow::int64 parameter_number; |
| if (!ParseToken(TokKind::kLparen, |
| "expects '(' before parameter number") || |
| !ParseInt64(¶meter_number) || |
| !ParseToken(TokKind::kRparen, "expects ')' after parameter number") || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = builder->AddInstruction( |
| HloInstruction::CreateParameter(parameter_number, shape, name)); |
| break; |
| } |
| case HloOpcode::kConstant: { |
| Literal literal; |
| if (!ParseToken(TokKind::kLparen, |
| "expects '(' before constant literal") || |
| !ParseLiteral(&literal, shape) || |
| !ParseToken(TokKind::kRparen, "expects ')' after constant literal") || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = builder->AddInstruction( |
| HloInstruction::CreateConstant(std::move(literal))); |
| break; |
| } |
| case HloOpcode::kIota: { |
| optional<tensorflow::int64> iota_dimension; |
| attrs["iota_dimension"] = {/*required=*/true, AttrTy::kInt64, |
| &iota_dimension}; |
| if (!ParseOperands(&operands, /*expected_size=*/0) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = builder->AddInstruction( |
| HloInstruction::CreateIota(shape, *iota_dimension)); |
| break; |
| } |
| // Unary ops. |
| case HloOpcode::kAbs: |
| case HloOpcode::kRoundNearestAfz: |
| case HloOpcode::kBitcast: |
| case HloOpcode::kCeil: |
| case HloOpcode::kClz: |
| case HloOpcode::kCopy: |
| case HloOpcode::kCos: |
| case HloOpcode::kExp: |
| case HloOpcode::kExpm1: |
| case HloOpcode::kImag: |
| case HloOpcode::kIsFinite: |
| case HloOpcode::kFloor: |
| case HloOpcode::kLog: |
| case HloOpcode::kLog1p: |
| case HloOpcode::kNot: |
| case HloOpcode::kNegate: |
| case HloOpcode::kReal: |
| case HloOpcode::kSign: |
| case HloOpcode::kSin: |
| case HloOpcode::kTanh: { |
| if (!ParseOperands(&operands, /*expected_size=*/1) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = builder->AddInstruction( |
| HloInstruction::CreateUnary(shape, opcode, operands[0])); |
| break; |
| } |
| // Binary ops. |
| case HloOpcode::kAdd: |
| case HloOpcode::kDivide: |
| case HloOpcode::kMultiply: |
| case HloOpcode::kSubtract: |
| case HloOpcode::kAtan2: |
| case HloOpcode::kComplex: |
| case HloOpcode::kEq: |
| case HloOpcode::kGe: |
| case HloOpcode::kGt: |
| case HloOpcode::kLe: |
| case HloOpcode::kLt: |
| case HloOpcode::kNe: |
| case HloOpcode::kMaximum: |
| case HloOpcode::kMinimum: |
| case HloOpcode::kPower: |
| case HloOpcode::kRemainder: |
| case HloOpcode::kAnd: |
| case HloOpcode::kOr: |
| case HloOpcode::kXor: |
| case HloOpcode::kShiftLeft: |
| case HloOpcode::kShiftRightArithmetic: |
| case HloOpcode::kShiftRightLogical: { |
| if (!ParseOperands(&operands, /*expected_size=*/2) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = builder->AddInstruction(HloInstruction::CreateBinary( |
| shape, opcode, operands[0], operands[1])); |
| break; |
| } |
| // Ternary ops. |
| case HloOpcode::kClamp: |
| case HloOpcode::kSelect: |
| case HloOpcode::kTupleSelect: { |
| if (!ParseOperands(&operands, /*expected_size=*/3) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = builder->AddInstruction(HloInstruction::CreateTernary( |
| shape, opcode, operands[0], operands[1], operands[2])); |
| break; |
| } |
| // Other supported ops. |
| case HloOpcode::kConvert: { |
| if (!ParseOperands(&operands, /*expected_size=*/1) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = builder->AddInstruction( |
| HloInstruction::CreateConvert(shape, operands[0])); |
| break; |
| } |
| case HloOpcode::kBitcastConvert: { |
| if (!ParseOperands(&operands, /*expected_size=*/1) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = builder->AddInstruction( |
| HloInstruction::CreateBitcastConvert(shape, operands[0])); |
| break; |
| } |
| case HloOpcode::kCrossReplicaSum: { |
| optional<std::vector<std::vector<int64>>> tmp_groups; |
| optional<HloComputation*> to_apply; |
| optional<std::vector<int64>> replica_group_ids; |
| optional<string> barrier; |
| optional<int64> all_reduce_id; |
| attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, |
| &to_apply}; |
| attrs["replica_groups"] = {/*required=*/false, |
| AttrTy::kBracedInt64ListList, &tmp_groups}; |
| attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier}; |
| attrs["all_reduce_id"] = {/*required=*/false, AttrTy::kInt64, |
| &all_reduce_id}; |
| if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { |
| return false; |
| } |
| std::vector<ReplicaGroup> replica_groups; |
| if (tmp_groups) { |
| replica_groups = CreateReplicaGroups(*tmp_groups); |
| } |
| instruction = |
| builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( |
| shape, operands, *to_apply, replica_groups, |
| barrier ? *barrier : "", all_reduce_id)); |
| break; |
| } |
| case HloOpcode::kAllToAll: { |
| optional<std::vector<std::vector<int64>>> tmp_groups; |
| optional<string> barrier; |
| attrs["replica_groups"] = {/*required=*/false, |
| AttrTy::kBracedInt64ListList, &tmp_groups}; |
| if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { |
| return false; |
| } |
| std::vector<ReplicaGroup> replica_groups; |
| if (tmp_groups) { |
| replica_groups = CreateReplicaGroups(*tmp_groups); |
| } |
| instruction = builder->AddInstruction( |
| HloInstruction::CreateAllToAll(shape, operands, replica_groups)); |
| break; |
| } |
| case HloOpcode::kCollectivePermute: { |
| optional<std::vector<std::vector<int64>>> source_targets; |
| attrs["source_target_pairs"] = { |
| /*required=*/true, AttrTy::kBracedInt64ListList, &source_targets}; |
| if (!ParseOperands(&operands, /*expected_size=*/1) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| std::vector<std::pair<int64, int64>> pairs(source_targets->size()); |
| for (int i = 0; i < pairs.size(); i++) { |
| if ((*source_targets)[i].size() != 2) { |
| return TokenError( |
| "expects 'source_target_pairs=' to be a list of pairs"); |
| } |
| pairs[i].first = (*source_targets)[i][0]; |
| pairs[i].second = (*source_targets)[i][1]; |
| } |
| instruction = builder->AddInstruction( |
| HloInstruction::CreateCollectivePermute(shape, operands[0], pairs)); |
| break; |
| } |
| case HloOpcode::kReshape: { |
| if (!ParseOperands(&operands, /*expected_size=*/1) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = builder->AddInstruction( |
| HloInstruction::CreateReshape(shape, operands[0])); |
| break; |
| } |
| case HloOpcode::kAfterAll: { |
| if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { |
| return false; |
| } |
| if (operands.empty()) { |
| instruction = builder->AddInstruction(HloInstruction::CreateToken()); |
| } else { |
| instruction = |
| builder->AddInstruction(HloInstruction::CreateAfterAll(operands)); |
| } |
| break; |
| } |
| case HloOpcode::kSort: { |
| auto loc = lexer_.GetLoc(); |
| |
| optional<std::vector<tensorflow::int64>> dimensions; |
| attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, |
| &dimensions}; |
| if (!ParseOperands(&operands) || !ParseAttributes(attrs) || |
| dimensions->size() != 1) { |
| return false; |
| } |
| switch (operands.size()) { |
| case 1: |
| instruction = builder->AddInstruction(HloInstruction::CreateSort( |
| shape, dimensions->at(0), /*keys=*/operands[0])); |
| break; |
| case 2: |
| instruction = builder->AddInstruction(HloInstruction::CreateSort( |
| shape, dimensions->at(0), |
| /*keys=*/operands[0], /*values=*/operands[1])); |
| break; |
| default: |
| return Error(loc, StrCat("expects either 1 or 2 operands, but has ", |
| operands.size(), " operands")); |
| } |
| break; |
| } |
| case HloOpcode::kTuple: { |
| if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = |
| builder->AddInstruction(HloInstruction::CreateTuple(operands)); |
| break; |
| } |
| case HloOpcode::kWhile: { |
| optional<HloComputation*> condition; |
| optional<HloComputation*> body; |
| attrs["condition"] = {/*required=*/true, AttrTy::kHloComputation, |
| &condition}; |
| attrs["body"] = {/*required=*/true, AttrTy::kHloComputation, &body}; |
| if (!ParseOperands(&operands, /*expected_size=*/1) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = builder->AddInstruction(HloInstruction::CreateWhile( |
| shape, *condition, *body, /*init=*/operands[0])); |
| break; |
| } |
| case HloOpcode::kRecv: { |
| optional<tensorflow::int64> channel_id; |
| // If the is_host_transfer attribute is not present then default to false. |
| optional<bool> is_host_transfer = false; |
| attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; |
| attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool, |
| &is_host_transfer}; |
| if (!ParseOperands(&operands, /*expected_size=*/1) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| // If the is_host_transfer attribute is not present then default to false. |
| instruction = builder->AddInstruction(HloInstruction::CreateRecv( |
| shape.tuple_shapes(0), operands[0], *channel_id, *is_host_transfer)); |
| break; |
| } |
| case HloOpcode::kRecvDone: { |
| optional<tensorflow::int64> channel_id; |
| // If the is_host_transfer attribute is not present then default to false. |
| optional<bool> is_host_transfer = false; |
| attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; |
| attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool, |
| &is_host_transfer}; |
| if (!ParseOperands(&operands, /*expected_size=*/1) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| if (channel_id != operands[0]->channel_id()) { |
| return false; |
| } |
| instruction = builder->AddInstruction( |
| HloInstruction::CreateRecvDone(operands[0], *is_host_transfer)); |
| break; |
| } |
| case HloOpcode::kSend: { |
| optional<tensorflow::int64> channel_id; |
| // If the is_host_transfer attribute is not present then default to false. |
| optional<bool> is_host_transfer = false; |
| attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; |
| attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool, |
| &is_host_transfer}; |
| if (!ParseOperands(&operands, /*expected_size=*/2) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = builder->AddInstruction(HloInstruction::CreateSend( |
| operands[0], operands[1], *channel_id, *is_host_transfer)); |
| break; |
| } |
| case HloOpcode::kSendDone: { |
| optional<tensorflow::int64> channel_id; |
| // If the is_host_transfer attribute is not present then default to false. |
| optional<bool> is_host_transfer = false; |
| attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; |
| attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool, |
| &is_host_transfer}; |
| if (!ParseOperands(&operands, /*expected_size=*/1) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| if (channel_id != operands[0]->channel_id()) { |
| return false; |
| } |
| instruction = builder->AddInstruction( |
| HloInstruction::CreateSendDone(operands[0], *is_host_transfer)); |
| break; |
| } |
| case HloOpcode::kGetTupleElement: { |
| optional<tensorflow::int64> index; |
| attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index}; |
| if (!ParseOperands(&operands, /*expected_size=*/1) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = builder->AddInstruction( |
| HloInstruction::CreateGetTupleElement(shape, operands[0], *index)); |
| break; |
| } |
| case HloOpcode::kCall: { |
| optional<HloComputation*> to_apply; |
| attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, |
| &to_apply}; |
| if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = builder->AddInstruction( |
| HloInstruction::CreateCall(shape, operands, *to_apply)); |
| break; |
| } |
| case HloOpcode::kReduceWindow: { |
| optional<HloComputation*> reduce_computation; |
| optional<Window> window; |
| attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; |
| attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, |
| &reduce_computation}; |
| if (!ParseOperands(&operands, /*expected_size=*/2) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| if (!window) { |
| window.emplace(); |
| } |
| instruction = builder->AddInstruction(HloInstruction::CreateReduceWindow( |
| shape, /*operand=*/operands[0], /*init_value=*/operands[1], *window, |
| *reduce_computation)); |
| break; |
| } |
| case HloOpcode::kConvolution: { |
| optional<Window> window; |
| optional<ConvolutionDimensionNumbers> dnums; |
| optional<int64> feature_group_count; |
| attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; |
| attrs["dim_labels"] = {/*required=*/true, |
| AttrTy::kConvolutionDimensionNumbers, &dnums}; |
| attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64, |
| &feature_group_count}; |
| optional<std::vector<PrecisionConfig::Precision>> operand_precision; |
| attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList, |
| &operand_precision}; |
| if (!ParseOperands(&operands, /*expected_size=*/2) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| if (!window) { |
| window.emplace(); |
| } |
| if (!feature_group_count) { |
| feature_group_count = 1; |
| } |
| PrecisionConfig precision_config; |
| if (operand_precision) { |
| *precision_config.mutable_operand_precision() = { |
| operand_precision->begin(), operand_precision->end()}; |
| } else { |
| precision_config.mutable_operand_precision()->Resize( |
| operands.size(), PrecisionConfig::DEFAULT); |
| } |
| instruction = builder->AddInstruction(HloInstruction::CreateConvolve( |
| shape, /*lhs=*/operands[0], /*rhs=*/operands[1], |
| feature_group_count.value(), *window, *dnums, precision_config)); |
| break; |
| } |
| case HloOpcode::kFft: { |
| optional<FftType> fft_type; |
| optional<std::vector<tensorflow::int64>> fft_length; |
| attrs["fft_type"] = {/*required=*/true, AttrTy::kFftType, &fft_type}; |
| attrs["fft_length"] = {/*required=*/true, AttrTy::kBracedInt64List, |
| &fft_length}; |
| if (!ParseOperands(&operands, /*expected_size=*/1) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = builder->AddInstruction(HloInstruction::CreateFft( |
| shape, operands[0], *fft_type, *fft_length)); |
| break; |
| } |
| case HloOpcode::kBroadcast: { |
| optional<std::vector<tensorflow::int64>> broadcast_dimensions; |
| attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, |
| &broadcast_dimensions}; |
| if (!ParseOperands(&operands, /*expected_size=*/1) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = builder->AddInstruction(HloInstruction::CreateBroadcast( |
| shape, operands[0], *broadcast_dimensions)); |
| break; |
| } |
| case HloOpcode::kConcatenate: { |
| optional<std::vector<tensorflow::int64>> dimensions; |
| attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, |
| &dimensions}; |
| if (!ParseOperands(&operands) || !ParseAttributes(attrs) || |
| dimensions->size() != 1) { |
| return false; |
| } |
| instruction = builder->AddInstruction(HloInstruction::CreateConcatenate( |
| shape, operands, dimensions->at(0))); |
| break; |
| } |
| case HloOpcode::kMap: { |
| optional<HloComputation*> to_apply; |
| attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, |
| &to_apply}; |
| optional<std::vector<tensorflow::int64>> dimensions; |
| attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List, |
| &dimensions}; |
| if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = builder->AddInstruction( |
| HloInstruction::CreateMap(shape, operands, *to_apply)); |
| break; |
| } |
| case HloOpcode::kReduce: { |
| auto loc = lexer_.GetLoc(); |
| |
| optional<HloComputation*> reduce_computation; |
| attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, |
| &reduce_computation}; |
| optional<std::vector<tensorflow::int64>> dimensions_to_reduce; |
| attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, |
| &dimensions_to_reduce}; |
| if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { |
| return false; |
| } |
| if (operands.size() % 2) { |
| return Error(loc, StrCat("expects an even number of operands, but has ", |
| operands.size(), " operands")); |
| } |
| instruction = builder->AddInstruction(HloInstruction::CreateReduce( |
| shape, /*operands=*/ |
| absl::Span<HloInstruction* const>(operands).subspan( |
| 0, operands.size() / 2), |
| /*init_values=*/ |
| absl::Span<HloInstruction* const>(operands).subspan( |
| operands.size() / 2, operands.size()), |
| *dimensions_to_reduce, *reduce_computation)); |
| break; |
| } |
| case HloOpcode::kReverse: { |
| optional<std::vector<tensorflow::int64>> dimensions; |
| attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, |
| &dimensions}; |
| if (!ParseOperands(&operands, /*expected_size=*/1) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = builder->AddInstruction( |
| HloInstruction::CreateReverse(shape, operands[0], *dimensions)); |
| break; |
| } |
| case HloOpcode::kSelectAndScatter: { |
| optional<HloComputation*> select; |
| attrs["select"] = {/*required=*/true, AttrTy::kHloComputation, &select}; |
| optional<HloComputation*> scatter; |
| attrs["scatter"] = {/*required=*/true, AttrTy::kHloComputation, &scatter}; |
| optional<Window> window; |
| attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; |
| if (!ParseOperands(&operands, /*expected_size=*/3) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| if (!window) { |
| window.emplace(); |
| } |
| instruction = |
| builder->AddInstruction(HloInstruction::CreateSelectAndScatter( |
| shape, /*operand=*/operands[0], *select, *window, |
| /*source=*/operands[1], /*init_value=*/operands[2], *scatter)); |
| break; |
| } |
| case HloOpcode::kSlice: { |
| optional<SliceRanges> slice_ranges; |
| attrs["slice"] = {/*required=*/true, AttrTy::kSliceRanges, &slice_ranges}; |
| if (!ParseOperands(&operands, /*expected_size=*/1) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = builder->AddInstruction(HloInstruction::CreateSlice( |
| shape, operands[0], slice_ranges->starts, slice_ranges->limits, |
| slice_ranges->strides)); |
| break; |
| } |
| case HloOpcode::kDynamicSlice: { |
| optional<std::vector<tensorflow::int64>> dynamic_slice_sizes; |
| attrs["dynamic_slice_sizes"] = { |
| /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes}; |
| if (!ParseOperands(&operands, /*expected_size=*/2) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = builder->AddInstruction(HloInstruction::CreateDynamicSlice( |
| shape, /*operand=*/operands[0], /*start_indices=*/operands[1], |
| *dynamic_slice_sizes)); |
| break; |
| } |
| case HloOpcode::kDynamicUpdateSlice: { |
| if (!ParseOperands(&operands, /*expected_size=*/3) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = |
| builder->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( |
| shape, /*operand=*/operands[0], /*update=*/operands[1], |
| /*start_indices=*/operands[2])); |
| break; |
| } |
| case HloOpcode::kTranspose: { |
| optional<std::vector<tensorflow::int64>> dimensions; |
| attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, |
| &dimensions}; |
| if (!ParseOperands(&operands, /*expected_size=*/1) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = builder->AddInstruction( |
| HloInstruction::CreateTranspose(shape, operands[0], *dimensions)); |
| break; |
| } |
| case HloOpcode::kBatchNormTraining: { |
| optional<float> epsilon; |
| attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; |
| optional<tensorflow::int64> feature_index; |
| attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, |
| &feature_index}; |
| if (!ParseOperands(&operands, /*expected_size=*/3) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = |
| builder->AddInstruction(HloInstruction::CreateBatchNormTraining( |
| shape, /*operand=*/operands[0], /*scale=*/operands[1], |
| /*offset=*/operands[2], *epsilon, *feature_index)); |
| break; |
| } |
| case HloOpcode::kBatchNormInference: { |
| optional<float> epsilon; |
| attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; |
| optional<tensorflow::int64> feature_index; |
| attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, |
| &feature_index}; |
| if (!ParseOperands(&operands, /*expected_size=*/5) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = |
| builder->AddInstruction(HloInstruction::CreateBatchNormInference( |
| shape, /*operand=*/operands[0], /*scale=*/operands[1], |
| /*offset=*/operands[2], /*mean=*/operands[3], |
| /*variance=*/operands[4], *epsilon, *feature_index)); |
| break; |
| } |
| case HloOpcode::kBatchNormGrad: { |
| optional<float> epsilon; |
| attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; |
| optional<tensorflow::int64> feature_index; |
| attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, |
| &feature_index}; |
| if (!ParseOperands(&operands, /*expected_size=*/5) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = builder->AddInstruction(HloInstruction::CreateBatchNormGrad( |
| shape, /*operand=*/operands[0], /*scale=*/operands[1], |
| /*mean=*/operands[2], /*variance=*/operands[3], |
| /*grad_output=*/operands[4], *epsilon, *feature_index)); |
| break; |
| } |
| case HloOpcode::kPad: { |
| optional<PaddingConfig> padding; |
| attrs["padding"] = {/*required=*/true, AttrTy::kPaddingConfig, &padding}; |
| if (!ParseOperands(&operands, /*expected_size=*/2) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = builder->AddInstruction(HloInstruction::CreatePad( |
| shape, operands[0], /*padding_value=*/operands[1], *padding)); |
| break; |
| } |
| case HloOpcode::kFusion: { |
| optional<HloComputation*> fusion_computation; |
| attrs["calls"] = {/*required=*/true, AttrTy::kHloComputation, |
| &fusion_computation}; |
| optional<HloInstruction::FusionKind> fusion_kind; |
| attrs["kind"] = {/*required=*/true, AttrTy::kFusionKind, &fusion_kind}; |
| if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = builder->AddInstruction(HloInstruction::CreateFusion( |
| shape, *fusion_kind, operands, *fusion_computation)); |
| break; |
| } |
| case HloOpcode::kInfeed: { |
| optional<string> config; |
| attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config}; |
| if (!ParseOperands(&operands, /*expected_size=*/1) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| // We need to know the infeed data shape to construct the infeed |
| // instruction. This is the zero-th element of the tuple-shaped output of |
| // the infeed instruction. ShapeUtil::GetTupleElementShape will check fail |
| // if the shape is not a non-empty tuple, so add guard so an error message |
| // can be emitted instead of a check fail |
| if (!ShapeUtil::IsTuple(shape) && !ShapeUtil::IsEmptyTuple(shape)) { |
| return Error(lexer_.GetLoc(), |
| "infeed must have a non-empty tuple shape"); |
| } |
| instruction = builder->AddInstruction(HloInstruction::CreateInfeed( |
| ShapeUtil::GetTupleElementShape(shape, 0), operands[0], |
| config ? *config : "")); |
| break; |
| } |
| case HloOpcode::kOutfeed: { |
| optional<string> config; |
| attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config}; |
| if (!ParseOperands(&operands, /*expected_size=*/2) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = builder->AddInstruction( |
| HloInstruction::CreateOutfeed(operands[0]->shape(), operands[0], |
| operands[1], config ? *config : "")); |
| break; |
| } |
| case HloOpcode::kRng: { |
| optional<RandomDistribution> distribution; |
| attrs["distribution"] = {/*required=*/true, AttrTy::kDistribution, |
| &distribution}; |
| if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = builder->AddInstruction( |
| HloInstruction::CreateRng(shape, *distribution, operands)); |
| break; |
| } |
| case HloOpcode::kReducePrecision: { |
| optional<tensorflow::int64> exponent_bits; |
| optional<tensorflow::int64> mantissa_bits; |
| attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64, |
| &exponent_bits}; |
| attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64, |
| &mantissa_bits}; |
| if (!ParseOperands(&operands, /*expected_size=*/1) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = |
| builder->AddInstruction(HloInstruction::CreateReducePrecision( |
| shape, operands[0], static_cast<int>(*exponent_bits), |
| static_cast<int>(*mantissa_bits))); |
| break; |
| } |
| case HloOpcode::kConditional: { |
| optional<HloComputation*> true_computation; |
| optional<HloComputation*> false_computation; |
| attrs["true_computation"] = {/*required=*/true, AttrTy::kHloComputation, |
| &true_computation}; |
| attrs["false_computation"] = {/*required=*/true, AttrTy::kHloComputation, |
| &false_computation}; |
| if (!ParseOperands(&operands, /*expected_size=*/3) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = builder->AddInstruction(HloInstruction::CreateConditional( |
| shape, /*pred=*/operands[0], |
| /*true_computation_arg=*/operands[1], *true_computation, |
| /*false_computation_arg=*/operands[2], *false_computation)); |
| break; |
| } |
| case HloOpcode::kCustomCall: { |
| optional<string> custom_call_target; |
| optional<Window> window; |
| optional<ConvolutionDimensionNumbers> dnums; |
| optional<int64> feature_group_count; |
| attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, |
| &custom_call_target}; |
| attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; |
| attrs["dim_labels"] = {/*required=*/false, |
| AttrTy::kConvolutionDimensionNumbers, &dnums}; |
| attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64, |
| &feature_group_count}; |
| if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = builder->AddInstruction(HloInstruction::CreateCustomCall( |
| shape, operands, *custom_call_target)); |
| if (window.has_value()) { |
| instruction->set_window(*window); |
| } |
| if (dnums.has_value()) { |
| instruction->set_convolution_dimension_numbers(*dnums); |
| } |
| if (feature_group_count.has_value()) { |
| instruction->set_feature_group_count(*feature_group_count); |
| } |
| break; |
| } |
| case HloOpcode::kDot: { |
| optional<std::vector<tensorflow::int64>> lhs_contracting_dims; |
| attrs["lhs_contracting_dims"] = { |
| /*required=*/false, AttrTy::kBracedInt64List, &lhs_contracting_dims}; |
| optional<std::vector<tensorflow::int64>> rhs_contracting_dims; |
| attrs["rhs_contracting_dims"] = { |
| /*required=*/false, AttrTy::kBracedInt64List, &rhs_contracting_dims}; |
| optional<std::vector<tensorflow::int64>> lhs_batch_dims; |
| attrs["lhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, |
| &lhs_batch_dims}; |
| optional<std::vector<tensorflow::int64>> rhs_batch_dims; |
| attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, |
| &rhs_batch_dims}; |
| optional<std::vector<PrecisionConfig::Precision>> operand_precision; |
| attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList, |
| &operand_precision}; |
| |
| if (!ParseOperands(&operands, /*expected_size=*/2) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| |
| DotDimensionNumbers dnum; |
| if (lhs_contracting_dims) { |
| *dnum.mutable_lhs_contracting_dimensions() = { |
| lhs_contracting_dims->begin(), lhs_contracting_dims->end()}; |
| } |
| if (rhs_contracting_dims) { |
| *dnum.mutable_rhs_contracting_dimensions() = { |
| rhs_contracting_dims->begin(), rhs_contracting_dims->end()}; |
| } |
| if (lhs_batch_dims) { |
| *dnum.mutable_lhs_batch_dimensions() = {lhs_batch_dims->begin(), |
| lhs_batch_dims->end()}; |
| } |
| if (rhs_batch_dims) { |
| *dnum.mutable_rhs_batch_dimensions() = {rhs_batch_dims->begin(), |
| rhs_batch_dims->end()}; |
| } |
| |
| PrecisionConfig precision_config; |
| if (operand_precision) { |
| *precision_config.mutable_operand_precision() = { |
| operand_precision->begin(), operand_precision->end()}; |
| } else { |
| precision_config.mutable_operand_precision()->Resize( |
| operands.size(), PrecisionConfig::DEFAULT); |
| } |
| |
| instruction = builder->AddInstruction(HloInstruction::CreateDot( |
| shape, operands[0], operands[1], dnum, precision_config)); |
| break; |
| } |
| case HloOpcode::kGather: { |
| optional<std::vector<tensorflow::int64>> offset_dims; |
| attrs["offset_dims"] = {/*required=*/true, AttrTy::kBracedInt64List, |
| &offset_dims}; |
| optional<std::vector<tensorflow::int64>> collapsed_slice_dims; |
| attrs["collapsed_slice_dims"] = { |
| /*required=*/true, AttrTy::kBracedInt64List, &collapsed_slice_dims}; |
| optional<std::vector<tensorflow::int64>> start_index_map; |
| attrs["start_index_map"] = {/*required=*/true, AttrTy::kBracedInt64List, |
| &start_index_map}; |
| optional<tensorflow::int64> index_vector_dim; |
| attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64, |
| &index_vector_dim}; |
| optional<std::vector<tensorflow::int64>> slice_sizes; |
| attrs["slice_sizes"] = {/*required=*/true, AttrTy::kBracedInt64List, |
| &slice_sizes}; |
| |
| if (!ParseOperands(&operands, /*expected_size=*/2) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| |
| GatherDimensionNumbers dim_numbers = |
| HloGatherInstruction::MakeGatherDimNumbers( |
| /*offset_dims=*/*offset_dims, |
| /*collapsed_slice_dims=*/*collapsed_slice_dims, |
| /*start_index_map=*/*start_index_map, |
| /*index_vector_dim=*/*index_vector_dim); |
| |
| instruction = builder->AddInstruction(HloInstruction::CreateGather( |
| shape, /*operand=*/operands[0], /*start_indices=*/operands[1], |
| dim_numbers, *slice_sizes)); |
| break; |
| } |
| case HloOpcode::kScatter: { |
| optional<std::vector<tensorflow::int64>> update_window_dims; |
| attrs["update_window_dims"] = { |
| /*required=*/true, AttrTy::kBracedInt64List, &update_window_dims}; |
| optional<std::vector<tensorflow::int64>> inserted_window_dims; |
| attrs["inserted_window_dims"] = { |
| /*required=*/true, AttrTy::kBracedInt64List, &inserted_window_dims}; |
| optional<std::vector<tensorflow::int64>> scatter_dims_to_operand_dims; |
| attrs["scatter_dims_to_operand_dims"] = {/*required=*/true, |
| AttrTy::kBracedInt64List, |
| &scatter_dims_to_operand_dims}; |
| optional<tensorflow::int64> index_vector_dim; |
| attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64, |
| &index_vector_dim}; |
| |
| optional<HloComputation*> update_computation; |
| attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, |
| &update_computation}; |
| |
| if (!ParseOperands(&operands, /*expected_size=*/3) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| |
| ScatterDimensionNumbers dim_numbers = |
| HloScatterInstruction::MakeScatterDimNumbers( |
| /*update_window_dims=*/*update_window_dims, |
| /*inserted_window_dims=*/*inserted_window_dims, |
| /*scatter_dims_to_operand_dims=*/*scatter_dims_to_operand_dims, |
| /*index_vector_dim=*/*index_vector_dim); |
| |
| instruction = builder->AddInstruction(HloInstruction::CreateScatter( |
| shape, /*operand=*/operands[0], /*scatter_indices=*/operands[1], |
| /*updates=*/operands[2], *update_computation, dim_numbers)); |
| break; |
| } |
| case HloOpcode::kDomain: { |
| DomainData domain; |
| attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain}; |
| if (!ParseOperands(&operands, /*expected_size=*/1) || |
| !ParseAttributes(attrs)) { |
| return false; |
| } |
| instruction = builder->AddInstruction(HloInstruction::CreateDomain( |
| shape, operands[0], std::move(domain.exit_metadata), |
| std::move(domain.entry_metadata))); |
| break; |
| } |
| case HloOpcode::kTrace: |
| return TokenError(StrCat("parsing not yet implemented for op: ", |
| HloOpcodeString(opcode))); |
| } |
| |
| instruction->SetAndSanitizeName(name); |
| if (instruction->name() != name) { |
| return Error(name_loc, |
| StrCat("illegal instruction name: ", name, |
| "; suggest renaming to: ", instruction->name())); |
| } |
| |
| // Add shared attributes like metadata to the instruction, if they were seen. |
| if (sharding) { |
| instruction->set_sharding( |
| HloSharding::FromProto(sharding.value()).ValueOrDie()); |
| } |
| if (predecessors) { |
| for (auto* pre : *predecessors) { |
| Status status = pre->AddControlDependencyTo(instruction); |
| if (!status.ok()) { |
| return Error(name_loc, StrCat("error adding control dependency for: ", |
| name, " status: ", status.ToString())); |
| } |
| } |
| } |
| if (metadata) { |
| instruction->set_metadata(*metadata); |
| } |
| if (backend_config) { |
| instruction->set_raw_backend_config_string(std::move(*backend_config)); |
| } |
| return AddInstruction(name, instruction, name_loc); |
| } // NOLINT(readability/fn_size) |
| |
| // ::= '{' (single_sharding | tuple_sharding) '}' |
| // |
| // tuple_sharding ::= single_sharding* (',' single_sharding)* |
| bool HloParser::ParseSharding(OpSharding* sharding) { |
| // A single sharding starts with '{' and is not followed by '{'. |
| // A tuple sharding starts with '{' and is followed by '{', or is '{''}' for |
| // an empty tuple. |
| if (!ParseToken(TokKind::kLbrace, |
| "expected '{' to start sharding attribute")) { |
| return false; |
| } |
| |
| if (lexer_.GetKind() != TokKind::kLbrace && |
| lexer_.GetKind() != TokKind::kRbrace) { |
| return ParseSingleSharding(sharding, /*lbrace_pre_lexed=*/true); |
| } |
| |
| // Tuple sharding. |
| // Allow empty tuple shardings. |
| if (lexer_.GetKind() != TokKind::kRbrace) { |
| do { |
| if (!ParseSingleSharding(sharding->add_tuple_shardings(), |
| /*lbrace_pre_lexed=*/false)) { |
| return false; |
| } |
| } while (EatIfPresent(TokKind::kComma)); |
| } |
| sharding->set_type(OpSharding::Type::OpSharding_Type_TUPLE); |
| |
| return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute"); |
| } |
| |
| // ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape? |
| // ('devices=' ('[' dims ']')* device_list)? '}' |
| // dims ::= int_list device_list ::= int_list |
| bool HloParser::ParseSingleSharding(OpSharding* sharding, |
| bool lbrace_pre_lexed) { |
| if (!lbrace_pre_lexed && |
| !ParseToken(TokKind::kLbrace, |
| "expected '{' to start sharding attribute")) { |
| return false; |
| } |
| |
| LocTy loc = lexer_.GetLoc(); |
| bool maximal = false; |
| bool replicated = false; |
| std::vector<tensorflow::int64> devices; |
| std::vector<tensorflow::int64> tile_assignment_dimensions; |
| while (lexer_.GetKind() != TokKind::kRbrace) { |
| switch (lexer_.GetKind()) { |
| case TokKind::kw_maximal: |
| maximal = true; |
| lexer_.Lex(); |
| break; |
| case TokKind::kw_replicated: |
| replicated = true; |
| lexer_.Lex(); |
| break; |
| case TokKind::kAttributeName: { |
| if (lexer_.GetStrVal() == "device") { |
| if (lexer_.Lex() != TokKind::kInt) { |
| return TokenError("device= attribute must be an integer"); |
| } |
| devices = {lexer_.GetInt64Val()}; |
| lexer_.Lex(); |
| } else if (lexer_.GetStrVal() == "devices") { |
| lexer_.Lex(); |
| if (!ParseToken(TokKind::kLsquare, |
| "expected '[' to start sharding devices shape")) { |
| return false; |
| } |
| |
| do { |
| tensorflow::int64 dim; |
| if (!ParseInt64(&dim)) { |
| return false; |
| } |
| tile_assignment_dimensions.push_back(dim); |
| } while (EatIfPresent(TokKind::kComma)); |
| |
| if (!ParseToken(TokKind::kRsquare, |
| "expected ']' to start sharding devices shape")) { |
| return false; |
| } |
| do { |
| tensorflow::int64 device; |
| if (!ParseInt64(&device)) { |
| return false; |
| } |
| devices.push_back(device); |
| } while (EatIfPresent(TokKind::kComma)); |
| } else { |
| return TokenError( |
| "unknown attribute in sharding: expected device= or devices="); |
| } |
| break; |
| } |
| case TokKind::kShape: |
| // TODO(b/112302613): Left here for backward compatibility to ignore the |
| // removed tile shape data. |
| lexer_.Lex(); |
| break; |
| case TokKind::kRbrace: |
| break; |
| default: |
| return TokenError("unexpected token"); |
| } |
| } |
| |
| if (replicated) { |
| if (!devices.empty()) { |
| return Error(loc, |
| "replicated shardings should not have any devices assigned"); |
| } |
| sharding->set_type(OpSharding::Type::OpSharding_Type_REPLICATED); |
| } else if (maximal) { |
| if (devices.size() != 1) { |
| return Error(loc, |
| "maximal shardings should have exactly one device assigned"); |
| } |
| sharding->set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); |
| sharding->add_tile_assignment_devices(devices[0]); |
| } else { |
| if (devices.size() <= 1) { |
| return Error( |
| loc, "non-maximal shardings must have more than one device assigned"); |
| } |
| if (tile_assignment_dimensions.empty()) { |
| return Error( |
| loc, |
| "non-maximal shardings must have a tile assignment list including " |
| "dimensions"); |
| } |
| sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER); |
| for (tensorflow::int64 dim : tile_assignment_dimensions) { |
| sharding->add_tile_assignment_dimensions(dim); |
| } |
| for (tensorflow::int64 device : devices) { |
| sharding->add_tile_assignment_devices(device); |
| } |
| } |
| |
| lexer_.Lex(); |
| return true; |
| } |
| |
| // domain ::= '{' 'kind=' domain_kind ',' 'entry=' entry_sharding ',' |
| // 'exit=' exit_sharding '}' |
| bool HloParser::ParseDomain(DomainData* domain) { |
| std::unordered_map<string, AttrConfig> attrs; |
| optional<string> kind; |
| optional<OpSharding> entry_sharding; |
| optional<OpSharding> exit_sharding; |
| attrs["kind"] = {/*required=*/true, AttrTy::kString, &kind}; |
| attrs["entry"] = {/*required=*/true, AttrTy::kSharding, &entry_sharding}; |
| attrs["exit"] = {/*required=*/true, AttrTy::kSharding, &exit_sharding}; |
| if (!ParseSubAttributes(attrs)) { |
| return false; |
| } |
| if (*kind == ShardingMetadata::KindName()) { |
| auto entry_sharding_ptr = absl::make_unique<HloSharding>( |
| HloSharding::FromProto(*entry_sharding).ValueOrDie()); |
| auto exit_sharding_ptr = absl::make_unique<HloSharding>( |
| HloSharding::FromProto(*exit_sharding).ValueOrDie()); |
| domain->entry_metadata = |
| absl::make_unique<ShardingMetadata>(std::move(entry_sharding_ptr)); |
| domain->exit_metadata = |
| absl::make_unique<ShardingMetadata>(std::move(exit_sharding_ptr)); |
| } else { |
| return TokenError(StrCat("unsupported domain kind: ", *kind)); |
| } |
| return true; |
| } |
| |
| // '{' name+ '}' |
| bool HloParser::ParseInstructionNames( |
| std::vector<HloInstruction*>* instructions) { |
| if (!ParseToken(TokKind::kLbrace, |
| "expects '{' at the beginning of instruction name list")) { |
| return false; |
| } |
| LocTy loc = lexer_.GetLoc(); |
| do { |
| string name; |
| if (!ParseName(&name)) { |
| return Error(loc, "expects a instruction name"); |
| } |
| std::pair<HloInstruction*, LocTy>* instr = FindInstruction(name); |
| if (!instr) { |
| return TokenError(StrFormat("instruction '%s' is not defined", name)); |
| } |
| instructions->push_back(instr->first); |
| } while (EatIfPresent(TokKind::kComma)); |
| |
| return ParseToken(TokKind::kRbrace, |
| "expects '}' at the end of instruction name list"); |
| } |
| |
| bool HloParser::SetValueInLiteral(tensorflow::int64 value, |
| tensorflow::int64 linear_index, |
| Literal* literal) { |
| const Shape& shape = literal->shape(); |
| switch (shape.element_type()) { |
| case S8: |
| return SetValueInLiteralHelper<tensorflow::int8>(value, linear_index, |
| literal); |
| case S16: |
| return SetValueInLiteralHelper<tensorflow::int16>(value, linear_index, |
| literal); |
| case S32: |
| return SetValueInLiteralHelper<tensorflow::int32>(value, linear_index, |
| literal); |
| case S64: |
| return SetValueInLiteralHelper<tensorflow::int64>(value, linear_index, |
| literal); |
| case U8: |
| return SetValueInLiteralHelper<tensorflow::uint8>(value, linear_index, |
| literal); |
| case U16: |
| return SetValueInLiteralHelper<tensorflow::uint16>(value, linear_index, |
| literal); |
| case U32: |
| return SetValueInLiteralHelper<tensorflow::uint32>(value, linear_index, |
| literal); |
| case U64: |
| return SetValueInLiteralHelper<tensorflow::uint64>(value, linear_index, |
| literal); |
| default: |
| LOG(FATAL) << "unknown integral primitive type " |
| << PrimitiveType_Name(shape.element_type()); |
| } |
| } |
| |
| bool HloParser::SetValueInLiteral(double value, tensorflow::int64 linear_index, |
| Literal* literal) { |
| const Shape& shape = literal->shape(); |
| switch (shape.element_type()) { |
| case F16: |
| return SetValueInLiteralHelper<Eigen::half>(value, linear_index, literal); |
| case BF16: |
| return SetValueInLiteralHelper<tensorflow::bfloat16>(value, linear_index, |
| literal); |
| case F32: |
| return SetValueInLiteralHelper<float>(value, linear_index, literal); |
| case F64: |
| return SetValueInLiteralHelper<double>(value, linear_index, literal); |
| default: |
| LOG(FATAL) << "unknown floating point primitive type " |
| << PrimitiveType_Name(shape.element_type()); |
| } |
| } |
| |
| bool HloParser::SetValueInLiteral(bool value, tensorflow::int64 linear_index, |
| Literal* literal) { |
| const Shape& shape = literal->shape(); |
| switch (shape.element_type()) { |
| case PRED: |
| return SetValueInLiteralHelper<bool>(value, linear_index, literal); |
| default: |
| LOG(FATAL) << PrimitiveType_Name(shape.element_type()) |
| << " is not PRED type"; |
| } |
| } |
| |
| template <typename LiteralNativeT, typename ParsedElemT> |
| bool HloParser::SetValueInLiteralHelper(ParsedElemT value, |
| tensorflow::int64 linear_index, |
| Literal* literal) { |
| // Check that linear_index is in range. |
| if (linear_index >= ShapeUtil::ElementsIn(literal->shape())) { |
| return TokenError( |
| StrCat("trys to set value ", value, " to a literal in shape ", |
| ShapeUtil::HumanString(literal->shape()), " at linear index ", |
| linear_index, ", but the index is out of range")); |
| } |
| |
| if (std::isnan(value) || |
| (std::numeric_limits<ParsedElemT>::has_infinity && |
| (std::numeric_limits<ParsedElemT>::infinity() == value || |
| -std::numeric_limits<ParsedElemT>::infinity() == value))) { |
| // Skip range checking for non-finite value. |
| } else if (literal->shape().element_type() == F16 || |
| literal->shape().element_type() == BF16) { |
| if (value > kF16max || value < -kF16max) { |
| return TokenError(StrCat( |
| "value ", value, " is out of range for literal's primitive type ", |
| PrimitiveType_Name(literal->shape().element_type()))); |
| } |
| } else if (std::is_unsigned<LiteralNativeT>::value) { |
| CHECK((std::is_same<ParsedElemT, tensorflow::int64>::value || |
| std::is_same<ParsedElemT, bool>::value)) |
| << "Unimplemented checking for ParsedElemT"; |
| |
| ParsedElemT upper_bound; |
| if (sizeof(LiteralNativeT) >= sizeof(ParsedElemT)) { |
| upper_bound = std::numeric_limits<ParsedElemT>::max(); |
| } else { |
| upper_bound = |
| static_cast<ParsedElemT>(std::numeric_limits<LiteralNativeT>::max()); |
| } |
| if (value > upper_bound || value < 0) { |
| // Value is out of range for LiteralNativeT. |
| return TokenError(StrCat( |
| "value ", value, " is out of range for literal's primitive type ", |
| PrimitiveType_Name(literal->shape().element_type()))); |
| } |
| } else if (value > static_cast<ParsedElemT>( |
| std::numeric_limits<LiteralNativeT>::max()) || |
| value < static_cast<ParsedElemT>( |
| std::numeric_limits<LiteralNativeT>::lowest())) { |
| // Value is out of range for LiteralNativeT. |
| return TokenError(StrCat( |
| "value ", value, " is out of range for literal's primitive type ", |
| PrimitiveType_Name(literal->shape().element_type()))); |
| } |
| |
| literal->data<LiteralNativeT>().at(linear_index) = |
| static_cast<LiteralNativeT>(value); |
| return true; |
| } |
| |
| bool HloParser::EatShapeAndCheckCompatible(const Shape& shape) { |
| Shape new_shape; |
| if (!ParseShape(&new_shape)) { |
| return TokenError(StrCat("expects shape ", ShapeUtil::HumanString(shape))); |
| } |
| if (!ShapeUtil::Compatible(shape, new_shape)) { |
| return TokenError(StrCat( |
| "expects shape ", ShapeUtil::HumanString(shape), |
| ", but sees a different shape: ", ShapeUtil::HumanString(new_shape))); |
| } |
| return true; |
| } |
| |
| // literal |
| // ::= tuple |
| // ::= non_tuple |
| bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) { |
| return ShapeUtil::IsTuple(shape) ? ParseTupleLiteral(literal, shape) |
| : ParseNonTupleLiteral(literal, shape); |
| } |
| |
| // tuple |
| // ::= shape '(' literal_list ')' |
| // literal_list |
| // ::= /*empty*/ |
| // ::= literal (',' literal)* |
| bool HloParser::ParseTupleLiteral(Literal* literal, const Shape& shape) { |
| if (!EatShapeAndCheckCompatible(shape)) { |
| return TokenError(StrCat("expects tuple constant in shape ", |
| ShapeUtil::HumanString(shape))); |
| } |
| if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) { |
| return false; |
| } |
| std::vector<Literal> elements(ShapeUtil::TupleElementCount(shape)); |
| |
| if (lexer_.GetKind() == TokKind::kRparen) { |
| // empty |
| } else { |
| // literal, (',' literal)* |
| for (int i = 0; i < elements.size(); i++) { |
| if (i > 0) { |
| ParseToken(TokKind::kComma, "exepcts ',' to separate tuple elements"); |
| } |
| if (!ParseLiteral(&elements[i], |
| ShapeUtil::GetTupleElementShape(shape, i))) { |
| return TokenError(StrCat("expects the ", i, "th element")); |
| } |
| } |
| } |
| *literal = LiteralUtil::MakeTupleOwned(std::move(elements)); |
| return ParseToken(TokKind::kRparen, |
| StrCat("expects ')' at the end of the tuple with ", |
| ShapeUtil::TupleElementCount(shape), "elements")); |
| } |
| |
| // non_tuple |
| // ::= rank01 |
| // ::= rank2345 |
| // rank2345 ::= shape sparse_or_nested_array |
| bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) { |
| if (LayoutUtil::IsSparseArray(shape)) { |
| return ParseSparseLiteral(literal, shape); |
| } |
| |
| CHECK(LayoutUtil::IsDenseArray(shape)); |
| return ParseDenseLiteral(literal, shape); |
| } |
| |
| bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { |
| const tensorflow::int64 rank = ShapeUtil::Rank(shape); |
| if (rank > 1 && !EatShapeAndCheckCompatible(shape)) { |
| return false; |
| } |
| |
| // Create a literal with the given shape in default layout. |
| *literal = LiteralUtil::CreateFromDimensions( |
| shape.element_type(), AsInt64Slice(shape.dimensions())); |
| tensorflow::int64 nest_level = 0; |
| tensorflow::int64 linear_index = 0; |
| // elems_seen_per_dim[i] is how many elements or sub-arrays we have seen for |
| // the dimension i. For example, to parse f32[2,3] {{1, 2, 3}, {4, 5, 6}}, |
| // when we are parsing the 2nd '{' (right before '1'), we are seeing a |
| // sub-array of the dimension 0, so elems_seen_per_dim[0]++. When we are at |
| // the first '}' (right after '3'), it means the sub-array ends, and the |
| // sub-array is supposed to contain exactly 3 elements, so check if |
| // elems_seen_per_dim[1] is 3. |
| std::vector<tensorflow::int64> elems_seen_per_dim(rank); |
| auto get_index_str = [&elems_seen_per_dim](int dim) -> string { |
| std::vector<tensorflow::int64> elems_seen_until_dim( |
| elems_seen_per_dim.begin(), elems_seen_per_dim.begin() + dim); |
| return StrCat("[", |
| StrJoin(elems_seen_until_dim, ",", |
| [](string* out, const tensorflow::int64& num_elems) { |
| StrAppend(out, num_elems - 1); |
| }), |
| "]"); |
| }; |
| do { |
| switch (lexer_.GetKind()) { |
| default: |
| return TokenError("unexpected token type in a literal"); |
| case TokKind::kLbrace: { |
| nest_level++; |
| if (nest_level > rank) { |
| return TokenError(absl::StrFormat( |
| "expects nested array in rank %d, but sees larger", rank)); |
| } |
| if (nest_level > 1) { |
| elems_seen_per_dim[nest_level - 2]++; |
| if (elems_seen_per_dim[nest_level - 2] > |
| shape.dimensions(nest_level - 2)) { |
| return TokenError(absl::StrFormat( |
| "expects %d elements in the %sth element, but sees more", |
| shape.dimensions(nest_level - 2), |
| get_index_str(nest_level - 2))); |
| } |
| } |
| lexer_.Lex(); |
| break; |
| } |
| case TokKind::kRbrace: { |
| nest_level--; |
| if (elems_seen_per_dim[nest_level] != shape.dimensions(nest_level)) { |
| return TokenError(absl::StrFormat( |
| "expects %d elements in the %sth element, but sees %d", |
| shape.dimensions(nest_level), get_index_str(nest_level), |
| elems_seen_per_dim[nest_level])); |
| } |
| elems_seen_per_dim[nest_level] = 0; |
| lexer_.Lex(); |
| break; |
| } |
| case TokKind::kComma: |
| // Skip. |
| lexer_.Lex(); |
| break; |
| case TokKind::kw_true: |
| case TokKind::kw_false: |
| case TokKind::kInt: |
| case TokKind::kDecimal: |
| case TokKind::kw_nan: |
| case TokKind::kw_inf: |
| case TokKind::kNegInf: { |
| if (rank > 0) { |
| if (nest_level != rank) { |
| return TokenError( |
| absl::StrFormat("expects nested array in rank %d, but sees %d", |
| rank, nest_level)); |
| } |
| elems_seen_per_dim[rank - 1]++; |
| if (elems_seen_per_dim[rank - 1] > shape.dimensions(rank - 1)) { |
| return TokenError(absl::StrFormat( |
| "expects %d elements on the minor-most dimension, but " |
| "sees more", |
| shape.dimensions(rank - 1))); |
| } |
| } |
| if (lexer_.GetKind() == TokKind::kw_true || |
| lexer_.GetKind() == TokKind::kw_false) { |
| // TODO(congliu): bool type literals with rank >= 1 are actually |
| // printed in a compact form instead of "true" or "false". Fix that. |
| if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true, |
| linear_index++, literal)) { |
| return false; |
| } |
| lexer_.Lex(); |
| } else if (primitive_util::IsIntegralType(shape.element_type())) { |
| LocTy loc = lexer_.GetLoc(); |
| tensorflow::int64 value; |
| if (!ParseInt64(&value)) { |
| return Error(loc, StrCat("expects integer for primitive type: ", |
| PrimitiveType_Name(shape.element_type()))); |
| } |
| if (!SetValueInLiteral(value, linear_index++, literal)) { |
| return false; |
| } |
| } else if (primitive_util::IsFloatingPointType(shape.element_type())) { |
| LocTy loc = lexer_.GetLoc(); |
| double value; |
| if (!ParseDouble(&value)) { |
| return Error( |
| loc, StrCat("expect floating point value for primitive type: ", |
| PrimitiveType_Name(shape.element_type()))); |
| } |
| if (!SetValueInLiteral(value, linear_index++, literal)) { |
| return false; |
| } |
| } else { |
| return TokenError(StrCat("unsupported primitive type ", |
| PrimitiveType_Name(shape.element_type()))); |
| } |
| break; |
| } |
| } // end of switch |
| } while (nest_level > 0); |
| |
| *literal = literal->Relayout(shape.layout()); |
| return true; |
| } |
| |
| bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) { |
| if (!EatShapeAndCheckCompatible(shape)) { |
| return false; |
| } |
| |
| switch (shape.element_type()) { |
| case PRED: |
| return ParseSparseLiteralHelper<tensorflow::uint8>(literal, shape); |
| case S8: |
| return ParseSparseLiteralHelper<tensorflow::int8>(literal, shape); |
| case S16: |
| return ParseSparseLiteralHelper<tensorflow::int16>(literal, shape); |
| case S32: |
| return ParseSparseLiteralHelper<tensorflow::int32>(literal, shape); |
| case S64: |
| return ParseSparseLiteralHelper<tensorflow::int64>(literal, shape); |
| case U8: |
| return ParseSparseLiteralHelper<tensorflow::uint8>(literal, shape); |
| case U16: |
| return ParseSparseLiteralHelper<tensorflow::uint16>(literal, shape); |
| case U32: |
| return ParseSparseLiteralHelper<tensorflow::uint32>(literal, shape); |
| case U64: |
| return ParseSparseLiteralHelper<tensorflow::uint64>(literal, shape); |
| case F16: |
| return ParseSparseLiteralHelper<Eigen::half>(literal, shape); |
| case F32: |
| return ParseSparseLiteralHelper<float>(literal, shape); |
| case BF16: |
| return ParseSparseLiteralHelper<tensorflow::bfloat16>(literal, shape); |
| case F64: |
| return ParseSparseLiteralHelper<double>(literal, shape); |
| default: |
| return Error(lexer_.GetLoc(), |
| StrCat("invalid primitive type for sparse literal: ", |
| PrimitiveType_Name(shape.element_type()))); |
| } |
| } |
| |
| template <typename LiteralNativeT> |
| bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) { |
| std::vector<tensorflow::int64> index; |
| |
| tensorflow::int64 rank = ShapeUtil::Rank(shape); |
| |
| *literal = Literal(shape); |
| |
| if (!ParseToken(TokKind::kLbrace, |
| "expects '{' at the beginning of a sparse literal")) { |
| return false; |
| } |
| |
| for (;;) { |
| if (lexer_.GetKind() == TokKind::kRbrace) { |
| lexer_.Lex(); |
| break; |
| } |
| |
| LocTy index_loc = lexer_.GetLoc(); |
| index.clear(); |
| if (lexer_.GetKind() == TokKind::kInt) { |
| tensorflow::int64 single_index = lexer_.GetInt64Val(); |
| lexer_.Lex(); |
| if (rank != 1) { |
| return Error( |
| index_loc, |
| StrCat("invalid single-dimensional index for shape with rank ", |
| rank, ": ", single_index)); |
| } |
| index.push_back(single_index); |
| } else { |
| if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma, |
| &index)) { |
| return false; |
| } |
| if (index.size() != rank) { |
| return Error( |
| index_loc, |
| StrCat("invalid multi-dimension index for shape with rank ", rank, |
| ": [", StrJoin(index, ", "), "]")); |
| } |
| } |
| if (!ParseToken(TokKind::kColon, |
| "expects ':' after after the sparse array index and before " |
| "the sparse array value")) { |
| return false; |
| } |
| LocTy value_loc = lexer_.GetLoc(); |
| LiteralNativeT value; |
| if (lexer_.GetKind() == TokKind::kw_true || |
| lexer_.GetKind() == TokKind::kw_false) { |
| value = static_cast<LiteralNativeT>(lexer_.GetKind() == TokKind::kw_true); |
| lexer_.Lex(); |
| } else if (primitive_util::IsIntegralType(shape.element_type())) { |
| tensorflow::int64 value_s64; |
| if (!ParseInt64(&value_s64)) { |
| return Error(value_loc, |
| StrCat("expects integer for primitive type: ", |
| PrimitiveType_Name(shape.element_type()))); |
| } |
| value = static_cast<LiteralNativeT>(value_s64); |
| } else if (primitive_util::IsFloatingPointType(shape.element_type())) { |
| double value_f64; |
| if (!ParseDouble(&value_f64)) { |
| return Error(value_loc, |
| StrCat("expects floating point value for primitive type: ", |
| PrimitiveType_Name(shape.element_type()))); |
| } |
| value = static_cast<LiteralNativeT>(value_f64); |
| } else { |
| LOG(FATAL) << "Unexpected element type: " |
| << PrimitiveType_Name(shape.element_type()); |
| } |
| if (lexer_.GetKind() != TokKind::kRbrace && |
| !ParseToken(TokKind::kComma, |
| "expects ',' separator between sparse array elements")) { |
| return false; |
| } |
| |
| if (literal->sparse_element_count() + 1 == |
| LayoutUtil::MaxSparseElements(shape.layout())) { |
| return Error( |
| lexer_.GetLoc(), |
| StrCat("number of sparse elements exceeds maximum for layout: ", |
| ShapeUtil::HumanStringWithLayout(shape))); |
| } |
| |
| literal->AppendSparseElement(index, value); |
| } |
| |
| literal->SortSparseElements(); |
| return true; |
| } |
| |
| // operands ::= '(' operands1 ')' |
| // operands1 |
| // ::= /*empty*/ |
| // ::= operand (, operand)* |
| // operand ::= (shape)? name |
| bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) { |
| CHECK(operands != nullptr); |
| if (!ParseToken(TokKind::kLparen, |
| "expects '(' at the beginning of operands")) { |
| return false; |
| } |
| if (lexer_.GetKind() == TokKind::kRparen) { |
| // empty |
| } else { |
| do { |
| LocTy loc = lexer_.GetLoc(); |
| string name; |
| optional<Shape> shape; |
| if (CanBeShape()) { |
| shape.emplace(); |
| if (!ParseShape(&shape.value())) { |
| return false; |
| } |
| } |
| if (!ParseName(&name)) { |
| return false; |
| } |
| std::pair<HloInstruction*, LocTy>* instruction = |
| FindInstruction(name, shape); |
| if (instruction == nullptr) { |
| return Error(loc, StrCat("instruction does not exist: ", name)); |
| } |
| operands->push_back(instruction->first); |
| } while (EatIfPresent(TokKind::kComma)); |
| } |
| return ParseToken(TokKind::kRparen, "expects ')' at the end of operands"); |
| } |
| |
| bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands, |
| const int expected_size) { |
| CHECK(operands != nullptr); |
| LocTy loc = lexer_.GetLoc(); |
| if (!ParseOperands(operands)) { |
| return false; |
| } |
| if (expected_size != operands->size()) { |
| return Error(loc, StrCat("expects ", expected_size, " operands, but has ", |
| operands->size(), " operands")); |
| } |
| return true; |
| } |
| |
| // sub_attributes ::= '{' (','? attribute)* '}' |
| bool HloParser::ParseSubAttributes( |
| const std::unordered_map<string, AttrConfig>& attrs) { |
| LocTy loc = lexer_.GetLoc(); |
| if (!ParseToken(TokKind::kLbrace, "expects '{' to start sub attributes")) { |
| return false; |
| } |
| std::unordered_set<string> seen_attrs; |
| if (lexer_.GetKind() == TokKind::kRbrace) { |
| // empty |
| } else { |
| do { |
| EatIfPresent(TokKind::kComma); |
| if (!ParseAttributeHelper(attrs, &seen_attrs)) { |
| return false; |
| } |
| } while (lexer_.GetKind() != TokKind::kRbrace); |
| } |
| // Check that all required attrs were seen. |
| for (const auto& attr_it : attrs) { |
| if (attr_it.second.required && |
| seen_attrs.find(attr_it.first) == seen_attrs.end()) { |
| return Error(loc, StrFormat("sub-attribute %s is expected but not seen", |
| attr_it.first)); |
| } |
| } |
| return ParseToken(TokKind::kRbrace, "expects '}' to end sub attributes"); |
| } |
| |
| // attributes ::= (',' attribute)* |
| bool HloParser::ParseAttributes( |
| const std::unordered_map<string, AttrConfig>& attrs) { |
| LocTy loc = lexer_.GetLoc(); |
| std::unordered_set<string> seen_attrs; |
| while (EatIfPresent(TokKind::kComma)) { |
| if (!ParseAttributeHelper(attrs, &seen_attrs)) { |
| return false; |
| } |
| } |
| // Check that all required attrs were seen. |
| for (const auto& attr_it : attrs) { |
| if (attr_it.second.required && |
| seen_attrs.find(attr_it.first) == seen_attrs.end()) { |
| return Error(loc, StrFormat("attribute %s is expected but not seen", |
| attr_it.first)); |
| } |
| } |
| return true; |
| } |
| |
| bool HloParser::ParseAttributeHelper( |
| const std::unordered_map<string, AttrConfig>& attrs, |
| std::unordered_set<string>* seen_attrs) { |
| LocTy loc = lexer_.GetLoc(); |
| string name; |
| if (!ParseAttributeName(&name)) { |
| return Error(loc, "error parsing attributes"); |
| } |
| VLOG(1) << "Parsing attribute " << name; |
| if (!seen_attrs->insert(name).second) { |
| return Error(loc, StrFormat("attribute %s already exists", name)); |
| } |
| auto attr_it = attrs.find(name); |
| if (attr_it == attrs.end()) { |
| string allowed_attrs; |
| if (attrs.empty()) { |
| allowed_attrs = "No attributes are allowed here."; |
| } else { |
| allowed_attrs = StrCat( |
| "Allowed attributes: ", |
| StrJoin(attrs, ", ", |
| [&](string* out, const std::pair<string, AttrConfig>& kv) { |
| StrAppend(out, kv.first); |
| })); |
| } |
| return Error(loc, StrFormat("unexpected attribute \"%s\". %s", name, |
| allowed_attrs)); |
| } |
| AttrTy attr_type = attr_it->second.attr_type; |
| void* attr_out_ptr = attr_it->second.result; |
| bool success = [&] { |
| LocTy attr_loc = lexer_.GetLoc(); |
| switch (attr_type) { |
| case AttrTy::kBool: { |
| bool result; |
| if (!ParseBool(&result)) { |
| return false; |
| } |
| static_cast<optional<bool>*>(attr_out_ptr)->emplace(result); |
| return true; |
| } |
| case AttrTy::kInt64: { |
| tensorflow::int64 result; |
| if (!ParseInt64(&result)) { |
| return false; |
| } |
| static_cast<optional<tensorflow::int64>*>(attr_out_ptr) |
| ->emplace(result); |
| return true; |
| } |
| case AttrTy::kInt32: { |
| tensorflow::int64 result; |
| if (!ParseInt64(&result)) { |
| return false; |
| } |
| if (result != static_cast<tensorflow::int32>(result)) { |
| return Error(attr_loc, "value out of range for int32"); |
| } |
| static_cast<optional<tensorflow::int32>*>(attr_out_ptr) |
| ->emplace(static_cast<tensorflow::int32>(result)); |
| return true; |
| } |
| case AttrTy::kFloat: { |
| double result; |
| if (!ParseDouble(&result)) { |
| return false; |
| } |
| if (result > std::numeric_limits<float>::max() || |
| result < std::numeric_limits<float>::lowest()) { |
| return Error(attr_loc, "value out of range for float"); |
| } |
| static_cast<optional<float>*>(attr_out_ptr) |
| ->emplace(static_cast<float>(result)); |
| return true; |
| } |
| case AttrTy::kHloComputation: { |
| HloComputation* result; |
| if (!ParseComputationName(&result)) { |
| return false; |
| } |
| static_cast<optional<HloComputation*>*>(attr_out_ptr)->emplace(result); |
| return true; |
| } |
| case AttrTy::kFftType: { |
| FftType result; |
| if (!ParseFftType(&result)) { |
| return false; |
| } |
| static_cast<optional<FftType>*>(attr_out_ptr)->emplace(result); |
| return true; |
| } |
| case AttrTy::kWindow: { |
| Window result; |
| if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) { |
| return false; |
| } |
| static_cast<optional<Window>*>(attr_out_ptr)->emplace(result); |
| return true; |
| } |
| case AttrTy::kConvolutionDimensionNumbers: { |
| ConvolutionDimensionNumbers result; |
| if (!ParseConvolutionDimensionNumbers(&result)) { |
| return false; |
| } |
| static_cast<optional<ConvolutionDimensionNumbers>*>(attr_out_ptr) |
| ->emplace(result); |
| return true; |
| } |
| case AttrTy::kSharding: { |
| OpSharding sharding; |
| if (!ParseSharding(&sharding)) { |
| return false; |
| } |
| static_cast<optional<OpSharding>*>(attr_out_ptr)->emplace(sharding); |
| return true; |
| } |
| case AttrTy::kInstructionList: { |
| std::vector<HloInstruction*> result; |
| if (!ParseInstructionNames(&result)) { |
| return false; |
| } |
| static_cast<optional<std::vector<HloInstruction*>>*>(attr_out_ptr) |
| ->emplace(result); |
| return true; |
| } |
| case AttrTy::kFusionKind: { |
| HloInstruction::FusionKind result; |
| if (!ParseFusionKind(&result)) { |
| return false; |
| } |
| static_cast<optional<HloInstruction::FusionKind>*>(attr_out_ptr) |
| ->emplace(result); |
| return true; |
| } |
| case AttrTy::kBracedInt64List: { |
| std::vector<tensorflow::int64> result; |
| if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, |
| &result)) { |
| return false; |
| } |
| static_cast<optional<std::vector<tensorflow::int64>>*>(attr_out_ptr) |
| ->emplace(result); |
| return true; |
| } |
| case AttrTy::kBracedInt64ListList: { |
| std::vector<std::vector<tensorflow::int64>> result; |
| auto parse_and_add_item = [&]() { |
| std::vector<tensorflow::int64> item; |
| if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, |
| TokKind::kComma, &item)) { |
| return false; |
| } |
| result.push_back(item); |
| return true; |
| }; |
| if (!ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, |
| parse_and_add_item)) { |
| return false; |
| } |
| static_cast<optional<std::vector<std::vector<tensorflow::int64>>>*>( |
| attr_out_ptr) |
| ->emplace(result); |
| return true; |
| } |
| case AttrTy::kSliceRanges: { |
| SliceRanges result; |
| if (!ParseSliceRanges(&result)) { |
| return false; |
| } |
| static_cast<optional<SliceRanges>*>(attr_out_ptr)->emplace(result); |
| return true; |
| } |
| case AttrTy::kPaddingConfig: { |
| PaddingConfig result; |
| if (!ParsePaddingConfig(&result)) { |
| return false; |
| } |
| static_cast<optional<PaddingConfig>*>(attr_out_ptr)->emplace(result); |
| return true; |
| } |
| case AttrTy::kString: { |
| string result; |
| if (!ParseString(&result)) { |
| return false; |
| } |
| static_cast<optional<string>*>(attr_out_ptr)->emplace(result); |
| return true; |
| } |
| case AttrTy::kMetadata: { |
| OpMetadata result; |
| if (!ParseMetadata(&result)) { |
| return false; |
| } |
| static_cast<optional<OpMetadata>*>(attr_out_ptr)->emplace(result); |
| return true; |
| } |
| case AttrTy::kDistribution: { |
| RandomDistribution result; |
| if (!ParseRandomDistribution(&result)) { |
| return false; |
| } |
| static_cast<optional<RandomDistribution>*>(attr_out_ptr) |
| ->emplace(result); |
| return true; |
| } |
| case AttrTy::kDomain: { |
| return ParseDomain(static_cast<DomainData*>(attr_out_ptr)); |
| } |
| case AttrTy::kPrecisionList: { |
| std::vector<PrecisionConfig::Precision> result; |
| if (!ParsePrecisionList(&result)) { |
| return false; |
| } |
| static_cast<optional<std::vector<PrecisionConfig::Precision>>*>( |
| attr_out_ptr) |
| ->emplace(result); |
| return true; |
| } |
| } |
| }(); |
| if (!success) { |
| return Error(loc, StrFormat("error parsing attribute %s", name)); |
| } |
| return true; |
| } |
| |
| bool HloParser::ParseComputationName(HloComputation** value) { |
| string name; |
| LocTy loc = lexer_.GetLoc(); |
| if (!ParseName(&name)) { |
| return Error(loc, "expects computation name"); |
| } |
| std::pair<HloComputation*, LocTy>* computation = |
| tensorflow::gtl::FindOrNull(computation_pool_, name); |
| if (computation == nullptr) { |
| return Error(loc, StrCat("computation does not exist: ", name)); |
| } |
| *value = computation->first; |
| return true; |
| } |
| |
| // ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}' |
| // The subattributes can appear in any order. 'size=' is required, others are |
| // optional. |
| bool HloParser::ParseWindow(Window* window, bool expect_outer_curlies) { |
| LocTy loc = lexer_.GetLoc(); |
| if (expect_outer_curlies && |
| !ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) { |
| return false; |
| } |
| |
| std::vector<int64> size; |
| std::vector<int64> stride; |
| std::vector<std::vector<int64>> pad; |
| std::vector<int64> lhs_dilate; |
| std::vector<int64> rhs_dilate; |
| std::vector<int64> rhs_reversal; |
| const auto end_token = |
| expect_outer_curlies ? TokKind::kRbrace : TokKind::kEof; |
| while (lexer_.GetKind() != end_token) { |
| LocTy attr_loc = lexer_.GetLoc(); |
| string field_name; |
| if (!ParseAttributeName(&field_name)) { |
| return Error(attr_loc, "expects sub-attributes in window"); |
| } |
| bool ok = [&] { |
| if (field_name == "size") { |
| return ParseDxD("size", &size); |
| } |
| if (field_name == "stride") { |
| return ParseDxD("stride", &stride); |
| } |
| if (field_name == "lhs_dilate") { |
| return ParseDxD("lhs_dilate", &lhs_dilate); |
| } |
| if (field_name == "rhs_dilate") { |
| return ParseDxD("rls_dilate", &rhs_dilate); |
| } |
| if (field_name == "pad") { |
| return ParseWindowPad(&pad); |
| } |
| if (field_name == "rhs_reversal") { |
| return ParseDxD("rhs_reversal", &rhs_reversal); |
| } |
| return Error(attr_loc, StrCat("unexpected attribute name: ", field_name)); |
| }(); |
| if (!ok) { |
| return false; |
| } |
| } |
| |
| if (size.empty()) { |
| return Error(loc, |
| "sub-attribute 'size=' is required in the window attribute"); |
| } |
| if (!stride.empty() && stride.size() != size.size()) { |
| return Error(loc, "expects 'stride=' has the same size as 'size='"); |
| } |
| if (!lhs_dilate.empty() && lhs_dilate.size() != size.size()) { |
| return Error(loc, "expects 'lhs_dilate=' has the same size as 'size='"); |
| } |
| if (!rhs_dilate.empty() && rhs_dilate.size() != size.size()) { |
| return Error(loc, "expects 'rhs_dilate=' has the same size as 'size='"); |
| } |
| if (!pad.empty() && pad.size() != size.size()) { |
| return Error(loc, "expects 'pad=' has the same size as 'size='"); |
| } |
| |
| for (int i = 0; i < size.size(); i++) { |
| window->add_dimensions()->set_size(size[i]); |
| if (!pad.empty()) { |
| window->mutable_dimensions(i)->set_padding_low(pad[i][0]); |
| window->mutable_dimensions(i)->set_padding_high(pad[i][1]); |
| } |
| // If some field is not present, it has the default value. |
| window->mutable_dimensions(i)->set_stride(stride.empty() ? 1 : stride[i]); |
| window->mutable_dimensions(i)->set_base_dilation( |
| lhs_dilate.empty() ? 1 : lhs_dilate[i]); |
| window->mutable_dimensions(i)->set_window_dilation( |
| rhs_dilate.empty() ? 1 : rhs_dilate[i]); |
| window->mutable_dimensions(i)->set_window_reversal( |
| rhs_reversal.empty() ? false : (rhs_reversal[i] == 1)); |
| } |
| return !expect_outer_curlies || |
| ParseToken(TokKind::kRbrace, "expected '}' to end window attribute"); |
| } |
| |
| // This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString. |
| // The string looks like "dim_labels=0bf_0io->0bf". |
| bool HloParser::ParseConvolutionDimensionNumbers( |
| ConvolutionDimensionNumbers* dnums) { |
| if (lexer_.GetKind() != TokKind::kDimLabels) { |
| return TokenError("expects dim labels pattern, e.g., 'bf0_0io->0bf'"); |
| } |
| string str = lexer_.GetStrVal(); |
| |
| // The str is expected to have 3 items, lhs, rhs, out, and it must look like |
| // lhs_rhs->out, that is, the first separator is "_" and the second is "->". |
| std::vector<string> split1 = absl::StrSplit(str, "_"); |
| if (split1.size() != 2) { |
| LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees " |
| << str; |
| } |
| std::vector<string> split2 = absl::StrSplit(split1[1], "->"); |
| if (split2.size() != 2) { |
| LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees " |
| << str; |
| } |
| absl::string_view lhs = split1[0]; |
| absl::string_view rhs = split2[0]; |
| absl::string_view out = split2[1]; |
| |
| const tensorflow::int64 rank = lhs.length(); |
| if (rank != rhs.length() || rank != out.length()) { |
| return TokenError( |
| "convolution lhs, rhs, and output must have the same rank"); |
| } |
| if (rank < 2) { |
| return TokenError("convolution rank must >=2"); |
| } |
| |
| auto is_unique = [](string str) -> bool { |
| std::sort(str.begin(), str.end()); |
| return std::unique(str.begin(), str.end()) == str.end(); |
| }; |
| |
| // lhs |
| { |
| if (!is_unique(string(lhs))) { |
| return TokenError( |
| StrCat("expects unique lhs dimension numbers, but sees ", lhs)); |
| } |
| for (int i = 0; i < rank - 2; i++) { |
| dnums->add_input_spatial_dimensions(-1); |
| } |
| for (int i = 0; i < rank; i++) { |
| char c = lhs[i]; |
| if (c == 'b') { |
| dnums->set_input_batch_dimension(i); |
| } else if (c == 'f') { |
| dnums->set_input_feature_dimension(i); |
| } else if (c < '0' + rank && c >= '0') { |
| dnums->set_input_spatial_dimensions(c - '0', i); |
| } else { |
| return TokenError( |
| StrFormat("expects [0-%dbf] in lhs dimension numbers", rank - 1)); |
| } |
| } |
| } |
| // rhs |
| { |
| if (!is_unique(string(rhs))) { |
| return TokenError( |
| StrCat("expects unique rhs dimension numbers, but sees ", rhs)); |
| } |
| for (int i = 0; i < rank - 2; i++) { |
| dnums->add_kernel_spatial_dimensions(-1); |
| } |
| for (int i = 0; i < rank; i++) { |
| char c = rhs[i]; |
| if (c == 'i') { |
| dnums->set_kernel_input_feature_dimension(i); |
| } else if (c == 'o') { |
| dnums->set_kernel_output_feature_dimension(i); |
| } else if (c < '0' + rank && c >= '0') { |
| dnums->set_kernel_spatial_dimensions(c - '0', i); |
| } else { |
| return TokenError( |
| StrFormat("expects [0-%dio] in rhs dimension numbers", rank - 1)); |
| } |
| } |
| } |
| // output |
| { |
| if (!is_unique(string(out))) { |
| return TokenError( |
| StrCat("expects unique output dimension numbers, but sees ", out)); |
| } |
| for (int i = 0; i < rank - 2; i++) { |
| dnums->add_output_spatial_dimensions(-1); |
| } |
| for (int i = 0; i < rank; i++) { |
| char c = out[i]; |
| if (c == 'b') { |
| dnums->set_output_batch_dimension(i); |
| } else if (c == 'f') { |
| dnums->set_output_feature_dimension(i); |
| } else if (c < '0' + rank && c >= '0') { |
| dnums->set_output_spatial_dimensions(c - '0', i); |
| } else { |
| return TokenError(StrFormat( |
| "expects [0-%dbf] in output dimension numbers", rank - 1)); |
| } |
| } |
| } |
| |
| lexer_.Lex(); |
| return true; |
| } |
| |
| // ::= '{' ranges '}' |
| // ::= /*empty*/ |
| // ::= range (',' range)* |
| // range ::= '[' start ':' limit (':' stride)? ']' |
| // |
| // The slice ranges are printed as: |
| // |
| // {[dim0_start:dim0_limit:dim0stride], [dim1_start:dim1_limit], ...} |
| // |
| // This function extracts the starts, limits, and strides as 3 vectors to the |
| // result. If stride is not present, stride is 1. For example, if the slice |
| // ranges is printed as: |
| // |
| // {[2:3:4], [5:6:7], [8:9]} |
| // |
| // The parsed result will be: |
| // |
| // {/*starts=*/{2, 5, 8}, /*limits=*/{3, 6, 9}, /*strides=*/{4, 7, 1}} |
| // |
| bool HloParser::ParseSliceRanges(SliceRanges* result) { |
| if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) { |
| return false; |
| } |
| std::vector<std::vector<tensorflow::int64>> ranges; |
| if (lexer_.GetKind() == TokKind::kRbrace) { |
| // empty |
| return ParseToken(TokKind::kRbrace, "expects '}' to end ranges"); |
| } |
| do { |
| LocTy loc = lexer_.GetLoc(); |
| ranges.emplace_back(); |
| if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kColon, |
| &ranges.back())) { |
| return false; |
| } |
| const auto& range = ranges.back(); |
| if (range.size() != 2 && range.size() != 3) { |
| return Error(loc, |
| StrFormat("expects [start:limit:step] or [start:limit], " |
| "but sees %d elements.", |
| range.size())); |
| } |
| } while (EatIfPresent(TokKind::kComma)); |
| |
| for (const auto& range : ranges) { |
| result->starts.push_back(range[0]); |
| result->limits.push_back(range[1]); |
| result->strides.push_back(range.size() == 3 ? range[2] : 1); |
| } |
| return ParseToken(TokKind::kRbrace, "expects '}' to end ranges"); |
| } |
| |
| // precisionlist ::= start precision_elements end |
| // precision_elements |
| // ::= /*empty*/ |
| // ::= precision_val (delim precision_val)* |
| bool HloParser::ParsePrecisionList( |
| std::vector<PrecisionConfig::Precision>* result) { |
| auto parse_and_add_item = [&]() { |
| PrecisionConfig::Precision item; |
| if (!ParsePrecision(&item)) { |
| return false; |
| } |
| result->push_back(item); |
| return true; |
| }; |
| return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, |
| parse_and_add_item); |
| } |
| |
| // int64list ::= start int64_elements end |
| // int64_elements |
| // ::= /*empty*/ |
| // ::= int64_val (delim int64_val)* |
| bool HloParser::ParseInt64List(const TokKind start, const TokKind end, |
| const TokKind delim, |
| std::vector<tensorflow::int64>* result) { |
| if (!ParseToken(start, StrCat("expects an int64 list starting with ", |
| TokKindToString(start)))) { |
| return false; |
| } |
| if (lexer_.GetKind() == end) { |
| // empty |
| } else { |
| do { |
| tensorflow::int64 i; |
| if (!ParseInt64(&i)) { |
| return false; |
| } |
| result->push_back(i); |
| } while (EatIfPresent(delim)); |
| } |
| return ParseToken( |
| end, StrCat("expects an int64 list to end with ", TokKindToString(end))); |
| } |
| |
| bool HloParser::ParseList(const TokKind start, const TokKind end, |
| const TokKind delim, |
| const std::function<bool()>& parse_and_add_item) { |
| if (!ParseToken(start, StrCat("expects a list starting with ", |
| TokKindToString(start)))) { |
| return false; |
| } |
| if (lexer_.GetKind() == end) { |
| // empty |
| } else { |
| do { |
| if (!parse_and_add_item()) { |
| return false; |
| } |
| } while (EatIfPresent(delim)); |
| } |
| return ParseToken( |
| end, StrCat("expects a list to end with ", TokKindToString(end))); |
| } |
| |
| // param_list_to_shape ::= param_list '->' shape |
| bool HloParser::ParseParamListToShape(Shape* shape, LocTy* shape_loc) { |
| if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) { |
| return false; |
| } |
| *shape_loc = lexer_.GetLoc(); |
| return ParseShape(shape); |
| } |
| |
| bool HloParser::CanBeParamListToShape() { |
| return lexer_.GetKind() == TokKind::kLparen; |
| } |
| |
| // param_list ::= '(' param_list1 ')' |
| // param_list1 |
| // ::= /*empty*/ |
| // ::= param (',' param)* |
| // param ::= name shape |
| bool HloParser::ParseParamList() { |
| if (!ParseToken(TokKind::kLparen, |
| "expects '(' at the beginning of param list")) { |
| return false; |
| } |
| |
| if (lexer_.GetKind() == TokKind::kRparen) { |
| // empty |
| } else { |
| do { |
| Shape shape; |
| string name; |
| if (!ParseName(&name) || !ParseShape(&shape)) { |
| return false; |
| } |
| } while (EatIfPresent(TokKind::kComma)); |
| } |
| return ParseToken(TokKind::kRparen, "expects ')' at the end of param list"); |
| } |
| |
| // shape ::= shape_val_ |
| // shape ::= '(' tuple_elements ')' |
| // tuple_elements |
| // ::= /*empty*/ |
| // ::= shape (',' shape)* |
| bool HloParser::ParseShape(Shape* result) { |
| if (EatIfPresent(TokKind::kLparen)) { // Tuple |
| std::vector<Shape> shapes; |
| if (lexer_.GetKind() == TokKind::kRparen) { |
| /*empty*/ |
| } else { |
| // shape (',' shape)* |
| do { |
| shapes.emplace_back(); |
| if (!ParseShape(&shapes.back())) { |
| return false; |
| } |
| } while (EatIfPresent(TokKind::kComma)); |
| } |
| *result = ShapeUtil::MakeTupleShape(shapes); |
| return ParseToken(TokKind::kRparen, "expects ')' at the end of tuple."); |
| } |
| |
| if (lexer_.GetKind() != TokKind::kShape) { |
| return TokenError("expects shape"); |
| } |
| *result = lexer_.GetShapeVal(); |
| lexer_.Lex(); |
| return true; |
| } |
| |
| bool HloParser::CanBeShape() { |
| // A non-tuple shape starts with a kShape token; a tuple shape starts with |
| // '('. |
| return lexer_.GetKind() == TokKind::kShape || |
| lexer_.GetKind() == TokKind::kLparen; |
| } |
| |
| bool HloParser::ParseName(string* result) { |
| VLOG(1) << "ParseName"; |
| if (lexer_.GetKind() != TokKind::kIdent && |
| lexer_.GetKind() != TokKind::kName) { |
| return TokenError("expects name"); |
| } |
| *result = lexer_.GetStrVal(); |
| lexer_.Lex(); |
| return true; |
| } |
| |
| bool HloParser::ParseAttributeName(string* result) { |
| if (lexer_.GetKind() != TokKind::kAttributeName) { |
| return TokenError("expects attribute name"); |
| } |
| *result = lexer_.GetStrVal(); |
| lexer_.Lex(); |
| return true; |
| } |
| |
| bool HloParser::ParseString(string* result) { |
| VLOG(1) << "ParseString"; |
| if (lexer_.GetKind() != TokKind::kString) { |
| return TokenError("expects string"); |
| } |
| *result = lexer_.GetStrVal(); |
| lexer_.Lex(); |
| return true; |
| } |
| |
| bool HloParser::ParseDxD(const string& name, |
| std::vector<tensorflow::int64>* result) { |
| LocTy loc = lexer_.GetLoc(); |
| if (!result->empty()) { |
| return Error(loc, StrFormat("sub-attribute '%s=' already exists", name)); |
| } |
| // 1D |
| if (lexer_.GetKind() == TokKind::kInt) { |
| tensorflow::int64 number; |
| if (!ParseInt64(&number)) { |
| return Error(loc, StrFormat("expects sub-attribute '%s=i'", name)); |
| } |
| result->push_back(number); |
| return true; |
| } |
| // 2D or higher. |
| if (lexer_.GetKind() == TokKind::kDxD) { |
| string str = lexer_.GetStrVal(); |
| if (!SplitToInt64s(str, 'x', result)) { |
| return Error(loc, StrFormat("expects sub-attribute '%s=ixj...'", name)); |
| } |
| lexer_.Lex(); |
| return true; |
| } |
| return TokenError("expects token type kInt or kDxD"); |
| } |
| |
| bool HloParser::ParseWindowPad( |
| std::vector<std::vector<tensorflow::int64>>* pad) { |
| LocTy loc = lexer_.GetLoc(); |
| if (!pad->empty()) { |
| return Error(loc, "sub-attribute 'pad=' already exists"); |
| } |
| if (lexer_.GetKind() != TokKind::kPad) { |
| return TokenError("expects window pad pattern, e.g., '0_0x3_3'"); |
| } |
| string str = lexer_.GetStrVal(); |
| for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) { |
| std::vector<tensorflow::int64> low_high; |
| if (!SplitToInt64s(padding_dim_str, '_', &low_high) || |
| low_high.size() != 2) { |
| return Error(loc, |
| "expects padding_low and padding_high separated by '_'"); |
| } |
| pad->push_back(low_high); |
| } |
| lexer_.Lex(); |
| return true; |
| } |
| |
| // This is the inverse xla::ToString(PaddingConfig). The padding config string |
| // looks like "0_0_0x3_3_1". The string is first separated by 'x', each |
| // substring represents one PaddingConfigDimension. The substring is 3 (or 2) |
| // numbers joined by '_'. |
| bool HloParser::ParsePaddingConfig(PaddingConfig* padding) { |
| if (lexer_.GetKind() != TokKind::kPad) { |
| return TokenError("expects padding config, e.g., '0_0_0x3_3_1'"); |
| } |
| LocTy loc = lexer_.GetLoc(); |
| string str = lexer_.GetStrVal(); |
| for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) { |
| std::vector<tensorflow::int64> padding_dim; |
| if (!SplitToInt64s(padding_dim_str, '_', &padding_dim) || |
| (padding_dim.size() != 2 && padding_dim.size() != 3)) { |
| return Error(loc, |
| "expects padding config pattern like 'low_high_interior' or " |
| "'low_high'"); |
| } |
| auto* dim = padding->add_dimensions(); |
| dim->set_edge_padding_low(padding_dim[0]); |
| dim->set_edge_padding_high(padding_dim[1]); |
| dim->set_interior_padding(padding_dim.size() == 3 ? padding_dim[2] : 0); |
| } |
| lexer_.Lex(); |
| return true; |
| } |
| |
| // '{' metadata_string '}' |
| bool HloParser::ParseMetadata(OpMetadata* metadata) { |
| std::unordered_map<string, AttrConfig> attrs; |
| optional<string> op_type; |
| optional<string> op_name; |
| optional<string> source_file; |
| optional<tensorflow::int32> source_line; |
| attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type}; |
| attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name}; |
| attrs["source_file"] = {/*required=*/false, AttrTy::kString, &source_file}; |
| attrs["source_line"] = {/*required=*/false, AttrTy::kInt32, &source_line}; |
| if (!ParseSubAttributes(attrs)) { |
| return false; |
| } |
| if (op_type) { |
| metadata->set_op_type(*op_type); |
| } |
| if (op_name) { |
| metadata->set_op_name(*op_name); |
| } |
| if (source_file) { |
| metadata->set_source_file(*source_file); |
| } |
| if (source_line) { |
| metadata->set_source_line(*source_line); |
| } |
| return true; |
| } |
| |
| bool HloParser::ParseOpcode(HloOpcode* result) { |
| VLOG(1) << "ParseOpcode"; |
| if (lexer_.GetKind() != TokKind::kIdent) { |
| return TokenError("expects opcode"); |
| } |
| string val = lexer_.GetStrVal(); |
| auto status_or_result = StringToHloOpcode(val); |
| if (!status_or_result.ok()) { |
| return TokenError(StrFormat("expects opcode but sees: %s, error: %s", val, |
| status_or_result.status().error_message())); |
| } |
| *result = status_or_result.ValueOrDie(); |
| lexer_.Lex(); |
| return true; |
| } |
| |
| bool HloParser::ParseFftType(FftType* result) { |
| VLOG(1) << "ParseFftType"; |
| if (lexer_.GetKind() != TokKind::kIdent) { |
| return TokenError("expects fft type"); |
| } |
| string val = lexer_.GetStrVal(); |
| if (!FftType_Parse(val, result) || !FftType_IsValid(*result)) { |
| return TokenError(StrFormat("expects fft type but sees: %s", val)); |
| } |
| lexer_.Lex(); |
| return true; |
| } |
| |
| bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) { |
| VLOG(1) << "ParseFusionKind"; |
| if (lexer_.GetKind() != TokKind::kIdent) { |
| return TokenError("expects fusion kind"); |
| } |
| string val = lexer_.GetStrVal(); |
| auto status_or_result = StringToFusionKind(val); |
| if (!status_or_result.ok()) { |
| return TokenError(StrFormat("expects fusion kind but sees: %s, error: %s", |
| val, |
| status_or_result.status().error_message())); |
| } |
| *result = status_or_result.ValueOrDie(); |
| lexer_.Lex(); |
| return true; |
| } |
| |
| bool HloParser::ParseRandomDistribution(RandomDistribution* result) { |
| VLOG(1) << "ParseRandomDistribution"; |
| if (lexer_.GetKind() != TokKind::kIdent) { |
| return TokenError("expects random distribution"); |
| } |
| string val = lexer_.GetStrVal(); |
| auto status_or_result = StringToRandomDistribution(val); |
| if (!status_or_result.ok()) { |
| return TokenError( |
| StrFormat("expects random distribution but sees: %s, error: %s", val, |
| status_or_result.status().error_message())); |
| } |
| *result = status_or_result.ValueOrDie(); |
| lexer_.Lex(); |
| return true; |
| } |
| |
| bool HloParser::ParsePrecision(PrecisionConfig::Precision* result) { |
| VLOG(1) << "ParsePrecision"; |
| if (lexer_.GetKind() != TokKind::kIdent) { |
| return TokenError("expects random distribution"); |
| } |
| string val = lexer_.GetStrVal(); |
| auto status_or_result = StringToPrecision(val); |
| if (!status_or_result.ok()) { |
| return TokenError(StrFormat("expects precision but sees: %s, error: %s", |
| val, |
| status_or_result.status().error_message())); |
| } |
| *result = status_or_result.ValueOrDie(); |
| lexer_.Lex(); |
| return true; |
| } |
| |
| bool HloParser::ParseInt64(tensorflow::int64* result) { |
| VLOG(1) << "ParseInt64"; |
| if (lexer_.GetKind() != TokKind::kInt) { |
| return TokenError("expects integer"); |
| } |
| *result = lexer_.GetInt64Val(); |
| lexer_.Lex(); |
| return true; |
| } |
| |
| bool HloParser::ParseDouble(double* result) { |
| switch (lexer_.GetKind()) { |
| case TokKind::kDecimal: |
| *result = lexer_.GetDecimalVal(); |
| break; |
| case TokKind::kInt: |
| *result = static_cast<double>(lexer_.GetInt64Val()); |
| break; |
| case TokKind::kw_nan: |
| *result = std::numeric_limits<double>::quiet_NaN(); |
| break; |
| case TokKind::kw_inf: |
| *result = std::numeric_limits<double>::infinity(); |
| break; |
| case TokKind::kNegInf: |
| *result = -std::numeric_limits<double>::infinity(); |
| break; |
| default: |
| return TokenError("expects decimal or integer"); |
| } |
| lexer_.Lex(); |
| return true; |
| } |
| |
| bool HloParser::ParseBool(bool* result) { |
| if (lexer_.GetKind() != TokKind::kw_true && |
| lexer_.GetKind() != TokKind::kw_false) { |
| return TokenError("expects true or false"); |
| } |
| *result = lexer_.GetKind() == TokKind::kw_true; |
| lexer_.Lex(); |
| return true; |
| } |
| |
| bool HloParser::ParseToken(TokKind kind, const string& msg) { |
| VLOG(1) << "ParseToken " << TokKindToString(kind) << " " << msg; |
| if (lexer_.GetKind() != kind) { |
| return TokenError(msg); |
| } |
| lexer_.Lex(); |
| return true; |
| } |
| |
| bool HloParser::EatIfPresent(TokKind kind) { |
| if (lexer_.GetKind() != kind) { |
| return false; |
| } |
| lexer_.Lex(); |
| return true; |
| } |
| |
| bool HloParser::AddInstruction(const string& name, HloInstruction* instruction, |
| LocTy name_loc) { |
| auto result = instruction_pool_.insert({name, {instruction, name_loc}}); |
| if (!result.second) { |
| Error(name_loc, StrCat("instruction already exists: ", name)); |
| return Error(/*loc=*/result.first->second.second, |
| "instruction previously defined here"); |
| } |
| return true; |
| } |
| |
| bool HloParser::AddComputation(const string& name, HloComputation* computation, |
| LocTy name_loc) { |
| auto result = computation_pool_.insert({name, {computation, name_loc}}); |
| if (!result.second) { |
| Error(name_loc, StrCat("computation already exists: ", name)); |
| return Error(/*loc=*/result.first->second.second, |
| "computation previously defined here"); |
| } |
| return true; |
| } |
| |
| StatusOr<HloSharding> HloParser::ParseShardingOnly() { |
| lexer_.Lex(); |
| OpSharding op_sharding; |
| if (!ParseSharding(&op_sharding)) { |
| return InvalidArgument("Syntax error:\n%s", GetError()); |
| } |
| if (lexer_.GetKind() != TokKind::kEof) { |
| return InvalidArgument("Syntax error:\nExtra content after sharding"); |
| } |
| return HloSharding::FromProto(op_sharding); |
| } |
| |
| StatusOr<Window> HloParser::ParseWindowOnly() { |
| lexer_.Lex(); |
| Window window; |
| if (!ParseWindow(&window, /*expect_outer_curlies=*/false)) { |
| return InvalidArgument("Syntax error:\n%s", GetError()); |
| } |
| if (lexer_.GetKind() != TokKind::kEof) { |
| return InvalidArgument("Syntax error:\nExtra content after window"); |
| } |
| return window; |
| } |
| |
| StatusOr<ConvolutionDimensionNumbers> |
| HloParser::ParseConvolutionDimensionNumbersOnly() { |
| lexer_.Lex(); |
| ConvolutionDimensionNumbers dnums; |
| if (!ParseConvolutionDimensionNumbers(&dnums)) { |
| return InvalidArgument("Syntax error:\n%s", GetError()); |
| } |
| if (lexer_.GetKind() != TokKind::kEof) { |
| return InvalidArgument( |
| "Syntax error:\nExtra content after convolution dnums"); |
| } |
| return dnums; |
| } |
| |
| StatusOr<PaddingConfig> HloParser::ParsePaddingConfigOnly() { |
| lexer_.Lex(); |
| PaddingConfig padding_config; |
| if (!ParsePaddingConfig(&padding_config)) { |
| return InvalidArgument("Syntax error:\n%s", GetError()); |
| } |
| if (lexer_.GetKind() != TokKind::kEof) { |
| return InvalidArgument("Syntax error:\nExtra content after PaddingConfig"); |
| } |
| return padding_config; |
| } |
| |
| Status HloParser::ParseSingleInstruction(HloComputation::Builder* builder, |
| string* root_name) { |
| TF_RET_CHECK(missing_instruction_hook_ == nullptr); |
| |
| // The missing instruction hook we register creates the shaped instruction on |
| // the fly as a parameter and returns it. |
| int64 parameter_count = 0; |
| missing_instruction_hook_ = |
| [this, builder, ¶meter_count]( |
| string name, |
| const optional<Shape>& shape) -> std::pair<HloInstruction*, LocTy>* { |
| if (!shape.has_value()) { |
| Error(lexer_.GetLoc(), |
| StrCat("Operand ", name, |
| " had no shape in HLO text; cannot create parameter for " |
| "single-instruction module.")); |
| return nullptr; |
| } |
| HloInstruction* parameter = builder->AddInstruction( |
| HloInstruction::CreateParameter(parameter_count++, *shape, name)); |
| instruction_pool_[name] = {parameter, lexer_.GetLoc()}; |
| return tensorflow::gtl::FindOrNull(instruction_pool_, name); |
| }; |
| |
| // Prime the lexer. |
| lexer_.Lex(); |
| |
| // Parse the instruction with the registered hook. |
| if (!ParseInstruction(builder, root_name)) { |
| return InvalidArgument("Syntax error:\n%s", GetError()); |
| } |
| return Status::OK(); |
| } |
| |
| } // namespace |
| |
| StatusOr<std::unique_ptr<HloModule>> ParseHloString( |
| absl::string_view str, const HloModuleConfig& config) { |
| auto module = absl::make_unique<HloModule>(/*name=*/"", config); |
| HloParser parser(str); |
| if (!parser.Run(module.get())) { |
| return InvalidArgument("Syntax error:\n%s", parser.GetError()); |
| } |
| return std::move(module); |
| } |
| |
| StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str) { |
| auto module = absl::make_unique<HloModule>(/*name=*/"", HloModuleConfig()); |
| HloParser parser(str); |
| if (!parser.Run(module.get())) { |
| return InvalidArgument("Syntax error:\n%s", parser.GetError()); |
| } |
| return std::move(module); |
| } |
| |
| Status ParseHloString(absl::string_view str, HloModule* module) { |
| TF_RET_CHECK(module->computation_count() == 0); |
| HloParser parser(str); |
| if (!parser.Run(module)) { |
| return InvalidArgument("Syntax error:\n%s", parser.GetError()); |
| } |
| return Status::OK(); |
| } |
| |
| StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule( |
| absl::string_view str, absl::string_view name) { |
| HloParser parser(str); |
| auto builder = absl::make_unique<HloComputation::Builder>(string(name)); |
| string root_name; |
| TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(builder.get(), &root_name)); |
| std::unique_ptr<HloComputation> computation = builder->Build(); |
| auto module = absl::make_unique<HloModule>(string(name), HloModuleConfig()); |
| module->AddEntryComputation(std::move(computation)); |
| return std::move(module); |
| } |
| |
| StatusOr<HloSharding> ParseSharding(absl::string_view str) { |
| HloParser parser(str); |
| return parser.ParseShardingOnly(); |
| } |
| |
| StatusOr<Window> ParseWindow(absl::string_view str) { |
| HloParser parser(str); |
| return parser.ParseWindowOnly(); |
| } |
| |
| StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers( |
| absl::string_view str) { |
| HloParser parser(str); |
| return parser.ParseConvolutionDimensionNumbersOnly(); |
| } |
| |
| StatusOr<PaddingConfig> ParsePaddingConfig(absl::string_view str) { |
| HloParser parser(str); |
| return parser.ParsePaddingConfigOnly(); |
| } |
| |
| } // namespace xla |