Updating C++ Ops (experimental) to add first IO ops, along with required fixes to list types.
Specifically:
- adding support for additional types of list Attrs
- generalizing and cleaning up support for types and list-types
- fixing the list-type Ops to only depend on the output being a list
The IO ops now being generated are RestoreV2 and SaveV2.
PiperOrigin-RevId: 449260319
diff --git a/tensorflow/c/experimental/ops/gen/cpp/views/attr_view.cc b/tensorflow/c/experimental/ops/gen/cpp/views/attr_view.cc
index af1feb1..acff747 100644
--- a/tensorflow/c/experimental/ops/gen/cpp/views/attr_view.cc
+++ b/tensorflow/c/experimental/ops/gen/cpp/views/attr_view.cc
@@ -28,18 +28,29 @@
string AttrView::VariableName() const { return attr_.name(); }
string AttrView::VariableType() const {
+ // Completely special cases (e.g. strings are different when lists)
if (attr_.full_type() == "string") {
return "const char*";
}
- if (attr_.full_type() == "type") {
- return "DataType";
- }
- if (attr_.full_type() == "shape") {
- return "const PartialTensorShape";
- }
if (attr_.full_type() == "list(string)") {
return "absl::Span<string const>";
}
+
+ // Normal path: translate base type to C++ ...
+ string c_base_type = attr_.base_type();
+ if (attr_.base_type() == "type") {
+ c_base_type = "DataType";
+ } else if (attr_.base_type() == "shape") {
+ c_base_type = "const PartialTensorShape";
+ }
+
+ // ... and wrap in a Span<> if it's a list.
+ if (attr_.is_list()) {
+ return absl::Substitute("absl::Span<$0>", c_base_type);
+ } else {
+ return c_base_type;
+ }
+
return attr_.full_type();
}
@@ -80,6 +91,14 @@
return Call("strlen", {VariableName()});
}
+string AttrView::VariableSpanData() const {
+ return Call(VariableName(), "data", {}, ".");
+}
+
+string AttrView::VariableSpanLen() const {
+ return Call(VariableName(), "length", {}, ".");
+}
+
string AttrView::InputArg(bool with_default_value) const {
string default_value = DefaultValue();
if (!with_default_value || default_value.empty()) {
@@ -98,10 +117,14 @@
}
std::vector<string> AttrView::SetterArgs() const {
- if (attr_.full_type() != "string") {
- return {AttrNameString(), VariableName()};
- } else {
+ if (attr_.full_type() == "string") {
return {AttrNameString(), VariableName(), VariableStrLen()};
+ } else if (attr_.full_type() == "list(string)") {
+ return {AttrNameString(), VariableName()}; // accepts span directly
+ } else if (attr_.is_list()) {
+ return {AttrNameString(), VariableSpanData(), VariableSpanLen()};
+ } else {
+ return {AttrNameString(), VariableName()};
}
}
diff --git a/tensorflow/c/experimental/ops/gen/cpp/views/attr_view.h b/tensorflow/c/experimental/ops/gen/cpp/views/attr_view.h
index 4281932..194c70e 100644
--- a/tensorflow/c/experimental/ops/gen/cpp/views/attr_view.h
+++ b/tensorflow/c/experimental/ops/gen/cpp/views/attr_view.h
@@ -30,6 +30,8 @@
string VariableType() const;
string AttrNameString() const;
string VariableStrLen() const;
+ string VariableSpanData() const;
+ string VariableSpanLen() const;
string DefaultValue() const;
string InputArg(bool with_default_value) const;
string SetterMethod() const;
diff --git a/tensorflow/c/experimental/ops/gen/cpp/views/op_view.cc b/tensorflow/c/experimental/ops/gen/cpp/views/op_view.cc
index 0ee532f..2f5f028 100644
--- a/tensorflow/c/experimental/ops/gen/cpp/views/op_view.cc
+++ b/tensorflow/c/experimental/ops/gen/cpp/views/op_view.cc
@@ -85,8 +85,7 @@
// Context
bool OpView::IsListOp() const {
- return NumInputs() == 1 && OnlyInput().IsList() && NumOutputs() == 1 &&
- OnlyOutput().IsList();
+ return NumOutputs() == 1 && OnlyOutput().IsList();
}
} // namespace cpp
diff --git a/tensorflow/c/experimental/ops/io_ops.cc b/tensorflow/c/experimental/ops/io_ops.cc
new file mode 100644
index 0000000..5791a13
--- /dev/null
+++ b/tensorflow/c/experimental/ops/io_ops.cc
@@ -0,0 +1,93 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file is MACHINE GENERATED! Do not edit.
+
+#include "tensorflow/c/experimental/ops/io_ops.h"
+
+#include "tensorflow/c/eager/abstract_context.h"
+#include "tensorflow/c/eager/abstract_tensor_handle.h"
+#include "tensorflow/c/eager/tracing_utils.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/errors.h"
+
+using tensorflow::tracing::MaybeSetOpName;
+
+namespace tensorflow {
+namespace ops {
+
+// Op: RestoreV2()
+// Summary: Restores tensors from a V2 checkpoint.
+//
+// Description:
+// For backward compatibility with the V1 format, this Op currently allows
+// restoring from a V1 checkpoint as well:
+// - This Op first attempts to find the V2 index file pointed to by
+// "prefix", and
+// if found proceed to read it as a V2 checkpoint;
+// - Otherwise the V1 read path is invoked.
+// Relying on this behavior is not recommended, as the ability to fall back to
+// read V1 might be deprecated and eventually removed.
+//
+// By default, restores the named tensors in full. If the caller wishes to
+// restore specific slices of stored tensors, "shape_and_slices" should be
+// non-empty strings and correspondingly well-formed.
+//
+// Callers must ensure all the named tensors are indeed stored in the
+// checkpoint.
+Status RestoreV2(AbstractContext* ctx, AbstractTensorHandle* const prefix,
+ AbstractTensorHandle* const tensor_names,
+ AbstractTensorHandle* const shape_and_slices,
+ absl::Span<AbstractTensorHandle*> tensors,
+ absl::Span<DataType> dtypes, const char* name,
+ const char* raw_device_name) {
+ AbstractOperationPtr op_ptr(ctx->CreateOperation());
+ TF_RETURN_IF_ERROR(op_ptr->Reset("RestoreV2", raw_device_name));
+ TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name));
+ TF_RETURN_IF_ERROR(op_ptr->AddInput(prefix));
+ TF_RETURN_IF_ERROR(op_ptr->AddInput(tensor_names));
+ TF_RETURN_IF_ERROR(op_ptr->AddInput(shape_and_slices));
+ TF_RETURN_IF_ERROR(
+ op_ptr->SetAttrTypeList("dtypes", dtypes.data(), dtypes.length()));
+ int num_retvals = tensors.size();
+ return op_ptr->Execute(tensors, &num_retvals);
+}
+
+// Op: SaveV2()
+// Summary: Saves tensors in V2 checkpoint format.
+//
+// Description:
+// By default, saves the named tensors in full. If the caller wishes to save
+// specific slices of full tensors, "shape_and_slices" should be non-empty
+// strings and correspondingly well-formed.
+Status SaveV2(AbstractContext* ctx, AbstractTensorHandle* const prefix,
+ AbstractTensorHandle* const tensor_names,
+ AbstractTensorHandle* const shape_and_slices,
+ absl::Span<AbstractTensorHandle* const> tensors, const char* name,
+ const char* raw_device_name) {
+ AbstractOperationPtr op_ptr(ctx->CreateOperation());
+ TF_RETURN_IF_ERROR(op_ptr->Reset("SaveV2", raw_device_name));
+ TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name));
+ TF_RETURN_IF_ERROR(op_ptr->AddInput(prefix));
+ TF_RETURN_IF_ERROR(op_ptr->AddInput(tensor_names));
+ TF_RETURN_IF_ERROR(op_ptr->AddInput(shape_and_slices));
+ TF_RETURN_IF_ERROR(op_ptr->AddInputList(tensors));
+ int num_retvals = 0;
+ std::vector<AbstractTensorHandle*> dummy_outputs;
+ return op_ptr->Execute(absl::MakeSpan(dummy_outputs), &num_retvals);
+}
+
+} // namespace ops
+} // namespace tensorflow
diff --git a/tensorflow/c/experimental/ops/io_ops.h b/tensorflow/c/experimental/ops/io_ops.h
new file mode 100644
index 0000000..ccba7ff
--- /dev/null
+++ b/tensorflow/c/experimental/ops/io_ops.h
@@ -0,0 +1,46 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file is MACHINE GENERATED! Do not edit.
+
+#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_IO_OPS_H_
+#define TENSORFLOW_C_EXPERIMENTAL_OPS_IO_OPS_H_
+
+#include "tensorflow/c/eager/abstract_context.h"
+#include "tensorflow/c/eager/abstract_tensor_handle.h"
+
+namespace tensorflow {
+namespace ops {
+
+// Restores tensors from a V2 checkpoint.
+Status RestoreV2(AbstractContext* ctx, AbstractTensorHandle* const prefix,
+ AbstractTensorHandle* const tensor_names,
+ AbstractTensorHandle* const shape_and_slices,
+ absl::Span<AbstractTensorHandle*> tensors,
+ absl::Span<DataType> dtypes, const char* name = nullptr,
+ const char* raw_device_name = nullptr);
+
+// Saves tensors in V2 checkpoint format.
+Status SaveV2(AbstractContext* ctx, AbstractTensorHandle* const prefix,
+ AbstractTensorHandle* const tensor_names,
+ AbstractTensorHandle* const shape_and_slices,
+ absl::Span<AbstractTensorHandle* const> tensors,
+ const char* name = nullptr,
+ const char* raw_device_name = nullptr);
+
+} // namespace ops
+} // namespace tensorflow
+
+#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_IO_OPS_H_
diff --git a/tensorflow/c/experimental/ops/update_cpp_ops.sh b/tensorflow/c/experimental/ops/update_cpp_ops.sh
index 55f44dc..b5a47b2 100755
--- a/tensorflow/c/experimental/ops/update_cpp_ops.sh
+++ b/tensorflow/c/experimental/ops/update_cpp_ops.sh
@@ -23,6 +23,7 @@
api_dir="${current_dir}/../../../core/api_def/base_api"
generate="bazel run \
+ --test_output=all \
//tensorflow/c/experimental/ops/gen:generate_cpp -- \
--output_dir="${current_dir}" \
--api_dirs="${api_dir}" \
@@ -67,3 +68,8 @@
ReadVariableOp \
AssignVariableOp \
DestroyResourceOp
+
+${generate} \
+ --category=io \
+ RestoreV2 \
+ SaveV2