[tf.data] Adding microbenchmark for control flow lowering.

PiperOrigin-RevId: 306781111
Change-Id: Ic34f5ae3a026cc5636b7572961018bed461ac31c
diff --git a/tensorflow/python/data/benchmarks/benchmark_base.py b/tensorflow/python/data/benchmarks/benchmark_base.py
index 40a8c61..1e1a716 100644
--- a/tensorflow/python/data/benchmarks/benchmark_base.py
+++ b/tensorflow/python/data/benchmarks/benchmark_base.py
@@ -31,7 +31,12 @@
 class DatasetBenchmarkBase(test.Benchmark):
   """Base class for dataset benchmarks."""
 
-  def run_benchmark(self, dataset, num_elements, iters=1, warmup=True):
+  def run_benchmark(self,
+                    dataset,
+                    num_elements,
+                    iters=1,
+                    warmup=True,
+                    apply_default_optimizations=False):
     """Benchmarks the dataset.
 
     Runs the dataset `iters` times. In each iteration, the benchmark measures
@@ -43,6 +48,8 @@
         iteration.
       iters: Number of times to repeat the timing.
       warmup: If true, warms up the session caches by running an untimed run.
+      apply_default_optimizations: Determines whether default optimizations
+        should be applied.
 
     Returns:
       A float, representing the per-element wall time of the dataset in seconds.
@@ -50,7 +57,8 @@
       to go through `num_elements` elements, divided by `num_elements.`
     """
     options = dataset_ops.Options()
-    options.experimental_optimization.apply_default_optimizations = False
+    options.experimental_optimization.apply_default_optimizations = (
+        apply_default_optimizations)
     dataset = dataset.with_options(options)
     # NOTE: We use `dataset.skip()` to perform the iterations in C++, avoiding
     # the overhead of multiple `session.run()` calls. Note that this relies on
@@ -82,9 +90,11 @@
                                name,
                                iters=5,
                                extras=None,
-                               warmup=True):
+                               warmup=True,
+                               apply_default_optimizations=False):
     # Measure the per-element wall time.
-    wall_time = self.run_benchmark(dataset, num_elements, iters, warmup)
+    wall_time = self.run_benchmark(dataset, num_elements, iters, warmup,
+                                   apply_default_optimizations)
 
     if extras is None:
       extras = {}
diff --git a/tensorflow/python/data/benchmarks/map_benchmark.py b/tensorflow/python/data/benchmarks/map_benchmark.py
index 6638842..aea0fe9 100644
--- a/tensorflow/python/data/benchmarks/map_benchmark.py
+++ b/tensorflow/python/data/benchmarks/map_benchmark.py
@@ -20,6 +20,12 @@
 from tensorflow.python.data.benchmarks import benchmark_base
 from tensorflow.python.data.experimental.ops import stats_aggregator
 from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import map_fn as map_fn
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
 
 
 # TODO(b/119837791): Add eager benchmarks.
@@ -28,11 +34,11 @@
 
   def benchmark_chain_of_maps(self):
 
-    def benchmark_helper(chain_length, map_fn, use_inter_op_parallelism, label):
+    def benchmark_helper(chain_length, fn, use_inter_op_parallelism, label):
       dataset = dataset_ops.Dataset.range(10000)
       for _ in range(chain_length):
         dataset = dataset_ops.MapDataset(
-            dataset, map_fn, use_inter_op_parallelism=use_inter_op_parallelism)
+            dataset, fn, use_inter_op_parallelism=use_inter_op_parallelism)
       self.run_and_report_benchmark(
           dataset,
           num_elements=10000,
@@ -47,11 +53,11 @@
   def benchmark_map_fan_out(self):
     fan_outs = [1, 2, 5, 10, 20, 50, 100]
 
-    def benchmark_helper(fan_out, map_fn, use_inter_op_parallelism, label):
+    def benchmark_helper(fan_out, fn, use_inter_op_parallelism, label):
       dataset = dataset_ops.Dataset.from_tensors(
           tuple(0 for _ in range(fan_out))).repeat(None)
       dataset = dataset_ops.MapDataset(
-          dataset, map_fn, use_inter_op_parallelism=use_inter_op_parallelism)
+          dataset, fn, use_inter_op_parallelism=use_inter_op_parallelism)
       self.run_and_report_benchmark(
           dataset,
           num_elements=10000,
@@ -76,6 +82,39 @@
       self.run_and_report_benchmark(
           dataset, num_elements=10000, name="stats_%s" % stats)
 
+  def benchmark_sequential_control_flow(self):
+    dataset = dataset_ops.Dataset.from_tensors(100000)
+
+    def fn(x):
+      i = constant_op.constant(0)
+
+      def body(i, x):
+        return math_ops.add(i, 1), x
+
+      return control_flow_ops.while_loop(math_ops.less, body, [i, x])
+
+    dataset = dataset.map(fn)
+    self.run_and_report_benchmark(
+        dataset,
+        num_elements=1,
+        name="sequential_control_flow",
+        apply_default_optimizations=True)
+
+  def benchmark_parallel_control_flow(self):
+    dataset = dataset_ops.Dataset.from_tensors(
+        random_ops.random_uniform([100, 10000000]))
+
+    def fn(x):
+      return map_fn.map_fn(
+          lambda y: y * array_ops.transpose(y), x, parallel_iterations=10)
+
+    dataset = dataset.map(fn)
+    self.run_and_report_benchmark(
+        dataset,
+        num_elements=1,
+        name="parallel_control_flow",
+        apply_default_optimizations=True)
+
 
 if __name__ == "__main__":
   benchmark_base.test.main()