Moved shape inference out of RemoteExecuteNode
This allows to reuse it in ExecuteNode with zero changes to RemoteExecuteNode.
Note also, that RunShapeInference does not short-circuit on empty inputs
as in theory an Op could infer shapes from attributes in the NodeDef.
PiperOrigin-RevId: 272755801
diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD
index e9c549a..bcde8de 100644
--- a/tensorflow/core/common_runtime/eager/BUILD
+++ b/tensorflow/core/common_runtime/eager/BUILD
@@ -192,6 +192,19 @@
}),
)
+cc_library(
+ name = "shape_inference",
+ srcs = ["shape_inference.cc"],
+ hdrs = ["shape_inference.h"],
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ ":tensor_handle",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
KERNEL_AND_DEVICE_DEPS = [
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
diff --git a/tensorflow/core/common_runtime/eager/shape_inference.cc b/tensorflow/core/common_runtime/eager/shape_inference.cc
new file mode 100644
index 0000000..43ef021
--- /dev/null
+++ b/tensorflow/core/common_runtime/eager/shape_inference.cc
@@ -0,0 +1,59 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/eager/shape_inference.h"
+
+#include <vector>
+
+#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+namespace eager {
+
+Status RunShapeInference(const NodeDef& ndef,
+ const FunctionLibraryDefinition& lib_def,
+ const gtl::InlinedVector<TensorHandle*, 4>& inputs,
+ const gtl::InlinedVector<TensorHandle*, 2>& retvals) {
+ const tensorflow::OpRegistrationData* op_reg_data;
+ // TODO(b/141209983): Consider adding a shape inference cache.
+ // FunctionLibraryDefinition::LookUp delegates to global OpRegistry
+ // if op is not a function.
+ TF_RETURN_IF_ERROR(lib_def.LookUp(ndef.op(), &op_reg_data));
+ if (op_reg_data->shape_inference_fn == nullptr) return Status::OK();
+
+ shape_inference::InferenceContext ic(
+ TF_GRAPH_DEF_VERSION, &ndef, op_reg_data->op_def,
+ std::vector<shape_inference::ShapeHandle>(inputs.size()), {}, {}, {});
+ for (size_t i = 0; i < inputs.size(); i++) {
+ shape_inference::ShapeHandle shape;
+ TF_RETURN_IF_ERROR(inputs[i]->InferenceShape(&ic, &shape));
+ ic.SetInput(i, shape);
+ }
+
+ TF_RETURN_IF_ERROR(ic.Run(op_reg_data->shape_inference_fn));
+ CHECK_EQ(ic.num_outputs(), retvals.size());
+ for (int i = 0; i < ic.num_outputs(); i++) {
+ shape_inference::ShapeHandle shape_handle = ic.output(i);
+ retvals[i]->SetInferenceShape(&ic, shape_handle);
+ }
+ // TODO(slebedev): populate TensorHandle::handle_dtypes_and_shapes.
+ return Status::OK();
+}
+
+} // namespace eager
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/shape_inference.h b/tensorflow/core/common_runtime/eager/shape_inference.h
new file mode 100644
index 0000000..456bfc3
--- /dev/null
+++ b/tensorflow/core/common_runtime/eager/shape_inference.h
@@ -0,0 +1,36 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_SHAPE_INFERENCE_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_SHAPE_INFERENCE_H_
+
+#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+
+namespace tensorflow {
+namespace eager {
+
+Status RunShapeInference(const NodeDef& ndef,
+ const FunctionLibraryDefinition& lib_def,
+ const gtl::InlinedVector<TensorHandle*, 4>& inputs,
+ const gtl::InlinedVector<TensorHandle*, 2>& retvals);
+
+} // namespace eager
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_SHAPE_INFERENCE_H_
diff --git a/tensorflow/core/distributed_runtime/eager/BUILD b/tensorflow/core/distributed_runtime/eager/BUILD
index 2967761..108d4b5 100644
--- a/tensorflow/core/distributed_runtime/eager/BUILD
+++ b/tensorflow/core/distributed_runtime/eager/BUILD
@@ -72,12 +72,15 @@
hdrs = ["remote_execute_node.h"],
deps = [
":eager_client",
+ "//tensorflow/core:core_cpu",
"//tensorflow/core:eager_service_proto_cc",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:eager_executor",
+ "//tensorflow/core/common_runtime/eager:shape_inference",
"//tensorflow/core/common_runtime/eager:tensor_handle",
+ "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
diff --git a/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc b/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc
index a23a3be..0598086 100644
--- a/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc
+++ b/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc
@@ -17,44 +17,12 @@
#include <vector>
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/shape_inference.h"
-#include "tensorflow/core/public/version.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
namespace tensorflow {
namespace eager {
-Status RemoteExecuteNode::Prepare() {
- if (retvals_.empty()) return Status::OK();
-
- // TODO(b/141209983): Consider adding a shape inference cache.
- const tensorflow::OpRegistrationData* op_reg_data;
- if (lib_def_->Find(ndef_.op()) == nullptr) {
- TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUp(ndef_.op(), &op_reg_data));
- } else {
- TF_RETURN_IF_ERROR(lib_def_->LookUp(ndef_.op(), &op_reg_data));
- }
-
- shape_inference::InferenceContext inference_context(
- TF_GRAPH_DEF_VERSION, &ndef_, op_reg_data->op_def,
- std::vector<shape_inference::ShapeHandle>(inputs_.size()), {}, {},
- std::vector<
- std::unique_ptr<std::vector<shape_inference::ShapeAndType>>>());
- for (size_t i = 0; i < inputs_.size(); i++) {
- shape_inference::ShapeHandle shape;
- TF_RETURN_IF_ERROR(inputs_[i]->InferenceShape(&inference_context, &shape));
- inference_context.SetInput(i, shape);
- }
-
- TF_RETURN_IF_ERROR(inference_context.Run(op_reg_data->shape_inference_fn));
- DCHECK_EQ(inference_context.num_outputs(), retvals_.size());
- for (int i = 0; i < inference_context.num_outputs(); i++) {
- shape_inference::ShapeHandle shape_handle = inference_context.output(i);
- retvals_[i]->SetInferenceShape(&inference_context, shape_handle);
- }
- return Status::OK();
-}
-
void RemoteExecuteNode::RunAsync(StatusCallback done) {
EnqueueResponse* response = new EnqueueResponse;
diff --git a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
index 22d5721..3736173 100644
--- a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
+++ b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
@@ -19,11 +19,14 @@
#include <cstddef>
#include "absl/types/span.h"
+#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
+#include "tensorflow/core/common_runtime/eager/shape_inference.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/protobuf/eager_service.pb.h"
namespace tensorflow {
@@ -69,7 +72,9 @@
}
}
- Status Prepare() override;
+ Status Prepare() override {
+ return RunShapeInference(ndef_, *lib_def_, inputs_, retvals_);
+ }
void RunAsync(StatusCallback done) override;