refactor parameterized test cases
diff --git a/tensorflow/python/data/util/random_seed_test.py b/tensorflow/python/data/util/random_seed_test.py
index d8233b3..cf4b3cd 100644
--- a/tensorflow/python/data/util/random_seed_test.py
+++ b/tensorflow/python/data/util/random_seed_test.py
@@ -18,6 +18,8 @@
from __future__ import division
from __future__ import print_function
+import functools
+
from absl.testing import parameterized
from tensorflow.python.data.kernel_tests import test_base
@@ -33,82 +35,96 @@
# NOTE(vikoth18): Arguments of parameterized tests are lifted into lambdas to make
# sure they are not executed before the (eager- or graph-mode) test environment
# has been set up.
-#
+
+def _test_random_seed_combinations():
+
+ cases = [
+ # Each test case is a tuple with input to get_seed:
+ # (input_graph_seed, input_op_seed)
+ # and output from get_seed:
+ # (output_graph_seed, output_op_seed)
+ (
+ "CASE_0",
+ lambda: (None, None),
+ lambda: (0, 0),
+ ),
+ (
+ "CASE_1",
+ lambda: (None, 1),
+ lambda: (random_seed.DEFAULT_GRAPH_SEED, 1)
+ ),
+ (
+ "CASE_2",
+ lambda: (1, 1),
+ lambda: (1, 1)
+ ),
+ (
+ # Avoid nondeterministic (0, 0) output
+ "CASE_3",
+ lambda: (0, 0),
+ lambda: (0, 2**31 - 1)
+ ),
+ (
+ # Don't wrap to (0, 0) either
+ "CASE_4",
+ lambda: (2**31 - 1, 0),
+ lambda: (0, 2**31 - 1)
+ ),
+ (
+ # Wrapping for the other argument
+ "CASE_5",
+ lambda: (0, 2**31 - 1),
+ lambda: (0, 2**31 - 1)
+ ),
+ (
+ # Once more, with tensor-valued arguments
+ "CASE_6",
+ lambda: (None, constant_op.constant(1, dtype=dtypes.int64, name='one')),
+ lambda: (random_seed.DEFAULT_GRAPH_SEED, 1)
+ ),
+ (
+ "CASE_7",
+ lambda: (1, constant_op.constant(1, dtype=dtypes.int64, name='one')),
+ lambda: (1, 1)
+ ),
+ (
+ "CASE_8",
+ lambda: (0, constant_op.constant(0, dtype=dtypes.int64, name='zero')),
+ lambda: (0, 2**31 - 1) # Avoid nondeterministic (0, 0) output
+ ),
+ (
+ "CASE_9",
+ lambda: (2**31 - 1, constant_op.constant(0, dtype=dtypes.int64, name='zero')),
+ lambda: (0, 2**31 - 1) # Don't wrap to (0, 0) either
+ ),
+ (
+ "CASE_10",
+ lambda: (0, constant_op.constant(2**31 - 1, dtype=dtypes.int64, name='intmax')),
+ lambda: (0, 2**31 - 1) # Wrapping for the other argument
+ )
+ ]
+ def reduce_fn(x, y):
+ name, input_fn, output_fn = y
+ return x + combinations.combine(
+ input_fn=combinations.NamedObject(
+ "input_fn.{}".format(name), input_fn),
+ output_fn=combinations.NamedObject(
+ "output_fn.{}".format(name), output_fn)
+ )
+
+ return functools.reduce(reduce_fn, cases, [])
+
class RandomSeedTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
- combinations.combine(test_case_fn=[
- # Each test case is a tuple with input to get_seed:
- # (input_graph_seed, input_op_seed)
- # and output from get_seed:
- # (output_graph_seed, output_op_seed)
- combinations.NamedObject(
- "Case_0",
- lambda: ((None, None), (0, 0))
- ),
- combinations.NamedObject(
- "Case_1",
- lambda: ((None, 1), (random_seed.DEFAULT_GRAPH_SEED, 1)),
- ),
- combinations.NamedObject(
- "Case_2",
- lambda: ((1, 1), (1, 1)),
- ),
- combinations.NamedObject(
- "Case_3",
- # Avoid nondeterministic (0, 0) output
- lambda: ((0, 0), (0, 2**31 - 1)),
- ),
- combinations.NamedObject(
- "Case_4",
- # Don't wrap to (0, 0) either
- lambda: ((2**31 - 1, 0), (0, 2**31 - 1)),
- ),
- combinations.NamedObject(
- "Case_5",
- # Wrapping for the other argument
- lambda: ((0, 2**31 - 1), (0, 2**31 - 1)),
- ),
- combinations.NamedObject(
- "Case_6",
- # Once more, with tensor-valued arguments
- lambda: ((None, constant_op.constant(
- 1, dtype=dtypes.int64, name='one')),
- (random_seed.DEFAULT_GRAPH_SEED, 1)),
- ),
- combinations.NamedObject(
- "Case_7",
- lambda: ((1, constant_op.constant(1, dtype=dtypes.int64, name='one')),
- (1, 1)),
- ),
- combinations.NamedObject(
- "Case_8",
- lambda: ((0, constant_op.constant(
- 0, dtype=dtypes.int64, name='zero')),
- (0, 2**31 - 1)), # Avoid nondeterministic (0, 0) output
- ),
- combinations.NamedObject(
- "Case_9",
- lambda: ((2**31 - 1, constant_op.constant(
- 0, dtype=dtypes.int64, name='zero')),
- (0, 2**31 - 1)), # Don't wrap to (0, 0) either
- ),
- combinations.NamedObject(
- "Case_10",
- lambda: ((0, constant_op.constant(
- 2**31 - 1, dtype=dtypes.int64, name='intmax')),
- (0, 2**31 - 1)), # Wrapping for the other argument
- )
- ])
+ _test_random_seed_combinations()
)
)
- def testRandomSeed(self, test_case_fn):
- test_case_fn = test_case_fn._obj # pylint: disable=protected-access
- test_case = test_case_fn()
- tinput, toutput = test_case[0], test_case[1]
+ def testRandomSeed(self, input_fn, output_fn):
+ tinput, toutput = input_fn._obj(), output_fn._obj() # pylint: disable=protected-access
def check(tinput, toutput):
random_seed.set_random_seed(tinput[0])
g_seed, op_seed = data_random_seed.get_seed(tinput[1])
diff --git a/tensorflow/python/data/util/sparse_test.py b/tensorflow/python/data/util/sparse_test.py
index f95687a..7a338b8 100644
--- a/tensorflow/python/data/util/sparse_test.py
+++ b/tensorflow/python/data/util/sparse_test.py
@@ -18,6 +18,8 @@
from __future__ import division
from __future__ import print_function
+import functools
+
from absl.testing import parameterized
from tensorflow.python.data.kernel_tests import test_base
@@ -36,49 +38,384 @@
# sure they are not executed before the (eager- or graph-mode) test environment
# has been set up.
#
+
+def _test_any_sparse_combinations():
+
+ cases = [
+ ("CASE_0", lambda: (), False),
+ ("CASE_1", lambda: (ops.Tensor), False),
+ ("CASE_2", lambda: (((ops.Tensor))), False),
+ ("CASE_3", lambda: (ops.Tensor, ops.Tensor), False),
+ ("CASE_4", lambda: (ops.Tensor, sparse_tensor.SparseTensor), True),
+ ("CASE_5", lambda: (sparse_tensor.SparseTensor, sparse_tensor.SparseTensor), True),
+ ("CASE_6", lambda: (((sparse_tensor.SparseTensor))), True)
+ ]
+ def reduce_fn(x, y):
+ name, classes_fn, expected = y
+ return x + combinations.combine(
+ classes_fn=combinations.NamedObject(
+ "classes_fn.{}".format(name), classes_fn
+ ),
+ expected=combinations.NamedObject(
+ "expected.{}".format(name), expected
+ )
+ )
+
+ return functools.reduce(reduce_fn, cases, [])
+
+def _test_as_dense_shapes_combinations():
+
+ cases = [
+ (
+ "CASE_0",
+ lambda: (),
+ lambda: (),
+ lambda: ()
+ ),
+ (
+ "CASE_1",
+ lambda: tensor_shape.TensorShape([]),
+ lambda: ops.Tensor,
+ lambda: tensor_shape.TensorShape([])
+ ),
+ (
+ "CASE_2",
+ lambda: tensor_shape.TensorShape([]),
+ lambda: sparse_tensor.SparseTensor,
+ lambda: tensor_shape.unknown_shape()
+ ),
+ (
+ "CASE_3",
+ lambda: (tensor_shape.TensorShape([])),
+ lambda: (ops.Tensor),
+ lambda: (tensor_shape.TensorShape([]))
+ ),
+ (
+ "CASE_4",
+ lambda: (tensor_shape.TensorShape([])),
+ lambda: (sparse_tensor.SparseTensor),
+ lambda: (tensor_shape.unknown_shape())
+ ),
+ (
+ "CASE_5",
+ lambda: (tensor_shape.TensorShape([]), ()),
+ lambda: (ops.Tensor, ()),
+ lambda: (tensor_shape.TensorShape([]), ())
+ ),
+ (
+ "CASE_6",
+ lambda: ((), tensor_shape.TensorShape([])),
+ lambda: ((), ops.Tensor),
+ lambda: ((), tensor_shape.TensorShape([]))
+ ),
+ (
+ "CASE_7",
+ lambda: (tensor_shape.TensorShape([]), ()),
+ lambda: (sparse_tensor.SparseTensor, ()),
+ lambda: (tensor_shape.unknown_shape(), ())
+ ),
+ (
+ "CASE_8",
+ lambda: ((), tensor_shape.TensorShape([])),
+ lambda: ((), sparse_tensor.SparseTensor),
+ lambda: ((), tensor_shape.unknown_shape())
+ ),
+ (
+ "CASE_9",
+ lambda: (tensor_shape.TensorShape([]), (),
+ tensor_shape.TensorShape([])),
+ lambda: (ops.Tensor, (), ops.Tensor),
+ lambda: (tensor_shape.TensorShape([]), (),
+ tensor_shape.TensorShape([]))
+ ),
+ (
+ "CASE_10",
+ lambda: (tensor_shape.TensorShape([]), (),
+ tensor_shape.TensorShape([])),
+ lambda: (sparse_tensor.SparseTensor, (),
+ sparse_tensor.SparseTensor),
+ lambda: (tensor_shape.unknown_shape(), (),
+ tensor_shape.unknown_shape())
+ ),
+ (
+ "CASE_11",
+ lambda: ((), tensor_shape.TensorShape([]), ()),
+ lambda: ((), ops.Tensor, ()),
+ lambda: ((), tensor_shape.TensorShape([]), ())
+ ),
+ (
+ "CASE_12",
+ lambda: ((), tensor_shape.TensorShape([]), ()),
+ lambda: ((), sparse_tensor.SparseTensor, ()),
+ lambda: ((), tensor_shape.unknown_shape(), ())
+ )
+ ]
+ def reduce_fn(x, y):
+ name, types_fn, classes_fn, expected_fn = y
+ return x + combinations.combine(
+ types_fn=combinations.NamedObject(
+ "types_fn.{}".format(name), types_fn
+ ),
+ classes_fn=combinations.NamedObject(
+ "classes_fn.{}".format(name), classes_fn
+ ),
+ expected_fn=combinations.NamedObject(
+ "expected_fn.{}".format(name), expected_fn
+ )
+ )
+
+ return functools.reduce(reduce_fn, cases, [])
+
+
+def _test_as_dense_types_combinations():
+ cases = [
+ (
+ "CASE_0",
+ lambda: (),
+ lambda: (),
+ lambda: ()
+ ),
+ (
+ "CASE_1",
+ lambda: dtypes.int32,
+ lambda: ops.Tensor,
+ lambda: dtypes.int32
+ ),
+ (
+ "CASE_2",
+ lambda: dtypes.int32,
+ lambda: sparse_tensor.SparseTensor,
+ lambda: dtypes.variant
+ ),
+ (
+ "CASE_3",
+ lambda: (dtypes.int32),
+ lambda: (ops.Tensor),
+ lambda: (dtypes.int32)
+ ),
+ (
+ "CASE_4",
+ lambda: (dtypes.int32),
+ lambda: (sparse_tensor.SparseTensor),
+ lambda: (dtypes.variant)
+ ),
+ (
+ "CASE_5",
+ lambda: (dtypes.int32, ()),
+ lambda: (ops.Tensor, ()),
+ lambda: (dtypes.int32, ())
+ ),
+ (
+ "CASE_6",
+ lambda: ((), dtypes.int32),
+ lambda: ((), ops.Tensor),
+ lambda: ((), dtypes.int32)
+ ),
+ (
+ "CASE_7",
+ lambda: (dtypes.int32, ()),
+ lambda: (sparse_tensor.SparseTensor, ()),
+ lambda: (dtypes.variant, ())
+ ),
+ (
+ "CASE_8",
+ lambda: ((), dtypes.int32),
+ lambda: ((), sparse_tensor.SparseTensor),
+ lambda: ((), dtypes.variant)
+ ),
+ (
+ "CASE_9",
+ lambda: (dtypes.int32, (), dtypes.int32),
+ lambda: (ops.Tensor, (), ops.Tensor),
+ lambda: (dtypes.int32, (), dtypes.int32)
+ ),
+ (
+ "CASE_10",
+ lambda: (dtypes.int32, (), dtypes.int32),
+ lambda: (sparse_tensor.SparseTensor, (),
+ sparse_tensor.SparseTensor),
+ lambda: (dtypes.variant, (), dtypes.variant)
+ ),
+ (
+ "CASE_11",
+ lambda: ((), dtypes.int32, ()),
+ lambda: ((), ops.Tensor, ()),
+ lambda: ((), dtypes.int32, ())
+ ),
+ (
+ "CASE_12",
+ lambda: ((), dtypes.int32, ()),
+ lambda: ((), sparse_tensor.SparseTensor, ()),
+ lambda: ((), dtypes.variant, ())
+ ),
+ ]
+ def reduce_fn(x, y):
+ name, types_fn, classes_fn, expected_fn = y
+ return x + combinations.combine(
+ types_fn=combinations.NamedObject(
+ "types_fn.{}".format(name), types_fn
+ ),
+ classes_fn=combinations.NamedObject(
+ "classes_fn.{}".format(name), classes_fn
+ ),
+ expected_fn=combinations.NamedObject(
+ "expected_fn.{}".format(name), expected_fn
+ )
+ )
+
+ return functools.reduce(reduce_fn, cases, [])
+
+def _test_get_classes_combinations():
+ cases = [
+ (
+ "CASE_0",
+ lambda: (),
+ lambda: ()
+ ),
+ (
+ "CASE_1",
+ lambda: sparse_tensor.SparseTensor(
+ indices=[[0]], values=[1], dense_shape=[1]),
+ lambda: sparse_tensor.SparseTensor
+ ),
+ (
+ "CASE_2",
+ lambda: constant_op.constant([1]),
+ lambda: ops.Tensor
+ ),
+ (
+ "CASE_3",
+ lambda: (sparse_tensor.SparseTensor(
+ indices=[[0]], values=[1], dense_shape=[1])),
+ lambda: (sparse_tensor.SparseTensor)
+ ),
+ (
+ "CASE_4",
+ lambda: (constant_op.constant([1])),
+ lambda: (ops.Tensor)
+ ),
+ (
+ "CASE_5",
+ lambda: (sparse_tensor.SparseTensor(
+ indices=[[0]], values=[1], dense_shape=[1]), ()),
+ lambda: (sparse_tensor.SparseTensor, ())
+ ),
+ (
+ "CASE_6",
+ lambda: ((), sparse_tensor.SparseTensor(
+ indices=[[0]], values=[1], dense_shape=[1])),
+ lambda: ((), sparse_tensor.SparseTensor)
+ ),
+ (
+ "CASE_7",
+ lambda: (constant_op.constant([1]), ()),
+ lambda: (ops.Tensor, ())
+ ),
+ (
+ "CASE_8",
+ lambda: ((), constant_op.constant([1])),
+ lambda: ((), ops.Tensor)
+ ),
+ (
+ "CASE_9",
+ lambda: (sparse_tensor.SparseTensor(
+ indices=[[0]], values=[1], dense_shape=[1]),
+ (), constant_op.constant([1])),
+ lambda: (sparse_tensor.SparseTensor, (), ops.Tensor)
+ ),
+ (
+ "CASE_10",
+ lambda: ((), sparse_tensor.SparseTensor(
+ indices=[[0]], values=[1], dense_shape=[1]), ()),
+ lambda: ((), sparse_tensor.SparseTensor, ())
+ ),
+ (
+ "CASE_11",
+ lambda: ((), constant_op.constant([1]), ()),
+ lambda: ((), ops.Tensor, ())
+ ),
+ ]
+ def reduce_fn(x, y):
+ name, classes_fn, expected_fn = y
+ return x + combinations.combine(
+ classes_fn=combinations.NamedObject(
+ "classes_fn.{}".format(name), classes_fn
+ ),
+ expected_fn=combinations.NamedObject(
+ "expected_fn.{}".format(name), expected_fn
+ )
+ )
+
+ return functools.reduce(reduce_fn, cases, [])
+
+
+def _test_serialize_deserialize_combinations():
+ cases = [
+ ("CASE_0", lambda: ()),
+ ("CASE_1", lambda: sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
+ ("CASE_2", lambda: sparse_tensor.SparseTensor(
+ indices=[[3, 4]], values=[-1], dense_shape=[4, 5])),
+ ("CASE_3", lambda: sparse_tensor.SparseTensor(
+ indices=[[0, 0], [3, 4]], values=[1, -1],
+ dense_shape=[4, 5])),
+ ("CASE_4", lambda: (sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[1], dense_shape=[1, 1]))),
+ ("CASE_5", lambda: (sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ())),
+ ("CASE_6", lambda: ((), sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[1], dense_shape=[1, 1])))
+ ]
+ def reduce_fn(x, y):
+ name, input_fn = y
+ return x + combinations.combine(
+ input_fn=combinations.NamedObject(
+ "input_fn.{}".format(name), input_fn
+ )
+ )
+
+ return functools.reduce(reduce_fn, cases, [])
+
+
+def _test_serialize_many_deserialize_combinations():
+ cases = [
+ ("CASE_0", lambda: ()),
+ ("CASE_1", lambda: sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
+ ("CASE_2", lambda: sparse_tensor.SparseTensor(
+ indices=[[3, 4]], values=[-1], dense_shape=[4, 5])),
+ ("CASE_3", lambda: sparse_tensor.SparseTensor(
+ indices=[[0, 0], [3, 4]], values=[1, -1],
+ dense_shape=[4, 5])),
+ ("CASE_4", lambda: (sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[1], dense_shape=[1, 1]))),
+ ("CASE_5", lambda: (sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ())),
+ ("CASE_6", lambda: ((), sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[1], dense_shape=[1, 1])))
+ ]
+ def reduce_fn(x, y):
+ name, input_fn = y
+ return x + combinations.combine(
+ input_fn=combinations.NamedObject(
+ "input_fn.{}".format(name), input_fn
+ )
+ )
+
+ return functools.reduce(reduce_fn, cases, [])
+
+
class SparseTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
- combinations.combine(test_case_fn=[
- combinations.NamedObject("Case_0", lambda: {
- "classes": (),
- "expected": False
- }),
- combinations.NamedObject("Case_1", lambda: {
- "classes": (ops.Tensor),
- "expected": False
- }),
- combinations.NamedObject("Case_2", lambda: {
- "classes": (((ops.Tensor))),
- "expected": False
- }),
- combinations.NamedObject("Case_3", lambda: {
- "classes": (ops.Tensor, ops.Tensor),
- "expected": False
- }),
- combinations.NamedObject("Case_4", lambda: {
- "classes": (ops.Tensor, sparse_tensor.SparseTensor),
- "expected": True
- }),
- combinations.NamedObject("Case_5", lambda: {
- "classes": (sparse_tensor.SparseTensor,
- sparse_tensor.SparseTensor),
- "expected": True
- }),
- combinations.NamedObject("Case_6", lambda: {
- "classes": (((sparse_tensor.SparseTensor))),
- "expected": True
- }),
- ])
+ _test_any_sparse_combinations()
)
)
- def testAnySparse(self, test_case_fn):
- test_case_fn = test_case_fn._obj # pylint: disable=protected-access
- test_case = test_case_fn()
- classes = test_case["classes"]
- expected = test_case["expected"]
+ def testAnySparse(self, classes_fn, expected):
+ classes = classes_fn._obj() # pylint: disable=protected-access
+ expected = expected._obj # pylint: disable=protected-access
self.assertEqual(sparse.any_sparse(classes), expected)
def assertShapesEqual(self, a, b):
@@ -92,238 +429,36 @@
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
- combinations.combine(test_case_fn=[
- combinations.NamedObject("Case_0", lambda: {
- "types": (),
- "classes": (),
- "expected": ()
- }),
- combinations.NamedObject("Case_1", lambda: {
- "types": tensor_shape.TensorShape([]),
- "classes": ops.Tensor,
- "expected": tensor_shape.TensorShape([])
- }),
- combinations.NamedObject("Case_2", lambda: {
- "types": tensor_shape.TensorShape([]),
- "classes": sparse_tensor.SparseTensor,
- "expected": tensor_shape.unknown_shape()
- }),
- combinations.NamedObject("Case_3", lambda: {
- "types": (tensor_shape.TensorShape([])),
- "classes": (ops.Tensor),
- "expected": (tensor_shape.TensorShape([]))
- }),
- combinations.NamedObject("Case_4", lambda: {
- "types": (tensor_shape.TensorShape([])),
- "classes": (sparse_tensor.SparseTensor),
- "expected": (tensor_shape.unknown_shape())
- }),
- combinations.NamedObject("Case_5", lambda: {
- "types": (tensor_shape.TensorShape([]), ()),
- "classes": (ops.Tensor, ()),
- "expected": (tensor_shape.TensorShape([]), ())
- }),
- combinations.NamedObject("Case_6", lambda: {
- "types": ((), tensor_shape.TensorShape([])),
- "classes": ((), ops.Tensor),
- "expected": ((), tensor_shape.TensorShape([]))
- }),
- combinations.NamedObject("Case_7", lambda: {
- "types": (tensor_shape.TensorShape([]), ()),
- "classes": (sparse_tensor.SparseTensor, ()),
- "expected": (tensor_shape.unknown_shape(), ())
- }),
- combinations.NamedObject("Case_8", lambda: {
- "types": ((), tensor_shape.TensorShape([])),
- "classes": ((), sparse_tensor.SparseTensor),
- "expected": ((), tensor_shape.unknown_shape())
- }),
- combinations.NamedObject("Case_9", lambda: {
- "types": (tensor_shape.TensorShape([]), (),
- tensor_shape.TensorShape([])),
- "classes": (ops.Tensor, (), ops.Tensor),
- "expected": (tensor_shape.TensorShape([]), (),
- tensor_shape.TensorShape([]))
- }),
- combinations.NamedObject("Case_10", lambda: {
- "types": (tensor_shape.TensorShape([]), (),
- tensor_shape.TensorShape([])),
- "classes": (sparse_tensor.SparseTensor, (),
- sparse_tensor.SparseTensor),
- "expected": (tensor_shape.unknown_shape(), (),
- tensor_shape.unknown_shape())
- }),
- combinations.NamedObject("Case_11", lambda: {
- "types": ((), tensor_shape.TensorShape([]), ()),
- "classes": ((), ops.Tensor, ()),
- "expected": ((), tensor_shape.TensorShape([]), ())
- }),
- combinations.NamedObject("Case_12", lambda: {
- "types": ((), tensor_shape.TensorShape([]), ()),
- "classes": ((), sparse_tensor.SparseTensor, ()),
- "expected": ((), tensor_shape.unknown_shape(), ())
- }),
- ])
+ _test_as_dense_shapes_combinations()
)
)
- def testAsDenseShapes(self, test_case_fn):
- test_case_fn = test_case_fn._obj # pylint: disable=protected-access
- test_case = test_case_fn()
- types = test_case["types"]
- classes = test_case["classes"]
- expected = test_case["expected"]
+ def testAsDenseShapes(self, types_fn, classes_fn, expected_fn):
+ types = types_fn._obj() # pylint: disable=protected-access
+ classes = classes_fn._obj() # pylint: disable=protected-access
+ expected = expected_fn._obj() # pylint: disable=protected-access
self.assertShapesEqual(sparse.as_dense_shapes(types, classes), expected)
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
- combinations.combine(test_case_fn=[
- combinations.NamedObject("Case_0", lambda: {
- "types": (),
- "classes": (),
- "expected": ()
- }),
- combinations.NamedObject("Case_1", lambda: {
- "types": dtypes.int32,
- "classes": ops.Tensor,
- "expected": dtypes.int32
- }),
- combinations.NamedObject("Case_2", lambda: {
- "types": dtypes.int32,
- "classes": sparse_tensor.SparseTensor,
- "expected": dtypes.variant
- }),
- combinations.NamedObject("Case_3", lambda: {
- "types": (dtypes.int32),
- "classes": (ops.Tensor),
- "expected": (dtypes.int32)
- }),
- combinations.NamedObject("Case_4", lambda: {
- "types": (dtypes.int32),
- "classes": (sparse_tensor.SparseTensor),
- "expected": (dtypes.variant)
- }),
- combinations.NamedObject("Case_5", lambda: {
- "types": (dtypes.int32, ()),
- "classes": (ops.Tensor, ()),
- "expected": (dtypes.int32, ())
- }),
- combinations.NamedObject("Case_6", lambda: {
- "types": ((), dtypes.int32),
- "classes": ((), ops.Tensor),
- "expected": ((), dtypes.int32)
- }),
- combinations.NamedObject("Case_7", lambda: {
- "types": (dtypes.int32, ()),
- "classes": (sparse_tensor.SparseTensor, ()),
- "expected": (dtypes.variant, ())
- }),
- combinations.NamedObject("Case_8", lambda: {
- "types": ((), dtypes.int32),
- "classes": ((), sparse_tensor.SparseTensor),
- "expected": ((), dtypes.variant)
- }),
- combinations.NamedObject("Case_9", lambda: {
- "types": (dtypes.int32, (), dtypes.int32),
- "classes": (ops.Tensor, (), ops.Tensor),
- "expected": (dtypes.int32, (), dtypes.int32)
- }),
- combinations.NamedObject("Case_10", lambda: {
- "types": (dtypes.int32, (), dtypes.int32),
- "classes": (sparse_tensor.SparseTensor, (),
- sparse_tensor.SparseTensor),
- "expected": (dtypes.variant, (), dtypes.variant)
- }),
- combinations.NamedObject("Case_11", lambda: {
- "types": ((), dtypes.int32, ()),
- "classes": ((), ops.Tensor, ()),
- "expected": ((), dtypes.int32, ())
- }),
- combinations.NamedObject("Case_12", lambda: {
- "types": ((), dtypes.int32, ()),
- "classes": ((), sparse_tensor.SparseTensor, ()),
- "expected": ((), dtypes.variant, ())
- }),
- ])
+ _test_as_dense_types_combinations()
)
)
- def testAsDenseTypes(self, test_case_fn):
- test_case_fn = test_case_fn._obj # pylint: disable=protected-access
- test_case = test_case_fn()
- types = test_case["types"]
- classes = test_case["classes"]
- expected = test_case["expected"]
+ def testAsDenseTypes(self, types_fn, classes_fn, expected_fn):
+ types = types_fn._obj() # pylint: disable=protected-access
+ classes = classes_fn._obj() # pylint: disable=protected-access
+ expected = expected_fn._obj() # pylint: disable=protected-access
self.assertEqual(sparse.as_dense_types(types, classes), expected)
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
- combinations.combine(test_case_fn=[
- combinations.NamedObject("Case_0", lambda: {
- "classes": (),
- "expected": ()
- }),
- combinations.NamedObject("Case_1", lambda: {
- "classes": sparse_tensor.SparseTensor(
- indices=[[0]], values=[1], dense_shape=[1]
- ),
- "expected": sparse_tensor.SparseTensor
- }),
- combinations.NamedObject("Case_2", lambda: {
- "classes": constant_op.constant([1]),
- "expected": ops.Tensor
- }),
- combinations.NamedObject("Case_3", lambda: {
- "classes": (sparse_tensor.SparseTensor(
- indices=[[0]], values=[1], dense_shape=[1])),
- "expected": (sparse_tensor.SparseTensor)
- }),
- combinations.NamedObject("Case_4", lambda: {
- "classes": (constant_op.constant([1])),
- "expected": (ops.Tensor)
- }),
- combinations.NamedObject("Case_5", lambda: {
- "classes": (sparse_tensor.SparseTensor(
- indices=[[0]], values=[1], dense_shape=[1]), ()),
- "expected": (sparse_tensor.SparseTensor, ())
- }),
- combinations.NamedObject("Case_6", lambda: {
- "classes": ((), sparse_tensor.SparseTensor(
- indices=[[0]], values=[1], dense_shape=[1])),
- "expected": ((), sparse_tensor.SparseTensor)
- }),
- combinations.NamedObject("Case_7", lambda: {
- "classes": (constant_op.constant([1]), ()),
- "expected": (ops.Tensor, ())
- }),
- combinations.NamedObject("Case_8", lambda: {
- "classes": ((), constant_op.constant([1])),
- "expected": ((), ops.Tensor)
- }),
- combinations.NamedObject("Case_9", lambda: {
- "classes": (
- sparse_tensor.SparseTensor(
- indices=[[0]], values=[1], dense_shape=[1]),
- (), constant_op.constant([1])),
- "expected": (sparse_tensor.SparseTensor, (), ops.Tensor)
- }),
- combinations.NamedObject("Case_10", lambda: {
- "classes": ((), sparse_tensor.SparseTensor(
- indices=[[0]], values=[1], dense_shape=[1]), ()),
- "expected": ((), sparse_tensor.SparseTensor, ())
- }),
- combinations.NamedObject("Case_11", lambda: {
- "classes": ((), constant_op.constant([1]), ()),
- "expected": ((), ops.Tensor, ())
- }),
- ])
+ _test_get_classes_combinations()
)
)
- def testGetClasses(self, test_case_fn):
- test_case_fn = test_case_fn._obj # pylint: disable=protected-access
- test_case = test_case_fn()
- classes = test_case["classes"]
- expected = test_case["expected"]
+ def testGetClasses(self, classes_fn, expected_fn):
+ classes = classes_fn._obj() # pylint: disable=protected-access
+ expected = expected_fn._obj() # pylint: disable=protected-access
self.assertEqual(sparse.get_classes(classes), expected)
def assertSparseValuesEqual(self, a, b):
@@ -340,40 +475,11 @@
@combinations.generate(
combinations.times(
test_base.graph_only_combinations(),
- combinations.combine(test_case_fn=[
- combinations.NamedObject(
- "Case_0", lambda: ()
- ),
- combinations.NamedObject(
- "Case_1", lambda: sparse_tensor.SparseTensor(
- indices=[[0, 0]], values=[1], dense_shape=[1, 1])
- ),
- combinations.NamedObject(
- "Case_2", lambda: sparse_tensor.SparseTensor(
- indices=[[3, 4]], values=[-1], dense_shape=[4, 5])
- ),
- combinations.NamedObject(
- "Case_3", lambda: sparse_tensor.SparseTensor(
- indices=[[0, 0], [3, 4]], values=[1, -1],
- dense_shape=[4, 5])
- ),
- combinations.NamedObject(
- "Case_4", lambda: (sparse_tensor.SparseTensor(
- indices=[[0, 0]], values=[1], dense_shape=[1, 1]))
- ),
- combinations.NamedObject(
- "Case_5", lambda: (sparse_tensor.SparseTensor(
- indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ())
- ),
- combinations.NamedObject(
- "Case_6", lambda: ((), sparse_tensor.SparseTensor(
- indices=[[0, 0]], values=[1], dense_shape=[1, 1]))),
- ])
+ _test_serialize_deserialize_combinations()
)
)
- def testSerializeDeserialize(self, test_case_fn):
- test_case_fn = test_case_fn._obj # pylint: disable=protected-access
- test_case = test_case_fn()
+ def testSerializeDeserialize(self, input_fn):
+ test_case = input_fn._obj() # pylint: disable=protected-access
classes = sparse.get_classes(test_case)
shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None),
classes)
@@ -388,41 +494,11 @@
@combinations.generate(
combinations.times(
test_base.graph_only_combinations(),
- combinations.combine(test_case_fn=[
- combinations.NamedObject(
- "Case_0", lambda: ()
- ),
- combinations.NamedObject(
- "Case_1", lambda: sparse_tensor.SparseTensor(
- indices=[[0, 0]], values=[1], dense_shape=[1, 1])
- ),
- combinations.NamedObject(
- "Case_2", lambda: sparse_tensor.SparseTensor(
- indices=[[3, 4]], values=[-1], dense_shape=[4, 5])
- ),
- combinations.NamedObject(
- "Case_3", lambda: sparse_tensor.SparseTensor(
- indices=[[0, 0], [3, 4]], values=[1, -1],
- dense_shape=[4, 5])
- ),
- combinations.NamedObject(
- "Case_4", lambda: (sparse_tensor.SparseTensor(
- indices=[[0, 0]], values=[1], dense_shape=[1, 1]))
- ),
- combinations.NamedObject(
- "Case_5", lambda: (sparse_tensor.SparseTensor(
- indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ())
- ),
- combinations.NamedObject(
- "Case_6", lambda: ((), sparse_tensor.SparseTensor(
- indices=[[0, 0]], values=[1], dense_shape=[1, 1]))
- ),
- ])
+ _test_serialize_many_deserialize_combinations()
)
)
- def testSerializeManyDeserialize(self, test_case_fn):
- test_case_fn = test_case_fn._obj # pylint: disable=protected-access
- test_case = test_case_fn()
+ def testSerializeManyDeserialize(self, input_fn):
+ test_case = input_fn._obj() # pylint: disable=protected-access
classes = sparse.get_classes(test_case)
shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None),
classes)
diff --git a/tensorflow/python/data/util/structure_test.py b/tensorflow/python/data/util/structure_test.py
index 7bfc80e..7e962c8 100644
--- a/tensorflow/python/data/util/structure_test.py
+++ b/tensorflow/python/data/util/structure_test.py
@@ -19,6 +19,7 @@
from __future__ import print_function
import collections
+import functools
import numpy as np
import wrapt
@@ -49,6 +50,701 @@
# sure they are not executed before the (eager- or graph-mode) test environment
# has been set up.
#
+
+def _test_flat_structure_combinations():
+ cases = [
+ (
+ "Tensor",
+ lambda: constant_op.constant(37.0),
+ tensor_spec.TensorSpec,
+ [dtypes.float32],
+ [[]]
+ ),
+ (
+ "TensorArray",
+ lambda: tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, element_shape=(3,),
+ size=0),
+ tensor_array_ops.TensorArraySpec,
+ [dtypes.variant],
+ [[]]
+ ),
+ (
+ "SparseTensor",
+ lambda: sparse_tensor.SparseTensor(
+ indices=[[3, 4]], values=[-1],
+ dense_shape=[4, 5]),
+ sparse_tensor.SparseTensorSpec,
+ [dtypes.variant],
+ [None]
+ ),
+ (
+ "RaggedTensor",
+ lambda: ragged_factory_ops.constant(
+ [[1, 2], [], [4]]),
+ ragged_tensor.RaggedTensorSpec,
+ [dtypes.variant],
+ [None]
+ ),
+ (
+ "Nested_0",
+ lambda: (constant_op.constant(37.0),
+ constant_op.constant([1, 2, 3])),
+ tuple,
+ [dtypes.float32, dtypes.int32],
+ [[], [3]]
+ ),
+ (
+ "Nested_1",
+ lambda: {
+ "a": constant_op.constant(37.0),
+ "b": constant_op.constant([1, 2, 3])
+ },
+ dict,
+ [dtypes.float32, dtypes.int32],
+ [[], [3]]
+ ),
+ (
+ "Nested_2",
+ lambda: {
+ "a": constant_op.constant(37.0),
+ "b": (sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[1],
+ dense_shape=[1, 1]), sparse_tensor.SparseTensor(
+ indices=[[3, 4]], values=[-1],
+ dense_shape=[4, 5]))
+ },
+ dict,
+ [dtypes.float32, dtypes.variant, dtypes.variant],
+ [[], None, None]
+ ),
+ ]
+ def reduce_fn(x, y):
+ name, value_fn, expected_structure, expected_types, expected_shapes = y
+ return x + combinations.combine(
+ value_fn=combinations.NamedObject(
+ "value_fn.{}".format(name), value_fn
+ ),
+ expected_structure=combinations.NamedObject(
+ "expected_structure.{}".format(name), expected_structure
+ ),
+ expected_types=combinations.NamedObject(
+ "expected_types.{}".format(name), expected_types
+ ),
+ expected_shapes=combinations.NamedObject(
+ "expected_shapes.{}".format(name), expected_shapes
+ )
+ )
+
+ return functools.reduce(reduce_fn, cases, [])
+
+
+def _test_is_compatible_with_structure_combinations():
+ cases = [
+ (
+ "Tensor",
+ lambda: constant_op.constant(37.0),
+ lambda: [
+ constant_op.constant(38.0),
+ array_ops.placeholder(dtypes.float32),
+ variables.Variable(100.0), 42.0,
+ np.array(42.0, dtype=np.float32)
+ ],
+ lambda: [constant_op.constant([1.0, 2.0]),
+ constant_op.constant(37)]
+ ),
+ (
+ "TensorArray",
+ lambda: tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, element_shape=(3,), size=0),
+ lambda: [
+ tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, element_shape=(3,), size=0),
+ tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, element_shape=(3,), size=10)
+ ],
+ lambda: [
+ tensor_array_ops.TensorArray(
+ dtype=dtypes.int32, element_shape=(3,), size=0),
+ tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, element_shape=(), size=0)
+ ]
+ ),
+ (
+ "SparseTensor",
+ lambda: sparse_tensor.SparseTensor(
+ indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
+ lambda: [
+ sparse_tensor.SparseTensor(
+ indices=[[1, 1], [3, 4]], values=[10, -1],
+ dense_shape=[4, 5]),
+ sparse_tensor.SparseTensorValue(
+ indices=[[1, 1], [3, 4]], values=[10, -1],
+ dense_shape=[4, 5]),
+ array_ops.sparse_placeholder(dtype=dtypes.int32),
+ array_ops.sparse_placeholder(
+ dtype=dtypes.int32, shape=[None, None])
+ ],
+ lambda: [
+ constant_op.constant(37, shape=[4, 5]),
+ sparse_tensor.SparseTensor(
+ indices=[[3, 4]], values=[-1],
+ dense_shape=[5, 6]),
+ array_ops.sparse_placeholder(
+ dtype=dtypes.int32, shape=[None, None, None]),
+ sparse_tensor.SparseTensor(
+ indices=[[3, 4]], values=[-1.0],
+ dense_shape=[4, 5])
+ ]
+ ),
+ (
+ "RaggedTensor",
+ lambda: ragged_factory_ops.constant([[1, 2], [], [3]]),
+ lambda: [
+ ragged_factory_ops.constant([[1, 2], [3, 4], []]),
+ ragged_factory_ops.constant([[1], [2, 3, 4], [5]]),
+ ],
+ lambda: [
+ ragged_factory_ops.constant(1),
+ ragged_factory_ops.constant([1, 2]),
+ ragged_factory_ops.constant([[1], [2]]),
+ ragged_factory_ops.constant([["a", "b"]]),
+ ]
+ ),
+ (
+ "Nested",
+ lambda: {
+ "a": constant_op.constant(37.0),
+ "b": constant_op.constant([1, 2, 3])
+ },
+ lambda: [{
+ "a": constant_op.constant(15.0),
+ "b": constant_op.constant([4, 5, 6])
+ }],
+ lambda: [
+ {
+ "a": constant_op.constant(15.0),
+ "b": constant_op.constant([4, 5, 6, 7])
+ },
+ {
+ "a": constant_op.constant(15),
+ "b": constant_op.constant([4, 5, 6])
+ },
+ {
+ "a": constant_op.constant(15),
+ "b": sparse_tensor.SparseTensor(
+ indices=[[0], [1], [2]], values=[4, 5, 6],
+ dense_shape=[3])
+ },
+ (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))
+ ]
+ ),
+ ]
+ def reduce_fn(x, y):
+ name, original_value_fn, compatible_values_fn, incompatible_values_fn = y
+ return x + combinations.combine(
+ original_value_fn=combinations.NamedObject(
+ "original_value_fn.{}".format(name), original_value_fn
+ ),
+ compatible_values_fn=combinations.NamedObject(
+ "compatible_values_fn.{}".format(name), compatible_values_fn
+ ),
+ incompatible_values_fn=combinations.NamedObject(
+ "incompatible_values_fn.{}".format(name), incompatible_values_fn
+ )
+ )
+
+ return functools.reduce(reduce_fn, cases, [])
+
+
+def _test_structure_from_value_equality_combinations():
+ cases = [
+ (
+ "Tensor",
+ lambda: constant_op.constant(37.0),
+ lambda: constant_op.constant(42.0),
+ lambda: constant_op.constant([5])
+ ),
+ (
+ "TensorArray",
+ lambda: tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, element_shape=(3,), size=0),
+ lambda: tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, element_shape=(3,), size=0),
+ lambda: tensor_array_ops.TensorArray(
+ dtype=dtypes.int32, element_shape=(), size=0)
+ ),
+ (
+ "SparseTensor",
+ lambda: sparse_tensor.SparseTensor(
+ indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
+ lambda: sparse_tensor.SparseTensor(
+ indices=[[1, 2]], values=[42], dense_shape=[4, 5]),
+ lambda: sparse_tensor.SparseTensor(
+ indices=[[3]], values=[-1], dense_shape=[5]),
+ lambda: sparse_tensor.SparseTensor(
+ indices=[[3, 4]], values=[1.0], dense_shape=[4, 5])
+ ),
+ (
+ "RaggedTensor",
+ lambda: ragged_factory_ops.constant([[[1, 2]], [[3]]]),
+ lambda: ragged_factory_ops.constant([[[5]], [[8], [3, 2]]]),
+ lambda: ragged_factory_ops.constant([[[1]], [[2], [3]]],
+ ragged_rank=1),
+ lambda: ragged_factory_ops.constant([[[1.0, 2.0]], [[3.0]]]),
+ lambda: ragged_factory_ops.constant([[[1]], [[2]], [[3]]])
+ ),
+ (
+ "Nested",
+ lambda: {
+ "a": constant_op.constant(37.0),
+ "b": constant_op.constant([1, 2, 3])
+ },
+ lambda: {
+ "a": constant_op.constant(42.0),
+ "b": constant_op.constant([4, 5, 6])
+ },
+ lambda: {
+ "a": constant_op.constant([1, 2, 3]),
+ "b": constant_op.constant(37.0)
+ }
+ ),
+ ]
+ def reduce_fn(x, y):
+ name, value1_fn, value2_fn, *not_equal_value_fns = y
+ return x + combinations.combine(
+ value1_fn=combinations.NamedObject(
+ "value1_fn.{}".format(name), value1_fn
+ ),
+ value2_fn=combinations.NamedObject(
+ "value2_fn.{}".format(name), value2_fn
+ ),
+ not_equal_value_fns=combinations.NamedObject(
+ "not_equal_value_fns.{}".format(name), not_equal_value_fns
+ )
+ )
+
+ return functools.reduce(reduce_fn, cases, [])
+
+
+def _test_ragged_structure_inequality_combinations():
+ cases = [
+ (
+ "RaggedTensor_RaggedRank",
+ ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 1),
+ ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 2)
+ ),
+ (
+ "RaggedTensor_Shape",
+ ragged_tensor.RaggedTensorSpec([3, None], dtypes.int32, 1),
+ ragged_tensor.RaggedTensorSpec([5, None], dtypes.int32, 1)
+ ),
+ (
+ "RaggedTensor_DType",
+ ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 1),
+ ragged_tensor.RaggedTensorSpec(None, dtypes.float32, 1)
+ ),
+ ]
+ def reduce_fn(x, y):
+ name, spec1, spec2 = y
+ return x + combinations.combine(
+ spec1=combinations.NamedObject("spec1.{}".format(name), spec1),
+ spec2=combinations.NamedObject("spec2.{}".format(name), spec2),
+ )
+
+ return functools.reduce(reduce_fn, cases, [])
+
+
+def _test_hash_combinations():
+ cases = [
+ (
+ "Tensor",
+ lambda: constant_op.constant(37.0),
+ lambda: constant_op.constant(42.0),
+ lambda: constant_op.constant([5])
+ ),
+ (
+ "TensorArray",
+ lambda: tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, element_shape=(3,), size=0),
+ lambda: tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, element_shape=(3,), size=0),
+ lambda: tensor_array_ops.TensorArray(
+ dtype=dtypes.int32, element_shape=(), size=0)
+ ),
+ (
+ "SparseTensor",
+ lambda: sparse_tensor.SparseTensor(
+ indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
+ lambda: sparse_tensor.SparseTensor(
+ indices=[[1, 2]], values=[42], dense_shape=[4, 5]),
+ lambda: sparse_tensor.SparseTensor(
+ indices=[[3]], values=[-1], dense_shape=[5])
+ ),
+ (
+ "Nested",
+ lambda: {
+ "a": constant_op.constant(37.0),
+ "b": constant_op.constant([1, 2, 3])
+ },
+ lambda: {
+ "a": constant_op.constant(42.0),
+ "b": constant_op.constant([4, 5, 6])
+ },
+ lambda: {
+ "a": constant_op.constant([1, 2, 3]),
+ "b": constant_op.constant(37.0)
+ }
+ ),
+ ]
+ def reduce_fn(x, y):
+ name, value1_fn, value2_fn, value3_fn = y
+ return x + combinations.combine(
+ value1_fn=combinations.NamedObject(
+ "value1_fn.{}".format(name), value1_fn
+ ),
+ value2_fn=combinations.NamedObject(
+ "value2_fn.{}".format(name), value2_fn
+ ),
+ value3_fn=combinations.NamedObject(
+ "value3_fn.{}".format(name), value3_fn
+ )
+ )
+
+ return functools.reduce(reduce_fn, cases, [])
+
+
+def _test_round_trip_conversion_combinations():
+ cases = [
+ (
+ "Tensor",
+ lambda: constant_op.constant(37.0),
+ ),
+ (
+ "SparseTensor",
+ lambda: sparse_tensor.SparseTensor(
+ indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
+ ),
+ (
+ "TensorArray",
+ lambda: tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, element_shape=(), size=1
+ ).write(0, 7)
+ ),
+ (
+ "RaggedTensor",
+ lambda: ragged_factory_ops.constant([[1, 2], [], [3]]),
+ ),
+ (
+ "Nested_0",
+ lambda: {
+ "a": constant_op.constant(37.0),
+ "b": constant_op.constant([1, 2, 3])
+ },
+ ),
+ (
+ "Nested_1",
+ lambda: {
+ "a": constant_op.constant(37.0),
+ "b": (sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
+ sparse_tensor.SparseTensor(
+ indices=[[3, 4]], values=[-1],
+ dense_shape=[4, 5]))
+ },
+ ),
+ ]
+ def reduce_fn(x, y):
+ name, value_fn = y
+ return x + combinations.combine(
+ value_fn=combinations.NamedObject(
+ "value_fn.{}".format(name), value_fn
+ )
+ )
+
+ return functools.reduce(reduce_fn, cases, [])
+
+
+def _test_convert_legacy_structure_combinations():
+ cases = [
+ (
+ "Tensor",
+ dtypes.float32,
+ tensor_shape.TensorShape([]),
+ ops.Tensor,
+ tensor_spec.TensorSpec([], dtypes.float32)
+ ),
+ (
+ "SparseTensor",
+ dtypes.int32,
+ tensor_shape.TensorShape([2, 2]),
+ sparse_tensor.SparseTensor,
+ sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32)
+ ),
+ (
+ "TensorArray_0",
+ dtypes.int32,
+ tensor_shape.TensorShape(
+ [None, True, 2, 2]),
+ tensor_array_ops.TensorArray,
+ tensor_array_ops.TensorArraySpec(
+ [2, 2], dtypes.int32, dynamic_size=None,
+ infer_shape=True)
+ ),
+ (
+ "TensorArray_1",
+ dtypes.int32,
+ tensor_shape.TensorShape([True, None, 2, 2]),
+ tensor_array_ops.TensorArray,
+ tensor_array_ops.TensorArraySpec(
+ [2, 2], dtypes.int32, dynamic_size=True,
+ infer_shape=None)
+ ),
+ (
+ "TensorArray_2",
+ dtypes.int32,
+ tensor_shape.TensorShape([True, False, 2, 2]),
+ tensor_array_ops.TensorArray,
+ tensor_array_ops.TensorArraySpec(
+ [2, 2], dtypes.int32, dynamic_size=True,
+ infer_shape=False)
+ ),
+ (
+ "RaggedTensor",
+ dtypes.int32,
+ tensor_shape.TensorShape([2, None]),
+ ragged_tensor.RaggedTensorSpec(
+ [2, None], dtypes.int32, 1),
+ ragged_tensor.RaggedTensorSpec(
+ [2, None], dtypes.int32, 1)
+ ),
+ (
+ "Nested",
+ {
+ "a": dtypes.float32,
+ "b": (dtypes.int32, dtypes.string)
+ },
+ {
+ "a": tensor_shape.TensorShape([]),
+ "b": (tensor_shape.TensorShape([2, 2]),
+ tensor_shape.TensorShape([]))
+ },
+ {
+ "a": ops.Tensor,
+ "b": (sparse_tensor.SparseTensor, ops.Tensor)
+ },
+ {
+ "a": tensor_spec.TensorSpec([], dtypes.float32),
+ "b": (sparse_tensor.SparseTensorSpec(
+ [2, 2], dtypes.int32),
+ tensor_spec.TensorSpec([], dtypes.string))
+ }
+ )
+ ]
+ def reduce_fn(x, y):
+ name, output_types, output_shapes, output_classes, expected_structure = y
+ return x + combinations.combine(
+ output_types=combinations.NamedObject(
+ "output_types.{}".format(name), output_types
+ ),
+ output_shapes=combinations.NamedObject(
+ "output_shapes.{}".format(name), output_shapes
+ ),
+ output_classes=combinations.NamedObject(
+ "output_classes.{}".format(name), output_classes
+ ),
+ expected_structure=combinations.NamedObject(
+ "expected_structure.{}".format(name), expected_structure
+ )
+ )
+
+ return functools.reduce(reduce_fn, cases, [])
+
+def _test_batch_combinations():
+ cases = [
+ (
+ "Tensor",
+ tensor_spec.TensorSpec([], dtypes.float32),
+ 32,
+ tensor_spec.TensorSpec([32], dtypes.float32)
+ ),
+ (
+ "TensorUnknown",
+ tensor_spec.TensorSpec([], dtypes.float32),
+ None,
+ tensor_spec.TensorSpec([None], dtypes.float32)
+ ),
+ (
+ "SparseTensor",
+ sparse_tensor.SparseTensorSpec([None], dtypes.float32),
+ 32,
+ sparse_tensor.SparseTensorSpec([32, None], dtypes.float32)
+ ),
+ (
+ "SparseTensorUnknown",
+ sparse_tensor.SparseTensorSpec([4], dtypes.float32),
+ None,
+ sparse_tensor.SparseTensorSpec([None, 4], dtypes.float32)
+ ),
+ (
+ "RaggedTensor",
+ ragged_tensor.RaggedTensorSpec([2, None], dtypes.float32, 1),
+ 32,
+ ragged_tensor.RaggedTensorSpec([32, 2, None], dtypes.float32, 2)
+ ),
+ (
+ "RaggedTensorUnknown",
+ ragged_tensor.RaggedTensorSpec([4, None], dtypes.float32, 1),
+ None,
+ ragged_tensor.RaggedTensorSpec([None, 4, None], dtypes.float32, 2)
+ ),
+ (
+ "Nested",
+ {
+ "a":
+ tensor_spec.TensorSpec([], dtypes.float32),
+ "b": (
+ sparse_tensor.SparseTensorSpec(
+ [2, 2], dtypes.int32),
+ tensor_spec.TensorSpec([], dtypes.string)
+ )
+ },
+ 128,
+ {
+ "a":
+ tensor_spec.TensorSpec([128], dtypes.float32),
+ "b": (
+ sparse_tensor.SparseTensorSpec(
+ [128, 2, 2], dtypes.int32),
+ tensor_spec.TensorSpec([128], dtypes.string)
+ )
+ }
+ ),
+ ]
+ def reduce_fn(x, y):
+ name, element_structure, batch_size, expected_batched_structure = y
+ return x + combinations.combine(
+ element_structure=combinations.NamedObject(
+ "element_structure.{}".format(name), element_structure
+ ),
+ batch_size=combinations.NamedObject(
+ "batch_size.{}".format(name), batch_size
+ ),
+ expected_batched_structure=combinations.NamedObject(
+ "expected_batched_structure.{}".format(name),
+ expected_batched_structure
+ )
+ )
+
+ return functools.reduce(reduce_fn, cases, [])
+
+def _test_unbatch_combinations():
+ cases = [
+ (
+ "Tensor",
+ tensor_spec.TensorSpec([32], dtypes.float32),
+ tensor_spec.TensorSpec([], dtypes.float32)
+ ),
+ (
+ "TensorUnknown",
+ tensor_spec.TensorSpec([None], dtypes.float32),
+ tensor_spec.TensorSpec([], dtypes.float32)
+ ),
+ (
+ "SparseTensor",
+ sparse_tensor.SparseTensorSpec([32, None], dtypes.float32),
+ sparse_tensor.SparseTensorSpec([None], dtypes.float32)
+ ),
+ (
+ "SparseTensorUnknown",
+ sparse_tensor.SparseTensorSpec([None, 4], dtypes.float32),
+ sparse_tensor.SparseTensorSpec([4], dtypes.float32)
+ ),
+ (
+ "RaggedTensor",
+ ragged_tensor.RaggedTensorSpec([32, None, None], dtypes.float32, 2),
+ ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32, 1)
+ ),
+ (
+ "RaggedTensorUnknown",
+ ragged_tensor.RaggedTensorSpec([None, None, None], dtypes.float32, 2),
+ ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32, 1)
+ ),
+ (
+ "Nested",
+ {
+ "a": tensor_spec.TensorSpec([128], dtypes.float32),
+ "b": (sparse_tensor.SparseTensorSpec([128, 2, 2], dtypes.int32),
+ tensor_spec.TensorSpec([None], dtypes.string))
+ },
+ {
+ "a": tensor_spec.TensorSpec([], dtypes.float32),
+ "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32),
+ tensor_spec.TensorSpec([], dtypes.string))
+ }
+ ),
+ ]
+ def reduce_fn(x, y):
+ name, element_structure, expected_unbatched_structure = y
+ return x + combinations.combine(
+ element_structure=combinations.NamedObject(
+ "element_structure.{}".format(name), element_structure
+ ),
+ expected_unbatched_structure=combinations.NamedObject(
+ "expected_unbatched_structure.{}".format(name),
+ expected_unbatched_structure
+ )
+ )
+
+ return functools.reduce(reduce_fn, cases, [])
+
+def _test_to_batched_tensor_list_combinations():
+ cases = [
+ (
+ "Tensor",
+ lambda: constant_op.constant([[1.0, 2.0], [3.0, 4.0]]),
+ lambda: constant_op.constant([1.0, 2.0])
+ ),
+ (
+ "SparseTensor",
+ lambda: sparse_tensor.SparseTensor(
+ indices=[[0, 0], [1, 1]], values=[13, 27],
+ dense_shape=[2, 2]),
+ lambda: sparse_tensor.SparseTensor(
+ indices=[[0]], values=[13], dense_shape=[2])
+ ),
+ (
+ "RaggedTensor",
+ lambda: ragged_factory_ops.constant([[[1]], [[2]]]),
+ lambda: ragged_factory_ops.constant([[1]])
+ ),
+ (
+ "Nest",
+ lambda: (
+ constant_op.constant([[1.0, 2.0], [3.0, 4.0]]),
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0], [1, 1]], values=[13, 27],
+ dense_shape=[2, 2])
+ ),
+ lambda: (
+ constant_op.constant([1.0, 2.0]),
+ sparse_tensor.SparseTensor(
+ indices=[[0]], values=[13], dense_shape=[2]))
+ ),
+ ]
+ def reduce_fn(x, y):
+ name, value_fn, element_0_fn = y
+ return x + combinations.combine(
+ value_fn=combinations.NamedObject(
+ "value_fn.{}".format(name), value_fn
+ ),
+ element_0_fn=combinations.NamedObject(
+ "element_0_fn.{}".format(name), element_0_fn
+ )
+ )
+
+ return functools.reduce(reduce_fn, cases, [])
+
# TODO(jsimsa): Add tests for OptionalStructure and DatasetStructure.
class StructureTest(test_base.DatasetTestBase, parameterized.TestCase):
@@ -56,81 +752,15 @@
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
- combinations.combine(data=[
- combinations.NamedObject(
- "Tensor", [
- lambda: constant_op.constant(37.0),
- tensor_spec.TensorSpec,
- [dtypes.float32],
- [[]]
- ]),
- combinations.NamedObject(
- "TensorArray", [
- lambda: tensor_array_ops.TensorArray(
- dtype=dtypes.float32, element_shape=(3,),
- size=0),
- tensor_array_ops.TensorArraySpec,
- [dtypes.variant],
- [[]]
- ]),
- combinations.NamedObject(
- "SparseTensor", [
- lambda: sparse_tensor.SparseTensor(
- indices=[[3, 4]], values=[-1],
- dense_shape=[4, 5]),
- sparse_tensor.SparseTensorSpec,
- [dtypes.variant],
- [None]
- ]),
- combinations.NamedObject(
- "RaggedTensor", [
- lambda: ragged_factory_ops.constant(
- [[1, 2], [], [4]]),
- ragged_tensor.RaggedTensorSpec,
- [dtypes.variant],
- [None]
- ]),
- combinations.NamedObject(
- "Nested_0", [
- lambda: (constant_op.constant(37.0),
- constant_op.constant([1, 2, 3])),
- tuple,
- [dtypes.float32, dtypes.int32],
- [[], [3]]
- ]),
- combinations.NamedObject(
- "Nested_1", [
- lambda: {
- "a": constant_op.constant(37.0),
- "b": constant_op.constant([1, 2, 3])
- },
- dict,
- [dtypes.float32, dtypes.int32],
- [[], [3]]
- ]),
- combinations.NamedObject(
- "Nested_2", [
- lambda: {
- "a": constant_op.constant(37.0),
- "b": (sparse_tensor.SparseTensor(
- indices=[[0, 0]], values=[1],
- dense_shape=[1, 1]), sparse_tensor.SparseTensor(
- indices=[[3, 4]], values=[-1],
- dense_shape=[4, 5]))
- },
- dict,
- [dtypes.float32, dtypes.variant, dtypes.variant],
- [[], None, None]
- ]),
- ])
+ _test_flat_structure_combinations()
)
)
- def testFlatStructure(self, data):
- data = data._obj
- value_fn = data[0]
- expected_structure = data[1]
- expected_types = data[2]
- expected_shapes = data[3]
+ def testFlatStructure(self, value_fn, expected_structure,
+ expected_types, expected_shapes):
+ value_fn = value_fn._obj
+ expected_structure = expected_structure._obj
+ expected_types = expected_types._obj
+ expected_shapes = expected_shapes._obj
value = value_fn()
s = structure.type_spec_from_value(value)
self.assertIsInstance(s, expected_structure)
@@ -147,112 +777,16 @@
@combinations.generate(
combinations.times(
test_base.graph_only_combinations(),
- combinations.combine(value_fns=[
- combinations.NamedObject(
- "Tensor", [
- lambda: constant_op.constant(37.0),
- lambda: [
- constant_op.constant(38.0),
- array_ops.placeholder(dtypes.float32),
- variables.Variable(100.0), 42.0,
- np.array(42.0, dtype=np.float32)
- ],
- lambda: [constant_op.constant([1.0, 2.0]),
- constant_op.constant(37)]
- ]),
- combinations.NamedObject(
- "TensorArray", [
- lambda: tensor_array_ops.TensorArray(
- dtype=dtypes.float32, element_shape=(3,), size=0),
- lambda: [
- tensor_array_ops.TensorArray(
- dtype=dtypes.float32, element_shape=(3,), size=0),
- tensor_array_ops.TensorArray(
- dtype=dtypes.float32, element_shape=(3,), size=10)
- ],
- lambda: [
- tensor_array_ops.TensorArray(
- dtype=dtypes.int32, element_shape=(3,), size=0),
- tensor_array_ops.TensorArray(
- dtype=dtypes.float32, element_shape=(), size=0)
- ]
- ]),
- combinations.NamedObject(
- "SparseTensor", [
- lambda: sparse_tensor.SparseTensor(
- indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
- lambda: [
- sparse_tensor.SparseTensor(
- indices=[[1, 1], [3, 4]], values=[10, -1],
- dense_shape=[4, 5]),
- sparse_tensor.SparseTensorValue(
- indices=[[1, 1], [3, 4]], values=[10, -1],
- dense_shape=[4, 5]),
- array_ops.sparse_placeholder(dtype=dtypes.int32),
- array_ops.sparse_placeholder(
- dtype=dtypes.int32, shape=[None, None])
- ],
- lambda: [
- constant_op.constant(37, shape=[4, 5]),
- sparse_tensor.SparseTensor(
- indices=[[3, 4]], values=[-1],
- dense_shape=[5, 6]),
- array_ops.sparse_placeholder(
- dtype=dtypes.int32, shape=[None, None, None]),
- sparse_tensor.SparseTensor(
- indices=[[3, 4]], values=[-1.0],
- dense_shape=[4, 5])
- ]
- ]),
- combinations.NamedObject(
- "RaggedTensor", [
- lambda: ragged_factory_ops.constant([[1, 2], [], [3]]),
- lambda: [
- ragged_factory_ops.constant([[1, 2], [3, 4], []]),
- ragged_factory_ops.constant([[1], [2, 3, 4], [5]]),
- ],
- lambda: [
- ragged_factory_ops.constant(1),
- ragged_factory_ops.constant([1, 2]),
- ragged_factory_ops.constant([[1], [2]]),
- ragged_factory_ops.constant([["a", "b"]]),
- ]
- ]),
- combinations.NamedObject(
- "Nested", [
- lambda: {
- "a": constant_op.constant(37.0),
- "b": constant_op.constant([1, 2, 3])
- },
- lambda: [{
- "a": constant_op.constant(15.0),
- "b": constant_op.constant([4, 5, 6])
- }],
- lambda: [{
- "a": constant_op.constant(15.0),
- "b": constant_op.constant([4, 5, 6, 7])
- }, {
- "a": constant_op.constant(15),
- "b": constant_op.constant([4, 5, 6])
- }, {
- "a":
- constant_op.constant(15),
- "b":
- sparse_tensor.SparseTensor(
- indices=[[0], [1], [2]], values=[4, 5, 6],
- dense_shape=[3])
- }, (constant_op.constant(15.0),
- constant_op.constant([4, 5, 6]))]
- ]),
- ])))
- def testIsCompatibleWithStructure(self, value_fns):
- value_fns = value_fns._obj
- original_value_fn = value_fns[0]
- compatible_values_fn = value_fns[1]
- incompatible_values_fn = value_fns[2]
- original_value = original_value_fn()
- compatible_values = compatible_values_fn()
- incompatible_values = incompatible_values_fn()
+ _test_is_compatible_with_structure_combinations()
+ )
+ )
+ def testIsCompatibleWithStructure(self, original_value_fn,
+ compatible_values_fn,
+ incompatible_values_fn):
+ original_value = original_value_fn._obj()
+ compatible_values = compatible_values_fn._obj()
+ incompatible_values = incompatible_values_fn._obj()
+
s = structure.type_spec_from_value(original_value)
for compatible_value in compatible_values:
self.assertTrue(
@@ -266,64 +800,15 @@
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
- combinations.combine(value_fns=[
- combinations.NamedObject(
- "Tensor", [
- lambda: constant_op.constant(37.0),
- lambda: constant_op.constant(42.0),
- lambda: constant_op.constant([5])
- ]),
- combinations.NamedObject(
- "TensorArray", [
- lambda: tensor_array_ops.TensorArray(
- dtype=dtypes.float32, element_shape=(3,), size=0),
- lambda: tensor_array_ops.TensorArray(
- dtype=dtypes.float32, element_shape=(3,), size=0),
- lambda: tensor_array_ops.TensorArray(
- dtype=dtypes.int32, element_shape=(), size=0)
- ]),
- combinations.NamedObject(
- "SparseTensor", [
- lambda: sparse_tensor.SparseTensor(
- indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
- lambda: sparse_tensor.SparseTensor(
- indices=[[1, 2]], values=[42], dense_shape=[4, 5]),
- lambda: sparse_tensor.SparseTensor(
- indices=[[3]], values=[-1], dense_shape=[5]),
- lambda: sparse_tensor.SparseTensor(
- indices=[[3, 4]], values=[1.0], dense_shape=[4, 5])
- ]),
- combinations.NamedObject(
- "RaggedTensor", [
- lambda: ragged_factory_ops.constant([[[1, 2]], [[3]]]),
- lambda: ragged_factory_ops.constant([[[5]], [[8], [3, 2]]]),
- lambda: ragged_factory_ops.constant([[[1]], [[2], [3]]],
- ragged_rank=1),
- lambda: ragged_factory_ops.constant([[[1.0, 2.0]], [[3.0]]]),
- lambda: ragged_factory_ops.constant([[[1]], [[2]], [[3]]])
- ]),
- combinations.NamedObject(
- "Nested", [
- lambda: {
- "a": constant_op.constant(37.0),
- "b": constant_op.constant([1, 2, 3])},
- lambda: {
- "a": constant_op.constant(42.0),
- "b": constant_op.constant([4, 5, 6])},
- lambda: {
- "a": constant_op.constant([1, 2, 3]),
- "b": constant_op.constant(37.0)
- }
- ]),
- ])
+ _test_structure_from_value_equality_combinations()
)
)
- def testStructureFromValueEquality(self, value_fns):
+ def testStructureFromValueEquality(self, value1_fn, value2_fn,
+ not_equal_value_fns):
# pylint: disable=g-generic-assert
- value_fns = value_fns._obj
- value1_fn = value_fns[0]
- value2_fn = value_fns[1]
- not_equal_value_fns = value_fns[2:]
+ value1_fn = value1_fn._obj
+ value2_fn = value2_fn._obj
+ not_equal_value_fns = not_equal_value_fns._obj
s1 = structure.type_spec_from_value(value1_fn())
s2 = structure.type_spec_from_value(value2_fn())
self.assertEqual(s1, s1) # check __eq__ operator.
@@ -343,86 +828,26 @@
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
- combinations.combine(specs=[
- combinations.NamedObject(
- "RaggedTensor_RaggedRank", [
- ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 1),
- ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 2)
- ]),
- combinations.NamedObject(
- "RaggedTensor_Shape", [
- ragged_tensor.RaggedTensorSpec(
- [3, None], dtypes.int32, 1),
- ragged_tensor.RaggedTensorSpec(
- [5, None], dtypes.int32, 1)
- ]),
- combinations.NamedObject(
- "RaggedTensor_DType", [
- ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 1),
- ragged_tensor.RaggedTensorSpec(None, dtypes.float32, 1)
- ]),
- ])
+ _test_ragged_structure_inequality_combinations()
)
)
- def testRaggedStructureInequality(self, specs):
- specs = specs._obj
- s1 = specs[0]
- s2 = specs[1]
+ def testRaggedStructureInequality(self, spec1, spec2):
+ spec1 = spec1._obj
+ spec2 = spec2._obj
# pylint: disable=g-generic-assert
- self.assertNotEqual(s1, s2) # check __ne__ operator.
- self.assertFalse(s1 == s2) # check __eq__ operator.
+ self.assertNotEqual(spec1, spec2) # check __ne__ operator.
+ self.assertFalse(spec1 == spec2) # check __eq__ operator.
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
- combinations.combine(value_fns=[
- combinations.NamedObject(
- "Tensor", [
- lambda: constant_op.constant(37.0),
- lambda: constant_op.constant(42.0),
- lambda: constant_op.constant([5])
- ]),
- combinations.NamedObject(
- "TensorArray", [
- lambda: tensor_array_ops.TensorArray(
- dtype=dtypes.float32, element_shape=(3,), size=0),
- lambda: tensor_array_ops.TensorArray(
- dtype=dtypes.float32, element_shape=(3,), size=0),
- lambda: tensor_array_ops.TensorArray(
- dtype=dtypes.int32, element_shape=(), size=0)
- ]),
- combinations.NamedObject(
- "SparseTensor", [
- lambda: sparse_tensor.SparseTensor(
- indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
- lambda: sparse_tensor.SparseTensor(
- indices=[[1, 2]], values=[42], dense_shape=[4, 5]),
- lambda: sparse_tensor.SparseTensor(
- indices=[[3]], values=[-1], dense_shape=[5])
- ]),
- combinations.NamedObject(
- "Nested", [
- lambda: {
- "a": constant_op.constant(37.0),
- "b": constant_op.constant([1, 2, 3])
- },
- lambda: {
- "a": constant_op.constant(42.0),
- "b": constant_op.constant([4, 5, 6])
- },
- lambda: {
- "a": constant_op.constant([1, 2, 3]),
- "b": constant_op.constant(37.0)
- }
- ]),
- ])
+ _test_hash_combinations()
)
)
- def testHash(self, value_fns):
- value_fns = value_fns._obj
- value1_fn = value_fns[0]
- value2_fn = value_fns[1]
- value3_fn = value_fns[2]
+ def testHash(self, value1_fn, value2_fn, value3_fn):
+ value1_fn = value1_fn._obj
+ value2_fn = value2_fn._obj
+ value3_fn = value3_fn._obj
s1 = structure.type_spec_from_value(value1_fn())
s2 = structure.type_spec_from_value(value2_fn())
s3 = structure.type_spec_from_value(value3_fn())
@@ -435,51 +860,11 @@
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
- combinations.combine(value_fn=[
- combinations.NamedObject(
- "Tensor",
- lambda: constant_op.constant(37.0),
- ),
- combinations.NamedObject(
- "SparseTensor",
- lambda: sparse_tensor.SparseTensor(
- indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
- ),
- combinations.NamedObject(
- "TensorArray",
- lambda: tensor_array_ops.TensorArray(
- dtype=dtypes.float32, element_shape=(), size=1
- ).write(0, 7)
- ),
- combinations.NamedObject(
- "RaggedTensor",
- lambda: ragged_factory_ops.constant([[1, 2], [], [3]]),
- ),
- combinations.NamedObject(
- "Nested_0",
- lambda: {
- "a": constant_op.constant(37.0),
- "b": constant_op.constant([1, 2, 3])
- },
- ),
- combinations.NamedObject(
- "Nested_1",
- lambda: {
- "a":
- constant_op.constant(37.0),
- "b": (sparse_tensor.SparseTensor(
- indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
- sparse_tensor.SparseTensor(
- indices=[[3, 4]], values=[-1],
- dense_shape=[4, 5]))
- },
- ),
- ])
+ _test_round_trip_conversion_combinations()
)
)
def testRoundTripConversion(self, value_fn):
- value_fn = value_fn._obj
- value = value_fn()
+ value = value_fn._obj()
s = structure.type_spec_from_value(value)
def maybe_stack_ta(v):
@@ -702,92 +1087,15 @@
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
- combinations.combine(data=[
- combinations.NamedObject(
- "Tensor", [
- dtypes.float32,
- tensor_shape.TensorShape([]),
- ops.Tensor,
- tensor_spec.TensorSpec([], dtypes.float32)
- ]),
- combinations.NamedObject(
- "SparseTensor", [
- dtypes.int32,
- tensor_shape.TensorShape([2, 2]),
- sparse_tensor.SparseTensor,
- sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32)
- ]),
- combinations.NamedObject(
- "TensorArray_0", [
- dtypes.int32,
- tensor_shape.TensorShape(
- [None, True, 2, 2]),
- tensor_array_ops.TensorArray,
- tensor_array_ops.TensorArraySpec(
- [2, 2], dtypes.int32, dynamic_size=None,
- infer_shape=True)
- ]),
- combinations.NamedObject(
- "TensorArray_1", [
- dtypes.int32,
- tensor_shape.TensorShape([True, None, 2, 2]),
- tensor_array_ops.TensorArray,
- tensor_array_ops.TensorArraySpec(
- [2, 2], dtypes.int32, dynamic_size=True,
- infer_shape=None)
- ]),
- combinations.NamedObject(
- "TensorArray_2", [
- dtypes.int32,
- tensor_shape.TensorShape([True, False, 2, 2]),
- tensor_array_ops.TensorArray,
- tensor_array_ops.TensorArraySpec(
- [2, 2], dtypes.int32, dynamic_size=True,
- infer_shape=False)
- ]),
- combinations.NamedObject(
- "RaggedTensor", [
- dtypes.int32,
- tensor_shape.TensorShape([2, None]),
- ragged_tensor.RaggedTensorSpec(
- [2, None], dtypes.int32, 1),
- ragged_tensor.RaggedTensorSpec(
- [2, None], dtypes.int32, 1)
- ]),
- combinations.NamedObject(
- "Nested", [
- {
- "a": dtypes.float32,
- "b": (dtypes.int32, dtypes.string)
- },
- {
- "a": tensor_shape.TensorShape([]),
- "b": (tensor_shape.TensorShape([2, 2]),
- tensor_shape.TensorShape([]))
- },
- {
- "a": ops.Tensor,
- "b": (sparse_tensor.SparseTensor, ops.Tensor)
- },
- {
- "a":
- tensor_spec.TensorSpec([], dtypes.float32),
- "b": (
- sparse_tensor.SparseTensorSpec(
- [2, 2], dtypes.int32),
- tensor_spec.TensorSpec([], dtypes.string)
- )
- }
- ]),
- ])
+ _test_convert_legacy_structure_combinations()
)
)
- def testConvertLegacyStructure(self, data):
- data = data._obj
- output_types = data[0]
- output_shapes = data[1]
- output_classes = data[2]
- expected_structure = data[3]
+ def testConvertLegacyStructure(self, output_types, output_shapes,
+ output_classes, expected_structure):
+ output_types = output_types._obj
+ output_shapes = output_shapes._obj
+ output_classes = output_classes._obj
+ expected_structure = expected_structure._obj
actual_structure = structure.convert_legacy_structure(
output_types, output_shapes, output_classes)
self.assertEqual(actual_structure, expected_structure)
@@ -824,77 +1132,13 @@
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
- combinations.combine(data=[
- combinations.NamedObject(
- "Tensor", [
- tensor_spec.TensorSpec([], dtypes.float32),
- 32,
- tensor_spec.TensorSpec([32], dtypes.float32)
- ]),
- combinations.NamedObject(
- "TensorUnknown", [
- tensor_spec.TensorSpec([], dtypes.float32),
- None,
- tensor_spec.TensorSpec([None], dtypes.float32)
- ]),
- combinations.NamedObject(
- "SparseTensor", [
- sparse_tensor.SparseTensorSpec([None], dtypes.float32),
- 32,
- sparse_tensor.SparseTensorSpec([32, None], dtypes.float32)
- ]),
- combinations.NamedObject(
- "SparseTensorUnknown", [
- sparse_tensor.SparseTensorSpec([4], dtypes.float32),
- None,
- sparse_tensor.SparseTensorSpec([None, 4], dtypes.float32)
- ]),
- combinations.NamedObject(
- "RaggedTensor", [
- ragged_tensor.RaggedTensorSpec(
- [2, None], dtypes.float32, 1),
- 32,
- ragged_tensor.RaggedTensorSpec(
- [32, 2, None], dtypes.float32, 2)
- ]),
- combinations.NamedObject(
- "RaggedTensorUnknown", [
- ragged_tensor.RaggedTensorSpec(
- [4, None], dtypes.float32, 1),
- None,
- ragged_tensor.RaggedTensorSpec(
- [None, 4, None], dtypes.float32, 2)
- ]),
- combinations.NamedObject(
- "Nested", [
- {
- "a":
- tensor_spec.TensorSpec([], dtypes.float32),
- "b": (
- sparse_tensor.SparseTensorSpec(
- [2, 2], dtypes.int32),
- tensor_spec.TensorSpec([], dtypes.string)
- )
- },
- 128,
- {
- "a":
- tensor_spec.TensorSpec([128], dtypes.float32),
- "b": (
- sparse_tensor.SparseTensorSpec(
- [128, 2, 2], dtypes.int32),
- tensor_spec.TensorSpec([128], dtypes.string)
- )
- }
- ]),
- ])
+ _test_batch_combinations()
)
)
- def testBatch(self, data):
- data = data._obj
- element_structure = data[0]
- batch_size = data[1]
- expected_batched_structure = data[2]
+ def testBatch(self, element_structure, batch_size, expected_batched_structure):
+ element_structure = element_structure._obj
+ batch_size = batch_size._obj
+ expected_batched_structure = expected_batched_structure._obj
batched_structure = nest.map_structure(
lambda component_spec: component_spec._batch(batch_size),
element_structure)
@@ -903,70 +1147,12 @@
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
- combinations.combine(data=[
- combinations.NamedObject(
- "Tensor", [
- tensor_spec.TensorSpec([32], dtypes.float32),
- tensor_spec.TensorSpec([], dtypes.float32)
- ]),
- combinations.NamedObject(
- "TensorUnknown", [
- tensor_spec.TensorSpec([None], dtypes.float32),
- tensor_spec.TensorSpec([], dtypes.float32)
- ]),
- combinations.NamedObject(
- "SparseTensor", [
- sparse_tensor.SparseTensorSpec(
- [32, None], dtypes.float32),
- sparse_tensor.SparseTensorSpec([None], dtypes.float32)
- ]),
- combinations.NamedObject(
- "SparseTensorUnknown", [
- sparse_tensor.SparseTensorSpec([None, 4], dtypes.float32),
- sparse_tensor.SparseTensorSpec([4], dtypes.float32)
- ]),
- combinations.NamedObject(
- "RaggedTensor", [
- ragged_tensor.RaggedTensorSpec([32, None, None],
- dtypes.float32, 2),
- ragged_tensor.RaggedTensorSpec([None, None],
- dtypes.float32, 1)
- ]),
- combinations.NamedObject(
- "RaggedTensorUnknown", [
- ragged_tensor.RaggedTensorSpec([None, None, None],
- dtypes.float32, 2),
- ragged_tensor.RaggedTensorSpec([None, None],
- dtypes.float32, 1)
- ]),
- combinations.NamedObject(
- "Nested", [
- {
- "a":
- tensor_spec.TensorSpec([128], dtypes.float32),
- "b": (
- sparse_tensor.SparseTensorSpec(
- [128, 2, 2], dtypes.int32),
- tensor_spec.TensorSpec([None], dtypes.string)
- )
- },
- {
- "a":
- tensor_spec.TensorSpec([], dtypes.float32),
- "b": (
- sparse_tensor.SparseTensorSpec(
- [2, 2], dtypes.int32),
- tensor_spec.TensorSpec([], dtypes.string)
- )
- }
- ]),
- ])
+ _test_unbatch_combinations()
)
)
- def testUnbatch(self, data):
- data = data._obj
- element_structure = data[0]
- expected_unbatched_structure = data[1]
+ def testUnbatch(self, element_structure, expected_unbatched_structure):
+ element_structure = element_structure._obj
+ expected_unbatched_structure = expected_unbatched_structure._obj
unbatched_structure = nest.map_structure(
lambda component_spec: component_spec._unbatch(), element_structure)
self.assertEqual(unbatched_structure, expected_unbatched_structure)
@@ -975,45 +1161,12 @@
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
- combinations.combine(value_fns=[
- combinations.NamedObject(
- "Tensor", [
- lambda: constant_op.constant([[1.0, 2.0], [3.0, 4.0]]),
- lambda: constant_op.constant([1.0, 2.0])
- ]),
- combinations.NamedObject(
- "SparseTensor", [
- lambda: sparse_tensor.SparseTensor(
- indices=[[0, 0], [1, 1]], values=[13, 27],
- dense_shape=[2, 2]),
- lambda: sparse_tensor.SparseTensor(
- indices=[[0]], values=[13], dense_shape=[2])
- ]),
- combinations.NamedObject(
- "RaggedTensor", [
- lambda: ragged_factory_ops.constant([[[1]], [[2]]]),
- lambda: ragged_factory_ops.constant([[1]])
- ]),
- combinations.NamedObject(
- "Nest", [
- lambda: (
- constant_op.constant([[1.0, 2.0], [3.0, 4.0]]),
- sparse_tensor.SparseTensor(
- indices=[[0, 0], [1, 1]], values=[13, 27],
- dense_shape=[2, 2])
- ),
- lambda: (
- constant_op.constant([1.0, 2.0]),
- sparse_tensor.SparseTensor(
- indices=[[0]], values=[13], dense_shape=[2]))
- ]),
- ])
+ _test_to_batched_tensor_list_combinations()
)
)
- def testToBatchedTensorList(self, value_fns):
- value_fns = value_fns._obj
- value_fn = value_fns[0]
- element_0_fn = value_fns[1]
+ def testToBatchedTensorList(self, value_fn, element_0_fn):
+ value_fn = value_fn._obj
+ element_0_fn = element_0_fn._obj
batched_value = value_fn()
s = structure.type_spec_from_value(batched_value)
batched_tensor_list = structure.to_batched_tensor_list(s, batched_value)