[tf.data] Removing unused experimental optimizations.

This CL removes API and implementation for tf.data optimizations that are either not used (`filter_with_random_uniform_fusion`, `hoist_random_uniform`, `reorder_data_discarding_ops`) or not needed because there is no longer any mechanism to trigger them (`latency_all_edges`).

PiperOrigin-RevId: 379942022
Change-Id: I0ce93df9d098949677e37bc35126e95c8250f1b8
diff --git a/RELEASE.md b/RELEASE.md
index bbf72bb..65eb2d6 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -154,9 +154,11 @@
         *   `tf.data.experimental.StatsOptions.*`
         *   `tf.data.experimental.bytes_produced_stats`
         *   `tf.data.experimental.latency_stats`
-    *   Removed experimental tf.data API for map vectorization:
-        *   `tf.data.experimental.OptimizationOptions.map_vectorization`
+    *   Removed the following experimental tf.data optimization APIs:
         *   `tf.data.experimental.MapVectorizationOptions.*`
+        *   `tf.data.experimental.OptimizationOptions.filter_with_random_uniform_fusion`
+        *   `tf.data.experimental.OptimizationOptions.hoist_random_uniform`
+        *   `tf.data.experimental.OptimizationOptions.map_vectorization`                 *   `tf.data.experimental.OptimizationOptions.reorder_data_discarding_ops`
 *   `tf.keras`:
     *   Fix usage of `__getitem__` slicing in Keras Functional APIs when the
         inputs are `RaggedTensor` objects.
diff --git a/tensorflow/core/data/dataset_utils.cc b/tensorflow/core/data/dataset_utils.cc
index 915ca2e..35d95de 100644
--- a/tensorflow/core/data/dataset_utils.cc
+++ b/tensorflow/core/data/dataset_utils.cc
@@ -88,13 +88,9 @@
 constexpr char kMapParallelizationOpt[] = "map_parallelization";
 constexpr char kShuffleAndRepeatFusionOpt[] = "shuffle_and_repeat_fusion";
 constexpr char kFilterFusionOpt[] = "filter_fusion";
-constexpr char kFilterWithRandomUniformFusionOpt[] =
-    "filter_with_random_uniform_fusion";
-constexpr char kHoistRandomUniformOpt[] = "hoist_random_uniform";
 constexpr char kMapAndFilterFusionOpt[] = "map_and_filter_fusion";
 constexpr char kMapFusionOpt[] = "map_fusion";
 constexpr char kParallelBatchOpt[] = "parallel_batch";
-constexpr char kReorderDataDiscardingOpsOpt[] = "reorder_data_discarding_ops";
 constexpr char kAutotuneBufferSizesOpt[] = "autotune_buffer_sizes";
 constexpr char kDisablePrefetchLegacyAutotuneOpt[] =
     "disable_prefetch_legacy_autotune";
@@ -139,22 +135,6 @@
       optimization_disabled->insert(kFilterFusionOpt);
     }
   }
-  if (optimization_options.optional_filter_with_random_uniform_fusion_case() ==
-      OptimizationOptions::kFilterWithRandomUniformFusion) {
-    if (optimization_options.filter_with_random_uniform_fusion()) {
-      optimization_enabled->insert(kFilterWithRandomUniformFusionOpt);
-    } else {
-      optimization_disabled->insert(kFilterWithRandomUniformFusionOpt);
-    }
-  }
-  if (optimization_options.optional_hoist_random_uniform_case() ==
-      OptimizationOptions::kHoistRandomUniform) {
-    if (optimization_options.hoist_random_uniform()) {
-      optimization_enabled->insert(kHoistRandomUniformOpt);
-    } else {
-      optimization_disabled->insert(kHoistRandomUniformOpt);
-    }
-  }
   if (optimization_options.optional_map_and_batch_fusion_case() ==
       OptimizationOptions::kMapAndBatchFusion) {
     if (optimization_options.map_and_batch_fusion()) {
@@ -203,14 +183,6 @@
       optimization_disabled->insert(kParallelBatchOpt);
     }
   }
-  if (optimization_options.optional_reorder_data_discarding_ops_case() ==
-      OptimizationOptions::kReorderDataDiscardingOps) {
-    if (optimization_options.reorder_data_discarding_ops()) {
-      optimization_enabled->insert(kReorderDataDiscardingOpsOpt);
-    } else {
-      optimization_disabled->insert(kReorderDataDiscardingOpsOpt);
-    }
-  }
   if (optimization_options.optional_shuffle_and_repeat_fusion_case() ==
       OptimizationOptions::kShuffleAndRepeatFusion) {
     if (optimization_options.shuffle_and_repeat_fusion()) {
diff --git a/tensorflow/core/data/dataset_utils_test.cc b/tensorflow/core/data/dataset_utils_test.cc
index 81d0d10..c3927f7 100644
--- a/tensorflow/core/data/dataset_utils_test.cc
+++ b/tensorflow/core/data/dataset_utils_test.cc
@@ -629,26 +629,21 @@
   options.set_deterministic(false);
   options.mutable_optimization_options()->set_autotune_buffers(true);
   options.mutable_optimization_options()->set_filter_fusion(true);
-  options.mutable_optimization_options()->set_filter_with_random_uniform_fusion(
-      true);
-  options.mutable_optimization_options()->set_hoist_random_uniform(true);
   options.mutable_optimization_options()->set_map_and_batch_fusion(true);
   options.mutable_optimization_options()->set_map_and_filter_fusion(true);
   options.mutable_optimization_options()->set_map_fusion(true);
   options.mutable_optimization_options()->set_map_parallelization(true);
   options.mutable_optimization_options()->set_noop_elimination(true);
   options.mutable_optimization_options()->set_parallel_batch(true);
-  options.mutable_optimization_options()->set_reorder_data_discarding_ops(true);
   options.mutable_optimization_options()->set_shuffle_and_repeat_fusion(true);
   options.set_slack(true);
   return {options,
           /*expected_enabled=*/
           {"autotune_buffer_sizes", "disable_prefetch_legacy_autotune",
-           "filter_fusion", "filter_with_random_uniform_fusion",
-           "hoist_random_uniform", "make_sloppy", "map_and_batch_fusion",
+           "filter_fusion", "make_sloppy", "map_and_batch_fusion",
            "map_and_filter_fusion", "map_fusion", "map_parallelization",
-           "noop_elimination", "parallel_batch", "reorder_data_discarding_ops",
-           "shuffle_and_repeat_fusion", "slack"},
+           "noop_elimination", "parallel_batch", "shuffle_and_repeat_fusion",
+           "slack"},
           /*expected_disabled=*/{},
           /*expected_default=*/{}};
 }
diff --git a/tensorflow/core/framework/dataset_options.proto b/tensorflow/core/framework/dataset_options.proto
index 273b906..85e9852 100644
--- a/tensorflow/core/framework/dataset_options.proto
+++ b/tensorflow/core/framework/dataset_options.proto
@@ -64,15 +64,10 @@
   oneof optional_filter_fusion {
     bool filter_fusion = 6;
   }
-  // Whether to fuse filter dataset that predicts random_uniform < rate into a
-  // sampling dataset.
-  oneof optional_filter_with_random_uniform_fusion {
-    bool filter_with_random_uniform_fusion = 7;
-  }
-  // Whether to hoist tf.random_uniform() ops out of map transformations.
-  oneof optional_hoist_random_uniform {
-    bool hoist_random_uniform = 8;
-  }
+  // NOTE: field id 7 deleted in June 2021.
+  reserved 7;
+  // NOTE: field id 8 deleted in June 2021.
+  reserved 8;
   // Whether to fuse map and batch transformations.
   oneof optional_map_and_batch_fusion {
     bool map_and_batch_fusion = 9;
@@ -106,15 +101,8 @@
   oneof optional_parallel_batch {
     bool parallel_batch = 15;
   }
-  // Whether to reorder ops that will discard data to the front of unary
-  // cardinality preserving transformations, e.g. dataset.map(...).take(3) will
-  // be optimized to dataset.take(3).map(...). For now this optimization will
-  // move `skip`, `shard` and `take` to the front of `map` and `prefetch`. This
-  // optimization is only for performance; it will not affect the output of the
-  // dataset.
-  oneof optional_reorder_data_discarding_ops {
-    bool reorder_data_discarding_ops = 16;
-  }
+  // Field id 16 was removed in 06/2021.
+  reserved 16;
   // Whether to fuse shuffle and repeat transformations.
   oneof optional_shuffle_and_repeat_fusion {
     bool shuffle_and_repeat_fusion = 17;
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index 9e80a73..4a2dbf4 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -22,9 +22,6 @@
         ":disable_prefetch_legacy_autotune",
         ":enable_gradient_descent",
         ":filter_fusion",
-        ":filter_with_random_uniform_fusion",
-        ":hoist_random_uniform",
-        ":latency_all_edges",
         ":make_sloppy",
         ":map_and_batch_fusion",
         ":map_and_filter_fusion",
@@ -33,7 +30,6 @@
         ":meta_optimizer",
         ":noop_elimination",
         ":parallel_batch",
-        ":reorder_data_discarding_ops",
         ":shuffle_and_repeat_fusion",
         ":slack",
         ":use_private_thread_pool",
@@ -288,44 +284,6 @@
 )
 
 cc_library(
-    name = "filter_with_random_uniform_fusion",
-    srcs = ["filter_with_random_uniform_fusion.cc"],
-    hdrs = [
-        "filter_with_random_uniform_fusion.h",
-    ],
-    visibility = ["//visibility:public"],
-    deps = [
-        ":graph_utils",
-        ":optimizer_base",
-        "//tensorflow/core:lib",
-        "//tensorflow/core/grappler:graph_view",
-        "//tensorflow/core/grappler:grappler_item",
-        "//tensorflow/core/grappler:op_types",
-        "//tensorflow/core/grappler:utils",
-        "//tensorflow/core/grappler/clusters:cluster",
-        "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
-    ] + tf_protos_all(),
-    alwayslink = 1,
-)
-
-tf_cc_test(
-    name = "filter_with_random_uniform_fusion_test",
-    size = "small",
-    srcs = ["filter_with_random_uniform_fusion_test.cc"],
-    visibility = ["//visibility:public"],
-    deps = [
-        ":filter_with_random_uniform_fusion",
-        ":graph_test_utils",
-        ":graph_utils",
-        "//tensorflow/core:framework",
-        "//tensorflow/core:test",
-        "//tensorflow/core:test_main",
-        "//tensorflow/core:testlib",
-        "//tensorflow/core/grappler:grappler_item",
-    ],
-)
-
-cc_library(
     name = "fusion_utils",
     srcs = ["fusion_utils.cc"],
     hdrs = [
@@ -450,81 +408,6 @@
 )
 
 cc_library(
-    name = "hoist_random_uniform",
-    srcs = ["hoist_random_uniform.cc"],
-    hdrs = [
-        "hoist_random_uniform.h",
-    ],
-    deps = [
-        ":function_utils",
-        ":graph_utils",
-        ":optimizer_base",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "//tensorflow/core:lib",
-        "//tensorflow/core/grappler:mutable_graph_view",
-        "//tensorflow/core/grappler:grappler_item",
-        "//tensorflow/core/grappler:op_types",
-        "//tensorflow/core/grappler:utils",
-        "//tensorflow/core/grappler/clusters:cluster",
-        "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
-        "//tensorflow/core:lib_internal",
-    ] + tf_protos_all(),
-    alwayslink = 1,
-)
-
-tf_cc_test(
-    name = "hoist_random_uniform_test",
-    size = "small",
-    srcs = ["hoist_random_uniform_test.cc"],
-    deps = [
-        ":graph_test_utils",
-        ":graph_utils",
-        ":hoist_random_uniform",
-        "//tensorflow/core:framework",
-        "//tensorflow/core:test",
-        "//tensorflow/core:test_main",
-        "//tensorflow/core:testlib",
-        "//tensorflow/core/grappler:grappler_item",
-    ] + tf_protos_all(),
-)
-
-cc_library(
-    name = "latency_all_edges",
-    srcs = ["latency_all_edges.cc"],
-    hdrs = [
-        "latency_all_edges.h",
-    ],
-    deps = [
-        ":graph_utils",
-        ":optimizer_base",
-        "//tensorflow/core/grappler:mutable_graph_view",
-        "//tensorflow/core/data:stats_utils",
-        "//tensorflow/core:lib",
-        "//tensorflow/core/grappler:grappler_item",
-        "//tensorflow/core/grappler:op_types",
-        "//tensorflow/core/grappler:utils",
-        "//tensorflow/core/grappler/clusters:cluster",
-        "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
-    ] + tf_protos_all(),
-    alwayslink = 1,
-)
-
-tf_cc_test(
-    name = "latency_all_edges_test",
-    size = "small",
-    srcs = ["latency_all_edges_test.cc"],
-    deps = [
-        ":graph_utils",
-        ":latency_all_edges",
-        "//tensorflow/core:framework",
-        "//tensorflow/core:test",
-        "//tensorflow/core:test_main",
-        "//tensorflow/core:testlib",
-        "//tensorflow/core/grappler:grappler_item",
-    ],
-)
-
-cc_library(
     name = "make_sloppy",
     srcs = ["make_sloppy.cc"],
     hdrs = ["make_sloppy.h"],
@@ -813,45 +696,6 @@
 )
 
 cc_library(
-    name = "reorder_data_discarding_ops",
-    srcs = ["reorder_data_discarding_ops.cc"],
-    hdrs = [
-        "reorder_data_discarding_ops.h",
-    ],
-    deps = [
-        ":function_utils",
-        ":graph_utils",
-        ":optimizer_base",
-        "@com_google_absl//absl/container:flat_hash_set",
-        "//tensorflow/core:lib",
-        "//tensorflow/core/grappler:mutable_graph_view",
-        "//tensorflow/core/grappler:grappler_item",
-        "//tensorflow/core/grappler:op_types",
-        "//tensorflow/core/grappler:utils",
-        "//tensorflow/core/grappler/clusters:cluster",
-        "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
-        "//tensorflow/core:lib_internal",
-    ] + tf_protos_all(),
-    alwayslink = 1,
-)
-
-tf_cc_test(
-    name = "reorder_data_discarding_ops_test",
-    size = "small",
-    srcs = ["reorder_data_discarding_ops_test.cc"],
-    deps = [
-        ":graph_test_utils",
-        ":graph_utils",
-        ":reorder_data_discarding_ops",
-        "//tensorflow/core:framework",
-        "//tensorflow/core:test",
-        "//tensorflow/core:test_main",
-        "//tensorflow/core:testlib",
-        "//tensorflow/core/grappler:grappler_item",
-    ] + tf_protos_all(),
-)
-
-cc_library(
     name = "shuffle_and_repeat_fusion",
     srcs = ["shuffle_and_repeat_fusion.cc"],
     hdrs = [
diff --git a/tensorflow/core/grappler/optimizers/data/filter_with_random_uniform_fusion.cc b/tensorflow/core/grappler/optimizers/data/filter_with_random_uniform_fusion.cc
deleted file mode 100644
index ba29e60..0000000
--- a/tensorflow/core/grappler/optimizers/data/filter_with_random_uniform_fusion.cc
+++ /dev/null
@@ -1,272 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/grappler/optimizers/data/filter_with_random_uniform_fusion.h"
-
-#include <iostream>
-
-#include "tensorflow/core/framework/attr_value.pb.h"
-#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/grappler/clusters/cluster.h"
-#include "tensorflow/core/grappler/graph_view.h"
-#include "tensorflow/core/grappler/grappler_item.h"
-#include "tensorflow/core/grappler/op_types.h"
-#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
-#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
-#include "tensorflow/core/grappler/utils.h"
-#include "tensorflow/core/platform/protobuf.h"
-
-namespace tensorflow {
-namespace grappler {
-
-constexpr char kFusedOpName[] = "SamplingDataset";
-
-NodeDef MakeFusedNode(const NodeDef& filter_node, float rate, int64 seed,
-                      int64 seed2, MutableGraphView* graph) {
-  NodeDef fused_node;
-  graph_utils::SetUniqueGraphNodeName("fused_sampling", graph->graph(),
-                                      &fused_node);
-  fused_node.set_op(kFusedOpName);
-
-  // Copy over inputs.
-  for (int i = 0; i < filter_node.input_size(); ++i) {
-    fused_node.add_input(filter_node.input(i));
-  }
-
-  // Required attrs.
-  for (auto key : {"output_shapes", "output_types"}) {
-    graph_utils::CopyAttribute(key, filter_node, &fused_node);
-  }
-
-  // Optional attrs.
-  for (auto key : {"use_inter_op_parallelism", "sloppy"}) {
-    if (gtl::FindOrNull(filter_node.attr(), key)) {
-      graph_utils::CopyAttribute(key, filter_node, &fused_node);
-    }
-  }
-
-  NodeDef* tmp_rate = graph_utils::AddScalarConstNode<float>(rate, graph);
-  fused_node.add_input(tmp_rate->name());
-  NodeDef* tmp_seed = graph_utils::AddScalarConstNode<int64>(seed, graph);
-  fused_node.add_input(tmp_seed->name());
-  NodeDef* tmp_seed2 = graph_utils::AddScalarConstNode<int64>(seed2, graph);
-  fused_node.add_input(tmp_seed2->name());
-
-  return fused_node;
-}
-
-const NodeDef* FunctionFindNodeDef(const FunctionDef& function, const string op,
-                                   const string func, const string match) {
-  for (const NodeDef& func_node : function.node_def()) {
-    if (func_node.op() != op) {
-      continue;
-    }
-    if (func_node.name() + match != func) {
-      continue;
-    }
-    return &func_node;
-  }
-  return nullptr;
-}
-
-bool FunctionFindFloatConst(const FunctionDef& function, const string& func,
-                            const string& match, float* result) {
-  const NodeDef* const_node =
-      FunctionFindNodeDef(function, "Const", func, match);
-  if (const_node == nullptr) {
-    return false;
-  }
-  if (const_node->attr().at("dtype").type() != DT_FLOAT) {
-    return false;
-  }
-  const auto& value = const_node->attr().at("value").tensor().float_val(0);
-  *result = value;
-  return true;
-}
-
-bool FunctionExpectFloatConst(const FunctionDef& function, const string& func,
-                              const string match, const float val) {
-  float result;
-  if (FunctionFindFloatConst(function, func, match, &result) && result == val) {
-    return true;
-  } else {
-    return false;
-  }
-}
-
-// This optimization fuses one of the following two forms of
-// filter + random_uniform predication into a single data sampling operation:
-// fuse:
-//   filter
-//   |
-//   + predication: less [0]
-//                  |
-//                  + random_uniform [1]
-//                  |
-//                  + rate
-// or:
-//   filter
-//   |
-//   + predication: less
-//                  |
-//                  + random_uniform[]
-//                  |
-//                  + rate
-// into:
-//   sampling(rate)
-Status FilterWithRandomUniformFusion::OptimizeAndCollectStats(
-    Cluster* cluster, const GrapplerItem& item, GraphDef* output,
-    OptimizationStats* stats) {
-  *output = item.graph;
-  MutableGraphView graph(output);
-  absl::flat_hash_set<string> nodes_to_delete;
-  float rate;
-  int64 seed, seed2;
-
-  for (const NodeDef& node : item.graph.node()) {
-    // stage 1 -- recognition
-    if (node.op() != "FilterDataset") {
-      continue;
-    }
-
-    // Use a more descriptive variable name
-    const NodeDef& filter_node = node;
-
-    // find predicate function of the node
-    const auto& predicate = filter_node.attr().at("predicate");
-    const string func_name = predicate.func().name();
-
-    bool function_match = false;
-    // find the function that matches func_name
-    for (const auto& function : item.graph.library().function()) {
-      if (function.signature().name() == func_name) {
-        if (function.ret().size() != 1) {
-          continue;
-        }
-        auto it = function.ret().begin();
-        string node_name = it->second;
-        const NodeDef* func_node =
-            FunctionFindNodeDef(function, "Identity", node_name, ":output:0");
-        while (func_node != nullptr) {
-          node_name = func_node->input(0);
-          func_node =
-              FunctionFindNodeDef(function, "Identity", node_name, ":output:0");
-        }
-        func_node = FunctionFindNodeDef(function, "StridedSlice", node_name,
-                                        ":output:0");
-        const NodeDef* less_node;
-        if (func_node != nullptr) {
-          // for form one: datasetS = datasetS.filter(lambda x:
-          // tf.less(tf.random_uniform([1]), rate)[0])
-          less_node = FunctionFindNodeDef(function, "Less", func_node->input(0),
-                                          ":z:0");
-        } else {
-          // for form two: datasetS = datasetS.filter(lambda _:
-          // tf.random_uniform([]) < rate)
-          less_node = FunctionFindNodeDef(function, "Less", node_name, ":z:0");
-        }
-        if (less_node == nullptr) {
-          continue;
-        }
-
-        // check whether the function is actually doing
-        // random_uniform[0.0, 1.0) < rate
-        // There could be two forms of random_uniform[0.0, 1.0) in the graph
-        // * Simple form just have a RandomUniform node which means
-        //   random_uniform[0.0, 1.0)
-        // * Expanded form is "RandomUniform * (1.0 - 0.0) + 0.0", which is
-        //   still random_uniform[0.0, 1.0)
-        //
-        // First detect whether simple form is used
-        const NodeDef* random_uniform_node = FunctionFindNodeDef(
-            function, "RandomUniform", less_node->input(0), ":output:0");
-        if (random_uniform_node == nullptr) {
-          // If expanded form is used, check boundaries
-          const NodeDef* random_uniform_result_node =
-              FunctionFindNodeDef(function, "Add", less_node->input(0), ":z:0");
-
-          if (!FunctionExpectFloatConst(function,
-                                        random_uniform_result_node->input(1),
-                                        ":output:0", 0.0f)) {
-            continue;
-          }
-
-          const NodeDef* random_uniform_mul_node = FunctionFindNodeDef(
-              function, "Mul", random_uniform_result_node->input(0), ":z:0");
-
-          const NodeDef* random_uniform_sub_node = FunctionFindNodeDef(
-              function, "Sub", random_uniform_mul_node->input(1), ":z:0");
-
-          if (!FunctionExpectFloatConst(function,
-                                        random_uniform_sub_node->input(0),
-                                        ":output:0", 1.0f)) {
-            continue;
-          }
-
-          if (!FunctionExpectFloatConst(function,
-                                        random_uniform_sub_node->input(1),
-                                        ":output:0", 0.0f)) {
-            continue;
-          }
-
-          random_uniform_node = FunctionFindNodeDef(
-              function, "RandomUniform", random_uniform_mul_node->input(0),
-              ":output:0");
-          if (random_uniform_node == nullptr) {
-            continue;
-          }
-        }
-
-        seed = random_uniform_node->attr().at("seed").i();
-        seed2 = random_uniform_node->attr().at("seed2").i();
-
-        if (!FunctionFindFloatConst(function, less_node->input(1), ":output:0",
-                                    &rate)) {
-          continue;
-        }
-
-        function_match = true;
-        break;
-      }
-    }
-
-    if (!function_match) {
-      continue;
-    }
-
-    // stage 2 -- fuse
-    const auto* fused_sampling =
-        graph.AddNode(MakeFusedNode(filter_node, rate, seed, seed2, &graph));
-
-    TF_RETURN_IF_ERROR(
-        graph.UpdateFanouts(filter_node.name(), fused_sampling->name()));
-
-    // Mark the `Filter` node for removal.
-    nodes_to_delete.insert(filter_node.name());
-    stats->num_changes++;
-  }
-
-  TF_RETURN_IF_ERROR(graph.DeleteNodes(nodes_to_delete));
-  return Status::OK();
-}
-
-// TODO(b/131229793): The current implementation of the optimization is brittle
-// as it depends on the order of inputs to commutative nodes. Make the
-// optimization robust to the input ordering before re-enabling it.
-// REGISTER_GRAPH_OPTIMIZER_AS(FilterWithRandomUniformFusion,
-//                             "filter_with_random_uniform_fusion");
-
-}  // end namespace grappler
-}  // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/filter_with_random_uniform_fusion.h b/tensorflow/core/grappler/optimizers/data/filter_with_random_uniform_fusion.h
deleted file mode 100644
index 419c38e..0000000
--- a/tensorflow/core/grappler/optimizers/data/filter_with_random_uniform_fusion.h
+++ /dev/null
@@ -1,46 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FILTER_WITH_RANDOM_UNIFORM_FUSION_H_
-#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FILTER_WITH_RANDOM_UNIFORM_FUSION_H_
-
-#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h"
-
-namespace tensorflow {
-namespace grappler {
-
-class FilterWithRandomUniformFusion : public TFDataOptimizerBase {
- public:
-  FilterWithRandomUniformFusion() = default;
-  ~FilterWithRandomUniformFusion() override = default;
-
-  string name() const override { return "filter_with_random_uniform_fusion"; };
-
-  bool UsesFunctionLibrary() const override { return false; }
-
-  Status Init(
-      const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
-    return Status::OK();
-  }
-
-  Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item,
-                                 GraphDef* output,
-                                 OptimizationStats* stats) override;
-};
-
-}  // end namespace grappler
-}  // end namespace tensorflow
-
-#endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FILTER_WITH_RANDOM_UNIFORM_FUSION_H_
diff --git a/tensorflow/core/grappler/optimizers/data/filter_with_random_uniform_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/filter_with_random_uniform_fusion_test.cc
deleted file mode 100644
index 62b7f56..0000000
--- a/tensorflow/core/grappler/optimizers/data/filter_with_random_uniform_fusion_test.cc
+++ /dev/null
@@ -1,78 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/grappler/optimizers/data/filter_with_random_uniform_fusion.h"
-
-#include "tensorflow/core/framework/attr_value_util.h"
-#include "tensorflow/core/framework/function_testlib.h"
-#include "tensorflow/core/framework/tensor_testutil.h"
-#include "tensorflow/core/grappler/grappler_item.h"
-#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
-#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/platform/test.h"
-
-namespace tensorflow {
-namespace grappler {
-namespace {
-
-using graph_tests_utils::MakeFilterNode;
-
-TEST(FilterWithRandomUniformFusionTest, FuseToSampling) {
-  using test::function::NDef;
-  GrapplerItem item;
-  item.graph = test::function::GDef(
-      {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
-       NDef("stop", "Const", {}, {{"value", 10000}, {"dtype", DT_INT32}}),
-       NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
-       NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
-       MakeFilterNode("filter1", "range", "RandomUniformLess")},
-      // FunctionLib
-      {
-          test::function::RandomUniformLess(),
-      });
-
-  FilterWithRandomUniformFusion optimizer;
-  GraphDef output;
-  TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
-  EXPECT_TRUE(graph_utils::ContainsNodeWithOp("SamplingDataset", output));
-  EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("filter1", output));
-}
-
-TEST(FilterWithoutRandomUniformFusionTest, FuseToSampling) {
-  using test::function::NDef;
-  GrapplerItem item;
-  item.graph = test::function::GDef(
-      {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
-       NDef("stop", "Const", {}, {{"value", 10000}, {"dtype", DT_INT32}}),
-       NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
-       NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
-       MakeFilterNode("filter1", "range")},
-      // FunctionLib
-      {
-          test::function::XTimesTwo(),
-          test::function::IsZero(),
-      });
-
-  FilterWithRandomUniformFusion optimizer;
-  GraphDef output;
-  TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
-  EXPECT_FALSE(graph_utils::ContainsNodeWithOp("SamplingDataset", output));
-  EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName("filter1", output));
-}
-
-}  // namespace
-}  // namespace grappler
-}  // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc
deleted file mode 100644
index 6ec49de..0000000
--- a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc
+++ /dev/null
@@ -1,298 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h"
-
-#include "absl/container/flat_hash_set.h"
-#include "tensorflow/core/framework/attr_value.pb.h"
-#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/framework/op_def.pb.h"
-#include "tensorflow/core/framework/types.pb.h"
-#include "tensorflow/core/grappler/clusters/cluster.h"
-#include "tensorflow/core/grappler/grappler_item.h"
-#include "tensorflow/core/grappler/mutable_graph_view.h"
-#include "tensorflow/core/grappler/op_types.h"
-#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
-#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
-#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
-#include "tensorflow/core/grappler/utils.h"
-#include "tensorflow/core/lib/gtl/map_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/platform/protobuf.h"
-
-namespace tensorflow {
-namespace grappler {
-namespace {
-
-NodeDef MakeStatelessMap(const NodeDef& map_node, const NodeDef& zip_node,
-                         const FunctionDef& stateless_function,
-                         MutableGraphView* graph) {
-  NodeDef stateless_map;
-  graph_utils::SetUniqueGraphNodeName("stateless_map", graph->graph(),
-                                      &stateless_map);
-
-  stateless_map.set_op("MapDataset");
-  stateless_map.add_input(zip_node.name());
-  // Add placeholders.
-  for (int i = 1; i < map_node.input_size(); i++)
-    stateless_map.add_input(map_node.input(i));
-
-  auto attr = map_node.attr().at("f");
-  *attr.mutable_func()->mutable_name() = stateless_function.signature().name();
-  *attr.mutable_func()->mutable_attr() = stateless_function.attr();
-  (*stateless_map.mutable_attr())["f"] = std::move(attr);
-
-  graph_utils::CopyAttribute("Targuments", map_node, &stateless_map);
-  for (auto key : {"output_shapes", "output_types"})
-    graph_utils::CopyAttribute(key, map_node, &stateless_map);
-
-  if (const auto* attr =
-          gtl::FindOrNull(map_node.attr(), "use_inter_op_parallelism"))
-    (*stateless_map.mutable_attr())["use_inter_op_parallelism"] = *attr;
-
-  return stateless_map;
-}
-
-NodeDef MakeRandomDataset(const NodeDef& random_uniform_node,
-                          MutableGraphView* graph) {
-  NodeDef random_dataset;
-  random_dataset.set_op("ExperimentalRandomDataset");
-  graph_utils::SetUniqueGraphNodeName("RandomDataset", graph->graph(),
-                                      &random_dataset);
-
-  const auto* seed = graph_utils::AddScalarConstNode<int64>(
-      random_uniform_node.attr().at("seed").i(), graph);
-  const auto* seed2 = graph_utils::AddScalarConstNode<int64>(
-      random_uniform_node.attr().at("seed2").i(), graph);
-
-  random_dataset.add_input(seed->name());
-  random_dataset.add_input(seed2->name());
-
-  (*random_dataset.mutable_attr())["output_shapes"].mutable_list()->add_shape();
-  (*random_dataset.mutable_attr())["output_types"].mutable_list()->add_type(
-      DT_INT64);
-
-  return random_dataset;
-}
-
-NodeDef MakeBatchTwo(const NodeDef& random_dataset, MutableGraphView* graph) {
-  NodeDef batch_dataset;
-  batch_dataset.set_op("BatchDatasetV2");
-  graph_utils::SetUniqueGraphNodeName("pair_of_random", graph->graph(),
-                                      &batch_dataset);
-  const auto* batch_size = graph_utils::AddScalarConstNode<int64>(2, graph);
-  const auto* drop_reminder = graph_utils::AddScalarConstNode(false, graph);
-  batch_dataset.add_input(random_dataset.name());
-  batch_dataset.add_input(batch_size->name());
-  batch_dataset.add_input(drop_reminder->name());
-
-  (*batch_dataset.mutable_attr())["output_shapes"]
-      .mutable_list()
-      ->add_shape()
-      ->mutable_dim()
-      ->Add()
-      ->set_size(-1);
-  (*batch_dataset.mutable_attr())["output_types"].mutable_list()->add_type(
-      DT_INT64);
-
-  return batch_dataset;
-}
-
-NodeDef MakeZipNode(const NodeDef& first_node, const NodeDef& second_node,
-                    MutableGraphView* graph) {
-  NodeDef zip_node;
-  graph_utils::SetUniqueGraphNodeName("zip_with_random", graph->graph(),
-                                      &zip_node);
-
-  zip_node.set_op("ZipDataset");
-  zip_node.add_input(first_node.name());
-  zip_node.add_input(second_node.name());
-
-  for (auto key : {"output_shapes", "output_types"})
-    graph_utils::ConcatAttributeList(key, first_node, second_node, &zip_node);
-
-  (*zip_node.mutable_attr())["N"].set_i(2);
-
-  return zip_node;
-}
-
-// We need to insert our argument before the placeholders, which are the last
-// arguments.
-OpDef_ArgDef* InsertSeedArgument(FunctionDef* function, int num_placeholders) {
-  OpDef* signature = function->mutable_signature();
-  int new_argument_idx = signature->input_arg_size() - num_placeholders;
-  signature->add_input_arg();
-  for (int i = signature->input_arg_size() - 1; i > new_argument_idx; i--) {
-    signature->mutable_input_arg()->SwapElements(i - 1, i);
-  }
-  auto* seed_arg = signature->mutable_input_arg(new_argument_idx);
-  seed_arg->set_name(strings::StrCat("seed_arg", new_argument_idx));
-  seed_arg->set_type(DT_INT64);
-
-  // Update arg_attr, any arg_attrs for the placeholders how have index one
-  // higher.
-  for (int i = signature->input_arg_size() - 1; i > new_argument_idx; i--) {
-    if (function->arg_attr().contains(i - 1)) {
-      (*function->mutable_arg_attr())[i] =
-          (*function->mutable_arg_attr())[i - 1];
-      function->mutable_arg_attr()->erase(i - 1);
-    }
-  }
-
-  return seed_arg;
-}
-
-// Make function that uses `StatelessRandomUniform` instead of `RandomUniform`
-// to make it less statefull.  The function can still be stateful, but in when
-// other stateful ops are e.g. `Assert`, then it will be parallelizable.
-const FunctionDef* MakeLessStatefulFunction(const FunctionDef& map_function,
-                                            bool is_stateful,
-                                            int num_placeholders,
-                                            FunctionDefLibrary* library) {
-  FunctionDef* stateless_function = library->add_function();
-  *stateless_function = map_function;
-  if (is_stateful)
-    stateless_function->mutable_signature()->set_is_stateful(is_stateful);
-  graph_utils::SetUniqueGraphFunctionName("stateless_function", library,
-                                          stateless_function);
-
-  auto* seed_arg = InsertSeedArgument(stateless_function, num_placeholders);
-
-  auto* const random_uniform = stateless_function->mutable_node_def(
-      function_utils::FindFunctionNodeWithOp("RandomUniform",
-                                             *stateless_function));
-
-  // Replace RandomUniform node with StatelessRandomUniform.
-  random_uniform->set_op("StatelessRandomUniform");
-  random_uniform->add_input(seed_arg->name());
-  (*random_uniform->mutable_attr())["Tseed"].set_type(DT_INT64);
-  random_uniform->mutable_attr()->erase("seed");
-  random_uniform->mutable_attr()->erase("seed2");
-
-  return stateless_function;
-}
-// This function returns true if function is stateful and has single
-// RandomUniform op and no other stateful ops except Assert and If/While.
-// `is_stateful_after_hoisting` is set to true if RandomUniform is the only
-// stateful op and hoisting can be performed.
-bool CanHoistRandomUniform(const FunctionDef& map_function,
-                           const FunctionLibraryDefinition& library,
-                           bool* is_stateful_after_hoisting,
-                           const NodeDef** random_uniform_op) {
-  if (!map_function.signature().is_stateful()) return false;
-  *is_stateful_after_hoisting = true;
-
-  bool have_other_stateful_ops = false;
-
-  for (const auto& node : map_function.node_def()) {
-    const OpDef* op_def;
-    TF_CHECK_OK(library.LookUpOpDef(node.op(), &op_def));
-    if (!op_def->is_stateful()) continue;
-
-    if (!function_utils::IsNodeStateful(library, node, true)) {
-      // Skip ops that are marked stateful but are in fact not stateful.
-      have_other_stateful_ops = true;
-      continue;
-    }
-
-    // TODO(prazek): For now we only handle RandomUniform, we should handle
-    // RandomUniformInt as well.
-    if (op_def->name() != "RandomUniform") return false;
-
-    // TODO(prazek): For now we can only hoist single RandomUniform.
-    if (*random_uniform_op != nullptr) return false;
-
-    *random_uniform_op = &node;
-  }
-
-  if (!have_other_stateful_ops) *is_stateful_after_hoisting = false;
-
-  // Have we found single RandomUniform?
-  return *random_uniform_op != nullptr;
-}
-
-int NumberOfPlaceholders(const NodeDef& map_node) {
-  // First input of MapDataset is the argument to the function.  Rest of the
-  // inputs are placeholders.
-  return map_node.input_size() - 1;
-}
-
-}  // namespace
-
-Status HoistRandomUniform::OptimizeAndCollectStats(Cluster* cluster,
-                                                   const GrapplerItem& item,
-                                                   GraphDef* output,
-                                                   OptimizationStats* stats) {
-  *output = item.graph;
-
-  MutableGraphView graph(output);
-  absl::flat_hash_set<string> nodes_to_delete;
-  FunctionLibraryDefinition function_library(OpRegistry::Global(),
-                                             item.graph.library());
-
-  auto get_map_node = [](const NodeDef& node) -> const NodeDef* {
-    // TODO(prazek): we could also handle ParallelMapDataset and
-    // MapAndBatchDataset.
-    if (node.op() == "MapDataset") return &node;
-    return nullptr;
-  };
-
-  for (const NodeDef& node : item.graph.node()) {
-    const NodeDef* map_node = get_map_node(node);
-    if (!map_node) continue;
-
-    const auto& fun = map_node->attr().at("f");
-    const FunctionDef* func = function_library.Find(fun.func().name());
-
-    const NodeDef* random_uniform_op = nullptr;
-    bool is_stateful_after_hoisting = true;
-    if (!CanHoistRandomUniform(*func, function_library,
-                               &is_stateful_after_hoisting, &random_uniform_op))
-      continue;
-    const auto* random_seed_dataset =
-        graph.AddNode(MakeRandomDataset(*random_uniform_op, &graph));
-
-    const auto* batch_dataset =
-        graph.AddNode(MakeBatchTwo(*random_seed_dataset, &graph));
-
-    const NodeDef& parent_node = *graph_utils::GetInputNode(*map_node, graph);
-
-    const auto* zip_node =
-        graph.AddNode(MakeZipNode(parent_node, *batch_dataset, &graph));
-
-    const auto* stateless_func = MakeLessStatefulFunction(
-        *func, is_stateful_after_hoisting, NumberOfPlaceholders(*map_node),
-        output->mutable_library());
-
-    const auto* stateless_map = graph.AddNode(
-        MakeStatelessMap(*map_node, *zip_node, *stateless_func, &graph));
-
-    TF_RETURN_IF_ERROR(
-        graph.UpdateFanouts(map_node->name(), stateless_map->name()));
-
-    // TODO(b/116285210): we could also remove map functions from library if
-    // they are not used anymore.
-    nodes_to_delete.insert(map_node->name());
-    stats->num_changes++;
-  }
-
-  TF_RETURN_IF_ERROR(graph.DeleteNodes(nodes_to_delete));
-  return Status::OK();
-}
-
-REGISTER_GRAPH_OPTIMIZER_AS(HoistRandomUniform, "hoist_random_uniform");
-
-}  // namespace grappler
-}  // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h
deleted file mode 100644
index 77562f0..0000000
--- a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h
+++ /dev/null
@@ -1,55 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_HOIST_RANDOM_UNIFORM_H_
-#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_HOIST_RANDOM_UNIFORM_H_
-
-#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h"
-
-namespace tensorflow {
-namespace grappler {
-
-// This optimization hoists instances of `random_uniform` out of a function
-// with the aim of making it stateless.  It creates a new function that takes a
-// random seed as an extra argument and uses `stateless_random_uniform` instead
-// of `random_uniform` to make it stateless.
-// It also creates RandomDataset(seed).batch(2), which is zipped with old input
-// to the map.  The batching in RandomDataset is because we need 2 seeds for
-// `stateless_random_uniform`.
-// TODO(prazek): for now only `RandomUniform` is handled, but we could handle
-// `RandomUniformInt` similarly.
-class HoistRandomUniform : public TFDataOptimizerBase {
- public:
-  HoistRandomUniform() = default;
-  ~HoistRandomUniform() override = default;
-
-  string name() const override { return "hoist_random_uniform"; };
-
-  bool UsesFunctionLibrary() const override { return true; }
-
-  Status Init(
-      const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
-    return Status::OK();
-  }
-
-  Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item,
-                                 GraphDef* output,
-                                 OptimizationStats* stats) override;
-};
-
-}  // namespace grappler
-}  // namespace tensorflow
-
-#endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_HOIST_RANDOM_UNIFORM_H_
diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc
deleted file mode 100644
index b6a29a4..0000000
--- a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc
+++ /dev/null
@@ -1,84 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h"
-
-#include "tensorflow/core/framework/function_testlib.h"
-#include "tensorflow/core/framework/tensor_shape.h"
-#include "tensorflow/core/framework/tensor_testutil.h"
-#include "tensorflow/core/grappler/grappler_item.h"
-#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
-#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/platform/test.h"
-
-namespace tensorflow {
-namespace grappler {
-namespace {
-
-TEST(HoistRandomUniform, SimpleHoisting) {
-  using test::function::NDef;
-  GrapplerItem item;
-  item.graph = test::function::GDef(
-      {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
-       NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
-       NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
-       NDef("filename", "Const", {}, {{"value", ""}, {"dtype", DT_STRING}}),
-       NDef("range", "RangeDataset", {"start", "stop", "step"},
-            {{"output_shapes", gtl::ArraySlice<TensorShape>{}},
-             {"output_types", gtl::ArraySlice<DataType>{}}}),
-       graph_tests_utils::MakeMapNode("map1", "range", "RandomUniform"),
-       NDef("cache", "CacheDataset", {"map1", "filename"}, {})},
-      // FunctionLib
-      {
-          test::function::RandomUniform(),
-      });
-
-  HoistRandomUniform optimizer;
-  GraphDef output;
-  TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
-
-  EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map1", output));
-  const int new_map_id = graph_utils::FindGraphNodeWithOp("MapDataset", output);
-  const int zip_dataset_id =
-      graph_utils::FindGraphNodeWithOp("ZipDataset", output);
-  const int random_dataset_id =
-      graph_utils::FindGraphNodeWithOp("ExperimentalRandomDataset", output);
-  const int batch_random_id =
-      graph_utils::FindGraphNodeWithOp("BatchDatasetV2", output);
-  ASSERT_NE(random_dataset_id, -1);
-  ASSERT_NE(zip_dataset_id, -1);
-  ASSERT_NE(new_map_id, -1);
-  ASSERT_NE(batch_random_id, -1);
-
-  const auto& new_map = output.node(new_map_id);
-  const auto& zip = output.node(zip_dataset_id);
-  const auto& random = output.node(random_dataset_id);
-  const auto& batch = output.node(batch_random_id);
-
-  ASSERT_EQ(new_map.input_size(), 1);
-  EXPECT_EQ(new_map.input(0), zip.name());
-
-  ASSERT_EQ(zip.input_size(), 2);
-  EXPECT_EQ(zip.input(0), "range");
-  EXPECT_EQ(zip.input(1), batch.name());
-
-  ASSERT_EQ(batch.input_size(), 3);
-  EXPECT_EQ(batch.input(0), random.name());
-}
-
-}  // namespace
-}  // namespace grappler
-}  // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc b/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc
deleted file mode 100644
index 2638d42..0000000
--- a/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc
+++ /dev/null
@@ -1,121 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/grappler/optimizers/data/latency_all_edges.h"
-
-#include "tensorflow/core/data/stats_utils.h"
-#include "tensorflow/core/framework/attr_value.pb.h"
-#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/grappler/clusters/cluster.h"
-#include "tensorflow/core/grappler/grappler_item.h"
-#include "tensorflow/core/grappler/mutable_graph_view.h"
-#include "tensorflow/core/grappler/op_types.h"
-#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
-#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
-#include "tensorflow/core/grappler/utils.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/protobuf.h"
-
-namespace tensorflow {
-namespace grappler {
-namespace {
-
-constexpr char kInsertOpName[] = "LatencyStatsDataset";
-constexpr char kModelDataset[] = "ModelDataset";
-
-// Creates a LatencyStatsDataset node whose input is `node`.
-Status MakeLatencyNode(const NodeDef& node, MutableGraphView* graph,
-                       NodeDef* result) {
-  result->set_op(kInsertOpName);
-  graph_utils::SetUniqueGraphNodeName(strings::StrCat(kInsertOpName),
-                                      graph->graph(), result);
-  // Set the input of LatencyDataset node as `node`
-  result->add_input(node.name());
-
-  string tag_name = strings::StrCat("record_latency",
-                                    data::stats_utils::kDelimiter, node.name());
-  NodeDef* tag = graph_utils::AddScalarConstNode<StringPiece>(
-      StringPiece(tag_name), graph);
-  result->add_input(tag->name());
-
-  // Set `output_types` and `output_shapes` attributes by copying the relevant
-  // attrs from the input node. This is an imperfect heuristic; some dataset ops
-  // might not have these attrs. If we encounter such an op, return an error
-  // instead of creating a node.
-  for (auto key : {"output_shapes", "output_types"}) {
-    if (node.attr().find(key) != node.attr().end()) {
-      (*result->mutable_attr())[key] = node.attr().at(key);
-    } else {
-      const char* kInferredAttrPrefix = "T";
-      if (node.attr().find(strings::StrCat(kInferredAttrPrefix, key)) !=
-          node.attr().end()) {
-        (*result->mutable_attr())[key] =
-            node.attr().at(strings::StrCat(kInferredAttrPrefix, key));
-      } else {
-        return errors::InvalidArgument(
-            "Could not create LatencyStatsDataset after ", node.op(),
-            " node because it does not have a (T)output_types or output_shapes "
-            "attr.");
-      }
-    }
-  }
-  return Status::OK();
-}
-
-}  // namespace
-
-Status LatencyAllEdges::OptimizeAndCollectStats(Cluster* cluster,
-                                                const GrapplerItem& item,
-                                                GraphDef* output,
-                                                OptimizationStats* stats) {
-  *output = item.graph;
-  MutableGraphView graph(output);
-
-  // Add LatencyDatasetOp node after each node.
-  // TODO(shivaniagrawal): Add Op to return Latency for the particular Op than
-  // for the edge (e2 - e1?).
-  for (const NodeDef& node : item.graph.node()) {
-    if (!absl::EndsWith(node.op(), "Dataset") || node.attr().empty()) {
-      // TODO(b/111805951): Replace this with non-approximate way to check if
-      // node corresponds to a `Dataset` op.
-      continue;
-    }
-    // We don't add LatencyStatsDataset after ModelDataset.
-    if (node.op() == kModelDataset) {
-      continue;
-    }
-
-    NodeDef latency_node;
-    // Try to make a latency node. This may fail if the input node doesn't have
-    // output_types or output_shapes attrs. In those cases, we don't add a node
-    // after `node`.
-    Status s = MakeLatencyNode(node, &graph, &latency_node);
-    if (s.ok()) {
-      NodeDef* latency_node_pointer = graph.AddNode(std::move(latency_node));
-      TF_RETURN_IF_ERROR(
-          graph.UpdateFanouts(node.name(), latency_node_pointer->name()));
-      stats->num_changes++;
-    } else {
-      LOG(WARNING) << s.error_message();
-    }
-  }
-  return Status::OK();
-}
-
-REGISTER_GRAPH_OPTIMIZER_AS(LatencyAllEdges, "latency_all_edges");
-
-}  // namespace grappler
-}  // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/latency_all_edges.h b/tensorflow/core/grappler/optimizers/data/latency_all_edges.h
deleted file mode 100644
index 038da34..0000000
--- a/tensorflow/core/grappler/optimizers/data/latency_all_edges.h
+++ /dev/null
@@ -1,46 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_LATENCY_ALL_EDGES_H_
-#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_LATENCY_ALL_EDGES_H_
-
-#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h"
-
-namespace tensorflow {
-namespace grappler {
-
-class LatencyAllEdges : public TFDataOptimizerBase {
- public:
-  LatencyAllEdges() = default;
-  ~LatencyAllEdges() override = default;
-
-  string name() const override { return "latency_all_edges"; };
-
-  bool UsesFunctionLibrary() const override { return false; }
-
-  Status Init(
-      const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
-    return Status::OK();
-  }
-
-  Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item,
-                                 GraphDef* output,
-                                 OptimizationStats* stats) override;
-};
-
-}  // namespace grappler
-}  // namespace tensorflow
-
-#endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_LATENCY_ALL_EDGES_H_
diff --git a/tensorflow/core/grappler/optimizers/data/latency_all_edges_test.cc b/tensorflow/core/grappler/optimizers/data/latency_all_edges_test.cc
deleted file mode 100644
index 71d761a..0000000
--- a/tensorflow/core/grappler/optimizers/data/latency_all_edges_test.cc
+++ /dev/null
@@ -1,107 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/grappler/optimizers/data/latency_all_edges.h"
-
-#include "tensorflow/core/framework/attr_value_util.h"
-#include "tensorflow/core/framework/function_testlib.h"
-#include "tensorflow/core/grappler/grappler_item.h"
-#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/platform/test.h"
-
-namespace tensorflow {
-namespace grappler {
-namespace {
-
-class LatencyAllEdgesTest : public ::testing::TestWithParam<bool> {};
-
-TEST_P(LatencyAllEdgesTest, AddLatenciesAfterTensorMapPrefetch) {
-  const bool contain_model = GetParam();
-  using test::function::NDef;
-  GrapplerItem item;
-  NodeDef component_node =
-      NDef("component_node", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}});
-  NodeDef from_tensor_node =
-      NDef("from_tensor_node", "TensorDataset", {"component_node"},
-           {{"Toutput_types", {}}, {"output_shapes", {}}});
-
-  NodeDef captured_input_node = NDef("captured_input_node", "Const", {},
-                                     {{"value", ""}, {"dtype", DT_STRING}});
-  NodeDef map_node = NDef("map_node", "MapDataset",
-                          {"from_tensor_node", "captured_input_node"},
-                          {{"f", {}},
-                           {"Targumemts", {}},
-                           {"output_shapes", {}},
-                           {"output_types", {}}});
-  NodeDef buffer_size_node = NDef("buffer_size_node", "Const", {},
-                                  {{"value", 1}, {"dtype", DT_INT32}});
-  NodeDef prefetch_node =
-      NDef("prefetch_node", "PrefetchDataset", {"map_node", "buffer_size_node"},
-           {{"output_shapes", {}}, {"output_types", {}}});
-  NodeDef model_node = NDef("model_node", "ModelDataset", {"prefetch_node"},
-                            {{"output_shapes", {}}, {"output_types", {}}});
-
-  if (contain_model) {
-    item.graph = test::function::GDef(
-        {component_node, from_tensor_node, captured_input_node, map_node,
-         buffer_size_node, prefetch_node, model_node});
-  } else {
-    item.graph = test::function::GDef({component_node, from_tensor_node,
-                                       captured_input_node, map_node,
-                                       buffer_size_node, prefetch_node});
-  }
-
-  LatencyAllEdges optimizer;
-  GraphDef output;
-  TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
-
-  EXPECT_EQ(output.node_size(), contain_model ? 13 : 12);
-  EXPECT_TRUE(graph_utils::ContainsNodeWithOp("LatencyStatsDataset", output));
-  std::vector<int> latency_node_indices =
-      graph_utils::FindAllGraphNodesWithOp("LatencyStatsDataset", output);
-  EXPECT_EQ(latency_node_indices.size(), 3);
-  std::vector<NodeDef> dataset_nodes = {std::move(from_tensor_node),
-                                        std::move(map_node),
-                                        std::move(prefetch_node)};
-  for (int i = 0; i < latency_node_indices.size(); i++) {
-    NodeDef latency_node = output.node(latency_node_indices[i]);
-    EXPECT_EQ(latency_node.input_size(), 2);
-    EXPECT_EQ(latency_node.input(0), dataset_nodes[i].name());
-    EXPECT_TRUE(
-        AreAttrValuesEqual(latency_node.attr().at("output_shapes"),
-                           dataset_nodes[i].attr().at("output_shapes")));
-    if (dataset_nodes[i].attr().find("output_types") !=
-        dataset_nodes[i].attr().end()) {
-      EXPECT_TRUE(
-          AreAttrValuesEqual(latency_node.attr().at("output_types"),
-                             dataset_nodes[i].attr().at("output_types")));
-    } else {
-      if (dataset_nodes[i].attr().find("Toutput_types") !=
-          dataset_nodes[i].attr().end()) {
-        EXPECT_TRUE(
-            AreAttrValuesEqual(latency_node.attr().at("output_types"),
-                               dataset_nodes[i].attr().at("Toutput_types")));
-      }
-    }
-  }
-}
-
-INSTANTIATE_TEST_SUITE_P(Test, LatencyAllEdgesTest,
-                         ::testing::Values(false, true));
-
-}  // namespace
-}  // namespace grappler
-}  // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc
index 3a199d2..63526e0 100644
--- a/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc
@@ -35,23 +35,19 @@
     std::map<string, tensorflow::RewriterConfig_CustomGraphOptimizer>;
 
 // tf.data optimizations, in the order we want to perform them.
-constexpr std::array<const char*, 20> kTFDataOptimizations = {
+constexpr std::array<const char*, 16> kTFDataOptimizations = {
     "noop_elimination",
     "disable_intra_op_parallelism",
     "use_private_thread_pool",
     "shuffle_and_repeat_fusion",
     "map_fusion",
     "filter_fusion",
-    "filter_with_random_uniform_fusion",
     "map_and_filter_fusion",
-    "hoist_random_uniform",
     "map_parallelization",
     "map_and_batch_fusion",
     "batch_parallelization",
-    "latency_all_edges",
     "make_sloppy",
     "parallel_batch",
-    "reorder_data_discarding_ops",
     "slack",
     "autotune_buffer_sizes",
     "disable_prefetch_legacy_autotune",
diff --git a/tensorflow/core/grappler/optimizers/data/reorder_data_discarding_ops.cc b/tensorflow/core/grappler/optimizers/data/reorder_data_discarding_ops.cc
deleted file mode 100644
index 59fbb4e..0000000
--- a/tensorflow/core/grappler/optimizers/data/reorder_data_discarding_ops.cc
+++ /dev/null
@@ -1,128 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/grappler/optimizers/data/reorder_data_discarding_ops.h"
-
-#include "absl/container/flat_hash_set.h"
-#include "tensorflow/core/framework/attr_value.pb.h"
-#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/grappler/clusters/cluster.h"
-#include "tensorflow/core/grappler/grappler_item.h"
-#include "tensorflow/core/grappler/mutable_graph_view.h"
-#include "tensorflow/core/grappler/op_types.h"
-#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
-#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
-#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
-#include "tensorflow/core/grappler/utils.h"
-#include "tensorflow/core/platform/protobuf.h"
-
-namespace tensorflow {
-namespace grappler {
-namespace {
-
-constexpr char kReorderDataDiscardingOpPrefix[] =
-    "reorder_data_discarding_ops/";
-
-constexpr std::array<const char*, 3> kDataDiscarding = {
-    "ShardDataset",
-    "SkipDataset",
-    "TakeDataset",
-};
-
-// TODO(zilinzhu): Support memory cache op when file cache op and
-// memory cache op are separated.
-const std::array<const char*, 4> kCardinalityPreserving = {
-    "PrefetchDataset",
-    "MapDataset",
-    "ParallelMapDataset",
-    "ParallelMapDatasetV2",
-};
-
-bool IsDataDiscarding(const NodeDef& node) {
-  for (const auto& discard_op : kDataDiscarding) {
-    if (node.op() == discard_op) {
-      return true;
-    }
-  }
-  return false;
-}
-
-bool IsCardinalityPreserving(const NodeDef& node) {
-  for (const auto& cardinality_preserving_op : kCardinalityPreserving) {
-    if (node.op() != cardinality_preserving_op) {
-      continue;
-    }
-    // Map ops with preserve_cardinality=false do not qualify.
-    auto attr = node.attr().find("preserve_cardinality");
-    if (attr != node.attr().end() && !attr->second.b()) {
-      return false;
-    }
-    return true;
-  }
-  return false;
-}
-
-}  // namespace
-
-Status ReorderDataDiscardingOps::OptimizeAndCollectStats(
-    Cluster* cluster, const GrapplerItem& item, GraphDef* output,
-    OptimizationStats* stats) {
-  *output = item.graph;
-  MutableGraphView graph(output);
-  bool updated;
-  do {
-    updated = false;
-    for (int i = 0; i < graph.graph()->node_size(); ++i) {
-      NodeDef* discard_node = graph.graph()->mutable_node(i);
-      if (!IsDataDiscarding(*discard_node)) {
-        continue;
-      }
-      NodeDef* start = discard_node;
-      NodeDef* start_parent = graph_utils::GetInputNode(*start, graph);
-      while (IsCardinalityPreserving(*start_parent)) {
-        start = start_parent;
-        start_parent = graph_utils::GetInputNode(*start, graph);
-      }
-      if (start->name() == discard_node->name()) {
-        continue;
-      }
-      NodeDef* parent = graph_utils::GetInputNode(*discard_node, graph);
-      TF_RETURN_IF_ERROR(
-          graph.UpdateFanouts(discard_node->name(), parent->name()));
-      if (!absl::StartsWith(discard_node->name(),
-                            kReorderDataDiscardingOpPrefix)) {
-        TF_RETURN_IF_ERROR(
-            graph.UpdateNodeName(discard_node->name(),
-                                 strings::StrCat(kReorderDataDiscardingOpPrefix,
-                                                 discard_node->name()),
-                                 false));
-      }
-      for (const auto& attr_name : {"output_types", "output_shapes"}) {
-        graph_utils::CopyAttribute(attr_name, *start_parent, discard_node);
-      }
-      *discard_node->mutable_input(0) = start_parent->name();
-      *start->mutable_input(0) = discard_node->name();
-      updated = true;
-      break;
-    }
-  } while (updated);
-  return Status::OK();
-}
-
-REGISTER_GRAPH_OPTIMIZER_AS(ReorderDataDiscardingOps,
-                            "reorder_data_discarding_ops");
-
-}  // namespace grappler
-}  // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/reorder_data_discarding_ops.h b/tensorflow/core/grappler/optimizers/data/reorder_data_discarding_ops.h
deleted file mode 100644
index 72ed68d..0000000
--- a/tensorflow/core/grappler/optimizers/data/reorder_data_discarding_ops.h
+++ /dev/null
@@ -1,49 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_REORDER_DATA_DISCARDING_OPS_H_
-#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_REORDER_DATA_DISCARDING_OPS_H_
-
-#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h"
-
-namespace tensorflow {
-namespace grappler {
-
-// This optimization reorders the data discarding ops (such as `skip`, `take`
-// and `shard`) to avoid unnecessary computation,
-// e.g. reordering ds.map(...).take(5) to ds.take(5).map(...).
-class ReorderDataDiscardingOps : public TFDataOptimizerBase {
- public:
-  ReorderDataDiscardingOps() = default;
-  ~ReorderDataDiscardingOps() override = default;
-
-  string name() const override { return "reorder_data_discarding_ops"; };
-
-  bool UsesFunctionLibrary() const override { return false; }
-
-  Status Init(
-      const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
-    return Status::OK();
-  }
-
-  Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item,
-                                 GraphDef* output,
-                                 OptimizationStats* stats) override;
-};
-
-}  // namespace grappler
-}  // namespace tensorflow
-
-#endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_REORDER_DATA_DISCARDING_OPS_H_
diff --git a/tensorflow/core/grappler/optimizers/data/reorder_data_discarding_ops_test.cc b/tensorflow/core/grappler/optimizers/data/reorder_data_discarding_ops_test.cc
deleted file mode 100644
index 743769e..0000000
--- a/tensorflow/core/grappler/optimizers/data/reorder_data_discarding_ops_test.cc
+++ /dev/null
@@ -1,94 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/grappler/optimizers/data/reorder_data_discarding_ops.h"
-
-#include "tensorflow/core/framework/attr_value_util.h"
-#include "tensorflow/core/framework/function_testlib.h"
-#include "tensorflow/core/framework/tensor_testutil.h"
-#include "tensorflow/core/grappler/grappler_item.h"
-#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
-#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/platform/test.h"
-
-namespace tensorflow {
-namespace grappler {
-namespace {
-
-TEST(ReorderDataDiscardingOpsTest, ExampleOps) {
-  using test::function::NDef;
-  GrapplerItem item;
-  item.graph = test::function::GDef(
-      {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
-       NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
-       NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
-       NDef("range", "RangeDataset", {"start", "stop", "step"},
-            {
-                {"output_shapes", gtl::ArraySlice<TensorShape>{}},
-                {"output_types", gtl::ArraySlice<DataType>{}},
-            }),
-       graph_tests_utils::MakeMapNode("map", "range", "XTimesTwo"),
-       NDef("take_count", "Const", {}, {{"value", 5}, {"dtype", DT_INT32}}),
-       graph_tests_utils::MakeTakeNode("take", "map", "take_count"),
-       NDef("skip_count", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
-       graph_tests_utils::MakeSkipNode("skip", "take", "skip_count"),
-       NDef("batch_size", "Const", {}, {{"value", 2}, {"dtype", DT_INT32}}),
-       NDef("drop_remainder", "Const", {},
-            {{"value", true}, {"dtype", DT_BOOL}}),
-       graph_tests_utils::MakeMapAndBatchNode("map_and_batch", "skip",
-                                              "batch_size", "drop_remainder",
-                                              "XTimesTwo"),
-       NDef("num_shards", "Const", {}, {{"value", 2}, {"dtype", DT_INT32}}),
-       NDef("index", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
-       graph_tests_utils::MakeShardNode("shard", "map_and_batch", "num_shards",
-                                        "index")},
-      // FunctionLib
-      {
-          test::function::XTimesTwo(),
-      });
-
-  ReorderDataDiscardingOps optimizer;
-  GraphDef output;
-  TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
-
-  EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName(
-      "reorder_data_discarding_ops/take", output));
-  EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName(
-      "reorder_data_discarding_ops/skip", output));
-  EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName(
-      "reorder_data_discarding_ops/shard", output));
-
-  EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("take", output));
-  EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("skip", output));
-  EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName("shard", output));
-
-  MutableGraphView graph(&output);
-  EXPECT_EQ(graph_utils::GetInputNode(
-                *graph.GetNode("reorder_data_discarding_ops/take"), graph)
-                ->name(),
-            "range");
-  EXPECT_EQ(graph_utils::GetInputNode(
-                *graph.GetNode("reorder_data_discarding_ops/skip"), graph)
-                ->name(),
-            "reorder_data_discarding_ops/take");
-  EXPECT_EQ(
-      graph_utils::GetInputNode(*graph.GetNode("map_and_batch"), graph)->name(),
-      "map");
-}
-
-}  // namespace
-}  // namespace grappler
-}  // namespace tensorflow
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD
index b83d047..bfc1987 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD
@@ -38,31 +38,6 @@
     ],
 )
 
-tf_py_test(
-    name = "filter_with_random_uniform_fusion_test",
-    size = "medium",
-    srcs = ["filter_with_random_uniform_fusion_test.py"],
-    tags = [
-        "manual",
-        "no_oss",
-        "no_pip",
-        "no_windows",
-        "notap",  # TODO(b/131229793)
-    ],
-    deps = [
-        "//tensorflow/python:client_testlib",
-        "//tensorflow/python:constant_op",
-        "//tensorflow/python:dtypes",
-        "//tensorflow/python:errors",
-        "//tensorflow/python:math_ops",
-        "//tensorflow/python/data/experimental/ops:optimization_options",
-        "//tensorflow/python/data/experimental/ops:testing",
-        "//tensorflow/python/data/kernel_tests:test_base",
-        "//tensorflow/python/data/ops:dataset_ops",
-        "@absl_py//absl/testing:parameterized",
-    ],
-)
-
 cuda_py_test(
     name = "grappler_test",
     size = "medium",
@@ -86,27 +61,6 @@
 )
 
 tf_py_test(
-    name = "hoist_random_uniform_test",
-    size = "small",
-    srcs = ["hoist_random_uniform_test.py"],
-    deps = [
-        "//tensorflow/python:client_testlib",
-        "//tensorflow/python:constant_op",
-        "//tensorflow/python:control_flow_ops",
-        "//tensorflow/python:dtypes",
-        "//tensorflow/python:errors",
-        "//tensorflow/python:framework_ops",
-        "//tensorflow/python:math_ops",
-        "//tensorflow/python:random_ops",
-        "//tensorflow/python/data/experimental/ops:optimization_options",
-        "//tensorflow/python/data/experimental/ops:testing",
-        "//tensorflow/python/data/kernel_tests:test_base",
-        "//tensorflow/python/data/ops:dataset_ops",
-        "@absl_py//absl/testing:parameterized",
-    ],
-)
-
-tf_py_test(
     name = "map_and_batch_fusion_test",
     size = "small",
     srcs = ["map_and_batch_fusion_test.py"],
@@ -217,24 +171,6 @@
 )
 
 tf_py_test(
-    name = "reorder_data_discarding_ops_test",
-    size = "small",
-    srcs = ["reorder_data_discarding_ops_test.py"],
-    tags = [
-        "no_oss",
-        "no_pip",
-        "no_windows",
-    ],
-    deps = [
-        "//tensorflow/python:client_testlib",
-        "//tensorflow/python:errors",
-        "//tensorflow/python/data/experimental/ops:testing",
-        "//tensorflow/python/data/kernel_tests:test_base",
-        "//tensorflow/python/data/ops:dataset_ops",
-    ],
-)
-
-tf_py_test(
     name = "shuffle_and_repeat_fusion_test",
     size = "small",
     srcs = ["shuffle_and_repeat_fusion_test.py"],
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/filter_with_random_uniform_fusion_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/filter_with_random_uniform_fusion_test.py
deleted file mode 100644
index 7600625..0000000
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/filter_with_random_uniform_fusion_test.py
+++ /dev/null
@@ -1,49 +0,0 @@
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the `FilterWithRandomUniformFusion` optimization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from absl.testing import parameterized
-
-from tensorflow.python.data.experimental.ops import testing
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import combinations
-from tensorflow.python.ops import random_ops
-from tensorflow.python.platform import test
-
-
-class FilterWithRandomUniformFusionTest(test_base.DatasetTestBase,
-                                        parameterized.TestCase):
-
-  @combinations.generate(test_base.default_test_combinations())
-  def testFilterWithRandomUniformFusion(self):
-    dataset = dataset_ops.Dataset.range(10000000).apply(
-        testing.assert_next(["Sampling"]))
-    dataset = dataset.filter(lambda _: random_ops.random_uniform([]) < 0.05)
-
-    options = dataset_ops.Options()
-    options.experimental_optimization.apply_default_optimizations = False
-    options.experimental_optimization.filter_with_random_uniform_fusion = True
-    dataset = dataset.with_options(options)
-
-    get_next = self.getNext(dataset)
-    self.evaluate(get_next())
-
-
-if __name__ == "__main__":
-  test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py
deleted file mode 100644
index 1097b1e..0000000
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py
+++ /dev/null
@@ -1,120 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the `HoistRandomUniform` optimization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import functools
-
-from absl.testing import parameterized
-
-from tensorflow.python.data.experimental.ops import testing
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import combinations
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.platform import test
-
-
-def _test_combinations():
-  def random(_):
-    return random_ops.random_uniform([],
-                                     minval=1,
-                                     maxval=10,
-                                     dtype=dtypes.float32,
-                                     seed=42)
-
-  def random_with_assert(x):
-    y = random(x)
-    assert_op = control_flow_ops.Assert(math_ops.greater_equal(y, 1), [y])
-    with ops.control_dependencies([assert_op]):
-      return y
-
-  cases = [
-      ("Increment", lambda x: x + 1, False),
-      ("Random", random, True),
-      ("RandomWithAssert", random_with_assert, True),
-      ("Complex", lambda x: (random(x) + random(x)) / 2, False),
-  ]
-
-  def reduce_fn(x, y):
-    name, map_fn, should_optimize = y
-    return x + combinations.combine(
-        map_fn=combinations.NamedObject(name, map_fn),
-        should_optimize=should_optimize)
-
-  return functools.reduce(reduce_fn, cases, [])
-
-
-class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
-
-  def _testDataset(self, dataset):
-    previous_result = 0
-    get_next = self.getNext(dataset)
-    for _ in range(5):
-      result = self.evaluate(get_next())
-      self.assertLessEqual(1, result)
-      self.assertLessEqual(result, 10)
-      # This checks if the result is somehow random by checking if we are not
-      # generating the same values.
-      self.assertNotEqual(previous_result, result)
-      previous_result = result
-    with self.assertRaises(errors.OutOfRangeError):
-      self.evaluate(get_next())
-    with self.assertRaises(errors.OutOfRangeError):
-      self.evaluate(get_next())
-
-  @combinations.generate(
-      combinations.times(test_base.default_test_combinations(),
-                         _test_combinations()))
-  def testHoistFunction(self, map_fn, should_optimize):
-    dataset = dataset_ops.Dataset.range(5).apply(
-        testing.assert_next(
-            ["Zip[0]", "Map"] if should_optimize else ["Map"])).map(map_fn)
-
-    options = dataset_ops.Options()
-    options.experimental_optimization.apply_default_optimizations = False
-    options.experimental_optimization.hoist_random_uniform = True
-    dataset = dataset.with_options(options)
-    self._testDataset(dataset)
-
-  @combinations.generate(test_base.default_test_combinations())
-  def testCapturedInputs(self):
-    a = constant_op.constant(1, dtype=dtypes.float32)
-    b = constant_op.constant(0, dtype=dtypes.float32)
-    some_tensor = math_ops.mul(a, b)
-
-    def random_with_capture(_):
-      return some_tensor + random_ops.random_uniform(
-          [], minval=1, maxval=10, dtype=dtypes.float32, seed=42)
-
-    dataset = dataset_ops.Dataset.range(5).apply(
-        testing.assert_next(["Zip[0]", "Map"])).map(random_with_capture)
-    options = dataset_ops.Options()
-    options.experimental_optimization.apply_default_optimizations = False
-    options.experimental_optimization.hoist_random_uniform = True
-    dataset = dataset.with_options(options)
-    self._testDataset(dataset)
-
-
-if __name__ == "__main__":
-  test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/reorder_data_discarding_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/reorder_data_discarding_ops_test.py
deleted file mode 100644
index 66b509c..0000000
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/reorder_data_discarding_ops_test.py
+++ /dev/null
@@ -1,71 +0,0 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the `ReorderDataDiscardingOps` rewrite."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from absl.testing import parameterized
-
-from tensorflow.python.data.experimental.ops import testing
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import combinations
-from tensorflow.python.platform import test
-
-
-class ReorderDataDiscardingOpsTest(test_base.DatasetTestBase,
-                                   parameterized.TestCase):
-
-  @combinations.generate(
-      combinations.combine(tf_api_version=2, mode=["eager", "graph"]))
-  def testSimpleReorderingV2(self):
-    dataset = dataset_ops.Dataset.range(100)
-    dataset = dataset.apply(
-        testing.assert_next(
-            ["FiniteSkip", "FiniteTake", "Shard", "ParallelMap", "Prefetch"]))
-    dataset = dataset.map(lambda x: x + 1, num_parallel_calls=10)
-    dataset = dataset.skip(10)
-    dataset = dataset.prefetch(1)
-    dataset = dataset.take(50)
-    dataset = dataset.shard(2, 0)
-    options = dataset_ops.Options()
-    options.experimental_optimization.apply_default_optimizations = False
-    options.experimental_optimization.reorder_data_discarding_ops = True
-    dataset = dataset.with_options(options)
-    self.assertDatasetProduces(dataset, range(11, 61, 2))
-
-  @combinations.generate(
-      combinations.combine(tf_api_version=1, mode=["eager", "graph"]))
-  def testSimpleReorderingV1(self):
-    dataset = dataset_ops.Dataset.range(100)
-    # Map ops have preserve_cardinality=false in tensorflow v1.
-    dataset = dataset.apply(
-        testing.assert_next(
-            ["ParallelMap", "FiniteSkip", "FiniteTake", "Shard", "Prefetch"]))
-    dataset = dataset.map(lambda x: x + 1, num_parallel_calls=10)
-    dataset = dataset.skip(10)
-    dataset = dataset.prefetch(1)
-    dataset = dataset.take(50)
-    dataset = dataset.shard(2, 0)
-    options = dataset_ops.Options()
-    options.experimental_optimization.apply_default_optimizations = False
-    options.experimental_optimization.reorder_data_discarding_ops = True
-    dataset = dataset.with_options(options)
-    self.assertDatasetProduces(dataset, range(11, 61, 2))
-
-
-if __name__ == "__main__":
-  test.main()
diff --git a/tensorflow/python/data/experimental/ops/optimization_options.py b/tensorflow/python/data/experimental/ops/optimization_options.py
index 45dcb4f..998cd00 100644
--- a/tensorflow/python/data/experimental/ops/optimization_options.py
+++ b/tensorflow/python/data/experimental/ops/optimization_options.py
@@ -90,20 +90,6 @@
       docstring=
       "Whether to fuse filter transformations. If None, defaults to False.")
 
-  filter_with_random_uniform_fusion = options.create_option(
-      name="filter_with_random_uniform_fusion",
-      ty=bool,
-      docstring=
-      "Whether to fuse filter dataset that predicts random_uniform < rate into "
-      "a sampling dataset. If None, defaults to False.")
-
-  hoist_random_uniform = options.create_option(
-      name="hoist_random_uniform",
-      ty=bool,
-      docstring=
-      "Whether to hoist `tf.random_uniform()` ops out of map transformations. "
-      "If None, defaults to False.")
-
   map_and_batch_fusion = options.create_option(
       name="map_and_batch_fusion",
       ty=bool,
@@ -148,17 +134,6 @@
       "batching and b) you have validated that this optimization improves "
       "performance. If None, defaults to False.")
 
-  reorder_data_discarding_ops = options.create_option(
-      name="reorder_data_discarding_ops",
-      ty=bool,
-      docstring="Whether to reorder ops that will discard data to the front of "
-      "unary cardinality preserving transformations, e.g. "
-      "dataset.map(...).take(3) will be optimized to dataset.take(3).map(...). "
-      "For now this optimization will move `skip`, `shard` and `take` to the "
-      "front of `map` and `prefetch`. This optimization is only for "
-      "performance; it will not affect the output of the dataset. "
-      "If None, defaults to True.")
-
   shuffle_and_repeat_fusion = options.create_option(
       name="shuffle_and_repeat_fusion",
       ty=bool,
@@ -179,11 +154,6 @@
       pb.autotune_ram_budget = self.autotune_ram_budget
     if self.filter_fusion is not None:
       pb.filter_fusion = self.filter_fusion
-    if self.filter_with_random_uniform_fusion is not None:
-      pb.filter_with_random_uniform_fusion = (
-          self.filter_with_random_uniform_fusion)
-    if self.hoist_random_uniform is not None:
-      pb.hoist_random_uniform = self.hoist_random_uniform
     if self.map_and_batch_fusion is not None:
       pb.map_and_batch_fusion = self.map_and_batch_fusion
     if self.map_and_filter_fusion is not None:
@@ -196,8 +166,6 @@
       pb.noop_elimination = self.noop_elimination
     if self.parallel_batch is not None:
       pb.parallel_batch = self.parallel_batch
-    if self.reorder_data_discarding_ops is not None:
-      pb.reorder_data_discarding_ops = self.reorder_data_discarding_ops
     if self.shuffle_and_repeat_fusion is not None:
       pb.shuffle_and_repeat_fusion = self.shuffle_and_repeat_fusion
     return pb
@@ -215,11 +183,6 @@
       self.autotune_ram_budget = pb.autotune_ram_budget
     if pb.WhichOneof("optional_filter_fusion") is not None:
       self.filter_fusion = pb.filter_fusion
-    if pb.WhichOneof("optional_filter_with_random_uniform_fusion") is not None:
-      self.filter_with_random_uniform_fusion = (
-          pb.filter_with_random_uniform_fusion)
-    if pb.WhichOneof("optional_hoist_random_uniform") is not None:
-      self.hoist_random_uniform = pb.hoist_random_uniform
     if pb.WhichOneof("optional_map_and_batch_fusion") is not None:
       self.map_and_batch_fusion = pb.map_and_batch_fusion
     if pb.WhichOneof("optional_map_and_filter_fusion") is not None:
@@ -232,8 +195,6 @@
       self.noop_elimination = pb.noop_elimination
     if pb.WhichOneof("optional_parallel_batch") is not None:
       self.parallel_batch = pb.parallel_batch
-    if pb.WhichOneof("optional_reorder_data_discarding_ops") is not None:
-      self.reorder_data_discarding_ops = pb.reorder_data_discarding_ops
     if pb.WhichOneof("optional_shuffle_and_repeat_fusion") is not None:
       self.shuffle_and_repeat_fusion = pb.shuffle_and_repeat_fusion
 
diff --git a/tensorflow/python/data/kernel_tests/options_test.py b/tensorflow/python/data/kernel_tests/options_test.py
index 25bfa9a..8c00ae0 100644
--- a/tensorflow/python/data/kernel_tests/options_test.py
+++ b/tensorflow/python/data/kernel_tests/options_test.py
@@ -153,15 +153,12 @@
     options.experimental_optimization.autotune_cpu_budget = 10
     options.experimental_optimization.autotune_ram_budget = 20
     options.experimental_optimization.filter_fusion = True
-    options.experimental_optimization.filter_with_random_uniform_fusion = True
-    options.experimental_optimization.hoist_random_uniform = True
     options.experimental_optimization.map_and_batch_fusion = True
     options.experimental_optimization.map_and_filter_fusion = True
     options.experimental_optimization.map_fusion = True
     options.experimental_optimization.map_parallelization = True
     options.experimental_optimization.noop_elimination = True
     options.experimental_optimization.parallel_batch = True
-    options.experimental_optimization.reorder_data_discarding_ops = True
     options.experimental_optimization.shuffle_and_repeat_fusion = True
     options.experimental_slack = True
     options.threading.max_intra_op_parallelism = 30
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-optimization-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-optimization-options.pbtxt
index 84dce55..17c6afc 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-optimization-options.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-optimization-options.pbtxt
@@ -28,14 +28,6 @@
     mtype: "<type \'property\'>"
   }
   member {
-    name: "filter_with_random_uniform_fusion"
-    mtype: "<type \'property\'>"
-  }
-  member {
-    name: "hoist_random_uniform"
-    mtype: "<type \'property\'>"
-  }
-  member {
     name: "map_and_batch_fusion"
     mtype: "<type \'property\'>"
   }
@@ -60,10 +52,6 @@
     mtype: "<type \'property\'>"
   }
   member {
-    name: "reorder_data_discarding_ops"
-    mtype: "<type \'property\'>"
-  }
-  member {
     name: "shuffle_and_repeat_fusion"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-optimization-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-optimization-options.pbtxt
index 84dce55..17c6afc 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-optimization-options.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-optimization-options.pbtxt
@@ -28,14 +28,6 @@
     mtype: "<type \'property\'>"
   }
   member {
-    name: "filter_with_random_uniform_fusion"
-    mtype: "<type \'property\'>"
-  }
-  member {
-    name: "hoist_random_uniform"
-    mtype: "<type \'property\'>"
-  }
-  member {
     name: "map_and_batch_fusion"
     mtype: "<type \'property\'>"
   }
@@ -60,10 +52,6 @@
     mtype: "<type \'property\'>"
   }
   member {
-    name: "reorder_data_discarding_ops"
-    mtype: "<type \'property\'>"
-  }
-  member {
     name: "shuffle_and_repeat_fusion"
     mtype: "<type \'property\'>"
   }