Fix Const op tensor_content on s390x during save/load
diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD
index 056c99e..c48f898 100644
--- a/tensorflow/cc/saved_model/BUILD
+++ b/tensorflow/cc/saved_model/BUILD
@@ -45,6 +45,7 @@
deps = [
":constants",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/util/tensor_bundle",
] + if_not_mobile([
# TODO(b/111634734): :lib and :protos_all contain dependencies that
# cannot be built on mobile platforms. Instead, include the appropriate
diff --git a/tensorflow/cc/saved_model/reader.cc b/tensorflow/cc/saved_model/reader.cc
index c1d4736..17f160a 100644
--- a/tensorflow/cc/saved_model/reader.cc
+++ b/tensorflow/cc/saved_model/reader.cc
@@ -24,6 +24,7 @@
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/protobuf/saved_model.pb.h"
+#include "tensorflow/core/util/tensor_bundle/byte_swap.h"
namespace tensorflow {
namespace {
@@ -49,6 +50,33 @@
export_dir);
}
+// Swap tensor_content field of Const Op Tensors in the named functions
+static Status SwapTensorContent(MetaGraphDef* meta_graph_def) {
+ GraphDef graph_def = *meta_graph_def->mutable_graph_def();
+ for (auto& function : *meta_graph_def->mutable_graph_def()->mutable_library()->mutable_function()) {
+ for (auto& node : (*function.mutable_node_def())) {
+ if (node.op() == "Const") {
+ auto node_iterator = node.mutable_attr()->find("value");
+ if (node_iterator != node.mutable_attr()->end()) {
+ AttrValue node_value = node_iterator->second;
+ if (node_value.has_tensor()) {
+ auto tsize = node_value.mutable_tensor()->tensor_content().size();
+ auto p_type = node_value.mutable_tensor()->dtype();
+ // Swap only when there is something in tensor_content field
+ if (tsize!=0 && DataTypeCanUseMemcpy(p_type)) {
+ Tensor parsed(p_type);
+ DCHECK(parsed.FromProto(*node_value.mutable_tensor()));
+ TF_RETURN_IF_ERROR(ByteSwapTensor(&parsed));
+ (*node.mutable_attr())["value"].mutable_tensor()->set_tensor_content(string(reinterpret_cast<const char*>(parsed.tensor_data().data()), parsed.tensor_data().size()));
+ }
+ }
+ }
+ }
+ }
+ }
+ return Status::OK();
+}
+
Status FindMetaGraphDef(const std::unordered_set<string>& tags,
SavedModel* saved_model_proto,
MetaGraphDef* meta_graph_def) {
@@ -63,6 +91,10 @@
// Match with the set of tags provided.
if (graph_tags == tags) {
*meta_graph_def = std::move(graph_def);
+ // Correct the endiness of Tensor content on big-endian system
+ if (!port::kLittleEndian) {
+ SwapTensorContent(meta_graph_def);
+ }
return Status::OK();
}
}
diff --git a/tensorflow/python/saved_model/load.py b/tensorflow/python/saved_model/load.py
index 1d513b4..c4eb751 100644
--- a/tensorflow/python/saved_model/load.py
+++ b/tensorflow/python/saved_model/load.py
@@ -20,6 +20,7 @@
import functools
import os
+import sys
from tensorflow.core.protobuf import graph_debug_info_pb2
from tensorflow.python.distribute import distribute_utils
@@ -874,6 +875,10 @@
if (len(saved_model_proto.meta_graphs) == 1 and
saved_model_proto.meta_graphs[0].HasField("object_graph_def")):
meta_graph_def = saved_model_proto.meta_graphs[0]
+ # tensor_content field contains raw bytes in litle endian format which causes problems
+ # when loaded on big-endian systems requiring byteswap
+ if sys.byteorder == 'big':
+ saved_model_utils.swap_function_tensor_content(meta_graph_def, "little", "big")
if (tags is not None
and set(tags) != set(meta_graph_def.meta_info_def.tags)):
raise ValueError(
diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py
index 3725576..844796d 100644
--- a/tensorflow/python/saved_model/save.py
+++ b/tensorflow/python/saved_model/save.py
@@ -22,6 +22,7 @@
import functools
import gc
import os
+import sys
from absl import logging
from tensorflow.core.framework import versions_pb2
@@ -719,6 +720,9 @@
for signature_key, signature in signatures.items():
meta_graph_def.signature_def[signature_key].CopyFrom(signature)
meta_graph.strip_graph_default_valued_attrs(meta_graph_def)
+ # store tensor_content in litle endian format
+ if sys.byteorder == 'big':
+ utils_impl.swap_function_tensor_content(meta_graph_def, "big", "little")
return asset_info, exported_graph
diff --git a/tensorflow/python/saved_model/utils_impl.py b/tensorflow/python/saved_model/utils_impl.py
index 17ef2ee..70135e3 100644
--- a/tensorflow/python/saved_model/utils_impl.py
+++ b/tensorflow/python/saved_model/utils_impl.py
@@ -278,3 +278,46 @@
"""Returns path to the debug sub-directory in the SavedModel."""
return os.path.join(
compat.as_text(export_dir), compat.as_text(constants.DEBUG_DIRECTORY))
+
+# Based on tensor_bundle/byte_swap.cc
+byte_swappable = [
+ dtypes.float16,
+ dtypes.float32,
+ dtypes.float64,
+ dtypes.bfloat16,
+ dtypes.complex64,
+ dtypes.complex128,
+ dtypes.uint16,
+ dtypes.uint32,
+ dtypes.uint64,
+ dtypes.int16,
+ dtypes.int32,
+ dtypes.int64,
+ dtypes.qint16,
+ dtypes.quint16,
+ dtypes.qint32
+]
+
+def swap_function_tensor_content(meta_graph_def, from_endiness, to_endiness):
+ functions = meta_graph_def.graph_def.library.function
+ for function in functions:
+ node_def = function.node_def
+ for node in node_def:
+ if node.op == "Const":
+ tensor = node.attr["value"].tensor
+ byte_swap_tensor_content(tensor,from_endiness, to_endiness)
+
+def byte_swap_tensor_content(tensor, from_endiness, to_endiness):
+ """Byte swaps"""
+ if tensor.dtype in byte_swappable:
+ tshape = tensor.tensor_shape.dim
+ tensor_bytes = tensor.tensor_content
+ if tensor_bytes != b'':
+ tensor_size = 1
+ for sz in tshape:
+ tensor_size = tensor_size*sz.size
+ chunksize = int(len(tensor_bytes)/tensor_size)
+ #split tensor_data into chunks for byte swapping
+ to_swap = [tensor_bytes[i:i+chunksize] for i in range(0, len(tensor_bytes), chunksize)]
+ #swap and replace tensor_content
+ tensor.tensor_content = b''.join([int.from_bytes(byteswap, from_endiness).to_bytes(chunksize, to_endiness) for byteswap in to_swap])