eager support for OptimizationBenchmark
diff --git a/tensorflow/python/data/experimental/benchmarks/optimize_benchmark.py b/tensorflow/python/data/experimental/benchmarks/optimize_benchmark.py
index bb0c3ed..1ff595f 100644
--- a/tensorflow/python/data/experimental/benchmarks/optimize_benchmark.py
+++ b/tensorflow/python/data/experimental/benchmarks/optimize_benchmark.py
@@ -17,19 +17,13 @@
 from __future__ import division
 from __future__ import print_function
 
-import time
 
-import numpy as np
-
-from tensorflow.python.client import session
+from tensorflow.python.data.benchmarks import benchmark_base
 from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import ops
 from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
 
 
-# TODO(b/119837791): Add eager benchmarks too.
-class OptimizationBenchmark(test.Benchmark):
+class OptimizationBenchmark(benchmark_base.DatasetBenchmarkBase):
   """Benchmarks for static optimizations."""
 
   def benchmark_map_fusion(self):
@@ -37,123 +31,96 @@
 
     chain_lengths = [0, 1, 2, 5, 10, 20, 50]
     for chain_length in chain_lengths:
-      self._benchmark_map_fusion(chain_length, False)
-      self._benchmark_map_fusion(chain_length, True)
+      self._benchmark_map_fusion(chain_length=chain_length,
+                                 optimize_dataset=False)
+      self._benchmark_map_fusion(chain_length=chain_length,
+                                 optimize_dataset=True)
 
   def _benchmark_map_fusion(self, chain_length, optimize_dataset):
-    with ops.Graph().as_default():
-      dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
-      for _ in range(chain_length):
-        dataset = dataset.map(lambda x: x)
-      if optimize_dataset:
-        options = dataset_ops.Options()
-        options.experimental_optimization.apply_default_optimizations = False
-        options.experimental_optimization.map_fusion = True
-        dataset = dataset.with_options(options)
 
-      iterator = dataset_ops.make_one_shot_iterator(dataset)
-      next_element = iterator.get_next()
+    dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
+    for _ in range(chain_length):
+      dataset = dataset.map(lambda x: x)
+    if optimize_dataset:
+      options = dataset_ops.Options()
+      options.experimental_optimization.apply_default_optimizations = False
+      options.experimental_optimization.map_fusion = True
+      dataset = dataset.with_options(options)
 
-      with session.Session() as sess:
-        for _ in range(5):
-          sess.run(next_element.op)
-        deltas = []
-        for _ in range(100):
-          start = time.time()
-          for _ in range(100):
-            sess.run(next_element.op)
-          end = time.time()
-          deltas.append(end - start)
-
-        median_wall_time = np.median(deltas) / 100
-        opt_mark = "opt" if optimize_dataset else "noopt"
-        self.report_benchmark(
-            iters=100,
-            wall_time=median_wall_time,
-            name="map_fusion_{}_chain_length_{}".format(
-                opt_mark, chain_length))
+    opt_mark = "opt" if optimize_dataset else "noopt"
+    self.run_and_report_benchmark(
+        dataset=dataset,
+        num_elements=100,
+        iters=10,
+        warmup=True,
+        name="map_fusion_{}_chain_length_{}".format(
+            opt_mark, chain_length)
+    )
 
   def benchmark_map_and_filter_fusion(self):
     """Evaluates performance map of fusion."""
 
     chain_lengths = [0, 1, 2, 5, 10, 20, 50]
     for chain_length in chain_lengths:
-      self._benchmark_map_and_filter_fusion(chain_length, False)
-      self._benchmark_map_and_filter_fusion(chain_length, True)
+      self._benchmark_map_and_filter_fusion(chain_length=chain_length,
+                                            optimize_dataset=False)
+      self._benchmark_map_and_filter_fusion(chain_length=chain_length,
+                                            optimize_dataset=True)
 
   def _benchmark_map_and_filter_fusion(self, chain_length, optimize_dataset):
-    with ops.Graph().as_default():
-      dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
-      for _ in range(chain_length):
-        dataset = dataset.map(lambda x: x + 5).filter(
-            lambda x: math_ops.greater_equal(x - 5, 0))
-      if optimize_dataset:
-        options = dataset_ops.Options()
-        options.experimental_optimization.apply_default_optimizations = False
-        options.experimental_optimization.map_and_filter_fusion = True
-        dataset = dataset.with_options(options)
-      iterator = dataset_ops.make_one_shot_iterator(dataset)
-      next_element = iterator.get_next()
 
-      with session.Session() as sess:
-        for _ in range(10):
-          sess.run(next_element.op)
-        deltas = []
-        for _ in range(100):
-          start = time.time()
-          for _ in range(100):
-            sess.run(next_element.op)
-          end = time.time()
-          deltas.append(end - start)
+    dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
+    for _ in range(chain_length):
+      dataset = dataset.map(lambda x: x + 5).filter(
+          lambda x: math_ops.greater_equal(x - 5, 0))
+    if optimize_dataset:
+      options = dataset_ops.Options()
+      options.experimental_optimization.apply_default_optimizations = False
+      options.experimental_optimization.map_and_filter_fusion = True
+      dataset = dataset.with_options(options)
 
-        median_wall_time = np.median(deltas) / 100
-        opt_mark = "opt" if optimize_dataset else "noopt"
-        self.report_benchmark(
-            iters=100,
-            wall_time=median_wall_time,
-            name="map_and_filter_fusion_{}_chain_length_{}".format(
-                opt_mark, chain_length))
+    opt_mark = "opt" if optimize_dataset else "noopt"
+    self.run_and_report_benchmark(
+        dataset=dataset,
+        num_elements=100,
+        iters=10,
+        warmup=True,
+        name="map_and_filter_fusion_{}_chain_length_{}".format(
+            opt_mark, chain_length)
+    )
 
   # This benchmark compares the performance of pipeline with multiple chained
   # filter with and without filter fusion.
+
   def benchmark_filter_fusion(self):
     chain_lengths = [0, 1, 2, 5, 10, 20, 50]
     for chain_length in chain_lengths:
-      self._benchmark_filter_fusion(chain_length, False)
-      self._benchmark_filter_fusion(chain_length, True)
+      self._benchmark_filter_fusion(chain_length=chain_length,
+                                    optimize_dataset=False)
+      self._benchmark_filter_fusion(chain_length=chain_length,
+                                    optimize_dataset=True)
 
   def _benchmark_filter_fusion(self, chain_length, optimize_dataset):
-    with ops.Graph().as_default():
-      dataset = dataset_ops.Dataset.from_tensors(5).repeat(None)
-      for _ in range(chain_length):
-        dataset = dataset.filter(lambda x: math_ops.greater_equal(x - 5, 0))
-      if optimize_dataset:
-        options = dataset_ops.Options()
-        options.experimental_optimization.apply_default_optimizations = False
-        options.experimental_optimization.filter_fusion = True
-        dataset = dataset.with_options(options)
 
-      iterator = dataset_ops.make_one_shot_iterator(dataset)
-      next_element = iterator.get_next()
+    dataset = dataset_ops.Dataset.from_tensors(5).repeat(None)
+    for _ in range(chain_length):
+      dataset = dataset.filter(lambda x: math_ops.greater_equal(x - 5, 0))
+    if optimize_dataset:
+      options = dataset_ops.Options()
+      options.experimental_optimization.apply_default_optimizations = False
+      options.experimental_optimization.filter_fusion = True
+      dataset = dataset.with_options(options)
 
-      with session.Session() as sess:
-        for _ in range(10):
-          sess.run(next_element.op)
-        deltas = []
-        for _ in range(100):
-          start = time.time()
-          for _ in range(100):
-            sess.run(next_element.op)
-          end = time.time()
-          deltas.append(end - start)
-
-        median_wall_time = np.median(deltas) / 100
-        opt_mark = "opt" if optimize_dataset else "no-opt"
-        self.report_benchmark(
-            iters=1000,
-            wall_time=median_wall_time,
-            name="chain_length_{}_{}".format(opt_mark, chain_length))
+    opt_mark = "opt" if optimize_dataset else "noopt"
+    self.run_and_report_benchmark(
+        dataset=dataset,
+        num_elements=100,
+        iters=10,
+        warmup=True,
+        name="filter_fusion_{}_chain_length_{}".format(
+            opt_mark, chain_length)
+    )
 
 
 if __name__ == "__main__":
-  test.main()
+  benchmark_base.test.main()