[NFC] Remove unused layout_optimizer grappler optimization. There is generic_layout_optimizer which is currently in use.

PiperOrigin-RevId: 396457524
Change-Id: Ideac28247aa6e3cd28e4fac14ff28de39f3f003f
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index e1f6626..8008679 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -565,55 +565,6 @@
 )
 
 cc_library(
-    name = "layout_optimizer",
-    srcs = ["layout_optimizer.cc"],
-    hdrs = [
-        "layout_optimizer.h",
-    ],
-    visibility = ["//visibility:public"],
-    deps = [
-        ":graph_optimizer",
-        "//tensorflow/core:framework",
-        "//tensorflow/core:lib",
-        "//tensorflow/core:protos_all_cc",
-        "//tensorflow/core/grappler:devices",
-        "//tensorflow/core/grappler:grappler_item",
-        "//tensorflow/core/grappler:op_types",
-        "//tensorflow/core/grappler:utils",
-        "//tensorflow/core/grappler/clusters:cluster",
-        "//tensorflow/core/grappler/costs:graph_properties",
-        "//tensorflow/core/grappler/costs:virtual_placer",
-        "//tensorflow/core/grappler/utils:frame",
-        "@com_google_absl//absl/strings",
-    ],
-)
-
-tf_cuda_cc_test(
-    name = "layout_optimizer_test",
-    srcs = ["layout_optimizer_test.cc"],
-    deps = [
-        ":layout_optimizer",
-        "//tensorflow/cc:cc_ops",
-        "//tensorflow/cc:cc_ops_internal",
-        "//tensorflow/core:all_kernels",
-        "//tensorflow/core:core_cpu",
-        "//tensorflow/core:framework",
-        "//tensorflow/core:lib",
-        "//tensorflow/core:protos_all_cc",
-        "//tensorflow/core:test",
-        "//tensorflow/core:test_main",
-        "//tensorflow/core:testlib",
-        "//tensorflow/core/grappler:devices",
-        "//tensorflow/core/grappler:grappler_item",
-        "//tensorflow/core/grappler:utils",
-        "//tensorflow/core/grappler/clusters:single_machine",
-        "//tensorflow/core/grappler/clusters:virtual_cluster",
-        "//tensorflow/core/grappler/costs:virtual_placer",
-        "//tensorflow/core/grappler/utils:grappler_test",
-    ],
-)
-
-cc_library(
     name = "auto_mixed_precision",
     srcs = ["auto_mixed_precision.cc"],
     hdrs = [
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
deleted file mode 100644
index a379cb9..0000000
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc
+++ /dev/null
@@ -1,2287 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
-
-#include <deque>
-#include <unordered_set>
-
-#include "absl/strings/strip.h"
-#include "tensorflow/core/framework/attr_value.pb.h"
-#include "tensorflow/core/framework/memory_types.h"
-#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/tensor.pb.h"
-#include "tensorflow/core/framework/tensor_shape.pb.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/grappler/clusters/cluster.h"
-#include "tensorflow/core/grappler/devices.h"
-#include "tensorflow/core/grappler/grappler_item.h"
-#include "tensorflow/core/grappler/op_types.h"
-#include "tensorflow/core/grappler/utils.h"
-#include "tensorflow/core/grappler/utils/frame.h"
-#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/util/device_name_utils.h"
-
-namespace tensorflow {
-namespace grappler {
-namespace {
-
-const char kSuffix[] = "LayoutOptimizer";
-const char kDefaultDevice[] = "DefaultDevice";
-const char kPermNHWCToNCHW[] = "PermConstNHWCToNCHW";
-const char kPermNCHWToNHWC[] = "PermConstNCHWToNHWC";
-const char kTransposeNHWCToNCHW[] = "TransposeNHWCToNCHW";
-const char kTransposeNCHWToNHWC[] = "TransposeNCHWToNHWC";
-const char kDimMapNHWCToNCHW[] = "DimMapNHWCToNCHW";
-const char kDimMapNCHWToNHWC[] = "DimMapNCHWToNHWC";
-const char kVecPermuteNHWCToNCHW[] = "VecPermuteNHWCToNCHW";
-const char kVecPermuteNCHWToNHWC[] = "VecPermuteNCHWToNHWC";
-const char kReshapeNHWCToNCHW[] = "ReshapeNHWCToNCHW";
-const char kReshapeConst[] = "ReshapeConst";
-
-std::set<string> GetOpsFormatSupported() {
-  std::set<string> ops_format_supported = {
-      "AvgPool",
-      "AvgPoolGrad",
-      "Conv2D",
-      "Conv2DBackpropFilter",
-      "Conv2DBackpropInput",
-      "BiasAdd",
-      "BiasAddGrad",
-      "DepthwiseConv2dNative",
-      "DepthwiseConv2dNativeBackpropInput",
-      "DepthwiseConv2dNativeBackpropFilter",
-      "FusedBatchNorm",
-      "FusedBatchNormV2",
-      "FusedBatchNormV3",
-      "FusedBatchNormGrad",
-      "FusedBatchNormGradV2",
-      "FusedBatchNormGradV3",
-      "FusedConv2DBiasActivation",
-      "MaxPool",
-      "MaxPoolV2",
-      "MaxPoolGrad",
-      "MaxPoolGradGrad",
-      "MaxPoolGradV2",
-      "MaxPoolGradGradV2",
-      "SpaceToDepth",
-      "DepthToSpace"};
-  return ops_format_supported;
-}
-
-std::set<string> GetOpsFormatAgnostic() {
-  std::set<string> ops_format_agnostic = {"Abs",
-                                          "Add",
-                                          "AddN",
-                                          "AddV2",
-                                          "Acos",
-                                          "Acosh",
-                                          "All",
-                                          "Angle",
-                                          "Any",
-                                          "ApproximateEqual",
-                                          "Asin",
-                                          "Asinh",
-                                          "Atan",
-                                          "Atan2",
-                                          "Atanh",
-                                          "Betainc",
-                                          "Bitcast",
-                                          "Cast",
-                                          "Ceil",
-                                          "CheckNumerics",
-                                          "Complex",
-                                          "ComplexAbs",
-                                          "Concat",
-                                          "ConcatV2",
-                                          "Conj",
-                                          "Cos",
-                                          "Cosh",
-                                          "Digamma",
-                                          "Div",
-                                          "Elu",
-                                          "EluGrad",
-                                          "Enter",
-                                          "Equal",
-                                          "Erf",
-                                          "Erfc",
-                                          "Exit",
-                                          "Exp",
-                                          "Expm1",
-                                          "FakeQuantWithMinMaxVars",
-                                          "FakeQuantWithMinMaxArgs",
-                                          "Fill",
-                                          "Floor",
-                                          "FloorDiv",
-                                          "FloorMod",
-                                          "Greater",
-                                          "GreaterEqual",
-                                          "GuaranteeConst",
-                                          "HistogramSummary",
-                                          "Identity",
-                                          "IdentityN",
-                                          "Igamma",
-                                          "Igammac",
-                                          "Imag",
-                                          "Inv",
-                                          "InvGrad",
-                                          "IsFinite",
-                                          "IsInf",
-                                          "IsNan",
-                                          "Less",
-                                          "LessEqual",
-                                          "Lgamma",
-                                          "Log",
-                                          "LogicalAnd",
-                                          "LogicalNot",
-                                          "LogicalOr",
-                                          "Log1p",
-                                          "Max",
-                                          "Maximum",
-                                          "Mean",
-                                          "Merge",
-                                          "Min",
-                                          "Minimum",
-                                          "Mod",
-                                          "Mul",
-                                          "Neg",
-                                          "NextIteration",
-                                          "NotEqual",
-                                          "OnesLike",
-                                          "Pad",
-                                          "PreventGradient",
-                                          "Prod",
-                                          "Polygamma",
-                                          "QuantizeAndDequantizeV2",
-                                          "QuantizeAndDequantizeV3",
-                                          "QuantizeAndDequantizeV4",
-                                          "Pow",
-                                          "Real",
-                                          "RealDiv",
-                                          "Reciprocal",
-                                          "ReciprocalGrad",
-                                          "Relu",
-                                          "Relu6",
-                                          "Relu6Grad",
-                                          "ReluGrad",
-                                          "Rint",
-                                          "Select",
-                                          "SelectV2",
-                                          "Selu",
-                                          "SeluGrad",
-                                          "Shape",
-                                          "ShapeN",
-                                          "Sigmoid",
-                                          "SigmoidGrad",
-                                          "Sign",
-                                          "Sin",
-                                          "Sinh",
-                                          "Slice",
-                                          "Snapshot",
-                                          "Softplus",
-                                          "SoftplusGrad",
-                                          "Split",
-                                          "SplitV",
-                                          "StridedSlice",
-                                          "StridedSliceGrad",
-                                          "Switch",
-                                          "_SwitchN",
-                                          "Tile",
-                                          "TruncateDiv",
-                                          "TruncateMod",
-                                          "ReverseV2",
-                                          "Round",
-                                          "Rsqrt",
-                                          "RsqrtGrad",
-                                          "Sqrt",
-                                          "SqrtGrad",
-                                          "Square",
-                                          "SquaredDifference",
-                                          "Squeeze",
-                                          "StopGradient",
-                                          "Sub",
-                                          "Sum",
-                                          "Tan",
-                                          "Tanh",
-                                          "TanhGrad",
-                                          "ZerosLike",
-                                          "Zeta"};
-  return ops_format_agnostic;
-}
-
-bool EndWith(const string& str, const string& ending) {
-  if (str.size() < ending.size()) return false;
-  if (str.substr(str.size() - ending.size(), ending.size()) == ending)
-    return true;
-  return false;
-}
-
-bool IsNodeByLayoutOptimizer(const string& node_name) {
-  const string suffix = kSuffix;
-  return EndWith(node_name, suffix);
-}
-
-bool IsNodeType(const string& node_name, const string& type) {
-  const string suffix = strings::StrCat(type, "-", kSuffix);
-  return EndWith(node_name, suffix);
-}
-
-bool IsTransposeNHWCToNCHW(const string& node_name) {
-  return IsNodeType(node_name, kTransposeNHWCToNCHW);
-}
-
-bool IsTransposeNCHWToNHWC(const string& node_name) {
-  return IsNodeType(node_name, kTransposeNCHWToNHWC);
-}
-
-bool IsDimMapNHWCToNCHW(const string& node_name) {
-  return IsNodeType(node_name, kDimMapNHWCToNCHW);
-}
-
-bool IsDimMapNCHWToNHWC(const string& node_name) {
-  return IsNodeType(node_name, kDimMapNCHWToNHWC);
-}
-
-bool IsVecPermuteNHWCToNCHW(const string& node_name) {
-  return IsNodeType(node_name, kVecPermuteNHWCToNCHW);
-}
-
-bool IsVecPermuteNCHWToNHWC(const string& node_name) {
-  return IsNodeType(node_name, kVecPermuteNCHWToNHWC);
-}
-
-bool IsConcat(const NodeDef& node) {
-  const auto op = node.op();
-  return op == "Concat" || op == "ConcatV2";
-}
-
-bool IsConcatV1(const NodeDef& node) {
-  const auto op = node.op();
-  return op == "Concat";
-}
-
-bool IsMaxPoolV2(const NodeDef& node) {
-  const auto& op = node.op();
-  return op == "MaxPoolV2";
-}
-
-bool IsMaxPoolGradV1(const NodeDef& node) {
-  const auto& op = node.op();
-  return op == "MaxPoolGrad";
-}
-
-bool IsMaxPoolGradV2(const NodeDef& node) {
-  const auto& op = node.op();
-  return op == "MaxPoolGradV2";
-}
-
-bool IsMaxPoolGradGradV1(const NodeDef& node) {
-  const auto& op = node.op();
-  return op == "MaxPoolGradGrad";
-}
-
-bool IsMaxPoolGradGradV2(const NodeDef& node) {
-  const auto& op = node.op();
-  return op == "MaxPoolGradGradV2";
-}
-
-bool IsUnaryGrad(const NodeDef& node) {
-  bool is_unary_grad =
-      IsEluGrad(node) || IsInvGrad(node) || IsReciprocalGrad(node) ||
-      IsRelu6Grad(node) || IsReluGrad(node) || IsRsqrtGrad(node) ||
-      IsSeluGrad(node) || IsSigmoidGrad(node) || IsSoftplusGrad(node) ||
-      IsSoftsignGrad(node) || IsSqrtGrad(node) || IsTanhGrad(node);
-  return is_unary_grad;
-}
-
-bool IsComparisonOp(const NodeDef& node) {
-  bool is_compare = IsApproximateEqual(node) || IsEqual(node) ||
-                    IsGreater(node) || IsGreaterEqual(node) || IsLess(node) ||
-                    IsLessEqual(node) || IsNotEqual(node);
-  return is_compare;
-}
-
-bool IsReduceOp(const NodeDef& node) {
-  return IsSum(node) || IsMean(node) || IsProd(node) || IsMax(node) ||
-         IsMin(node) || IsAll(node) || IsAny(node);
-}
-
-bool IsBinaryOp(const NodeDef& node) {
-  bool is_binary =
-      IsAdd(node) || IsAtan2(node) || IsComparisonOp(node) || IsComplex(node) ||
-      IsDiv(node) || IsFloorDiv(node) || IsIgamma(node) || IsIgammac(node) ||
-      IsLogicalAnd(node) || IsLogicalOr(node) || IsMaximum(node) ||
-      IsMinimum(node) || IsMod(node) || IsMul(node) || IsPolygamma(node) ||
-      IsPow(node) || IsRealDiv(node) || IsSquaredDifference(node) ||
-      IsSub(node) || IsTruncateDiv(node) || IsTruncateMod(node) || IsZeta(node);
-  return is_binary;
-}
-
-std::vector<int> NonControlInputs(const NodeDef& node) {
-  std::vector<int> pos;
-  for (int i = 0; i < node.input_size(); i++) {
-    if (!IsControlInput(node.input(i))) {
-      pos.push_back(i);
-    }
-  }
-  return pos;
-}
-
-std::vector<int> DataInputPosConcat(const NodeDef& node) {
-  int n = node.attr().at("N").i();
-  std::vector<int> input_pos;
-  int start = (IsConcatV1(node)) ? 1 : 0;
-  int end = start + n;
-  for (int i = start; i < end; i++) {
-    input_pos.push_back(i);
-  }
-  return input_pos;
-}
-
-std::vector<int> DataInputPos(const NodeDef& node) {
-  if (IsSplit(node) || IsHistogramSummary(node)) {
-    return {1};
-  }
-  if (IsStridedSliceGrad(node)) {
-    return {4};
-  }
-  if (IsBinaryOp(node) || IsUnaryGrad(node)) {
-    return {0, 1};
-  }
-  if (IsBetainc(node) || IsSelect(node)) {
-    return {0, 1, 2};
-  }
-  if (IsShapeN(node) || IsIdentityN(node) || IsAddN(node) || IsMerge(node)) {
-    return NonControlInputs(node);
-  }
-  if (IsConcat(node)) {
-    return DataInputPosConcat(node);
-  }
-  if (node.input_size() > 0 && !IsControlInput(node.input(0))) {
-    return {0};
-  }
-  return {};
-}
-
-bool IsHostMemory(const NodeDef& node, int output_port) {
-  DeviceNameUtils::ParsedName parsed_name;
-  if (DeviceNameUtils::ParseFullName(node.device(), &parsed_name)) {
-    DeviceType device_type(parsed_name.type);
-    Status s = FindKernelDef(device_type, node, nullptr, nullptr);
-    if (s.ok()) {
-      tensorflow::MemoryTypeVector in_mtypes;
-      tensorflow::MemoryTypeVector out_mtypes;
-      s = tensorflow::MemoryTypesForNode(OpRegistry::Global(), device_type,
-                                         node, &in_mtypes, &out_mtypes);
-      if (s.ok()) {
-        if (out_mtypes[output_port] == HOST_MEMORY) {
-          return true;
-        }
-      }
-    } else {
-      return true;
-    }
-  }
-  return false;
-}
-
-class GraphProcessor {
- public:
-  GraphProcessor(const GraphProperties& graph_properties,
-                 const VirtualPlacer& virtual_placer,
-                 const std::unordered_set<string>& nodes_to_preserve,
-                 GraphDef* graph, NodeMap* node_map)
-      : graph_properties_(graph_properties),
-        virtual_placer_(virtual_placer),
-        nodes_to_preserve_(nodes_to_preserve),
-        graph_(graph),
-        node_map_(node_map) {}
-
- protected:
-  NodeDef* AddNodePermConst(const string& name, const string& device,
-                            const std::vector<int>& permutation) {
-    NodeDef* node = graph_->add_node();
-    node_map_->AddNode(name, node);
-    node->set_name(name);
-    node->set_op("Const");
-    AttrValue attr_data_type;
-    attr_data_type.set_type(DT_INT32);
-    node->mutable_attr()->insert({"dtype", attr_data_type});
-    AttrValue attr_tensor;
-    Tensor tensor(DT_INT32, TensorShape({4}));
-    for (int i = 0; static_cast<size_t>(i) < permutation.size(); i++) {
-      tensor.flat<int>()(i) = permutation[i];
-    }
-    tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
-    node->mutable_attr()->insert({"value", attr_tensor});
-    string device_name;
-    if (device.empty()) {
-      device_name = virtual_placer_.get_canonical_device_name(*node);
-    } else {
-      device_name = device;
-    }
-    node->set_device(device_name);
-    return node;
-  }
-
-  NodeDef* AddNodeConstScalar(const string& name, const string& device,
-                              DataType dtype, int value) {
-    NodeDef* node = graph_->add_node();
-    node_map_->AddNode(name, node);
-    node->set_name(name);
-    node->set_op("Const");
-    AttrValue attr_data_type;
-    attr_data_type.set_type(dtype);
-    node->mutable_attr()->insert({"dtype", attr_data_type});
-    AttrValue attr_tensor;
-    Tensor tensor(dtype, TensorShape({}));
-    tensor.scalar<int>()() = value;
-    tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
-    node->mutable_attr()->insert({"value", attr_tensor});
-    string device_name;
-    if (device.empty()) {
-      device_name = virtual_placer_.get_canonical_device_name(*node);
-    } else {
-      device_name = device;
-    }
-    node->set_device(device_name);
-    return node;
-  }
-
-  string LayoutOptimizerNode(const string& base_name) {
-    return strings::StrCat(base_name, "-", kSuffix);
-  }
-
-  const GraphProperties& graph_properties_;
-  const VirtualPlacer& virtual_placer_;
-  const std::unordered_set<string>& nodes_to_preserve_;
-  GraphDef* graph_;
-  NodeMap* node_map_;
-};
-
-struct OptimizeContext {
-  OptimizeContext(GraphDef* graph, NodeDef* node, NodeMap* node_map,
-                  const GraphProperties& graph_properties,
-                  const VirtualPlacer& virtual_placer,
-                  const std::unordered_set<string>& nodes_to_preserve,
-                  bool is_in_frame,
-                  std::unordered_set<string>* devices_with_perm_const)
-      : graph(graph),
-        node(node),
-        node_map(node_map),
-        graph_properties(graph_properties),
-        virtual_placer(virtual_placer),
-        nodes_to_preserve(nodes_to_preserve),
-        is_in_frame(is_in_frame),
-        devices_with_perm_const(devices_with_perm_const) {}
-  GraphDef* graph;
-  NodeDef* node;
-  NodeMap* node_map;
-  const GraphProperties& graph_properties;
-  const VirtualPlacer& virtual_placer;
-  const std::unordered_set<string>& nodes_to_preserve;
-  bool is_in_frame;
-  std::unordered_set<string>* devices_with_perm_const;  // not owned
-};
-
-class NodeProcessor : public GraphProcessor {
- public:
-  explicit NodeProcessor(const OptimizeContext& opt_cxt)
-      : GraphProcessor(opt_cxt.graph_properties, opt_cxt.virtual_placer,
-                       opt_cxt.nodes_to_preserve, opt_cxt.graph,
-                       opt_cxt.node_map),
-        node_(opt_cxt.node),
-        is_in_frame_(opt_cxt.is_in_frame),
-        devices_with_perm_const_(opt_cxt.devices_with_perm_const) {}
-  virtual ~NodeProcessor() {}
-  virtual Status ConvertNode() {
-    if (ShouldProcess()) {
-      UpdateAttrDataFormat();
-      UpdateAttrKSize();
-      UpdateAttrStrides();
-      UpdateAttrDilations();
-      UpdateAttrExplicitPaddings();
-      UpdateAttrShape();
-      TF_RETURN_IF_ERROR(AddLayoutTransposeToInputs());
-      TF_RETURN_IF_ERROR(AddLayoutTransposeToOutputs());
-      TF_RETURN_IF_ERROR(CustomizedProcessing());
-    }
-    return Status::OK();
-  }
-
- protected:
-  bool IsPortDimsN(const NodeDef& node, int port, int n) const {
-    if (node.attr().find("_output_shapes") != node.attr().end()) {
-      if (node.attr().at("_output_shapes").list().shape_size() > port) {
-        auto shape = node.attr().at("_output_shapes").list().shape(port);
-        if (shape.unknown_rank()) {
-          return false;
-        }
-        if (shape.dim_size() == n) {
-          return true;
-        }
-      }
-    }
-    return false;
-  }
-
-  bool IsPortZeroDimsN(const NodeDef& node, int n) const {
-    return IsPortDimsN(node, 0, n);
-  }
-
-  bool IsPortZeroDimsFour(const NodeDef& node) const {
-    return NodeProcessor::IsPortZeroDimsN(node, 4) ||
-           IsTransposeNCHWToNHWC(node.name());
-  }
-
-  bool IsPortDimsFour(const NodeDef& node, int port) const {
-    return NodeProcessor::IsPortDimsN(node, port, 4) ||
-           IsTransposeNCHWToNHWC(node.name());
-  }
-
-  bool IsNHWC() const {
-    if (node_->attr().find("data_format") != node_->attr().end()) {
-      if (node_->attr().at("data_format").s().compare("NHWC") == 0) {
-        return true;
-      }
-    }
-    return false;
-  }
-
-  bool HasOutputs() const {
-    auto outputs = node_map_->GetOutputs(node_->name());
-    return !outputs.empty();
-  }
-
-  Status HasAttribute(const NodeDef& node, const string& attr) const {
-    if (node.attr().find(attr) == node.attr().end()) {
-      return Status(error::INVALID_ARGUMENT,
-                    strings::StrCat("Missing attribute ", attr));
-    }
-    return Status::OK();
-  }
-
-  bool MustPreserve() const {
-    return nodes_to_preserve_.find(node_->name()) != nodes_to_preserve_.end();
-  }
-
-  bool IsOnGPU() const {
-    string device_name;
-    if (node_->device().empty()) {
-      device_name = virtual_placer_.get_canonical_device_name(*node_);
-    } else {
-      device_name = node_->device();
-    }
-    string device;
-    string not_used;
-    if (DeviceNameUtils::SplitDeviceName(device_name, &not_used, &device) &&
-        absl::StrContains(absl::AsciiStrToLower(device),
-                          absl::AsciiStrToLower(DEVICE_GPU))) {
-      return true;
-    }
-    return false;
-  }
-
-  virtual bool ShouldProcess() const {
-    return !MustPreserve() && IsNHWC() && IsPortZeroDimsFour(*node_) &&
-           HasOutputs() && IsOnGPU();
-  }
-
-  virtual void UpdateAttrShape() {
-    if (node_->attr().find("_output_shapes") != node_->attr().end()) {
-      for (const auto& pos : GetOutputPos()) {
-        auto shape = node_->mutable_attr()
-                         ->at("_output_shapes")
-                         .mutable_list()
-                         ->mutable_shape(pos);
-        if (shape->dim_size() == 4) {
-          int64_t h = shape->dim(1).size();
-          int64_t w = shape->dim(2).size();
-          int64_t c = shape->dim(3).size();
-          shape->mutable_dim(1)->set_size(c);
-          shape->mutable_dim(2)->set_size(h);
-          shape->mutable_dim(3)->set_size(w);
-        }
-      }
-    }
-  }
-
-  Status UpdateAttrValueOfInput(int input_index, bool permute) {
-    auto input_node = node_map_->GetNode(node_->input(input_index));
-    // We created a copy of the node, so that we don't modify the original node,
-    // which might be used elsewhere. Note that this copy also copies the
-    // control dependency input in the case this node is inside a loop,
-    // to ensure added_node is in the same frame with node_.
-    NodeDef* added_node = graph_->add_node();
-    *added_node = *input_node;
-    string base_name = strings::StrCat(node_->name(), "-", input_index);
-    string node_name = LayoutOptimizerNode(base_name);
-    added_node->set_name(node_name);
-    *node_->mutable_input(input_index) = node_name;
-    node_map_->AddNode(node_name, added_node);
-    node_map_->AddOutput(node_name, node_->name());
-    return UpdateAttrValue(added_node, permute);
-  }
-
-  virtual std::vector<int> GetInputPos() const { return {0}; }
-
-  virtual std::set<int> GetOutputPos() const {
-    // For most nodes, no need to process control nodes or nodes that use an
-    // output other than the first output: only the first output is of
-    // 4D NCHW/NHWC format and thus relevant here.
-    std::set<int> output_pos = {0};
-    return output_pos;
-  }
-
-  virtual Status AddLayoutTransposeToInputs() {
-    std::vector<int> input_pos = GetInputPos();
-    for (const auto& pos : input_pos) {
-      string node_name = LayoutOptimizerNode(
-          strings::StrCat(node_->name(), "-", pos, "-", kTransposeNHWCToNCHW));
-      DataType dtype =
-          graph_properties_.GetInputProperties(node_->name())[pos].dtype();
-      auto input_node = node_map_->GetNode(node_->input(pos));
-      TF_RETURN_IF_ERROR(HasAttribute(*input_node, "_output_shapes"));
-      string const_name = GetOrAddNodePermNHWCToNCHW(pos);
-      int output_pos;
-      ParseNodeName(node_->input(pos), &output_pos);
-      AddNodeTranspose(
-          node_name, node_->input(pos), const_name, dtype,
-          input_node->attr().at("_output_shapes").list().shape(output_pos),
-          true);
-      node_map_->UpdateOutput(NodeName(node_->input(pos)), node_->name(),
-                              node_name);
-      node_map_->AddOutput(node_name, node_->name());
-      *node_->mutable_input(pos) = node_name;
-    }
-    return Status::OK();
-  }
-
-  Status AddTransformToOutputs(const string& op) {
-    auto outputs = node_map_->GetOutputs(node_->name());
-    string const_name = GetOrAddNodePermNCHWToNHWC();
-    int output_count = 0;
-    for (const auto& output : outputs) {
-      int connections = 0;
-      int connections_removed = 0;
-      for (int i = 0; i < output->input_size(); i++) {
-        auto& input = *output->mutable_input(i);
-        int input_port;
-        string input_name = ParseNodeName(input, &input_port);
-        auto output_pos = GetOutputPos();
-        if (input_name == node_->name()) {
-          connections++;
-          if (output_pos.find(input_port) != output_pos.end()) {
-            connections_removed++;
-            string added_node_base_name =
-                strings::StrCat(node_->name(), "-", output_count, "-", i);
-            string added_node_name;
-            DataType dtype =
-                graph_properties_.GetOutputProperties(node_->name())[input_port]
-                    .dtype();
-            if (op == "Transpose") {
-              added_node_name = LayoutOptimizerNode(strings::StrCat(
-                  added_node_base_name, "-", kTransposeNCHWToNHWC));
-              TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes"));
-              AddNodeTranspose(
-                  added_node_name, input, const_name, dtype,
-                  node_->attr().at("_output_shapes").list().shape(input_port),
-                  false);
-            } else if (op == "DataFormatVecPermute") {
-              added_node_name = LayoutOptimizerNode(strings::StrCat(
-                  added_node_base_name, "-", kVecPermuteNCHWToNHWC));
-              AddNodeDataFormatOp(added_node_name, input, op, dtype, false);
-            } else {
-              return errors::InvalidArgument("Unsupported op type: ", op);
-            }
-            input = added_node_name;
-            node_map_->AddOutput(node_->name(), added_node_name);
-            node_map_->AddOutput(added_node_name, output->name());
-          }
-        }
-      }
-      if (connections == connections_removed) {
-        node_map_->RemoveOutput(node_->name(), output->name());
-      }
-      output_count++;
-    }
-    return Status::OK();
-  }
-
-  virtual Status AddLayoutTransposeToOutputs() {
-    return AddTransformToOutputs("Transpose");
-  }
-
-  virtual Status CustomizedProcessing() { return Status::OK(); }
-
-  Status UpdateOrTransformParamInput(int param_index, const string& op,
-                                     DataType dtype) {
-    auto param_node = node_map_->GetNode(node_->input(param_index));
-    bool permute = (op == "DataFormatVecPermute") ? true : false;
-    if (IsConstant(*param_node)) {
-      TF_RETURN_IF_ERROR(UpdateAttrValueOfInput(param_index, permute));
-    } else {
-      AddDataFormatTransformToParamInput(op, param_index, dtype);
-    }
-    return Status::OK();
-  }
-
-  NodeDef* node_;
-  bool is_in_frame_;
-  std::unordered_set<string>* devices_with_perm_const_;  // not owned.
-
- private:
-  string CompliantDeviceName(const string& device) {
-    if (device.empty()) return string(kDefaultDevice);
-    string ret(device);
-    std::replace(ret.begin(), ret.end(), '/', '_');
-    std::replace(ret.begin(), ret.end(), ':', '_');
-    return string(absl::StripPrefix(ret, "_"));
-  }
-
-  void UpdateAttrKSize() {
-    if (node_->attr().find("ksize") != node_->attr().end()) {
-      auto list = node_->mutable_attr()->at("ksize").mutable_list();
-      UpdateTuple(list);
-    }
-  }
-
-  void UpdateAttrStrides() {
-    if (node_->attr().find("strides") != node_->attr().end()) {
-      auto list = node_->mutable_attr()->at("strides").mutable_list();
-      UpdateTuple(list);
-    }
-  }
-
-  void UpdateAttrDilations() {
-    if (node_->attr().find("dilations") != node_->attr().end()) {
-      auto list = node_->mutable_attr()->at("dilations").mutable_list();
-      UpdateTuple(list);
-    }
-  }
-
-  void UpdateAttrExplicitPaddings() {
-    if (node_->attr().find("explicit_paddings") != node_->attr().end()) {
-      auto list = node_->mutable_attr()->at("explicit_paddings").mutable_list();
-      int size = list->i_size();
-      if (size == 8) {
-        int64_t height_before = list->i(2);
-        int64_t height_after = list->i(3);
-        int64_t width_before = list->i(4);
-        int64_t width_after = list->i(5);
-        list->set_i(2, 0);
-        list->set_i(3, 0);
-        list->set_i(4, height_before);
-        list->set_i(5, height_after);
-        list->set_i(6, width_before);
-        list->set_i(7, width_after);
-      } else if (size != 0) {
-        LOG(ERROR) << "Cannot handle explicit_paddings attribute of size "
-                   << size;
-      }
-    }
-  }
-
-  void UpdateAttrDataFormat() {
-    if (node_->attr().find("data_format") != node_->attr().end()) {
-      if (node_->attr().at("data_format").s().compare("NHWC") == 0) {
-        string* data_format =
-            node_->mutable_attr()->at("data_format").mutable_s();
-        *data_format = "NCHW";
-      }
-    }
-  }
-
-  Status UpdateAttrValue(NodeDef* node, bool permute) {
-    TF_RETURN_IF_ERROR(HasAttribute(*node, "value"));
-    Tensor tensor;
-    auto success =
-        tensor.FromProto(node->mutable_attr()->at({"value"}).tensor());
-    if (!success) {
-      LOG(ERROR) << "Failed to parse TensorProto.";
-    }
-
-    if (permute) {
-      if (tensor.dims() == 1) {
-        if (tensor.flat<int>().size() == 4) {
-          int c = tensor.flat<int>()(3);
-          tensor.flat<int>()(3) = tensor.flat<int>()(2);
-          tensor.flat<int>()(2) = tensor.flat<int>()(1);
-          tensor.flat<int>()(1) = c;
-        } else {
-          return Status(error::INVALID_ARGUMENT,
-                        strings::StrCat("Unsupported tensor size: ",
-                                        tensor.flat<int>().size()));
-        }
-      } else if (tensor.dims() == 2) {
-        for (int i = 0; i < 2; i++) {
-          int c = tensor.matrix<int>()(3, i);
-          tensor.matrix<int>()(3, i) = tensor.matrix<int>()(2, i);
-          tensor.matrix<int>()(2, i) = tensor.matrix<int>()(1, i);
-          tensor.matrix<int>()(1, i) = c;
-        }
-      } else {
-        return Status(
-            error::INVALID_ARGUMENT,
-            strings::StrCat("Unsupported dimension size: ", tensor.dims()));
-      }
-    } else {
-      for (int i = 0; i < tensor.flat<int>().size(); i++) {
-        int value = tensor.flat<int>()(i);
-        value = (value >= 0) ? value : value + 4;
-        if (value == 1 || value == 2) {
-          value = value + 1;
-        } else if (value == 3) {
-          value = 1;
-        }
-        tensor.flat<int>()(i) = value;
-      }
-    }
-
-    if (tensor.dims() == 0) {
-      tensor.AsProtoField(node->mutable_attr()->at({"value"}).mutable_tensor());
-    } else {
-      tensor.AsProtoTensorContent(
-          node->mutable_attr()->at({"value"}).mutable_tensor());
-    }
-    return Status::OK();
-  }
-
-  NodeDef* AddNodeTranspose(const string& node_name, const string& input_name,
-                            const string& const_name, DataType data_type,
-                            const TensorShapeProto& input_shape,
-                            bool NHWCToNCHW) {
-    NodeDef* node = graph_->add_node();
-    node_map_->AddNode(node_name, node);
-    node->set_name(node_name);
-    *node->add_input() = input_name;
-    *node->add_input() = const_name;
-    node->set_op("Transpose");
-    node->set_device(node_->device());
-    AttrValue attr_data_type;
-    attr_data_type.set_type(data_type);
-    node->mutable_attr()->insert({"T", attr_data_type});
-    AttrValue attr_data_type_perm;
-    attr_data_type_perm.set_type(DT_INT32);
-    node->mutable_attr()->insert({"Tperm", attr_data_type_perm});
-    if (!input_shape.unknown_rank()) {
-      AttrValue attr_output_shape;
-      auto output_shape = attr_output_shape.mutable_list()->add_shape();
-      if (NHWCToNCHW) {
-        output_shape->add_dim()->set_size(input_shape.dim(0).size());
-        output_shape->add_dim()->set_size(input_shape.dim(3).size());
-        output_shape->add_dim()->set_size(input_shape.dim(1).size());
-        output_shape->add_dim()->set_size(input_shape.dim(2).size());
-      } else {
-        output_shape->add_dim()->set_size(input_shape.dim(0).size());
-        output_shape->add_dim()->set_size(input_shape.dim(2).size());
-        output_shape->add_dim()->set_size(input_shape.dim(3).size());
-        output_shape->add_dim()->set_size(input_shape.dim(1).size());
-      }
-      node->mutable_attr()->insert({"_output_shapes", attr_output_shape});
-    }
-    return node;
-  }
-
-  NodeDef* AddNodePermNHWCToNCHW(const string& base_name,
-                                 const string& depended_node,
-                                 const string& device) {
-    string name =
-        LayoutOptimizerNode(strings::StrCat(base_name, "-", kPermNHWCToNCHW));
-    auto const_node = AddNodePermConst(name, device, {0, 3, 1, 2});
-    // This is to ensure the transpose node and the const node are in the same
-    // frame.
-    *const_node->add_input() = AsControlDependency(depended_node);
-    return const_node;
-  }
-
-  NodeDef* AddNodePermNCHWToNHWC(const string& base_name,
-                                 const string& depended_node,
-                                 const string& device) {
-    string name =
-        LayoutOptimizerNode(strings::StrCat(base_name, "-", kPermNCHWToNHWC));
-    auto const_node = AddNodePermConst(name, device, {0, 2, 3, 1});
-    // This is to ensure the transpose node and the const node are in the same
-    // frame.
-    *const_node->add_input() = AsControlDependency(depended_node);
-    return const_node;
-  }
-
-  // NOTE(zycao): We try to make sure each device has the permutation consts
-  // iff the consts are really needed. Thus no unexpected inter-worker
-  // connections and no redundant nodes would be existed.
-  void AddNodePermConstOnDevice(const string& device) {
-    string compliant_device_prefix;
-    compliant_device_prefix = CompliantDeviceName(device);
-    string node_name;
-    // Permutation const for NHWCToNCHW
-    node_name = strings::StrCat(compliant_device_prefix, "-",
-                                LayoutOptimizerNode(kPermNHWCToNCHW));
-    AddNodePermConst(node_name, device, {0, 3, 1, 2});
-
-    // Permutation const for NCHWToNHWC
-    node_name = strings::StrCat(compliant_device_prefix, "-",
-                                LayoutOptimizerNode(kPermNCHWToNHWC));
-    AddNodePermConst(node_name, device, {0, 2, 3, 1});
-  }
-
-  string GetOrAddNodePermNHWCToNCHW(int pos) {
-    string const_name;
-    if (is_in_frame_) {
-      string base_name = strings::StrCat(node_->name(), "-", pos);
-      string input = NodeName(node_->input(pos));
-      string depended_node;
-      if (!IsTransposeNCHWToNHWC(input)) {
-        depended_node = input;
-      } else {
-        auto input_node = node_map_->GetNode(input);
-        depended_node = NodeName(input_node->input(0));
-      }
-      auto const_node =
-          AddNodePermNHWCToNCHW(base_name, depended_node, node_->device());
-      const_name = const_node->name();
-    } else {
-      if (devices_with_perm_const_->find(node_->device()) ==
-          devices_with_perm_const_->end()) {
-        AddNodePermConstOnDevice(node_->device());
-        devices_with_perm_const_->insert(node_->device());
-      }
-      const_name = strings::StrCat(CompliantDeviceName(node_->device()), "-",
-                                   LayoutOptimizerNode(kPermNHWCToNCHW));
-    }
-    return const_name;
-  }
-
-  string GetOrAddNodePermNCHWToNHWC() {
-    string const_name;
-    if (is_in_frame_) {
-      auto const_node =
-          AddNodePermNCHWToNHWC(node_->name(), node_->name(), node_->device());
-      const_name = const_node->name();
-    } else {
-      if (devices_with_perm_const_->find(node_->device()) ==
-          devices_with_perm_const_->end()) {
-        AddNodePermConstOnDevice(node_->device());
-        devices_with_perm_const_->insert(node_->device());
-      }
-      const_name = strings::StrCat(CompliantDeviceName(node_->device()), "-",
-                                   LayoutOptimizerNode(kPermNCHWToNHWC));
-    }
-    return const_name;
-  }
-
-  void UpdateTuple(AttrValue_ListValue* list) {
-    int64_t h = list->i(1);
-    int64_t w = list->i(2);
-    int64_t c = list->i(3);
-    list->set_i(1, c);
-    list->set_i(2, h);
-    list->set_i(3, w);
-  }
-
-  bool IsInputOnHost(const string& input_name) const {
-    string device = node_->device();
-    DeviceNameUtils::ParsedName parsed_name;
-    if (DeviceNameUtils::ParseFullName(device, &parsed_name)) {
-      if (parsed_name.type != "CPU") {
-        NodeDef* input = node_map_->GetNode(input_name);
-        int port;
-        ParseNodeName(input_name, &port);
-        if (IsHostMemory(*input, port)) {
-          return true;
-        }
-      }
-    }
-    return false;
-  }
-
-  NodeDef* AddNodeDataFormatOp(const string& name, const string& input_name,
-                               const string& op, DataType dtype,
-                               bool nhwc_to_nchw) {
-    NodeDef* added_node = graph_->add_node();
-    added_node->set_name(name);
-    added_node->set_op(op);
-    node_map_->AddNode(added_node->name(), added_node);
-    added_node->set_device(node_->device());
-    // The inputs of a DataFormat op could be in host memory for ops such as
-    // Reshape. In such cases, run the kernel on the host too.
-    if (IsInputOnHost(input_name)) {
-      AttrValue attr_kernel;
-      attr_kernel.set_s("host");
-      added_node->mutable_attr()->insert({"_kernel", attr_kernel});
-    }
-    AttrValue attr_data_type;
-    attr_data_type.set_type(dtype);
-    added_node->mutable_attr()->insert({"T", attr_data_type});
-    string src_format = (nhwc_to_nchw) ? "NHWC" : "NCHW";
-    string dst_format = (nhwc_to_nchw) ? "NCHW" : "NHWC";
-    AttrValue attr_format;
-    attr_format.set_s(src_format);
-    added_node->mutable_attr()->insert({"src_format", attr_format});
-    attr_format.set_s(dst_format);
-    added_node->mutable_attr()->insert({"dst_format", attr_format});
-    *added_node->add_input() = input_name;
-    return added_node;
-  }
-
-  void AddDataFormatTransformToParamInput(const string& op, int input_pos,
-                                          DataType dtype) {
-    string suffix = (op == "DataFormatVecPermute") ? kVecPermuteNHWCToNCHW
-                                                   : kDimMapNHWCToNCHW;
-    string name = LayoutOptimizerNode(
-        strings::StrCat(node_->name(), "-", input_pos, "-", suffix));
-    auto added_node =
-        AddNodeDataFormatOp(name, node_->input(input_pos), op, dtype, true);
-    *node_->mutable_input(input_pos) = added_node->name();
-    node_map_->UpdateOutput(NodeName(added_node->input(0)), node_->name(),
-                            added_node->name());
-    node_map_->AddOutput(added_node->name(), node_->name());
-  }
-};
-
-class AvgPoolGradProcessor : public NodeProcessor {
- public:
-  explicit AvgPoolGradProcessor(const OptimizeContext& opt_cxt)
-      : NodeProcessor(opt_cxt) {}
-
- protected:
-  std::vector<int> GetInputPos() const override { return {1}; }
-  Status CustomizedProcessing() override {
-    return UpdateOrTransformParamInput(0, "DataFormatVecPermute", DT_INT32);
-  }
-};
-
-class BiasAddGradProcessor : public NodeProcessor {
- public:
-  explicit BiasAddGradProcessor(const OptimizeContext& opt_cxt)
-      : NodeProcessor(opt_cxt) {}
-
- protected:
-  bool ShouldProcess() const override {
-    if (MustPreserve()) {
-      return false;
-    }
-    if (!IsOnGPU()) {
-      return false;
-    }
-    auto input = node_map_->GetNode(node_->input(0));
-    if (input) {
-      int port;
-      ParseNodeName(node_->input(0), &port);
-      if (IsNHWC() && IsPortDimsFour(*input, port)) {
-        return true;
-      }
-    }
-    return false;
-  }
-
-  Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
-};
-
-class Conv2DProcessor : public NodeProcessor {
- public:
-  Conv2DProcessor(const OptimizeContext& opt_cxt, bool no_gemm)
-      : NodeProcessor(opt_cxt), no_gemm_(no_gemm) {}
-
- protected:
-  bool ShouldProcess() const override {
-    return !MustPreserve() && IsNHWC() && IsPortZeroDimsFour(*node_) &&
-           HasOutputs() && (!IsGemmUsed() || no_gemm_) && IsOnGPU() &&
-           IsDataTypeFloat();
-  }
-
-  TensorShapeProto GetShape(const string& input_name) const {
-    string node_name;
-    int output_pos;
-    node_name = ParseNodeName(input_name, &output_pos);
-    NodeDef* node = node_map_->GetNode(node_name);
-    if (node->attr().find("_output_shapes") != node->attr().end()) {
-      return node->attr().at("_output_shapes").list().shape(output_pos);
-    }
-    TensorShapeProto shape;
-    return shape;
-  }
-
-  bool IsStrideOne() const {
-    if (node_->attr().find("strides") != node_->attr().end()) {
-      auto list = node_->attr().at("strides").list();
-      return list.i(1) == 1 && list.i(2) == 1;
-    }
-    return false;
-  }
-
-  bool IsValidPadding() const {
-    if (node_->attr().find("padding") != node_->attr().end()) {
-      auto padding = node_->attr().at("padding").s();
-      return padding == "VALID";
-    }
-    return false;
-  }
-
-  bool IsDataTypeFloat() const {
-    if (node_->attr().find("T") != node_->attr().end()) {
-      return kDataTypeIsFloating.Contains(node_->attr().at("T").type());
-    }
-    return false;
-  }
-
-  // The logic inside this function is based on the internal implementation of
-  // Conv2D, Conv2DBackpropInput, and Conv2DBackpropFilter ops, and thus
-  // needs to be updated accordingly if the internal implementation changes.
-  bool IsGemmUsed(const TensorShapeProto& filter_shape,
-                  const TensorShapeProto& input_shape) const {
-    if (filter_shape.dim_size() == 4) {
-      if (filter_shape.dim(0).size() == 1 && filter_shape.dim(1).size() == 1 &&
-          IsStrideOne()) {
-        return true;
-      }
-    }
-    if (input_shape.dim_size() == 4 && filter_shape.dim_size() == 4) {
-      if (input_shape.dim(1).size() == filter_shape.dim(0).size() &&
-          input_shape.dim(2).size() == filter_shape.dim(1).size() &&
-          IsValidPadding()) {
-        return true;
-      }
-    }
-    return false;
-  }
-
-  virtual bool IsGemmUsed() const {
-    auto filter_shape = GetShape(node_->input(1));
-    auto input_shape = GetShape(node_->input(0));
-    return IsGemmUsed(filter_shape, input_shape);
-  }
-
-  bool no_gemm_;
-};
-
-class Conv2DBackpropFilterProcessor : public Conv2DProcessor {
- public:
-  Conv2DBackpropFilterProcessor(const OptimizeContext& opt_cxt, bool no_gemm)
-      : Conv2DProcessor(opt_cxt, no_gemm) {}
-
- protected:
-  bool IsGemmUsed() const override {
-    auto filter_shape = GetShape(node_->name());
-    auto input_shape = GetShape(node_->input(0));
-    return Conv2DProcessor::IsGemmUsed(filter_shape, input_shape);
-  }
-
-  std::vector<int> GetInputPos() const override { return {0, 2}; }
-
-  Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
-  // No need to update output shape, as it is always of shape
-  // [filter_height, filter_width, in_channels, out_channels], regardless of
-  // whether NCHW or NHWC is used.
-  void UpdateAttrShape() override {}
-};
-
-class Conv2DBackpropInputProcessor : public Conv2DProcessor {
- public:
-  Conv2DBackpropInputProcessor(const OptimizeContext& opt_cxt, bool no_gemm)
-      : Conv2DProcessor(opt_cxt, no_gemm) {}
-
- protected:
-  bool IsGemmUsed() const override {
-    auto filter_shape = GetShape(node_->input(1));
-    auto input_shape = GetShape(node_->name());
-    return Conv2DProcessor::IsGemmUsed(filter_shape, input_shape);
-  }
-
-  std::vector<int> GetInputPos() const override { return {2}; }
-
-  Status CustomizedProcessing() override {
-    return UpdateOrTransformParamInput(0, "DataFormatVecPermute", DT_INT32);
-  }
-};
-
-class FusedBatchNormGradProcessor : public NodeProcessor {
- public:
-  explicit FusedBatchNormGradProcessor(const OptimizeContext& opt_cxt)
-      : NodeProcessor(opt_cxt) {}
-
- protected:
-  bool ShouldProcess() const override {
-    return NodeProcessor::ShouldProcess() && IsTraining();
-  }
-
-  std::vector<int> GetInputPos() const override { return {0, 1}; }
-
- private:
-  bool IsTraining() const {
-    if (node_->attr().find("is_training") != node_->attr().end()) {
-      if (node_->attr().at("is_training").b()) {
-        return true;
-      }
-    }
-    return false;
-  }
-};
-
-class MaxPoolGradProcessor : public NodeProcessor {
- public:
-  explicit MaxPoolGradProcessor(const OptimizeContext& opt_cxt)
-      : NodeProcessor(opt_cxt) {}
-
- protected:
-  std::vector<int> GetInputPos() const override { return {0, 1, 2}; }
-};
-
-class MaxPoolGradV2Processor : public MaxPoolGradProcessor {
- public:
-  explicit MaxPoolGradV2Processor(const OptimizeContext& opt_cxt)
-      : MaxPoolGradProcessor(opt_cxt) {}
-
- protected:
-  Status CustomizedProcessing() override {
-    for (int i = 3; i <= 4; i++) {
-      TF_RETURN_IF_ERROR(
-          UpdateOrTransformParamInput(i, "DataFormatVecPermute", DT_INT32));
-    }
-    return Status::OK();
-  }
-};
-
-class MaxPoolV2Processor : public NodeProcessor {
- public:
-  explicit MaxPoolV2Processor(const OptimizeContext& opt_cxt)
-      : NodeProcessor(opt_cxt) {}
-
- protected:
-  bool ShouldProcess() const override {
-    // We check data_input's shape instead, because the shape inference of
-    // MaxPoolV2 is not able to infer the shape when ksize or strides is not
-    // constant.
-    auto data_input = node_map_->GetNode(node_->input(0));
-    int port;
-    ParseNodeName(node_->input(0), &port);
-    return !MustPreserve() && IsNHWC() && IsPortDimsFour(*data_input, port) &&
-           HasOutputs() && IsOnGPU();
-  }
-
-  Status CustomizedProcessing() override {
-    for (int i = 1; i <= 2; i++) {
-      TF_RETURN_IF_ERROR(
-          UpdateOrTransformParamInput(i, "DataFormatVecPermute", DT_INT32));
-    }
-    return Status::OK();
-  }
-};
-
-class AgnosticNodeProcessor : public NodeProcessor {
- public:
-  explicit AgnosticNodeProcessor(const OptimizeContext& opt_cxt)
-      : NodeProcessor(opt_cxt) {}
-
- protected:
-  bool ShouldProcess() const override {
-    return !MustPreserve() && IsPortZeroDimsFour(*node_) && HasOutputs() &&
-           IsNodeAfterNCHWToNHWC() && IsOnGPU();
-  }
-
-  bool IsNodeAfterNCHWToNHWC(const NodeDef& node) const {
-    std::set<string> ops_format_agnostic = GetOpsFormatAgnostic();
-    std::deque<NodeDef*> queue;
-    auto data_node_pos = DataInputPos(node);
-    std::unordered_set<string> visited;
-    for (const auto& pos : data_node_pos) {
-      auto input_node = node_map_->GetNode(node.input(pos));
-      queue.push_back(input_node);
-      visited.insert(input_node->name());
-    }
-    // The code will exit this while loop in one iteration in most cases, as the
-    // graph is already topologically sorted.
-    while (!queue.empty()) {
-      NodeDef* current_node = queue.front();
-      queue.pop_front();
-      if (IsTransposeNCHWToNHWC(current_node->name()) ||
-          IsDimMapNCHWToNHWC(current_node->name()) ||
-          IsVecPermuteNCHWToNHWC(current_node->name())) {
-        return true;
-      }
-      // We only continue searching if the path is connected through
-      // format-agnostic nodes.
-      if (ops_format_agnostic.find(current_node->op()) !=
-          ops_format_agnostic.end()) {
-        auto current_node_pos = DataInputPos(*current_node);
-        for (const auto& pos : current_node_pos) {
-          auto input_node = node_map_->GetNode(current_node->input(pos));
-          if (visited.find(input_node->name()) == visited.end()) {
-            queue.push_back(input_node);
-            visited.insert(input_node->name());
-          }
-        }
-      }
-    }
-    return false;
-  }
-
-  bool IsNodeAfterNCHWToNHWC() const { return IsNodeAfterNCHWToNHWC(*node_); }
-};
-
-class AddNProcessor : public AgnosticNodeProcessor {
- public:
-  explicit AddNProcessor(const OptimizeContext& opt_cxt)
-      : AgnosticNodeProcessor(opt_cxt) {}
-
- protected:
-  std::vector<int> GetInputPos() const override {
-    return NonControlInputs(*node_);
-  }
-};
-
-class BinaryOpProcessor : public AgnosticNodeProcessor {
- public:
-  explicit BinaryOpProcessor(const OptimizeContext& opt_cxt)
-      : AgnosticNodeProcessor(opt_cxt) {}
-
- protected:
-  bool ShouldProcess() const override {
-    return !MustPreserve() && IsPortZeroDimsFour(*node_) && HasOutputs() &&
-           IsNodeAfterNCHWToNHWC() &&
-           (IsNDOperateWithMD(4, 0) || IsNDOperateWithMD(4, 1) ||
-            IsNDOperateWithMD(4, 4) || IsNDOperateWithMD(0, 4) ||
-            IsNDOperateWithMD(1, 4)) &&
-           IsOnGPU();
-  }
-
-  std::vector<int> GetInputPos() const override {
-    std::vector<int> input_pos;
-    auto input0 = node_map_->GetNode(node_->input(0));
-    auto input1 = node_map_->GetNode(node_->input(1));
-    int input0_port;
-    ParseNodeName(node_->input(0), &input0_port);
-    int input1_port;
-    ParseNodeName(node_->input(1), &input1_port);
-    if (IsPortDimsFour(*input0, input0_port)) {
-      input_pos.push_back(0);
-    }
-    if (IsPortDimsFour(*input1, input1_port)) {
-      input_pos.push_back(1);
-    }
-    return input_pos;
-  }
-
-  bool IsNDOperateWithMD(int n, int m) const {
-    auto input0 = node_map_->GetNode(node_->input(0));
-    auto input1 = node_map_->GetNode(node_->input(1));
-    int input0_port;
-    ParseNodeName(node_->input(0), &input0_port);
-    int input1_port;
-    ParseNodeName(node_->input(1), &input1_port);
-
-    if (input0 && input1) {
-      bool input0_is_n = (n == 4) ? IsPortDimsFour(*input0, input0_port)
-                                  : IsPortDimsN(*input0, input0_port, n);
-      bool input1_is_m = (m == 4) ? IsPortDimsFour(*input1, input1_port)
-                                  : IsPortDimsN(*input1, input1_port, m);
-      return input0_is_n && input1_is_m;
-    }
-    return false;
-  }
-
-  NodeDef* AddNodeShapeConst(const string& name, int num_channels,
-                             const string& depended_node) {
-    NodeDef* node = graph_->add_node();
-    node_map_->AddNode(name, node);
-    node->set_name(name);
-    node->set_op("Const");
-    node->set_device(node_->device());
-    AttrValue attr_data_type;
-    attr_data_type.set_type(DT_INT32);
-    node->mutable_attr()->insert({"dtype", attr_data_type});
-
-    AttrValue attr_tensor;
-    Tensor tensor(DT_INT32, TensorShape({4}));
-    std::vector<int> shape = {1, num_channels, 1, 1};
-    for (int i = 0; i < static_cast<int>(shape.size()); i++) {
-      tensor.flat<int>()(i) = shape[i];
-    }
-    tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
-    node->mutable_attr()->insert({"value", attr_tensor});
-    if (is_in_frame_) {
-      // This is to ensure the transpose node and the const node are in the
-      // same frame.
-      *node->add_input() = AsControlDependency(depended_node);
-    }
-    return node;
-  }
-
-  NodeDef* AddNodeReshape(const string& node_name, const string& input_name,
-                          const string& shape_const_node_name,
-                          DataType data_type) {
-    NodeDef* node = graph_->add_node();
-    node_map_->AddNode(node_name, node);
-    node->set_name(node_name);
-    *node->add_input() = input_name;
-    *node->add_input() = shape_const_node_name;
-    node->set_op("Reshape");
-    node->set_device(node_->device());
-
-    AttrValue attr_type_indices;
-    attr_type_indices.set_type(DT_INT32);
-    node->mutable_attr()->insert({"Tshape", attr_type_indices});
-
-    AttrValue attr_type_params;
-    attr_type_params.set_type(data_type);
-    node->mutable_attr()->insert({"T", attr_type_params});
-    return node;
-  }
-
-  Status CustomizedProcessing() override {
-    int vector_index = -1;
-    if (IsNDOperateWithMD(4, 1)) {
-      vector_index = 1;
-    } else if (IsNDOperateWithMD(1, 4)) {
-      vector_index = 0;
-    }
-    if (vector_index != -1) {
-      string base_name = strings::StrCat(node_->name(), "-", vector_index);
-      string reshape_node_name = LayoutOptimizerNode(
-          strings::StrCat(base_name, "-", kReshapeNHWCToNCHW));
-      string shape_const_node_name =
-          LayoutOptimizerNode(strings::StrCat(base_name, "-", kReshapeConst));
-      auto input_node = node_map_->GetNode(node_->input(vector_index));
-      TF_RETURN_IF_ERROR(HasAttribute(*input_node, "_output_shapes"));
-      int port;
-      ParseNodeName(node_->input(vector_index), &port);
-      int vector_size = input_node->attr()
-                            .at("_output_shapes")
-                            .list()
-                            .shape(port)
-                            .dim(0)
-                            .size();
-      AddNodeShapeConst(shape_const_node_name, vector_size,
-                        NodeName(node_->input(vector_index)));
-      TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
-      AddNodeReshape(reshape_node_name, node_->input(vector_index),
-                     shape_const_node_name, node_->attr().at("T").type());
-      node_map_->AddOutput(shape_const_node_name, reshape_node_name);
-      node_map_->UpdateOutput(NodeName(node_->input(vector_index)),
-                              node_->name(), reshape_node_name);
-      node_map_->AddOutput(reshape_node_name, node_->name());
-      *node_->mutable_input(vector_index) = reshape_node_name;
-    }
-    return Status::OK();
-  }
-};
-
-class ConcatProcessor : public AgnosticNodeProcessor {
- public:
-  explicit ConcatProcessor(const OptimizeContext& opt_cxt)
-      : AgnosticNodeProcessor(opt_cxt) {
-    // For Concat,  the concat axis is the first input; for ConcatV2,
-    // the last input. Note that if with control inputs, the number of inputs
-    // is larger than the integer attribute N.
-    int n = node_->attr().at("N").i();
-    axis_node_pos_ = (IsConcatV1(*node_)) ? 0 : n;
-  }
-
- protected:
-  std::vector<int> GetInputPos() const override {
-    return DataInputPosConcat(*node_);
-  }
-
-  Status CustomizedProcessing() override {
-    DataType dtype =
-        (IsConcatV1(*node_)) ? DT_INT32 : node_->attr().at("Tidx").type();
-    return UpdateOrTransformParamInput(axis_node_pos_, "DataFormatDimMap",
-                                       dtype);
-  }
-
-  int axis_node_pos_;
-};
-
-class FillProcessor : public AgnosticNodeProcessor {
- public:
-  explicit FillProcessor(const OptimizeContext& opt_cxt)
-      : AgnosticNodeProcessor(opt_cxt) {}
-
- protected:
-  std::vector<int> GetInputPos() const override { return {}; }
-
-  Status CustomizedProcessing() override {
-    DataType dtype = node_->attr().at("index_type").type();
-    return UpdateOrTransformParamInput(0, "DataFormatVecPermute", dtype);
-  }
-};
-
-class HistogramSummaryProcessor : public AgnosticNodeProcessor {
- public:
-  explicit HistogramSummaryProcessor(const OptimizeContext& opt_cxt)
-      : AgnosticNodeProcessor(opt_cxt) {}
-
- protected:
-  bool ShouldProcess() const override {
-    auto input1 = node_map_->GetNode(node_->input(1));
-    int port;
-    ParseNodeName(node_->input(1), &port);
-    return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
-           IsPortDimsFour(*input1, port) && IsOnGPU();
-  }
-
-  std::vector<int> GetInputPos() const override { return {1}; }
-
-  Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
-};
-
-class IdentityNProcessor : public AgnosticNodeProcessor {
- public:
-  explicit IdentityNProcessor(const OptimizeContext& opt_cxt)
-      : AgnosticNodeProcessor(opt_cxt) {
-    std::set<string> ops_format_agnostic = GetOpsFormatAgnostic();
-    for (int i = 0; i < node_->input_size(); i++) {
-      auto input = node_map_->GetNode(node_->input(i));
-      int port;
-      ParseNodeName(node_->input(i), &port);
-      // Skip control input.
-      if (port != -1) {
-        bool is_agnostic =
-            ops_format_agnostic.find(input->op()) != ops_format_agnostic.end();
-        if (IsPortDimsFour(*input, port) &&
-            ((IsNodeAfterNCHWToNHWC(*input) && is_agnostic) ||
-             IsTransposeNCHWToNHWC(input->name()))) {
-          input_pos_.push_back(i);
-        }
-      }
-    }
-  }
-
- protected:
-  bool ShouldProcess() const override {
-    return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
-           IsOnGPU();
-  }
-
-  std::vector<int> GetInputPos() const override { return input_pos_; }
-
-  std::set<int> GetOutputPos() const override {
-    std::set<int> output_pos{};
-    for (const auto& input_pos : input_pos_) {
-      output_pos.insert(input_pos);
-    }
-    return output_pos;
-  }
-
- private:
-  std::vector<int> input_pos_;
-};
-
-class ShapeProcessor : public IdentityNProcessor {
- public:
-  explicit ShapeProcessor(const OptimizeContext& opt_cxt)
-      : IdentityNProcessor(opt_cxt) {}
-
- protected:
-  Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
-
-  Status CustomizedProcessing() override {
-    return AddTransformToOutputs("DataFormatVecPermute");
-  }
-};
-
-class MergeProcessor : public AgnosticNodeProcessor {
- public:
-  explicit MergeProcessor(const OptimizeContext& opt_cxt)
-      : AgnosticNodeProcessor(opt_cxt) {}
-
- protected:
-  bool ShouldProcess() const override {
-    return !MustPreserve() && IsPortZeroDimsFour(*node_) && HasOutputs() &&
-           IsEveryInputAfterNCHWToNHWC() && IsOnGPU();
-  }
-
-  std::vector<int> GetInputPos() const override {
-    std::vector<int> input_pos;
-    int n = node_->attr().at("N").i();
-    input_pos.reserve(n);
-    for (int i = 0; i < n; i++) {
-      input_pos.push_back(i);
-    }
-    return input_pos;
-  }
-
- private:
-  bool IsEveryInputAfterNCHWToNHWC() const {
-    std::set<string> ops_format_agnostic = GetOpsFormatAgnostic();
-    for (const auto& input : node_->input()) {
-      auto input_node = node_map_->GetNode(input);
-      int port;
-      ParseNodeName(input, &port);
-      bool is_agnostic = ops_format_agnostic.find(input_node->op()) !=
-                         ops_format_agnostic.end();
-      if (IsPortDimsFour(*input_node, port) &&
-          ((IsNodeAfterNCHWToNHWC(*input_node) && is_agnostic) ||
-           IsTransposeNCHWToNHWC(input_node->name()))) {
-        continue;
-      }
-      return false;
-    }
-    return true;
-  }
-};
-
-class PadProcessor : public AgnosticNodeProcessor {
- public:
-  explicit PadProcessor(const OptimizeContext& opt_cxt)
-      : AgnosticNodeProcessor(opt_cxt) {}
-
- protected:
-  Status CustomizedProcessing() override {
-    DataType dtype = node_->attr().at("Tpaddings").type();
-    return UpdateOrTransformParamInput(1, "DataFormatVecPermute", dtype);
-  }
-};
-
-class ReverseProcessor : public AgnosticNodeProcessor {
- public:
-  explicit ReverseProcessor(const OptimizeContext& opt_cxt)
-      : AgnosticNodeProcessor(opt_cxt) {}
-
- protected:
-  Status CustomizedProcessing() override {
-    DataType dtype = node_->attr().at("Tidx").type();
-    return UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype);
-  }
-};
-
-class SplitProcessor : public AgnosticNodeProcessor {
- public:
-  explicit SplitProcessor(const OptimizeContext& opt_cxt)
-      : AgnosticNodeProcessor(opt_cxt) {
-    axis_node_pos_ = 0;
-  }
-
- protected:
-  std::vector<int> GetInputPos() const override { return {1}; }
-
-  std::set<int> GetOutputPos() const override {
-    std::set<int> output_pos{0};
-    if (HasAttribute(*node_, "num_split").ok()) {
-      for (int i = 1; i < node_->attr().at("num_split").i(); i++) {
-        output_pos.insert(i);
-      }
-    }
-    return output_pos;
-  }
-
-  Status CustomizedProcessing() override {
-    return UpdateOrTransformParamInput(axis_node_pos_, "DataFormatDimMap",
-                                       DT_INT32);
-  }
-
-  int axis_node_pos_;
-};
-
-class SplitVProcessor : public SplitProcessor {
- public:
-  explicit SplitVProcessor(const OptimizeContext& opt_cxt)
-      : SplitProcessor(opt_cxt) {
-    axis_node_pos_ = 2;
-  }
-
- protected:
-  std::vector<int> GetInputPos() const override { return {0}; }
-};
-
-class TernaryOpProcessor : public AgnosticNodeProcessor {
- public:
-  explicit TernaryOpProcessor(const OptimizeContext& opt_cxt)
-      : AgnosticNodeProcessor(opt_cxt) {}
-
- protected:
-  std::vector<int> GetInputPos() const override { return {0, 1, 2}; }
-};
-
-class SelectProcessor : public AgnosticNodeProcessor {
- public:
-  explicit SelectProcessor(const OptimizeContext& opt_cxt)
-      : AgnosticNodeProcessor(opt_cxt) {}
-
- protected:
-  bool ShouldProcess() const override {
-    auto input0 = node_map_->GetNode(node_->input(0));
-    int input0_port;
-    ParseNodeName(node_->input(0), &input0_port);
-    bool is_input0_scalar_vector_4d = IsPortDimsN(*input0, input0_port, 0) ||
-                                      IsPortDimsN(*input0, input0_port, 1) ||
-                                      IsPortDimsN(*input0, input0_port, 4);
-    return AgnosticNodeProcessor::ShouldProcess() && is_input0_scalar_vector_4d;
-  }
-
-  std::vector<int> GetInputPos() const override {
-    auto input0 = node_map_->GetNode(node_->input(0));
-    int input0_port;
-    ParseNodeName(node_->input(0), &input0_port);
-    // Input 0 could be a scalar, a vector with size matching the first
-    // dimension of input 1 and 2, or must have the same shape as input 1 and 2.
-    if (IsPortDimsFour(*input0, input0_port)) {
-      return {0, 1, 2};
-    } else {
-      return {1, 2};
-    }
-  }
-};
-
-class UnaryGradProcessor : public AgnosticNodeProcessor {
- public:
-  explicit UnaryGradProcessor(const OptimizeContext& opt_cxt)
-      : AgnosticNodeProcessor(opt_cxt) {}
-
- protected:
-  std::vector<int> GetInputPos() const override { return {0, 1}; }
-};
-
-class SliceProcessor : public AgnosticNodeProcessor {
- public:
-  explicit SliceProcessor(const OptimizeContext& opt_cxt)
-      : AgnosticNodeProcessor(opt_cxt) {
-    // Skip the first input, which is the data to be sliced.
-    start_ = 1;
-    // Note that we can't use node_->input_size() here because there
-    // could be control inputs.
-    end_ = 2;
-  }
-
- protected:
-  Status ProcessInputs() {
-    for (int i = start_; i <= end_; i++) {
-      DataType dtype = node_->attr().at("Index").type();
-      TF_RETURN_IF_ERROR(
-          UpdateOrTransformParamInput(i, "DataFormatVecPermute", dtype));
-    }
-    return Status::OK();
-  }
-
-  Status CustomizedProcessing() override { return ProcessInputs(); }
-
-  int start_;
-  int end_;
-};
-
-class StridedSliceProcessor : public SliceProcessor {
- public:
-  explicit StridedSliceProcessor(const OptimizeContext& opt_cxt)
-      : SliceProcessor(opt_cxt) {
-    start_ = 1;
-    end_ = 3;
-  }
-
- protected:
-  bool ShouldProcess() const override {
-    return AgnosticNodeProcessor::ShouldProcess() && IsOnlyBeginEndMask();
-  }
-
-  Status CustomizedProcessing() override {
-    TF_RETURN_IF_ERROR(UpdateMask("begin_mask"));
-    TF_RETURN_IF_ERROR(UpdateMask("end_mask"));
-    TF_RETURN_IF_ERROR(ProcessInputs());
-    return Status::OK();
-  }
-
- private:
-  bool IsMaskZero(const string& mask) const {
-    return node_->attr().at(mask).i() == 0;
-  }
-
-  bool IsOnlyBeginEndMask() const {
-    return IsMaskZero("ellipsis_mask") && IsMaskZero("new_axis_mask") &&
-           IsMaskZero("shrink_axis_mask");
-  }
-
-  Status UpdateMask(const string& mask) {
-    int i = node_->attr().at(mask).i();
-    if (i < 0 || i > 15) {
-      return errors::InvalidArgument("invalid mask value: ", i);
-    }
-    if (i == 0 || i == 1 || i == 14 || i == 15) return Status::OK();
-    switch (i) {
-      case 2:
-      case 3:
-        i += 2;
-        break;
-      case 4:
-      case 5:
-        i += 4;
-        break;
-      case 6:
-      case 7:
-        i += 6;
-        break;
-      case 8:
-      case 9:
-        i -= 6;
-        break;
-      case 10:
-      case 11:
-        i -= 4;
-        break;
-      case 12:
-      case 13:
-        i -= 2;
-        break;
-    }
-    node_->mutable_attr()->at(mask).set_i(i);
-    return Status::OK();
-  }
-};
-
-class StridedSliceGradProcessor : public StridedSliceProcessor {
- public:
-  explicit StridedSliceGradProcessor(const OptimizeContext& opt_cxt)
-      : StridedSliceProcessor(opt_cxt) {
-    start_ = 0;
-    end_ = 3;
-  }
-
- protected:
-  std::vector<int> GetInputPos() const override { return {4}; }
-};
-
-class SqueezeProcessor : public AgnosticNodeProcessor {
- public:
-  explicit SqueezeProcessor(const OptimizeContext& opt_cxt)
-      : AgnosticNodeProcessor(opt_cxt) {}
-
- protected:
-  bool ShouldProcess() const override {
-    bool is_dims_supported = (IsPortZeroDimsN(*node_, 2) && IsAlongHW()) ||
-                             (IsPortZeroDimsN(*node_, 1) && IsAlongNHW());
-    return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
-           IsInputConvertible() && is_dims_supported && IsOnGPU();
-  }
-
-  Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
-
-  Status CustomizedProcessing() override {
-    TF_RETURN_IF_ERROR(HasAttribute(*node_, "squeeze_dims"));
-    auto list = node_->mutable_attr()->at("squeeze_dims").mutable_list();
-    if (list->i_size() == 2) {
-      list->set_i(0, 2);
-      list->set_i(1, 3);
-    } else if (list->i_size() == 3) {
-      list->set_i(1, 2);
-      list->set_i(2, 3);
-    }
-    return Status::OK();
-  }
-
- private:
-  bool IsInputConvertible() const {
-    int input_port;
-    auto input = node_map_->GetNode(node_->input(0));
-    ParseNodeName(node_->input(0), &input_port);
-    if (input->attr().find("_output_shapes") != input->attr().end()) {
-      auto shape = input->attr().at("_output_shapes").list().shape(input_port);
-      if (shape.dim_size() != 4) {
-        return false;
-      }
-      if (shape.dim(1).size() == 1 && shape.dim(2).size() == 1) {
-        return true;
-      }
-      if (shape.dim(0).size() == 1 && shape.dim(1).size() == 1 &&
-          shape.dim(2).size() == 1) {
-        return true;
-      }
-    }
-    return false;
-  }
-
-  bool IsAlongAxis(const std::vector<int>& axis) const {
-    if (node_->attr().find("squeeze_dims") != node_->attr().end()) {
-      auto list = node_->attr().at("squeeze_dims").list();
-      // If list is empty, Squeeze op will squeeze all dimensions of size 1.
-      if (list.i_size() == 0) return true;
-      if (list.i_size() == axis.size()) {
-        bool along_axis = true;
-        for (int i = 0; i < axis.size(); i++) {
-          along_axis = along_axis && (list.i(i) == axis[i]);
-        }
-        if (along_axis) return true;
-      }
-    }
-    return false;
-  }
-  bool IsAlongHW() const { return IsAlongAxis({1, 2}); }
-  bool IsAlongNHW() const { return IsAlongAxis({0, 1, 2}); }
-};
-
-class ReduceProcessor : public AgnosticNodeProcessor {
- public:
-  explicit ReduceProcessor(const OptimizeContext& opt_cxt)
-      : AgnosticNodeProcessor(opt_cxt) {}
-
- protected:
-  bool ShouldProcess() const override {
-    auto input0 = node_map_->GetNode(node_->input(0));
-    int port;
-    ParseNodeName(node_->input(0), &port);
-    return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
-           IsPortDimsFour(*input0, port) && IsReduceAxisSupported() &&
-           IsOnGPU();
-  }
-
-  Status CustomizedProcessing() override {
-    if (IsReduceAxisSupported()) {
-      DataType dtype = node_->attr().at("Tidx").type();
-      TF_RETURN_IF_ERROR(
-          UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype));
-    }
-    return Status::OK();
-  }
-
-  Status AddLayoutTransposeToOutputs() override {
-    if (KeepDims()) {
-      return AddTransformToOutputs("Transpose");
-    }
-    return Status::OK();
-  }
-
- private:
-  bool IsReduceAxisSupported() const {
-    return KeepDims() || ((IsAlongAllFourDims() || IsAlongHWC() ||
-                           IsAlongNHW() || IsAlongHW() || IsAlongC()) &&
-                          !KeepDims());
-  }
-
-  bool IsAlongAxis(const std::vector<int>& axis) const {
-    auto axis_node = node_map_->GetNode(node_->input(1));
-    if (!IsConstant(*axis_node)) {
-      return false;
-    }
-    if (HasAttribute(*axis_node, "value").ok()) {
-      Tensor tensor;
-      auto success = tensor.FromProto(axis_node->attr().at({"value"}).tensor());
-      if (!success) {
-        LOG(ERROR) << "Failed to parse TensorProto.";
-      }
-      if (tensor.dims() == 1 && tensor.dim_size(0) == axis.size()) {
-        bool along_axis = true;
-        for (int i = 0; i < axis.size(); i++) {
-          along_axis = along_axis && (tensor.flat<int>()(i) == axis[i]);
-        }
-        if (along_axis) return true;
-      }
-    }
-    return false;
-  }
-
-  bool IsAlongAllFourDims() const { return IsAlongAxis({0, 1, 2, 3}); }
-
-  bool IsAlongHWC() const { return IsAlongAxis({1, 2, 3}); }
-
-  bool IsAlongNHW() const { return IsAlongAxis({0, 1, 2}); }
-
-  bool IsAlongHW() const { return IsAlongAxis({1, 2}); }
-
-  bool IsAlongC() const { return IsAlongAxis({3}); }
-
-  bool KeepDims() const { return node_->attr().at("keep_dims").b(); }
-};
-
-class SwitchProcessor : public AgnosticNodeProcessor {
- public:
-  explicit SwitchProcessor(const OptimizeContext& opt_cxt)
-      : AgnosticNodeProcessor(opt_cxt) {}
-
- protected:
-  std::set<int> GetOutputPos() const override {
-    std::set<int> output_pos;
-    const int num_outs =
-        node_->attr().count("num_outs") ? node_->attr().at("num_outs").i() : 2;
-    for (int i = 0; i < num_outs; i++) {
-      output_pos.insert(i);
-    }
-    return output_pos;
-  }
-};
-
-class TileProcessor : public AgnosticNodeProcessor {
- public:
-  explicit TileProcessor(const OptimizeContext& opt_cxt)
-      : AgnosticNodeProcessor(opt_cxt) {}
-
- protected:
-  Status CustomizedProcessing() override {
-    DataType dtype = node_->attr().at("Tmultiples").type();
-    return UpdateOrTransformParamInput(1, "DataFormatVecPermute", dtype);
-  }
-};
-
-class DataLayoutOptimizer : GraphProcessor {
- public:
-  explicit DataLayoutOptimizer(
-      const GraphProperties& graph_properties,
-      const VirtualPlacer& virtual_placer,
-      const LayoutOptimizer::TuningConfig& config,
-      const std::unordered_set<string>& nodes_to_preserve, GraphDef* graph,
-      NodeMap* node_map)
-      : GraphProcessor(graph_properties, virtual_placer, nodes_to_preserve,
-                       graph, node_map),
-        config_(config) {}
-
-  Status Optimize() {
-    VLOG(1) << "Number of nodes for original graph: " << graph_->node_size();
-    TF_RETURN_IF_ERROR(Expand());
-    VLOG(1) << "Number of nodes after Expand: " << graph_->node_size();
-    TF_RETURN_IF_ERROR(Collapse());
-    VLOG(1) << "Number of nodes after Collapse: " << graph_->node_size();
-    return Status::OK();
-  }
-
- private:
-  // Expand all nodes which is in NHWC, but supports NCHW or is layout agnostic.
-  Status Expand() {
-    int node_size_original = graph_->node_size();
-    std::unordered_set<string> devices_with_perm_const;
-
-    FrameView frame_view;
-    TF_RETURN_IF_ERROR(frame_view.InferFromGraph(*graph_));
-
-    // This is the first pass where we expand the nodes which support NCHW.
-    std::set<string> ops_format_supported = GetOpsFormatSupported();
-    for (int i = 0; i < node_size_original; i++) {
-      if (IsNodeByLayoutOptimizer(graph_->node(i).name())) {
-        return Status(error::INVALID_ARGUMENT,
-                      "The graph is already optimized by layout optimizer.");
-      }
-      if (ops_format_supported.find(graph_->node(i).op()) !=
-          ops_format_supported.end()) {
-        auto node = graph_->mutable_node(i);
-        bool is_in_frame = frame_view.IsInFrame(*node);
-        OptimizeContext opt_cxt(graph_, node, node_map_, graph_properties_,
-                                virtual_placer_, nodes_to_preserve_,
-                                is_in_frame, &devices_with_perm_const);
-        std::unique_ptr<NodeProcessor> node_processor;
-        if (IsAvgPoolGrad(*node)) {
-          node_processor.reset(new AvgPoolGradProcessor(opt_cxt));
-        } else if (IsBiasAddGrad(*node)) {
-          node_processor.reset(new BiasAddGradProcessor(opt_cxt));
-        } else if (IsConv2D(*node)) {
-          node_processor.reset(new Conv2DProcessor(opt_cxt, config_.no_gemm));
-        } else if (IsConv2DBackpropFilter(*node)) {
-          node_processor.reset(
-              new Conv2DBackpropFilterProcessor(opt_cxt, config_.no_gemm));
-        } else if (IsConv2DBackpropInput(*node)) {
-          node_processor.reset(
-              new Conv2DBackpropInputProcessor(opt_cxt, config_.no_gemm));
-        } else if (IsDepthwiseConv2dNative(*node)) {
-          node_processor.reset(new Conv2DProcessor(opt_cxt, true));
-        } else if (IsDepthwiseConv2dNativeBackpropFilter(*node)) {
-          node_processor.reset(
-              new Conv2DBackpropFilterProcessor(opt_cxt, true));
-        } else if (IsDepthwiseConv2dNativeBackpropInput(*node)) {
-          node_processor.reset(new Conv2DBackpropInputProcessor(opt_cxt, true));
-        } else if (IsFusedBatchNormGrad(*node)) {
-          node_processor.reset(new FusedBatchNormGradProcessor(opt_cxt));
-        } else if (IsMaxPoolV2(*node)) {
-          node_processor.reset(new MaxPoolV2Processor(opt_cxt));
-        } else if (IsMaxPoolGradV1(*node) || IsMaxPoolGradGradV1(*node)) {
-          node_processor.reset(new MaxPoolGradProcessor(opt_cxt));
-        } else if (IsMaxPoolGradV2(*node) || IsMaxPoolGradGradV2(*node)) {
-          node_processor.reset(new MaxPoolGradV2Processor(opt_cxt));
-        } else {
-          node_processor.reset(new NodeProcessor(opt_cxt));
-        }
-        TF_RETURN_IF_ERROR(node_processor->ConvertNode());
-      }
-    }
-
-    // This is the second pass where we expand layout-agnostic nodes. This pass
-    // only needs to be performed if at least one node in the previous pass is
-    // expanded.
-    if (graph_->node_size() > node_size_original) {
-      std::set<string> ops_format_agnostic = GetOpsFormatAgnostic();
-      for (int i = 0; i < graph_->node_size(); i++) {
-        if (ops_format_agnostic.find(graph_->node(i).op()) !=
-            ops_format_agnostic.end()) {
-          auto node = graph_->mutable_node(i);
-          bool is_in_frame = frame_view.IsInFrame(*node);
-          OptimizeContext opt_cxt(graph_, node, node_map_, graph_properties_,
-                                  virtual_placer_, nodes_to_preserve_,
-                                  is_in_frame, &devices_with_perm_const);
-          std::unique_ptr<NodeProcessor> node_processor;
-          if (IsAddN(*node)) {
-            node_processor.reset(new AddNProcessor(opt_cxt));
-          } else if (IsBetainc(*node)) {
-            node_processor.reset(new TernaryOpProcessor(opt_cxt));
-          } else if (IsBinaryOp(*node)) {
-            node_processor.reset(new BinaryOpProcessor(opt_cxt));
-          } else if (IsConcat(*node)) {
-            node_processor.reset(new ConcatProcessor(opt_cxt));
-          } else if (IsFill(*node)) {
-            node_processor.reset(new FillProcessor(opt_cxt));
-          } else if (IsHistogramSummary(*node)) {
-            node_processor.reset(new HistogramSummaryProcessor(opt_cxt));
-          } else if (IsIdentityN(*node)) {
-            node_processor.reset(new IdentityNProcessor(opt_cxt));
-          } else if (IsMerge(*node)) {
-            node_processor.reset(new MergeProcessor(opt_cxt));
-          } else if (IsPad(*node) || IsMirrorPad(*node) ||
-                     IsMirrorPadGrad(*node)) {
-            node_processor.reset(new PadProcessor(opt_cxt));
-          } else if (IsReduceOp(*node)) {
-            node_processor.reset(new ReduceProcessor(opt_cxt));
-          } else if (IsReverseV2(*node)) {
-            node_processor.reset(new ReverseProcessor(opt_cxt));
-          } else if (IsSelect(*node)) {
-            node_processor.reset(new SelectProcessor(opt_cxt));
-          } else if (IsSlice(*node)) {
-            node_processor.reset(new SliceProcessor(opt_cxt));
-          } else if (IsStridedSlice(*node)) {
-            node_processor.reset(new StridedSliceProcessor(opt_cxt));
-          } else if (IsShape(*node) || IsShapeN(*node)) {
-            node_processor.reset(new ShapeProcessor(opt_cxt));
-          } else if (IsSplit(*node)) {
-            node_processor.reset(new SplitProcessor(opt_cxt));
-          } else if (IsSplitV(*node)) {
-            node_processor.reset(new SplitVProcessor(opt_cxt));
-          } else if (IsSqueeze(*node)) {
-            node_processor.reset(new SqueezeProcessor(opt_cxt));
-          } else if (IsStridedSliceGrad(*node)) {
-            node_processor.reset(new StridedSliceGradProcessor(opt_cxt));
-          } else if (IsSwitch(*node)) {
-            node_processor.reset(new SwitchProcessor(opt_cxt));
-          } else if (IsTile(*node)) {
-            node_processor.reset(new TileProcessor(opt_cxt));
-          } else if (IsUnaryGrad(*node)) {
-            node_processor.reset(new UnaryGradProcessor(opt_cxt));
-          } else {
-            node_processor.reset(new AgnosticNodeProcessor(opt_cxt));
-          }
-          TF_RETURN_IF_ERROR(node_processor->ConvertNode());
-        }
-      }
-    }
-    return Status::OK();
-  }
-
-  // Remove all node pairs, where a NCHW-to-NHWC node is followed by
-  // a NHWC-to-NCHW node.
-  Status Collapse() {
-    std::unordered_set<string> nodes_removable;
-    for (int i = 0; i < graph_->node_size(); i++) {
-      auto node = graph_->mutable_node(i);
-      node->mutable_attr()->erase("_output_shapes");
-      if (IsTransposeNHWCToNCHW(node->name()) ||
-          IsDimMapNHWCToNCHW(node->name()) ||
-          IsVecPermuteNHWCToNCHW(node->name())) {
-        bool transpose_pair = IsTransposeNHWCToNCHW(node->name()) &&
-                              IsTransposeNCHWToNHWC(node->input(0));
-        bool dim_map_pair = IsDimMapNHWCToNCHW(node->name()) &&
-                            IsDimMapNCHWToNHWC(node->input(0));
-        bool vec_permute_pair = IsVecPermuteNHWCToNCHW(node->name()) &&
-                                IsVecPermuteNCHWToNHWC(node->input(0));
-        if (transpose_pair || dim_map_pair || vec_permute_pair) {
-          const string& trans_first = node->input(0);
-          const string& trans_second = node->name();
-          auto outputs = node_map_->GetOutputs(trans_second);
-          CHECK(outputs.size() == 1)
-              << "There is always only a single output for a Transpose node, "
-              << "due to the way it is added by NodeProcessor.";
-          NodeDef* output = *outputs.begin();
-          string input = node_map_->GetNode(trans_first)->input(0);
-          for (int i = 0; i < output->input_size(); i++) {
-            if (output->input(i).compare(trans_second) == 0) {
-              *output->mutable_input(i) = input;
-              break;
-            }
-          }
-          nodes_removable.insert(trans_first);
-          nodes_removable.insert(trans_second);
-        }
-      }
-    }
-    graph_->mutable_node()->erase(
-        std::remove_if(
-            graph_->mutable_node()->begin(), graph_->mutable_node()->end(),
-            [nodes_removable](const NodeDef& node) {
-              return nodes_removable.find(node.name()) != nodes_removable.end();
-            }),
-        graph_->mutable_node()->end());
-    return Status::OK();
-  }
-
-  const LayoutOptimizer::TuningConfig& config_;
-};
-
-int GetNumGPUs(const Cluster& cluster) {
-  auto devices = cluster.GetDevices();
-  int num_gpus = 0;
-  for (const auto& device : devices) {
-    if (device.second.type() == "GPU") {
-      num_gpus++;
-    }
-  }
-  return num_gpus;
-}
-}  // namespace
-
-Status LayoutOptimizer::Tune(const GrapplerItem& item,
-                             const GraphProperties& graph_properties,
-                             const TuningConfig& config, GraphDef* output) {
-  auto status = graph_properties.AnnotateOutputShapes(
-      output, /*allow_symbolic_shapes=*/true);
-  if (!status.ok()) {
-    VLOG(1) << "Annotate shape return status: " << status.ToString();
-    *output = item.graph;
-    return status;
-  }
-  NodeMap node_map(output);
-  DataLayoutOptimizer layout_optimizer(graph_properties, *virtual_placer_,
-                                       config, nodes_to_preserve_, output,
-                                       &node_map);
-  status = layout_optimizer.Optimize();
-  return status;
-}
-
-Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
-                                 GraphDef* output) {
-  if (cluster == nullptr) {
-    LOG(WARNING) << "layout optimizer was called with cluster == nullptr";
-    return errors::Aborted("cluster == nullptr.");
-  }
-  if (GetNumGPUs(*cluster) < 1) {
-    return errors::Aborted(
-        "No GPUs found: LayoutOptimizer is currently only tuned for GPU.");
-  }
-
-  GraphProperties graph_properties(item);
-  TF_RETURN_IF_ERROR(
-      graph_properties.InferStatically(/*assume_valid_feeds=*/false,
-                                       /*aggressive_shape_inference=*/false,
-                                       /*include_tensor_values=*/false));
-  GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
-
-  virtual_placer_.reset(new VirtualPlacer(cluster->GetDevices()));
-  nodes_to_preserve_ = item.NodesToPreserve();
-
-  TuningConfig config;
-  config.no_gemm = true;
-  // TODO(yaozhang): Enable tuning with various TuningConfig choices with
-  // the measurement-based estimator.
-  Status status = Tune(item, graph_properties, config, output);
-  if (!status.ok()) {
-    *output = item.graph;
-  }
-  return status;
-}
-
-}  // end namespace grappler
-}  // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.h b/tensorflow/core/grappler/optimizers/layout_optimizer.h
deleted file mode 100644
index ae6307e..0000000
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.h
+++ /dev/null
@@ -1,59 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_LAYOUT_OPTIMIZER_H_
-#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_LAYOUT_OPTIMIZER_H_
-
-#include "tensorflow/core/grappler/costs/graph_properties.h"
-#include "tensorflow/core/grappler/costs/virtual_placer.h"
-#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
-
-namespace tensorflow {
-namespace grappler {
-// Convert the NHWC layout to NCHW for Conv-related ops on GPUs.
-class LayoutOptimizer : public GraphOptimizer {
- public:
-  LayoutOptimizer() {}
-  ~LayoutOptimizer() override {}
-
-  string name() const override { return "layout"; };
-
-  bool UsesFunctionLibrary() const override { return false; }
-
-  struct TuningConfig {
-    // If true, do not use the NHWC GEMM implementation. When filter size is
-    // one or filter size is equal to input image size,
-    // the NHWC implementation of Conv2D, Conv2DBackpropInput, and
-    // Conv2DBackpropFilter will use a specialized GEMM implementation, which is
-    // usually faster than the NCHW implementation. The downside is that this
-    // might result in more non-cancellable layout conversion nodes (implemented
-    // by the Transpose op).
-    bool no_gemm;
-  };
-
-  Status Optimize(Cluster* cluster, const GrapplerItem& item,
-                  GraphDef* output) override;
-
- private:
-  std::unique_ptr<VirtualPlacer> virtual_placer_;
-  std::unordered_set<string> nodes_to_preserve_;
-  Status Tune(const GrapplerItem& item, const GraphProperties& graph_properties,
-              const TuningConfig& config, GraphDef* output);
-};
-
-}  // end namespace grappler
-}  // end namespace tensorflow
-
-#endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_LAYOUT_OPTIMIZER_H_
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
deleted file mode 100644
index f8aef8a..0000000
--- a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
+++ /dev/null
@@ -1,1297 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
-
-#include "tensorflow/cc/ops/standard_ops.h"
-#include "tensorflow/core/framework/common_shape_fns.h"
-#include "tensorflow/core/framework/kernel_shape_util.h"
-#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/framework/tensor_testutil.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/grappler/clusters/single_machine.h"
-#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
-#include "tensorflow/core/grappler/costs/virtual_placer.h"
-#include "tensorflow/core/grappler/devices.h"
-#include "tensorflow/core/grappler/grappler_item.h"
-#include "tensorflow/core/grappler/utils.h"
-#include "tensorflow/core/grappler/utils/grappler_test.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/protobuf/device_properties.pb.h"
-
-namespace tensorflow {
-namespace grappler {
-namespace {
-
-class LayoutOptimizerTest : public GrapplerTest {
- protected:
-  void SetUp() override {
-    gpu_available_ = GetNumAvailableGPUs() > 0;
-
-    if (gpu_available_) {
-      virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 1));
-    } else {
-      DeviceProperties device_properties;
-      device_properties.set_type("GPU");
-      device_properties.mutable_environment()->insert({"architecture", "6"});
-      virtual_cluster_.reset(
-          new VirtualCluster({{"/GPU:1", device_properties}}));
-    }
-    TF_CHECK_OK(virtual_cluster_->Provision());
-  }
-
-  void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); }
-
-  template <typename T = float>
-  Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
-                      const string& padding) {
-    return SimpleConv2D<T>(s, input_size, filter_size, padding, "");
-  }
-
-  template <typename T = float>
-  Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
-                      const string& padding, const string& device) {
-    int batch_size = 8;
-    int input_height = input_size;
-    int input_width = input_size;
-    int input_depth = 3;
-    int filter_count = 2;
-    int stride = 1;
-    TensorShape input_shape(
-        {batch_size, input_height, input_width, input_depth});
-    Tensor input_data(DataTypeToEnum<T>::value, input_shape);
-    test::FillIota<T>(&input_data, static_cast<T>(1));
-    Output input =
-        ops::Const(s->WithOpName("Input"), Input::Initializer(input_data));
-
-    TensorShape filter_shape(
-        {filter_size, filter_size, input_depth, filter_count});
-    Tensor filter_data(DataTypeToEnum<T>::value, filter_shape);
-    test::FillIota<T>(&filter_data, static_cast<T>(1));
-    Output filter =
-        ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data));
-
-    ops::Conv2D::Attrs attrs;
-    const int kExplicitPaddings[] = {0, 0, 1, 2, 3, 4, 0, 0};
-    if (padding == "EXPLICIT") {
-      attrs = attrs.ExplicitPaddings(kExplicitPaddings);
-    }
-
-    Output conv = ops::Conv2D(s->WithOpName("Conv2D").WithDevice(device), input,
-                              filter, {1, stride, stride, 1}, padding, attrs);
-    return conv;
-  }
-
-  Output SimpleConv2DBackpropInput(tensorflow::Scope* s, int input_size,
-                                   int filter_size, const string& padding) {
-    return SimpleConv2DBackpropInput(s, input_size, filter_size, padding, true,
-                                     true);
-  }
-
-  Output SimpleConv2DBackpropInput(tensorflow::Scope* s, int input_size,
-                                   int filter_size, const string& padding,
-                                   bool const_input_size, bool dilated) {
-    int batch_size = 128;
-    int input_height = input_size;
-    int input_width = input_size;
-    int input_depth = 3;
-    int filter_count = 2;
-    int stride = 1;
-    int dilation = dilated ? 2 : 1;
-    int64_t padding_top = 1;
-    int64_t padding_bottom = 2;
-    int64_t padding_left = 3;
-    int64_t padding_right = 4;
-    int64_t output_height;
-    int64_t output_width;
-    Padding padding_enum;
-    if (padding == "SAME") {
-      padding_enum = SAME;
-    } else if (padding == "VALID") {
-      padding_enum = VALID;
-    } else {
-      CHECK_EQ(padding, "EXPLICIT");
-      padding_enum = EXPLICIT;
-    }
-    TF_CHECK_OK(GetWindowedOutputSizeVerboseV2(
-        input_height, filter_size, dilation, stride, padding_enum,
-        &output_height, &padding_top, &padding_bottom));
-    TF_CHECK_OK(GetWindowedOutputSizeVerboseV2(
-        input_width, filter_size, dilation, stride, padding_enum, &output_width,
-        &padding_left, &padding_right));
-    TensorShape input_sizes_shape({4});
-    Tensor input_data(DT_INT32, input_sizes_shape);
-    test::FillValues<int>(&input_data,
-                          {batch_size, input_height, input_width, input_depth});
-    Output input_sizes =
-        ops::Const(s->WithOpName("InputSizes"), Input::Initializer(input_data));
-
-    TensorShape filter_shape(
-        {filter_size, filter_size, input_depth, filter_count});
-    Output filter =
-        ops::Variable(s->WithOpName("Filter"), filter_shape, DT_FLOAT);
-
-    TensorShape output_shape(
-        {batch_size, output_height, output_width, filter_count});
-    Tensor output_data(DT_FLOAT, output_shape);
-    test::FillIota<float>(&output_data, 1.0f);
-    Output output =
-        ops::Const(s->WithOpName("Output"), Input::Initializer(output_data));
-
-    Output conv_backprop_input;
-    Output input_sizes_i =
-        ops::Identity(s->WithOpName("InputSizesIdentity"), input_sizes);
-    std::vector<int> dilations{1, dilation, dilation, 1};
-    std::vector<int> explicit_paddings;
-    if (padding == "EXPLICIT") {
-      explicit_paddings = {0,
-                           0,
-                           static_cast<int>(padding_top),
-                           static_cast<int>(padding_bottom),
-                           static_cast<int>(padding_left),
-                           static_cast<int>(padding_right),
-                           0,
-                           0};
-    }
-    auto attrs =
-        ops::Conv2DBackpropInput::Attrs().Dilations(dilations).ExplicitPaddings(
-            explicit_paddings);
-    if (const_input_size) {
-      conv_backprop_input = ops::Conv2DBackpropInput(
-          s->WithOpName("Conv2DBackpropInput"), input_sizes, filter, output,
-          {1, stride, stride, 1}, padding, attrs);
-    } else {
-      conv_backprop_input = ops::Conv2DBackpropInput(
-          s->WithOpName("Conv2DBackpropInput"), input_sizes_i, filter, output,
-          {1, stride, stride, 1}, padding, attrs);
-    }
-    return conv_backprop_input;
-  }
-
-  Tensor GetAttrValue(const NodeDef& node) {
-    Tensor tensor;
-    CHECK(tensor.FromProto(node.attr().at({"value"}).tensor()));
-    return tensor;
-  }
-
-  TensorShape GetAttrShape(const NodeDef& node) {
-    return TensorShape(node.attr().at({"shape"}).shape());
-  }
-
-  Output SimpleFusedBatchNormGrad(tensorflow::Scope* s, bool is_training) {
-    int batch_size = 16;
-    int input_height = 8;
-    int input_width = 8;
-    int input_channels = 3;
-    TensorShape shape({batch_size, input_height, input_width, input_channels});
-    Tensor data(DT_FLOAT, shape);
-    test::FillIota<float>(&data, 1.0f);
-    Output x = ops::Const(s->WithOpName("Input"), Input::Initializer(data));
-    Output y_backprop =
-        ops::Const(s->WithOpName("YBackprop"), Input::Initializer(data));
-
-    TensorShape shape_vector({input_channels});
-    Tensor data_vector(DT_FLOAT, shape_vector);
-    test::FillIota<float>(&data_vector, 2.0f);
-    Output scale =
-        ops::Const(s->WithOpName("Scale"), Input::Initializer(data_vector));
-    Output reserve1 =
-        ops::Const(s->WithOpName("Reserve1"), Input::Initializer(data_vector));
-    Output reserve2 =
-        ops::Const(s->WithOpName("Reserve2"), Input::Initializer(data_vector));
-
-    ops::FusedBatchNormGrad::Attrs attrs;
-    attrs.is_training_ = is_training;
-    auto output =
-        ops::FusedBatchNormGrad(s->WithOpName("FusedBatchNormGrad"), y_backprop,
-                                x, scale, reserve1, reserve2, attrs);
-    return output.x_backprop;
-  }
-
-  std::unique_ptr<Cluster> virtual_cluster_;
-  bool gpu_available_;
-};
-
-TEST_F(LayoutOptimizerTest, Conv2DBackpropInput) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "EXPLICIT");
-  Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  string input_name = "Conv2DBackpropInput-0-LayoutOptimizer";
-  auto input_sizes_node = node_map.GetNode(input_name);
-  CHECK(input_sizes_node);
-  auto conv2d_backprop_node = node_map.GetNode("Conv2DBackpropInput");
-  CHECK(conv2d_backprop_node);
-  EXPECT_EQ(input_name, conv2d_backprop_node->input(0));
-  auto input_sizes = GetAttrValue(*input_sizes_node);
-  Tensor input_sizes_expected(DT_INT32, {4});
-  test::FillValues<int>(&input_sizes_expected, {128, 3, 7, 7});
-  test::ExpectTensorEqual<int>(input_sizes_expected, input_sizes);
-
-  if (gpu_available_) {
-    TensorShape filter_shape = GetAttrShape(*node_map.GetNode("Filter"));
-    Tensor filter_data = GenerateRandomTensor<DT_FLOAT>(filter_shape);
-    std::vector<string> fetch = {"Fetch"};
-    auto tensors_expected =
-        EvaluateNodes(item.graph, fetch, {{"Filter", filter_data}});
-    auto tensors = EvaluateNodes(output, fetch, {{"Filter", filter_data}});
-    EXPECT_EQ(1, tensors_expected.size());
-    EXPECT_EQ(1, tensors.size());
-    test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
-  }
-}
-
-TEST_F(LayoutOptimizerTest, Conv2DBackpropInputNonConstInputSizes) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", false, false);
-  Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto conv2d_backprop_node = node_map.GetNode("Conv2DBackpropInput");
-  CHECK(conv2d_backprop_node);
-  EXPECT_EQ(conv2d_backprop_node->input(0),
-            "Conv2DBackpropInput-0-VecPermuteNHWCToNCHW-LayoutOptimizer");
-  auto input_sizes_node = node_map.GetNode(
-      "Conv2DBackpropInput-0-VecPermuteNHWCToNCHW-LayoutOptimizer");
-  CHECK(input_sizes_node);
-  EXPECT_EQ(input_sizes_node->input(0), "InputSizesIdentity");
-  EXPECT_EQ(input_sizes_node->op(), "DataFormatVecPermute");
-}
-
-TEST_F(LayoutOptimizerTest, FilterSizeIsOne) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 2, 1, "SAME");
-  Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
-}
-
-TEST_F(LayoutOptimizerTest, FilterSizeNotOne) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 2, 1, "SAME");
-  Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
-}
-
-TEST_F(LayoutOptimizerTest, EqualSizeWithValidPadding) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 2, 2, "VALID");
-  Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
-}
-
-TEST_F(LayoutOptimizerTest, EqualSizeWithSamePadding) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 2, 2, "SAME");
-  Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
-}
-
-TEST_F(LayoutOptimizerTest, NotEqualSizeWithValidPadding) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
-  Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
-}
-
-TEST_F(LayoutOptimizerTest, ExplicitPadding) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 4, 2, "EXPLICIT");
-  Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
-}
-
-TEST_F(LayoutOptimizerTest, DataTypeIsInt32) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D<int32>(&s, 4, 2, "EXPLICIT");
-  Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  EXPECT_FALSE(
-      node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
-}
-
-TEST_F(LayoutOptimizerTest, Pad) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
-  auto c = ops::Const(s.WithOpName("c"), {1, 2, 3, 4, 5, 6, 7, 8}, {4, 2});
-  auto p = ops::Pad(s.WithOpName("p"), conv, c);
-  auto o = ops::Identity(s.WithOpName("o"), p);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-
-  auto pad = node_map.GetNode("p");
-  EXPECT_EQ(pad->input(0), "Conv2D");
-
-  auto pad_const = node_map.GetNode("p-1-LayoutOptimizer");
-  EXPECT_TRUE(pad_const);
-  EXPECT_TRUE(pad_const->attr().find("value") != pad_const->attr().end());
-  Tensor tensor;
-  EXPECT_TRUE(
-      tensor.FromProto(pad_const->mutable_attr()->at({"value"}).tensor()));
-  Tensor tensor_expected(DT_INT32, {4, 2});
-  test::FillValues<int>(&tensor_expected, {1, 2, 7, 8, 3, 4, 5, 6});
-  test::ExpectTensorEqual<int>(tensor_expected, tensor);
-}
-
-TEST_F(LayoutOptimizerTest, Connectivity) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
-  auto i1 = ops::Identity(s.WithOpName("i1"), conv);
-  auto i2 = ops::Identity(s.WithOpName("i2"), i1);
-  auto i3 = ops::Identity(s.WithOpName("i3"), i2);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  // Make the graph not in topological order to test the handling of multi-hop
-  // connectivity (here we say two nodes are connected if all nodes in the
-  // middle are layout agnostic). If the graph is already in topological order,
-  // the problem is easier, where layout optimizer only needs to check
-  // single-hop connectivity.
-  NodeMap node_map_original(&item.graph);
-  auto node_i1 = node_map_original.GetNode("i1");
-  auto node_i2 = node_map_original.GetNode("i2");
-  node_i2->Swap(node_i1);
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map_output(&output);
-  auto node_i2_output = node_map_output.GetNode("i2");
-  // Layout optimizer should process i2, as it detects i2 is connected with the
-  // Conv2D node two hops away. Similarly i1 is processed as well, as i1 is
-  // directly connected to the Conv2D node. The two added transposes between
-  // i1 and i2 should cancel each other, and as a result i2 is directly
-  // connected to i1.
-  EXPECT_EQ(node_i2_output->input(0), "i1");
-}
-
-TEST_F(LayoutOptimizerTest, ConnectivityBinaryOpWithInputScalarAnd4D) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
-  auto i1 = ops::Identity(s.WithOpName("i1"), conv);
-  auto i2 = ops::Identity(s.WithOpName("i2"), i1);
-  auto scalar_sub = ops::Const(s.WithOpName("scalar_sub"), 3.0f, {});
-  auto sub = ops::Sub(s.WithOpName("sub"), scalar_sub, i2);
-  auto i3 = ops::Identity(s.WithOpName("i3"), sub);
-  auto i4 = ops::Identity(s.WithOpName("i4"), i3);
-  auto i5 = ops::Identity(s.WithOpName("i5"), i4);
-  auto scalar_mul = ops::Const(s.WithOpName("scalar_mul"), 3.0f, {});
-  auto mul = ops::Mul(s.WithOpName("mul"), scalar_mul, i5);
-  auto i6 = ops::Identity(s.WithOpName("i6"), mul);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  // Make the graph not in topological order to test the handling of multi-hop
-  // connectivity (here we say two nodes are connected if all nodes in the
-  // middle are layout agnostic). If the graph is already in topological order,
-  // the problem is easier, where layout optimizer only needs to check
-  // single-hop connectivity.
-  NodeMap node_map_original(&item.graph);
-  auto node_i1 = node_map_original.GetNode("i1");
-  auto node_mul = node_map_original.GetNode("mul");
-  node_mul->Swap(node_i1);
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map_output(&output);
-  auto mul_node = node_map_output.GetNode("mul");
-  EXPECT_EQ(mul_node->input(0), "scalar_mul");
-  EXPECT_EQ(mul_node->input(1), "i5");
-}
-
-TEST_F(LayoutOptimizerTest, PreserveFetch) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
-  auto i = ops::Identity(s.WithOpName("i"), conv);
-  GrapplerItem item;
-  item.fetch.push_back("Conv2D");
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto conv_node = node_map.GetNode("Conv2D");
-  EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
-}
-
-TEST_F(LayoutOptimizerTest, EmptyDevice) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
-  Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto conv_node = node_map.GetNode("Conv2D");
-  EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NCHW");
-}
-
-TEST_F(LayoutOptimizerTest, GPUDevice) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv =
-      SimpleConv2D(&s, 4, 2, "VALID", "/job:w/replica:0/task:0/device:gpu:0");
-  Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto conv_node = node_map.GetNode("Conv2D");
-  EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NCHW");
-}
-
-TEST_F(LayoutOptimizerTest, CPUDeviceLowercase) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv =
-      SimpleConv2D(&s, 4, 2, "VALID", "/job:w/replica:0/task:0/device:cpu:0");
-  Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto conv_node = node_map.GetNode("Conv2D");
-  EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
-}
-
-TEST_F(LayoutOptimizerTest, CPUDeviceUppercase) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 4, 2, "VALID", "/CPU:0");
-  Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto conv_node = node_map.GetNode("Conv2D");
-  EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
-}
-
-TEST_F(LayoutOptimizerTest, FusedBatchNormGradTrainingTrue) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto x_backprop = SimpleFusedBatchNormGrad(&s, true);
-  Output fetch = ops::Identity(s.WithOpName("Fetch"), {x_backprop});
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto conv_node = node_map.GetNode("FusedBatchNormGrad");
-  EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NCHW");
-}
-
-TEST_F(LayoutOptimizerTest, FusedBatchNormGradTrainingFalse) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto x_backprop = SimpleFusedBatchNormGrad(&s, false);
-  Output fetch = ops::Identity(s.WithOpName("Fetch"), {x_backprop});
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto conv_node = node_map.GetNode("FusedBatchNormGrad");
-  EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
-}
-
-TEST_F(LayoutOptimizerTest, SplitDimC) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 5, 2, "VALID");
-  auto c = ops::Const(s.WithOpName("c"), 3, {});
-  auto split = ops::Split(s.WithOpName("split"), c, conv, 2);
-  auto i = ops::Identity(s.WithOpName("i"), split[0]);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto split_node = node_map.GetNode("split");
-  EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
-  EXPECT_EQ(split_node->input(1), "Conv2D");
-  auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
-  EXPECT_EQ(split_const->op(), "Const");
-  EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 1);
-}
-
-TEST_F(LayoutOptimizerTest, SplitDimH) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 6, 2, "SAME");
-  auto c = ops::Const(s.WithOpName("c"), 1, {});
-  auto split = ops::Split(s.WithOpName("split"), c, conv, 2);
-  auto i = ops::Identity(s.WithOpName("i"), split[0]);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto split_node = node_map.GetNode("split");
-  EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
-  EXPECT_EQ(split_node->input(1), "Conv2D");
-  auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
-  EXPECT_EQ(split_const->op(), "Const");
-  EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 2);
-}
-
-TEST_F(LayoutOptimizerTest, SplitDimW) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 5, 2, "VALID");
-  auto c = ops::Const(s.WithOpName("c"), 2, {});
-  auto split = ops::Split(s.WithOpName("split"), c, conv, 2);
-  auto i = ops::Identity(s.WithOpName("i"), split[0]);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto split_node = node_map.GetNode("split");
-  EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
-  EXPECT_EQ(split_node->input(1), "Conv2D");
-  auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
-  EXPECT_EQ(split_const->op(), "Const");
-  EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 3);
-}
-
-TEST_F(LayoutOptimizerTest, SplitDimN) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 5, 2, "VALID");
-  auto c = ops::Const(s.WithOpName("c"), 0, {});
-  auto split = ops::Split(s.WithOpName("split"), c, conv, 2);
-  auto i = ops::Identity(s.WithOpName("i"), split[0]);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto split_node = node_map.GetNode("split");
-  EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
-  EXPECT_EQ(split_node->input(1), "Conv2D");
-  auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
-  EXPECT_EQ(split_const->op(), "Const");
-  EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 0);
-}
-
-TEST_F(LayoutOptimizerTest, SplitNonConstDim) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 5, 2, "VALID");
-  auto c = ops::Const(s.WithOpName("c"), 0, {});
-  auto i1 = ops::Identity(s.WithOpName("i1"), c);
-  auto split = ops::Split(s.WithOpName("split"), i1, conv, 2);
-  auto i2 = ops::Identity(s.WithOpName("i"), split[0]);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto split_node = node_map.GetNode("split");
-  EXPECT_EQ(split_node->input(0), "split-0-DimMapNHWCToNCHW-LayoutOptimizer");
-  EXPECT_EQ(split_node->input(1), "Conv2D");
-  auto map_node = node_map.GetNode("split-0-DimMapNHWCToNCHW-LayoutOptimizer");
-  EXPECT_EQ(map_node->op(), "DataFormatDimMap");
-  EXPECT_EQ(map_node->input(0), "i1");
-}
-
-TEST_F(LayoutOptimizerTest, SplitSamePortToMultipleInputsOfSameNode) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 5, 2, "VALID");
-  auto axis = ops::Const(s.WithOpName("axis"), 3);
-  auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
-  auto concat =
-      ops::Concat(s.WithOpName("concat"), {split[1], split[1], split[1]}, axis);
-  auto o = ops::Identity(s.WithOpName("o"), concat);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto concat_node = node_map.GetNode("concat");
-  EXPECT_EQ(concat_node->input(0), "split:1");
-  EXPECT_EQ(concat_node->input(1), "split:1");
-  EXPECT_EQ(concat_node->input(2), "split:1");
-  EXPECT_EQ(concat_node->input(3), "concat-3-LayoutOptimizer");
-  auto concat_dim = node_map.GetNode("concat-3-LayoutOptimizer");
-  EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 1);
-}
-
-TEST_F(LayoutOptimizerTest, ConcatDimH) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 4, 2, "SAME");
-  auto axis = ops::Const(s.WithOpName("axis"), 1);
-  auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
-  auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
-  auto o = ops::Identity(s.WithOpName("o"), concat);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto concat_node = node_map.GetNode("concat");
-  EXPECT_EQ(concat_node->input(0), "split");
-  EXPECT_EQ(concat_node->input(1), "split:1");
-  EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
-  auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
-  EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 2);
-}
-
-TEST_F(LayoutOptimizerTest, ConcatNonConst) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 4, 2, "SAME");
-  auto axis = ops::Const(s.WithOpName("axis"), 1);
-  auto i = ops::Identity(s.WithOpName("i"), axis);
-  auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
-  auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, i);
-  auto o = ops::Identity(s.WithOpName("o"), concat);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto concat_node = node_map.GetNode("concat");
-  EXPECT_EQ(concat_node->input(0), "split");
-  EXPECT_EQ(concat_node->input(1), "split:1");
-  EXPECT_EQ(concat_node->input(2), "concat-2-DimMapNHWCToNCHW-LayoutOptimizer");
-  auto concat_dim =
-      node_map.GetNode("concat-2-DimMapNHWCToNCHW-LayoutOptimizer");
-  EXPECT_EQ(concat_dim->op(), "DataFormatDimMap");
-  EXPECT_EQ(concat_dim->input(0), "i");
-}
-
-TEST_F(LayoutOptimizerTest, ConcatDimW) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 4, 2, "SAME");
-  auto axis = ops::Const(s.WithOpName("axis"), 2);
-  auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
-  auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
-  auto o = ops::Identity(s.WithOpName("o"), concat);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto concat_node = node_map.GetNode("concat");
-  EXPECT_EQ(concat_node->input(0), "split");
-  EXPECT_EQ(concat_node->input(1), "split:1");
-  EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
-  auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
-  EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 3);
-}
-
-TEST_F(LayoutOptimizerTest, ConcatDimN) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
-  auto axis = ops::Const(s.WithOpName("axis"), 0);
-  auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
-  auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
-  auto o = ops::Identity(s.WithOpName("o"), concat);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto concat_node = node_map.GetNode("concat");
-  EXPECT_EQ(concat_node->input(0), "split");
-  EXPECT_EQ(concat_node->input(1), "split:1");
-  EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
-  auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
-  EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 0);
-}
-
-TEST_F(LayoutOptimizerTest, ConcatDimC) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
-  auto axis = ops::Const(s.WithOpName("axis"), 3);
-  auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
-  auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
-  auto o = ops::Identity(s.WithOpName("o"), concat);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto concat_node = node_map.GetNode("concat");
-  EXPECT_EQ(concat_node->input(0), "split");
-  EXPECT_EQ(concat_node->input(1), "split:1");
-  EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
-  auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
-  EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 1);
-}
-
-TEST_F(LayoutOptimizerTest, Sum) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
-  auto reduction_indices =
-      ops::Const(s.WithOpName("reduction_indices"), {0, 1, 2}, {3});
-  auto sum = ops::Sum(s.WithOpName("sum"), conv, reduction_indices);
-  auto o = ops::Identity(s.WithOpName("o"), sum);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  // TODO(yaozhang): enable SumProcessor with auto-tuning. Currently disabled
-  // because of the worse performance in some cases.
-  /*
-  NodeMap node_map(&output);
-  auto sum_node = node_map.GetNode("sum");
-  EXPECT_EQ(sum_node->input(0), "Conv2D");
-  EXPECT_EQ(sum_node->input(1), "LayoutOptimizer-sum-reduction_indices");
-  auto sum_const = node_map.GetNode("LayoutOptimizer-sum-reduction_indices");
-  Tensor tensor;
-  EXPECT_TRUE(
-      tensor.FromProto(sum_const->mutable_attr()->at({"value"}).tensor()));
-  Tensor tensor_expected(DT_INT32, {3});
-  test::FillValues<int>(&tensor_expected, {0, 2, 3});
-  test::ExpectTensorEqual<int>(tensor_expected, tensor);
-  */
-}
-
-TEST_F(LayoutOptimizerTest, MulScalarAnd4D) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
-  auto scalar = ops::Const(s.WithOpName("scalar"), 3.0f, {});
-  auto mul = ops::Mul(s.WithOpName("mul"), scalar, conv);
-  auto o = ops::Identity(s.WithOpName("o"), mul);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto mul_node = node_map.GetNode("mul");
-  EXPECT_EQ(mul_node->input(0), "scalar");
-  EXPECT_EQ(mul_node->input(1), "Conv2D");
-}
-
-TEST_F(LayoutOptimizerTest, Mul4DAndScalar) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
-  auto scalar = ops::Const(s.WithOpName("scalar"), 3.0f, {});
-  auto mul = ops::Mul(s.WithOpName("mul"), conv, scalar);
-  auto o = ops::Identity(s.WithOpName("o"), mul);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto mul_node = node_map.GetNode("mul");
-  EXPECT_EQ(mul_node->input(0), "Conv2D");
-  EXPECT_EQ(mul_node->input(1), "scalar");
-}
-
-TEST_F(LayoutOptimizerTest, Mul4DAndUnknownRank) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
-  auto unknown_rank =
-      ops::Placeholder(s.WithOpName("unknown"), DT_FLOAT,
-                       ops::Placeholder::Shape(PartialTensorShape()));
-  Output c = ops::Const(s.WithOpName("c"), 3.0f, {8, 2, 2, 2});
-  Output mul = ops::Mul(s.WithOpName("mul"), conv, unknown_rank);
-  auto o = ops::AddN(s.WithOpName("o"), {mul, c});
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto mul_node = node_map.GetNode("mul");
-  // Node mul should not be processed by layout optimizer, because one of its
-  // inputs is of unknown rank.
-  EXPECT_EQ(mul_node->input(0),
-            "Conv2D-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
-  EXPECT_EQ(mul_node->input(1), "unknown");
-}
-
-TEST_F(LayoutOptimizerTest, Mul4DAnd4D) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
-  auto i = ops::Identity(s.WithOpName("i"), conv);
-  auto mul = ops::Mul(s.WithOpName("mul"), conv, i);
-  auto o = ops::Identity(s.WithOpName("o"), mul);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto mul_node = node_map.GetNode("mul");
-  EXPECT_EQ(mul_node->input(0), "Conv2D");
-  EXPECT_EQ(mul_node->input(1), "i");
-}
-
-TEST_F(LayoutOptimizerTest, Mul4DAndVector) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
-  auto vector = ops::Const(s.WithOpName("vector"), {3.0f, 7.0f}, {2});
-  auto mul = ops::Mul(s.WithOpName("mul"), conv, vector);
-  auto o = ops::Identity(s.WithOpName("o"), mul);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto mul_node = node_map.GetNode("mul");
-  EXPECT_EQ(mul_node->input(0), "Conv2D");
-  EXPECT_EQ(mul_node->input(1), "mul-1-ReshapeNHWCToNCHW-LayoutOptimizer");
-  auto mul_const = node_map.GetNode("mul-1-ReshapeConst-LayoutOptimizer");
-  Tensor tensor;
-  EXPECT_TRUE(
-      tensor.FromProto(mul_const->mutable_attr()->at({"value"}).tensor()));
-  Tensor tensor_expected(DT_INT32, {4});
-  test::FillValues<int>(&tensor_expected, {1, 2, 1, 1});
-  test::ExpectTensorEqual<int>(tensor_expected, tensor);
-}
-
-TEST_F(LayoutOptimizerTest, MulVectorAnd4D) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
-  auto vector = ops::Const(s.WithOpName("vector"), {3.0f, 7.0f}, {2});
-  auto mul = ops::Mul(s.WithOpName("mul"), vector, conv);
-  auto o = ops::Identity(s.WithOpName("o"), mul);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto mul_node = node_map.GetNode("mul");
-  EXPECT_EQ(mul_node->input(0), "mul-0-ReshapeNHWCToNCHW-LayoutOptimizer");
-  EXPECT_EQ(mul_node->input(1), "Conv2D");
-  auto mul_const = node_map.GetNode("mul-0-ReshapeConst-LayoutOptimizer");
-  Tensor tensor;
-  EXPECT_TRUE(
-      tensor.FromProto(mul_const->mutable_attr()->at({"value"}).tensor()));
-  Tensor tensor_expected(DT_INT32, {4});
-  test::FillValues<int>(&tensor_expected, {1, 2, 1, 1});
-  test::ExpectTensorEqual<int>(tensor_expected, tensor);
-}
-
-TEST_F(LayoutOptimizerTest, SliceConst) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 5, 2, "VALID");
-  auto begin = ops::Const(s.WithOpName("begin"), {0, 2, 3, 1}, {4});
-  auto size = ops::Const(s.WithOpName("size"), {4, 1, 2, 4}, {4});
-  auto slice = ops::Slice(s.WithOpName("slice"), conv, begin, size);
-  auto o = ops::Identity(s.WithOpName("o"), slice);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto slice_node = node_map.GetNode("slice");
-  EXPECT_EQ(slice_node->input(0), "Conv2D");
-  EXPECT_EQ(slice_node->input(1), "slice-1-LayoutOptimizer");
-  EXPECT_EQ(slice_node->input(2), "slice-2-LayoutOptimizer");
-
-  auto begin_const = node_map.GetNode("slice-1-LayoutOptimizer");
-  Tensor begin_tensor;
-  EXPECT_TRUE(begin_tensor.FromProto(
-      begin_const->mutable_attr()->at({"value"}).tensor()));
-  Tensor begin_tensor_expected(DT_INT32, {4});
-  test::FillValues<int>(&begin_tensor_expected, {0, 1, 2, 3});
-  test::ExpectTensorEqual<int>(begin_tensor_expected, begin_tensor);
-
-  auto size_const = node_map.GetNode("slice-2-LayoutOptimizer");
-  Tensor size_tensor;
-  EXPECT_TRUE(size_tensor.FromProto(
-      size_const->mutable_attr()->at({"value"}).tensor()));
-  Tensor size_tensor_expected(DT_INT32, {4});
-  test::FillValues<int>(&size_tensor_expected, {4, 4, 1, 2});
-  test::ExpectTensorEqual<int>(size_tensor_expected, size_tensor);
-}
-
-TEST_F(LayoutOptimizerTest, SliceNonConst) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 5, 2, "VALID");
-  auto begin = ops::Const(s.WithOpName("begin"), {0, 2, 3, 1}, {4});
-  auto ibegin = ops::Identity(s.WithOpName("ibegin"), begin);
-  auto size = ops::Const(s.WithOpName("size"), {4, 1, 2, 4}, {4});
-  auto isize = ops::Identity(s.WithOpName("isize"), size);
-  auto slice = ops::Slice(s.WithOpName("slice"), conv, ibegin, isize);
-  auto o = ops::Identity(s.WithOpName("o"), slice);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto slice_node = node_map.GetNode("slice");
-  EXPECT_EQ(slice_node->input(0), "Conv2D");
-  EXPECT_EQ(slice_node->input(1),
-            "slice-1-VecPermuteNHWCToNCHW-LayoutOptimizer");
-  EXPECT_EQ(slice_node->input(2),
-            "slice-2-VecPermuteNHWCToNCHW-LayoutOptimizer");
-  auto perm1 = node_map.GetNode("slice-1-VecPermuteNHWCToNCHW-LayoutOptimizer");
-  EXPECT_EQ(perm1->op(), "DataFormatVecPermute");
-  EXPECT_EQ(perm1->input(0), "ibegin");
-  auto perm2 = node_map.GetNode("slice-2-VecPermuteNHWCToNCHW-LayoutOptimizer");
-  EXPECT_EQ(perm1->op(), "DataFormatVecPermute");
-  EXPECT_EQ(perm2->input(0), "isize");
-}
-
-TEST_F(LayoutOptimizerTest, DoNotApplyOptimizerTwice) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto scalar =
-      ops::Const(s.WithOpName("AlreadyApplied-LayoutOptimizer"), 3.0f, {});
-  auto mul = ops::Mul(s.WithOpName("mul"), scalar, scalar);
-  auto o = ops::Identity(s.WithOpName("o"), mul);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  EXPECT_TRUE(errors::IsInvalidArgument(status));
-}
-
-TEST_F(LayoutOptimizerTest, ShapeNWithInputs4DAnd4D) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
-  auto shapen = ops::ShapeN(s.WithOpName("shapen"), {conv, conv});
-  auto add = ops::Add(s.WithOpName("add"), shapen[0], shapen[1]);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto shapen_node = node_map.GetNode("shapen");
-  EXPECT_EQ(shapen_node->input(0), "Conv2D");
-  EXPECT_EQ(shapen_node->input(1), "Conv2D");
-  auto add_node = node_map.GetNode("add");
-  EXPECT_EQ(add_node->input(0),
-            "shapen-0-0-VecPermuteNCHWToNHWC-LayoutOptimizer");
-  EXPECT_EQ(add_node->input(1),
-            "shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
-  auto vec_permute1 =
-      node_map.GetNode("shapen-0-0-VecPermuteNCHWToNHWC-LayoutOptimizer");
-  EXPECT_EQ(vec_permute1->input(0), "shapen");
-  EXPECT_EQ(vec_permute1->op(), "DataFormatVecPermute");
-  auto vec_permute2 =
-      node_map.GetNode("shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
-  EXPECT_EQ(vec_permute2->input(0), "shapen:1");
-  EXPECT_EQ(vec_permute2->op(), "DataFormatVecPermute");
-}
-
-TEST_F(LayoutOptimizerTest, ShapeNWithInputsVectorAnd4D) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
-  auto vector = ops::Const(s.WithOpName("vector"), 3.0f, {7});
-  auto shapen = ops::ShapeN(s.WithOpName("shapen"), {vector, conv});
-  auto add = ops::Add(s.WithOpName("add"), shapen[0], shapen[1]);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto shapen_node = node_map.GetNode("shapen");
-  EXPECT_EQ(shapen_node->input(0), "vector");
-  EXPECT_EQ(shapen_node->input(1), "Conv2D");
-  auto add_node = node_map.GetNode("add");
-  EXPECT_EQ(add_node->input(0), "shapen");
-  EXPECT_EQ(add_node->input(1),
-            "shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
-  auto vec_permute =
-      node_map.GetNode("shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
-  EXPECT_EQ(vec_permute->input(0), "shapen:1");
-  EXPECT_EQ(vec_permute->op(), "DataFormatVecPermute");
-}
-
-TEST_F(LayoutOptimizerTest, ShapeNWithInputs4DAndNoNeedToTransform4D) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
-  auto tensor_4d = ops::Const(s.WithOpName("tensor_4d"), 3.0f, {1, 1, 1, 3});
-  auto i1 = ops::Identity(s.WithOpName("i1"), tensor_4d);
-  Output i2 = ops::Identity(s.WithOpName("i2"), i1);
-  auto shapen = ops::ShapeN(s.WithOpName("shapen"), {conv, i2});
-  auto add = ops::Add(s.WithOpName("add"), shapen[0], shapen[1]);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto shapen_node = node_map.GetNode("shapen");
-  EXPECT_EQ(shapen_node->input(0), "Conv2D");
-  EXPECT_EQ(shapen_node->input(1), "i2");
-}
-
-TEST_F(LayoutOptimizerTest, Switch) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
-  ops::Variable ctrl(s.WithOpName("ctrl"), {}, DT_BOOL);
-  auto sw = ops::Switch(s.WithOpName("switch"), conv, ctrl);
-  auto i1 = ops::Identity(s.WithOpName("i1"), sw.output_true);
-  auto i2 = ops::Identity(s.WithOpName("i2"), sw.output_false);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto switch_node = node_map.GetNode("switch");
-  EXPECT_EQ(switch_node->input(0), "Conv2D");
-  EXPECT_EQ(switch_node->input(1), "ctrl");
-  auto i1_node = node_map.GetNode("i1");
-  auto i2_node = node_map.GetNode("i2");
-  auto trans1 = node_map.GetNode(i1_node->input(0));
-  EXPECT_EQ(trans1->input(0), "switch:1");
-  auto trans2 = node_map.GetNode(i2_node->input(0));
-  EXPECT_EQ(trans2->input(0), "switch");
-}
-
-TEST_F(LayoutOptimizerTest, MergeBothInputsConvertible) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
-  Output i1 = ops::Identity(s.WithOpName("i1"), conv);
-  auto merge = ops::Merge(s.WithOpName("merge"), {conv, i1});
-  auto i2 = ops::Identity(s.WithOpName("i2"), merge.output);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto merge_node = node_map.GetNode("merge");
-  EXPECT_EQ(merge_node->input(0), "Conv2D");
-  EXPECT_EQ(merge_node->input(1), "i1");
-  auto i2_node = node_map.GetNode("i2");
-  EXPECT_EQ(i2_node->input(0), "merge-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
-  auto transpose =
-      node_map.GetNode("merge-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
-  EXPECT_EQ(transpose->input(0), "merge");
-}
-
-TEST_F(LayoutOptimizerTest, MergeOneInputNotConvertible) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
-  auto tensor_4d = ops::Const(s.WithOpName("tensor_4d"), 3.0f, {1, 1, 1, 3});
-  auto merge = ops::Merge(s.WithOpName("merge"), {tensor_4d, conv});
-  auto i2 = ops::Identity(s.WithOpName("i2"), merge.output);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto merge_node = node_map.GetNode("merge");
-  EXPECT_EQ(merge_node->input(0), "tensor_4d");
-  EXPECT_EQ(merge_node->input(1),
-            "Conv2D-0-1-TransposeNCHWToNHWC-LayoutOptimizer");
-}
-
-TEST_F(LayoutOptimizerTest, Complex) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
-  auto comp = ops::Complex(s.WithOpName("complex"), conv, conv);
-  auto i = ops::Identity(s.WithOpName("i"), comp);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto merge_node = node_map.GetNode("complex");
-  EXPECT_EQ(merge_node->input(0), "Conv2D");
-  EXPECT_EQ(merge_node->input(1), "Conv2D");
-  auto trans =
-      node_map.GetNode("complex-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
-  EXPECT_EQ(trans->attr().at("T").type(), DT_COMPLEX64);
-}
-
-TEST_F(LayoutOptimizerTest, IdentityNWithInputsVectorAnd4D) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
-  auto vector = ops::Const(s.WithOpName("vector"), 3.0f, {2});
-  auto identity_n = ops::IdentityN(s.WithOpName("identity_n"), {vector, conv});
-  auto add = ops::Add(s.WithOpName("add"), identity_n[0], identity_n[1]);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto i = node_map.GetNode("identity_n");
-  EXPECT_EQ(i->input(0), "vector");
-  EXPECT_EQ(i->input(1), "Conv2D");
-  auto trans =
-      node_map.GetNode("identity_n-0-1-TransposeNCHWToNHWC-LayoutOptimizer");
-  EXPECT_EQ(trans->input(0), "identity_n:1");
-  auto add_node = node_map.GetNode("add");
-  EXPECT_EQ(add_node->input(0), "identity_n");
-  EXPECT_EQ(add_node->input(1),
-            "identity_n-0-1-TransposeNCHWToNHWC-LayoutOptimizer");
-}
-
-TEST_F(LayoutOptimizerTest, LoopNoLiveLock) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto c = ops::Const(s.WithOpName("const"), 3.0f, {8, 3, 3, 2});
-  auto merge = ops::Merge(s.WithOpName("merge"), {c, c});
-  auto i0 = ops::Identity(s.WithOpName("i0"), merge.output);
-  ops::Variable v_ctrl(s.WithOpName("v_ctrl"), {}, DT_BOOL);
-  auto sw = ops::Switch(s.WithOpName("switch"), i0, v_ctrl);
-  auto next = ops::NextIteration(s.WithOpName("next"), sw.output_true);
-  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
-  auto mul = ops::Mul(s.WithOpName("mul"), conv, sw.output_false);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  NodeMap node_map_original(&item.graph);
-  auto merge_node = node_map_original.GetNode("merge");
-  // Modify the graph to create a loop
-  merge_node->set_input(1, "next");
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto conv_node = node_map.GetNode("Conv2D");
-  EXPECT_EQ(conv_node->input(0),
-            "Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer");
-  auto mul_node = node_map.GetNode("mul");
-  EXPECT_EQ(mul_node->input(0),
-            "Conv2D-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
-}
-
-TEST_F(LayoutOptimizerTest, DevicePlacement) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
-  auto shape = ops::Shape(s.WithOpName("s"), conv);
-  auto i = ops::Identity(s.WithOpName("i"), shape);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  VirtualPlacer virtual_placer(virtual_cluster_->GetDevices());
-  for (auto& node : *item.graph.mutable_node()) {
-    string device = virtual_placer.get_canonical_device_name(node);
-    node.set_device(device);
-  }
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto vec_permute =
-      node_map.GetNode("s-0-0-VecPermuteNCHWToNHWC-LayoutOptimizer");
-  EXPECT_EQ(vec_permute->attr().at("_kernel").s(), "host");
-}
-
-TEST_F(LayoutOptimizerTest, PermConstWithDevice) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  const string worker0_gpu0 = "/job:w/replica:0/task:0/device:gpu:0";
-  const string worker1_gpu1 = "/job:w/replica:0/task:1/device:gpu:1";
-  const string worker0_node_prefix = "job_w_replica_0_task_0_device_gpu_0-";
-  const string worker1_node_prefix = "job_w_replica_0_task_1_device_gpu_1-";
-  const string perm_nchw2nhwc_str = "PermConstNCHWToNHWC-LayoutOptimizer";
-  const string perm_nhwc2nchw_str = "PermConstNHWCToNCHW-LayoutOptimizer";
-  auto conv_0 = SimpleConv2D(&s, 4, 2, "VALID", worker0_gpu0);
-  auto shape_0 = ops::Shape(s.WithOpName("s"), conv_0);
-  auto i_0 = ops::Identity(s.WithOpName("i"), shape_0);
-  auto conv_1 = SimpleConv2D(&s, 4, 2, "VALID", worker1_gpu1);
-  auto shape_1 = ops::Shape(s.WithOpName("s"), conv_1);
-  auto i_1 = ops::Identity(s.WithOpName("i"), shape_1);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  LayoutOptimizer optimizer;
-  GraphDef output;
-  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
-  NodeMap node_map(&output);
-  auto const_permute_0_0 =
-      node_map.GetNode(worker0_node_prefix + perm_nchw2nhwc_str);
-  auto const_permute_0_1 =
-      node_map.GetNode(worker0_node_prefix + perm_nhwc2nchw_str);
-  EXPECT_EQ(const_permute_0_0->device(), worker0_gpu0);
-  EXPECT_EQ(const_permute_0_1->device(), worker0_gpu0);
-  auto const_permute_1_0 =
-      node_map.GetNode(worker1_node_prefix + perm_nchw2nhwc_str);
-  auto const_permute_1_1 =
-      node_map.GetNode(worker1_node_prefix + perm_nhwc2nchw_str);
-  EXPECT_EQ(const_permute_1_0->device(), worker1_gpu1);
-  EXPECT_EQ(const_permute_1_1->device(), worker1_gpu1);
-  EXPECT_FALSE(node_map.GetNode(perm_nchw2nhwc_str));
-  EXPECT_FALSE(node_map.GetNode(perm_nhwc2nchw_str));
-}
-}  // namespace
-}  // namespace grappler
-}  // namespace tensorflow