[NFC] Remove unused layout_optimizer grappler optimization. There is generic_layout_optimizer which is currently in use.
PiperOrigin-RevId: 396457524
Change-Id: Ideac28247aa6e3cd28e4fac14ff28de39f3f003f
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index e1f6626..8008679 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -565,55 +565,6 @@
)
cc_library(
- name = "layout_optimizer",
- srcs = ["layout_optimizer.cc"],
- hdrs = [
- "layout_optimizer.h",
- ],
- visibility = ["//visibility:public"],
- deps = [
- ":graph_optimizer",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core/grappler:devices",
- "//tensorflow/core/grappler:grappler_item",
- "//tensorflow/core/grappler:op_types",
- "//tensorflow/core/grappler:utils",
- "//tensorflow/core/grappler/clusters:cluster",
- "//tensorflow/core/grappler/costs:graph_properties",
- "//tensorflow/core/grappler/costs:virtual_placer",
- "//tensorflow/core/grappler/utils:frame",
- "@com_google_absl//absl/strings",
- ],
-)
-
-tf_cuda_cc_test(
- name = "layout_optimizer_test",
- srcs = ["layout_optimizer_test.cc"],
- deps = [
- ":layout_optimizer",
- "//tensorflow/cc:cc_ops",
- "//tensorflow/cc:cc_ops_internal",
- "//tensorflow/core:all_kernels",
- "//tensorflow/core:core_cpu",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- "//tensorflow/core:testlib",
- "//tensorflow/core/grappler:devices",
- "//tensorflow/core/grappler:grappler_item",
- "//tensorflow/core/grappler:utils",
- "//tensorflow/core/grappler/clusters:single_machine",
- "//tensorflow/core/grappler/clusters:virtual_cluster",
- "//tensorflow/core/grappler/costs:virtual_placer",
- "//tensorflow/core/grappler/utils:grappler_test",
- ],
-)
-
-cc_library(
name = "auto_mixed_precision",
srcs = ["auto_mixed_precision.cc"],
hdrs = [
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
deleted file mode 100644
index a379cb9..0000000
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc
+++ /dev/null
@@ -1,2287 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
-
-#include <deque>
-#include <unordered_set>
-
-#include "absl/strings/strip.h"
-#include "tensorflow/core/framework/attr_value.pb.h"
-#include "tensorflow/core/framework/memory_types.h"
-#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/tensor.pb.h"
-#include "tensorflow/core/framework/tensor_shape.pb.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/grappler/clusters/cluster.h"
-#include "tensorflow/core/grappler/devices.h"
-#include "tensorflow/core/grappler/grappler_item.h"
-#include "tensorflow/core/grappler/op_types.h"
-#include "tensorflow/core/grappler/utils.h"
-#include "tensorflow/core/grappler/utils/frame.h"
-#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/util/device_name_utils.h"
-
-namespace tensorflow {
-namespace grappler {
-namespace {
-
-const char kSuffix[] = "LayoutOptimizer";
-const char kDefaultDevice[] = "DefaultDevice";
-const char kPermNHWCToNCHW[] = "PermConstNHWCToNCHW";
-const char kPermNCHWToNHWC[] = "PermConstNCHWToNHWC";
-const char kTransposeNHWCToNCHW[] = "TransposeNHWCToNCHW";
-const char kTransposeNCHWToNHWC[] = "TransposeNCHWToNHWC";
-const char kDimMapNHWCToNCHW[] = "DimMapNHWCToNCHW";
-const char kDimMapNCHWToNHWC[] = "DimMapNCHWToNHWC";
-const char kVecPermuteNHWCToNCHW[] = "VecPermuteNHWCToNCHW";
-const char kVecPermuteNCHWToNHWC[] = "VecPermuteNCHWToNHWC";
-const char kReshapeNHWCToNCHW[] = "ReshapeNHWCToNCHW";
-const char kReshapeConst[] = "ReshapeConst";
-
-std::set<string> GetOpsFormatSupported() {
- std::set<string> ops_format_supported = {
- "AvgPool",
- "AvgPoolGrad",
- "Conv2D",
- "Conv2DBackpropFilter",
- "Conv2DBackpropInput",
- "BiasAdd",
- "BiasAddGrad",
- "DepthwiseConv2dNative",
- "DepthwiseConv2dNativeBackpropInput",
- "DepthwiseConv2dNativeBackpropFilter",
- "FusedBatchNorm",
- "FusedBatchNormV2",
- "FusedBatchNormV3",
- "FusedBatchNormGrad",
- "FusedBatchNormGradV2",
- "FusedBatchNormGradV3",
- "FusedConv2DBiasActivation",
- "MaxPool",
- "MaxPoolV2",
- "MaxPoolGrad",
- "MaxPoolGradGrad",
- "MaxPoolGradV2",
- "MaxPoolGradGradV2",
- "SpaceToDepth",
- "DepthToSpace"};
- return ops_format_supported;
-}
-
-std::set<string> GetOpsFormatAgnostic() {
- std::set<string> ops_format_agnostic = {"Abs",
- "Add",
- "AddN",
- "AddV2",
- "Acos",
- "Acosh",
- "All",
- "Angle",
- "Any",
- "ApproximateEqual",
- "Asin",
- "Asinh",
- "Atan",
- "Atan2",
- "Atanh",
- "Betainc",
- "Bitcast",
- "Cast",
- "Ceil",
- "CheckNumerics",
- "Complex",
- "ComplexAbs",
- "Concat",
- "ConcatV2",
- "Conj",
- "Cos",
- "Cosh",
- "Digamma",
- "Div",
- "Elu",
- "EluGrad",
- "Enter",
- "Equal",
- "Erf",
- "Erfc",
- "Exit",
- "Exp",
- "Expm1",
- "FakeQuantWithMinMaxVars",
- "FakeQuantWithMinMaxArgs",
- "Fill",
- "Floor",
- "FloorDiv",
- "FloorMod",
- "Greater",
- "GreaterEqual",
- "GuaranteeConst",
- "HistogramSummary",
- "Identity",
- "IdentityN",
- "Igamma",
- "Igammac",
- "Imag",
- "Inv",
- "InvGrad",
- "IsFinite",
- "IsInf",
- "IsNan",
- "Less",
- "LessEqual",
- "Lgamma",
- "Log",
- "LogicalAnd",
- "LogicalNot",
- "LogicalOr",
- "Log1p",
- "Max",
- "Maximum",
- "Mean",
- "Merge",
- "Min",
- "Minimum",
- "Mod",
- "Mul",
- "Neg",
- "NextIteration",
- "NotEqual",
- "OnesLike",
- "Pad",
- "PreventGradient",
- "Prod",
- "Polygamma",
- "QuantizeAndDequantizeV2",
- "QuantizeAndDequantizeV3",
- "QuantizeAndDequantizeV4",
- "Pow",
- "Real",
- "RealDiv",
- "Reciprocal",
- "ReciprocalGrad",
- "Relu",
- "Relu6",
- "Relu6Grad",
- "ReluGrad",
- "Rint",
- "Select",
- "SelectV2",
- "Selu",
- "SeluGrad",
- "Shape",
- "ShapeN",
- "Sigmoid",
- "SigmoidGrad",
- "Sign",
- "Sin",
- "Sinh",
- "Slice",
- "Snapshot",
- "Softplus",
- "SoftplusGrad",
- "Split",
- "SplitV",
- "StridedSlice",
- "StridedSliceGrad",
- "Switch",
- "_SwitchN",
- "Tile",
- "TruncateDiv",
- "TruncateMod",
- "ReverseV2",
- "Round",
- "Rsqrt",
- "RsqrtGrad",
- "Sqrt",
- "SqrtGrad",
- "Square",
- "SquaredDifference",
- "Squeeze",
- "StopGradient",
- "Sub",
- "Sum",
- "Tan",
- "Tanh",
- "TanhGrad",
- "ZerosLike",
- "Zeta"};
- return ops_format_agnostic;
-}
-
-bool EndWith(const string& str, const string& ending) {
- if (str.size() < ending.size()) return false;
- if (str.substr(str.size() - ending.size(), ending.size()) == ending)
- return true;
- return false;
-}
-
-bool IsNodeByLayoutOptimizer(const string& node_name) {
- const string suffix = kSuffix;
- return EndWith(node_name, suffix);
-}
-
-bool IsNodeType(const string& node_name, const string& type) {
- const string suffix = strings::StrCat(type, "-", kSuffix);
- return EndWith(node_name, suffix);
-}
-
-bool IsTransposeNHWCToNCHW(const string& node_name) {
- return IsNodeType(node_name, kTransposeNHWCToNCHW);
-}
-
-bool IsTransposeNCHWToNHWC(const string& node_name) {
- return IsNodeType(node_name, kTransposeNCHWToNHWC);
-}
-
-bool IsDimMapNHWCToNCHW(const string& node_name) {
- return IsNodeType(node_name, kDimMapNHWCToNCHW);
-}
-
-bool IsDimMapNCHWToNHWC(const string& node_name) {
- return IsNodeType(node_name, kDimMapNCHWToNHWC);
-}
-
-bool IsVecPermuteNHWCToNCHW(const string& node_name) {
- return IsNodeType(node_name, kVecPermuteNHWCToNCHW);
-}
-
-bool IsVecPermuteNCHWToNHWC(const string& node_name) {
- return IsNodeType(node_name, kVecPermuteNCHWToNHWC);
-}
-
-bool IsConcat(const NodeDef& node) {
- const auto op = node.op();
- return op == "Concat" || op == "ConcatV2";
-}
-
-bool IsConcatV1(const NodeDef& node) {
- const auto op = node.op();
- return op == "Concat";
-}
-
-bool IsMaxPoolV2(const NodeDef& node) {
- const auto& op = node.op();
- return op == "MaxPoolV2";
-}
-
-bool IsMaxPoolGradV1(const NodeDef& node) {
- const auto& op = node.op();
- return op == "MaxPoolGrad";
-}
-
-bool IsMaxPoolGradV2(const NodeDef& node) {
- const auto& op = node.op();
- return op == "MaxPoolGradV2";
-}
-
-bool IsMaxPoolGradGradV1(const NodeDef& node) {
- const auto& op = node.op();
- return op == "MaxPoolGradGrad";
-}
-
-bool IsMaxPoolGradGradV2(const NodeDef& node) {
- const auto& op = node.op();
- return op == "MaxPoolGradGradV2";
-}
-
-bool IsUnaryGrad(const NodeDef& node) {
- bool is_unary_grad =
- IsEluGrad(node) || IsInvGrad(node) || IsReciprocalGrad(node) ||
- IsRelu6Grad(node) || IsReluGrad(node) || IsRsqrtGrad(node) ||
- IsSeluGrad(node) || IsSigmoidGrad(node) || IsSoftplusGrad(node) ||
- IsSoftsignGrad(node) || IsSqrtGrad(node) || IsTanhGrad(node);
- return is_unary_grad;
-}
-
-bool IsComparisonOp(const NodeDef& node) {
- bool is_compare = IsApproximateEqual(node) || IsEqual(node) ||
- IsGreater(node) || IsGreaterEqual(node) || IsLess(node) ||
- IsLessEqual(node) || IsNotEqual(node);
- return is_compare;
-}
-
-bool IsReduceOp(const NodeDef& node) {
- return IsSum(node) || IsMean(node) || IsProd(node) || IsMax(node) ||
- IsMin(node) || IsAll(node) || IsAny(node);
-}
-
-bool IsBinaryOp(const NodeDef& node) {
- bool is_binary =
- IsAdd(node) || IsAtan2(node) || IsComparisonOp(node) || IsComplex(node) ||
- IsDiv(node) || IsFloorDiv(node) || IsIgamma(node) || IsIgammac(node) ||
- IsLogicalAnd(node) || IsLogicalOr(node) || IsMaximum(node) ||
- IsMinimum(node) || IsMod(node) || IsMul(node) || IsPolygamma(node) ||
- IsPow(node) || IsRealDiv(node) || IsSquaredDifference(node) ||
- IsSub(node) || IsTruncateDiv(node) || IsTruncateMod(node) || IsZeta(node);
- return is_binary;
-}
-
-std::vector<int> NonControlInputs(const NodeDef& node) {
- std::vector<int> pos;
- for (int i = 0; i < node.input_size(); i++) {
- if (!IsControlInput(node.input(i))) {
- pos.push_back(i);
- }
- }
- return pos;
-}
-
-std::vector<int> DataInputPosConcat(const NodeDef& node) {
- int n = node.attr().at("N").i();
- std::vector<int> input_pos;
- int start = (IsConcatV1(node)) ? 1 : 0;
- int end = start + n;
- for (int i = start; i < end; i++) {
- input_pos.push_back(i);
- }
- return input_pos;
-}
-
-std::vector<int> DataInputPos(const NodeDef& node) {
- if (IsSplit(node) || IsHistogramSummary(node)) {
- return {1};
- }
- if (IsStridedSliceGrad(node)) {
- return {4};
- }
- if (IsBinaryOp(node) || IsUnaryGrad(node)) {
- return {0, 1};
- }
- if (IsBetainc(node) || IsSelect(node)) {
- return {0, 1, 2};
- }
- if (IsShapeN(node) || IsIdentityN(node) || IsAddN(node) || IsMerge(node)) {
- return NonControlInputs(node);
- }
- if (IsConcat(node)) {
- return DataInputPosConcat(node);
- }
- if (node.input_size() > 0 && !IsControlInput(node.input(0))) {
- return {0};
- }
- return {};
-}
-
-bool IsHostMemory(const NodeDef& node, int output_port) {
- DeviceNameUtils::ParsedName parsed_name;
- if (DeviceNameUtils::ParseFullName(node.device(), &parsed_name)) {
- DeviceType device_type(parsed_name.type);
- Status s = FindKernelDef(device_type, node, nullptr, nullptr);
- if (s.ok()) {
- tensorflow::MemoryTypeVector in_mtypes;
- tensorflow::MemoryTypeVector out_mtypes;
- s = tensorflow::MemoryTypesForNode(OpRegistry::Global(), device_type,
- node, &in_mtypes, &out_mtypes);
- if (s.ok()) {
- if (out_mtypes[output_port] == HOST_MEMORY) {
- return true;
- }
- }
- } else {
- return true;
- }
- }
- return false;
-}
-
-class GraphProcessor {
- public:
- GraphProcessor(const GraphProperties& graph_properties,
- const VirtualPlacer& virtual_placer,
- const std::unordered_set<string>& nodes_to_preserve,
- GraphDef* graph, NodeMap* node_map)
- : graph_properties_(graph_properties),
- virtual_placer_(virtual_placer),
- nodes_to_preserve_(nodes_to_preserve),
- graph_(graph),
- node_map_(node_map) {}
-
- protected:
- NodeDef* AddNodePermConst(const string& name, const string& device,
- const std::vector<int>& permutation) {
- NodeDef* node = graph_->add_node();
- node_map_->AddNode(name, node);
- node->set_name(name);
- node->set_op("Const");
- AttrValue attr_data_type;
- attr_data_type.set_type(DT_INT32);
- node->mutable_attr()->insert({"dtype", attr_data_type});
- AttrValue attr_tensor;
- Tensor tensor(DT_INT32, TensorShape({4}));
- for (int i = 0; static_cast<size_t>(i) < permutation.size(); i++) {
- tensor.flat<int>()(i) = permutation[i];
- }
- tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
- node->mutable_attr()->insert({"value", attr_tensor});
- string device_name;
- if (device.empty()) {
- device_name = virtual_placer_.get_canonical_device_name(*node);
- } else {
- device_name = device;
- }
- node->set_device(device_name);
- return node;
- }
-
- NodeDef* AddNodeConstScalar(const string& name, const string& device,
- DataType dtype, int value) {
- NodeDef* node = graph_->add_node();
- node_map_->AddNode(name, node);
- node->set_name(name);
- node->set_op("Const");
- AttrValue attr_data_type;
- attr_data_type.set_type(dtype);
- node->mutable_attr()->insert({"dtype", attr_data_type});
- AttrValue attr_tensor;
- Tensor tensor(dtype, TensorShape({}));
- tensor.scalar<int>()() = value;
- tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
- node->mutable_attr()->insert({"value", attr_tensor});
- string device_name;
- if (device.empty()) {
- device_name = virtual_placer_.get_canonical_device_name(*node);
- } else {
- device_name = device;
- }
- node->set_device(device_name);
- return node;
- }
-
- string LayoutOptimizerNode(const string& base_name) {
- return strings::StrCat(base_name, "-", kSuffix);
- }
-
- const GraphProperties& graph_properties_;
- const VirtualPlacer& virtual_placer_;
- const std::unordered_set<string>& nodes_to_preserve_;
- GraphDef* graph_;
- NodeMap* node_map_;
-};
-
-struct OptimizeContext {
- OptimizeContext(GraphDef* graph, NodeDef* node, NodeMap* node_map,
- const GraphProperties& graph_properties,
- const VirtualPlacer& virtual_placer,
- const std::unordered_set<string>& nodes_to_preserve,
- bool is_in_frame,
- std::unordered_set<string>* devices_with_perm_const)
- : graph(graph),
- node(node),
- node_map(node_map),
- graph_properties(graph_properties),
- virtual_placer(virtual_placer),
- nodes_to_preserve(nodes_to_preserve),
- is_in_frame(is_in_frame),
- devices_with_perm_const(devices_with_perm_const) {}
- GraphDef* graph;
- NodeDef* node;
- NodeMap* node_map;
- const GraphProperties& graph_properties;
- const VirtualPlacer& virtual_placer;
- const std::unordered_set<string>& nodes_to_preserve;
- bool is_in_frame;
- std::unordered_set<string>* devices_with_perm_const; // not owned
-};
-
-class NodeProcessor : public GraphProcessor {
- public:
- explicit NodeProcessor(const OptimizeContext& opt_cxt)
- : GraphProcessor(opt_cxt.graph_properties, opt_cxt.virtual_placer,
- opt_cxt.nodes_to_preserve, opt_cxt.graph,
- opt_cxt.node_map),
- node_(opt_cxt.node),
- is_in_frame_(opt_cxt.is_in_frame),
- devices_with_perm_const_(opt_cxt.devices_with_perm_const) {}
- virtual ~NodeProcessor() {}
- virtual Status ConvertNode() {
- if (ShouldProcess()) {
- UpdateAttrDataFormat();
- UpdateAttrKSize();
- UpdateAttrStrides();
- UpdateAttrDilations();
- UpdateAttrExplicitPaddings();
- UpdateAttrShape();
- TF_RETURN_IF_ERROR(AddLayoutTransposeToInputs());
- TF_RETURN_IF_ERROR(AddLayoutTransposeToOutputs());
- TF_RETURN_IF_ERROR(CustomizedProcessing());
- }
- return Status::OK();
- }
-
- protected:
- bool IsPortDimsN(const NodeDef& node, int port, int n) const {
- if (node.attr().find("_output_shapes") != node.attr().end()) {
- if (node.attr().at("_output_shapes").list().shape_size() > port) {
- auto shape = node.attr().at("_output_shapes").list().shape(port);
- if (shape.unknown_rank()) {
- return false;
- }
- if (shape.dim_size() == n) {
- return true;
- }
- }
- }
- return false;
- }
-
- bool IsPortZeroDimsN(const NodeDef& node, int n) const {
- return IsPortDimsN(node, 0, n);
- }
-
- bool IsPortZeroDimsFour(const NodeDef& node) const {
- return NodeProcessor::IsPortZeroDimsN(node, 4) ||
- IsTransposeNCHWToNHWC(node.name());
- }
-
- bool IsPortDimsFour(const NodeDef& node, int port) const {
- return NodeProcessor::IsPortDimsN(node, port, 4) ||
- IsTransposeNCHWToNHWC(node.name());
- }
-
- bool IsNHWC() const {
- if (node_->attr().find("data_format") != node_->attr().end()) {
- if (node_->attr().at("data_format").s().compare("NHWC") == 0) {
- return true;
- }
- }
- return false;
- }
-
- bool HasOutputs() const {
- auto outputs = node_map_->GetOutputs(node_->name());
- return !outputs.empty();
- }
-
- Status HasAttribute(const NodeDef& node, const string& attr) const {
- if (node.attr().find(attr) == node.attr().end()) {
- return Status(error::INVALID_ARGUMENT,
- strings::StrCat("Missing attribute ", attr));
- }
- return Status::OK();
- }
-
- bool MustPreserve() const {
- return nodes_to_preserve_.find(node_->name()) != nodes_to_preserve_.end();
- }
-
- bool IsOnGPU() const {
- string device_name;
- if (node_->device().empty()) {
- device_name = virtual_placer_.get_canonical_device_name(*node_);
- } else {
- device_name = node_->device();
- }
- string device;
- string not_used;
- if (DeviceNameUtils::SplitDeviceName(device_name, ¬_used, &device) &&
- absl::StrContains(absl::AsciiStrToLower(device),
- absl::AsciiStrToLower(DEVICE_GPU))) {
- return true;
- }
- return false;
- }
-
- virtual bool ShouldProcess() const {
- return !MustPreserve() && IsNHWC() && IsPortZeroDimsFour(*node_) &&
- HasOutputs() && IsOnGPU();
- }
-
- virtual void UpdateAttrShape() {
- if (node_->attr().find("_output_shapes") != node_->attr().end()) {
- for (const auto& pos : GetOutputPos()) {
- auto shape = node_->mutable_attr()
- ->at("_output_shapes")
- .mutable_list()
- ->mutable_shape(pos);
- if (shape->dim_size() == 4) {
- int64_t h = shape->dim(1).size();
- int64_t w = shape->dim(2).size();
- int64_t c = shape->dim(3).size();
- shape->mutable_dim(1)->set_size(c);
- shape->mutable_dim(2)->set_size(h);
- shape->mutable_dim(3)->set_size(w);
- }
- }
- }
- }
-
- Status UpdateAttrValueOfInput(int input_index, bool permute) {
- auto input_node = node_map_->GetNode(node_->input(input_index));
- // We created a copy of the node, so that we don't modify the original node,
- // which might be used elsewhere. Note that this copy also copies the
- // control dependency input in the case this node is inside a loop,
- // to ensure added_node is in the same frame with node_.
- NodeDef* added_node = graph_->add_node();
- *added_node = *input_node;
- string base_name = strings::StrCat(node_->name(), "-", input_index);
- string node_name = LayoutOptimizerNode(base_name);
- added_node->set_name(node_name);
- *node_->mutable_input(input_index) = node_name;
- node_map_->AddNode(node_name, added_node);
- node_map_->AddOutput(node_name, node_->name());
- return UpdateAttrValue(added_node, permute);
- }
-
- virtual std::vector<int> GetInputPos() const { return {0}; }
-
- virtual std::set<int> GetOutputPos() const {
- // For most nodes, no need to process control nodes or nodes that use an
- // output other than the first output: only the first output is of
- // 4D NCHW/NHWC format and thus relevant here.
- std::set<int> output_pos = {0};
- return output_pos;
- }
-
- virtual Status AddLayoutTransposeToInputs() {
- std::vector<int> input_pos = GetInputPos();
- for (const auto& pos : input_pos) {
- string node_name = LayoutOptimizerNode(
- strings::StrCat(node_->name(), "-", pos, "-", kTransposeNHWCToNCHW));
- DataType dtype =
- graph_properties_.GetInputProperties(node_->name())[pos].dtype();
- auto input_node = node_map_->GetNode(node_->input(pos));
- TF_RETURN_IF_ERROR(HasAttribute(*input_node, "_output_shapes"));
- string const_name = GetOrAddNodePermNHWCToNCHW(pos);
- int output_pos;
- ParseNodeName(node_->input(pos), &output_pos);
- AddNodeTranspose(
- node_name, node_->input(pos), const_name, dtype,
- input_node->attr().at("_output_shapes").list().shape(output_pos),
- true);
- node_map_->UpdateOutput(NodeName(node_->input(pos)), node_->name(),
- node_name);
- node_map_->AddOutput(node_name, node_->name());
- *node_->mutable_input(pos) = node_name;
- }
- return Status::OK();
- }
-
- Status AddTransformToOutputs(const string& op) {
- auto outputs = node_map_->GetOutputs(node_->name());
- string const_name = GetOrAddNodePermNCHWToNHWC();
- int output_count = 0;
- for (const auto& output : outputs) {
- int connections = 0;
- int connections_removed = 0;
- for (int i = 0; i < output->input_size(); i++) {
- auto& input = *output->mutable_input(i);
- int input_port;
- string input_name = ParseNodeName(input, &input_port);
- auto output_pos = GetOutputPos();
- if (input_name == node_->name()) {
- connections++;
- if (output_pos.find(input_port) != output_pos.end()) {
- connections_removed++;
- string added_node_base_name =
- strings::StrCat(node_->name(), "-", output_count, "-", i);
- string added_node_name;
- DataType dtype =
- graph_properties_.GetOutputProperties(node_->name())[input_port]
- .dtype();
- if (op == "Transpose") {
- added_node_name = LayoutOptimizerNode(strings::StrCat(
- added_node_base_name, "-", kTransposeNCHWToNHWC));
- TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes"));
- AddNodeTranspose(
- added_node_name, input, const_name, dtype,
- node_->attr().at("_output_shapes").list().shape(input_port),
- false);
- } else if (op == "DataFormatVecPermute") {
- added_node_name = LayoutOptimizerNode(strings::StrCat(
- added_node_base_name, "-", kVecPermuteNCHWToNHWC));
- AddNodeDataFormatOp(added_node_name, input, op, dtype, false);
- } else {
- return errors::InvalidArgument("Unsupported op type: ", op);
- }
- input = added_node_name;
- node_map_->AddOutput(node_->name(), added_node_name);
- node_map_->AddOutput(added_node_name, output->name());
- }
- }
- }
- if (connections == connections_removed) {
- node_map_->RemoveOutput(node_->name(), output->name());
- }
- output_count++;
- }
- return Status::OK();
- }
-
- virtual Status AddLayoutTransposeToOutputs() {
- return AddTransformToOutputs("Transpose");
- }
-
- virtual Status CustomizedProcessing() { return Status::OK(); }
-
- Status UpdateOrTransformParamInput(int param_index, const string& op,
- DataType dtype) {
- auto param_node = node_map_->GetNode(node_->input(param_index));
- bool permute = (op == "DataFormatVecPermute") ? true : false;
- if (IsConstant(*param_node)) {
- TF_RETURN_IF_ERROR(UpdateAttrValueOfInput(param_index, permute));
- } else {
- AddDataFormatTransformToParamInput(op, param_index, dtype);
- }
- return Status::OK();
- }
-
- NodeDef* node_;
- bool is_in_frame_;
- std::unordered_set<string>* devices_with_perm_const_; // not owned.
-
- private:
- string CompliantDeviceName(const string& device) {
- if (device.empty()) return string(kDefaultDevice);
- string ret(device);
- std::replace(ret.begin(), ret.end(), '/', '_');
- std::replace(ret.begin(), ret.end(), ':', '_');
- return string(absl::StripPrefix(ret, "_"));
- }
-
- void UpdateAttrKSize() {
- if (node_->attr().find("ksize") != node_->attr().end()) {
- auto list = node_->mutable_attr()->at("ksize").mutable_list();
- UpdateTuple(list);
- }
- }
-
- void UpdateAttrStrides() {
- if (node_->attr().find("strides") != node_->attr().end()) {
- auto list = node_->mutable_attr()->at("strides").mutable_list();
- UpdateTuple(list);
- }
- }
-
- void UpdateAttrDilations() {
- if (node_->attr().find("dilations") != node_->attr().end()) {
- auto list = node_->mutable_attr()->at("dilations").mutable_list();
- UpdateTuple(list);
- }
- }
-
- void UpdateAttrExplicitPaddings() {
- if (node_->attr().find("explicit_paddings") != node_->attr().end()) {
- auto list = node_->mutable_attr()->at("explicit_paddings").mutable_list();
- int size = list->i_size();
- if (size == 8) {
- int64_t height_before = list->i(2);
- int64_t height_after = list->i(3);
- int64_t width_before = list->i(4);
- int64_t width_after = list->i(5);
- list->set_i(2, 0);
- list->set_i(3, 0);
- list->set_i(4, height_before);
- list->set_i(5, height_after);
- list->set_i(6, width_before);
- list->set_i(7, width_after);
- } else if (size != 0) {
- LOG(ERROR) << "Cannot handle explicit_paddings attribute of size "
- << size;
- }
- }
- }
-
- void UpdateAttrDataFormat() {
- if (node_->attr().find("data_format") != node_->attr().end()) {
- if (node_->attr().at("data_format").s().compare("NHWC") == 0) {
- string* data_format =
- node_->mutable_attr()->at("data_format").mutable_s();
- *data_format = "NCHW";
- }
- }
- }
-
- Status UpdateAttrValue(NodeDef* node, bool permute) {
- TF_RETURN_IF_ERROR(HasAttribute(*node, "value"));
- Tensor tensor;
- auto success =
- tensor.FromProto(node->mutable_attr()->at({"value"}).tensor());
- if (!success) {
- LOG(ERROR) << "Failed to parse TensorProto.";
- }
-
- if (permute) {
- if (tensor.dims() == 1) {
- if (tensor.flat<int>().size() == 4) {
- int c = tensor.flat<int>()(3);
- tensor.flat<int>()(3) = tensor.flat<int>()(2);
- tensor.flat<int>()(2) = tensor.flat<int>()(1);
- tensor.flat<int>()(1) = c;
- } else {
- return Status(error::INVALID_ARGUMENT,
- strings::StrCat("Unsupported tensor size: ",
- tensor.flat<int>().size()));
- }
- } else if (tensor.dims() == 2) {
- for (int i = 0; i < 2; i++) {
- int c = tensor.matrix<int>()(3, i);
- tensor.matrix<int>()(3, i) = tensor.matrix<int>()(2, i);
- tensor.matrix<int>()(2, i) = tensor.matrix<int>()(1, i);
- tensor.matrix<int>()(1, i) = c;
- }
- } else {
- return Status(
- error::INVALID_ARGUMENT,
- strings::StrCat("Unsupported dimension size: ", tensor.dims()));
- }
- } else {
- for (int i = 0; i < tensor.flat<int>().size(); i++) {
- int value = tensor.flat<int>()(i);
- value = (value >= 0) ? value : value + 4;
- if (value == 1 || value == 2) {
- value = value + 1;
- } else if (value == 3) {
- value = 1;
- }
- tensor.flat<int>()(i) = value;
- }
- }
-
- if (tensor.dims() == 0) {
- tensor.AsProtoField(node->mutable_attr()->at({"value"}).mutable_tensor());
- } else {
- tensor.AsProtoTensorContent(
- node->mutable_attr()->at({"value"}).mutable_tensor());
- }
- return Status::OK();
- }
-
- NodeDef* AddNodeTranspose(const string& node_name, const string& input_name,
- const string& const_name, DataType data_type,
- const TensorShapeProto& input_shape,
- bool NHWCToNCHW) {
- NodeDef* node = graph_->add_node();
- node_map_->AddNode(node_name, node);
- node->set_name(node_name);
- *node->add_input() = input_name;
- *node->add_input() = const_name;
- node->set_op("Transpose");
- node->set_device(node_->device());
- AttrValue attr_data_type;
- attr_data_type.set_type(data_type);
- node->mutable_attr()->insert({"T", attr_data_type});
- AttrValue attr_data_type_perm;
- attr_data_type_perm.set_type(DT_INT32);
- node->mutable_attr()->insert({"Tperm", attr_data_type_perm});
- if (!input_shape.unknown_rank()) {
- AttrValue attr_output_shape;
- auto output_shape = attr_output_shape.mutable_list()->add_shape();
- if (NHWCToNCHW) {
- output_shape->add_dim()->set_size(input_shape.dim(0).size());
- output_shape->add_dim()->set_size(input_shape.dim(3).size());
- output_shape->add_dim()->set_size(input_shape.dim(1).size());
- output_shape->add_dim()->set_size(input_shape.dim(2).size());
- } else {
- output_shape->add_dim()->set_size(input_shape.dim(0).size());
- output_shape->add_dim()->set_size(input_shape.dim(2).size());
- output_shape->add_dim()->set_size(input_shape.dim(3).size());
- output_shape->add_dim()->set_size(input_shape.dim(1).size());
- }
- node->mutable_attr()->insert({"_output_shapes", attr_output_shape});
- }
- return node;
- }
-
- NodeDef* AddNodePermNHWCToNCHW(const string& base_name,
- const string& depended_node,
- const string& device) {
- string name =
- LayoutOptimizerNode(strings::StrCat(base_name, "-", kPermNHWCToNCHW));
- auto const_node = AddNodePermConst(name, device, {0, 3, 1, 2});
- // This is to ensure the transpose node and the const node are in the same
- // frame.
- *const_node->add_input() = AsControlDependency(depended_node);
- return const_node;
- }
-
- NodeDef* AddNodePermNCHWToNHWC(const string& base_name,
- const string& depended_node,
- const string& device) {
- string name =
- LayoutOptimizerNode(strings::StrCat(base_name, "-", kPermNCHWToNHWC));
- auto const_node = AddNodePermConst(name, device, {0, 2, 3, 1});
- // This is to ensure the transpose node and the const node are in the same
- // frame.
- *const_node->add_input() = AsControlDependency(depended_node);
- return const_node;
- }
-
- // NOTE(zycao): We try to make sure each device has the permutation consts
- // iff the consts are really needed. Thus no unexpected inter-worker
- // connections and no redundant nodes would be existed.
- void AddNodePermConstOnDevice(const string& device) {
- string compliant_device_prefix;
- compliant_device_prefix = CompliantDeviceName(device);
- string node_name;
- // Permutation const for NHWCToNCHW
- node_name = strings::StrCat(compliant_device_prefix, "-",
- LayoutOptimizerNode(kPermNHWCToNCHW));
- AddNodePermConst(node_name, device, {0, 3, 1, 2});
-
- // Permutation const for NCHWToNHWC
- node_name = strings::StrCat(compliant_device_prefix, "-",
- LayoutOptimizerNode(kPermNCHWToNHWC));
- AddNodePermConst(node_name, device, {0, 2, 3, 1});
- }
-
- string GetOrAddNodePermNHWCToNCHW(int pos) {
- string const_name;
- if (is_in_frame_) {
- string base_name = strings::StrCat(node_->name(), "-", pos);
- string input = NodeName(node_->input(pos));
- string depended_node;
- if (!IsTransposeNCHWToNHWC(input)) {
- depended_node = input;
- } else {
- auto input_node = node_map_->GetNode(input);
- depended_node = NodeName(input_node->input(0));
- }
- auto const_node =
- AddNodePermNHWCToNCHW(base_name, depended_node, node_->device());
- const_name = const_node->name();
- } else {
- if (devices_with_perm_const_->find(node_->device()) ==
- devices_with_perm_const_->end()) {
- AddNodePermConstOnDevice(node_->device());
- devices_with_perm_const_->insert(node_->device());
- }
- const_name = strings::StrCat(CompliantDeviceName(node_->device()), "-",
- LayoutOptimizerNode(kPermNHWCToNCHW));
- }
- return const_name;
- }
-
- string GetOrAddNodePermNCHWToNHWC() {
- string const_name;
- if (is_in_frame_) {
- auto const_node =
- AddNodePermNCHWToNHWC(node_->name(), node_->name(), node_->device());
- const_name = const_node->name();
- } else {
- if (devices_with_perm_const_->find(node_->device()) ==
- devices_with_perm_const_->end()) {
- AddNodePermConstOnDevice(node_->device());
- devices_with_perm_const_->insert(node_->device());
- }
- const_name = strings::StrCat(CompliantDeviceName(node_->device()), "-",
- LayoutOptimizerNode(kPermNCHWToNHWC));
- }
- return const_name;
- }
-
- void UpdateTuple(AttrValue_ListValue* list) {
- int64_t h = list->i(1);
- int64_t w = list->i(2);
- int64_t c = list->i(3);
- list->set_i(1, c);
- list->set_i(2, h);
- list->set_i(3, w);
- }
-
- bool IsInputOnHost(const string& input_name) const {
- string device = node_->device();
- DeviceNameUtils::ParsedName parsed_name;
- if (DeviceNameUtils::ParseFullName(device, &parsed_name)) {
- if (parsed_name.type != "CPU") {
- NodeDef* input = node_map_->GetNode(input_name);
- int port;
- ParseNodeName(input_name, &port);
- if (IsHostMemory(*input, port)) {
- return true;
- }
- }
- }
- return false;
- }
-
- NodeDef* AddNodeDataFormatOp(const string& name, const string& input_name,
- const string& op, DataType dtype,
- bool nhwc_to_nchw) {
- NodeDef* added_node = graph_->add_node();
- added_node->set_name(name);
- added_node->set_op(op);
- node_map_->AddNode(added_node->name(), added_node);
- added_node->set_device(node_->device());
- // The inputs of a DataFormat op could be in host memory for ops such as
- // Reshape. In such cases, run the kernel on the host too.
- if (IsInputOnHost(input_name)) {
- AttrValue attr_kernel;
- attr_kernel.set_s("host");
- added_node->mutable_attr()->insert({"_kernel", attr_kernel});
- }
- AttrValue attr_data_type;
- attr_data_type.set_type(dtype);
- added_node->mutable_attr()->insert({"T", attr_data_type});
- string src_format = (nhwc_to_nchw) ? "NHWC" : "NCHW";
- string dst_format = (nhwc_to_nchw) ? "NCHW" : "NHWC";
- AttrValue attr_format;
- attr_format.set_s(src_format);
- added_node->mutable_attr()->insert({"src_format", attr_format});
- attr_format.set_s(dst_format);
- added_node->mutable_attr()->insert({"dst_format", attr_format});
- *added_node->add_input() = input_name;
- return added_node;
- }
-
- void AddDataFormatTransformToParamInput(const string& op, int input_pos,
- DataType dtype) {
- string suffix = (op == "DataFormatVecPermute") ? kVecPermuteNHWCToNCHW
- : kDimMapNHWCToNCHW;
- string name = LayoutOptimizerNode(
- strings::StrCat(node_->name(), "-", input_pos, "-", suffix));
- auto added_node =
- AddNodeDataFormatOp(name, node_->input(input_pos), op, dtype, true);
- *node_->mutable_input(input_pos) = added_node->name();
- node_map_->UpdateOutput(NodeName(added_node->input(0)), node_->name(),
- added_node->name());
- node_map_->AddOutput(added_node->name(), node_->name());
- }
-};
-
-class AvgPoolGradProcessor : public NodeProcessor {
- public:
- explicit AvgPoolGradProcessor(const OptimizeContext& opt_cxt)
- : NodeProcessor(opt_cxt) {}
-
- protected:
- std::vector<int> GetInputPos() const override { return {1}; }
- Status CustomizedProcessing() override {
- return UpdateOrTransformParamInput(0, "DataFormatVecPermute", DT_INT32);
- }
-};
-
-class BiasAddGradProcessor : public NodeProcessor {
- public:
- explicit BiasAddGradProcessor(const OptimizeContext& opt_cxt)
- : NodeProcessor(opt_cxt) {}
-
- protected:
- bool ShouldProcess() const override {
- if (MustPreserve()) {
- return false;
- }
- if (!IsOnGPU()) {
- return false;
- }
- auto input = node_map_->GetNode(node_->input(0));
- if (input) {
- int port;
- ParseNodeName(node_->input(0), &port);
- if (IsNHWC() && IsPortDimsFour(*input, port)) {
- return true;
- }
- }
- return false;
- }
-
- Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
-};
-
-class Conv2DProcessor : public NodeProcessor {
- public:
- Conv2DProcessor(const OptimizeContext& opt_cxt, bool no_gemm)
- : NodeProcessor(opt_cxt), no_gemm_(no_gemm) {}
-
- protected:
- bool ShouldProcess() const override {
- return !MustPreserve() && IsNHWC() && IsPortZeroDimsFour(*node_) &&
- HasOutputs() && (!IsGemmUsed() || no_gemm_) && IsOnGPU() &&
- IsDataTypeFloat();
- }
-
- TensorShapeProto GetShape(const string& input_name) const {
- string node_name;
- int output_pos;
- node_name = ParseNodeName(input_name, &output_pos);
- NodeDef* node = node_map_->GetNode(node_name);
- if (node->attr().find("_output_shapes") != node->attr().end()) {
- return node->attr().at("_output_shapes").list().shape(output_pos);
- }
- TensorShapeProto shape;
- return shape;
- }
-
- bool IsStrideOne() const {
- if (node_->attr().find("strides") != node_->attr().end()) {
- auto list = node_->attr().at("strides").list();
- return list.i(1) == 1 && list.i(2) == 1;
- }
- return false;
- }
-
- bool IsValidPadding() const {
- if (node_->attr().find("padding") != node_->attr().end()) {
- auto padding = node_->attr().at("padding").s();
- return padding == "VALID";
- }
- return false;
- }
-
- bool IsDataTypeFloat() const {
- if (node_->attr().find("T") != node_->attr().end()) {
- return kDataTypeIsFloating.Contains(node_->attr().at("T").type());
- }
- return false;
- }
-
- // The logic inside this function is based on the internal implementation of
- // Conv2D, Conv2DBackpropInput, and Conv2DBackpropFilter ops, and thus
- // needs to be updated accordingly if the internal implementation changes.
- bool IsGemmUsed(const TensorShapeProto& filter_shape,
- const TensorShapeProto& input_shape) const {
- if (filter_shape.dim_size() == 4) {
- if (filter_shape.dim(0).size() == 1 && filter_shape.dim(1).size() == 1 &&
- IsStrideOne()) {
- return true;
- }
- }
- if (input_shape.dim_size() == 4 && filter_shape.dim_size() == 4) {
- if (input_shape.dim(1).size() == filter_shape.dim(0).size() &&
- input_shape.dim(2).size() == filter_shape.dim(1).size() &&
- IsValidPadding()) {
- return true;
- }
- }
- return false;
- }
-
- virtual bool IsGemmUsed() const {
- auto filter_shape = GetShape(node_->input(1));
- auto input_shape = GetShape(node_->input(0));
- return IsGemmUsed(filter_shape, input_shape);
- }
-
- bool no_gemm_;
-};
-
-class Conv2DBackpropFilterProcessor : public Conv2DProcessor {
- public:
- Conv2DBackpropFilterProcessor(const OptimizeContext& opt_cxt, bool no_gemm)
- : Conv2DProcessor(opt_cxt, no_gemm) {}
-
- protected:
- bool IsGemmUsed() const override {
- auto filter_shape = GetShape(node_->name());
- auto input_shape = GetShape(node_->input(0));
- return Conv2DProcessor::IsGemmUsed(filter_shape, input_shape);
- }
-
- std::vector<int> GetInputPos() const override { return {0, 2}; }
-
- Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
- // No need to update output shape, as it is always of shape
- // [filter_height, filter_width, in_channels, out_channels], regardless of
- // whether NCHW or NHWC is used.
- void UpdateAttrShape() override {}
-};
-
-class Conv2DBackpropInputProcessor : public Conv2DProcessor {
- public:
- Conv2DBackpropInputProcessor(const OptimizeContext& opt_cxt, bool no_gemm)
- : Conv2DProcessor(opt_cxt, no_gemm) {}
-
- protected:
- bool IsGemmUsed() const override {
- auto filter_shape = GetShape(node_->input(1));
- auto input_shape = GetShape(node_->name());
- return Conv2DProcessor::IsGemmUsed(filter_shape, input_shape);
- }
-
- std::vector<int> GetInputPos() const override { return {2}; }
-
- Status CustomizedProcessing() override {
- return UpdateOrTransformParamInput(0, "DataFormatVecPermute", DT_INT32);
- }
-};
-
-class FusedBatchNormGradProcessor : public NodeProcessor {
- public:
- explicit FusedBatchNormGradProcessor(const OptimizeContext& opt_cxt)
- : NodeProcessor(opt_cxt) {}
-
- protected:
- bool ShouldProcess() const override {
- return NodeProcessor::ShouldProcess() && IsTraining();
- }
-
- std::vector<int> GetInputPos() const override { return {0, 1}; }
-
- private:
- bool IsTraining() const {
- if (node_->attr().find("is_training") != node_->attr().end()) {
- if (node_->attr().at("is_training").b()) {
- return true;
- }
- }
- return false;
- }
-};
-
-class MaxPoolGradProcessor : public NodeProcessor {
- public:
- explicit MaxPoolGradProcessor(const OptimizeContext& opt_cxt)
- : NodeProcessor(opt_cxt) {}
-
- protected:
- std::vector<int> GetInputPos() const override { return {0, 1, 2}; }
-};
-
-class MaxPoolGradV2Processor : public MaxPoolGradProcessor {
- public:
- explicit MaxPoolGradV2Processor(const OptimizeContext& opt_cxt)
- : MaxPoolGradProcessor(opt_cxt) {}
-
- protected:
- Status CustomizedProcessing() override {
- for (int i = 3; i <= 4; i++) {
- TF_RETURN_IF_ERROR(
- UpdateOrTransformParamInput(i, "DataFormatVecPermute", DT_INT32));
- }
- return Status::OK();
- }
-};
-
-class MaxPoolV2Processor : public NodeProcessor {
- public:
- explicit MaxPoolV2Processor(const OptimizeContext& opt_cxt)
- : NodeProcessor(opt_cxt) {}
-
- protected:
- bool ShouldProcess() const override {
- // We check data_input's shape instead, because the shape inference of
- // MaxPoolV2 is not able to infer the shape when ksize or strides is not
- // constant.
- auto data_input = node_map_->GetNode(node_->input(0));
- int port;
- ParseNodeName(node_->input(0), &port);
- return !MustPreserve() && IsNHWC() && IsPortDimsFour(*data_input, port) &&
- HasOutputs() && IsOnGPU();
- }
-
- Status CustomizedProcessing() override {
- for (int i = 1; i <= 2; i++) {
- TF_RETURN_IF_ERROR(
- UpdateOrTransformParamInput(i, "DataFormatVecPermute", DT_INT32));
- }
- return Status::OK();
- }
-};
-
-class AgnosticNodeProcessor : public NodeProcessor {
- public:
- explicit AgnosticNodeProcessor(const OptimizeContext& opt_cxt)
- : NodeProcessor(opt_cxt) {}
-
- protected:
- bool ShouldProcess() const override {
- return !MustPreserve() && IsPortZeroDimsFour(*node_) && HasOutputs() &&
- IsNodeAfterNCHWToNHWC() && IsOnGPU();
- }
-
- bool IsNodeAfterNCHWToNHWC(const NodeDef& node) const {
- std::set<string> ops_format_agnostic = GetOpsFormatAgnostic();
- std::deque<NodeDef*> queue;
- auto data_node_pos = DataInputPos(node);
- std::unordered_set<string> visited;
- for (const auto& pos : data_node_pos) {
- auto input_node = node_map_->GetNode(node.input(pos));
- queue.push_back(input_node);
- visited.insert(input_node->name());
- }
- // The code will exit this while loop in one iteration in most cases, as the
- // graph is already topologically sorted.
- while (!queue.empty()) {
- NodeDef* current_node = queue.front();
- queue.pop_front();
- if (IsTransposeNCHWToNHWC(current_node->name()) ||
- IsDimMapNCHWToNHWC(current_node->name()) ||
- IsVecPermuteNCHWToNHWC(current_node->name())) {
- return true;
- }
- // We only continue searching if the path is connected through
- // format-agnostic nodes.
- if (ops_format_agnostic.find(current_node->op()) !=
- ops_format_agnostic.end()) {
- auto current_node_pos = DataInputPos(*current_node);
- for (const auto& pos : current_node_pos) {
- auto input_node = node_map_->GetNode(current_node->input(pos));
- if (visited.find(input_node->name()) == visited.end()) {
- queue.push_back(input_node);
- visited.insert(input_node->name());
- }
- }
- }
- }
- return false;
- }
-
- bool IsNodeAfterNCHWToNHWC() const { return IsNodeAfterNCHWToNHWC(*node_); }
-};
-
-class AddNProcessor : public AgnosticNodeProcessor {
- public:
- explicit AddNProcessor(const OptimizeContext& opt_cxt)
- : AgnosticNodeProcessor(opt_cxt) {}
-
- protected:
- std::vector<int> GetInputPos() const override {
- return NonControlInputs(*node_);
- }
-};
-
-class BinaryOpProcessor : public AgnosticNodeProcessor {
- public:
- explicit BinaryOpProcessor(const OptimizeContext& opt_cxt)
- : AgnosticNodeProcessor(opt_cxt) {}
-
- protected:
- bool ShouldProcess() const override {
- return !MustPreserve() && IsPortZeroDimsFour(*node_) && HasOutputs() &&
- IsNodeAfterNCHWToNHWC() &&
- (IsNDOperateWithMD(4, 0) || IsNDOperateWithMD(4, 1) ||
- IsNDOperateWithMD(4, 4) || IsNDOperateWithMD(0, 4) ||
- IsNDOperateWithMD(1, 4)) &&
- IsOnGPU();
- }
-
- std::vector<int> GetInputPos() const override {
- std::vector<int> input_pos;
- auto input0 = node_map_->GetNode(node_->input(0));
- auto input1 = node_map_->GetNode(node_->input(1));
- int input0_port;
- ParseNodeName(node_->input(0), &input0_port);
- int input1_port;
- ParseNodeName(node_->input(1), &input1_port);
- if (IsPortDimsFour(*input0, input0_port)) {
- input_pos.push_back(0);
- }
- if (IsPortDimsFour(*input1, input1_port)) {
- input_pos.push_back(1);
- }
- return input_pos;
- }
-
- bool IsNDOperateWithMD(int n, int m) const {
- auto input0 = node_map_->GetNode(node_->input(0));
- auto input1 = node_map_->GetNode(node_->input(1));
- int input0_port;
- ParseNodeName(node_->input(0), &input0_port);
- int input1_port;
- ParseNodeName(node_->input(1), &input1_port);
-
- if (input0 && input1) {
- bool input0_is_n = (n == 4) ? IsPortDimsFour(*input0, input0_port)
- : IsPortDimsN(*input0, input0_port, n);
- bool input1_is_m = (m == 4) ? IsPortDimsFour(*input1, input1_port)
- : IsPortDimsN(*input1, input1_port, m);
- return input0_is_n && input1_is_m;
- }
- return false;
- }
-
- NodeDef* AddNodeShapeConst(const string& name, int num_channels,
- const string& depended_node) {
- NodeDef* node = graph_->add_node();
- node_map_->AddNode(name, node);
- node->set_name(name);
- node->set_op("Const");
- node->set_device(node_->device());
- AttrValue attr_data_type;
- attr_data_type.set_type(DT_INT32);
- node->mutable_attr()->insert({"dtype", attr_data_type});
-
- AttrValue attr_tensor;
- Tensor tensor(DT_INT32, TensorShape({4}));
- std::vector<int> shape = {1, num_channels, 1, 1};
- for (int i = 0; i < static_cast<int>(shape.size()); i++) {
- tensor.flat<int>()(i) = shape[i];
- }
- tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
- node->mutable_attr()->insert({"value", attr_tensor});
- if (is_in_frame_) {
- // This is to ensure the transpose node and the const node are in the
- // same frame.
- *node->add_input() = AsControlDependency(depended_node);
- }
- return node;
- }
-
- NodeDef* AddNodeReshape(const string& node_name, const string& input_name,
- const string& shape_const_node_name,
- DataType data_type) {
- NodeDef* node = graph_->add_node();
- node_map_->AddNode(node_name, node);
- node->set_name(node_name);
- *node->add_input() = input_name;
- *node->add_input() = shape_const_node_name;
- node->set_op("Reshape");
- node->set_device(node_->device());
-
- AttrValue attr_type_indices;
- attr_type_indices.set_type(DT_INT32);
- node->mutable_attr()->insert({"Tshape", attr_type_indices});
-
- AttrValue attr_type_params;
- attr_type_params.set_type(data_type);
- node->mutable_attr()->insert({"T", attr_type_params});
- return node;
- }
-
- Status CustomizedProcessing() override {
- int vector_index = -1;
- if (IsNDOperateWithMD(4, 1)) {
- vector_index = 1;
- } else if (IsNDOperateWithMD(1, 4)) {
- vector_index = 0;
- }
- if (vector_index != -1) {
- string base_name = strings::StrCat(node_->name(), "-", vector_index);
- string reshape_node_name = LayoutOptimizerNode(
- strings::StrCat(base_name, "-", kReshapeNHWCToNCHW));
- string shape_const_node_name =
- LayoutOptimizerNode(strings::StrCat(base_name, "-", kReshapeConst));
- auto input_node = node_map_->GetNode(node_->input(vector_index));
- TF_RETURN_IF_ERROR(HasAttribute(*input_node, "_output_shapes"));
- int port;
- ParseNodeName(node_->input(vector_index), &port);
- int vector_size = input_node->attr()
- .at("_output_shapes")
- .list()
- .shape(port)
- .dim(0)
- .size();
- AddNodeShapeConst(shape_const_node_name, vector_size,
- NodeName(node_->input(vector_index)));
- TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
- AddNodeReshape(reshape_node_name, node_->input(vector_index),
- shape_const_node_name, node_->attr().at("T").type());
- node_map_->AddOutput(shape_const_node_name, reshape_node_name);
- node_map_->UpdateOutput(NodeName(node_->input(vector_index)),
- node_->name(), reshape_node_name);
- node_map_->AddOutput(reshape_node_name, node_->name());
- *node_->mutable_input(vector_index) = reshape_node_name;
- }
- return Status::OK();
- }
-};
-
-class ConcatProcessor : public AgnosticNodeProcessor {
- public:
- explicit ConcatProcessor(const OptimizeContext& opt_cxt)
- : AgnosticNodeProcessor(opt_cxt) {
- // For Concat, the concat axis is the first input; for ConcatV2,
- // the last input. Note that if with control inputs, the number of inputs
- // is larger than the integer attribute N.
- int n = node_->attr().at("N").i();
- axis_node_pos_ = (IsConcatV1(*node_)) ? 0 : n;
- }
-
- protected:
- std::vector<int> GetInputPos() const override {
- return DataInputPosConcat(*node_);
- }
-
- Status CustomizedProcessing() override {
- DataType dtype =
- (IsConcatV1(*node_)) ? DT_INT32 : node_->attr().at("Tidx").type();
- return UpdateOrTransformParamInput(axis_node_pos_, "DataFormatDimMap",
- dtype);
- }
-
- int axis_node_pos_;
-};
-
-class FillProcessor : public AgnosticNodeProcessor {
- public:
- explicit FillProcessor(const OptimizeContext& opt_cxt)
- : AgnosticNodeProcessor(opt_cxt) {}
-
- protected:
- std::vector<int> GetInputPos() const override { return {}; }
-
- Status CustomizedProcessing() override {
- DataType dtype = node_->attr().at("index_type").type();
- return UpdateOrTransformParamInput(0, "DataFormatVecPermute", dtype);
- }
-};
-
-class HistogramSummaryProcessor : public AgnosticNodeProcessor {
- public:
- explicit HistogramSummaryProcessor(const OptimizeContext& opt_cxt)
- : AgnosticNodeProcessor(opt_cxt) {}
-
- protected:
- bool ShouldProcess() const override {
- auto input1 = node_map_->GetNode(node_->input(1));
- int port;
- ParseNodeName(node_->input(1), &port);
- return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
- IsPortDimsFour(*input1, port) && IsOnGPU();
- }
-
- std::vector<int> GetInputPos() const override { return {1}; }
-
- Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
-};
-
-class IdentityNProcessor : public AgnosticNodeProcessor {
- public:
- explicit IdentityNProcessor(const OptimizeContext& opt_cxt)
- : AgnosticNodeProcessor(opt_cxt) {
- std::set<string> ops_format_agnostic = GetOpsFormatAgnostic();
- for (int i = 0; i < node_->input_size(); i++) {
- auto input = node_map_->GetNode(node_->input(i));
- int port;
- ParseNodeName(node_->input(i), &port);
- // Skip control input.
- if (port != -1) {
- bool is_agnostic =
- ops_format_agnostic.find(input->op()) != ops_format_agnostic.end();
- if (IsPortDimsFour(*input, port) &&
- ((IsNodeAfterNCHWToNHWC(*input) && is_agnostic) ||
- IsTransposeNCHWToNHWC(input->name()))) {
- input_pos_.push_back(i);
- }
- }
- }
- }
-
- protected:
- bool ShouldProcess() const override {
- return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
- IsOnGPU();
- }
-
- std::vector<int> GetInputPos() const override { return input_pos_; }
-
- std::set<int> GetOutputPos() const override {
- std::set<int> output_pos{};
- for (const auto& input_pos : input_pos_) {
- output_pos.insert(input_pos);
- }
- return output_pos;
- }
-
- private:
- std::vector<int> input_pos_;
-};
-
-class ShapeProcessor : public IdentityNProcessor {
- public:
- explicit ShapeProcessor(const OptimizeContext& opt_cxt)
- : IdentityNProcessor(opt_cxt) {}
-
- protected:
- Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
-
- Status CustomizedProcessing() override {
- return AddTransformToOutputs("DataFormatVecPermute");
- }
-};
-
-class MergeProcessor : public AgnosticNodeProcessor {
- public:
- explicit MergeProcessor(const OptimizeContext& opt_cxt)
- : AgnosticNodeProcessor(opt_cxt) {}
-
- protected:
- bool ShouldProcess() const override {
- return !MustPreserve() && IsPortZeroDimsFour(*node_) && HasOutputs() &&
- IsEveryInputAfterNCHWToNHWC() && IsOnGPU();
- }
-
- std::vector<int> GetInputPos() const override {
- std::vector<int> input_pos;
- int n = node_->attr().at("N").i();
- input_pos.reserve(n);
- for (int i = 0; i < n; i++) {
- input_pos.push_back(i);
- }
- return input_pos;
- }
-
- private:
- bool IsEveryInputAfterNCHWToNHWC() const {
- std::set<string> ops_format_agnostic = GetOpsFormatAgnostic();
- for (const auto& input : node_->input()) {
- auto input_node = node_map_->GetNode(input);
- int port;
- ParseNodeName(input, &port);
- bool is_agnostic = ops_format_agnostic.find(input_node->op()) !=
- ops_format_agnostic.end();
- if (IsPortDimsFour(*input_node, port) &&
- ((IsNodeAfterNCHWToNHWC(*input_node) && is_agnostic) ||
- IsTransposeNCHWToNHWC(input_node->name()))) {
- continue;
- }
- return false;
- }
- return true;
- }
-};
-
-class PadProcessor : public AgnosticNodeProcessor {
- public:
- explicit PadProcessor(const OptimizeContext& opt_cxt)
- : AgnosticNodeProcessor(opt_cxt) {}
-
- protected:
- Status CustomizedProcessing() override {
- DataType dtype = node_->attr().at("Tpaddings").type();
- return UpdateOrTransformParamInput(1, "DataFormatVecPermute", dtype);
- }
-};
-
-class ReverseProcessor : public AgnosticNodeProcessor {
- public:
- explicit ReverseProcessor(const OptimizeContext& opt_cxt)
- : AgnosticNodeProcessor(opt_cxt) {}
-
- protected:
- Status CustomizedProcessing() override {
- DataType dtype = node_->attr().at("Tidx").type();
- return UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype);
- }
-};
-
-class SplitProcessor : public AgnosticNodeProcessor {
- public:
- explicit SplitProcessor(const OptimizeContext& opt_cxt)
- : AgnosticNodeProcessor(opt_cxt) {
- axis_node_pos_ = 0;
- }
-
- protected:
- std::vector<int> GetInputPos() const override { return {1}; }
-
- std::set<int> GetOutputPos() const override {
- std::set<int> output_pos{0};
- if (HasAttribute(*node_, "num_split").ok()) {
- for (int i = 1; i < node_->attr().at("num_split").i(); i++) {
- output_pos.insert(i);
- }
- }
- return output_pos;
- }
-
- Status CustomizedProcessing() override {
- return UpdateOrTransformParamInput(axis_node_pos_, "DataFormatDimMap",
- DT_INT32);
- }
-
- int axis_node_pos_;
-};
-
-class SplitVProcessor : public SplitProcessor {
- public:
- explicit SplitVProcessor(const OptimizeContext& opt_cxt)
- : SplitProcessor(opt_cxt) {
- axis_node_pos_ = 2;
- }
-
- protected:
- std::vector<int> GetInputPos() const override { return {0}; }
-};
-
-class TernaryOpProcessor : public AgnosticNodeProcessor {
- public:
- explicit TernaryOpProcessor(const OptimizeContext& opt_cxt)
- : AgnosticNodeProcessor(opt_cxt) {}
-
- protected:
- std::vector<int> GetInputPos() const override { return {0, 1, 2}; }
-};
-
-class SelectProcessor : public AgnosticNodeProcessor {
- public:
- explicit SelectProcessor(const OptimizeContext& opt_cxt)
- : AgnosticNodeProcessor(opt_cxt) {}
-
- protected:
- bool ShouldProcess() const override {
- auto input0 = node_map_->GetNode(node_->input(0));
- int input0_port;
- ParseNodeName(node_->input(0), &input0_port);
- bool is_input0_scalar_vector_4d = IsPortDimsN(*input0, input0_port, 0) ||
- IsPortDimsN(*input0, input0_port, 1) ||
- IsPortDimsN(*input0, input0_port, 4);
- return AgnosticNodeProcessor::ShouldProcess() && is_input0_scalar_vector_4d;
- }
-
- std::vector<int> GetInputPos() const override {
- auto input0 = node_map_->GetNode(node_->input(0));
- int input0_port;
- ParseNodeName(node_->input(0), &input0_port);
- // Input 0 could be a scalar, a vector with size matching the first
- // dimension of input 1 and 2, or must have the same shape as input 1 and 2.
- if (IsPortDimsFour(*input0, input0_port)) {
- return {0, 1, 2};
- } else {
- return {1, 2};
- }
- }
-};
-
-class UnaryGradProcessor : public AgnosticNodeProcessor {
- public:
- explicit UnaryGradProcessor(const OptimizeContext& opt_cxt)
- : AgnosticNodeProcessor(opt_cxt) {}
-
- protected:
- std::vector<int> GetInputPos() const override { return {0, 1}; }
-};
-
-class SliceProcessor : public AgnosticNodeProcessor {
- public:
- explicit SliceProcessor(const OptimizeContext& opt_cxt)
- : AgnosticNodeProcessor(opt_cxt) {
- // Skip the first input, which is the data to be sliced.
- start_ = 1;
- // Note that we can't use node_->input_size() here because there
- // could be control inputs.
- end_ = 2;
- }
-
- protected:
- Status ProcessInputs() {
- for (int i = start_; i <= end_; i++) {
- DataType dtype = node_->attr().at("Index").type();
- TF_RETURN_IF_ERROR(
- UpdateOrTransformParamInput(i, "DataFormatVecPermute", dtype));
- }
- return Status::OK();
- }
-
- Status CustomizedProcessing() override { return ProcessInputs(); }
-
- int start_;
- int end_;
-};
-
-class StridedSliceProcessor : public SliceProcessor {
- public:
- explicit StridedSliceProcessor(const OptimizeContext& opt_cxt)
- : SliceProcessor(opt_cxt) {
- start_ = 1;
- end_ = 3;
- }
-
- protected:
- bool ShouldProcess() const override {
- return AgnosticNodeProcessor::ShouldProcess() && IsOnlyBeginEndMask();
- }
-
- Status CustomizedProcessing() override {
- TF_RETURN_IF_ERROR(UpdateMask("begin_mask"));
- TF_RETURN_IF_ERROR(UpdateMask("end_mask"));
- TF_RETURN_IF_ERROR(ProcessInputs());
- return Status::OK();
- }
-
- private:
- bool IsMaskZero(const string& mask) const {
- return node_->attr().at(mask).i() == 0;
- }
-
- bool IsOnlyBeginEndMask() const {
- return IsMaskZero("ellipsis_mask") && IsMaskZero("new_axis_mask") &&
- IsMaskZero("shrink_axis_mask");
- }
-
- Status UpdateMask(const string& mask) {
- int i = node_->attr().at(mask).i();
- if (i < 0 || i > 15) {
- return errors::InvalidArgument("invalid mask value: ", i);
- }
- if (i == 0 || i == 1 || i == 14 || i == 15) return Status::OK();
- switch (i) {
- case 2:
- case 3:
- i += 2;
- break;
- case 4:
- case 5:
- i += 4;
- break;
- case 6:
- case 7:
- i += 6;
- break;
- case 8:
- case 9:
- i -= 6;
- break;
- case 10:
- case 11:
- i -= 4;
- break;
- case 12:
- case 13:
- i -= 2;
- break;
- }
- node_->mutable_attr()->at(mask).set_i(i);
- return Status::OK();
- }
-};
-
-class StridedSliceGradProcessor : public StridedSliceProcessor {
- public:
- explicit StridedSliceGradProcessor(const OptimizeContext& opt_cxt)
- : StridedSliceProcessor(opt_cxt) {
- start_ = 0;
- end_ = 3;
- }
-
- protected:
- std::vector<int> GetInputPos() const override { return {4}; }
-};
-
-class SqueezeProcessor : public AgnosticNodeProcessor {
- public:
- explicit SqueezeProcessor(const OptimizeContext& opt_cxt)
- : AgnosticNodeProcessor(opt_cxt) {}
-
- protected:
- bool ShouldProcess() const override {
- bool is_dims_supported = (IsPortZeroDimsN(*node_, 2) && IsAlongHW()) ||
- (IsPortZeroDimsN(*node_, 1) && IsAlongNHW());
- return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
- IsInputConvertible() && is_dims_supported && IsOnGPU();
- }
-
- Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
-
- Status CustomizedProcessing() override {
- TF_RETURN_IF_ERROR(HasAttribute(*node_, "squeeze_dims"));
- auto list = node_->mutable_attr()->at("squeeze_dims").mutable_list();
- if (list->i_size() == 2) {
- list->set_i(0, 2);
- list->set_i(1, 3);
- } else if (list->i_size() == 3) {
- list->set_i(1, 2);
- list->set_i(2, 3);
- }
- return Status::OK();
- }
-
- private:
- bool IsInputConvertible() const {
- int input_port;
- auto input = node_map_->GetNode(node_->input(0));
- ParseNodeName(node_->input(0), &input_port);
- if (input->attr().find("_output_shapes") != input->attr().end()) {
- auto shape = input->attr().at("_output_shapes").list().shape(input_port);
- if (shape.dim_size() != 4) {
- return false;
- }
- if (shape.dim(1).size() == 1 && shape.dim(2).size() == 1) {
- return true;
- }
- if (shape.dim(0).size() == 1 && shape.dim(1).size() == 1 &&
- shape.dim(2).size() == 1) {
- return true;
- }
- }
- return false;
- }
-
- bool IsAlongAxis(const std::vector<int>& axis) const {
- if (node_->attr().find("squeeze_dims") != node_->attr().end()) {
- auto list = node_->attr().at("squeeze_dims").list();
- // If list is empty, Squeeze op will squeeze all dimensions of size 1.
- if (list.i_size() == 0) return true;
- if (list.i_size() == axis.size()) {
- bool along_axis = true;
- for (int i = 0; i < axis.size(); i++) {
- along_axis = along_axis && (list.i(i) == axis[i]);
- }
- if (along_axis) return true;
- }
- }
- return false;
- }
- bool IsAlongHW() const { return IsAlongAxis({1, 2}); }
- bool IsAlongNHW() const { return IsAlongAxis({0, 1, 2}); }
-};
-
-class ReduceProcessor : public AgnosticNodeProcessor {
- public:
- explicit ReduceProcessor(const OptimizeContext& opt_cxt)
- : AgnosticNodeProcessor(opt_cxt) {}
-
- protected:
- bool ShouldProcess() const override {
- auto input0 = node_map_->GetNode(node_->input(0));
- int port;
- ParseNodeName(node_->input(0), &port);
- return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
- IsPortDimsFour(*input0, port) && IsReduceAxisSupported() &&
- IsOnGPU();
- }
-
- Status CustomizedProcessing() override {
- if (IsReduceAxisSupported()) {
- DataType dtype = node_->attr().at("Tidx").type();
- TF_RETURN_IF_ERROR(
- UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype));
- }
- return Status::OK();
- }
-
- Status AddLayoutTransposeToOutputs() override {
- if (KeepDims()) {
- return AddTransformToOutputs("Transpose");
- }
- return Status::OK();
- }
-
- private:
- bool IsReduceAxisSupported() const {
- return KeepDims() || ((IsAlongAllFourDims() || IsAlongHWC() ||
- IsAlongNHW() || IsAlongHW() || IsAlongC()) &&
- !KeepDims());
- }
-
- bool IsAlongAxis(const std::vector<int>& axis) const {
- auto axis_node = node_map_->GetNode(node_->input(1));
- if (!IsConstant(*axis_node)) {
- return false;
- }
- if (HasAttribute(*axis_node, "value").ok()) {
- Tensor tensor;
- auto success = tensor.FromProto(axis_node->attr().at({"value"}).tensor());
- if (!success) {
- LOG(ERROR) << "Failed to parse TensorProto.";
- }
- if (tensor.dims() == 1 && tensor.dim_size(0) == axis.size()) {
- bool along_axis = true;
- for (int i = 0; i < axis.size(); i++) {
- along_axis = along_axis && (tensor.flat<int>()(i) == axis[i]);
- }
- if (along_axis) return true;
- }
- }
- return false;
- }
-
- bool IsAlongAllFourDims() const { return IsAlongAxis({0, 1, 2, 3}); }
-
- bool IsAlongHWC() const { return IsAlongAxis({1, 2, 3}); }
-
- bool IsAlongNHW() const { return IsAlongAxis({0, 1, 2}); }
-
- bool IsAlongHW() const { return IsAlongAxis({1, 2}); }
-
- bool IsAlongC() const { return IsAlongAxis({3}); }
-
- bool KeepDims() const { return node_->attr().at("keep_dims").b(); }
-};
-
-class SwitchProcessor : public AgnosticNodeProcessor {
- public:
- explicit SwitchProcessor(const OptimizeContext& opt_cxt)
- : AgnosticNodeProcessor(opt_cxt) {}
-
- protected:
- std::set<int> GetOutputPos() const override {
- std::set<int> output_pos;
- const int num_outs =
- node_->attr().count("num_outs") ? node_->attr().at("num_outs").i() : 2;
- for (int i = 0; i < num_outs; i++) {
- output_pos.insert(i);
- }
- return output_pos;
- }
-};
-
-class TileProcessor : public AgnosticNodeProcessor {
- public:
- explicit TileProcessor(const OptimizeContext& opt_cxt)
- : AgnosticNodeProcessor(opt_cxt) {}
-
- protected:
- Status CustomizedProcessing() override {
- DataType dtype = node_->attr().at("Tmultiples").type();
- return UpdateOrTransformParamInput(1, "DataFormatVecPermute", dtype);
- }
-};
-
-class DataLayoutOptimizer : GraphProcessor {
- public:
- explicit DataLayoutOptimizer(
- const GraphProperties& graph_properties,
- const VirtualPlacer& virtual_placer,
- const LayoutOptimizer::TuningConfig& config,
- const std::unordered_set<string>& nodes_to_preserve, GraphDef* graph,
- NodeMap* node_map)
- : GraphProcessor(graph_properties, virtual_placer, nodes_to_preserve,
- graph, node_map),
- config_(config) {}
-
- Status Optimize() {
- VLOG(1) << "Number of nodes for original graph: " << graph_->node_size();
- TF_RETURN_IF_ERROR(Expand());
- VLOG(1) << "Number of nodes after Expand: " << graph_->node_size();
- TF_RETURN_IF_ERROR(Collapse());
- VLOG(1) << "Number of nodes after Collapse: " << graph_->node_size();
- return Status::OK();
- }
-
- private:
- // Expand all nodes which is in NHWC, but supports NCHW or is layout agnostic.
- Status Expand() {
- int node_size_original = graph_->node_size();
- std::unordered_set<string> devices_with_perm_const;
-
- FrameView frame_view;
- TF_RETURN_IF_ERROR(frame_view.InferFromGraph(*graph_));
-
- // This is the first pass where we expand the nodes which support NCHW.
- std::set<string> ops_format_supported = GetOpsFormatSupported();
- for (int i = 0; i < node_size_original; i++) {
- if (IsNodeByLayoutOptimizer(graph_->node(i).name())) {
- return Status(error::INVALID_ARGUMENT,
- "The graph is already optimized by layout optimizer.");
- }
- if (ops_format_supported.find(graph_->node(i).op()) !=
- ops_format_supported.end()) {
- auto node = graph_->mutable_node(i);
- bool is_in_frame = frame_view.IsInFrame(*node);
- OptimizeContext opt_cxt(graph_, node, node_map_, graph_properties_,
- virtual_placer_, nodes_to_preserve_,
- is_in_frame, &devices_with_perm_const);
- std::unique_ptr<NodeProcessor> node_processor;
- if (IsAvgPoolGrad(*node)) {
- node_processor.reset(new AvgPoolGradProcessor(opt_cxt));
- } else if (IsBiasAddGrad(*node)) {
- node_processor.reset(new BiasAddGradProcessor(opt_cxt));
- } else if (IsConv2D(*node)) {
- node_processor.reset(new Conv2DProcessor(opt_cxt, config_.no_gemm));
- } else if (IsConv2DBackpropFilter(*node)) {
- node_processor.reset(
- new Conv2DBackpropFilterProcessor(opt_cxt, config_.no_gemm));
- } else if (IsConv2DBackpropInput(*node)) {
- node_processor.reset(
- new Conv2DBackpropInputProcessor(opt_cxt, config_.no_gemm));
- } else if (IsDepthwiseConv2dNative(*node)) {
- node_processor.reset(new Conv2DProcessor(opt_cxt, true));
- } else if (IsDepthwiseConv2dNativeBackpropFilter(*node)) {
- node_processor.reset(
- new Conv2DBackpropFilterProcessor(opt_cxt, true));
- } else if (IsDepthwiseConv2dNativeBackpropInput(*node)) {
- node_processor.reset(new Conv2DBackpropInputProcessor(opt_cxt, true));
- } else if (IsFusedBatchNormGrad(*node)) {
- node_processor.reset(new FusedBatchNormGradProcessor(opt_cxt));
- } else if (IsMaxPoolV2(*node)) {
- node_processor.reset(new MaxPoolV2Processor(opt_cxt));
- } else if (IsMaxPoolGradV1(*node) || IsMaxPoolGradGradV1(*node)) {
- node_processor.reset(new MaxPoolGradProcessor(opt_cxt));
- } else if (IsMaxPoolGradV2(*node) || IsMaxPoolGradGradV2(*node)) {
- node_processor.reset(new MaxPoolGradV2Processor(opt_cxt));
- } else {
- node_processor.reset(new NodeProcessor(opt_cxt));
- }
- TF_RETURN_IF_ERROR(node_processor->ConvertNode());
- }
- }
-
- // This is the second pass where we expand layout-agnostic nodes. This pass
- // only needs to be performed if at least one node in the previous pass is
- // expanded.
- if (graph_->node_size() > node_size_original) {
- std::set<string> ops_format_agnostic = GetOpsFormatAgnostic();
- for (int i = 0; i < graph_->node_size(); i++) {
- if (ops_format_agnostic.find(graph_->node(i).op()) !=
- ops_format_agnostic.end()) {
- auto node = graph_->mutable_node(i);
- bool is_in_frame = frame_view.IsInFrame(*node);
- OptimizeContext opt_cxt(graph_, node, node_map_, graph_properties_,
- virtual_placer_, nodes_to_preserve_,
- is_in_frame, &devices_with_perm_const);
- std::unique_ptr<NodeProcessor> node_processor;
- if (IsAddN(*node)) {
- node_processor.reset(new AddNProcessor(opt_cxt));
- } else if (IsBetainc(*node)) {
- node_processor.reset(new TernaryOpProcessor(opt_cxt));
- } else if (IsBinaryOp(*node)) {
- node_processor.reset(new BinaryOpProcessor(opt_cxt));
- } else if (IsConcat(*node)) {
- node_processor.reset(new ConcatProcessor(opt_cxt));
- } else if (IsFill(*node)) {
- node_processor.reset(new FillProcessor(opt_cxt));
- } else if (IsHistogramSummary(*node)) {
- node_processor.reset(new HistogramSummaryProcessor(opt_cxt));
- } else if (IsIdentityN(*node)) {
- node_processor.reset(new IdentityNProcessor(opt_cxt));
- } else if (IsMerge(*node)) {
- node_processor.reset(new MergeProcessor(opt_cxt));
- } else if (IsPad(*node) || IsMirrorPad(*node) ||
- IsMirrorPadGrad(*node)) {
- node_processor.reset(new PadProcessor(opt_cxt));
- } else if (IsReduceOp(*node)) {
- node_processor.reset(new ReduceProcessor(opt_cxt));
- } else if (IsReverseV2(*node)) {
- node_processor.reset(new ReverseProcessor(opt_cxt));
- } else if (IsSelect(*node)) {
- node_processor.reset(new SelectProcessor(opt_cxt));
- } else if (IsSlice(*node)) {
- node_processor.reset(new SliceProcessor(opt_cxt));
- } else if (IsStridedSlice(*node)) {
- node_processor.reset(new StridedSliceProcessor(opt_cxt));
- } else if (IsShape(*node) || IsShapeN(*node)) {
- node_processor.reset(new ShapeProcessor(opt_cxt));
- } else if (IsSplit(*node)) {
- node_processor.reset(new SplitProcessor(opt_cxt));
- } else if (IsSplitV(*node)) {
- node_processor.reset(new SplitVProcessor(opt_cxt));
- } else if (IsSqueeze(*node)) {
- node_processor.reset(new SqueezeProcessor(opt_cxt));
- } else if (IsStridedSliceGrad(*node)) {
- node_processor.reset(new StridedSliceGradProcessor(opt_cxt));
- } else if (IsSwitch(*node)) {
- node_processor.reset(new SwitchProcessor(opt_cxt));
- } else if (IsTile(*node)) {
- node_processor.reset(new TileProcessor(opt_cxt));
- } else if (IsUnaryGrad(*node)) {
- node_processor.reset(new UnaryGradProcessor(opt_cxt));
- } else {
- node_processor.reset(new AgnosticNodeProcessor(opt_cxt));
- }
- TF_RETURN_IF_ERROR(node_processor->ConvertNode());
- }
- }
- }
- return Status::OK();
- }
-
- // Remove all node pairs, where a NCHW-to-NHWC node is followed by
- // a NHWC-to-NCHW node.
- Status Collapse() {
- std::unordered_set<string> nodes_removable;
- for (int i = 0; i < graph_->node_size(); i++) {
- auto node = graph_->mutable_node(i);
- node->mutable_attr()->erase("_output_shapes");
- if (IsTransposeNHWCToNCHW(node->name()) ||
- IsDimMapNHWCToNCHW(node->name()) ||
- IsVecPermuteNHWCToNCHW(node->name())) {
- bool transpose_pair = IsTransposeNHWCToNCHW(node->name()) &&
- IsTransposeNCHWToNHWC(node->input(0));
- bool dim_map_pair = IsDimMapNHWCToNCHW(node->name()) &&
- IsDimMapNCHWToNHWC(node->input(0));
- bool vec_permute_pair = IsVecPermuteNHWCToNCHW(node->name()) &&
- IsVecPermuteNCHWToNHWC(node->input(0));
- if (transpose_pair || dim_map_pair || vec_permute_pair) {
- const string& trans_first = node->input(0);
- const string& trans_second = node->name();
- auto outputs = node_map_->GetOutputs(trans_second);
- CHECK(outputs.size() == 1)
- << "There is always only a single output for a Transpose node, "
- << "due to the way it is added by NodeProcessor.";
- NodeDef* output = *outputs.begin();
- string input = node_map_->GetNode(trans_first)->input(0);
- for (int i = 0; i < output->input_size(); i++) {
- if (output->input(i).compare(trans_second) == 0) {
- *output->mutable_input(i) = input;
- break;
- }
- }
- nodes_removable.insert(trans_first);
- nodes_removable.insert(trans_second);
- }
- }
- }
- graph_->mutable_node()->erase(
- std::remove_if(
- graph_->mutable_node()->begin(), graph_->mutable_node()->end(),
- [nodes_removable](const NodeDef& node) {
- return nodes_removable.find(node.name()) != nodes_removable.end();
- }),
- graph_->mutable_node()->end());
- return Status::OK();
- }
-
- const LayoutOptimizer::TuningConfig& config_;
-};
-
-int GetNumGPUs(const Cluster& cluster) {
- auto devices = cluster.GetDevices();
- int num_gpus = 0;
- for (const auto& device : devices) {
- if (device.second.type() == "GPU") {
- num_gpus++;
- }
- }
- return num_gpus;
-}
-} // namespace
-
-Status LayoutOptimizer::Tune(const GrapplerItem& item,
- const GraphProperties& graph_properties,
- const TuningConfig& config, GraphDef* output) {
- auto status = graph_properties.AnnotateOutputShapes(
- output, /*allow_symbolic_shapes=*/true);
- if (!status.ok()) {
- VLOG(1) << "Annotate shape return status: " << status.ToString();
- *output = item.graph;
- return status;
- }
- NodeMap node_map(output);
- DataLayoutOptimizer layout_optimizer(graph_properties, *virtual_placer_,
- config, nodes_to_preserve_, output,
- &node_map);
- status = layout_optimizer.Optimize();
- return status;
-}
-
-Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
- GraphDef* output) {
- if (cluster == nullptr) {
- LOG(WARNING) << "layout optimizer was called with cluster == nullptr";
- return errors::Aborted("cluster == nullptr.");
- }
- if (GetNumGPUs(*cluster) < 1) {
- return errors::Aborted(
- "No GPUs found: LayoutOptimizer is currently only tuned for GPU.");
- }
-
- GraphProperties graph_properties(item);
- TF_RETURN_IF_ERROR(
- graph_properties.InferStatically(/*assume_valid_feeds=*/false,
- /*aggressive_shape_inference=*/false,
- /*include_tensor_values=*/false));
- GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
-
- virtual_placer_.reset(new VirtualPlacer(cluster->GetDevices()));
- nodes_to_preserve_ = item.NodesToPreserve();
-
- TuningConfig config;
- config.no_gemm = true;
- // TODO(yaozhang): Enable tuning with various TuningConfig choices with
- // the measurement-based estimator.
- Status status = Tune(item, graph_properties, config, output);
- if (!status.ok()) {
- *output = item.graph;
- }
- return status;
-}
-
-} // end namespace grappler
-} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.h b/tensorflow/core/grappler/optimizers/layout_optimizer.h
deleted file mode 100644
index ae6307e..0000000
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.h
+++ /dev/null
@@ -1,59 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_LAYOUT_OPTIMIZER_H_
-#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_LAYOUT_OPTIMIZER_H_
-
-#include "tensorflow/core/grappler/costs/graph_properties.h"
-#include "tensorflow/core/grappler/costs/virtual_placer.h"
-#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
-
-namespace tensorflow {
-namespace grappler {
-// Convert the NHWC layout to NCHW for Conv-related ops on GPUs.
-class LayoutOptimizer : public GraphOptimizer {
- public:
- LayoutOptimizer() {}
- ~LayoutOptimizer() override {}
-
- string name() const override { return "layout"; };
-
- bool UsesFunctionLibrary() const override { return false; }
-
- struct TuningConfig {
- // If true, do not use the NHWC GEMM implementation. When filter size is
- // one or filter size is equal to input image size,
- // the NHWC implementation of Conv2D, Conv2DBackpropInput, and
- // Conv2DBackpropFilter will use a specialized GEMM implementation, which is
- // usually faster than the NCHW implementation. The downside is that this
- // might result in more non-cancellable layout conversion nodes (implemented
- // by the Transpose op).
- bool no_gemm;
- };
-
- Status Optimize(Cluster* cluster, const GrapplerItem& item,
- GraphDef* output) override;
-
- private:
- std::unique_ptr<VirtualPlacer> virtual_placer_;
- std::unordered_set<string> nodes_to_preserve_;
- Status Tune(const GrapplerItem& item, const GraphProperties& graph_properties,
- const TuningConfig& config, GraphDef* output);
-};
-
-} // end namespace grappler
-} // end namespace tensorflow
-
-#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_LAYOUT_OPTIMIZER_H_
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
deleted file mode 100644
index f8aef8a..0000000
--- a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
+++ /dev/null
@@ -1,1297 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
-
-#include "tensorflow/cc/ops/standard_ops.h"
-#include "tensorflow/core/framework/common_shape_fns.h"
-#include "tensorflow/core/framework/kernel_shape_util.h"
-#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/framework/tensor_testutil.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/grappler/clusters/single_machine.h"
-#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
-#include "tensorflow/core/grappler/costs/virtual_placer.h"
-#include "tensorflow/core/grappler/devices.h"
-#include "tensorflow/core/grappler/grappler_item.h"
-#include "tensorflow/core/grappler/utils.h"
-#include "tensorflow/core/grappler/utils/grappler_test.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/protobuf/device_properties.pb.h"
-
-namespace tensorflow {
-namespace grappler {
-namespace {
-
-class LayoutOptimizerTest : public GrapplerTest {
- protected:
- void SetUp() override {
- gpu_available_ = GetNumAvailableGPUs() > 0;
-
- if (gpu_available_) {
- virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 1));
- } else {
- DeviceProperties device_properties;
- device_properties.set_type("GPU");
- device_properties.mutable_environment()->insert({"architecture", "6"});
- virtual_cluster_.reset(
- new VirtualCluster({{"/GPU:1", device_properties}}));
- }
- TF_CHECK_OK(virtual_cluster_->Provision());
- }
-
- void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); }
-
- template <typename T = float>
- Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
- const string& padding) {
- return SimpleConv2D<T>(s, input_size, filter_size, padding, "");
- }
-
- template <typename T = float>
- Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
- const string& padding, const string& device) {
- int batch_size = 8;
- int input_height = input_size;
- int input_width = input_size;
- int input_depth = 3;
- int filter_count = 2;
- int stride = 1;
- TensorShape input_shape(
- {batch_size, input_height, input_width, input_depth});
- Tensor input_data(DataTypeToEnum<T>::value, input_shape);
- test::FillIota<T>(&input_data, static_cast<T>(1));
- Output input =
- ops::Const(s->WithOpName("Input"), Input::Initializer(input_data));
-
- TensorShape filter_shape(
- {filter_size, filter_size, input_depth, filter_count});
- Tensor filter_data(DataTypeToEnum<T>::value, filter_shape);
- test::FillIota<T>(&filter_data, static_cast<T>(1));
- Output filter =
- ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data));
-
- ops::Conv2D::Attrs attrs;
- const int kExplicitPaddings[] = {0, 0, 1, 2, 3, 4, 0, 0};
- if (padding == "EXPLICIT") {
- attrs = attrs.ExplicitPaddings(kExplicitPaddings);
- }
-
- Output conv = ops::Conv2D(s->WithOpName("Conv2D").WithDevice(device), input,
- filter, {1, stride, stride, 1}, padding, attrs);
- return conv;
- }
-
- Output SimpleConv2DBackpropInput(tensorflow::Scope* s, int input_size,
- int filter_size, const string& padding) {
- return SimpleConv2DBackpropInput(s, input_size, filter_size, padding, true,
- true);
- }
-
- Output SimpleConv2DBackpropInput(tensorflow::Scope* s, int input_size,
- int filter_size, const string& padding,
- bool const_input_size, bool dilated) {
- int batch_size = 128;
- int input_height = input_size;
- int input_width = input_size;
- int input_depth = 3;
- int filter_count = 2;
- int stride = 1;
- int dilation = dilated ? 2 : 1;
- int64_t padding_top = 1;
- int64_t padding_bottom = 2;
- int64_t padding_left = 3;
- int64_t padding_right = 4;
- int64_t output_height;
- int64_t output_width;
- Padding padding_enum;
- if (padding == "SAME") {
- padding_enum = SAME;
- } else if (padding == "VALID") {
- padding_enum = VALID;
- } else {
- CHECK_EQ(padding, "EXPLICIT");
- padding_enum = EXPLICIT;
- }
- TF_CHECK_OK(GetWindowedOutputSizeVerboseV2(
- input_height, filter_size, dilation, stride, padding_enum,
- &output_height, &padding_top, &padding_bottom));
- TF_CHECK_OK(GetWindowedOutputSizeVerboseV2(
- input_width, filter_size, dilation, stride, padding_enum, &output_width,
- &padding_left, &padding_right));
- TensorShape input_sizes_shape({4});
- Tensor input_data(DT_INT32, input_sizes_shape);
- test::FillValues<int>(&input_data,
- {batch_size, input_height, input_width, input_depth});
- Output input_sizes =
- ops::Const(s->WithOpName("InputSizes"), Input::Initializer(input_data));
-
- TensorShape filter_shape(
- {filter_size, filter_size, input_depth, filter_count});
- Output filter =
- ops::Variable(s->WithOpName("Filter"), filter_shape, DT_FLOAT);
-
- TensorShape output_shape(
- {batch_size, output_height, output_width, filter_count});
- Tensor output_data(DT_FLOAT, output_shape);
- test::FillIota<float>(&output_data, 1.0f);
- Output output =
- ops::Const(s->WithOpName("Output"), Input::Initializer(output_data));
-
- Output conv_backprop_input;
- Output input_sizes_i =
- ops::Identity(s->WithOpName("InputSizesIdentity"), input_sizes);
- std::vector<int> dilations{1, dilation, dilation, 1};
- std::vector<int> explicit_paddings;
- if (padding == "EXPLICIT") {
- explicit_paddings = {0,
- 0,
- static_cast<int>(padding_top),
- static_cast<int>(padding_bottom),
- static_cast<int>(padding_left),
- static_cast<int>(padding_right),
- 0,
- 0};
- }
- auto attrs =
- ops::Conv2DBackpropInput::Attrs().Dilations(dilations).ExplicitPaddings(
- explicit_paddings);
- if (const_input_size) {
- conv_backprop_input = ops::Conv2DBackpropInput(
- s->WithOpName("Conv2DBackpropInput"), input_sizes, filter, output,
- {1, stride, stride, 1}, padding, attrs);
- } else {
- conv_backprop_input = ops::Conv2DBackpropInput(
- s->WithOpName("Conv2DBackpropInput"), input_sizes_i, filter, output,
- {1, stride, stride, 1}, padding, attrs);
- }
- return conv_backprop_input;
- }
-
- Tensor GetAttrValue(const NodeDef& node) {
- Tensor tensor;
- CHECK(tensor.FromProto(node.attr().at({"value"}).tensor()));
- return tensor;
- }
-
- TensorShape GetAttrShape(const NodeDef& node) {
- return TensorShape(node.attr().at({"shape"}).shape());
- }
-
- Output SimpleFusedBatchNormGrad(tensorflow::Scope* s, bool is_training) {
- int batch_size = 16;
- int input_height = 8;
- int input_width = 8;
- int input_channels = 3;
- TensorShape shape({batch_size, input_height, input_width, input_channels});
- Tensor data(DT_FLOAT, shape);
- test::FillIota<float>(&data, 1.0f);
- Output x = ops::Const(s->WithOpName("Input"), Input::Initializer(data));
- Output y_backprop =
- ops::Const(s->WithOpName("YBackprop"), Input::Initializer(data));
-
- TensorShape shape_vector({input_channels});
- Tensor data_vector(DT_FLOAT, shape_vector);
- test::FillIota<float>(&data_vector, 2.0f);
- Output scale =
- ops::Const(s->WithOpName("Scale"), Input::Initializer(data_vector));
- Output reserve1 =
- ops::Const(s->WithOpName("Reserve1"), Input::Initializer(data_vector));
- Output reserve2 =
- ops::Const(s->WithOpName("Reserve2"), Input::Initializer(data_vector));
-
- ops::FusedBatchNormGrad::Attrs attrs;
- attrs.is_training_ = is_training;
- auto output =
- ops::FusedBatchNormGrad(s->WithOpName("FusedBatchNormGrad"), y_backprop,
- x, scale, reserve1, reserve2, attrs);
- return output.x_backprop;
- }
-
- std::unique_ptr<Cluster> virtual_cluster_;
- bool gpu_available_;
-};
-
-TEST_F(LayoutOptimizerTest, Conv2DBackpropInput) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "EXPLICIT");
- Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
-
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- string input_name = "Conv2DBackpropInput-0-LayoutOptimizer";
- auto input_sizes_node = node_map.GetNode(input_name);
- CHECK(input_sizes_node);
- auto conv2d_backprop_node = node_map.GetNode("Conv2DBackpropInput");
- CHECK(conv2d_backprop_node);
- EXPECT_EQ(input_name, conv2d_backprop_node->input(0));
- auto input_sizes = GetAttrValue(*input_sizes_node);
- Tensor input_sizes_expected(DT_INT32, {4});
- test::FillValues<int>(&input_sizes_expected, {128, 3, 7, 7});
- test::ExpectTensorEqual<int>(input_sizes_expected, input_sizes);
-
- if (gpu_available_) {
- TensorShape filter_shape = GetAttrShape(*node_map.GetNode("Filter"));
- Tensor filter_data = GenerateRandomTensor<DT_FLOAT>(filter_shape);
- std::vector<string> fetch = {"Fetch"};
- auto tensors_expected =
- EvaluateNodes(item.graph, fetch, {{"Filter", filter_data}});
- auto tensors = EvaluateNodes(output, fetch, {{"Filter", filter_data}});
- EXPECT_EQ(1, tensors_expected.size());
- EXPECT_EQ(1, tensors.size());
- test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
- }
-}
-
-TEST_F(LayoutOptimizerTest, Conv2DBackpropInputNonConstInputSizes) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", false, false);
- Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
-
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto conv2d_backprop_node = node_map.GetNode("Conv2DBackpropInput");
- CHECK(conv2d_backprop_node);
- EXPECT_EQ(conv2d_backprop_node->input(0),
- "Conv2DBackpropInput-0-VecPermuteNHWCToNCHW-LayoutOptimizer");
- auto input_sizes_node = node_map.GetNode(
- "Conv2DBackpropInput-0-VecPermuteNHWCToNCHW-LayoutOptimizer");
- CHECK(input_sizes_node);
- EXPECT_EQ(input_sizes_node->input(0), "InputSizesIdentity");
- EXPECT_EQ(input_sizes_node->op(), "DataFormatVecPermute");
-}
-
-TEST_F(LayoutOptimizerTest, FilterSizeIsOne) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 2, 1, "SAME");
- Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
-}
-
-TEST_F(LayoutOptimizerTest, FilterSizeNotOne) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 2, 1, "SAME");
- Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
-}
-
-TEST_F(LayoutOptimizerTest, EqualSizeWithValidPadding) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 2, 2, "VALID");
- Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
-}
-
-TEST_F(LayoutOptimizerTest, EqualSizeWithSamePadding) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 2, 2, "SAME");
- Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
-}
-
-TEST_F(LayoutOptimizerTest, NotEqualSizeWithValidPadding) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 4, 2, "VALID");
- Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
-}
-
-TEST_F(LayoutOptimizerTest, ExplicitPadding) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 4, 2, "EXPLICIT");
- Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
-}
-
-TEST_F(LayoutOptimizerTest, DataTypeIsInt32) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D<int32>(&s, 4, 2, "EXPLICIT");
- Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- EXPECT_FALSE(
- node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
-}
-
-TEST_F(LayoutOptimizerTest, Pad) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 4, 2, "VALID");
- auto c = ops::Const(s.WithOpName("c"), {1, 2, 3, 4, 5, 6, 7, 8}, {4, 2});
- auto p = ops::Pad(s.WithOpName("p"), conv, c);
- auto o = ops::Identity(s.WithOpName("o"), p);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
-
- auto pad = node_map.GetNode("p");
- EXPECT_EQ(pad->input(0), "Conv2D");
-
- auto pad_const = node_map.GetNode("p-1-LayoutOptimizer");
- EXPECT_TRUE(pad_const);
- EXPECT_TRUE(pad_const->attr().find("value") != pad_const->attr().end());
- Tensor tensor;
- EXPECT_TRUE(
- tensor.FromProto(pad_const->mutable_attr()->at({"value"}).tensor()));
- Tensor tensor_expected(DT_INT32, {4, 2});
- test::FillValues<int>(&tensor_expected, {1, 2, 7, 8, 3, 4, 5, 6});
- test::ExpectTensorEqual<int>(tensor_expected, tensor);
-}
-
-TEST_F(LayoutOptimizerTest, Connectivity) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 4, 2, "VALID");
- auto i1 = ops::Identity(s.WithOpName("i1"), conv);
- auto i2 = ops::Identity(s.WithOpName("i2"), i1);
- auto i3 = ops::Identity(s.WithOpName("i3"), i2);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- // Make the graph not in topological order to test the handling of multi-hop
- // connectivity (here we say two nodes are connected if all nodes in the
- // middle are layout agnostic). If the graph is already in topological order,
- // the problem is easier, where layout optimizer only needs to check
- // single-hop connectivity.
- NodeMap node_map_original(&item.graph);
- auto node_i1 = node_map_original.GetNode("i1");
- auto node_i2 = node_map_original.GetNode("i2");
- node_i2->Swap(node_i1);
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map_output(&output);
- auto node_i2_output = node_map_output.GetNode("i2");
- // Layout optimizer should process i2, as it detects i2 is connected with the
- // Conv2D node two hops away. Similarly i1 is processed as well, as i1 is
- // directly connected to the Conv2D node. The two added transposes between
- // i1 and i2 should cancel each other, and as a result i2 is directly
- // connected to i1.
- EXPECT_EQ(node_i2_output->input(0), "i1");
-}
-
-TEST_F(LayoutOptimizerTest, ConnectivityBinaryOpWithInputScalarAnd4D) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 4, 2, "VALID");
- auto i1 = ops::Identity(s.WithOpName("i1"), conv);
- auto i2 = ops::Identity(s.WithOpName("i2"), i1);
- auto scalar_sub = ops::Const(s.WithOpName("scalar_sub"), 3.0f, {});
- auto sub = ops::Sub(s.WithOpName("sub"), scalar_sub, i2);
- auto i3 = ops::Identity(s.WithOpName("i3"), sub);
- auto i4 = ops::Identity(s.WithOpName("i4"), i3);
- auto i5 = ops::Identity(s.WithOpName("i5"), i4);
- auto scalar_mul = ops::Const(s.WithOpName("scalar_mul"), 3.0f, {});
- auto mul = ops::Mul(s.WithOpName("mul"), scalar_mul, i5);
- auto i6 = ops::Identity(s.WithOpName("i6"), mul);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- // Make the graph not in topological order to test the handling of multi-hop
- // connectivity (here we say two nodes are connected if all nodes in the
- // middle are layout agnostic). If the graph is already in topological order,
- // the problem is easier, where layout optimizer only needs to check
- // single-hop connectivity.
- NodeMap node_map_original(&item.graph);
- auto node_i1 = node_map_original.GetNode("i1");
- auto node_mul = node_map_original.GetNode("mul");
- node_mul->Swap(node_i1);
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map_output(&output);
- auto mul_node = node_map_output.GetNode("mul");
- EXPECT_EQ(mul_node->input(0), "scalar_mul");
- EXPECT_EQ(mul_node->input(1), "i5");
-}
-
-TEST_F(LayoutOptimizerTest, PreserveFetch) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 4, 2, "VALID");
- auto i = ops::Identity(s.WithOpName("i"), conv);
- GrapplerItem item;
- item.fetch.push_back("Conv2D");
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto conv_node = node_map.GetNode("Conv2D");
- EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
-}
-
-TEST_F(LayoutOptimizerTest, EmptyDevice) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 4, 2, "VALID");
- Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto conv_node = node_map.GetNode("Conv2D");
- EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NCHW");
-}
-
-TEST_F(LayoutOptimizerTest, GPUDevice) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv =
- SimpleConv2D(&s, 4, 2, "VALID", "/job:w/replica:0/task:0/device:gpu:0");
- Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto conv_node = node_map.GetNode("Conv2D");
- EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NCHW");
-}
-
-TEST_F(LayoutOptimizerTest, CPUDeviceLowercase) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv =
- SimpleConv2D(&s, 4, 2, "VALID", "/job:w/replica:0/task:0/device:cpu:0");
- Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto conv_node = node_map.GetNode("Conv2D");
- EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
-}
-
-TEST_F(LayoutOptimizerTest, CPUDeviceUppercase) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 4, 2, "VALID", "/CPU:0");
- Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto conv_node = node_map.GetNode("Conv2D");
- EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
-}
-
-TEST_F(LayoutOptimizerTest, FusedBatchNormGradTrainingTrue) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto x_backprop = SimpleFusedBatchNormGrad(&s, true);
- Output fetch = ops::Identity(s.WithOpName("Fetch"), {x_backprop});
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto conv_node = node_map.GetNode("FusedBatchNormGrad");
- EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NCHW");
-}
-
-TEST_F(LayoutOptimizerTest, FusedBatchNormGradTrainingFalse) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto x_backprop = SimpleFusedBatchNormGrad(&s, false);
- Output fetch = ops::Identity(s.WithOpName("Fetch"), {x_backprop});
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto conv_node = node_map.GetNode("FusedBatchNormGrad");
- EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
-}
-
-TEST_F(LayoutOptimizerTest, SplitDimC) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 5, 2, "VALID");
- auto c = ops::Const(s.WithOpName("c"), 3, {});
- auto split = ops::Split(s.WithOpName("split"), c, conv, 2);
- auto i = ops::Identity(s.WithOpName("i"), split[0]);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto split_node = node_map.GetNode("split");
- EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
- EXPECT_EQ(split_node->input(1), "Conv2D");
- auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
- EXPECT_EQ(split_const->op(), "Const");
- EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 1);
-}
-
-TEST_F(LayoutOptimizerTest, SplitDimH) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 6, 2, "SAME");
- auto c = ops::Const(s.WithOpName("c"), 1, {});
- auto split = ops::Split(s.WithOpName("split"), c, conv, 2);
- auto i = ops::Identity(s.WithOpName("i"), split[0]);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto split_node = node_map.GetNode("split");
- EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
- EXPECT_EQ(split_node->input(1), "Conv2D");
- auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
- EXPECT_EQ(split_const->op(), "Const");
- EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 2);
-}
-
-TEST_F(LayoutOptimizerTest, SplitDimW) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 5, 2, "VALID");
- auto c = ops::Const(s.WithOpName("c"), 2, {});
- auto split = ops::Split(s.WithOpName("split"), c, conv, 2);
- auto i = ops::Identity(s.WithOpName("i"), split[0]);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto split_node = node_map.GetNode("split");
- EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
- EXPECT_EQ(split_node->input(1), "Conv2D");
- auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
- EXPECT_EQ(split_const->op(), "Const");
- EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 3);
-}
-
-TEST_F(LayoutOptimizerTest, SplitDimN) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 5, 2, "VALID");
- auto c = ops::Const(s.WithOpName("c"), 0, {});
- auto split = ops::Split(s.WithOpName("split"), c, conv, 2);
- auto i = ops::Identity(s.WithOpName("i"), split[0]);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto split_node = node_map.GetNode("split");
- EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
- EXPECT_EQ(split_node->input(1), "Conv2D");
- auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
- EXPECT_EQ(split_const->op(), "Const");
- EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 0);
-}
-
-TEST_F(LayoutOptimizerTest, SplitNonConstDim) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 5, 2, "VALID");
- auto c = ops::Const(s.WithOpName("c"), 0, {});
- auto i1 = ops::Identity(s.WithOpName("i1"), c);
- auto split = ops::Split(s.WithOpName("split"), i1, conv, 2);
- auto i2 = ops::Identity(s.WithOpName("i"), split[0]);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto split_node = node_map.GetNode("split");
- EXPECT_EQ(split_node->input(0), "split-0-DimMapNHWCToNCHW-LayoutOptimizer");
- EXPECT_EQ(split_node->input(1), "Conv2D");
- auto map_node = node_map.GetNode("split-0-DimMapNHWCToNCHW-LayoutOptimizer");
- EXPECT_EQ(map_node->op(), "DataFormatDimMap");
- EXPECT_EQ(map_node->input(0), "i1");
-}
-
-TEST_F(LayoutOptimizerTest, SplitSamePortToMultipleInputsOfSameNode) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 5, 2, "VALID");
- auto axis = ops::Const(s.WithOpName("axis"), 3);
- auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
- auto concat =
- ops::Concat(s.WithOpName("concat"), {split[1], split[1], split[1]}, axis);
- auto o = ops::Identity(s.WithOpName("o"), concat);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto concat_node = node_map.GetNode("concat");
- EXPECT_EQ(concat_node->input(0), "split:1");
- EXPECT_EQ(concat_node->input(1), "split:1");
- EXPECT_EQ(concat_node->input(2), "split:1");
- EXPECT_EQ(concat_node->input(3), "concat-3-LayoutOptimizer");
- auto concat_dim = node_map.GetNode("concat-3-LayoutOptimizer");
- EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 1);
-}
-
-TEST_F(LayoutOptimizerTest, ConcatDimH) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 4, 2, "SAME");
- auto axis = ops::Const(s.WithOpName("axis"), 1);
- auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
- auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
- auto o = ops::Identity(s.WithOpName("o"), concat);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto concat_node = node_map.GetNode("concat");
- EXPECT_EQ(concat_node->input(0), "split");
- EXPECT_EQ(concat_node->input(1), "split:1");
- EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
- auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
- EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 2);
-}
-
-TEST_F(LayoutOptimizerTest, ConcatNonConst) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 4, 2, "SAME");
- auto axis = ops::Const(s.WithOpName("axis"), 1);
- auto i = ops::Identity(s.WithOpName("i"), axis);
- auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
- auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, i);
- auto o = ops::Identity(s.WithOpName("o"), concat);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto concat_node = node_map.GetNode("concat");
- EXPECT_EQ(concat_node->input(0), "split");
- EXPECT_EQ(concat_node->input(1), "split:1");
- EXPECT_EQ(concat_node->input(2), "concat-2-DimMapNHWCToNCHW-LayoutOptimizer");
- auto concat_dim =
- node_map.GetNode("concat-2-DimMapNHWCToNCHW-LayoutOptimizer");
- EXPECT_EQ(concat_dim->op(), "DataFormatDimMap");
- EXPECT_EQ(concat_dim->input(0), "i");
-}
-
-TEST_F(LayoutOptimizerTest, ConcatDimW) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 4, 2, "SAME");
- auto axis = ops::Const(s.WithOpName("axis"), 2);
- auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
- auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
- auto o = ops::Identity(s.WithOpName("o"), concat);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto concat_node = node_map.GetNode("concat");
- EXPECT_EQ(concat_node->input(0), "split");
- EXPECT_EQ(concat_node->input(1), "split:1");
- EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
- auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
- EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 3);
-}
-
-TEST_F(LayoutOptimizerTest, ConcatDimN) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 4, 2, "VALID");
- auto axis = ops::Const(s.WithOpName("axis"), 0);
- auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
- auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
- auto o = ops::Identity(s.WithOpName("o"), concat);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto concat_node = node_map.GetNode("concat");
- EXPECT_EQ(concat_node->input(0), "split");
- EXPECT_EQ(concat_node->input(1), "split:1");
- EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
- auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
- EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 0);
-}
-
-TEST_F(LayoutOptimizerTest, ConcatDimC) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 4, 2, "VALID");
- auto axis = ops::Const(s.WithOpName("axis"), 3);
- auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
- auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
- auto o = ops::Identity(s.WithOpName("o"), concat);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto concat_node = node_map.GetNode("concat");
- EXPECT_EQ(concat_node->input(0), "split");
- EXPECT_EQ(concat_node->input(1), "split:1");
- EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
- auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
- EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 1);
-}
-
-TEST_F(LayoutOptimizerTest, Sum) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 4, 2, "VALID");
- auto reduction_indices =
- ops::Const(s.WithOpName("reduction_indices"), {0, 1, 2}, {3});
- auto sum = ops::Sum(s.WithOpName("sum"), conv, reduction_indices);
- auto o = ops::Identity(s.WithOpName("o"), sum);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- // TODO(yaozhang): enable SumProcessor with auto-tuning. Currently disabled
- // because of the worse performance in some cases.
- /*
- NodeMap node_map(&output);
- auto sum_node = node_map.GetNode("sum");
- EXPECT_EQ(sum_node->input(0), "Conv2D");
- EXPECT_EQ(sum_node->input(1), "LayoutOptimizer-sum-reduction_indices");
- auto sum_const = node_map.GetNode("LayoutOptimizer-sum-reduction_indices");
- Tensor tensor;
- EXPECT_TRUE(
- tensor.FromProto(sum_const->mutable_attr()->at({"value"}).tensor()));
- Tensor tensor_expected(DT_INT32, {3});
- test::FillValues<int>(&tensor_expected, {0, 2, 3});
- test::ExpectTensorEqual<int>(tensor_expected, tensor);
- */
-}
-
-TEST_F(LayoutOptimizerTest, MulScalarAnd4D) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 4, 2, "VALID");
- auto scalar = ops::Const(s.WithOpName("scalar"), 3.0f, {});
- auto mul = ops::Mul(s.WithOpName("mul"), scalar, conv);
- auto o = ops::Identity(s.WithOpName("o"), mul);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto mul_node = node_map.GetNode("mul");
- EXPECT_EQ(mul_node->input(0), "scalar");
- EXPECT_EQ(mul_node->input(1), "Conv2D");
-}
-
-TEST_F(LayoutOptimizerTest, Mul4DAndScalar) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 4, 2, "VALID");
- auto scalar = ops::Const(s.WithOpName("scalar"), 3.0f, {});
- auto mul = ops::Mul(s.WithOpName("mul"), conv, scalar);
- auto o = ops::Identity(s.WithOpName("o"), mul);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto mul_node = node_map.GetNode("mul");
- EXPECT_EQ(mul_node->input(0), "Conv2D");
- EXPECT_EQ(mul_node->input(1), "scalar");
-}
-
-TEST_F(LayoutOptimizerTest, Mul4DAndUnknownRank) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 4, 2, "VALID");
- auto unknown_rank =
- ops::Placeholder(s.WithOpName("unknown"), DT_FLOAT,
- ops::Placeholder::Shape(PartialTensorShape()));
- Output c = ops::Const(s.WithOpName("c"), 3.0f, {8, 2, 2, 2});
- Output mul = ops::Mul(s.WithOpName("mul"), conv, unknown_rank);
- auto o = ops::AddN(s.WithOpName("o"), {mul, c});
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto mul_node = node_map.GetNode("mul");
- // Node mul should not be processed by layout optimizer, because one of its
- // inputs is of unknown rank.
- EXPECT_EQ(mul_node->input(0),
- "Conv2D-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
- EXPECT_EQ(mul_node->input(1), "unknown");
-}
-
-TEST_F(LayoutOptimizerTest, Mul4DAnd4D) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 4, 2, "VALID");
- auto i = ops::Identity(s.WithOpName("i"), conv);
- auto mul = ops::Mul(s.WithOpName("mul"), conv, i);
- auto o = ops::Identity(s.WithOpName("o"), mul);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto mul_node = node_map.GetNode("mul");
- EXPECT_EQ(mul_node->input(0), "Conv2D");
- EXPECT_EQ(mul_node->input(1), "i");
-}
-
-TEST_F(LayoutOptimizerTest, Mul4DAndVector) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 4, 2, "VALID");
- auto vector = ops::Const(s.WithOpName("vector"), {3.0f, 7.0f}, {2});
- auto mul = ops::Mul(s.WithOpName("mul"), conv, vector);
- auto o = ops::Identity(s.WithOpName("o"), mul);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto mul_node = node_map.GetNode("mul");
- EXPECT_EQ(mul_node->input(0), "Conv2D");
- EXPECT_EQ(mul_node->input(1), "mul-1-ReshapeNHWCToNCHW-LayoutOptimizer");
- auto mul_const = node_map.GetNode("mul-1-ReshapeConst-LayoutOptimizer");
- Tensor tensor;
- EXPECT_TRUE(
- tensor.FromProto(mul_const->mutable_attr()->at({"value"}).tensor()));
- Tensor tensor_expected(DT_INT32, {4});
- test::FillValues<int>(&tensor_expected, {1, 2, 1, 1});
- test::ExpectTensorEqual<int>(tensor_expected, tensor);
-}
-
-TEST_F(LayoutOptimizerTest, MulVectorAnd4D) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 4, 2, "VALID");
- auto vector = ops::Const(s.WithOpName("vector"), {3.0f, 7.0f}, {2});
- auto mul = ops::Mul(s.WithOpName("mul"), vector, conv);
- auto o = ops::Identity(s.WithOpName("o"), mul);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto mul_node = node_map.GetNode("mul");
- EXPECT_EQ(mul_node->input(0), "mul-0-ReshapeNHWCToNCHW-LayoutOptimizer");
- EXPECT_EQ(mul_node->input(1), "Conv2D");
- auto mul_const = node_map.GetNode("mul-0-ReshapeConst-LayoutOptimizer");
- Tensor tensor;
- EXPECT_TRUE(
- tensor.FromProto(mul_const->mutable_attr()->at({"value"}).tensor()));
- Tensor tensor_expected(DT_INT32, {4});
- test::FillValues<int>(&tensor_expected, {1, 2, 1, 1});
- test::ExpectTensorEqual<int>(tensor_expected, tensor);
-}
-
-TEST_F(LayoutOptimizerTest, SliceConst) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 5, 2, "VALID");
- auto begin = ops::Const(s.WithOpName("begin"), {0, 2, 3, 1}, {4});
- auto size = ops::Const(s.WithOpName("size"), {4, 1, 2, 4}, {4});
- auto slice = ops::Slice(s.WithOpName("slice"), conv, begin, size);
- auto o = ops::Identity(s.WithOpName("o"), slice);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto slice_node = node_map.GetNode("slice");
- EXPECT_EQ(slice_node->input(0), "Conv2D");
- EXPECT_EQ(slice_node->input(1), "slice-1-LayoutOptimizer");
- EXPECT_EQ(slice_node->input(2), "slice-2-LayoutOptimizer");
-
- auto begin_const = node_map.GetNode("slice-1-LayoutOptimizer");
- Tensor begin_tensor;
- EXPECT_TRUE(begin_tensor.FromProto(
- begin_const->mutable_attr()->at({"value"}).tensor()));
- Tensor begin_tensor_expected(DT_INT32, {4});
- test::FillValues<int>(&begin_tensor_expected, {0, 1, 2, 3});
- test::ExpectTensorEqual<int>(begin_tensor_expected, begin_tensor);
-
- auto size_const = node_map.GetNode("slice-2-LayoutOptimizer");
- Tensor size_tensor;
- EXPECT_TRUE(size_tensor.FromProto(
- size_const->mutable_attr()->at({"value"}).tensor()));
- Tensor size_tensor_expected(DT_INT32, {4});
- test::FillValues<int>(&size_tensor_expected, {4, 4, 1, 2});
- test::ExpectTensorEqual<int>(size_tensor_expected, size_tensor);
-}
-
-TEST_F(LayoutOptimizerTest, SliceNonConst) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 5, 2, "VALID");
- auto begin = ops::Const(s.WithOpName("begin"), {0, 2, 3, 1}, {4});
- auto ibegin = ops::Identity(s.WithOpName("ibegin"), begin);
- auto size = ops::Const(s.WithOpName("size"), {4, 1, 2, 4}, {4});
- auto isize = ops::Identity(s.WithOpName("isize"), size);
- auto slice = ops::Slice(s.WithOpName("slice"), conv, ibegin, isize);
- auto o = ops::Identity(s.WithOpName("o"), slice);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto slice_node = node_map.GetNode("slice");
- EXPECT_EQ(slice_node->input(0), "Conv2D");
- EXPECT_EQ(slice_node->input(1),
- "slice-1-VecPermuteNHWCToNCHW-LayoutOptimizer");
- EXPECT_EQ(slice_node->input(2),
- "slice-2-VecPermuteNHWCToNCHW-LayoutOptimizer");
- auto perm1 = node_map.GetNode("slice-1-VecPermuteNHWCToNCHW-LayoutOptimizer");
- EXPECT_EQ(perm1->op(), "DataFormatVecPermute");
- EXPECT_EQ(perm1->input(0), "ibegin");
- auto perm2 = node_map.GetNode("slice-2-VecPermuteNHWCToNCHW-LayoutOptimizer");
- EXPECT_EQ(perm1->op(), "DataFormatVecPermute");
- EXPECT_EQ(perm2->input(0), "isize");
-}
-
-TEST_F(LayoutOptimizerTest, DoNotApplyOptimizerTwice) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto scalar =
- ops::Const(s.WithOpName("AlreadyApplied-LayoutOptimizer"), 3.0f, {});
- auto mul = ops::Mul(s.WithOpName("mul"), scalar, scalar);
- auto o = ops::Identity(s.WithOpName("o"), mul);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- EXPECT_TRUE(errors::IsInvalidArgument(status));
-}
-
-TEST_F(LayoutOptimizerTest, ShapeNWithInputs4DAnd4D) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 4, 2, "VALID");
- auto shapen = ops::ShapeN(s.WithOpName("shapen"), {conv, conv});
- auto add = ops::Add(s.WithOpName("add"), shapen[0], shapen[1]);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto shapen_node = node_map.GetNode("shapen");
- EXPECT_EQ(shapen_node->input(0), "Conv2D");
- EXPECT_EQ(shapen_node->input(1), "Conv2D");
- auto add_node = node_map.GetNode("add");
- EXPECT_EQ(add_node->input(0),
- "shapen-0-0-VecPermuteNCHWToNHWC-LayoutOptimizer");
- EXPECT_EQ(add_node->input(1),
- "shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
- auto vec_permute1 =
- node_map.GetNode("shapen-0-0-VecPermuteNCHWToNHWC-LayoutOptimizer");
- EXPECT_EQ(vec_permute1->input(0), "shapen");
- EXPECT_EQ(vec_permute1->op(), "DataFormatVecPermute");
- auto vec_permute2 =
- node_map.GetNode("shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
- EXPECT_EQ(vec_permute2->input(0), "shapen:1");
- EXPECT_EQ(vec_permute2->op(), "DataFormatVecPermute");
-}
-
-TEST_F(LayoutOptimizerTest, ShapeNWithInputsVectorAnd4D) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 4, 2, "VALID");
- auto vector = ops::Const(s.WithOpName("vector"), 3.0f, {7});
- auto shapen = ops::ShapeN(s.WithOpName("shapen"), {vector, conv});
- auto add = ops::Add(s.WithOpName("add"), shapen[0], shapen[1]);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto shapen_node = node_map.GetNode("shapen");
- EXPECT_EQ(shapen_node->input(0), "vector");
- EXPECT_EQ(shapen_node->input(1), "Conv2D");
- auto add_node = node_map.GetNode("add");
- EXPECT_EQ(add_node->input(0), "shapen");
- EXPECT_EQ(add_node->input(1),
- "shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
- auto vec_permute =
- node_map.GetNode("shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
- EXPECT_EQ(vec_permute->input(0), "shapen:1");
- EXPECT_EQ(vec_permute->op(), "DataFormatVecPermute");
-}
-
-TEST_F(LayoutOptimizerTest, ShapeNWithInputs4DAndNoNeedToTransform4D) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 4, 2, "VALID");
- auto tensor_4d = ops::Const(s.WithOpName("tensor_4d"), 3.0f, {1, 1, 1, 3});
- auto i1 = ops::Identity(s.WithOpName("i1"), tensor_4d);
- Output i2 = ops::Identity(s.WithOpName("i2"), i1);
- auto shapen = ops::ShapeN(s.WithOpName("shapen"), {conv, i2});
- auto add = ops::Add(s.WithOpName("add"), shapen[0], shapen[1]);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto shapen_node = node_map.GetNode("shapen");
- EXPECT_EQ(shapen_node->input(0), "Conv2D");
- EXPECT_EQ(shapen_node->input(1), "i2");
-}
-
-TEST_F(LayoutOptimizerTest, Switch) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 4, 2, "VALID");
- ops::Variable ctrl(s.WithOpName("ctrl"), {}, DT_BOOL);
- auto sw = ops::Switch(s.WithOpName("switch"), conv, ctrl);
- auto i1 = ops::Identity(s.WithOpName("i1"), sw.output_true);
- auto i2 = ops::Identity(s.WithOpName("i2"), sw.output_false);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto switch_node = node_map.GetNode("switch");
- EXPECT_EQ(switch_node->input(0), "Conv2D");
- EXPECT_EQ(switch_node->input(1), "ctrl");
- auto i1_node = node_map.GetNode("i1");
- auto i2_node = node_map.GetNode("i2");
- auto trans1 = node_map.GetNode(i1_node->input(0));
- EXPECT_EQ(trans1->input(0), "switch:1");
- auto trans2 = node_map.GetNode(i2_node->input(0));
- EXPECT_EQ(trans2->input(0), "switch");
-}
-
-TEST_F(LayoutOptimizerTest, MergeBothInputsConvertible) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 4, 2, "VALID");
- Output i1 = ops::Identity(s.WithOpName("i1"), conv);
- auto merge = ops::Merge(s.WithOpName("merge"), {conv, i1});
- auto i2 = ops::Identity(s.WithOpName("i2"), merge.output);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto merge_node = node_map.GetNode("merge");
- EXPECT_EQ(merge_node->input(0), "Conv2D");
- EXPECT_EQ(merge_node->input(1), "i1");
- auto i2_node = node_map.GetNode("i2");
- EXPECT_EQ(i2_node->input(0), "merge-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
- auto transpose =
- node_map.GetNode("merge-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
- EXPECT_EQ(transpose->input(0), "merge");
-}
-
-TEST_F(LayoutOptimizerTest, MergeOneInputNotConvertible) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 4, 2, "VALID");
- auto tensor_4d = ops::Const(s.WithOpName("tensor_4d"), 3.0f, {1, 1, 1, 3});
- auto merge = ops::Merge(s.WithOpName("merge"), {tensor_4d, conv});
- auto i2 = ops::Identity(s.WithOpName("i2"), merge.output);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto merge_node = node_map.GetNode("merge");
- EXPECT_EQ(merge_node->input(0), "tensor_4d");
- EXPECT_EQ(merge_node->input(1),
- "Conv2D-0-1-TransposeNCHWToNHWC-LayoutOptimizer");
-}
-
-TEST_F(LayoutOptimizerTest, Complex) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 4, 2, "VALID");
- auto comp = ops::Complex(s.WithOpName("complex"), conv, conv);
- auto i = ops::Identity(s.WithOpName("i"), comp);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto merge_node = node_map.GetNode("complex");
- EXPECT_EQ(merge_node->input(0), "Conv2D");
- EXPECT_EQ(merge_node->input(1), "Conv2D");
- auto trans =
- node_map.GetNode("complex-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
- EXPECT_EQ(trans->attr().at("T").type(), DT_COMPLEX64);
-}
-
-TEST_F(LayoutOptimizerTest, IdentityNWithInputsVectorAnd4D) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 4, 2, "VALID");
- auto vector = ops::Const(s.WithOpName("vector"), 3.0f, {2});
- auto identity_n = ops::IdentityN(s.WithOpName("identity_n"), {vector, conv});
- auto add = ops::Add(s.WithOpName("add"), identity_n[0], identity_n[1]);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto i = node_map.GetNode("identity_n");
- EXPECT_EQ(i->input(0), "vector");
- EXPECT_EQ(i->input(1), "Conv2D");
- auto trans =
- node_map.GetNode("identity_n-0-1-TransposeNCHWToNHWC-LayoutOptimizer");
- EXPECT_EQ(trans->input(0), "identity_n:1");
- auto add_node = node_map.GetNode("add");
- EXPECT_EQ(add_node->input(0), "identity_n");
- EXPECT_EQ(add_node->input(1),
- "identity_n-0-1-TransposeNCHWToNHWC-LayoutOptimizer");
-}
-
-TEST_F(LayoutOptimizerTest, LoopNoLiveLock) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto c = ops::Const(s.WithOpName("const"), 3.0f, {8, 3, 3, 2});
- auto merge = ops::Merge(s.WithOpName("merge"), {c, c});
- auto i0 = ops::Identity(s.WithOpName("i0"), merge.output);
- ops::Variable v_ctrl(s.WithOpName("v_ctrl"), {}, DT_BOOL);
- auto sw = ops::Switch(s.WithOpName("switch"), i0, v_ctrl);
- auto next = ops::NextIteration(s.WithOpName("next"), sw.output_true);
- auto conv = SimpleConv2D(&s, 4, 2, "VALID");
- auto mul = ops::Mul(s.WithOpName("mul"), conv, sw.output_false);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- NodeMap node_map_original(&item.graph);
- auto merge_node = node_map_original.GetNode("merge");
- // Modify the graph to create a loop
- merge_node->set_input(1, "next");
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto conv_node = node_map.GetNode("Conv2D");
- EXPECT_EQ(conv_node->input(0),
- "Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer");
- auto mul_node = node_map.GetNode("mul");
- EXPECT_EQ(mul_node->input(0),
- "Conv2D-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
-}
-
-TEST_F(LayoutOptimizerTest, DevicePlacement) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2D(&s, 4, 2, "VALID");
- auto shape = ops::Shape(s.WithOpName("s"), conv);
- auto i = ops::Identity(s.WithOpName("i"), shape);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- VirtualPlacer virtual_placer(virtual_cluster_->GetDevices());
- for (auto& node : *item.graph.mutable_node()) {
- string device = virtual_placer.get_canonical_device_name(node);
- node.set_device(device);
- }
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto vec_permute =
- node_map.GetNode("s-0-0-VecPermuteNCHWToNHWC-LayoutOptimizer");
- EXPECT_EQ(vec_permute->attr().at("_kernel").s(), "host");
-}
-
-TEST_F(LayoutOptimizerTest, PermConstWithDevice) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- const string worker0_gpu0 = "/job:w/replica:0/task:0/device:gpu:0";
- const string worker1_gpu1 = "/job:w/replica:0/task:1/device:gpu:1";
- const string worker0_node_prefix = "job_w_replica_0_task_0_device_gpu_0-";
- const string worker1_node_prefix = "job_w_replica_0_task_1_device_gpu_1-";
- const string perm_nchw2nhwc_str = "PermConstNCHWToNHWC-LayoutOptimizer";
- const string perm_nhwc2nchw_str = "PermConstNHWCToNCHW-LayoutOptimizer";
- auto conv_0 = SimpleConv2D(&s, 4, 2, "VALID", worker0_gpu0);
- auto shape_0 = ops::Shape(s.WithOpName("s"), conv_0);
- auto i_0 = ops::Identity(s.WithOpName("i"), shape_0);
- auto conv_1 = SimpleConv2D(&s, 4, 2, "VALID", worker1_gpu1);
- auto shape_1 = ops::Shape(s.WithOpName("s"), conv_1);
- auto i_1 = ops::Identity(s.WithOpName("i"), shape_1);
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- LayoutOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
- NodeMap node_map(&output);
- auto const_permute_0_0 =
- node_map.GetNode(worker0_node_prefix + perm_nchw2nhwc_str);
- auto const_permute_0_1 =
- node_map.GetNode(worker0_node_prefix + perm_nhwc2nchw_str);
- EXPECT_EQ(const_permute_0_0->device(), worker0_gpu0);
- EXPECT_EQ(const_permute_0_1->device(), worker0_gpu0);
- auto const_permute_1_0 =
- node_map.GetNode(worker1_node_prefix + perm_nchw2nhwc_str);
- auto const_permute_1_1 =
- node_map.GetNode(worker1_node_prefix + perm_nhwc2nchw_str);
- EXPECT_EQ(const_permute_1_0->device(), worker1_gpu1);
- EXPECT_EQ(const_permute_1_1->device(), worker1_gpu1);
- EXPECT_FALSE(node_map.GetNode(perm_nchw2nhwc_str));
- EXPECT_FALSE(node_map.GetNode(perm_nhwc2nchw_str));
-}
-} // namespace
-} // namespace grappler
-} // namespace tensorflow