Moved shape inference out of RemoteExecuteNode

This allows to reuse it in ExecuteNode with zero changes to RemoteExecuteNode.

Note also, that RunShapeInference does not short-circuit on empty inputs
as in theory an Op could infer shapes from attributes in the NodeDef.

PiperOrigin-RevId: 272755801
diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD
index e9c549a..bcde8de 100644
--- a/tensorflow/core/common_runtime/eager/BUILD
+++ b/tensorflow/core/common_runtime/eager/BUILD
@@ -192,6 +192,19 @@
     }),
 )
 
+cc_library(
+    name = "shape_inference",
+    srcs = ["shape_inference.cc"],
+    hdrs = ["shape_inference.h"],
+    visibility = ["//tensorflow:internal"],
+    deps = [
+        ":tensor_handle",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+    ],
+)
+
 KERNEL_AND_DEVICE_DEPS = [
     "//tensorflow/core:core_cpu_lib",
     "//tensorflow/core:framework",
diff --git a/tensorflow/core/common_runtime/eager/shape_inference.cc b/tensorflow/core/common_runtime/eager/shape_inference.cc
new file mode 100644
index 0000000..43ef021
--- /dev/null
+++ b/tensorflow/core/common_runtime/eager/shape_inference.cc
@@ -0,0 +1,59 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/eager/shape_inference.h"
+
+#include <vector>
+
+#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+namespace eager {
+
+Status RunShapeInference(const NodeDef& ndef,
+                         const FunctionLibraryDefinition& lib_def,
+                         const gtl::InlinedVector<TensorHandle*, 4>& inputs,
+                         const gtl::InlinedVector<TensorHandle*, 2>& retvals) {
+  const tensorflow::OpRegistrationData* op_reg_data;
+  // TODO(b/141209983): Consider adding a shape inference cache.
+  // FunctionLibraryDefinition::LookUp delegates to global OpRegistry
+  // if op is not a function.
+  TF_RETURN_IF_ERROR(lib_def.LookUp(ndef.op(), &op_reg_data));
+  if (op_reg_data->shape_inference_fn == nullptr) return Status::OK();
+
+  shape_inference::InferenceContext ic(
+      TF_GRAPH_DEF_VERSION, &ndef, op_reg_data->op_def,
+      std::vector<shape_inference::ShapeHandle>(inputs.size()), {}, {}, {});
+  for (size_t i = 0; i < inputs.size(); i++) {
+    shape_inference::ShapeHandle shape;
+    TF_RETURN_IF_ERROR(inputs[i]->InferenceShape(&ic, &shape));
+    ic.SetInput(i, shape);
+  }
+
+  TF_RETURN_IF_ERROR(ic.Run(op_reg_data->shape_inference_fn));
+  CHECK_EQ(ic.num_outputs(), retvals.size());
+  for (int i = 0; i < ic.num_outputs(); i++) {
+    shape_inference::ShapeHandle shape_handle = ic.output(i);
+    retvals[i]->SetInferenceShape(&ic, shape_handle);
+  }
+  // TODO(slebedev): populate TensorHandle::handle_dtypes_and_shapes.
+  return Status::OK();
+}
+
+}  // namespace eager
+}  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/shape_inference.h b/tensorflow/core/common_runtime/eager/shape_inference.h
new file mode 100644
index 0000000..456bfc3
--- /dev/null
+++ b/tensorflow/core/common_runtime/eager/shape_inference.h
@@ -0,0 +1,36 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_SHAPE_INFERENCE_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_SHAPE_INFERENCE_H_
+
+#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+
+namespace tensorflow {
+namespace eager {
+
+Status RunShapeInference(const NodeDef& ndef,
+                         const FunctionLibraryDefinition& lib_def,
+                         const gtl::InlinedVector<TensorHandle*, 4>& inputs,
+                         const gtl::InlinedVector<TensorHandle*, 2>& retvals);
+
+}  // namespace eager
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_SHAPE_INFERENCE_H_
diff --git a/tensorflow/core/distributed_runtime/eager/BUILD b/tensorflow/core/distributed_runtime/eager/BUILD
index 2967761..108d4b5 100644
--- a/tensorflow/core/distributed_runtime/eager/BUILD
+++ b/tensorflow/core/distributed_runtime/eager/BUILD
@@ -72,12 +72,15 @@
     hdrs = ["remote_execute_node.h"],
     deps = [
         ":eager_client",
+        "//tensorflow/core:core_cpu",
         "//tensorflow/core:eager_service_proto_cc",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/common_runtime/eager:eager_executor",
+        "//tensorflow/core/common_runtime/eager:shape_inference",
         "//tensorflow/core/common_runtime/eager:tensor_handle",
+        "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:span",
     ],
 )
diff --git a/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc b/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc
index a23a3be..0598086 100644
--- a/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc
+++ b/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc
@@ -17,44 +17,12 @@
 
 #include <vector>
 
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/shape_inference.h"
-#include "tensorflow/core/public/version.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
 
 namespace tensorflow {
 namespace eager {
 
-Status RemoteExecuteNode::Prepare() {
-  if (retvals_.empty()) return Status::OK();
-
-  // TODO(b/141209983): Consider adding a shape inference cache.
-  const tensorflow::OpRegistrationData* op_reg_data;
-  if (lib_def_->Find(ndef_.op()) == nullptr) {
-    TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUp(ndef_.op(), &op_reg_data));
-  } else {
-    TF_RETURN_IF_ERROR(lib_def_->LookUp(ndef_.op(), &op_reg_data));
-  }
-
-  shape_inference::InferenceContext inference_context(
-      TF_GRAPH_DEF_VERSION, &ndef_, op_reg_data->op_def,
-      std::vector<shape_inference::ShapeHandle>(inputs_.size()), {}, {},
-      std::vector<
-          std::unique_ptr<std::vector<shape_inference::ShapeAndType>>>());
-  for (size_t i = 0; i < inputs_.size(); i++) {
-    shape_inference::ShapeHandle shape;
-    TF_RETURN_IF_ERROR(inputs_[i]->InferenceShape(&inference_context, &shape));
-    inference_context.SetInput(i, shape);
-  }
-
-  TF_RETURN_IF_ERROR(inference_context.Run(op_reg_data->shape_inference_fn));
-  DCHECK_EQ(inference_context.num_outputs(), retvals_.size());
-  for (int i = 0; i < inference_context.num_outputs(); i++) {
-    shape_inference::ShapeHandle shape_handle = inference_context.output(i);
-    retvals_[i]->SetInferenceShape(&inference_context, shape_handle);
-  }
-  return Status::OK();
-}
-
 void RemoteExecuteNode::RunAsync(StatusCallback done) {
   EnqueueResponse* response = new EnqueueResponse;
 
diff --git a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
index 22d5721..3736173 100644
--- a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
+++ b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
@@ -19,11 +19,14 @@
 #include <cstddef>
 
 #include "absl/types/span.h"
+#include "tensorflow/core/common_runtime/device.h"
 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
+#include "tensorflow/core/common_runtime/eager/shape_inference.h"
 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
 #include "tensorflow/core/protobuf/eager_service.pb.h"
 
 namespace tensorflow {
@@ -69,7 +72,9 @@
     }
   }
 
-  Status Prepare() override;
+  Status Prepare() override {
+    return RunShapeInference(ndef_, *lib_def_, inputs_, retvals_);
+  }
 
   void RunAsync(StatusCallback done) override;