| /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| // Class and associated machinery for specifying an Op's OpDef and shape |
| // inference function for Op registration. |
| |
| #ifndef TENSORFLOW_CORE_FRAMEWORK_OP_DEF_BUILDER_H_ |
| #define TENSORFLOW_CORE_FRAMEWORK_OP_DEF_BUILDER_H_ |
| |
| #include <string> |
| #include <vector> |
| #include "tensorflow/core/framework/op_def.pb.h" |
| #include "tensorflow/core/lib/core/status.h" |
| #include "tensorflow/core/lib/core/stringpiece.h" |
| #include "tensorflow/core/platform/macros.h" |
| |
| namespace tensorflow { |
| |
| class FunctionDefHelper; |
| |
| namespace shape_inference { |
| class InferenceContext; |
| } |
| typedef std::function<Status(shape_inference::InferenceContext* c)> |
| OpShapeInferenceFn; |
| |
| struct OpRegistrationData { |
| public: |
| OpRegistrationData() {} |
| OpRegistrationData(const OpDef& def) : op_def(def) {} |
| OpRegistrationData(const OpDef& def, const OpShapeInferenceFn& fn, |
| bool is_function = false) |
| : op_def(def), shape_inference_fn(fn), is_function_op(is_function) {} |
| |
| OpDef op_def; |
| OpShapeInferenceFn shape_inference_fn; |
| bool is_function_op = false; |
| }; |
| |
| // Builder class passed to the REGISTER_OP() macro. |
| class OpDefBuilder { |
| public: |
| // Constructs an OpDef with just the name field set. |
| explicit OpDefBuilder(string op_name); |
| |
| // Adds an attr to this OpDefBuilder (and returns *this). The spec has |
| // format "<name>:<type>" or "<name>:<type>=<default>" |
| // where <name> matches regexp [a-zA-Z][a-zA-Z0-9_]* |
| // (by convention only using capital letters for attrs that can be inferred) |
| // <type> can be: |
| // "string", "int", "float", "bool", "type", "shape", or "tensor" |
| // "numbertype", "realnumbertype", "quantizedtype" |
| // (meaning "type" with a restriction on valid values) |
| // "{int32,int64}" or {realnumbertype,quantizedtype,string}" |
| // (meaning "type" with a restriction containing unions of value types) |
| // "{\"foo\", \"bar\n baz\"}", or "{'foo', 'bar\n baz'}" |
| // (meaning "string" with a restriction on valid values) |
| // "list(string)", ..., "list(tensor)", "list(numbertype)", ... |
| // (meaning lists of the above types) |
| // "int >= 2" (meaning "int" with a restriction on valid values) |
| // "list(string) >= 2", "list(int) >= 2" |
| // (meaning "list(string)" / "list(int)" with length at least 2) |
| // <default>, if included, should use the Proto text format |
| // of <type>. For lists use [a, b, c] format. |
| // |
| // Note that any attr specifying the length of an input or output will |
| // get a default minimum of 1 unless the >= # syntax is used. |
| // |
| // TODO(josh11b): Perhaps support restrictions and defaults as optional |
| // extra arguments to Attr() instead of encoding them in the spec string. |
| // TODO(josh11b): Would like to have better dtype handling for tensor attrs: |
| // * Ability to say the type of an input/output matches the type of |
| // the tensor. |
| // * Ability to restrict the type of the tensor like the existing |
| // restrictions for type attrs. |
| // Perhaps by linking the type of the tensor to a type attr? |
| OpDefBuilder& Attr(string spec); |
| |
| // Adds an input or output to this OpDefBuilder (and returns *this). |
| // The spec has form "<name>:<type-expr>" or "<name>:Ref(<type-expr>)" |
| // where <name> matches regexp [a-z][a-z0-9_]* and <type-expr> can be: |
| // * For a single tensor: <type> |
| // * For a sequence of tensors with the same type: <number>*<type> |
| // * For a sequence of tensors with different types: <type-list> |
| // Where: |
| // <type> is either one of "float", "int32", "string", ... |
| // or the name of an attr (see above) with type "type". |
| // <number> is the name of an attr with type "int". |
| // <type-list> is the name of an attr with type "list(type)". |
| // TODO(josh11b): Indicate Ref() via an optional argument instead of |
| // in the spec? |
| // TODO(josh11b): SparseInput() and SparseOutput() matching the Python |
| // handling? |
| OpDefBuilder& Input(string spec); |
| OpDefBuilder& Output(string spec); |
| |
| // Turns on the indicated boolean flag in this OpDefBuilder (and |
| // returns *this). |
| OpDefBuilder& SetIsCommutative(); |
| OpDefBuilder& SetIsAggregate(); |
| OpDefBuilder& SetIsStateful(); |
| OpDefBuilder& SetAllowsUninitializedInput(); |
| |
| // Deprecate the op at a certain GraphDef version. |
| OpDefBuilder& Deprecated(int version, string explanation); |
| |
| // Adds docs to this OpDefBuilder (and returns *this). |
| // Docs have the format: |
| // <1-line summary> |
| // <rest of the description> |
| // <name>: <description of name> |
| // <name>: <description of name> |
| // <if long, indent the description on subsequent lines> |
| // Where <name> is the name of an attr, input, or output. Please |
| // wrap docs at 72 columns so that it may be indented in the |
| // generated output. For tensor inputs or outputs (not attrs), you |
| // may start the description with an "=" (like name:= <description>) |
| // to suppress the automatically-generated type documentation in |
| // generated output. |
| #ifndef TF_LEAN_BINARY |
| OpDefBuilder& Doc(string text); |
| #else |
| OpDefBuilder& Doc(string text) { return *this; } |
| #endif |
| |
| // Sets the shape function to be used for shape inference. |
| // |
| // Note that currently (October 2016), python code still requires a |
| // RegisterShape call to invoke this; see call_cpp_shape_fn in |
| // python/framework/common_shapes.py |
| OpDefBuilder& SetShapeFn(OpShapeInferenceFn fn); |
| |
| // Sets op_reg_data->op_def to the requested OpDef and |
| // op_reg_data->shape_inference_fn to the requested shape inference function, |
| // or returns an error. |
| // Must be called after all of the above methods. |
| // |
| // Note that OpDefBuilder only reports parsing errors. You should also |
| // call ValidateOpDef() to detect other problems. |
| Status Finalize(OpRegistrationData* op_reg_data) const; |
| |
| private: |
| friend class FunctionDefHelper; |
| |
| // Adds control output to this OpDefBuilder (and returns *this). |
| // The <name> must be a valid node name (matches regexp |
| // [a-zA-Z][a-zA-Z0-9_]*). Named control output can only exist for functions. |
| OpDefBuilder& ControlOutput(string name); |
| |
| OpDef* op_def() { return &op_reg_data_.op_def; } |
| |
| OpRegistrationData op_reg_data_; |
| std::vector<string> attrs_; |
| std::vector<string> inputs_; |
| std::vector<string> outputs_; |
| std::vector<string> control_outputs_; |
| string doc_; |
| std::vector<string> errors_; |
| }; |
| |
| } // namespace tensorflow |
| |
| #endif // TENSORFLOW_CORE_FRAMEWORK_OP_DEF_BUILDER_H_ |