blob: 34d91a2e7318fb8845aa6caeab89ef86d5a51bed [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h"
#include <algorithm>
#include <numeric>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/ascii.h"
#include "absl/strings/match.h"
#include "absl/strings/numbers.h"
#include "absl/strings/substitute.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/frame.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/protobuf/device_properties.pb.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
namespace grappler {
namespace {
constexpr char kOptimizedSuffix[] = "LayoutOptimizer";
constexpr char kAttrKSize[] = "ksize";
constexpr char kAttrStrides[] = "strides";
constexpr char kAttrDilations[] = "dilations";
constexpr char kAttrExplicitPaddings[] = "explicit_paddings";
constexpr char kAttrDataFormat[] = "data_format";
constexpr char kAttrIsTraining[] = "is_training";
constexpr char kAttrValue[] = "value";
constexpr char kAttrN[] = "N";
constexpr char kAttrT[] = "T";
constexpr char kAttrNumSplit[] = "num_split";
constexpr char kAttrNumOuts[] = "num_outs";
constexpr char kAttrKeepDims[] = "keep_dims";
constexpr char kAttrSqueezeDims[] = "squeeze_dims";
constexpr char kOpTranspose[] = "Transpose";
constexpr char kOpDataFormatVecPermute[] = "DataFormatVecPermute";
constexpr char kOpDataFormatDimMap[] = "DataFormatDimMap";
constexpr char kOpConst[] = "Const";
constexpr char kReshape[] = "Reshape";
constexpr char kReshapeConst[] = "ReshapeConst";
constexpr int kRank = 4;
inline bool AttrDataFormatMatch(const utils::MutableNodeView& node,
absl::string_view src_data_format,
bool* missing) {
const auto* attr = node.GetAttr(kAttrDataFormat);
if (attr != nullptr) {
return attr->s() == src_data_format;
}
*missing = true;
return false;
}
inline bool AttrDataFormatMatch(const utils::MutableNodeView& node,
absl::string_view src_data_format) {
bool missing = false;
return AttrDataFormatMatch(node, src_data_format, &missing);
}
// Utils for layout agnostic transposer.
bool IsComparisonOp(const NodeDef& node) {
bool is_compare = IsApproximateEqual(node) || IsEqual(node) ||
IsGreater(node) || IsGreaterEqual(node) || IsLess(node) ||
IsLessEqual(node) || IsNotEqual(node);
return is_compare;
}
std::vector<int> GetRegularFaninPorts(const utils::MutableNodeView& node) {
const int num_regular_fanins = node.NumRegularFanins();
std::vector<int> values(num_regular_fanins);
std::iota(values.begin(), values.end(), 0);
return values;
}
std::vector<int> GetConcatDataFaninPorts(const utils::MutableNodeView& node) {
const auto* n_attr = node.GetAttr(kAttrN);
const int n = n_attr != nullptr ? n_attr->i() : 0;
const int start = (node.GetOp() == "Concat") ? 1 : 0;
const int end = start + n;
std::vector<int> values(end - start);
std::iota(values.begin(), values.end(), start);
return values;
}
struct ComparatorByNodeNameAndIndex {
bool operator()(const utils::MutableFaninView& node1,
const utils::MutableFaninView& node2) const {
auto* node1_view = node1.node_view();
auto* node2_view = node2.node_view();
auto name_compare = node1_view->GetName().compare(node2_view->GetName());
if (name_compare == 0) {
return node1.index() < node2.index();
}
return name_compare < 0;
}
};
bool IsHostMemory(const NodeDef& node, int output_port) {
DeviceNameUtils::ParsedName parsed_name;
if (DeviceNameUtils::ParseFullName(node.device(), &parsed_name)) {
DeviceType device_type(parsed_name.type);
Status s = FindKernelDef(device_type, node, nullptr, nullptr);
if (s.ok()) {
tensorflow::MemoryTypeVector in_mtypes;
tensorflow::MemoryTypeVector out_mtypes;
s = tensorflow::MemoryTypesForNode(OpRegistry::Global(), device_type,
node, &in_mtypes, &out_mtypes);
if (s.ok()) {
if (out_mtypes[output_port] == HOST_MEMORY) {
return true;
}
}
} else {
return true;
}
}
return false;
}
std::vector<int> GetDimensionIndicesFromLabel(
const absl::flat_hash_map<char, int>& dim_indices,
absl::Span<const char> labels) {
std::vector<int> indices;
indices.reserve(labels.size());
for (const auto& label : labels) {
indices.push_back(dim_indices.at(label));
}
return indices;
}
} // namespace
// TransposeContext.
Status TransposeContext::InitializeTransposeContext(const GrapplerItem& item,
const Cluster* cluster,
TransposeContext* context) {
DCHECK(context != nullptr);
context->graph_properties = absl::make_unique<GraphProperties>(item);
TF_RETURN_IF_ERROR(context->graph_properties->InferStatically(false));
TF_RETURN_IF_ERROR(
context->graph_properties->AnnotateOutputShapes(&context->graph));
Status status;
context->graph_view =
absl::make_unique<utils::MutableGraphView>(&context->graph, &status);
TF_RETURN_IF_ERROR(status);
context->num_nodes = context->graph.node_size();
const auto& nodes_to_preserve = item.NodesToPreserve();
context->nodes_to_preserve = absl::flat_hash_set<string>(
nodes_to_preserve.begin(), nodes_to_preserve.end());
TF_RETURN_IF_ERROR(context->frames.InferFromGraph(context->graph));
if (cluster != nullptr) {
context->virtual_placer =
absl::make_unique<const VirtualPlacer>(cluster->GetDevices());
}
return Status::OK();
}
// Sets data formats to convert from and to for specified device type.
void TransposeContext::AssignDeviceAndDataFormats(
absl::string_view target_device, absl::string_view src_format,
absl::string_view dst_format) {
this->target_device = string(target_device);
this->src_format = string(src_format);
this->dst_format = string(dst_format);
this->src_dim_indices = GetDimensionIndices(src_format);
this->dst_dim_indices = GetDimensionIndices(dst_format);
this->src_to_dst = GetPermutation(this->src_dim_indices, dst_format);
this->dst_to_src = GetPermutation(this->dst_dim_indices, src_format);
}
// Transposer.
bool Transposer::ShouldProcess(const TransposeContext& context,
const utils::MutableNodeView& node) const {
const auto* node_def = node.node();
const string& device_name =
GetDeviceName(context.virtual_placer.get(), *node_def);
string device;
string task;
bool is_on_target_device =
DeviceNameUtils::SplitDeviceName(device_name, &task, &device) &&
absl::StrContains(absl::AsciiStrToLower(device),
absl::AsciiStrToLower(context.target_device));
// Only checks data format for layout sensitive op.
bool data_format_match = !IsLayoutSensitiveOp(*node_def) ||
AttrDataFormatMatch(node, context.src_format);
return is_on_target_device && data_format_match &&
!context.nodes_to_preserve.contains(node_def->name()) &&
!(node.NumRegularFanouts() == 0 && node.NumControlledFanouts() == 0);
}
Status Transposer::CreateConstPermNode(TransposeContext* context,
absl::string_view node_name,
absl::string_view device,
absl::Span<const int> permutation,
absl::string_view control_node_name,
utils::MutationNewNode* added_node) {
auto* graph_view = context->graph_view.get();
DCHECK(!graph_view->HasNode(node_name));
NodeDef node;
node.set_name(string(node_name));
node.set_op(kOpConst);
node.set_device(string(device));
if (!control_node_name.empty()) {
node.add_input(string(control_node_name));
}
AttrValue attr_data_type;
attr_data_type.set_type(DT_INT32);
node.mutable_attr()->insert({"dtype", attr_data_type});
AttrValue attr_tensor;
Tensor tensor(DT_INT32, TensorShape({4}));
for (int i = 0; i < permutation.size(); i++) {
tensor.flat<int>()(i) = permutation[i];
}
tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
node.mutable_attr()->insert({"value", attr_tensor});
Status status;
*added_node =
graph_view->GetMutationBuilder()->AddNode(std::move(node), &status);
return status;
}
Status Transposer::CreateTransposeNode(
TransposeContext* context, absl::string_view name_format,
const DataType& data_type, absl::string_view device,
TensorShapeProto fanin_shape, absl::Span<const int> permutation,
absl::string_view control_node_name, utils::MutationNewNode* added_node,
string* transpose_node_name) {
const string node_name = absl::Substitute(name_format, kOpTranspose);
auto* graph_view = context->graph_view.get();
DCHECK(!graph_view->HasNode(node_name));
*transpose_node_name = node_name;
NodeDef node;
node.set_name(string(node_name));
node.set_op(kOpTranspose);
node.set_device(string(device));
AttrValue attr_data_type;
attr_data_type.set_type(data_type);
node.mutable_attr()->insert({"T", attr_data_type});
AttrValue attr_data_type_perm;
attr_data_type_perm.set_type(DT_INT32);
node.mutable_attr()->insert({"Tperm", attr_data_type_perm});
if (!fanin_shape.unknown_rank()) {
TF_RETURN_IF_ERROR(
PermuteSingle(absl::StrCat("fanin shape in", node.name()), permutation,
fanin_shape.mutable_dim()));
AttrValue attr_output_shape;
*attr_output_shape.mutable_list()->add_shape() = fanin_shape;
node.mutable_attr()->insert({kAttrOutputShape, attr_output_shape});
}
// Create Const Node
utils::MutationNewNode const_perm_added_node;
const string const_perm_node_name =
absl::Substitute(name_format, "PermConst");
TF_RETURN_IF_ERROR(CreateConstPermNode(context, const_perm_node_name, device,
permutation, control_node_name,
&const_perm_added_node));
// Add place holder for 1st input.
node.add_input("");
// Connect const_perm_node to 2nd input of transpose_node.
node.add_input(const_perm_node_name);
Status status;
*added_node =
graph_view->GetMutationBuilder()->AddNode(std::move(node), &status);
return status;
}
Status Transposer::UpdateFaninEdgesWithOp(TransposeContext* context,
absl::Span<const int> dst_ports,
utils::MutableNodeView* dst_node,
absl::string_view op) {
const bool is_in_frame = context->frames.IsInFrame(*dst_node->node());
for (int dst_port : dst_ports) {
auto& fanin_port = dst_node->GetRegularFanin(dst_port);
auto* fanin_node_view = fanin_port.node_view();
TF_RETURN_IF_ERROR(
UpdateEdge(context,
GetFaninNameFormat(dst_node->GetName(), dst_port,
context->src_format, context->dst_format),
op, /*input_shape=*/nullptr, /*is_in_frame=*/is_in_frame,
/*is_src_format_to_dst_format=*/true, fanin_port.index(),
dst_port, fanin_node_view, dst_node));
}
return Status::OK();
}
Status Transposer::UpdateFanoutEdgesWithOp(TransposeContext* context,
absl::Span<const int> src_ports,
utils::MutableNodeView* src_node,
absl::string_view op) {
// Update attr _output_shapes for output ports.
const auto* output_shape_attr = src_node->GetAttr(kAttrOutputShape);
AttrValue shape_attr_copy;
if (op == kOpTranspose && output_shape_attr != nullptr) {
shape_attr_copy = *output_shape_attr;
for (int port : src_ports) {
auto* shape = shape_attr_copy.mutable_list()->mutable_shape(port);
if (shape->unknown_rank()) continue;
TF_RETURN_IF_ERROR(
PermuteSingle(absl::StrCat("output shape attribute at port ", port,
" in", src_node->GetName()),
context->src_to_dst, shape->mutable_dim()));
}
context->graph_view->GetMutationBuilder()->AddOrUpdateNodeAttr(
src_node, kAttrOutputShape, shape_attr_copy);
}
const bool is_in_frame = context->frames.IsInFrame(*src_node->node());
// We might modify the output set in the loop. Make a copy first.
// Use a set with custom comparator to order output nodes by node name,
// so that we can keep transposer name deterministic.
for (int src_port : src_ports) {
const auto& fanouts_src_port = src_node->GetRegularFanout(src_port);
std::vector<utils::MutableFaninView> sorted_fanouts(
fanouts_src_port.begin(), fanouts_src_port.end());
std::sort(sorted_fanouts.begin(), sorted_fanouts.end(),
ComparatorByNodeNameAndIndex());
int num_downstream_transposers = 0;
for (const auto& fanout : sorted_fanouts) {
TF_RETURN_IF_ERROR(UpdateEdge(
context,
GetFanoutNameFormat(src_node->GetName(), src_port,
num_downstream_transposers++, context->src_format,
context->dst_format),
op, &shape_attr_copy, /*is_in_frame=*/is_in_frame,
/*is_src_format_to_dst_format=*/false, src_port, fanout.index(),
src_node, fanout.node_view()));
}
}
return Status::OK();
}
Status Transposer::CreateDataFormatNode(
TransposeContext* context, absl::string_view node_name,
absl::string_view op, absl::string_view device, const DataType& data_type,
bool is_fanin_on_host, bool is_src_format_to_dst_format,
utils::MutationNewNode* added_node) {
auto* graph_view = context->graph_view.get();
DCHECK(!graph_view->HasNode(node_name));
// Create the node
NodeDef node;
node.set_name(string(node_name));
// Set up parameters of node.
node.set_op(string(op));
node.set_device(string(device));
AttrValue attr_data_type;
attr_data_type.set_type(data_type);
node.mutable_attr()->insert({"T", attr_data_type});
// The inputs of a DataFormat op could be in host memory for ops such as
// Reshape. In such cases, run the kernel on the host too.
if (is_fanin_on_host) {
AttrValue attr_kernel;
attr_kernel.set_s("host");
node.mutable_attr()->insert({"_kernel", attr_kernel});
}
AttrValue src_format;
src_format.set_s(is_src_format_to_dst_format ? context->src_format
: context->dst_format);
node.mutable_attr()->insert({kAttrSrcFormat, src_format});
AttrValue dst_format;
dst_format.set_s(is_src_format_to_dst_format ? context->dst_format
: context->src_format);
node.mutable_attr()->insert({kAttrDstFormat, dst_format});
// Add place holder for 1st input field.
node.add_input("");
Status status;
*added_node =
graph_view->GetMutationBuilder()->AddNode(std::move(node), &status);
return status;
}
Status Transposer::UpdateEdge(
TransposeContext* context, absl::string_view name_format,
absl::string_view op, const AttrValue* input_shape, bool is_in_frame,
bool is_src_format_to_dst_format, const int src_port, const int dst_port,
utils::MutableNodeView* src_node, utils::MutableNodeView* dst_node) {
DCHECK(src_node != nullptr);
DCHECK(dst_node != nullptr);
auto* src_node_def = src_node->node();
auto* dst_node_def = dst_node->node();
// TODO(lyandy): Minimize device parsing/fetching.
const string device = GetDeviceName(
context->virtual_placer.get(),
is_src_format_to_dst_format ? *dst_node_def : *src_node_def);
DataType data_type =
is_src_format_to_dst_format
? context->graph_properties
->GetInputProperties(dst_node->GetName())[dst_port]
.dtype()
: context->graph_properties
->GetOutputProperties(src_node->GetName())[src_port]
.dtype();
utils::MutationNewNode added_node;
string added_node_name;
if (op == kOpTranspose) {
TensorShapeProto input_shape_proto;
input_shape_proto.set_unknown_rank(true);
if (input_shape != nullptr) {
input_shape_proto = input_shape->list().shape(src_port);
} else {
const auto* src_node_shape_attr = src_node->GetAttr(kAttrOutputShape);
if (src_node_shape_attr != nullptr) {
input_shape_proto = src_node_shape_attr->list().shape(src_port);
}
}
const string control_node_name =
is_in_frame ? AsControlDependency(src_node_def->name()) : "";
const std::vector<int>& permutation =
is_src_format_to_dst_format ? context->src_to_dst : context->dst_to_src;
TF_RETURN_IF_ERROR(CreateTransposeNode(
context, name_format, data_type, device, input_shape_proto, permutation,
control_node_name, &added_node, &added_node_name));
} else if (op == kOpDataFormatVecPermute || op == kOpDataFormatDimMap) {
DeviceNameUtils::ParsedName parsed_name;
bool is_fanin_on_host =
DeviceNameUtils::ParseFullName(
GetDeviceName(context->virtual_placer.get(), *src_node_def),
&parsed_name) &&
parsed_name.type != "CPU" && IsHostMemory(*src_node_def, src_port);
const string node_name = absl::Substitute(name_format, op);
TF_RETURN_IF_ERROR(CreateDataFormatNode(
context, node_name, op, device, data_type, is_fanin_on_host,
is_src_format_to_dst_format, &added_node));
added_node_name = node_name;
} else {
return Status(error::INVALID_ARGUMENT,
absl::StrCat("Unsupported op \"", op,
"\". Supported ops are Transpose, "
"DataFormatVecPerm, DataFormatDimMap."));
}
// Connect src_node to 1st input of added_node.
utils::Mutation* mutation = context->graph_view->GetMutationBuilder();
mutation->AddOrUpdateRegularFanin(added_node, 0,
{src_node->GetName(), src_port});
// Connect output of added_node to dst_node:dst_port.
mutation->AddOrUpdateRegularFanin(dst_node, dst_port, {added_node_name, 0});
return Status::OK();
}
bool Transposer::IsFanoutPortRankN(const utils::MutableNodeView& node, int port,
int n) const {
const auto* output_shape_attr = node.GetAttr(kAttrOutputShape);
if (output_shape_attr == nullptr ||
output_shape_attr->list().shape_size() <= port) {
return false;
}
const auto& shape = output_shape_attr->list().shape(port);
return !shape.unknown_rank() && shape.dim_size() == n;
}
bool Transposer::IsFanoutPortsRankN(const utils::MutableNodeView& node,
absl::Span<const int> ports, int n) const {
for (const auto& port : ports) {
if (!IsFanoutPortRankN(node, port, n)) {
return false;
}
}
return true;
}
bool Transposer::IsFaninPortRankN(const utils::MutableNodeView& node, int port,
int n) const {
if (port < node.NumRegularFanins() && port >= 0) {
const auto& regular_fanin = node.GetRegularFanin(port);
return IsFanoutPortRankN(*regular_fanin.node_view(), regular_fanin.index(),
n);
}
return false;
}
bool Transposer::IsFaninPortDimsNIfConst(const utils::MutableNodeView& node,
int port,
absl::Span<const int> dims) const {
if (port < node.NumRegularFanins() && port >= 0) {
const auto& regular_fanin = node.GetRegularFanin(port);
const auto* fanin_node_view = regular_fanin.node_view();
if (!IsConstant(*fanin_node_view->node())) {
return true;
}
// If fanin is a Const, check tensor to see if dimensions match.
const auto* value_attr = fanin_node_view->GetAttr(kAttrValue);
if (value_attr == nullptr) {
return false;
}
Tensor tensor;
if (!tensor.FromProto(value_attr->tensor())) {
return false;
}
if (tensor.dims() != dims.size()) {
return false;
}
for (int i = 0; i < dims.size(); ++i) {
if (tensor.dim_size(i) != dims[i]) {
return false;
}
}
return true;
}
return false;
}
bool Transposer::IsFaninPortsDimsNIfConst(const utils::MutableNodeView& node,
absl::Span<const int> ports,
absl::Span<const int> dims) const {
for (const auto& port : ports) {
if (!IsFaninPortDimsNIfConst(node, port, dims)) {
return false;
}
}
return true;
}
bool Transposer::CanProcessNode(const TransposeContext& context,
const utils::MutableNodeView& node) const {
return !context.nodes_to_preserve.contains(node.GetName()) &&
!(node.NumRegularFanouts() == 0 && node.NumControlledFanouts() == 0);
}
string Transposer::GetFaninNameFormat(absl::string_view node_name, int port,
absl::string_view src_format,
absl::string_view dst_format) {
return absl::StrCat(node_name, "-", port, "-$0", src_format, "To", dst_format,
"-", kOptimizedSuffix);
}
string Transposer::GetFanoutNameFormat(absl::string_view node_name, int port,
int index, absl::string_view src_format,
absl::string_view dst_format) {
return absl::StrCat(node_name, "-", port, "-", index, "-$0", dst_format, "To",
src_format, "-", kOptimizedSuffix);
}
string Transposer::LayoutOptimizerNode(absl::string_view node_name) {
return absl::StrCat(node_name, "-", kOptimizedSuffix);
}
string Transposer::GetReshapeNodeNameFormat(absl::string_view node_name,
int index,
absl::string_view src_format,
absl::string_view dst_format) {
return absl::StrCat(node_name, "-", index, "-", kReshape, src_format, "To",
dst_format);
}
string Transposer::GetShapeConstNodeNameFormat(absl::string_view node_name,
int index) {
return absl::StrCat(node_name, "-", index, "-", kReshapeConst);
}
// Layout sensitive transposer.
inline string GetLayoutSensitiveNodeDataFormat(
const utils::MutableNodeView& node) {
const auto* attr = node.GetAttr(kAttrDataFormat);
if (attr != nullptr) {
return attr->s();
}
return "";
}
Status LayoutSensitiveOpTransposer::UpdateNode(TransposeContext* context,
utils::MutableNodeView* node) {
utils::Mutation* mutation = context->graph_view->GetMutationBuilder();
AttrValue data_format_attr;
data_format_attr.set_s(context->dst_format);
mutation->AddOrUpdateNodeAttr(node, kAttrDataFormat, data_format_attr);
auto permute_attr = [&context, &node,
&mutation](absl::string_view attr_name) {
const auto* attr = node->GetAttr(attr_name);
if (attr != nullptr) {
AttrValue attr_copy(*attr);
TF_RETURN_IF_ERROR(PermuteSingle(
absl::StrCat(attr_name, " attribute in", node->GetName()),
context->src_to_dst, attr_copy.mutable_list()->mutable_i()));
mutation->AddOrUpdateNodeAttr(node, attr_name, attr_copy);
}
return Status::OK();
};
// Update attrs.
TF_RETURN_IF_ERROR(permute_attr(kAttrStrides));
TF_RETURN_IF_ERROR(permute_attr(kAttrKSize));
TF_RETURN_IF_ERROR(permute_attr(kAttrDilations));
const auto* explicit_paddings_attr = node->GetAttr(kAttrExplicitPaddings);
if (explicit_paddings_attr != nullptr && explicit_paddings_attr->has_list() &&
explicit_paddings_attr->list().i_size() > 0) {
AttrValue explicit_paddings_attr_copy(*explicit_paddings_attr);
TF_RETURN_IF_ERROR(PermuteDouble(
absl::StrCat("explicit_paddings attribute in", node->GetName()),
context->src_to_dst,
explicit_paddings_attr_copy.mutable_list()->mutable_i()));
mutation->AddOrUpdateNodeAttr(node, kAttrExplicitPaddings,
explicit_paddings_attr_copy);
}
return Status::OK();
}
Status DefaultLayoutSensitiveOpTransposer::TransposeNode(
TransposeContext* context, utils::MutableNodeView* node) {
DCHECK(IsDefaultLayoutSensitiveOp(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
<< "' with op '" << node->GetOp() << "' from data format '"
<< context->src_format << "' to '" << context->dst_format << "'";
TF_RETURN_IF_ERROR(UpdateNode(context, node));
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
return context->graph_view->GetMutationBuilder()->Apply();
}
Status AvgPoolGradTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsAvgPoolGrad(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 1, 4)) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
<< "' with op '" << node->GetOp() << "' from data format '"
<< context->src_format << "' to '" << context->dst_format << "'";
TF_RETURN_IF_ERROR(UpdateNode(context, node));
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {1}, node, kOpTranspose));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
return context->graph_view->GetMutationBuilder()->Apply();
}
Status BiasAddGradTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsBiasAddGrad(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, 4)) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
<< "' with op '" << node->GetOp() << "' from data format '"
<< context->src_format << "' to '" << context->dst_format << "'";
TF_RETURN_IF_ERROR(UpdateNode(context, node));
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
// No need to update output shape, as it is always of shape 1-D with size the
// feature dimension of `out_backprop`, regardless of whether NCHW or NHWC is
// used.
return context->graph_view->GetMutationBuilder()->Apply();
}
Status Conv2DBackpropFilterTransposer::TransposeNode(
TransposeContext* context, utils::MutableNodeView* node) {
DCHECK(IsConv2DBackpropFilter(*node->node()) ||
IsDepthwiseConv2dNativeBackpropFilter(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
<< "' with op '" << node->GetOp() << "' from data format '"
<< context->src_format << "' to '" << context->dst_format << "'";
TF_RETURN_IF_ERROR(UpdateNode(context, node));
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {0, 2}, node, kOpTranspose));
// No need to update output shape, as it is always of shape
// [filter_height, filter_width, in_channels, out_channels], regardless of
// whether NCHW or NHWC is used.
return context->graph_view->GetMutationBuilder()->Apply();
}
Status Conv2DBackpropInputTransposer::TransposeNode(
TransposeContext* context, utils::MutableNodeView* node) {
DCHECK(IsConv2DBackpropInput(*node->node()) ||
IsDepthwiseConv2dNativeBackpropInput(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
<< "' with op '" << node->GetOp() << "' from data format '"
<< context->src_format << "' to '" << context->dst_format << "'";
TF_RETURN_IF_ERROR(UpdateNode(context, node));
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {2}, node, kOpTranspose));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
return context->graph_view->GetMutationBuilder()->Apply();
}
Status FusedBatchNormExTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsFusedBatchNormEx(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
<< "' with op '" << node->GetOp() << "' from data format '"
<< context->src_format << "' to '" << context->dst_format << "'";
TF_RETURN_IF_ERROR(UpdateNode(context, node));
if (node->NumRegularFanins() == 6) {
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {0, 5}, node, kOpTranspose));
} else {
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
}
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
return context->graph_view->GetMutationBuilder()->Apply();
}
bool FusedBatchNormGradTransposer::IsTraining(
const utils::MutableNodeView& node) const {
const auto* is_training_attr = node.GetAttr(kAttrIsTraining);
if (is_training_attr != nullptr) {
return is_training_attr->b();
}
return false;
}
Status FusedBatchNormGradTransposer::TransposeNode(
TransposeContext* context, utils::MutableNodeView* node) {
DCHECK(IsFusedBatchNormGrad(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
!IsTraining(*node)) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
<< "' with op '" << node->GetOp() << "' from data format '"
<< context->src_format << "' to '" << context->dst_format << "'";
TF_RETURN_IF_ERROR(UpdateNode(context, node));
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {0, 1}, node, kOpTranspose));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
return context->graph_view->GetMutationBuilder()->Apply();
}
Status MaxPoolV2Transposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsMaxPoolV2(*node->node()));
// We check data_input's shape instead, because the shape inference of
// MaxPoolV2 is not able to infer the shape when ksize or strides is not
// constant.
const auto& data_fanin = node->GetRegularFanin(0);
auto* data_fanin_node = data_fanin.node_view();
if (!ShouldProcess(*context, *node) ||
!IsFanoutPortRankN(*data_fanin_node, data_fanin.index(), 4)) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
<< "' with op '" << node->GetOp() << "' from data format '"
<< context->src_format << "' to '" << context->dst_format << "'";
TF_RETURN_IF_ERROR(UpdateNode(context, node));
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {1, 2}, node, kOpDataFormatVecPermute));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
return context->graph_view->GetMutationBuilder()->Apply();
}
Status MaxPoolGradTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsMaxPoolGrad(*node->node()) || IsMaxPoolGradGradV1(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
<< "' with op '" << node->GetOp() << "' from data format '"
<< context->src_format << "' to '" << context->dst_format << "'";
TF_RETURN_IF_ERROR(UpdateNode(context, node));
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {0, 1, 2}, node, kOpTranspose));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
return context->graph_view->GetMutationBuilder()->Apply();
}
Status MaxPoolGradV2Transposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsMaxPoolGradV2(*node->node()) || IsMaxPoolGradGradV2(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
<< "' with op '" << node->GetOp() << "' from data format '"
<< context->src_format << "' to '" << context->dst_format << "'";
TF_RETURN_IF_ERROR(UpdateNode(context, node));
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {0, 1, 2}, node, kOpTranspose));
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {3, 4}, node, kOpDataFormatVecPermute));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
return context->graph_view->GetMutationBuilder()->Apply();
}
// Layout agnostic transposer.
inline bool IsValidConstPermTransposeNode(const utils::MutableNodeView& node,
absl::Span<const int> permutation) {
Tensor tensor;
if (!GetValueAttrFromConstInputNode(node, IsTranspose, 1, &tensor)) {
return false;
}
if (tensor.NumElements() != permutation.size()) {
return false;
}
const auto& tensor_data = tensor.unaligned_flat<int32>();
for (int i = 0; i < permutation.size(); i++) {
if (permutation[i] != tensor_data(i)) {
return false;
}
}
return true;
}
inline bool IsValidDataFormatNode(const utils::MutableNodeView& node,
absl::string_view src_format,
absl::string_view dst_format) {
if (!IsDataFormatOp(node)) {
return false;
}
const auto* src_format_attr = node.GetAttr(kAttrSrcFormat);
if (src_format_attr == nullptr || src_format_attr->s() != src_format) {
return false;
}
const auto* dst_format_attr = node.GetAttr(kAttrDstFormat);
if (dst_format_attr == nullptr || dst_format_attr->s() != dst_format) {
return false;
}
return true;
}
inline bool IsLayoutOptimizerAddedDstToSrcTranspose(
const TransposeContext& context, const utils::MutableNodeView& node) {
return node.node_index() >= context.num_nodes &&
IsValidConstPermTransposeNode(node, context.dst_to_src);
}
inline bool IsLayoutOptimizerAddedDstToSrcTransform(
const TransposeContext& context, const utils::MutableNodeView& node) {
return node.node_index() >= context.num_nodes &&
(IsValidConstPermTransposeNode(node, context.dst_to_src) ||
IsValidDataFormatNode(node, context.dst_format, context.src_format));
}
bool LayoutAgnosticOpTransposer::IsAfterDstToSrcTransform(
const TransposeContext& context, const utils::MutableNodeView& node) const {
std::deque<utils::MutableNodeView*> queue;
absl::flat_hash_set<utils::MutableNodeView*> visited_nodes;
auto data_node_pos = GetDataFaninPorts(node);
for (const int pos : data_node_pos) {
const auto& fanin = node.GetRegularFanin(pos);
auto* fanin_node = fanin.node_view();
queue.push_back(fanin_node);
visited_nodes.insert(fanin_node);
}
// The code will exit this while loop in one iteration in most cases, as the
// graph is already topologically sorted.
while (!queue.empty()) {
utils::MutableNodeView* current_node = queue.front();
queue.pop_front();
if (IsLayoutOptimizerAddedDstToSrcTransform(context, *current_node)) {
return true;
}
// We only continue searching if the path is connected through
// format-agnostic nodes.
if (IsLayoutAgnosticOp(*current_node->node())) {
auto current_node_pos = GetDataFaninPorts(*current_node);
for (const auto& pos : current_node_pos) {
const auto& fanin = current_node->GetRegularFanin(pos);
auto* fanin_node = fanin.node_view();
if (visited_nodes.insert(fanin_node).second) {
queue.push_back(fanin_node);
}
}
}
}
return false;
}
std::vector<int> LayoutAgnosticOpTransposer::GetVariadic4DFaninPorts(
const TransposeContext& context, const utils::MutableNodeView& node) const {
std::vector<int> ports;
const int num_regular_fanins = node.NumRegularFanins();
ports.reserve(num_regular_fanins);
for (int i = 0; i < num_regular_fanins; ++i) {
const auto& regular_fanin = node.GetRegularFanin(i);
auto* regular_fanin_node = regular_fanin.node_view();
int regular_fanin_port = regular_fanin.index();
if (IsFanoutPortRankN(*regular_fanin_node, regular_fanin_port, 4) &&
((IsAfterDstToSrcTransform(context, *regular_fanin_node) &&
IsLayoutAgnosticOp(*regular_fanin_node->node())) ||
IsLayoutOptimizerAddedDstToSrcTranspose(context,
*regular_fanin_node))) {
ports.push_back(i);
}
}
return ports;
}
Status DefaultLayoutAgnosticOpTransposer::TransposeNode(
TransposeContext* context, utils::MutableNodeView* node) {
DCHECK(IsDefaultLayoutAgnosticOp(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
!IsAfterDstToSrcTransform(*context, *node)) {
return Status::OK();
}
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
return context->graph_view->GetMutationBuilder()->Apply();
}
Status AddNTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsAddN(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
!IsAfterDstToSrcTransform(*context, *node)) {
return Status::OK();
}
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, GetDataFaninPorts(*node),
node, kOpTranspose));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
return context->graph_view->GetMutationBuilder()->Apply();
}
bool BinaryOpTransposer::IsNDOperateWithMD(const utils::MutableNodeView& node,
int n, int m) {
return IsFaninPortRankN(node, 0, n) && IsFaninPortRankN(node, 1, m);
}
bool BinaryOpTransposer::IsFaninShapeSupported(
const utils::MutableNodeView& node) {
return (IsNDOperateWithMD(node, 4, 0) || IsNDOperateWithMD(node, 4, 1) ||
IsNDOperateWithMD(node, 4, 4) || IsNDOperateWithMD(node, 0, 4) ||
IsNDOperateWithMD(node, 1, 4));
}
std::vector<int> BinaryOpTransposer::Get4DDataFaninPorts(
const utils::MutableNodeView& node) {
std::vector<int> values;
if (IsFaninPortRankN(node, 0, 4)) {
values.push_back(0);
}
if (IsFaninPortRankN(node, 1, 4)) {
values.push_back(1);
}
return values;
}
Status BinaryOpTransposer::AddNodeReshape(
utils::Mutation* mutation, absl::string_view node_name,
absl::string_view node_device, absl::string_view input_name,
absl::string_view shape_const_node_name, const DataType& data_type) {
NodeDef new_node;
new_node.set_name(string(node_name));
new_node.add_input(string(input_name));
new_node.add_input(string(shape_const_node_name));
new_node.set_op(kReshape);
new_node.set_device(string(node_device));
AttrValue attr_type_indices;
attr_type_indices.set_type(DT_INT32);
new_node.mutable_attr()->insert({"Tshape", attr_type_indices});
AttrValue attr_type_params;
attr_type_params.set_type(data_type);
new_node.mutable_attr()->insert({"T", attr_type_params});
Status status;
mutation->AddNode(std::move(new_node), &status);
return status;
}
Status BinaryOpTransposer::AddNodeShapeConst(utils::Mutation* mutation,
absl::string_view node_name,
absl::string_view node_device,
bool node_in_frame,
int num_channels,
absl::string_view depended_node) {
NodeDef new_node;
new_node.set_name(string(node_name));
new_node.set_op(kOpConst);
new_node.set_device(string(node_device));
AttrValue attr_data_type;
attr_data_type.set_type(DT_INT32);
new_node.mutable_attr()->insert({"dtype", attr_data_type});
AttrValue attr_tensor;
Tensor tensor(DT_INT32, TensorShape({4}));
std::vector<int> shape = {1, num_channels, 1, 1};
for (int i = 0; i < static_cast<int>(shape.size()); i++) {
tensor.flat<int>()(i) = shape[i];
}
tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
new_node.mutable_attr()->insert({"value", attr_tensor});
if (node_in_frame) {
// This is to ensure the transpose node and the const node are in the same
// frame.
// TODO(halehri): Add Test that exercises this condition.
new_node.add_input(AsControlDependency(string(depended_node)));
}
Status status;
mutation->AddNode(std::move(new_node), &status);
return status;
}
Status BinaryOpTransposer::MaybeReshapeVectorFanin(
TransposeContext* context, utils::MutableNodeView* node) {
int vector_index = -1;
if (IsNDOperateWithMD(*node, 4, 1)) {
vector_index = 1;
} else if (IsNDOperateWithMD(*node, 1, 4)) {
vector_index = 0;
}
if (vector_index != -1) {
const string& node_name = node->GetName();
const string& node_device = node->GetDevice();
string reshape_node_name = LayoutOptimizerNode(GetReshapeNodeNameFormat(
node_name, vector_index, context->src_format, context->dst_format));
string shape_const_node_name = LayoutOptimizerNode(
GetShapeConstNodeNameFormat(node_name, vector_index));
const auto& fanin = node->GetRegularFanin(vector_index);
auto* fanin_node = fanin.node_view();
const auto* output_shape_attr = fanin_node->GetAttr(kAttrOutputShape);
if (output_shape_attr == nullptr) {
return errors::InvalidArgument("Missing attribute ", kAttrOutputShape);
}
int vector_size =
output_shape_attr->list().shape(fanin.index()).dim(0).size();
utils::Mutation* mutation = context->graph_view->GetMutationBuilder();
TF_RETURN_IF_ERROR(
AddNodeShapeConst(mutation, shape_const_node_name, node_device,
context->frames.IsInFrame(*node->node()), vector_size,
fanin_node->GetName()));
const auto* t_attr = node->GetAttr(kAttrT);
if (t_attr == nullptr) {
return errors::InvalidArgument("Missing attribute ", kAttrT);
}
TF_RETURN_IF_ERROR(
AddNodeReshape(mutation, reshape_node_name, node_device,
TensorIdToString({fanin_node->GetName(), fanin.index()}),
shape_const_node_name, t_attr->type()));
mutation->AddOrUpdateRegularFanin(node, vector_index,
{reshape_node_name, 0});
}
return Status::OK();
}
Status BinaryOpTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsBinaryOp(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFaninShapeSupported(*node) ||
!IsAfterDstToSrcTransform(*context, *node)) {
return Status::OK();
}
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, Get4DDataFaninPorts(*node),
node, kOpTranspose));
TF_RETURN_IF_ERROR(MaybeReshapeVectorFanin(context, node));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
return context->graph_view->GetMutationBuilder()->Apply();
}
Status ConcatOpTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsConcat(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
!IsAfterDstToSrcTransform(*context, *node)) {
return Status::OK();
}
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(
context, GetConcatDataFaninPorts(*node), node, kOpTranspose));
int axis_node = 0;
if (node->GetOp() == "ConcatV2") {
const auto* n_attr = node->GetAttr(kAttrN);
if (n_attr != nullptr) {
axis_node = n_attr->i();
}
}
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {axis_node}, node, kOpDataFormatDimMap));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
return context->graph_view->GetMutationBuilder()->Apply();
}
Status FillOpTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsFill(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
!IsFaninPortDimsNIfConst(*node, 0, {4}) ||
!IsAfterDstToSrcTransform(*context, *node)) {
return Status::OK();
}
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
return context->graph_view->GetMutationBuilder()->Apply();
}
Status IdentityNTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsIdentityN(*node->node()));
const auto ports = GetVariadic4DFaninPorts(*context, *node);
if (!ShouldProcess(*context, *node) || ports.empty()) {
return Status::OK();
}
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, ports, node, kOpTranspose));
TF_RETURN_IF_ERROR(
UpdateFanoutEdgesWithOp(context, ports, node, kOpTranspose));
return context->graph_view->GetMutationBuilder()->Apply();
}
bool MergeTransposer::IsEveryFaninAfterDstToSrcTransform(
const TransposeContext& context, const utils::MutableNodeView& node) const {
for (const auto& regular_fanin : node.GetRegularFanins()) {
auto* regular_fanin_node = regular_fanin.node_view();
if (IsFanoutPortRankN(*regular_fanin_node, regular_fanin.index(), 4) &&
((IsAfterDstToSrcTransform(context, *regular_fanin_node) &&
IsLayoutAgnosticOp(*regular_fanin_node->node())) ||
IsLayoutOptimizerAddedDstToSrcTranspose(context,
*regular_fanin_node))) {
continue;
}
return false;
}
return true;
}
Status MergeTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsMerge(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
!IsEveryFaninAfterDstToSrcTransform(*context, *node)) {
return Status::OK();
}
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, GetDataFaninPorts(*node),
node, kOpTranspose));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
return context->graph_view->GetMutationBuilder()->Apply();
}
Status PadTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsMirrorPad(*node->node()) || IsMirrorPadGrad(*node->node()) ||
IsPad(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
!IsFaninPortDimsNIfConst(*node, 1, {4, 2}) ||
!IsAfterDstToSrcTransform(*context, *node)) {
return Status::OK();
}
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {1}, node, kOpDataFormatVecPermute));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
return context->graph_view->GetMutationBuilder()->Apply();
}
bool ReduceTransposer::KeepDims(const utils::MutableNodeView& node) {
const auto* keep_dims_attr = node.GetAttr(kAttrKeepDims);
if (keep_dims_attr != nullptr) {
return keep_dims_attr->b();
}
return false;
}
bool ReduceTransposer::IsAlongAxis(const Tensor& tensor,
absl::Span<const int> axis, int rank) {
if (tensor.dims() != 1 || tensor.dim_size(0) != axis.size()) {
return false;
}
for (int i = 0; i < axis.size(); ++i) {
int local_axis = tensor.flat<int>()(i);
if (local_axis < 0) {
local_axis += rank;
}
bool along_axis = false;
for (int dim : axis) {
if (local_axis == dim) {
along_axis = true;
break;
}
}
if (!along_axis) {
return false;
}
}
return true;
}
bool ReduceTransposer::IsReduceAxisSupported(
const TransposeContext& context, const utils::MutableNodeView& node) {
if (KeepDims(node)) {
return true;
}
const auto& regular_fanin_1 = node.GetRegularFanin(1);
auto* axis_node = regular_fanin_1.node_view();
if (!IsConstant(*axis_node->node())) {
return false;
}
const auto* value_attr = axis_node->GetAttr(kAttrValue);
if (value_attr == nullptr) {
return false;
}
Tensor tensor;
if (!tensor.FromProto(value_attr->tensor())) {
LOG(ERROR) << "Failed to parse TensorProto.";
return false;
}
auto indices = [&context](absl::Span<const char> labels) {
return GetDimensionIndicesFromLabel(context.src_dim_indices, labels);
};
return IsAlongAxis(tensor, indices({'N', 'H', 'W', 'C'}), kRank) ||
IsAlongAxis(tensor, indices({'H', 'W', 'C'}), kRank) ||
IsAlongAxis(tensor, indices({'N', 'H', 'W'}), kRank) ||
IsAlongAxis(tensor, indices({'H', 'W'}), kRank) ||
IsAlongAxis(tensor, indices({'C'}), kRank);
}
Status ReduceTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsReduceOp(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, 4) ||
!IsReduceAxisSupported(*context, *node) ||
!IsAfterDstToSrcTransform(*context, *node)) {
return Status::OK();
}
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {1}, node, kOpDataFormatDimMap));
if (KeepDims(*node)) {
TF_RETURN_IF_ERROR(
UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
}
return context->graph_view->GetMutationBuilder()->Apply();
}
Status ReverseV2Transposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsReverseV2(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
!IsAfterDstToSrcTransform(*context, *node)) {
return Status::OK();
}
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {1}, node, kOpDataFormatDimMap));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
return context->graph_view->GetMutationBuilder()->Apply();
}
bool SelectTransposer::IsFaninScalarVector4D(
const utils::MutableNodeView& fanin, int port) {
return IsFanoutPortRankN(fanin, port, 0) ||
IsFanoutPortRankN(fanin, port, 1) || IsFanoutPortRankN(fanin, port, 4);
}
std::vector<int> SelectTransposer::GetFaninPorts(
const utils::MutableNodeView& fanin, int port) {
// Input 0 could be a scalar, a vector with size matching the first dimension
// of input 1 and 2, or must have the same shape as input 1 and 2.
if (IsFanoutPortRankN(fanin, port, 4)) {
return {0, 1, 2};
}
return {1, 2};
}
Status SelectTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsSelect(*node->node()));
const auto& regular_fanin_0 = node->GetRegularFanin(0);
auto* regular_fanin_0_node = regular_fanin_0.node_view();
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
!IsFaninScalarVector4D(*regular_fanin_0_node, regular_fanin_0.index()) ||
!IsAfterDstToSrcTransform(*context, *node)) {
return Status::OK();
}
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(
context, GetFaninPorts(*regular_fanin_0_node, regular_fanin_0.index()),
node, kOpTranspose));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
return context->graph_view->GetMutationBuilder()->Apply();
}
Status ShapeTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsShape(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, 4) ||
!IsAfterDstToSrcTransform(*context, *node)) {
return Status::OK();
}
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
TF_RETURN_IF_ERROR(
UpdateFanoutEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
return context->graph_view->GetMutationBuilder()->Apply();
}
Status ShapeNTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsShapeN(*node->node()));
const auto ports = GetVariadic4DFaninPorts(*context, *node);
if (!ShouldProcess(*context, *node) || ports.empty()) {
return Status::OK();
}
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, ports, node, kOpTranspose));
TF_RETURN_IF_ERROR(
UpdateFanoutEdgesWithOp(context, ports, node, kOpDataFormatVecPermute));
return context->graph_view->GetMutationBuilder()->Apply();
}
Status SliceTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsSlice(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
!IsFaninPortsDimsNIfConst(*node, {1, 2}, {4}) ||
!IsAfterDstToSrcTransform(*context, *node)) {
return Status::OK();
}
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {1, 2}, node, kOpDataFormatVecPermute));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
return context->graph_view->GetMutationBuilder()->Apply();
}
Status SplitTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsSplit(*node->node()));
const auto ports = GetDataFanoutPorts(*node);
if (!ShouldProcess(*context, *node) || !IsFanoutPortsRankN(*node, ports, 4) ||
!IsAfterDstToSrcTransform(*context, *node)) {
return Status::OK();
}
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {1}, node, kOpTranspose));
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatDimMap));
TF_RETURN_IF_ERROR(
UpdateFanoutEdgesWithOp(context, ports, node, kOpTranspose));
return context->graph_view->GetMutationBuilder()->Apply();
}
Status SplitVTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsSplitV(*node->node()));
const auto ports = GetDataFanoutPorts(*node);
if (!ShouldProcess(*context, *node) || !IsFanoutPortsRankN(*node, ports, 4) ||
!IsAfterDstToSrcTransform(*context, *node)) {
return Status::OK();
}
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {2}, node, kOpDataFormatDimMap));
TF_RETURN_IF_ERROR(
UpdateFanoutEdgesWithOp(context, ports, node, kOpTranspose));
return context->graph_view->GetMutationBuilder()->Apply();
}
bool SqueezeTransposer::IsInputConvertible(
const TransposeContext& context, const utils::MutableNodeView& node) const {
const auto& regular_fanin_0 = node.GetRegularFanin(0);
auto* regular_fanin_0_node = regular_fanin_0.node_view();
const auto* output_shape_attr =
regular_fanin_0_node->GetAttr(kAttrOutputShape);
if (output_shape_attr != nullptr) {
auto& shape = output_shape_attr->list().shape(regular_fanin_0.index());
if (shape.dim_size() != kRank) {
return false;
}
const int height_dim = context.src_dim_indices.at('H');
const int width_dim = context.src_dim_indices.at('W');
if (shape.dim(height_dim).size() == 1 && shape.dim(width_dim).size() == 1) {
return true;
}
}
return false;
}
bool SqueezeTransposer::IsAlongAxis(const AttrValue& attr,
absl::Span<const int> axis,
int rank) const {
const auto& list = attr.list();
// If list is empty, Squeeze op will squeeze all dimensions of size 1.
if (list.i_size() == 0) {
return true;
} else if (list.i_size() != axis.size()) {
return false;
}
for (int i = 0; i < axis.size(); ++i) {
int local_axis = list.i(i);
if (local_axis < 0) {
local_axis += rank;
}
bool along_axis = false;
for (int dim : axis) {
if (local_axis == dim) {
along_axis = true;
break;
}
}
if (!along_axis) {
return false;
}
}
return true;
}
bool SqueezeTransposer::IsDimsSupported(
const TransposeContext& context, const utils::MutableNodeView& node) const {
auto indices = [&context](absl::Span<const char> labels) {
return GetDimensionIndicesFromLabel(context.src_dim_indices, labels);
};
const auto* squeeze_dims_attr = node.GetAttr(kAttrSqueezeDims);
if (squeeze_dims_attr == nullptr) {
return false;
}
return (IsFanoutPortRankN(node, 0, 2) &&
IsAlongAxis(*squeeze_dims_attr, indices({'H', 'W'}), kRank)) ||
(IsFanoutPortRankN(node, 0, 1) &&
IsAlongAxis(*squeeze_dims_attr, indices({'N', 'H', 'W'}), kRank));
}
Status SqueezeTransposer::UpdateSqueezeDims(TransposeContext* context,
utils::MutableNodeView* node) {
const auto* squeeze_dims_attr = node->GetAttr(kAttrSqueezeDims);
if (squeeze_dims_attr == nullptr) {
return errors::InvalidArgument("Missing attribute ", kAttrSqueezeDims);
}
const int max_num_squeeze_dim = context->src_format.length() - 1;
const int min_squeeze_dim = -(max_num_squeeze_dim + 1);
std::vector<int> squeeze_dims_mapped;
const int squeeze_dims_size = squeeze_dims_attr->list().i_size();
squeeze_dims_mapped.reserve(squeeze_dims_size);
for (int i = 0; i < squeeze_dims_size; ++i) {
int dim = squeeze_dims_attr->list().i(i);
if (dim < min_squeeze_dim || dim >= max_num_squeeze_dim) {
return errors::InvalidArgument(
"Attribute '", kAttrSqueezeDims, "' contains out of range index '",
dim, "', index must be between [", min_squeeze_dim, ", ",
max_num_squeeze_dim, ")");
}
if (dim < 0) {
dim += max_num_squeeze_dim;
}
squeeze_dims_mapped.push_back(context->dst_to_src[dim]);
}
std::sort(squeeze_dims_mapped.begin(), squeeze_dims_mapped.end());
AttrValue squeeze_dims;
squeeze_dims.mutable_list()->mutable_i()->Reserve(squeeze_dims_size);
for (const auto& dim : squeeze_dims_mapped) {
squeeze_dims.mutable_list()->mutable_i()->Add(dim);
}
context->graph_view->GetMutationBuilder()->AddOrUpdateNodeAttr(
node, kAttrSqueezeDims, squeeze_dims);
return Status::OK();
}
Status SqueezeTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsSqueeze(*node->node()));
if (!ShouldProcess(*context, *node) || !IsDimsSupported(*context, *node) ||
!IsInputConvertible(*context, *node) ||
!IsAfterDstToSrcTransform(*context, *node)) {
return Status::OK();
}
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
TF_RETURN_IF_ERROR(UpdateSqueezeDims(context, node));
return context->graph_view->GetMutationBuilder()->Apply();
}
bool StridedSliceTransposer::IsMaskZero(const utils::MutableNodeView& node,
absl::string_view mask) {
const auto* mask_attr = node.GetAttr(mask);
if (mask_attr != nullptr) {
return mask_attr->i() == 0;
}
return true;
}
bool StridedSliceTransposer::HasOnlyBeginEndMask(
const utils::MutableNodeView& node) {
return IsMaskZero(node, "ellipsis_mask") &&
IsMaskZero(node, "new_axis_mask") &&
IsMaskZero(node, "shrink_axis_mask");
}
Status StridedSliceTransposer::PermuteMask(TransposeContext* context,
utils::MutableNodeView* node,
absl::string_view mask) {
// Computers the permutation of the masks based on the src and dst format.
// For example:
// src_format = NHWC
// dst_format = NCHW
// src_to_dst permutation = [0, 3, 1, 2].
// mask : 0010 [Note the bit positions correspond to indexes i.e this is in
// reverse order of the src format (CWHN)] result : 0100 (WHCN)
const auto* mask_attr = node->GetAttr(mask);
const int mask_i = mask_attr != nullptr ? mask_attr->i() : 0;
if (mask_i < 0 || mask_i > 15) {
return errors::InvalidArgument("invalid mask value: ", mask_i);
}
int result = 0;
for (int i = 0; i < context->src_to_dst.size(); i++) {
const int final_pos = context->src_to_dst[i];
const int position_mask = 1 << final_pos;
const int bit_i = (mask_i & position_mask) >> final_pos;
result |= bit_i << i;
}
AttrValue new_mask_attr;
new_mask_attr.set_i(result);
context->graph_view->GetMutationBuilder()->AddOrUpdateNodeAttr(node, mask,
new_mask_attr);
return Status::OK();
}
Status StridedSliceTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsStridedSlice(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
!IsFaninPortsDimsNIfConst(*node, {1, 2, 3}, {4}) ||
!HasOnlyBeginEndMask(*node) ||
!IsAfterDstToSrcTransform(*context, *node)) {
return Status::OK();
}
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
TF_RETURN_IF_ERROR(PermuteMask(context, node, "begin_mask"));
TF_RETURN_IF_ERROR(PermuteMask(context, node, "end_mask"));
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {1, 2, 3}, node,
kOpDataFormatVecPermute));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
return context->graph_view->GetMutationBuilder()->Apply();
}
Status SwitchTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsSwitch(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, 4) ||
!IsAfterDstToSrcTransform(*context, *node)) {
return Status::OK();
}
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, GetDataFanoutPorts(*node),
node, kOpTranspose));
return context->graph_view->GetMutationBuilder()->Apply();
}
Status TernaryOpTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsTernaryOp(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
!IsAfterDstToSrcTransform(*context, *node)) {
return Status::OK();
}
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {0, 1, 2}, node, kOpTranspose));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
return context->graph_view->GetMutationBuilder()->Apply();
}
Status TileTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsTile(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
!IsFaninPortDimsNIfConst(*node, 1, {4}) ||
!IsAfterDstToSrcTransform(*context, *node)) {
return Status::OK();
}
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {1}, node, kOpDataFormatVecPermute));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
return context->graph_view->GetMutationBuilder()->Apply();
}
Status UnaryGradTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsUnaryGrad(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
!IsAfterDstToSrcTransform(*context, *node)) {
return Status::OK();
}
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {0, 1}, node, kOpTranspose));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
return context->graph_view->GetMutationBuilder()->Apply();
}
// Utils.
string GetDeviceName(const VirtualPlacer* virtual_placer, const NodeDef& node) {
return (node.device().empty() && virtual_placer != nullptr)
? virtual_placer->get_canonical_device_name(node)
: node.device();
}
bool IsDefaultLayoutSensitiveOp(const NodeDef& node) {
std::set<string> default_layout_sensitive_ops = {"AvgPool",
"BiasAdd",
"Conv2D",
"DepthwiseConv2dNative",
"DepthToSpace",
"FusedBatchNorm",
"FusedBatchNormV2",
"FusedBatchNormV3",
"FusedConv2DBiasActivation",
"MaxPool",
"SpaceToDepth"};
return default_layout_sensitive_ops.find(node.op()) !=
default_layout_sensitive_ops.end();
}
bool IsLayoutSensitiveOp(const NodeDef& node) {
return IsDefaultLayoutSensitiveOp(node) || IsAvgPoolGrad(node) ||
IsBiasAddGrad(node) || IsConv2DBackpropFilter(node) ||
IsConv2DBackpropInput(node) ||
IsDepthwiseConv2dNativeBackpropFilter(node) ||
IsDepthwiseConv2dNativeBackpropInput(node) ||
IsFusedBatchNormEx(node) || IsFusedBatchNormGrad(node) ||
IsMaxPoolV2(node) || IsMaxPoolGrad(node) || IsMaxPoolGradV2(node) ||
IsMaxPoolGradGradV1(node) || IsMaxPoolGradGradV2(node);
}
bool IsDefaultLayoutAgnosticOp(const NodeDef& node) {
std::set<string> agnostic_nodes = {"Abs", "Acos",
"Acosh", "Angle",
"Asin", "Asinh",
"Atan", "Atanh",
"Bitcast", "Cast",
"Ceil", "CheckNumerics",
"ComplexAbs", "Conj",
"Cos", "Cosh",
"Digamma", "Elu",
"Enter", "Erf",
"Erfc", "Exit",
"Exp", "Expm1",
"Floor", "GuaranteeConst",
"Identity", "Imag",
"Inv", "IsFinite",
"IsInf", "IsNan",
"Lgamma", "Log",
"LogicalNot", "Log1p",
"Neg", "NextIteration",
"OnesLike", "PreventGradient",
"Real", "Reciprocal",
"Relu", "Relu6",
"Rint", "Selu",
"Sigmoid", "Sign",
"Sin", "Sinh",
"Snapshot", "Softplus",
"Round", "Rsqrt",
"Sqrt", "Square",
"StopGradient", "Tan",
"Tanh", "ZerosLike"};
return agnostic_nodes.find(node.op()) != agnostic_nodes.end();
}
bool IsLayoutAgnosticOp(const NodeDef& node) {
return IsDefaultLayoutAgnosticOp(node) || IsAddN(node) || IsBinaryOp(node) ||
IsIdentityN(node) || IsMerge(node) || IsMirrorPad(node) ||
IsMirrorPadGrad(node) || IsPad(node) || IsSelect(node) ||
IsSwitch(node) || IsTernaryOp(node) || IsUnaryGrad(node) ||
IsConcat(node) || IsReverseV2(node) || IsTile(node) || IsShape(node) ||
IsShapeN(node) || IsFill(node) || IsSlice(node) || IsSplit(node) ||
IsSqueeze(node) || IsSplitV(node) || IsStridedSlice(node) ||
IsReduceOp(node);
}
bool IsTernaryOp(const NodeDef& node) { return IsBetainc(node); }
bool IsUnaryGrad(const NodeDef& node) {
bool is_unary_grad =
IsEluGrad(node) || IsInvGrad(node) || IsReciprocalGrad(node) ||
IsRelu6Grad(node) || IsReluGrad(node) || IsRsqrtGrad(node) ||
IsSeluGrad(node) || IsSigmoidGrad(node) || IsSoftplusGrad(node) ||
IsSoftsignGrad(node) || IsSqrtGrad(node) || IsTanhGrad(node);
return is_unary_grad;
}
bool IsMaxPoolV2(const NodeDef& node) { return node.op() == "MaxPoolV2"; }
bool IsMaxPoolGradV2(const NodeDef& node) {
return node.op() == "MaxPoolGradV2";
}
bool IsMaxPoolGradGradV1(const NodeDef& node) {
return node.op() == "MaxPoolGradGrad";
}
bool IsMaxPoolGradGradV2(const NodeDef& node) {
return node.op() == "MaxPoolGradGradV2";
}
bool IsBinaryOp(const NodeDef& node) {
bool is_binary =
IsAdd(node) || IsAtan2(node) || IsComparisonOp(node) || IsComplex(node) ||
IsDiv(node) || IsFloorDiv(node) || IsIgamma(node) || IsIgammac(node) ||
IsLogicalAnd(node) || IsLogicalOr(node) || IsMaximum(node) ||
IsMinimum(node) || IsMod(node) || IsMul(node) || IsPolygamma(node) ||
IsPow(node) || IsRealDiv(node) || IsSquaredDifference(node) ||
IsSub(node) || IsTruncateDiv(node) || IsTruncateMod(node) || IsZeta(node);
return is_binary;
}
bool IsReduceOp(const NodeDef& node) {
return IsSum(node) || IsMean(node) || IsProd(node) || IsMax(node) ||
IsMin(node) || IsAll(node) || IsAny(node);
}
std::vector<int> GetDataFaninPorts(const utils::MutableNodeView& node) {
const auto* node_def = node.node();
if (IsAvgPoolGrad(*node_def) || IsSplit(*node_def)) {
return {1};
}
if (IsStridedSliceGrad(*node_def)) {
return {4};
}
if (IsBinaryOp(*node_def) || IsUnaryGrad(*node_def)) {
return {0, 1};
}
if (IsTernaryOp(*node_def) || IsSelect(*node_def) ||
IsMaxPoolGrad(*node_def) || IsMaxPoolGradV2(*node_def) ||
IsMaxPoolGradGradV1(*node_def) || IsMaxPoolGradGradV2(*node_def)) {
return {0, 1, 2};
}
if (IsShapeN(*node_def) || IsIdentityN(*node_def) || IsAddN(*node_def) ||
IsMerge(*node_def)) {
return GetRegularFaninPorts(node);
}
if (IsConcat(*node_def)) {
return GetConcatDataFaninPorts(node);
}
if (node.NumRegularFanins() > 0) {
return {0};
}
return {};
}
std::vector<int> GetDataFanoutPorts(const utils::MutableNodeView& node) {
const auto* node_def = node.node();
if (IsIdentityN(*node_def) || IsShape(*node_def) || IsShapeN(*node_def)) {
return GetDataFaninPorts(node);
}
if (IsSplit(*node_def) || IsSplitV(*node_def)) {
const auto* num_split_attr = node.GetAttr(kAttrNumSplit);
if (num_split_attr == nullptr) {
return {0};
}
std::vector<int> values(num_split_attr->i());
std::iota(values.begin(), values.end(), 0);
return values;
}
if (IsSwitch(*node_def)) {
const auto* num_outs_attr = node.GetAttr(kAttrNumOuts);
const int num_outs = num_outs_attr != nullptr ? num_outs_attr->i() : 2;
std::vector<int> values(num_outs);
std::iota(values.begin(), values.end(), 0);
return values;
}
return {0};
}
bool GetValueAttrFromConstInputNode(
const utils::MutableNodeView& node,
const std::function<bool(const NodeDef&)>& predicate, int index,
Tensor* tensor) {
if (!predicate(*node.node())) {
return false;
}
const auto& regular_fanin = node.GetRegularFanin(index);
auto* regular_fanin_node = regular_fanin.node_view();
if (!IsConstant(*regular_fanin_node->node())) {
return false;
}
const auto* value_attr = regular_fanin_node->GetAttr(kAttrValue);
if (value_attr == nullptr || value_attr->tensor().dtype() != DT_INT32) {
return false;
}
if (!tensor->FromProto(value_attr->tensor())) {
return false;
}
return true;
}
bool IsDataFormatOp(const utils::MutableNodeView& node) {
const string& op = node.GetOp();
return op == kOpDataFormatDimMap || op == kOpDataFormatVecPermute;
}
absl::flat_hash_map<char, int> GetDimensionIndices(
absl::string_view data_format) {
const int size = data_format.size();
absl::flat_hash_map<char, int> index;
index.reserve(size);
for (int i = 0; i < size; i++) {
index[data_format[i]] = i;
}
return index;
}
std::vector<int> GetPermutation(
const absl::flat_hash_map<char, int>& src_dim_indices,
absl::string_view dst_format) {
// Generate permutation for transformation between src and dst format.
// Example:
// src = NWHC, dst = NCWH
// index = { N:0 W:1 H:2 C:3 }
// permutation = [0, 3, 1, 2]
DCHECK(src_dim_indices.size() == dst_format.size());
std::vector<int> permutation;
const int size = dst_format.size();
permutation.reserve(size);
for (int i = 0; i < size; i++) {
permutation.push_back(src_dim_indices.at(dst_format[i]));
}
return permutation;
}
} // namespace grappler
} // namespace tensorflow