Add annotations to fallback ops
diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc
index 062a9aa..cdef72a 100644
--- a/tensorflow/python/framework/python_op_gen.cc
+++ b/tensorflow/python/framework/python_op_gen.cc
@@ -168,7 +168,8 @@
bool AddEagerFallbackCode(const string& parameters,
const std::vector<string>& output_sizes,
const string& num_outputs_expr,
- const string& eager_not_allowed_error);
+ const string& eager_not_allowed_error,
+ std::unordered_map<string, string>& type_annotations);
void AddEagerFastPathExecute();
void AddEagerInferredAttrs(const string& indentation);
@@ -355,18 +356,19 @@
}
string parameters;
+ // Param can be an input or an attr
for (const auto& param : params_no_default_) {
if (!parameters.empty()) strings::StrAppend(¶meters, ", ");
strings::StrAppend(¶meters, param.GetRenameTo());
// Add type annotations to param
if (type_annotations.find(param.GetName()) != type_annotations.end()) {
- if (!type_annotations[param.GetName()].empty()) {
- strings::StrAppend(¶meters, ": ", type_annotations[param.GetName()]);
- }
+ strings::StrAppend(¶meters, ": ", type_annotations[param.GetName()]);
}
}
+ // Append to parameters and parameters_with_defaults because multiple functions
+ // are generated (op and fallback op)
string parameters_with_defaults = parameters;
for (const auto& param_and_default : params_with_default_) {
if (!parameters.empty()) strings::StrAppend(¶meters, ", ");
@@ -375,14 +377,12 @@
// Add type annotations to param_and_default
if (type_annotations.find(param_and_default.first.GetName()) != type_annotations.end()) {
- if (!type_annotations[param_and_default.first.GetName()].empty()) {
- strings::StrAppend(¶meters, ": ", type_annotations[param_and_default.first.GetName()]);
- strings::StrAppend(¶meters_with_defaults,
- param_and_default.first.GetRenameTo(), ": ",
- type_annotations[param_and_default.first.GetName()], " ",
- "= ", param_and_default.second);
- continue;
- }
+ const string param_type = type_annotations[param_and_default.first.GetName()];
+ strings::StrAppend(¶meters, param_and_default.first.GetRenameTo(), ": ", param_type);
+ strings::StrAppend(¶meters_with_defaults,
+ param_and_default.first.GetRenameTo(), ": ",
+ param_type, " = ", param_and_default.second);
+ continue;
}
strings::StrAppend(¶meters, param_and_default.first.GetRenameTo());
@@ -425,7 +425,7 @@
}
if (!AddEagerFallbackCode(parameters, output_sizes, num_outputs_expr,
- eager_not_allowed_error)) {
+ eager_not_allowed_error, type_annotations)) {
return result_;
}
@@ -449,35 +449,29 @@
for (const auto& arg : op_def_.input_arg()) {
// Do not add type annotations to args that accept a sequence of Tensors
if (!arg.number_attr().empty()) continue;
- string type_annotation;
if (type_annotations.find(arg.type_attr()) != type_annotations.end()) {
// Get the correct TypeVar if input maps to an attr
- strings::StrAppend(&type_annotation, "_ops.Tensor[", type_annotations[arg.type_attr()], "]");
+ type_annotations[arg.name()] = "_ops.Tensor[" + type_annotations[arg.type_attr()] + "]";
} else {
// Get the dtype of the Tensor
const string py_dtype = python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
if (dtype_type.find(py_dtype) != dtype_type.end()) {
- strings::StrAppend(&type_annotation, "_ops.Tensor[", dtype_type[py_dtype], "]");
+ type_annotations[arg.name()] = "_ops.Tensor[" + dtype_type[py_dtype] + "]";
}
}
-
- type_annotations[arg.name()] = type_annotation;
}
// Mapping output Tensor to its type
if (op_def_.output_arg_size() == 1) {
const auto& arg = op_def_.output_arg(0);
- string type_annotation;
if (type_annotations.find(arg.type_attr()) != type_annotations.end()) {
- strings::StrAppend(&type_annotation, "_ops.Tensor[", type_annotations[arg.type_attr()], "]");
+ type_annotations[arg.name()] = "_ops.Tensor[" + type_annotations[arg.type_attr()] + "]";
} else {
const string py_dtype = python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
if (dtype_type.find(py_dtype) != dtype_type.end()) {
- strings::StrAppend(&type_annotation, "_ops.Tensor[", dtype_type[py_dtype], "]");
+ type_annotations[arg.name()] = "_ops.Tensor[" + dtype_type[py_dtype] + "]";
}
}
-
- type_annotations[arg.name()] = type_annotation;
}
return type_annotations;
@@ -521,19 +515,20 @@
if (added_typevar) strings::StrAppend(&result_, "\n");
}
+// TODO(rahulkamat): Modify AddDefLine() to add return type annotation
void GenEagerPythonOp::AddReturnTypeAnnotation(std::unordered_map<string, string>& type_annotations) {
if (op_def_.output_arg_size() == 1) {
const auto& arg = op_def_.output_arg(0);
// Add type annotations to param
if (type_annotations.find(arg.name()) != type_annotations.end()) {
- if (!type_annotations[arg.name()].empty()) {
- result_.erase(result_.length() - 2);
- strings::StrAppend(&result_, " -> ", type_annotations[arg.name()], ":\n");
- }
+ result_.erase(result_.length() - 2);
+ strings::StrAppend(&result_, " -> ", type_annotations[arg.name()], ":\n");
}
}
}
+
+
void GenEagerPythonOp::HandleGraphMode(
const string& function_setup, const std::vector<string>& output_sizes) {
strings::StrAppend(&result_, " # Add nodes to the TensorFlow graph.\n");
@@ -903,11 +898,14 @@
bool GenEagerPythonOp::AddEagerFallbackCode(
const string& parameters, const std::vector<string>& output_sizes,
- const string& num_outputs_expr, const string& eager_not_allowed_error) {
+ const string& num_outputs_expr, const string& eager_not_allowed_error,
+ std::unordered_map<string, string>& type_annotations) {
AddDefLine(
strings::StrCat(function_name_, kEagerFallbackSuffix),
strings::StrCat(parameters, parameters.empty() ? "" : ", ", "ctx"));
-
+ if (type_annotate_ops.find(op_def_.name()) != type_annotate_ops.end()) {
+ AddReturnTypeAnnotation(type_annotations);
+ }
if (!eager_not_allowed_error.empty()) {
strings::StrAppend(&result_, " ", eager_not_allowed_error);
return true;