| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include <assert.h> |
| |
| #include <sstream> |
| #include <string> |
| #include <vector> |
| |
| #include "llvm/Support/CommandLine.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "llvm/Support/InitLLVM.h" |
| #include "llvm/Support/Signals.h" |
| #include "llvm/TableGen/Error.h" |
| #include "llvm/TableGen/Main.h" |
| #include "llvm/TableGen/Record.h" |
| #include "llvm/TableGen/TableGenBackend.h" |
| #include "mlir/TableGen/Attribute.h" // from @llvm-project |
| #include "mlir/TableGen/Format.h" // from @llvm-project |
| #include "mlir/TableGen/Operator.h" // from @llvm-project |
| #include "mlir/TableGen/Predicate.h" // from @llvm-project |
| |
| using llvm::DefInit; |
| using llvm::dyn_cast; |
| using llvm::formatv; |
| using llvm::LessRecord; |
| using llvm::raw_ostream; |
| using llvm::Record; |
| using llvm::RecordKeeper; |
| using llvm::RecordRecTy; |
| using llvm::SmallVector; |
| using llvm::StringInit; |
| using llvm::StringRef; |
| |
| enum ActionType { |
| OpConv, |
| RuntimeVerify, |
| }; |
| |
| // NOLINTNEXTLINE |
| llvm::cl::opt<ActionType> action( |
| llvm::cl::desc("Action to perform:"), |
| llvm::cl::values(clEnumValN(OpConv, "gen-operator-converters", |
| "Generate operator converters"), |
| clEnumValN(RuntimeVerify, "gen-runtime-verifiers", |
| "Generate TFLite runtime verifiers"))); |
| |
| // Returns the associated option name for the given op definition. |
| static inline std::string GetOperatorOptionName(const Record &def) { |
| assert(def.getName().startswith("TFL_") && "unexpected op prefix"); |
| assert(def.getName().endswith("Op") && "unexpected op suffix"); |
| |
| auto *custom_option = dyn_cast<StringInit>(def.getValueInit("customOption")); |
| std::ostringstream oss; |
| if (custom_option) |
| oss << custom_option->getValue().str(); |
| else |
| oss << def.getName().drop_front(4).drop_back(2).str() << "Options"; |
| return oss.str(); |
| } |
| |
| // Returns the builder function name for the given op definition. |
| static inline std::string GetOperatorBuilderName(StringRef op_name) { |
| assert(op_name.startswith("TFL_") && "unexpected op prefix"); |
| assert(op_name.endswith("Op") && "unexpected op suffix"); |
| |
| // E.g., AddOp -> CreateAddOperator |
| std::ostringstream oss; |
| oss << "Create" << op_name.drop_front(4).str() << "erator"; |
| return oss.str(); |
| } |
| |
| static void EmitOptionBuilders(const RecordKeeper &record_keeper, |
| const std::vector<Record *> &defs, |
| raw_ostream *ostream) { |
| raw_ostream &os = *ostream; |
| |
| const auto attr_type = record_keeper.getClass("Attr"); |
| for (const auto *def : defs) { |
| // TFLite ops without options are skipped over. |
| if (!def->getValueAsBit("hasOptions")) continue; |
| |
| StringRef op_name = def->getName().drop_front(4); // Strip 'TFL_' prefix |
| std::string option_name = GetOperatorOptionName(*def); |
| std::string tflite_option_name = |
| option_name == "BasicLSTMOptions" ? "LSTMOptions" : option_name; |
| |
| os << "flatbuffers::Offset<tflite::" << tflite_option_name << "> Create" |
| << option_name << "(mlir::TFL::" << op_name |
| << " op, flatbuffers::FlatBufferBuilder *fbb) {\n"; |
| |
| // Construct all the builder option needed. |
| SmallVector<std::string, 8> options; |
| // Add options due to attributes (not-derived). |
| auto *arg_values = def->getValueAsDag("arguments"); |
| for (unsigned i = 0, e = arg_values->getNumArgs(); i != e; ++i) { |
| auto arg = arg_values->getArg(i); |
| DefInit *arg_def = dyn_cast<DefInit>(arg); |
| if (!arg_def) continue; |
| if (arg_def->getDef()->isSubClassOf(attr_type)) { |
| // This binds the name of the attribute in the TD file with the name |
| // of the add function of the builder and also with the conversion |
| // function to convert from the internal representation to the format |
| // expected by the flatbuffer builder. While this constrains the |
| // naming of the ops/attributes in the TD file, this also removes the |
| // need for specifying indirection. This tool is specific to TFLite |
| // conversion generation and so the simplicity was chosen over the |
| // flexibility. |
| StringRef arg_name = arg_values->getArgNameStr(i); |
| // Skip any "intermiadiateXXX" attribute as they are specially handled |
| // in the exporter. They are special because though they are attributes |
| // in the MLIR they are expressed as tensors in the flatbuffer instead |
| // of option. |
| if (op_name == "LSTMOp" && arg_name.take_back(12) == "intermediate") |
| continue; |
| os << formatv( |
| " auto {0} = Convert{1}ForOptionWriter(op.{0}(), fbb);\n", |
| arg_name, mlir::tblgen::Attribute(arg_def).getAttrDefName()); |
| options.push_back(arg_name.str()); |
| } |
| } |
| |
| // Add options due to derived attributes. |
| for (const auto &val : def->getValues()) { |
| if (auto *record = dyn_cast<RecordRecTy>(val.getType())) { |
| if (record->isSubClassOf(attr_type)) { |
| if (record->getClasses().size() != 1) { |
| PrintFatalError( |
| def->getLoc(), |
| "unsupported attribute modelling, only single class expected"); |
| } |
| os << formatv( |
| " auto {0} = Convert{1}ForOptionWriter(op.{0}(), fbb);\n", |
| val.getName(), record->getClasses()[0]->getName()); |
| options.push_back(std::string(val.getName())); |
| } |
| } |
| } |
| |
| os << " tflite::" << tflite_option_name << "Builder b(*fbb);\n"; |
| for (const auto &option : options) |
| os << formatv(" b.add_{0}(std::move({0}));\n", option); |
| os << " return b.Finish();\n}\n"; |
| } |
| } |
| |
| // For each TFLite op, emits a builder function that packs the TFLite op into |
| // the corresponding FlatBuffer object. |
| // |
| // TODO(hinsu): Revisit if only builtin_options and mutating_variable_inputs |
| // arguments that depend on op definitions should be auto-generated and then |
| // operator should be built by the caller because it does not require |
| // auto-generation. |
| static void EmitOperatorBuilders(const std::vector<Record *> &defs, |
| raw_ostream *ostream) { |
| raw_ostream &os = *ostream; |
| |
| for (const auto *def : defs) { |
| StringRef op_name = def->getName().drop_front(4); |
| |
| const bool has_intermediates = op_name == "LSTMOp"; |
| // Signature |
| os << "static flatbuffers::Offset<tflite::Operator> " |
| << GetOperatorBuilderName(def->getName()) << "(mlir::TFL::" << op_name |
| << " tflOp, uint32_t opcode_index, " |
| << "const std::vector<int32_t>& operands," |
| << "const std::vector<int32_t>& results," |
| << (has_intermediates ? "const std::vector<int32_t>& intermediate_index," |
| : "") |
| << "flatbuffers::FlatBufferBuilder *fbb) {\n"; |
| |
| // Inputs & outputs |
| os << " auto inputs = fbb->CreateVector(operands);\n" |
| " auto outputs = fbb->CreateVector(results);\n\n"; |
| // Intermediates for LSTM. |
| if (has_intermediates) { |
| os << " auto intermediates = fbb->CreateVector(intermediate_index);\n"; |
| } |
| |
| // Build the FlatBuffer operator |
| os << " return tflite::CreateOperator(\n" |
| " *fbb, opcode_index, inputs, outputs,\n"; |
| if (def->getValueAsBit("hasOptions")) { |
| auto option_name = GetOperatorOptionName(*def); |
| std::string tflite_option_name = |
| option_name == "BasicLSTMOptions" ? "LSTMOptions" : option_name; |
| os << " tflite::BuiltinOptions_" << tflite_option_name << ", " |
| << "Create" << option_name << "(tflOp, fbb).Union(),\n"; |
| } else { |
| os << " tflite::BuiltinOptions_NONE, /*builtin_options=*/0,\n"; |
| } |
| // Only builtin ops' builders are auto-generated. custom_options are only |
| // used by custom or flex ops and those ops are handled manually. |
| os << " /*custom_options=*/0, " |
| << "tflite::CustomOptionsFormat_FLEXBUFFERS,\n" |
| << " /*mutating_variable_inputs=*/0" |
| << (has_intermediates ? ", intermediates" : "") << ");\n}\n\n"; |
| } |
| } |
| |
| static inline std::string GetOperatorName(const Record &def) { |
| auto name = def.getValueAsString("opName"); |
| // Special case for basic_lstm. |
| if (name == "basic_lstm") { |
| return "LSTM"; |
| } |
| return name.upper(); |
| } |
| |
| // Emits a function that returns builtin operator code for each TFLite op. |
| // |
| // The signature of the function is: |
| // |
| // llvm::Optional<tflite::BuiltinOperator> |
| // mlir::GetBuiltinOpCode(mlir::Operation* op); |
| // |
| // TODO(hinsu): Consider converting this to a static constant associative |
| // container instead of a series of if conditions, if required. |
| static void EmitGetBuiltinOpCode(const std::vector<Record *> &defs, |
| raw_ostream *ostream) { |
| raw_ostream &os = *ostream; |
| |
| os << "llvm::Optional<tflite::BuiltinOperator> " |
| "mlir::GetBuiltinOpCode(mlir::Operation* op) {\n"; |
| |
| for (const auto *def : defs) { |
| StringRef op_name = def->getName().drop_front(4); |
| os << " if (isa<mlir::TFL::" << op_name << ">(op))\n" |
| << " return tflite::BuiltinOperator_" << GetOperatorName(*def) |
| << ";\n"; |
| } |
| |
| os << " return llvm::None;\n" |
| "}\n"; |
| } |
| |
| // Emits a builder function that returns the packed FlatBuffer object given |
| // a general mlir::Operation. |
| // |
| // The signature of the function is: |
| // |
| // llvm::Optional<Flatbuffers::Offset<tflite::Operator>> |
| // mlir::CreateFlatBufferOperator( |
| // mlir::Operation* op, |
| // uint32_t opcode_index, |
| // const std::vector<int32_t>& operands, |
| // const std::vector<int32_t>& results, |
| // const std::vector<int32_t>& intermediates, |
| // flatbuffers::FlatBufferBuilder *fbb); |
| static void EmitBuildOperator(const std::vector<Record *> &defs, |
| raw_ostream *ostream) { |
| raw_ostream &os = *ostream; |
| |
| // Signature |
| os << "llvm::Optional<flatbuffers::Offset<tflite::Operator>>\n" |
| "mlir::CreateFlatBufferOperator(mlir::Operation* op, " |
| "uint32_t opcode_index, " |
| "const std::vector<int32_t>& operands," |
| "const std::vector<int32_t>& results," |
| "const std::vector<int32_t>& intermediates," |
| "flatbuffers::FlatBufferBuilder *fbb) {\n"; |
| |
| for (const auto *def : defs) { |
| StringRef op_name = def->getName().drop_front(4); |
| |
| // Try to cast to each op case and call the corresponding op builder |
| os << " if (auto tflOp = llvm::dyn_cast<mlir::TFL::" << op_name |
| << ">(op))\n" |
| << " return " << GetOperatorBuilderName(def->getName()) |
| << "(tflOp, opcode_index, operands, results, " |
| << (op_name == "LSTMOp" ? "intermediates, " : "") << "fbb);\n"; |
| } |
| |
| os << " return llvm::None;\n" |
| "}\n"; |
| } |
| |
| // Emit a function that converts a BuiltinOptionsUnion to a vector of attributes |
| // Signature: |
| // void mlir::BuiltinOptionsToAttributes( |
| // tflite::BuiltinOptionsUnion op_union, |
| // mlir::Builder builder, |
| // llvm::SmallVectorImpl<mlir::NamedAttribute> &attributes); |
| static void EmitBuiltinOptionsToAttributes(const RecordKeeper &record_keeper, |
| const std::vector<Record *> &defs, |
| raw_ostream *ostream) { |
| raw_ostream &os = *ostream; |
| |
| // Signature |
| os << "void mlir::BuiltinOptionsToAttributes(" |
| "tflite::BuiltinOptionsUnion op_union, " |
| "mlir::Builder builder, " |
| "llvm::SmallVectorImpl<mlir::NamedAttribute> &attributes) {\n"; |
| |
| const auto attr_type = record_keeper.getClass("Attr"); |
| for (const auto *def : defs) { |
| if (!def->getValueAsBit("hasOptions")) continue; |
| auto option_name = GetOperatorOptionName(*def); |
| // Basic LSTM and LSTM ops share the same option to attribute converter. |
| if (option_name == "BasicLSTMOptions") { |
| continue; |
| } |
| |
| os << formatv(" if(const auto *op = op_union.As{0}()) {{\n", option_name); |
| |
| // We only care about options that are in arguments |
| auto *arg_values = def->getValueAsDag("arguments"); |
| for (unsigned i = 0, e = arg_values->getNumArgs(); i != e; ++i) { |
| auto arg = arg_values->getArg(i); |
| DefInit *arg_def = dyn_cast<DefInit>(arg); |
| if (!arg_def) continue; |
| if (arg_def->getDef()->isSubClassOf(attr_type)) { |
| StringRef arg_name = arg_values->getArgNameStr(i); |
| // Already handle this case in flatbuffer_import.cc. |
| if (option_name == "LSTMOptions" && |
| arg_name.take_back(12) == "intermediate") |
| continue; |
| StringRef attr_type = mlir::tblgen::Attribute(arg_def).getAttrDefName(); |
| os << formatv( |
| " attributes.emplace_back(builder.getNamedAttr(\"{0}\"," |
| " Build{1}(op->{0}, builder)));\n", |
| arg_name, attr_type); |
| } |
| } |
| |
| os << " return;\n"; |
| os << " }\n"; |
| } |
| // Fallthrough case is no attributes |
| os << "}"; |
| } |
| // The function below has a non-constant reference as that is required by LLVM's |
| // TableGenMain. |
| // NOLINTNEXTLINE |
| static bool OperatorWritersMain(raw_ostream &os, RecordKeeper &records) { |
| emitSourceFileHeader("MLIR TFLite FlatBuffer Builders", os); |
| |
| // Retrieve all the definitions derived from TFL_Op and sort by record name. |
| std::vector<Record *> defs = records.getAllDerivedDefinitions("TFL_Op"); |
| llvm::sort(defs, LessRecord()); |
| |
| for (const auto *def : defs) { |
| // TFLite ops in the .td file are expected to follow the naming convention: |
| // TFL_<OpName>Op. |
| // The generated TFLite op C++ class should be TFL::<OpName>Op. |
| // The generated operator's options should be tflite::<OpName>Options. |
| // The option builder should be Create<OpName>Options. |
| if (!def->getName().startswith("TFL_")) |
| PrintFatalError(def->getLoc(), |
| "unexpected op name format: 'TFL_' prefix missing"); |
| if (!def->getName().endswith("Op")) |
| PrintFatalError(def->getLoc(), |
| "unexpected op name format: 'Op' suffix missing"); |
| } |
| |
| EmitOptionBuilders(records, defs, &os); |
| os << "\n\n"; |
| EmitOperatorBuilders(defs, &os); |
| os << "\n\n"; |
| EmitGetBuiltinOpCode(defs, &os); |
| os << "\n\n"; |
| EmitBuildOperator(defs, &os); |
| os << "\n\n"; |
| EmitBuiltinOptionsToAttributes(records, defs, &os); |
| |
| return false; |
| } |
| |
| static void GenOperandResultVerifier(raw_ostream &os, |
| llvm::ArrayRef<llvm::Init *> values, |
| StringRef valueKind) { |
| mlir::tblgen::FmtContext fctx; |
| |
| bool first = true; |
| for (auto static_value : llvm::enumerate(values)) { |
| auto *definit = llvm::cast<llvm::DefInit>(static_value.value()); |
| auto *val = definit->getDef()->getValue("tflRuntimeTypePredicate"); |
| if (!val) continue; |
| |
| // Create code block on first type to verify. |
| if (first) { |
| os << " {\n"; |
| os << " unsigned index = " << static_value.index() << ";\n"; |
| first = false; |
| } |
| |
| mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue())); |
| auto desc = |
| definit->getDef()->getValueAsString("tflRuntimeTypeDescription"); |
| |
| // Emit a loop to check all the dynamic values in the pack. |
| os << formatv(" for (Value v : top.getODS{0}{1}s({2})) {{\n", |
| // Capitalize the first letter to match the function name |
| valueKind.substr(0, 1).upper(), valueKind.substr(1), |
| static_value.index()); |
| |
| os << " (void)v;\n" |
| << " if (!(" |
| << tgfmt(pred.getCondition(), &fctx.withSelf("v.getType()")) << ")) {\n" |
| << " if (failure_on_operand_type_mismatch) {\n" |
| << formatv( |
| " return op->emitOpError(\"{0} #\") << index " |
| "<< \" must be {1}, but got \" << v.getType();\n", |
| valueKind, desc) |
| << " } else {\n" |
| << " return ::mlir::LogicalResult::Failure;\n" |
| << " }\n" |
| << " }\n" // if |
| << " ++index;\n" |
| << " }\n"; // for |
| } |
| |
| // Emit closing brace if needed. |
| if (!first) os << " }\n"; |
| } |
| |
| // NOLINTNEXTLINE |
| static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) { |
| emitSourceFileHeader("MLIR TFLite Runtime Verifiers", os); |
| |
| // Retrieve all the definitions derived from TFL_Op and sort by record name. |
| std::vector<Record *> defs = records.getAllDerivedDefinitions("Op"); |
| llvm::sort(defs, LessRecord()); |
| |
| // Iterate through all the ops defined. |
| for (const auto *def : defs) { |
| mlir::tblgen::Operator op(*def); |
| if (!op.getTrait("TflRuntimeVerifyOpInterface::Trait")) continue; |
| |
| mlir::tblgen::FmtContext verify_ctx; |
| os << "::mlir::LogicalResult " << op.getCppClassName() |
| << "::VerifyTflRuntimeConstraints(::mlir::Operation *op, bool " |
| "failure_on_operand_type_mismatch) {\n"; |
| os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n"; |
| verify_ctx.withOp("top"); |
| |
| for (int i = 0, e = op.getNumOperands(); i < e; ++i) { |
| auto &value = op.getOperand(i); |
| // Skip from from first variadic operands for now. Else getOperand index |
| // used below doesn't match. |
| if (value.isVariadic()) break; |
| if (!value.name.empty()) |
| verify_ctx.addSubst(value.name, formatv("op->getOperand({0})", i)); |
| } |
| for (int i = 0, e = op.getNumResults(); i < e; ++i) { |
| auto &value = op.getResult(i); |
| // Skip from from first variadic results for now. Else getResult index |
| // used below doesn't match. |
| if (value.isVariadic()) break; |
| if (!value.name.empty()) |
| verify_ctx.addSubst(value.name, formatv("op->getResult({0})", i)); |
| } |
| GenOperandResultVerifier(os, def->getValueAsDag("arguments")->getArgs(), |
| "operand"); |
| GenOperandResultVerifier(os, def->getValueAsDag("results")->getArgs(), |
| "result"); |
| |
| for (auto &trait : op.getTraits()) { |
| if (!trait.getDef().isSubClassOf("GenInternalOpTrait")) { |
| continue; |
| } |
| if (trait.getDef().getValueAsString("trait") != |
| "OpTrait::TFLRuntimeOpTrait") { |
| continue; |
| } |
| |
| auto *val = trait.getDef().getValue("tflRuntimePredicate"); |
| if (!val) continue; |
| |
| mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue())); |
| os << tgfmt( |
| " if (!($0)) {\n " |
| " return ::mlir::LogicalResult::Failure;\n }\n", |
| &verify_ctx, tgfmt(pred.getCondition(), &verify_ctx)); |
| } |
| os << " return top.verify();\n}\n"; |
| } |
| |
| return false; |
| } |
| |
| int main(int argc, char **argv) { |
| llvm::InitLLVM y(argc, argv); |
| llvm::cl::ParseCommandLineOptions(argc, argv); |
| if (action == ActionType::OpConv) |
| return TableGenMain(argv[0], &OperatorWritersMain); |
| return TableGenMain(argv[0], &RuntimeVerifierWriterMain); |
| } |