Adding TFConcreteFunction class for ConcreteFunction reloading in the SavedModel C API.
PiperOrigin-RevId: 320127741
Change-Id: I002c77dd23dc8d85d1b088e2dbe6dfc8088d2b77
diff --git a/tensorflow/c/experimental/saved_model/core/BUILD b/tensorflow/c/experimental/saved_model/core/BUILD
index 34d9167..5452907 100644
--- a/tensorflow/c/experimental/saved_model/core/BUILD
+++ b/tensorflow/c/experimental/saved_model/core/BUILD
@@ -19,6 +19,9 @@
cc_library(
name = "concrete_function",
+ srcs = [
+ "concrete_function.cc",
+ ],
hdrs = [
"concrete_function.h",
],
@@ -26,6 +29,7 @@
":function_metadata",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
+ "//tensorflow/core:protos_all_cc",
],
)
@@ -56,13 +60,10 @@
"saved_model_utils.h",
],
deps = [
- ":function_metadata",
"//tensorflow/c:tf_tensor_internal",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
- "//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function",
"//tensorflow/c/experimental/saved_model/core/revived_types:variable",
- "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
@@ -91,18 +92,6 @@
)
cc_library(
- name = "tf_concrete_function_test_protos",
- testonly = True,
- srcs = ["tf_concrete_function_test_protos.cc"],
- hdrs = ["tf_concrete_function_test_protos.h"],
- deps = [
- "//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
- "@com_google_absl//absl/strings",
- ],
-)
-
-cc_library(
name = "tf_saved_model_impl",
srcs = [
"tf_saved_model_impl.cc",
@@ -125,16 +114,12 @@
"saved_model_api.h",
],
visibility = ["//tensorflow/python:__pkg__"],
- deps = [
- "//tensorflow/c/eager:immediate_execution_operation",
- "//tensorflow/c/eager:immediate_execution_tensor_handle",
- "//tensorflow/core:lib",
- ],
)
filegroup(
name = "mobile_srcs_only_runtime",
srcs = [
+ "concrete_function.cc",
"concrete_function.h",
"function_metadata.h",
"saved_model_api.h",
@@ -187,28 +172,3 @@
"//tensorflow/core/common_runtime/eager:core",
],
)
-
-tf_cc_test(
- name = "tf_concrete_function_loading_test",
- srcs = [
- "tf_concrete_function_loading_test.cc",
- ],
- deps = [
- ":saved_model_utils",
- ":test_utils",
- ":tf_concrete_function_test_protos",
- "//tensorflow/c:tensor_interface",
- "//tensorflow/c/eager:immediate_execution_tensor_handle",
- "//tensorflow/c/experimental/saved_model/core/revived_types:constant",
- "//tensorflow/c/experimental/saved_model/core/revived_types:tensorhandle_convertible",
- "//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- "//tensorflow/core/common_runtime:core_cpu_lib",
- "//tensorflow/core/common_runtime/eager:context",
- "//tensorflow/core/common_runtime/eager:core",
- ],
-)
diff --git a/tensorflow/c/experimental/saved_model/core/concrete_function.cc b/tensorflow/c/experimental/saved_model/core/concrete_function.cc
new file mode 100644
index 0000000..41bae43
--- /dev/null
+++ b/tensorflow/c/experimental/saved_model/core/concrete_function.cc
@@ -0,0 +1,32 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
+
+#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
+#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
+
+namespace tensorflow {
+
+const std::vector<tensorflow::ImmediateExecutionTensorHandle*>&
+ConcreteFunction::GetCaptures() const {
+ return captures_;
+}
+
+const FunctionMetadata& ConcreteFunction::GetFunctionMetadata() const {
+ return metadata_;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/c/experimental/saved_model/core/concrete_function.h b/tensorflow/c/experimental/saved_model/core/concrete_function.h
index 2cc627b..2253564 100644
--- a/tensorflow/c/experimental/saved_model/core/concrete_function.h
+++ b/tensorflow/c/experimental/saved_model/core/concrete_function.h
@@ -16,12 +16,12 @@
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_
-#include <memory>
#include <vector>
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
+#include "tensorflow/core/framework/function.pb.h"
namespace tensorflow {
@@ -35,14 +35,19 @@
// and have only a single implementation.
class ConcreteFunction {
public:
- virtual ~ConcreteFunction() = default;
+ virtual ~ConcreteFunction() = 0;
// This method returns the "Call" Op used to execute the function.
- virtual Status GetCallOp(ImmediateOpPtr* out) = 0;
+ virtual ImmediateExecutionOperation* GetCallOp() = 0;
- virtual const std::vector<ImmediateExecutionTensorHandle*>& GetCaptures()
- const = 0;
- virtual const FunctionMetadata& GetFunctionMetadata() const = 0;
+ const std::vector<tensorflow::ImmediateExecutionTensorHandle*>& GetCaptures()
+ const;
+ const FunctionMetadata& GetFunctionMetadata() const;
+
+ private:
+ FunctionMetadata metadata_;
+ std::vector<tensorflow::ImmediateExecutionTensorHandle*> captures_;
+ FunctionDef* function_;
};
} // namespace tensorflow
diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD
index 8bb1567..84fad2e 100644
--- a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD
+++ b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD
@@ -58,24 +58,3 @@
"//tensorflow/c/eager:immediate_execution_tensor_handle",
],
)
-
-cc_library(
- name = "tf_concrete_function",
- srcs = [
- "tf_concrete_function.cc",
- ],
- hdrs = [
- "tf_concrete_function.h",
- ],
- deps = [
- ":tensorhandle_convertible",
- "//tensorflow/c/eager:immediate_execution_context",
- "//tensorflow/c/eager:immediate_execution_operation",
- "//tensorflow/c/eager:immediate_execution_tensor_handle",
- "//tensorflow/c/experimental/saved_model/core:concrete_function",
- "//tensorflow/c/experimental/saved_model/core:function_metadata",
- "//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core/common_runtime/eager:context",
- ],
-)
diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc
deleted file mode 100644
index aa6f0e7..0000000
--- a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc
+++ /dev/null
@@ -1,87 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
-
-#include <memory>
-#include <string>
-
-#include "tensorflow/c/eager/immediate_execution_operation.h"
-#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
-#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
-#include "tensorflow/core/common_runtime/eager/context.h"
-#include "tensorflow/core/framework/function.pb.h"
-#include "tensorflow/core/platform/errors.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/status.h"
-#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
-#include "tensorflow/core/protobuf/struct.pb.h"
-
-namespace tensorflow {
-
-TFConcreteFunction::TFConcreteFunction(
- const std::string& name,
- std::vector<ImmediateExecutionTensorHandle*> captures,
- FunctionMetadata metadata, ImmediateExecutionContext* ctx)
- : name_(name),
- captures_(std::move(captures)),
- metadata_(std::move(metadata)),
- ctx_(ctx) {}
-
-TFConcreteFunction::~TFConcreteFunction() {
- Status status = ctx_->RemoveFunction(name_);
- if (!status.ok()) {
- LOG(ERROR) << "Failed to remove functiondef " << name_ << ". "
- << status.error_message();
- }
-}
-
-Status TFConcreteFunction::Create(
- const FunctionDef* function_def,
- std::vector<ImmediateExecutionTensorHandle*> captures,
- FunctionMetadata metadata, ImmediateExecutionContext* ctx,
- std::unique_ptr<TFConcreteFunction>* out) {
- TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def));
- out->reset(new TFConcreteFunction(function_def->signature().name(),
- std::move(captures), std::move(metadata),
- ctx));
- return Status();
-}
-
-const std::vector<ImmediateExecutionTensorHandle*>&
-TFConcreteFunction::GetCaptures() const {
- return captures_;
-}
-
-const FunctionMetadata& TFConcreteFunction::GetFunctionMetadata() const {
- return metadata_;
-}
-
-Status TFConcreteFunction::GetCallOp(ImmediateOpPtr* out) {
- out->reset(ctx_->CreateOperation());
- // In eager mode, TF2 python executes functions by constructing an op with
- // the name of the functiondef:
- // https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545
- // In graph mode, we create a PartitionedCallOp instead:
- // https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573
-
- // TODO(bmzhao): After discussing with Allen, we should execute this via a
- // PartitionedCallOp for compatibility with "tooling that assumes functions in
- // graphs are PartitionedCallOps".
- TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr));
- return Status();
-}
-
-} // namespace tensorflow
diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h
deleted file mode 100644
index 71c8322..0000000
--- a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h
+++ /dev/null
@@ -1,87 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_CONCRETE_FUNCTION_H_
-#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_CONCRETE_FUNCTION_H_
-
-#include <functional>
-#include <memory>
-#include <string>
-#include <unordered_map>
-#include <vector>
-
-#include "tensorflow/c/eager/immediate_execution_context.h"
-#include "tensorflow/c/eager/immediate_execution_operation.h"
-#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
-#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
-#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
-#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
-#include "tensorflow/core/framework/function.pb.h"
-#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
-
-namespace tensorflow {
-
-// TF Eager Runtime-based implementation of a "ConcreteFunction" loaded from a
-// saved model.
-class TFConcreteFunction : public ConcreteFunction {
- public:
- // Factory function for creating a TFConcreteFunction.
- //
- // Params:
- // function_def - The function_def associated with the created
- // TFConcreteFunction. TFConcreteFunction will register this
- // function_def with `ctx` on creation, and de-register it on
- // destruction. function_def must be non-null, but
- // otherwise has no lifetime requirements.
- // captures - The captured TensorHandles associated with this
- // TFConcreteFunction.
- // metadata - The FunctionMetadata associated with this TFConcreteFunction.
- // ctx - A handle to the Tensorflow runtime. This MUST be non-null and
- // outlive TFConcreteFunction.
- // out - The output TFConcreteFunction.
- static Status Create(const FunctionDef* function_def,
- std::vector<ImmediateExecutionTensorHandle*> captures,
- FunctionMetadata metadata,
- ImmediateExecutionContext* ctx,
- std::unique_ptr<TFConcreteFunction>* out);
-
- // This method returns the "Call" Op used to execute the function.
- Status GetCallOp(ImmediateOpPtr* out) override;
-
- const std::vector<ImmediateExecutionTensorHandle*>& GetCaptures()
- const override;
-
- const FunctionMetadata& GetFunctionMetadata() const override;
-
- ~TFConcreteFunction() override;
-
- private:
- TFConcreteFunction(const std::string& name,
- std::vector<ImmediateExecutionTensorHandle*> captures,
- FunctionMetadata metadata, ImmediateExecutionContext* ctx);
-
- TFConcreteFunction(const TFConcreteFunction&) = delete;
- TFConcreteFunction& operator=(const TFConcreteFunction&) = delete;
-
- // Name of the FunctionDef corresponding to this TFConcreteFunction
- std::string name_;
- std::vector<ImmediateExecutionTensorHandle*> captures_;
- FunctionMetadata metadata_;
- ImmediateExecutionContext* ctx_;
-};
-
-} // namespace tensorflow
-
-#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_CONCRETE_FUNCTION_H_
diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc
index 4b1d767..196420e 100644
--- a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc
+++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc
@@ -17,125 +17,14 @@
#include <memory>
-#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
#include "tensorflow/c/tf_tensor_internal.h"
#include "tensorflow/core/framework/tensor.pb.h"
-#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
-#include "tensorflow/core/protobuf/struct.pb.h"
namespace tensorflow {
namespace internal {
-namespace {
-
-// This returns the size of `tf.nest.flatten(value)`, on values that are
-// used in tf.function's input_signatures.
-int FlattenedSize(const tensorflow::StructuredValue& value, Status* status) {
- // This follows the logic from
- // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2775
- switch (value.kind_case()) {
- case StructuredValue::kDictValue: {
- const DictValue& dict = value.dict_value();
- int size = 0;
- for (const auto& field : dict.fields()) {
- size += FlattenedSize(field.second, status);
- }
- return size;
- }
- case StructuredValue::kTupleValue: {
- const TupleValue& tuple = value.tuple_value();
- int size = 0;
- for (const StructuredValue& value : tuple.values()) {
- size += FlattenedSize(value, status);
- }
- return size;
- }
- case StructuredValue::kListValue: {
- const ListValue& list = value.list_value();
- int size = 0;
- for (const StructuredValue& value : list.values()) {
- size += FlattenedSize(value, status);
- }
- return size;
- }
- case StructuredValue::kTensorSpecValue: {
- return 1;
- }
- case StructuredValue::kNoneValue: {
- // Base case: do nothing.
- // This arises, for example, as the top-level object of an output
- // signature when there are no return values.
- return 0;
- }
- default: {
- status->Update(errors::Internal("Unhandled structured value kind ",
- value.kind_case()));
- return 0;
- }
- }
-}
-
-// Perform some basic sanity checks on SavedConcreteFunction's input and
-// output signatures with respect to the corresponding FunctionDef's input
-// and output args.
-Status ValidateSavedFunctionCompatibleWithFunctionDef(
- const SavedConcreteFunction& saved_concrete_function,
- const FunctionDef* function_def) {
- // tf.functions go through many transformations before becoming FunctionDefs
- // 1. flatten user-provided inputs:
- // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L2671-L2675
- // 2. convert user-provided inputs to tensors:
- // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L2687-L2688
- // 3. filter any non-tensor, non-variable inputs:
- // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L1840-L1841
- // 4. concatenate any captured inputs:
- // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L1912
-
- // Since our API is limited to tf.functions annotated with input signatures,
- // conditions 2 and 3 are trivially satisfied.
- // We need to ensure that:
- // flatten(input_signature).size() + captures.size() = fdef.signature().size()
- // A concrete function's serialized "canonicalized_input_signature" comes
- // from encoding its "structured_input_signature" field:
- // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/saved_model/function_serialization.py#L70-L71
- // The "structured_input_signature" is guaranteed to be a tuple of the python
- // args, kwargs that correspond to the tf.function:
- // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L1974-L1979
-
- const std::string& name = function_def->signature().name();
- const StructuredValue& input_signature =
- saved_concrete_function.canonicalized_input_signature();
- Status status;
- int input_signature_size = FlattenedSize(input_signature, &status);
- TF_RETURN_IF_ERROR(status);
- if (input_signature_size + saved_concrete_function.bound_inputs_size() !=
- function_def->signature().input_arg_size()) {
- return errors::FailedPrecondition(
- "FunctionDef ", name, " has ",
- function_def->signature().input_arg_size(),
- " inputs, but the SavedConcreteFunction has ", input_signature_size,
- " flattened user inputs and ",
- saved_concrete_function.bound_inputs_size(), " captured inputs.");
- }
-
- const StructuredValue& output_signature =
- saved_concrete_function.output_signature();
- int output_signature_size = FlattenedSize(output_signature, &status);
- TF_RETURN_IF_ERROR(status);
- if (output_signature_size != function_def->signature().output_arg_size()) {
- return errors::FailedPrecondition(
- "FunctionDef ", name, " has ",
- function_def->signature().output_arg_size(),
- " outputs, but the SavedConcreteFunction has ", output_signature_size,
- " flattened outputs.");
- }
-
- return status;
-}
-
-} // namespace
Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
const TensorProto& proto,
@@ -165,31 +54,5 @@
return Status();
}
-Status LoadTFConcreteFunction(
- const SavedConcreteFunction& saved_concrete_function,
- const FunctionDef* function_def,
- const std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>&
- captured_objects,
- ImmediateExecutionContext* ctx, std::unique_ptr<TFConcreteFunction>* out) {
- TF_RETURN_IF_ERROR(ValidateSavedFunctionCompatibleWithFunctionDef(
- saved_concrete_function, function_def));
-
- // Copy over captures
- std::vector<ImmediateExecutionTensorHandle*> captures;
- captures.reserve(saved_concrete_function.bound_inputs_size());
- for (int bound_input : saved_concrete_function.bound_inputs()) {
- auto iter = captured_objects.find(bound_input);
- if (iter == captured_objects.end()) {
- return errors::FailedPrecondition("Failed to find bound_input ",
- bound_input,
- " for SavedConcreteFunction");
- }
- captures.push_back(iter->second->handle());
- }
-
- return TFConcreteFunction::Create(function_def, std::move(captures), {}, ctx,
- out);
-}
-
} // namespace internal
} // namespace tensorflow
diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.h b/tensorflow/c/experimental/saved_model/core/saved_model_utils.h
index 89a959a..ab15317 100644
--- a/tensorflow/c/experimental/saved_model/core/saved_model_utils.h
+++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.h
@@ -21,7 +21,6 @@
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
-#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
@@ -44,14 +43,6 @@
const SavedVariable& variable,
std::unique_ptr<Variable>* output);
-// Creates a TFConcreteFunction from a SavedConcreteFunction.
-Status LoadTFConcreteFunction(
- const SavedConcreteFunction& saved_concrete_function,
- const FunctionDef* function_def,
- const std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>&
- captured_objects,
- ImmediateExecutionContext* ctx, std::unique_ptr<TFConcreteFunction>* out);
-
} // namespace internal
} // namespace tensorflow
diff --git a/tensorflow/c/experimental/saved_model/core/tf_concrete_function_loading_test.cc b/tensorflow/c/experimental/saved_model/core/tf_concrete_function_loading_test.cc
deleted file mode 100644
index 05fbac1..0000000
--- a/tensorflow/c/experimental/saved_model/core/tf_concrete_function_loading_test.cc
+++ /dev/null
@@ -1,271 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <memory>
-#include <unordered_map>
-
-#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
-#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
-#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
-#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
-#include "tensorflow/c/experimental/saved_model/core/test_utils.h"
-#include "tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h"
-#include "tensorflow/core/framework/function.pb.h"
-#include "tensorflow/core/framework/op_def.pb.h"
-#include "tensorflow/core/framework/types.pb.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/platform/status.h"
-#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/protobuf/error_codes.pb.h"
-#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
-
-namespace tensorflow {
-namespace {
-
-class SavedConcreteFunctionLoadingTest : public ::testing::Test {
- public:
- SavedConcreteFunctionLoadingTest()
- : device_mgr_(testing::CreateTestingDeviceMgr()),
- ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {}
-
- EagerContext* context() { return ctx_.get(); }
-
- private:
- std::unique_ptr<StaticDeviceMgr> device_mgr_;
- EagerContextPtr ctx_;
-};
-
-class DummyCapture : public TensorHandleConvertible {
- public:
- DummyCapture(ImmediateExecutionContext* ctx, int8 value)
- : TensorHandleConvertible(
- testing::CreateTensorHandle(ctx, DT_FLOAT, {2, 4}, value)) {}
-};
-
-FunctionDef FuncDefWithNumInputsOutputs(int num_inputs, int num_outputs) {
- FunctionDef func;
- OpDef* signature = func.mutable_signature();
- for (int i = 0; i < num_inputs; ++i) {
- signature->add_input_arg();
- }
- for (int i = 0; i < num_outputs; ++i) {
- signature->add_output_arg();
- }
- return func;
-}
-
-// A SavedConcreteFunction whose canonicalized input signature
-// has less inputs than its corresponding FunctionDef should cause an error.
-TEST_F(SavedConcreteFunctionLoadingTest, TooFewInputsInSavedConcreteFunction) {
- // `saved` has 1 input
- SavedConcreteFunction saved;
- *saved.mutable_canonicalized_input_signature() =
- testing::SingleArgInputSignature();
- *saved.mutable_output_signature() = testing::ZeroReturnOutputSignature();
-
- // `func` has 2 inputs
- FunctionDef func = FuncDefWithNumInputsOutputs(2, 0);
-
- std::unique_ptr<TFConcreteFunction> result;
- Status status =
- internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result);
- EXPECT_EQ(status.code(), error::FAILED_PRECONDITION)
- << status.error_message();
-}
-
-// A SavedConcreteFunction whose canonicalized input signature length +
-// captures is less than its corresponding FunctionDef should cause an error.
-TEST_F(SavedConcreteFunctionLoadingTest,
- TooFewInputsWithCapturesInSavedConcreteFunction) {
- // `saved` has 1 input, and 1 capture, for a total of 2 inputs
- SavedConcreteFunction saved;
- *saved.mutable_canonicalized_input_signature() =
- testing::SingleArgInputSignature();
- *saved.mutable_output_signature() = testing::ZeroReturnOutputSignature();
- saved.add_bound_inputs(5);
-
- // `func` has 3 inputs
- FunctionDef func = FuncDefWithNumInputsOutputs(3, 0);
-
- std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>> captures;
- captures[5] = std::make_unique<DummyCapture>(context(), 10);
-
- std::unique_ptr<TFConcreteFunction> result;
- Status status = internal::LoadTFConcreteFunction(saved, &func, captures,
- context(), &result);
- EXPECT_EQ(status.code(), error::FAILED_PRECONDITION)
- << status.error_message();
-}
-
-// A SavedConcreteFunction whose canonicalized input signature
-// has more inputs than its corresponding FunctionDef should cause an error.
-TEST_F(SavedConcreteFunctionLoadingTest, TooManyInputsInSavedConcreteFunction) {
- // `saved` has 3 inputs
- SavedConcreteFunction saved;
- *saved.mutable_canonicalized_input_signature() =
- testing::ThreeArgInputSignature();
- *saved.mutable_output_signature() = testing::ZeroReturnOutputSignature();
-
- // `func` has 2 inputs
- FunctionDef func = FuncDefWithNumInputsOutputs(2, 0);
-
- std::unique_ptr<TFConcreteFunction> result;
- Status status =
- internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result);
- EXPECT_EQ(status.code(), error::FAILED_PRECONDITION)
- << status.error_message();
-}
-
-// A SavedConcreteFunction whose canonicalized input signature
-// has the same number of inputs than its corresponding FunctionDef, but has
-// additional captures should cause an error.
-TEST_F(SavedConcreteFunctionLoadingTest,
- TooManyInputsWithCaptureInSavedConcreteFunction) {
- // `saved` has 3 inputs, and 1 capture, for a total of 4 inputs.
- SavedConcreteFunction saved;
- *saved.mutable_canonicalized_input_signature() =
- testing::ThreeArgInputSignature();
- *saved.mutable_output_signature() = testing::ZeroReturnOutputSignature();
- saved.add_bound_inputs(5);
-
- // `func` has 3 inputs.
- FunctionDef func = FuncDefWithNumInputsOutputs(3, 0);
-
- std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>> captures;
- captures[5] = std::make_unique<DummyCapture>(context(), 10);
-
- std::unique_ptr<TFConcreteFunction> result;
- Status status = internal::LoadTFConcreteFunction(saved, &func, captures,
- context(), &result);
- EXPECT_EQ(status.code(), error::FAILED_PRECONDITION)
- << status.error_message();
-}
-
-// A SavedConcreteFunction whose capture refers to an index not in the capture
-// map should cause an error.
-TEST_F(SavedConcreteFunctionLoadingTest, ImproperCaptureIndex) {
- // `saved` has 3 inputs, 1 capture, for a total of 4 inputs
- SavedConcreteFunction saved;
- *saved.mutable_canonicalized_input_signature() =
- testing::ThreeArgInputSignature();
- *saved.mutable_output_signature() = testing::ZeroReturnOutputSignature();
- // Capture is at index "10"
- saved.add_bound_inputs(10);
-
- // `func` has 4 inputs
- FunctionDef func = FuncDefWithNumInputsOutputs(4, 0);
-
- // `captures` only has a capture for index 5
- std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>> captures;
- captures[5] = std::make_unique<DummyCapture>(context(), 10);
-
- std::unique_ptr<TFConcreteFunction> result;
- Status status = internal::LoadTFConcreteFunction(saved, &func, captures,
- context(), &result);
- EXPECT_EQ(status.code(), error::FAILED_PRECONDITION)
- << status.error_message();
-}
-
-// A SavedConcreteFunction whose outputs are fewer than its corresponding
-// functiondef should cause an error.
-TEST_F(SavedConcreteFunctionLoadingTest, TooFewOutputsInSavedConcreteFunction) {
- // `saved` has 0 inputs, 1 output
- SavedConcreteFunction saved;
- *saved.mutable_canonicalized_input_signature() =
- testing::ZeroArgInputSignature();
- *saved.mutable_output_signature() = testing::SingleReturnOutputSignature();
-
- // `func` has 0 inputs, 2 outputs
- FunctionDef func = FuncDefWithNumInputsOutputs(0, 2);
-
- std::unique_ptr<TFConcreteFunction> result;
- Status status =
- internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result);
- EXPECT_EQ(status.code(), error::FAILED_PRECONDITION)
- << status.error_message();
-}
-
-// A SavedConcreteFunction whose outputs exceed its corresponding functiondef
-// should cause an error.
-TEST_F(SavedConcreteFunctionLoadingTest,
- TooManyOutputsInSavedConcreteFunction) {
- // `saved` has 1 input, 3 outputs
- SavedConcreteFunction saved;
- *saved.mutable_canonicalized_input_signature() =
- testing::SingleArgInputSignature();
- *saved.mutable_output_signature() = testing::ThreeReturnOutputSignature();
-
- // `func` has 1 input, 2 outputs
- FunctionDef func = FuncDefWithNumInputsOutputs(1, 2);
-
- std::unique_ptr<TFConcreteFunction> result;
- Status status =
- internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result);
- EXPECT_EQ(status.code(), error::FAILED_PRECONDITION)
- << status.error_message();
-}
-
-// A SavedConcreteFunction whose (inputs + captures) = functiondef inputs,
-// and whose outputs = functiondef outputs should successfully load.
-TEST_F(SavedConcreteFunctionLoadingTest, SuccessfulLoad) {
- // `saved` has 1 input, 2 captures, 3 outputs
- SavedConcreteFunction saved;
- *saved.mutable_canonicalized_input_signature() =
- testing::SingleArgInputSignature();
- *saved.mutable_output_signature() = testing::ThreeReturnOutputSignature();
- saved.add_bound_inputs(2);
- saved.add_bound_inputs(5);
-
- // `func` has 3 inputs, 3 outputs
- FunctionDef func = FuncDefWithNumInputsOutputs(3, 3);
-
- std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>> captures;
- captures[2] = std::make_unique<DummyCapture>(context(), 1);
- captures[5] = std::make_unique<DummyCapture>(context(), 10);
-
- std::unique_ptr<TFConcreteFunction> result;
- Status status = internal::LoadTFConcreteFunction(saved, &func, captures,
- context(), &result);
- TF_EXPECT_OK(status) << status.error_message();
-}
-
-// A TFConcreteFunction should register functiondefs on creation, and
-// remove them upon deletion.
-TEST_F(SavedConcreteFunctionLoadingTest, RegistersAndRemovesFunctionDefs) {
- std::string func_name = "FooBarBazWombatFunction";
-
- SavedConcreteFunction saved;
- *saved.mutable_canonicalized_input_signature() =
- testing::ZeroArgInputSignature();
- *saved.mutable_output_signature() = testing::ZeroReturnOutputSignature();
- FunctionDef func = FuncDefWithNumInputsOutputs(0, 0);
- *func.mutable_signature()->mutable_name() = func_name;
-
- {
- std::unique_ptr<TFConcreteFunction> result;
- Status status =
- internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result);
- TF_EXPECT_OK(status) << status.error_message();
- // The function should be registered with context.
- EXPECT_TRUE(context()->FindFunctionByName(func_name));
- }
-
- // After `result's` destructor runs, the function should no longer be
- // registered with context.
- EXPECT_FALSE(context()->FindFunctionByName(func_name));
-}
-
-} // namespace
-} // namespace tensorflow
diff --git a/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.cc b/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.cc
deleted file mode 100644
index dc69cf6..0000000
--- a/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.cc
+++ /dev/null
@@ -1,212 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h"
-
-#include <string>
-
-#include "absl/strings/string_view.h"
-#include "tensorflow/core/platform/protobuf.h"
-#include "tensorflow/core/protobuf/struct.pb.h"
-
-namespace tensorflow {
-namespace testing {
-namespace {
-
-constexpr absl::string_view kZeroArgInputSignatureTextProto = R"(
-tuple_value: {
- values: {
- tuple_value: {
- }
- }
- values: {
- dict_value: {
- }
- }
-}
-)";
-
-constexpr absl::string_view kSingleArgInputSignatureTextProto = R"(
-tuple_value: {
- values: {
- tuple_value: {
- values: {
- tensor_spec_value: {
- name : "x"
- shape: {
- dim: {
- size: 1
- }
- dim: {
- size: 10
- }
- }
- dtype: DT_FLOAT
- }
- }
- }
- }
- values: {
- dict_value: {
- }
- }
-}
-)";
-
-constexpr absl::string_view kThreeArgInputSignatureTextProto = R"(
-tuple_value: {
- values: {
- tuple_value: {
- values: {
- tensor_spec_value: {
- name : "x"
- shape: {
- dim: {
- size: 1
- }
- }
- dtype: DT_FLOAT
- }
- }
- values: {
- tensor_spec_value: {
- name : "y"
- shape: {
- dim: {
- size: 1
- }
- }
- dtype: DT_FLOAT
- }
- }
- values: {
- tensor_spec_value: {
- name : "z"
- shape: {
- dim: {
- size: 1
- }
- }
- dtype: DT_FLOAT
- }
- }
- }
- }
- values: {
- dict_value: {
- }
- }
-}
-
-)";
-
-constexpr absl::string_view kZeroReturnOutputSignatureTextProto = R"(
-none_value: {}
-)";
-
-constexpr absl::string_view kSingleReturnOutputSignatureTextProto = R"(
-tensor_spec_value: {
- shape: {
- dim: {
- size: 1
- }
- }
- dtype: DT_FLOAT
-}
-)";
-
-constexpr absl::string_view kThreeReturnOutputSignatureTextProto = R"(
-tuple_value: {
- values: {
- dict_value: {
- fields: {
- key : "a"
- value: {
- tensor_spec_value: {
- name : "0/a"
- shape: {
- dim: {
- size: 1
- }
- }
- dtype: DT_FLOAT
- }
- }
- }
- fields: {
- key : "b"
- value: {
- tensor_spec_value: {
- name : "0/b"
- shape: {
- dim: {
- size: 1
- }
- }
- dtype: DT_FLOAT
- }
- }
- }
- }
- }
- values: {
- tensor_spec_value: {
- name : "1"
- shape: {
- dim: {
- size: 1
- }
- }
- dtype: DT_FLOAT
- }
- }
-}
-)";
-
-StructuredValue ParseStructuredValue(absl::string_view text_proto) {
- StructuredValue value;
- CHECK(tensorflow::protobuf::TextFormat::ParseFromString(string(text_proto),
- &value));
- return value;
-}
-
-} // namespace
-
-StructuredValue ZeroArgInputSignature() {
- return ParseStructuredValue(kZeroArgInputSignatureTextProto);
-}
-
-StructuredValue SingleArgInputSignature() {
- return ParseStructuredValue(kSingleArgInputSignatureTextProto);
-}
-
-StructuredValue ThreeArgInputSignature() {
- return ParseStructuredValue(kThreeArgInputSignatureTextProto);
-}
-
-StructuredValue ZeroReturnOutputSignature() {
- return ParseStructuredValue(kZeroReturnOutputSignatureTextProto);
-}
-
-StructuredValue SingleReturnOutputSignature() {
- return ParseStructuredValue(kSingleReturnOutputSignatureTextProto);
-}
-
-StructuredValue ThreeReturnOutputSignature() {
- return ParseStructuredValue(kThreeReturnOutputSignatureTextProto);
-}
-
-} // namespace testing
-} // namespace tensorflow
diff --git a/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h b/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h
deleted file mode 100644
index 8aa7d56..0000000
--- a/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h
+++ /dev/null
@@ -1,50 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_TF_CONCRETE_FUNCTION_TEST_PROTOS_H_
-#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_TF_CONCRETE_FUNCTION_TEST_PROTOS_H_
-
-#include "tensorflow/core/protobuf/struct.pb.h"
-
-namespace tensorflow {
-namespace testing {
-
-// Returns a StructuredValue corresponding to the serialized InputSignature of a
-// tf.function with 0 inputs
-StructuredValue ZeroArgInputSignature();
-
-// Returns a StructuredValue corresponding to the serialized InputSignature of a
-// tf.function with 1 input
-StructuredValue SingleArgInputSignature();
-
-// Returns a StructuredValue corresponding to the serialized InputSignature of a
-// tf.function with 3 inputs
-StructuredValue ThreeArgInputSignature();
-
-// Returns a StructuredValue corresponding to the serialized OutputSignature of
-// a tf.function with no return values
-StructuredValue ZeroReturnOutputSignature();
-
-// Returns a StructuredValue corresponding to the serialized OutputSignature of
-// a tf.function with a single tensor output
-StructuredValue SingleReturnOutputSignature();
-
-// Returns a StructuredValue corresponding to the serialized OutputSignature of
-// a tf.function with three tensor outputs
-StructuredValue ThreeReturnOutputSignature();
-
-} // namespace testing
-} // namespace tensorflow
-#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_TF_CONCRETE_FUNCTION_TEST_PROTOS_H_
diff --git a/tensorflow/c/experimental/saved_model/internal/BUILD b/tensorflow/c/experimental/saved_model/internal/BUILD
index 6be2a02..888c284 100644
--- a/tensorflow/c/experimental/saved_model/internal/BUILD
+++ b/tensorflow/c/experimental/saved_model/internal/BUILD
@@ -41,13 +41,11 @@
":tensorhandle_list",
":tensorhandle_list_type",
"//tensorflow/c:c_api_macros",
- "//tensorflow/c:tf_status_internal",
"//tensorflow/c/eager:c_api",
- "//tensorflow/c/eager:immediate_execution_operation",
+ "//tensorflow/c/eager:c_api_internal",
"//tensorflow/c/eager:tfe_op_internal",
"//tensorflow/c/experimental/saved_model/core:concrete_function",
"//tensorflow/c/experimental/saved_model/core:function_metadata",
- "//tensorflow/core:lib",
],
)
@@ -207,13 +205,9 @@
],
deps = [
"//tensorflow/c:tf_status",
- "//tensorflow/c:tf_tensor",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
- "//tensorflow/c/eager:c_api_test_util",
- "//tensorflow/c/experimental/saved_model/public:concrete_function",
"//tensorflow/c/experimental/saved_model/public:saved_model_api",
- "//tensorflow/c/experimental/saved_model/public:tensorhandle_list",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
diff --git a/tensorflow/c/experimental/saved_model/internal/concrete_function.cc b/tensorflow/c/experimental/saved_model/internal/concrete_function.cc
index 12d4921..dd54416 100644
--- a/tensorflow/c/experimental/saved_model/internal/concrete_function.cc
+++ b/tensorflow/c/experimental/saved_model/internal/concrete_function.cc
@@ -15,15 +15,12 @@
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
-#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
#include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h"
#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h"
-#include "tensorflow/c/tf_status_internal.h"
-#include "tensorflow/core/platform/status.h"
extern "C" {
@@ -37,11 +34,8 @@
return tensorflow::wrap(&tensorflow::unwrap(func)->GetCaptures());
}
-TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func,
- TF_Status* status) {
- tensorflow::ImmediateOpPtr call_op(nullptr);
- status->status = tensorflow::unwrap(func)->GetCallOp(&call_op);
- return tensorflow::wrap(call_op.release());
+TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func) {
+ return tensorflow::wrap(tensorflow::unwrap(func)->GetCallOp());
}
} // end extern "C"
diff --git a/tensorflow/c/experimental/saved_model/public/concrete_function.h b/tensorflow/c/experimental/saved_model/public/concrete_function.h
index 944ddec..2a87214 100644
--- a/tensorflow/c/experimental/saved_model/public/concrete_function.h
+++ b/tensorflow/c/experimental/saved_model/public/concrete_function.h
@@ -41,7 +41,7 @@
// Returns a TFE_Op suitable for executing this function.
TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionGetCallOp(
- TF_ConcreteFunction* func, TF_Status* status);
+ TF_ConcreteFunction* func);
#ifdef __cplusplus
} // end extern "C"