eager support for CsvDatasetBenchmark
diff --git a/tensorflow/python/data/experimental/benchmarks/csv_dataset_benchmark.py b/tensorflow/python/data/experimental/benchmarks/csv_dataset_benchmark.py
index 9348ae8..c58eefc 100644
--- a/tensorflow/python/data/experimental/benchmarks/csv_dataset_benchmark.py
+++ b/tensorflow/python/data/experimental/benchmarks/csv_dataset_benchmark.py
@@ -21,21 +21,16 @@
import os
import string
import tempfile
-import time
-import numpy as np
-
-from tensorflow.python.client import session
from tensorflow.python.data.experimental.ops import readers
-from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.benchmarks import benchmark_base
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import gfile
from tensorflow.python.platform import googletest
-from tensorflow.python.platform import test
-class CsvDatasetBenchmark(test.Benchmark):
+class CsvDatasetBenchmark(benchmark_base.DatasetBenchmarkBase):
"""Benchmarks for `tf.data.experimental.CsvDataset`."""
FLOAT_VAL = '1.23456E12'
@@ -62,30 +57,14 @@
gfile.DeleteRecursively(self._temp_dir)
def _run_benchmark(self, dataset, num_cols, prefix):
- dataset = dataset.skip(self._num_per_iter - 1)
- options = dataset_ops.Options()
- options.experimental_optimization.apply_default_optimizations = False
- dataset = dataset.with_options(options)
- deltas = []
- for _ in range(10):
- next_element = dataset_ops.make_one_shot_iterator(dataset).get_next()
- with session.Session() as sess:
- start = time.time()
- # NOTE: This depends on the underlying implementation of skip, to have
- # the net effect of calling `GetNext` num_per_iter times on the
- # input dataset. We do it this way (instead of a python for loop, or
- # batching N inputs in one iter) so that the overhead from session.run
- # or batch doesn't dominate. If we eventually optimize skip, this has
- # to change.
- sess.run(next_element)
- end = time.time()
- deltas.append(end - start)
- # Median wall time per CSV record read and decoded
- median_wall_time = np.median(deltas) / self._num_per_iter
- self.report_benchmark(
- iters=self._num_per_iter,
- wall_time=median_wall_time,
- name='%s_with_cols_%d' % (prefix, num_cols))
+
+ self.run_and_report_benchmark(
+ dataset=dataset,
+ num_elements=self._num_per_iter,
+ name='%s_with_cols_%d' % (prefix, num_cols),
+ iters=10,
+ warmup=True
+ )
def benchmark_map_with_floats(self):
self._set_up(self.FLOAT_VAL)
@@ -93,8 +72,13 @@
num_cols = self._num_cols[i]
kwargs = {'record_defaults': [[0.0]] * num_cols}
dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
- dataset = dataset.map(lambda l: parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop
- self._run_benchmark(dataset, num_cols, 'csv_float_map_decode_csv')
+ dataset = dataset.map(lambda l: parsing_ops.decode_csv(
+ l, **kwargs)) # pylint: disable=cell-var-from-loop
+ self._run_benchmark(
+ dataset=dataset,
+ num_cols=num_cols,
+ prefix='csv_float_map_decode_csv'
+ )
self._tear_down()
def benchmark_map_with_strings(self):
@@ -103,8 +87,13 @@
num_cols = self._num_cols[i]
kwargs = {'record_defaults': [['']] * num_cols}
dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
- dataset = dataset.map(lambda l: parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop
- self._run_benchmark(dataset, num_cols, 'csv_strings_map_decode_csv')
+ dataset = dataset.map(lambda l: parsing_ops.decode_csv(
+ l, **kwargs)) # pylint: disable=cell-var-from-loop
+ self._run_benchmark(
+ dataset=dataset,
+ num_cols=num_cols,
+ prefix='csv_strings_map_decode_csv'
+ )
self._tear_down()
def benchmark_csv_dataset_with_floats(self):
@@ -113,8 +102,13 @@
num_cols = self._num_cols[i]
kwargs = {'record_defaults': [[0.0]] * num_cols}
dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
- dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop
- self._run_benchmark(dataset, num_cols, 'csv_float_fused_dataset')
+ dataset = readers.CsvDataset(
+ self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop
+ self._run_benchmark(
+ dataset=dataset,
+ num_cols=num_cols,
+ prefix='csv_float_fused_dataset'
+ )
self._tear_down()
def benchmark_csv_dataset_with_strings(self):
@@ -123,9 +117,15 @@
num_cols = self._num_cols[i]
kwargs = {'record_defaults': [['']] * num_cols}
dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
- dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop
- self._run_benchmark(dataset, num_cols, 'csv_strings_fused_dataset')
+ dataset = readers.CsvDataset(
+ self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop
+ self._run_benchmark(
+ dataset=dataset,
+ num_cols=num_cols,
+ prefix='csv_strings_fused_dataset'
+ )
self._tear_down()
+
if __name__ == '__main__':
- test.main()
+ benchmark_base.test.main()