blob: e75edf60867ac697082b219fd54bdcfd533a0453 [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the MapAndFilterFusion optimization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.contrib.data.python.ops import optimization
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase):
@staticmethod
def map_functions():
identity = lambda x: x
increment = lambda x: x + 1
def increment_and_square(x):
y = x + 1
return y * y
functions = [identity, increment, increment_and_square]
tests = []
for i, fun1 in enumerate(functions):
for j, fun2 in enumerate(functions):
tests.append((
"Test{}{}".format(i, j),
[fun1, fun2],
))
for k, fun3 in enumerate(functions):
tests.append((
"Test{}{}{}".format(i, j, k),
[fun1, fun2, fun3],
))
swap = lambda x, n: (n, x)
tests.append((
"Swap1",
[lambda x: (x, 42), swap],
))
tests.append((
"Swap2",
[lambda x: (x, 42), swap, swap],
))
return tuple(tests)
@parameterized.named_parameters(*map_functions.__func__())
def testMapFusion(self, functions):
dataset = dataset_ops.Dataset.range(5).apply(
optimization.assert_next(["Map", "Prefetch"]))
for function in functions:
dataset = dataset.map(function)
dataset = dataset.prefetch(0).apply(optimization.optimize(["map_fusion"]))
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with self.cached_session() as sess:
for x in range(5):
result = sess.run(get_next)
r = x
for function in functions:
if isinstance(r, tuple):
r = function(*r) # Pass tuple as multiple arguments.
else:
r = function(r)
self.assertAllEqual(r, result)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@staticmethod
def map_and_filter_functions():
identity = lambda x: x
increment = lambda x: x + 1
minus_five = lambda x: x - 5
def increment_and_square(x):
y = x + 1
return y * y
take_all = lambda x: constant_op.constant(True)
is_zero = lambda x: math_ops.equal(x, 0)
is_odd = lambda x: math_ops.equal(x % 2, 0)
greater = lambda x: math_ops.greater(x + 5, 0)
functions = [identity, increment, minus_five, increment_and_square]
filters = [take_all, is_zero, is_odd, greater]
tests = []
for x, fun in enumerate(functions):
for y, predicate in enumerate(filters):
tests.append(("Mixed{}{}".format(x, y), fun, predicate))
# Multi output
tests.append(("Multi1", lambda x: (x, x),
lambda x, y: constant_op.constant(True)))
tests.append(
("Multi2", lambda x: (x, 2),
lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)))
return tuple(tests)
@parameterized.named_parameters(*map_and_filter_functions.__func__())
def testMapFilterFusion(self, function, predicate):
dataset = dataset_ops.Dataset.range(10).apply(
optimization.assert_next(
["Map",
"FilterByLastComponent"])).map(function).filter(predicate).apply(
optimization.optimize(["map_and_filter_fusion"]))
self._testMapAndFilter(dataset, function, predicate)
def _testMapAndFilter(self, dataset, function, predicate):
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with self.cached_session() as sess:
for x in range(10):
r = function(x)
if isinstance(r, tuple):
b = predicate(*r) # Pass tuple as multiple arguments.
else:
b = predicate(r)
if sess.run(b):
result = sess.run(get_next)
self.assertAllEqual(r, result)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testAdditionalInputs(self):
a = constant_op.constant(3, dtype=dtypes.int64)
b = constant_op.constant(4, dtype=dtypes.int64)
some_tensor = math_ops.mul(a, b)
function = lambda x: x * x
def predicate(y):
return math_ops.less(math_ops.cast(y, dtypes.int64), some_tensor)
# We are currently not supporting functions with additional inputs.
dataset = dataset_ops.Dataset.range(10).apply(
optimization.assert_next(
["Map", "Filter"])).map(function).filter(predicate).apply(
optimization.optimize(["map_and_filter_fusion"]))
self._testMapAndFilter(dataset, function, predicate)
@staticmethod
def filter_functions():
take_all = lambda x: constant_op.constant(True)
is_zero = lambda x: math_ops.equal(x, 0)
greater = lambda x: math_ops.greater(x + 5, 0)
tests = []
filters = [take_all, is_zero, greater]
identity = lambda x: x
for x, predicate_1 in enumerate(filters):
for y, predicate_2 in enumerate(filters):
tests.append(("Mixed{}{}".format(x, y), identity,
[predicate_1, predicate_2]))
for z, predicate_3 in enumerate(filters):
tests.append(("Mixed{}{}{}".format(x, y, z), identity,
[predicate_1, predicate_2, predicate_3]))
take_all_multiple = lambda x, y: constant_op.constant(True)
# Multi output
tests.append(("Multi1", lambda x: (x, x),
[take_all_multiple, take_all_multiple]))
tests.append(("Multi2", lambda x: (x, 2), [
take_all_multiple,
lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
]))
return tuple(tests)
@parameterized.named_parameters(*filter_functions.__func__())
def testFilterFusion(self, map_function, predicates):
dataset = dataset_ops.Dataset.range(5).apply(
optimization.assert_next(["Map", "Filter",
"Prefetch"])).map(map_function)
for predicate in predicates:
dataset = dataset.filter(predicate)
dataset = dataset.prefetch(0).apply(
optimization.optimize(["filter_fusion"]))
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with self.cached_session() as sess:
for x in range(5):
r = map_function(x)
filtered = False
for predicate in predicates:
if isinstance(r, tuple):
b = predicate(*r) # Pass tuple as multiple arguments.
else:
b = predicate(r)
if not sess.run(b):
filtered = True
break
if not filtered:
result = sess.run(get_next)
self.assertAllEqual(r, result)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
if __name__ == "__main__":
test.main()