blob: 630de47b8cb3b637bb9563a944276c5f09e755a3 [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/ir/importexport/tests/roundtrip/roundtrip.h"
#include "absl/strings/str_cat.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/ir/importexport/export.h"
#include "tensorflow/core/ir/importexport/import.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf.h"
using mlir::MLIRContext;
namespace tensorflow {
// Applies various normalization to a NodeDef to make it possible to perform
// textual comparison (for example splat constant are detected, NaN are removed,
// control input are alphabetically sorted, etc).
void NormalizeNode(NodeDef* node) {
for (auto& named_attr : (*node->mutable_attr())) {
AttrValue& attr_val = named_attr.second;
if (attr_val.has_tensor()) {
auto* tensor = attr_val.mutable_tensor();
switch (tensor->dtype()) {
// There is no compression or canonicalization for DT_STRING, let's
// just strip it entirely for now so it is ignored from the comparison.
case DT_STRING: {
const TensorShape shape(tensor->tensor_shape());
if (!tensor->tensor_content().empty()) {
tensor->mutable_tensor_content()->clear();
} else {
tensor->mutable_string_val()->Clear();
}
break;
}
case DT_FLOAT:
tensor::CompressTensorProtoInPlace(1, 1.0, tensor);
for (float& val : *tensor->mutable_float_val())
if (std::isnan(val)) val = -42.;
break;
case DT_DOUBLE:
tensor::CompressTensorProtoInPlace(1, 1.0, tensor);
for (double& val : *tensor->mutable_double_val())
if (std::isnan(val)) val = -42.;
break;
case DT_COMPLEX64:
tensor::CompressTensorProtoInPlace(1, 1.0, tensor);
for (float& val : *tensor->mutable_scomplex_val())
if (std::isnan(val)) val = -42.;
break;
case DT_COMPLEX128:
tensor::CompressTensorProtoInPlace(1, 1.0, tensor);
for (double& val : *tensor->mutable_dcomplex_val())
if (std::isnan(val)) val = -42.;
break;
case DT_VARIANT: {
Tensor t;
if (t.FromProto(*tensor)) t.AsProtoField(tensor);
break;
}
default:
tensor::CompressTensorProtoInPlace(1, 1.0, tensor);
}
}
}
// Sort control inputs alphabetically.
for (auto it = node->mutable_input()->begin(),
end = node->mutable_input()->end();
it != end; ++it) {
if (it->empty() || it->front() != '^') continue;
std::sort(it, end);
}
const OpDef* op_def = nullptr;
(void)tensorflow::OpRegistry::Global()->LookUpOpDef(node->op(), &op_def);
if (op_def) StripDefaultsFromNodeDef(*op_def, node);
// TODO(aminim): Fix this
node->clear_experimental_debug_info();
}
void NormalizeTensorData(GraphDef& graphdef) {
FunctionDefLibrary* library = graphdef.mutable_library();
llvm::sort(*library->mutable_function(),
[](FunctionDef& lhs, FunctionDef& rhs) {
return lhs.signature().name() < rhs.signature().name();
});
for (int i = 0; i < graphdef.node_size(); ++i)
NormalizeNode(graphdef.mutable_node(i));
llvm::sort(*graphdef.mutable_node(),
[](const NodeDef& lhs, const NodeDef& rhs) {
return lhs.name() < rhs.name();
});
for (int func_id = 0; func_id < library->function_size(); ++func_id) {
FunctionDef* func = library->mutable_function(func_id);
llvm::sort(*func->mutable_node_def(), [](NodeDef& lhs, NodeDef& rhs) {
return lhs.name() < rhs.name();
});
for (int node_id = 0; node_id < func->node_def_size(); ++node_id) {
NodeDef* node = func->mutable_node_def(node_id);
NormalizeNode(node);
}
for (const auto& it : *func->mutable_ret()) {
func->mutable_ret()->at(it.first) = it.second;
// Eliminate empty arg_attr entries.
llvm::SmallVector<int> to_erase;
for (auto& arg_attr : *func->mutable_arg_attr()) {
if (arg_attr.second.attr().empty()) {
to_erase.push_back(arg_attr.first);
}
}
for (int idx : to_erase) func->mutable_arg_attr()->erase(idx);
}
}
}
Status TestRoundTrip(GraphDef& graphdef) {
MLIRContext context;
GraphDebugInfo debug_info;
auto errorOrModule =
mlir::tfg::ImportGraphDefToMlir(&context, debug_info, graphdef);
if (!errorOrModule.ok()) {
LOG(ERROR) << errorOrModule.status();
llvm::errs()
<< "\n\n=========\n=========\n=========\n=========\n=========\n"
<< graphdef.DebugString()
<< "=========\n=========\n=========\n=========\n";
return errorOrModule.status();
}
GraphDef new_graph;
auto module = errorOrModule.ValueOrDie().get();
Status status = tensorflow::ExportMlirToGraphdef(module, &new_graph);
if (!status.ok()) {
LOG(ERROR) << "Error exporting MLIR module to GraphDef: " << status;
return status;
}
GraphDef original_graph;
{
GraphConstructorOptions options;
options.allow_internal_ops = true;
options.add_default_attributes = true;
Graph graph(OpRegistry::Global());
GraphDef preprocessed_graphdef(graphdef);
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
options, std::move(preprocessed_graphdef), &graph));
graph.ToGraphDef(&original_graph);
}
NormalizeTensorData(new_graph);
NormalizeTensorData(original_graph);
if (!tensorflow::protobuf::util::MessageDifferencer::Equivalent(
original_graph, new_graph)) {
LOG(ERROR) << "GraphDef didn't Roundtrip:";
llvm::errs()
<< "\n=========\n\n"
<< module
<< "\n\n=========\n=========\n=========\n=========\n=========\n"
<< graphdef.DebugString()
<< "=========\n=========\n=========\n=========\n";
return errors::InvalidArgument("GraphDef didn't roundtrip");
}
return {};
}
} // namespace tensorflow