blob: a2cc54d104e2d7027dae53595a8bb5a55d0afa73 [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for MapDefunOp."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
from absl.testing import parameterized
from tensorflow.python.client import session
from tensorflow.python.data.experimental.ops import map_defun
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.eager import function
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
# TODO(b/123903858): Add eager and V2 test coverage
class MapDefunTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testNoIntraOpLimit(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
def simple_fn(x):
return x * 2 + 3
nums = [[1, 2], [3, 4], [5, 6]]
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
r = map_defun.map_defun(
simple_fn, [elems], [dtypes.int32], [(2,)],
max_intra_op_parallelism=0)[0]
expected = elems * 2 + 3
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunSimple(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
def simple_fn(x):
return x * 2 + 3
nums = [[1, 2], [3, 4], [5, 6]]
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2,)])[0]
expected = elems * 2 + 3
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunMismatchedTypes(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
def fn(x):
return math_ops.cast(x, dtypes.float64)
nums = [1, 2, 3, 4, 5, 6]
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(r)
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunReduceDim(self):
# Tests where the output has a different rank from the input
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
def fn(x):
return array_ops.gather(x, 0)
nums = [[1, 2], [3, 4], [5, 6]]
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
expected = constant_op.constant([1, 3, 5])
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunMultipleOutputs(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
def fn(x):
return (x, math_ops.cast(x * 2 + 3, dtypes.float64))
nums = [[1, 2], [3, 4], [5, 6]]
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
r = map_defun.map_defun(fn, [elems], [dtypes.int32, dtypes.float64], [(2,),
(2,)])
expected = [elems, elems * 2 + 3]
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunShapeInference(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
def fn(x):
return x
nums = [[1, 2], [3, 4], [5, 6]]
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])[0]
self.assertEqual(result.get_shape(), (3, 2))
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunPartialShapeInference(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
def fn(x):
return x
elems = array_ops.placeholder(dtypes.int64, (None, 2))
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])
self.assertEqual(result[0].get_shape().as_list(), [None, 2])
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunRaisesErrorOnRuntimeShapeMismatch(self):
@function.defun(input_signature=[
tensor_spec.TensorSpec(None, dtypes.int32),
tensor_spec.TensorSpec(None, dtypes.int32)
])
def fn(x, y):
return x, y
elems1 = array_ops.placeholder(dtypes.int32)
elems2 = array_ops.placeholder(dtypes.int32)
result = map_defun.map_defun(fn, [elems1, elems2],
[dtypes.int32, dtypes.int32], [(), ()])
with self.cached_session() as sess:
with self.assertRaisesWithPredicateMatch(
errors.InvalidArgumentError,
"All inputs must have the same dimension 0."):
sess.run(result, feed_dict={elems1: [1, 2, 3, 4, 5], elems2: [1, 2, 3]})
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunRaisesDefunError(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
def fn(x):
with ops.control_dependencies([check_ops.assert_equal(x, 0)]):
return array_ops.identity(x)
elems = constant_op.constant([0, 0, 0, 37, 0])
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(result)
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunCancelledCorrectly(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([5], dtypes.int64)])
def defun(x):
# x has leading dimension 5, this will raise an error
return array_ops.gather(x, 10)
c = array_ops.tile(
array_ops.expand_dims(
constant_op.constant([1, 2, 3, 4, 5], dtype=dtypes.int64), 0),
[100, 1])
map_defun_op = map_defun.map_defun(defun, [c], [dtypes.int64], [()])[0]
with self.assertRaisesRegexp(errors.InvalidArgumentError,
r"indices = 10 is not in \[0, 5\)"):
self.evaluate(map_defun_op)
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunWithUnspecifiedOutputShape(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
def simple_fn(x):
res = x * 2 + 3
return (res, res + 1, res + 2)
nums = [[1, 2], [3, 4], [5, 6]]
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
r = map_defun.map_defun(simple_fn, [elems],
[dtypes.int32, dtypes.int32, dtypes.int32],
[None, (None,), (2,)])
expected = elems * 2 + 3
self.assertAllEqual(self.evaluate(r[0]), self.evaluate(expected))
self.assertAllEqual(self.evaluate(r[1]), self.evaluate(expected + 1))
self.assertAllEqual(self.evaluate(r[2]), self.evaluate(expected + 2))
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunWithDifferentOutputShapeEachRun(self):
@function.defun(
input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)])
def simple_fn(x):
return x * 2 + 3
elems = array_ops.placeholder(dtypes.int32, name="data")
r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [None])[0]
with session.Session() as sess:
self.assertAllEqual(sess.run(r, feed_dict={elems: [0]}), [3])
self.assertAllEqual(
sess.run(r, feed_dict={elems: [[0], [1]]}), [[3], [5]])
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunWithWrongOutputShape(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
def simple_fn(x):
return x * 2 + 3
nums = [[1, 2], [3, 4], [5, 6]]
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(1,)])[0]
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(r)
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunWithInvalidInput(self):
@function.defun(
input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)])
def simple_fn(x):
return x * 2
c = constant_op.constant(2)
with self.assertRaises(ValueError):
# Fails at graph construction time for inputs with known shapes.
r = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [None])[0]
p = array_ops.placeholder(dtypes.int32)
r = map_defun.map_defun(simple_fn, [p], [dtypes.int32], [None])[0]
with session.Session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(r, feed_dict={p: 0})
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunWithParentCancellation(self):
# Checks that a cancellation of the parent graph is threaded through to
# MapDefunOp correctly.
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
def simple_fn(x):
del x
queue = data_flow_ops.FIFOQueue(10, dtypes.int32, ())
# Blocking
return queue.dequeue_many(5)
c = constant_op.constant([1, 2, 3, 4, 5])
map_defun_op = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [()])[0]
with self.cached_session() as sess:
thread = self.checkedThread(
self.assert_op_cancelled, args=(map_defun_op,))
thread.start()
time.sleep(0.2)
sess.close()
thread.join()
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunWithCapturedInputs(self):
c = constant_op.constant(2)
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
def fn(x):
return x + c
x = constant_op.constant([1, 2, 3, 4])
map_defun_op = map_defun.map_defun(fn, [x], [dtypes.int32], [()])[0]
expected = x + c
self.assertAllEqual(self.evaluate(expected), self.evaluate(map_defun_op))
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunWithVariantTensor(self):
@function.defun(
input_signature=[tensor_spec.TensorSpec([], dtypes.variant)])
def fn(x):
return x
st = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
serialized = sparse_ops.serialize_sparse_v2(st, out_type=dtypes.variant)
serialized = array_ops.stack([serialized, serialized])
map_defun_op = map_defun.map_defun(fn, [serialized], [dtypes.variant],
[None])[0]
deserialized = sparse_ops.deserialize_sparse(map_defun_op, dtypes.int32)
expected = sparse_tensor.SparseTensorValue(
indices=[[0, 0, 0], [0, 1, 2], [1, 0, 0], [1, 1, 2]],
values=[1, 2, 1, 2],
dense_shape=[2, 3, 4])
actual = self.evaluate(deserialized)
self.assertValuesEqual(expected, actual)
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunWithVariantTensorAsCaptured(self):
st = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
serialized = sparse_ops.serialize_sparse_v2(st, out_type=dtypes.variant)
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
def fn(x):
del x
return serialized
x = constant_op.constant([0, 0])
map_defun_op = map_defun.map_defun(fn, [x], [dtypes.variant], [None])[0]
deserialized = sparse_ops.deserialize_sparse(map_defun_op, dtypes.int32)
expected = sparse_tensor.SparseTensorValue(
indices=[[0, 0, 0], [0, 1, 2], [1, 0, 0], [1, 1, 2]],
values=[1, 2, 1, 2],
dense_shape=[2, 3, 4])
actual = self.evaluate(deserialized)
self.assertValuesEqual(expected, actual)
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testMapDefunWithStrTensor(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
def fn(x):
return x
st = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
serialized = sparse_ops.serialize_sparse_v2(st, out_type=dtypes.string)
serialized = array_ops.stack([serialized, serialized])
map_defun_op = map_defun.map_defun(fn, [serialized], [dtypes.string],
[None])[0]
deserialized = sparse_ops.deserialize_sparse(map_defun_op, dtypes.int32)
expected = sparse_tensor.SparseTensorValue(
indices=[[0, 0, 0], [0, 1, 2], [1, 0, 0], [1, 1, 2]],
values=[1, 2, 1, 2],
dense_shape=[2, 3, 4])
actual = self.evaluate(deserialized)
self.assertValuesEqual(expected, actual)
if __name__ == "__main__":
test.main()