| # Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for tensorflow.ops.check_ops.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import time |
| |
| import numpy as np |
| |
| from tensorflow.core.protobuf import config_pb2 |
| from tensorflow.core.protobuf import rewriter_config_pb2 |
| from tensorflow.python.client import session |
| from tensorflow.python.eager import context |
| from tensorflow.python.eager import def_function |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import errors |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import sparse_tensor |
| from tensorflow.python.framework import test_util |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import check_ops |
| from tensorflow.python.ops import gradients |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import random_ops |
| from tensorflow.python.platform import test |
| |
| |
| # pylint:disable=g-error-prone-assert-raises |
| class AssertV2Asserts(test.TestCase): |
| |
| def test_passes_when_it_should(self): |
| # This is a v2 test and need to run eagerly |
| with context.eager_mode(): |
| c1 = constant_op.constant(-1, name="minus_one", dtype=dtypes.int32) |
| c2 = constant_op.constant(2, name="two", dtype=dtypes.int32) |
| c3 = constant_op.constant([3., 3.], name="three", dtype=dtypes.float32) |
| c4 = constant_op.constant([3., 3.5], name="three_and_a_half", |
| dtype=dtypes.float32) |
| scalar = c1 |
| non_scalar = c3 |
| integer = c1 |
| non_integer = c3 |
| positive = c2 |
| negative = c1 |
| cases = [ |
| (check_ops.assert_equal_v2, (c1, c1), (c1, c2)), |
| (check_ops.assert_less_v2, (c1, c2), (c1, c1)), |
| (check_ops.assert_near_v2, (c3, c3), (c3, c4)), |
| (check_ops.assert_greater_v2, (c2, c1), (c1, c1)), |
| (check_ops.assert_negative_v2, (negative,), (positive,)), |
| (check_ops.assert_positive_v2, (positive,), (negative,)), |
| (check_ops.assert_less_equal_v2, (c1, c1), (c2, c1)), |
| (check_ops.assert_none_equal_v2, (c1, c2), (c3, c4)), |
| (check_ops.assert_non_negative_v2, (positive,), (negative,)), |
| (check_ops.assert_non_positive_v2, (negative,), (positive,)), |
| (check_ops.assert_greater_equal_v2, (c1, c1), (c1, c2)), |
| (check_ops.assert_type_v2, (c1, dtypes.int32), (c1, dtypes.float32), |
| TypeError), |
| (check_ops.assert_integer_v2, (integer,), (non_integer,), |
| TypeError), |
| (check_ops.assert_scalar_v2, (scalar,), (non_scalar,), |
| ValueError), |
| (check_ops.assert_rank_v2, (c1, 0), (c3, 2), ValueError), |
| (check_ops.assert_rank_in_v2, (c1, [0, 1]), (c1, [1, 2]), |
| ValueError), |
| (check_ops.assert_rank_at_least_v2, (non_scalar, 1), (scalar, 1), |
| ValueError), |
| ] |
| |
| for case in cases: |
| fn = case[0] |
| passing_args = case[1] |
| failing_args = case[2] |
| error = errors.InvalidArgumentError if len(case) < 4 else case[3] |
| |
| print("Testing %s passing properly." % fn) |
| |
| fn(*passing_args) |
| |
| print("Testing %s failing properly." % fn) |
| |
| @def_function.function |
| def failing_fn(): |
| fn(*failing_args, message="fail") # pylint: disable=cell-var-from-loop |
| |
| with self.assertRaisesRegex(error, "fail"): |
| failing_fn() |
| |
| del failing_fn |
| |
| |
| class AssertProperIterableTest(test.TestCase): |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_single_tensor_raises(self): |
| tensor = constant_op.constant(1) |
| with self.assertRaisesRegex(TypeError, "proper"): |
| check_ops.assert_proper_iterable(tensor) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_single_sparse_tensor_raises(self): |
| ten = sparse_tensor.SparseTensor( |
| indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) |
| with self.assertRaisesRegex(TypeError, "proper"): |
| check_ops.assert_proper_iterable(ten) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_single_ndarray_raises(self): |
| array = np.array([1, 2, 3]) |
| with self.assertRaisesRegex(TypeError, "proper"): |
| check_ops.assert_proper_iterable(array) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_single_string_raises(self): |
| mystr = "hello" |
| with self.assertRaisesRegex(TypeError, "proper"): |
| check_ops.assert_proper_iterable(mystr) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_non_iterable_object_raises(self): |
| non_iterable = 1234 |
| with self.assertRaisesRegex(TypeError, "to be an iterable"): |
| check_ops.assert_proper_iterable(non_iterable) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_list_does_not_raise(self): |
| list_of_stuff = [ |
| constant_op.constant([11, 22]), constant_op.constant([1, 2]) |
| ] |
| check_ops.assert_proper_iterable(list_of_stuff) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_generator_does_not_raise(self): |
| generator_of_stuff = (constant_op.constant([11, 22]), constant_op.constant( |
| [1, 2])) |
| check_ops.assert_proper_iterable(generator_of_stuff) |
| |
| |
| class AssertEqualTest(test.TestCase): |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_equal(self): |
| small = constant_op.constant([1, 2], name="small") |
| with ops.control_dependencies([check_ops.assert_equal(small, small)]): |
| out = array_ops.identity(small) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_scalar_comparison(self): |
| const_true = constant_op.constant(True, name="true") |
| const_false = constant_op.constant(False, name="false") |
| with self.assertRaisesRegex(errors.InvalidArgumentError, "fail"): |
| check_ops.assert_equal(const_true, const_false, message="fail") |
| |
| def test_returns_none_with_eager(self): |
| with context.eager_mode(): |
| small = constant_op.constant([1, 2], name="small") |
| x = check_ops.assert_equal(small, small) |
| assert x is None |
| |
| @test_util.run_in_graph_and_eager_modes |
| @test_util.run_deprecated_v1 |
| def test_raises_when_greater(self): |
| # Static check |
| static_small = constant_op.constant([1, 2], name="small") |
| static_big = constant_op.constant([3, 4], name="big") |
| with self.assertRaisesRegex(errors.InvalidArgumentError, "fail"): |
| check_ops.assert_equal(static_big, static_small, message="fail") |
| |
| @test_util.run_deprecated_v1 |
| def test_raises_when_greater_dynamic(self): |
| with self.cached_session(): |
| small = array_ops.placeholder(dtypes.int32, name="small") |
| big = array_ops.placeholder(dtypes.int32, name="big") |
| with ops.control_dependencies( |
| [check_ops.assert_equal(big, small, message="fail")]): |
| out = array_ops.identity(small) |
| with self.assertRaisesOpError("fail.*big.*small"): |
| out.eval(feed_dict={small: [1, 2], big: [3, 4]}) |
| |
| def test_error_message_eager(self): |
| expected_error_msg_full = r"""big does not equal small |
| Condition x == y did not hold. |
| Indices of first 3 different values: |
| \[\[0 0\] |
| \[1 1\] |
| \[2 0\]\] |
| Corresponding x values: |
| \[2 3 6\] |
| Corresponding y values: |
| \[20 30 60\] |
| First 6 elements of x: |
| \[2 2 3 3 6 6\] |
| First 6 elements of y: |
| \[20 2 3 30 60 6\]""" |
| expected_error_msg_default = r"""big does not equal small |
| Condition x == y did not hold. |
| Indices of first 3 different values: |
| \[\[0 0\] |
| \[1 1\] |
| \[2 0\]\] |
| Corresponding x values: |
| \[2 3 6\] |
| Corresponding y values: |
| \[20 30 60\] |
| First 3 elements of x: |
| \[2 2 3\] |
| First 3 elements of y: |
| \[20 2 3\]""" |
| expected_error_msg_short = r"""big does not equal small |
| Condition x == y did not hold. |
| Indices of first 2 different values: |
| \[\[0 0\] |
| \[1 1\]\] |
| Corresponding x values: |
| \[2 3\] |
| Corresponding y values: |
| \[20 30\] |
| First 2 elements of x: |
| \[2 2\] |
| First 2 elements of y: |
| \[20 2\]""" |
| with context.eager_mode(): |
| big = constant_op.constant([[2, 2], [3, 3], [6, 6]]) |
| small = constant_op.constant([[20, 2], [3, 30], [60, 6]]) |
| with self.assertRaisesRegex(errors.InvalidArgumentError, |
| expected_error_msg_full): |
| check_ops.assert_equal(big, small, message="big does not equal small", |
| summarize=10) |
| with self.assertRaisesRegex(errors.InvalidArgumentError, |
| expected_error_msg_default): |
| check_ops.assert_equal(big, small, message="big does not equal small") |
| with self.assertRaisesRegex(errors.InvalidArgumentError, |
| expected_error_msg_short): |
| check_ops.assert_equal(big, small, message="big does not equal small", |
| summarize=2) |
| |
| @test_util.run_in_graph_and_eager_modes |
| @test_util.run_deprecated_v1 |
| def test_raises_when_less(self): |
| # Static check |
| static_small = constant_op.constant([3, 1], name="small") |
| static_big = constant_op.constant([4, 2], name="big") |
| with self.assertRaisesRegex(errors.InvalidArgumentError, "fail"): |
| check_ops.assert_equal(static_big, static_small, message="fail") |
| |
| @test_util.run_deprecated_v1 |
| def test_raises_when_less_dynamic(self): |
| with self.cached_session(): |
| small = array_ops.placeholder(dtypes.int32, name="small") |
| big = array_ops.placeholder(dtypes.int32, name="big") |
| with ops.control_dependencies([check_ops.assert_equal(small, big)]): |
| out = array_ops.identity(small) |
| with self.assertRaisesOpError("small.*big"): |
| out.eval(feed_dict={small: [3, 1], big: [4, 2]}) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_equal_and_broadcastable_shapes(self): |
| small = constant_op.constant([[1, 2], [1, 2]], name="small") |
| small_2 = constant_op.constant([1, 2], name="small_2") |
| with ops.control_dependencies([check_ops.assert_equal(small, small_2)]): |
| out = array_ops.identity(small) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_raises_when_equal_but_non_broadcastable_shapes(self): |
| small = constant_op.constant([1, 1, 1], name="small") |
| small_2 = constant_op.constant([1, 1], name="small_2") |
| # The exception in eager and non-eager mode is different because |
| # eager mode relies on shape check done as part of the C++ op, while |
| # graph mode does shape checks when creating the `Operation` instance. |
| with self.assertRaisesIncompatibleShapesError( |
| (errors.InvalidArgumentError, ValueError)): |
| with ops.control_dependencies([check_ops.assert_equal(small, small_2)]): |
| out = array_ops.identity(small) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_raises_when_not_equal_and_broadcastable_shapes(self): |
| cond = constant_op.constant([True, False], name="small") |
| with self.assertRaisesRegex(errors.InvalidArgumentError, "fail"): |
| check_ops.assert_equal(cond, False, message="fail") |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_both_empty(self): |
| larry = constant_op.constant([]) |
| curly = constant_op.constant([]) |
| with ops.control_dependencies([check_ops.assert_equal(larry, curly)]): |
| out = array_ops.identity(larry) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_noop_when_both_identical(self): |
| larry = constant_op.constant([]) |
| check_op = check_ops.assert_equal(larry, larry) |
| if context.executing_eagerly(): |
| self.assertIs(check_op, None) |
| else: |
| self.assertEqual(check_op.type, "NoOp") |
| |
| |
| class AssertNoneEqualTest(test.TestCase): |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_not_equal(self): |
| small = constant_op.constant([1, 2], name="small") |
| big = constant_op.constant([10, 20], name="small") |
| with ops.control_dependencies( |
| [check_ops.assert_none_equal(big, small)]): |
| out = array_ops.identity(small) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| @test_util.run_deprecated_v1 |
| def test_raises_when_equal(self): |
| small = constant_op.constant([3, 1], name="small") |
| with self.assertRaisesOpError("x != y did not hold"): |
| with ops.control_dependencies( |
| [check_ops.assert_none_equal(small, small)]): |
| out = array_ops.identity(small) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_not_equal_and_broadcastable_shapes(self): |
| small = constant_op.constant([1, 2], name="small") |
| big = constant_op.constant([3], name="big") |
| with ops.control_dependencies( |
| [check_ops.assert_none_equal(small, big)]): |
| out = array_ops.identity(small) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_raises_when_not_equal_but_non_broadcastable_shapes(self): |
| small = constant_op.constant([1, 1, 1], name="small") |
| big = constant_op.constant([10, 10], name="big") |
| # The exception in eager and non-eager mode is different because |
| # eager mode relies on shape check done as part of the C++ op, while |
| # graph mode does shape checks when creating the `Operation` instance. |
| with self.assertRaisesIncompatibleShapesError( |
| (ValueError, errors.InvalidArgumentError)): |
| with ops.control_dependencies( |
| [check_ops.assert_none_equal(small, big)]): |
| out = array_ops.identity(small) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_both_empty(self): |
| larry = constant_op.constant([]) |
| curly = constant_op.constant([]) |
| with ops.control_dependencies( |
| [check_ops.assert_none_equal(larry, curly)]): |
| out = array_ops.identity(larry) |
| self.evaluate(out) |
| |
| def test_returns_none_with_eager(self): |
| with context.eager_mode(): |
| t1 = constant_op.constant([1, 2]) |
| t2 = constant_op.constant([3, 4]) |
| x = check_ops.assert_none_equal(t1, t2) |
| assert x is None |
| |
| def test_static_check_in_graph_mode(self): |
| with ops.Graph().as_default(): |
| with self.assertRaisesRegex( # pylint:disable=g-error-prone-assert-raises |
| errors.InvalidArgumentError, "Custom error message"): |
| check_ops.assert_none_equal(1, 1, message="Custom error message") |
| |
| def test_error_message_eager(self): |
| # Note that the following three strings are regexes |
| expected_error_msg_full = r"""\[ *0\. +1\. +2\. +3\. +4\. +5\.\]""" |
| expected_error_msg_default = r"""\[ *0\. +1\. +2\.\]""" |
| expected_error_msg_short = r"""\[ *0\. +1\.\]""" |
| with context.eager_mode(): |
| t = constant_op.constant( |
| np.array(range(6)), shape=[2, 3], dtype=np.float32) |
| with self.assertRaisesRegex( # pylint:disable=g-error-prone-assert-raises |
| errors.InvalidArgumentError, expected_error_msg_full): |
| check_ops.assert_none_equal( |
| t, t, message="This is the error message.", summarize=10) |
| with self.assertRaisesRegex( # pylint:disable=g-error-prone-assert-raises |
| errors.InvalidArgumentError, expected_error_msg_full): |
| check_ops.assert_none_equal( |
| t, t, message="This is the error message.", summarize=-1) |
| with self.assertRaisesRegex( # pylint:disable=g-error-prone-assert-raises |
| errors.InvalidArgumentError, expected_error_msg_default): |
| check_ops.assert_none_equal(t, t, message="This is the error message.") |
| with self.assertRaisesRegex( # pylint:disable=g-error-prone-assert-raises |
| errors.InvalidArgumentError, expected_error_msg_short): |
| check_ops.assert_none_equal( |
| t, t, message="This is the error message.", summarize=2) |
| |
| |
| class AssertAllCloseTest(test.TestCase): |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_equal(self): |
| x = constant_op.constant(1., name="x") |
| y = constant_op.constant(1., name="y") |
| with ops.control_dependencies( |
| [check_ops.assert_near(x, y, message="failure message")]): |
| out = array_ops.identity(x) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_close_enough_32_bit_due_to_default_rtol(self): |
| eps = np.finfo(np.float32).eps |
| # Default rtol/atol is 10*eps |
| x = constant_op.constant(1., name="x") |
| y = constant_op.constant(1. + 2 * eps, name="y", dtype=np.float32) |
| with ops.control_dependencies( |
| [check_ops.assert_near(x, y, atol=0., message="failure message")]): |
| out = array_ops.identity(x) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_close_enough_32_bit_due_to_default_atol(self): |
| eps = np.finfo(np.float32).eps |
| # Default rtol/atol is 10*eps |
| x = constant_op.constant(0., name="x") |
| y = constant_op.constant(0. + 2 * eps, name="y", dtype=np.float32) |
| with ops.control_dependencies( |
| [check_ops.assert_near(x, y, rtol=0., message="failure message")]): |
| out = array_ops.identity(x) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_close_enough_64_bit_due_to_default_rtol(self): |
| eps = np.finfo(np.float64).eps |
| # Default rtol/atol is 10*eps |
| x = constant_op.constant(1., name="x", dtype=np.float64) |
| y = constant_op.constant(1. + 2 * eps, name="y", dtype=np.float64) |
| with ops.control_dependencies( |
| [check_ops.assert_near(x, y, atol=0., message="failure message")]): |
| out = array_ops.identity(x) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_close_enough_64_bit_due_to_default_atol(self): |
| eps = np.finfo(np.float64).eps |
| # Default rtol/atol is 10*eps |
| x = constant_op.constant(0., name="x", dtype=np.float64) |
| y = constant_op.constant(0. + 2 * eps, name="y", dtype=np.float64) |
| with ops.control_dependencies( |
| [check_ops.assert_near(x, y, rtol=0., message="failure message")]): |
| out = array_ops.identity(x) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_close_enough_due_to_custom_rtol(self): |
| x = constant_op.constant(1., name="x") |
| y = constant_op.constant(1.1, name="y") |
| with ops.control_dependencies( |
| [check_ops.assert_near(x, y, atol=0., rtol=0.5, |
| message="failure message")]): |
| out = array_ops.identity(x) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_close_enough_due_to_custom_atol(self): |
| x = constant_op.constant(0., name="x") |
| y = constant_op.constant(0.1, name="y", dtype=np.float32) |
| with ops.control_dependencies( |
| [check_ops.assert_near(x, y, atol=0.5, rtol=0., |
| message="failure message")]): |
| out = array_ops.identity(x) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_both_empty(self): |
| larry = constant_op.constant([]) |
| curly = constant_op.constant([]) |
| with ops.control_dependencies([check_ops.assert_near(larry, curly)]): |
| out = array_ops.identity(larry) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_raises_when_atol_violated(self): |
| x = constant_op.constant(10., name="x") |
| y = constant_op.constant(10.2, name="y") |
| with self.assertRaisesOpError( # pylint:disable=g-error-prone-assert-raises |
| "x and y not equal to tolerance"): |
| with ops.control_dependencies( |
| [check_ops.assert_near(x, y, atol=0.1, |
| message="failure message")]): |
| out = array_ops.identity(x) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_raises_when_default_rtol_violated(self): |
| x = constant_op.constant(0.1, name="x") |
| y = constant_op.constant(0.0, name="y") |
| with self.assertRaisesOpError( # pylint:disable=g-error-prone-assert-raises |
| "x and y not equal to tolerance"): |
| with ops.control_dependencies( |
| [check_ops.assert_near(x, y, message="failure message")]): |
| out = array_ops.identity(x) |
| self.evaluate(out) |
| |
| def test_returns_none_with_eager(self): |
| with context.eager_mode(): |
| t1 = constant_op.constant([1., 2.]) |
| t2 = constant_op.constant([1., 2.]) |
| x = check_ops.assert_near(t1, t2) |
| assert x is None |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_complex(self): |
| x = constant_op.constant(1. + 0.1j, name="x") |
| y = constant_op.constant(1.1 + 0.1j, name="y") |
| with ops.control_dependencies([ |
| check_ops.assert_near( |
| x, y, atol=0., rtol=0.5, message="failure message") |
| ]): |
| out = array_ops.identity(x) |
| self.evaluate(out) |
| |
| |
| class AssertLessTest(test.TestCase): |
| |
| @test_util.run_in_graph_and_eager_modes |
| @test_util.run_deprecated_v1 |
| def test_raises_when_equal(self): |
| small = constant_op.constant([1, 2], name="small") |
| with self.assertRaisesOpError( # pylint:disable=g-error-prone-assert-raises |
| "failure message.*\n*.* x < y did not hold"): |
| with ops.control_dependencies( |
| [check_ops.assert_less( |
| small, small, message="failure message")]): |
| out = array_ops.identity(small) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| @test_util.run_deprecated_v1 |
| def test_raises_when_greater(self): |
| small = constant_op.constant([1, 2], name="small") |
| big = constant_op.constant([3, 4], name="big") |
| with self.assertRaisesOpError( # pylint:disable=g-error-prone-assert-raises |
| "x < y did not hold"): |
| with ops.control_dependencies([check_ops.assert_less(big, small)]): |
| out = array_ops.identity(small) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_less(self): |
| small = constant_op.constant([3, 1], name="small") |
| big = constant_op.constant([4, 2], name="big") |
| with ops.control_dependencies([check_ops.assert_less(small, big)]): |
| out = array_ops.identity(small) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_less_and_broadcastable_shapes(self): |
| small = constant_op.constant([1], name="small") |
| big = constant_op.constant([3, 2], name="big") |
| with ops.control_dependencies([check_ops.assert_less(small, big)]): |
| out = array_ops.identity(small) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_raises_when_less_but_non_broadcastable_shapes(self): |
| small = constant_op.constant([1, 1, 1], name="small") |
| big = constant_op.constant([3, 2], name="big") |
| # The exception in eager and non-eager mode is different because |
| # eager mode relies on shape check done as part of the C++ op, while |
| # graph mode does shape checks when creating the `Operation` instance. |
| with self.assertRaisesIncompatibleShapesError( |
| (ValueError, errors.InvalidArgumentError)): |
| with ops.control_dependencies([check_ops.assert_less(small, big)]): |
| out = array_ops.identity(small) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_both_empty(self): |
| larry = constant_op.constant([]) |
| curly = constant_op.constant([]) |
| with ops.control_dependencies([check_ops.assert_less(larry, curly)]): |
| out = array_ops.identity(larry) |
| self.evaluate(out) |
| |
| def test_returns_none_with_eager(self): |
| with context.eager_mode(): |
| t1 = constant_op.constant([1, 2]) |
| t2 = constant_op.constant([3, 4]) |
| x = check_ops.assert_less(t1, t2) |
| assert x is None |
| |
| def test_static_check_in_graph_mode(self): |
| with ops.Graph().as_default(): |
| with self.assertRaisesRegex( # pylint:disable=g-error-prone-assert-raises |
| errors.InvalidArgumentError, "Custom error message"): |
| check_ops.assert_less(1, 1, message="Custom error message") |
| |
| |
| class AssertLessEqualTest(test.TestCase): |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_equal(self): |
| small = constant_op.constant([1, 2], name="small") |
| with ops.control_dependencies( |
| [check_ops.assert_less_equal(small, small)]): |
| out = array_ops.identity(small) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| @test_util.run_deprecated_v1 |
| def test_raises_when_greater(self): |
| small = constant_op.constant([1, 2], name="small") |
| big = constant_op.constant([3, 4], name="big") |
| with self.assertRaisesOpError( # pylint:disable=g-error-prone-assert-raises |
| "fail"): |
| with ops.control_dependencies( |
| [check_ops.assert_less_equal( |
| big, small, message="fail")]): |
| out = array_ops.identity(small) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_less_equal(self): |
| small = constant_op.constant([1, 2], name="small") |
| big = constant_op.constant([3, 2], name="big") |
| with ops.control_dependencies([check_ops.assert_less_equal(small, big)]): |
| out = array_ops.identity(small) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_less_equal_and_broadcastable_shapes(self): |
| small = constant_op.constant([1], name="small") |
| big = constant_op.constant([3, 1], name="big") |
| with ops.control_dependencies([check_ops.assert_less_equal(small, big)]): |
| out = array_ops.identity(small) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_raises_when_less_equal_but_non_broadcastable_shapes(self): |
| small = constant_op.constant([3, 1], name="small") |
| big = constant_op.constant([1, 1, 1], name="big") |
| # The exception in eager and non-eager mode is different because |
| # eager mode relies on shape check done as part of the C++ op, while |
| # graph mode does shape checks when creating the `Operation` instance. |
| with self.assertRaisesRegex( # pylint:disable=g-error-prone-assert-raises |
| (errors.InvalidArgumentError, ValueError), |
| (r"Incompatible shapes: \[2\] vs. \[3\]|" |
| r"Dimensions must be equal, but are 2 and 3")): |
| with ops.control_dependencies( |
| [check_ops.assert_less_equal(small, big)]): |
| out = array_ops.identity(small) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_both_empty(self): |
| larry = constant_op.constant([]) |
| curly = constant_op.constant([]) |
| with ops.control_dependencies( |
| [check_ops.assert_less_equal(larry, curly)]): |
| out = array_ops.identity(larry) |
| self.evaluate(out) |
| |
| def test_static_check_in_graph_mode(self): |
| with ops.Graph().as_default(): |
| with self.assertRaisesRegex( # pylint:disable=g-error-prone-assert-raises |
| errors.InvalidArgumentError, "Custom error message"): |
| check_ops.assert_less_equal(1, 0, message="Custom error message") |
| |
| |
| class AssertGreaterTest(test.TestCase): |
| |
| @test_util.run_in_graph_and_eager_modes |
| @test_util.run_deprecated_v1 |
| def test_raises_when_equal(self): |
| small = constant_op.constant([1, 2], name="small") |
| with self.assertRaisesOpError( # pylint:disable=g-error-prone-assert-raises |
| "fail"): |
| with ops.control_dependencies( |
| [check_ops.assert_greater( |
| small, small, message="fail")]): |
| out = array_ops.identity(small) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| @test_util.run_deprecated_v1 |
| def test_raises_when_less(self): |
| small = constant_op.constant([1, 2], name="small") |
| big = constant_op.constant([3, 4], name="big") |
| with self.assertRaisesOpError( # pylint:disable=g-error-prone-assert-raises |
| "x > y did not hold"): |
| with ops.control_dependencies([check_ops.assert_greater(small, big)]): |
| out = array_ops.identity(big) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_greater(self): |
| small = constant_op.constant([3, 1], name="small") |
| big = constant_op.constant([4, 2], name="big") |
| with ops.control_dependencies([check_ops.assert_greater(big, small)]): |
| out = array_ops.identity(small) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_greater_and_broadcastable_shapes(self): |
| small = constant_op.constant([1], name="small") |
| big = constant_op.constant([3, 2], name="big") |
| with ops.control_dependencies([check_ops.assert_greater(big, small)]): |
| out = array_ops.identity(small) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_raises_when_greater_but_non_broadcastable_shapes(self): |
| small = constant_op.constant([1, 1, 1], name="small") |
| big = constant_op.constant([3, 2], name="big") |
| # The exception in eager and non-eager mode is different because |
| # eager mode relies on shape check done as part of the C++ op, while |
| # graph mode does shape checks when creating the `Operation` instance. |
| with self.assertRaisesRegex( # pylint:disable=g-error-prone-assert-raises |
| (errors.InvalidArgumentError, ValueError), |
| (r"Incompatible shapes: \[2\] vs. \[3\]|" |
| r"Dimensions must be equal, but are 2 and 3")): |
| with ops.control_dependencies([check_ops.assert_greater(big, small)]): |
| out = array_ops.identity(small) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_both_empty(self): |
| larry = constant_op.constant([]) |
| curly = constant_op.constant([]) |
| with ops.control_dependencies([check_ops.assert_greater(larry, curly)]): |
| out = array_ops.identity(larry) |
| self.evaluate(out) |
| |
| def test_static_check_in_graph_mode(self): |
| with ops.Graph().as_default(): |
| with self.assertRaisesRegex( # pylint:disable=g-error-prone-assert-raises |
| errors.InvalidArgumentError, "Custom error message"): |
| check_ops.assert_greater(0, 1, message="Custom error message") |
| |
| |
| class AssertGreaterEqualTest(test.TestCase): |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_equal(self): |
| small = constant_op.constant([1, 2], name="small") |
| with ops.control_dependencies( |
| [check_ops.assert_greater_equal(small, small)]): |
| out = array_ops.identity(small) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| @test_util.run_deprecated_v1 |
| def test_raises_when_less(self): |
| small = constant_op.constant([1, 2], name="small") |
| big = constant_op.constant([3, 4], name="big") |
| with self.assertRaisesOpError( # pylint:disable=g-error-prone-assert-raises |
| "fail"): |
| with ops.control_dependencies( |
| [check_ops.assert_greater_equal( |
| small, big, message="fail")]): |
| out = array_ops.identity(small) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_greater_equal(self): |
| small = constant_op.constant([1, 2], name="small") |
| big = constant_op.constant([3, 2], name="big") |
| with ops.control_dependencies( |
| [check_ops.assert_greater_equal(big, small)]): |
| out = array_ops.identity(small) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_greater_equal_and_broadcastable_shapes(self): |
| small = constant_op.constant([1], name="small") |
| big = constant_op.constant([3, 1], name="big") |
| with ops.control_dependencies( |
| [check_ops.assert_greater_equal(big, small)]): |
| out = array_ops.identity(small) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_raises_when_less_equal_but_non_broadcastable_shapes(self): |
| small = constant_op.constant([1, 1, 1], name="big") |
| big = constant_op.constant([3, 1], name="small") |
| # The exception in eager and non-eager mode is different because |
| # eager mode relies on shape check done as part of the C++ op, while |
| # graph mode does shape checks when creating the `Operation` instance. |
| with self.assertRaisesRegex( # pylint:disable=g-error-prone-assert-raises |
| (errors.InvalidArgumentError, ValueError), |
| (r"Incompatible shapes: \[2\] vs. \[3\]|" |
| r"Dimensions must be equal, but are 2 and 3")): |
| with ops.control_dependencies( |
| [check_ops.assert_greater_equal(big, small)]): |
| out = array_ops.identity(small) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_both_empty(self): |
| larry = constant_op.constant([]) |
| curly = constant_op.constant([]) |
| with ops.control_dependencies( |
| [check_ops.assert_greater_equal(larry, curly)]): |
| out = array_ops.identity(larry) |
| self.evaluate(out) |
| |
| def test_static_check_in_graph_mode(self): |
| with ops.Graph().as_default(): |
| with self.assertRaisesRegex( # pylint:disable=g-error-prone-assert-raises |
| errors.InvalidArgumentError, "Custom error message"): |
| check_ops.assert_greater_equal(0, 1, message="Custom error message") |
| |
| |
| class AssertNegativeTest(test.TestCase): |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_negative(self): |
| frank = constant_op.constant([-1, -2], name="frank") |
| with ops.control_dependencies([check_ops.assert_negative(frank)]): |
| out = array_ops.identity(frank) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| @test_util.run_deprecated_v1 |
| def test_raises_when_positive(self): |
| doug = constant_op.constant([1, 2], name="doug") |
| with self.assertRaisesOpError( # pylint:disable=g-error-prone-assert-raises |
| "fail"): |
| with ops.control_dependencies( |
| [check_ops.assert_negative( |
| doug, message="fail")]): |
| out = array_ops.identity(doug) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| @test_util.run_deprecated_v1 |
| def test_raises_when_zero(self): |
| claire = constant_op.constant([0], name="claire") |
| with self.assertRaisesOpError( # pylint:disable=g-error-prone-assert-raises |
| "x < 0 did not hold"): |
| with ops.control_dependencies([check_ops.assert_negative(claire)]): |
| out = array_ops.identity(claire) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_empty_tensor_doesnt_raise(self): |
| # A tensor is negative when it satisfies: |
| # For every element x_i in x, x_i < 0 |
| # and an empty tensor has no elements, so this is trivially satisfied. |
| # This is standard set theory. |
| empty = constant_op.constant([], name="empty") |
| with ops.control_dependencies([check_ops.assert_negative(empty)]): |
| out = array_ops.identity(empty) |
| self.evaluate(out) |
| |
| def test_static_check_in_graph_mode(self): |
| with ops.Graph().as_default(): |
| with self.assertRaisesRegex(errors.InvalidArgumentError, |
| "Custom error message"): |
| check_ops.assert_negative(1, message="Custom error message") |
| |
| |
| # pylint:disable=g-error-prone-assert-raises |
| class AssertPositiveTest(test.TestCase): |
| |
| @test_util.run_in_graph_and_eager_modes |
| @test_util.run_deprecated_v1 |
| def test_raises_when_negative(self): |
| freddie = constant_op.constant([-1, -2], name="freddie") |
| with self.assertRaisesOpError("fail"): |
| with ops.control_dependencies( |
| [check_ops.assert_positive( |
| freddie, message="fail")]): |
| out = array_ops.identity(freddie) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_positive(self): |
| remmy = constant_op.constant([1, 2], name="remmy") |
| with ops.control_dependencies([check_ops.assert_positive(remmy)]): |
| out = array_ops.identity(remmy) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| @test_util.run_deprecated_v1 |
| def test_raises_when_zero(self): |
| meechum = constant_op.constant([0], name="meechum") |
| with self.assertRaisesOpError("x > 0 did not hold"): |
| with ops.control_dependencies([check_ops.assert_positive(meechum)]): |
| out = array_ops.identity(meechum) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_empty_tensor_doesnt_raise(self): |
| # A tensor is positive when it satisfies: |
| # For every element x_i in x, x_i > 0 |
| # and an empty tensor has no elements, so this is trivially satisfied. |
| # This is standard set theory. |
| empty = constant_op.constant([], name="empty") |
| with ops.control_dependencies([check_ops.assert_positive(empty)]): |
| out = array_ops.identity(empty) |
| self.evaluate(out) |
| |
| def test_static_check_in_graph_mode(self): |
| with ops.Graph().as_default(): |
| with self.assertRaisesRegex(errors.InvalidArgumentError, |
| "Custom error message"): |
| check_ops.assert_positive(-1, message="Custom error message") |
| |
| |
| class EnsureShapeTest(test.TestCase): |
| |
| # Static shape inference |
| @test_util.run_deprecated_v1 |
| def testStaticShape(self): |
| placeholder = array_ops.placeholder(dtypes.int32) |
| ensure_shape_op = check_ops.ensure_shape(placeholder, (3, 3, 3)) |
| self.assertEqual(ensure_shape_op.get_shape(), (3, 3, 3)) |
| |
| @test_util.run_deprecated_v1 |
| def testStaticShape_MergesShapes(self): |
| placeholder = array_ops.placeholder(dtypes.int32, shape=(None, None, 3)) |
| ensure_shape_op = check_ops.ensure_shape(placeholder, (5, 4, None)) |
| self.assertEqual(ensure_shape_op.get_shape(), (5, 4, 3)) |
| |
| @test_util.run_deprecated_v1 |
| def testStaticShape_RaisesErrorWhenRankIncompatible(self): |
| placeholder = array_ops.placeholder(dtypes.int32, shape=(None, None, 3)) |
| with self.assertRaises(ValueError): |
| check_ops.ensure_shape(placeholder, (2, 3)) |
| |
| @test_util.run_deprecated_v1 |
| def testStaticShape_RaisesErrorWhenDimIncompatible(self): |
| placeholder = array_ops.placeholder(dtypes.int32, shape=(None, None, 3)) |
| with self.assertRaises(ValueError): |
| check_ops.ensure_shape(placeholder, (2, 2, 4)) |
| |
| @test_util.run_deprecated_v1 |
| def testStaticShape_CanSetUnknownShape(self): |
| placeholder = array_ops.placeholder(dtypes.int32) |
| derived = placeholder / 3 |
| ensure_shape_op = check_ops.ensure_shape(derived, None) |
| self.assertEqual(ensure_shape_op.get_shape(), None) |
| |
| # Dynamic shape check |
| @test_util.run_deprecated_v1 |
| @test_util.disable_xla( |
| "b/123337890") # Dynamic shapes not supported now with XLA |
| def testEnsuresDynamicShape_RaisesError(self): |
| placeholder = array_ops.placeholder(dtypes.int32) |
| derived = math_ops.divide(placeholder, 3, name="MyDivide") |
| derived = check_ops.ensure_shape(derived, (3, 3, 3)) |
| feed_val = [[1], [2]] |
| with self.cached_session() as sess: |
| with self.assertRaisesWithPredicateMatch( |
| errors.InvalidArgumentError, |
| r"Shape of tensor MyDivide \[2,1\] is not compatible with " |
| r"expected shape \[3,3,3\]."): |
| sess.run(derived, feed_dict={placeholder: feed_val}) |
| |
| @test_util.run_deprecated_v1 |
| @test_util.disable_xla( |
| "b/123337890") # Dynamic shapes not supported now with XLA |
| def testEnsuresDynamicShape_RaisesErrorDimUnknown(self): |
| placeholder = array_ops.placeholder(dtypes.int32) |
| derived = placeholder / 3 |
| derived = check_ops.ensure_shape(derived, (None, None, 3)) |
| feed_val = [[1], [2]] |
| with self.cached_session() as sess: |
| with self.assertRaisesWithPredicateMatch( |
| errors.InvalidArgumentError, |
| r"Shape of tensor [A-Za-z_]* \[2,1\] is not compatible with " |
| r"expected shape \[\?,\?,3\]."): |
| sess.run(derived, feed_dict={placeholder: feed_val}) |
| |
| @test_util.run_deprecated_v1 |
| def testEnsuresDynamicShape(self): |
| placeholder = array_ops.placeholder(dtypes.int32) |
| derived = placeholder / 3 |
| derived = check_ops.ensure_shape(derived, (2, 1)) |
| feed_val = [[1], [2]] |
| with self.cached_session() as sess: |
| sess.run(derived, feed_dict={placeholder: feed_val}) |
| |
| @test_util.run_deprecated_v1 |
| def testEnsuresDynamicShape_WithUnknownDims(self): |
| placeholder = array_ops.placeholder(dtypes.int32) |
| derived = placeholder / 3 |
| derived = check_ops.ensure_shape(derived, (None, None)) |
| feed_val = [[1], [2]] |
| with self.cached_session() as sess: |
| sess.run(derived, feed_dict={placeholder: feed_val}) |
| |
| @test_util.run_deprecated_v1 |
| def testGradient(self): |
| placeholder = array_ops.placeholder(dtypes.float32) |
| derived = check_ops.ensure_shape(placeholder, (None, None)) |
| gradient = gradients.gradients(derived, placeholder) |
| |
| feed_val = [[4.0], [-1.0]] |
| with self.cached_session() as sess: |
| gradient_values, = sess.run(gradient, feed_dict={placeholder: feed_val}) |
| |
| expected = [[1.0], [1.0]] |
| self.assertAllEqual(gradient_values, expected) |
| |
| |
| class EnsureShapeBenchmark(test.Benchmark): |
| |
| def _grappler_all_off_config(self): |
| config = config_pb2.ConfigProto() |
| off = rewriter_config_pb2.RewriterConfig.OFF |
| config.graph_options.optimizer_options.opt_level = -1 |
| config.graph_options.rewrite_options.disable_model_pruning = 1 |
| config.graph_options.rewrite_options.constant_folding = off |
| config.graph_options.rewrite_options.layout_optimizer = off |
| config.graph_options.rewrite_options.arithmetic_optimization = off |
| config.graph_options.rewrite_options.dependency_optimization = off |
| return config |
| |
| def _run(self, op, feed_dict=None, num_iters=5000, name=None, **kwargs): |
| config = self._grappler_all_off_config() |
| with session.Session(config=config) as sess: |
| deltas = [] |
| # Warm up the session |
| for _ in range(5): |
| sess.run(op, feed_dict=feed_dict) |
| for _ in range(num_iters): |
| start = time.time() |
| sess.run(op, feed_dict=feed_dict) |
| end = time.time() |
| deltas.append(end - start) |
| mean_time = np.median(deltas) |
| mean_us = mean_time * 1e6 |
| # mean_us = (end - start) * 1e6 / num_iters |
| self.report_benchmark( |
| name=name, |
| wall_time=mean_us, |
| extras=kwargs, |
| ) |
| |
| def benchmark_const_op(self): |
| # In this case, we expect that the overhead of a `session.run` call |
| # far outweighs the time taken to execute the op... |
| shape = (3, 3, 100) |
| input_op = random_ops.random_normal(shape) |
| self._run(array_ops.identity(input_op), name="SingleConstOp") |
| |
| def benchmark_single_ensure_op(self): |
| # In this case, we expect that the overhead of a `session.run` call |
| # far outweighs the time taken to execute the op... |
| shape = (3, 3, 100) |
| input_op = random_ops.random_normal(shape) |
| ensure_shape_op = check_ops.ensure_shape(input_op, shape) |
| self._run(ensure_shape_op, name="SingleEnsureShapeOp") |
| |
| def _apply_n_times(self, op, target, n=1000): |
| for _ in range(n): |
| target = op(target) |
| return target |
| |
| def benchmark_n_ops(self): |
| shape = (1000,) |
| input_op = random_ops.random_normal(shape) |
| n_ops = self._apply_n_times(array_ops.identity, input_op) |
| self._run(n_ops, name="NIdentityOps_1000") |
| |
| def benchmark_n_ensure_ops(self): |
| shape = (1000,) |
| input_op = random_ops.random_normal(shape) |
| n_ensure_ops = self._apply_n_times( |
| lambda x: check_ops.ensure_shape(array_ops.identity(x), shape), |
| input_op) |
| self._run(n_ensure_ops, name="NEnsureShapeAndIdentityOps_1000") |
| |
| |
| class AssertRankTest(test.TestCase): |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_rank_zero_tensor_raises_if_rank_too_small_static_rank(self): |
| tensor = constant_op.constant(1, name="my_tensor") |
| desired_rank = 1 |
| with self.assertRaisesRegex(ValueError, "fail.*must have rank 1"): |
| with ops.control_dependencies( |
| [check_ops.assert_rank( |
| tensor, desired_rank, message="fail")]): |
| self.evaluate(array_ops.identity(tensor)) |
| |
| @test_util.run_deprecated_v1 |
| def test_rank_zero_tensor_raises_if_rank_too_small_dynamic_rank(self): |
| with self.cached_session(): |
| tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") |
| desired_rank = 1 |
| with ops.control_dependencies( |
| [check_ops.assert_rank( |
| tensor, desired_rank, message="fail")]): |
| with self.assertRaisesOpError("fail.*my_tensor.*rank"): |
| array_ops.identity(tensor).eval(feed_dict={tensor: 0}) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_static_rank(self): |
| tensor = constant_op.constant(1, name="my_tensor") |
| desired_rank = 0 |
| with ops.control_dependencies( |
| [check_ops.assert_rank(tensor, desired_rank)]): |
| self.evaluate(array_ops.identity(tensor)) |
| |
| @test_util.run_deprecated_v1 |
| def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self): |
| with self.cached_session(): |
| tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") |
| desired_rank = 0 |
| with ops.control_dependencies( |
| [check_ops.assert_rank(tensor, desired_rank)]): |
| array_ops.identity(tensor).eval(feed_dict={tensor: 0}) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_rank_one_tensor_raises_if_rank_too_large_static_rank(self): |
| tensor = constant_op.constant([1, 2], name="my_tensor") |
| desired_rank = 0 |
| with self.assertRaisesRegex(ValueError, "rank"): |
| with ops.control_dependencies( |
| [check_ops.assert_rank(tensor, desired_rank)]): |
| self.evaluate(array_ops.identity(tensor)) |
| |
| @test_util.run_deprecated_v1 |
| def test_rank_one_tensor_raises_if_rank_too_large_dynamic_rank(self): |
| with self.cached_session(): |
| tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") |
| desired_rank = 0 |
| with ops.control_dependencies( |
| [check_ops.assert_rank(tensor, desired_rank)]): |
| with self.assertRaisesOpError("my_tensor.*rank"): |
| array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]}) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_rank_one_tensor_doesnt_raise_if_rank_just_right_static_rank(self): |
| tensor = constant_op.constant([1, 2], name="my_tensor") |
| desired_rank = 1 |
| with ops.control_dependencies( |
| [check_ops.assert_rank(tensor, desired_rank)]): |
| self.evaluate(array_ops.identity(tensor)) |
| |
| @test_util.run_deprecated_v1 |
| def test_rank_one_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self): |
| with self.cached_session(): |
| tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") |
| desired_rank = 1 |
| with ops.control_dependencies( |
| [check_ops.assert_rank(tensor, desired_rank)]): |
| array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]}) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_rank_one_tensor_raises_if_rank_too_small_static_rank(self): |
| tensor = constant_op.constant([1, 2], name="my_tensor") |
| desired_rank = 2 |
| with self.assertRaisesRegex(ValueError, "rank"): |
| with ops.control_dependencies( |
| [check_ops.assert_rank(tensor, desired_rank)]): |
| self.evaluate(array_ops.identity(tensor)) |
| |
| @test_util.run_deprecated_v1 |
| def test_rank_one_tensor_raises_if_rank_too_small_dynamic_rank(self): |
| with self.cached_session(): |
| tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") |
| desired_rank = 2 |
| with ops.control_dependencies( |
| [check_ops.assert_rank(tensor, desired_rank)]): |
| with self.assertRaisesOpError("my_tensor.* must have rank"): |
| array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]}) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_raises_if_rank_is_not_scalar_static(self): |
| tensor = constant_op.constant([1, 2], name="my_tensor") |
| with self.assertRaisesRegex(ValueError, "Argument `rank` must be a scalar"): |
| check_ops.assert_rank(tensor, np.array([], dtype=np.int32)) |
| |
| @test_util.run_deprecated_v1 |
| def test_raises_if_rank_is_not_scalar_dynamic(self): |
| with self.cached_session(): |
| tensor = constant_op.constant( |
| [1, 2], dtype=dtypes.float32, name="my_tensor") |
| rank_tensor = array_ops.placeholder(dtypes.int32, name="rank_tensor") |
| with self.assertRaisesOpError("Rank must be a scalar."): |
| with ops.control_dependencies( |
| [check_ops.assert_rank(tensor, rank_tensor)]): |
| array_ops.identity(tensor).eval(feed_dict={rank_tensor: [1, 2]}) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_raises_if_rank_is_not_integer_static(self): |
| tensor = constant_op.constant([1, 2], name="my_tensor") |
| with self.assertRaisesRegex(TypeError, "must be of type <dtype: 'int32'>"): |
| check_ops.assert_rank(tensor, .5) |
| |
| @test_util.run_deprecated_v1 |
| def test_raises_if_rank_is_not_integer_dynamic(self): |
| with self.cached_session(): |
| tensor = constant_op.constant( |
| [1, 2], dtype=dtypes.float32, name="my_tensor") |
| rank_tensor = array_ops.placeholder(dtypes.float32, name="rank_tensor") |
| with self.assertRaisesRegex(TypeError, |
| "must be of type <dtype: 'int32'>"): |
| with ops.control_dependencies( |
| [check_ops.assert_rank(tensor, rank_tensor)]): |
| array_ops.identity(tensor).eval(feed_dict={rank_tensor: .5}) |
| |
| |
| class AssertRankInTest(test.TestCase): |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_rank_zero_tensor_raises_if_rank_mismatch_static_rank(self): |
| tensor_rank0 = constant_op.constant(42, name="my_tensor") |
| with self.assertRaisesRegex(ValueError, "fail.*must have rank.*1.*2"): |
| with ops.control_dependencies([ |
| check_ops.assert_rank_in(tensor_rank0, (1, 2), message="fail")]): |
| self.evaluate(array_ops.identity(tensor_rank0)) |
| |
| @test_util.run_deprecated_v1 |
| def test_rank_zero_tensor_raises_if_rank_mismatch_dynamic_rank(self): |
| with self.cached_session(): |
| tensor_rank0 = array_ops.placeholder(dtypes.float32, name="my_tensor") |
| with ops.control_dependencies([ |
| check_ops.assert_rank_in(tensor_rank0, (1, 2), message="fail")]): |
| with self.assertRaisesOpError("fail.*my_tensor.*rank"): |
| array_ops.identity(tensor_rank0).eval(feed_dict={tensor_rank0: 42.0}) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_rank_zero_tensor_doesnt_raise_if_rank_matches_static_rank(self): |
| tensor_rank0 = constant_op.constant(42, name="my_tensor") |
| for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)): |
| with ops.control_dependencies([ |
| check_ops.assert_rank_in(tensor_rank0, desired_ranks)]): |
| self.evaluate(array_ops.identity(tensor_rank0)) |
| |
| @test_util.run_deprecated_v1 |
| def test_rank_zero_tensor_doesnt_raise_if_rank_matches_dynamic_rank(self): |
| with self.cached_session(): |
| tensor_rank0 = array_ops.placeholder(dtypes.float32, name="my_tensor") |
| for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)): |
| with ops.control_dependencies([ |
| check_ops.assert_rank_in(tensor_rank0, desired_ranks)]): |
| array_ops.identity(tensor_rank0).eval(feed_dict={tensor_rank0: 42.0}) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_rank_one_tensor_doesnt_raise_if_rank_matches_static_rank(self): |
| tensor_rank1 = constant_op.constant([42, 43], name="my_tensor") |
| for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)): |
| with ops.control_dependencies([ |
| check_ops.assert_rank_in(tensor_rank1, desired_ranks)]): |
| self.evaluate(array_ops.identity(tensor_rank1)) |
| |
| @test_util.run_deprecated_v1 |
| def test_rank_one_tensor_doesnt_raise_if_rank_matches_dynamic_rank(self): |
| with self.cached_session(): |
| tensor_rank1 = array_ops.placeholder(dtypes.float32, name="my_tensor") |
| for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)): |
| with ops.control_dependencies([ |
| check_ops.assert_rank_in(tensor_rank1, desired_ranks)]): |
| array_ops.identity(tensor_rank1).eval(feed_dict={ |
| tensor_rank1: (42.0, 43.0) |
| }) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_rank_one_tensor_raises_if_rank_mismatches_static_rank(self): |
| tensor_rank1 = constant_op.constant((42, 43), name="my_tensor") |
| with self.assertRaisesRegex(ValueError, "rank"): |
| with ops.control_dependencies([ |
| check_ops.assert_rank_in(tensor_rank1, (0, 2))]): |
| self.evaluate(array_ops.identity(tensor_rank1)) |
| |
| @test_util.run_deprecated_v1 |
| def test_rank_one_tensor_raises_if_rank_mismatches_dynamic_rank(self): |
| with self.cached_session(): |
| tensor_rank1 = array_ops.placeholder(dtypes.float32, name="my_tensor") |
| with ops.control_dependencies([ |
| check_ops.assert_rank_in(tensor_rank1, (0, 2))]): |
| with self.assertRaisesOpError("my_tensor.*rank"): |
| array_ops.identity(tensor_rank1).eval(feed_dict={ |
| tensor_rank1: (42.0, 43.0) |
| }) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_raises_if_rank_is_not_scalar_static(self): |
| tensor = constant_op.constant((42, 43), name="my_tensor") |
| desired_ranks = ( |
| np.array(1, dtype=np.int32), |
| np.array((2, 1), dtype=np.int32)) |
| with self.assertRaisesRegex( |
| ValueError, "Argument `ranks` must contain scalar tensors."): |
| check_ops.assert_rank_in(tensor, desired_ranks) |
| |
| @test_util.run_deprecated_v1 |
| def test_raises_if_rank_is_not_scalar_dynamic(self): |
| with self.cached_session(): |
| tensor = constant_op.constant( |
| (42, 43), dtype=dtypes.float32, name="my_tensor") |
| desired_ranks = ( |
| array_ops.placeholder(dtypes.int32, name="rank0_tensor"), |
| array_ops.placeholder(dtypes.int32, name="rank1_tensor")) |
| with self.assertRaisesOpError("Rank must be a scalar"): |
| with ops.control_dependencies( |
| (check_ops.assert_rank_in(tensor, desired_ranks),)): |
| array_ops.identity(tensor).eval(feed_dict={ |
| desired_ranks[0]: 1, |
| desired_ranks[1]: [2, 1], |
| }) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_raises_if_rank_is_not_integer_static(self): |
| tensor = constant_op.constant((42, 43), name="my_tensor") |
| with self.assertRaisesRegex(TypeError, "must be of type <dtype: 'int32'>"): |
| check_ops.assert_rank_in(tensor, (1, .5,)) |
| |
| @test_util.run_deprecated_v1 |
| def test_raises_if_rank_is_not_integer_dynamic(self): |
| with self.cached_session(): |
| tensor = constant_op.constant( |
| (42, 43), dtype=dtypes.float32, name="my_tensor") |
| rank_tensor = array_ops.placeholder(dtypes.float32, name="rank_tensor") |
| with self.assertRaisesRegex(TypeError, |
| "must be of type <dtype: 'int32'>"): |
| with ops.control_dependencies( |
| [check_ops.assert_rank_in(tensor, (1, rank_tensor))]): |
| array_ops.identity(tensor).eval(feed_dict={rank_tensor: .5}) |
| |
| |
| class AssertRankAtLeastTest(test.TestCase): |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_rank_zero_tensor_raises_if_rank_too_small_static_rank(self): |
| tensor = constant_op.constant(1, name="my_tensor") |
| desired_rank = 1 |
| with self.assertRaisesRegex(ValueError, "rank at least 1"): |
| with ops.control_dependencies( |
| [check_ops.assert_rank_at_least(tensor, desired_rank)]): |
| self.evaluate(array_ops.identity(tensor)) |
| |
| @test_util.run_deprecated_v1 |
| def test_rank_zero_tensor_raises_if_rank_too_small_dynamic_rank(self): |
| with self.cached_session(): |
| tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") |
| desired_rank = 1 |
| with ops.control_dependencies( |
| [check_ops.assert_rank_at_least(tensor, desired_rank)]): |
| with self.assertRaisesOpError("my_tensor.*rank"): |
| array_ops.identity(tensor).eval(feed_dict={tensor: 0}) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_static_rank(self): |
| tensor = constant_op.constant(1, name="my_tensor") |
| desired_rank = 0 |
| with ops.control_dependencies( |
| [check_ops.assert_rank_at_least(tensor, desired_rank)]): |
| self.evaluate(array_ops.identity(tensor)) |
| |
| @test_util.run_deprecated_v1 |
| def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self): |
| with self.cached_session(): |
| tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") |
| desired_rank = 0 |
| with ops.control_dependencies( |
| [check_ops.assert_rank_at_least(tensor, desired_rank)]): |
| array_ops.identity(tensor).eval(feed_dict={tensor: 0}) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_rank_one_ten_doesnt_raise_raise_if_rank_too_large_static_rank(self): |
| tensor = constant_op.constant([1, 2], name="my_tensor") |
| desired_rank = 0 |
| with ops.control_dependencies( |
| [check_ops.assert_rank_at_least(tensor, desired_rank)]): |
| self.evaluate(array_ops.identity(tensor)) |
| |
| @test_util.run_deprecated_v1 |
| def test_rank_one_ten_doesnt_raise_if_rank_too_large_dynamic_rank(self): |
| with self.cached_session(): |
| tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") |
| desired_rank = 0 |
| with ops.control_dependencies( |
| [check_ops.assert_rank_at_least(tensor, desired_rank)]): |
| array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]}) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_rank_one_tensor_doesnt_raise_if_rank_just_right_static_rank(self): |
| tensor = constant_op.constant([1, 2], name="my_tensor") |
| desired_rank = 1 |
| with ops.control_dependencies( |
| [check_ops.assert_rank_at_least(tensor, desired_rank)]): |
| self.evaluate(array_ops.identity(tensor)) |
| |
| @test_util.run_deprecated_v1 |
| def test_rank_one_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self): |
| with self.cached_session(): |
| tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") |
| desired_rank = 1 |
| with ops.control_dependencies( |
| [check_ops.assert_rank_at_least(tensor, desired_rank)]): |
| array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]}) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_rank_one_tensor_raises_if_rank_too_small_static_rank(self): |
| tensor = constant_op.constant([1, 2], name="my_tensor") |
| desired_rank = 2 |
| with self.assertRaisesRegex(ValueError, "rank at least 2"): |
| with ops.control_dependencies( |
| [check_ops.assert_rank_at_least(tensor, desired_rank)]): |
| self.evaluate(array_ops.identity(tensor)) |
| |
| @test_util.run_deprecated_v1 |
| def test_rank_one_tensor_raises_if_rank_too_small_dynamic_rank(self): |
| with self.cached_session(): |
| tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") |
| desired_rank = 2 |
| with ops.control_dependencies( |
| [check_ops.assert_rank_at_least(tensor, desired_rank)]): |
| with self.assertRaisesOpError("my_tensor.*rank"): |
| array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]}) |
| |
| |
| class AssertNonNegativeTest(test.TestCase): |
| |
| @test_util.run_in_graph_and_eager_modes |
| @test_util.run_deprecated_v1 |
| def test_raises_when_negative(self): |
| zoe = constant_op.constant([-1, -2], name="zoe") |
| with self.assertRaisesOpError("x >= 0 did not hold"): |
| with ops.control_dependencies([check_ops.assert_non_negative(zoe)]): |
| out = array_ops.identity(zoe) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_zero_and_positive(self): |
| lucas = constant_op.constant([0, 2], name="lucas") |
| with ops.control_dependencies([check_ops.assert_non_negative(lucas)]): |
| out = array_ops.identity(lucas) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_empty_tensor_doesnt_raise(self): |
| # A tensor is non-negative when it satisfies: |
| # For every element x_i in x, x_i >= 0 |
| # and an empty tensor has no elements, so this is trivially satisfied. |
| # This is standard set theory. |
| empty = constant_op.constant([], name="empty") |
| with ops.control_dependencies([check_ops.assert_non_negative(empty)]): |
| out = array_ops.identity(empty) |
| self.evaluate(out) |
| |
| def test_static_check_in_graph_mode(self): |
| with ops.Graph().as_default(): |
| with self.assertRaisesRegex(errors.InvalidArgumentError, |
| "Custom error message"): |
| check_ops.assert_non_negative(-1, message="Custom error message") |
| |
| |
| class AssertNonPositiveTest(test.TestCase): |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_zero_and_negative(self): |
| tom = constant_op.constant([0, -2], name="tom") |
| with ops.control_dependencies([check_ops.assert_non_positive(tom)]): |
| out = array_ops.identity(tom) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| @test_util.run_deprecated_v1 |
| def test_raises_when_positive(self): |
| rachel = constant_op.constant([0, 2], name="rachel") |
| with self.assertRaisesOpError("x <= 0 did not hold"): |
| with ops.control_dependencies([check_ops.assert_non_positive(rachel)]): |
| out = array_ops.identity(rachel) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_empty_tensor_doesnt_raise(self): |
| # A tensor is non-positive when it satisfies: |
| # For every element x_i in x, x_i <= 0 |
| # and an empty tensor has no elements, so this is trivially satisfied. |
| # This is standard set theory. |
| empty = constant_op.constant([], name="empty") |
| with ops.control_dependencies([check_ops.assert_non_positive(empty)]): |
| out = array_ops.identity(empty) |
| self.evaluate(out) |
| |
| def test_static_check_in_graph_mode(self): |
| with ops.Graph().as_default(): |
| with self.assertRaisesRegex(errors.InvalidArgumentError, |
| "Custom error message"): |
| check_ops.assert_non_positive(1, message="Custom error message") |
| |
| |
| class AssertIntegerTest(test.TestCase): |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_integer(self): |
| integers = constant_op.constant([1, 2], name="integers") |
| with ops.control_dependencies([check_ops.assert_integer(integers)]): |
| out = array_ops.identity(integers) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_raises_when_float(self): |
| floats = constant_op.constant([1.0, 2.0], name="floats") |
| with self.assertRaisesRegex(TypeError, "Expected.*integer"): |
| check_ops.assert_integer(floats) |
| |
| |
| class AssertTypeTest(test.TestCase): |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_doesnt_raise_when_correct_type(self): |
| integers = constant_op.constant([1, 2], dtype=dtypes.int64) |
| with ops.control_dependencies([ |
| check_ops.assert_type(integers, dtypes.int64)]): |
| out = array_ops.identity(integers) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_sparsetensor_doesnt_raise_when_correct_type(self): |
| sparse_float = sparse_tensor.SparseTensor( |
| constant_op.constant([[111], [232]], dtypes.int64), |
| constant_op.constant([23.4, -43.2], dtypes.float32), |
| constant_op.constant([500], dtypes.int64)) |
| |
| with ops.control_dependencies( |
| [check_ops.assert_type(sparse_float, dtypes.float32)]): |
| out = sparse_tensor.SparseTensor(sparse_float.indices, |
| array_ops.identity(sparse_float.values), |
| sparse_float.dense_shape) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_raises_when_wrong_type(self): |
| floats = constant_op.constant([1.0, 2.0], dtype=dtypes.float16) |
| with self.assertRaisesRegex(TypeError, "must be of type.*float32"): |
| check_ops.assert_type(floats, dtypes.float32) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_sparsetensor_raises_when_wrong_type(self): |
| sparse_float16 = sparse_tensor.SparseTensor( |
| constant_op.constant([[111], [232]], dtypes.int64), |
| constant_op.constant([23.4, -43.2], dtypes.float16), |
| constant_op.constant([500], dtypes.int64)) |
| with self.assertRaisesRegex(TypeError, "must be of type.*float32"): |
| check_ops.assert_type(sparse_float16, dtypes.float32) |
| |
| def test_raise_when_tf_type_is_not_dtype(self): |
| # Test case for GitHub issue: |
| # https://github.com/tensorflow/tensorflow/issues/45975 |
| value = constant_op.constant(0.0) |
| with self.assertRaisesRegex(TypeError, |
| "Cannot convert.*to a TensorFlow DType"): |
| check_ops.assert_type(value, (dtypes.float32,)) |
| |
| |
| class AssertShapesTest(test.TestCase): |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_raise_static_shape_mismatch(self): |
| x = array_ops.ones([3, 2], name="x") |
| y = array_ops.ones([2, 3], name="y") |
| shapes = [ |
| (x, ("N", "Q")), |
| (y, ("N", "D")), |
| ] |
| regex = (r"Specified by tensor .* dimension 0. " |
| r"Tensor .* dimension 0 must have size 3. " |
| r"Received size 2") |
| self.raises_static_error(shapes=shapes, regex=regex) |
| |
| def test_raise_dynamic_shape_mismatch(self): |
| with ops.Graph().as_default(): |
| x = array_ops.placeholder(dtypes.float32, [None, 2], name="x") |
| y = array_ops.placeholder(dtypes.float32, [None, 3], name="y") |
| shapes = [ |
| (x, ("N", "Q")), |
| (y, ("N", "D")), |
| ] |
| regex = (r"\[Specified by tensor x.* dimension 0\] " |
| r"\[Tensor y.* dimension\] \[0\] \[must have size\] \[3\]") |
| feed_dict = {x: np.ones([3, 2]), y: np.ones([2, 3])} |
| self.raises_dynamic_error(shapes=shapes, regex=regex, feed_dict=feed_dict) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_raise_static_shape_explicit_mismatch(self): |
| x = array_ops.ones([3, 2], name="x") |
| y = array_ops.ones([2, 3], name="y") |
| shapes = [ |
| (x, (3, "Q")), |
| (y, (3, "D")), |
| ] |
| regex = (r"Specified explicitly. " |
| r"Tensor .* dimension 0 must have size 3. " |
| r"Received size 2") |
| self.raises_static_error(shapes=shapes, regex=regex) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_rank_zero_rank_one_size_one_equivalence(self): |
| rank_one_size_one = array_ops.ones([1], name="rank_one_size_one") |
| rank_zero = array_ops.constant(5, name="rank_zero") |
| check_ops.assert_shapes([ |
| (rank_one_size_one, ()), |
| (rank_zero, ()), |
| ]) |
| check_ops.assert_shapes([ |
| (rank_one_size_one, (1,)), |
| (rank_zero, (1,)), |
| ]) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_raise_static_rank_1_size_not_1_mismatch_scalar(self): |
| x = array_ops.constant([2, 2], name="x") |
| shapes = [ |
| (x, ()), |
| ] |
| regex = (r"Specified explicitly. " |
| r"Tensor .* dimension 0 must have size 1. " |
| r"Received size 2") |
| self.raises_static_error(shapes=shapes, regex=regex) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_raise_static_scalar_mismatch_rank_1_size_not_1(self): |
| x = array_ops.constant(2, name="x") |
| shapes = [ |
| (x, (2,)), |
| ] |
| regex = (r"Specified explicitly. " |
| r"Tensor .* dimension 0 must have size 2. " |
| r"Received size 1") |
| self.raises_static_error(shapes=shapes, regex=regex) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_scalar_implies_size_one(self): |
| scalar = array_ops.constant(5, name="rank_zero") |
| x = array_ops.ones([2, 2], name="x") |
| shapes = [(scalar, ("a",)), (x, ("a", 2))] |
| regex = (r"Specified by tensor .* dimension 0. " |
| r"Tensor .* dimension 0 must have size 1. " |
| r"Received size 2") |
| self.raises_static_error(shapes=shapes, regex=regex) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_raise_not_iterable(self): |
| x = array_ops.constant([1, 2], name="x") |
| shapes = [(x, 2)] |
| regex = (r"Tensor .*. " |
| r"Specified shape must be an iterable. " |
| r"An iterable has the attribute `__iter__` or `__getitem__`. " |
| r"Received specified shape: 2") |
| self.raises_static_error(shapes=shapes, regex=regex) |
| |
| def test_raise_dynamic_shape_explicit_mismatch(self): |
| with ops.Graph().as_default(): |
| x = array_ops.placeholder(dtypes.float32, [None, 2], name="xa") |
| y = array_ops.placeholder(dtypes.float32, [None, 3], name="y") |
| shapes = [ |
| (x, (3, "Q")), |
| (y, (3, "D")), |
| ] |
| regex = (r"\[Specified explicitly\] " |
| r"\[Tensor y.* dimension\] \[0\] \[must have size\] \[3\]") |
| feed_dict = {x: np.ones([3, 2]), y: np.ones([2, 3])} |
| self.raises_dynamic_error(shapes=shapes, regex=regex, feed_dict=feed_dict) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_no_op_when_specified_as_unknown(self): |
| x = array_ops.constant([1, 1], name="x") |
| assertion = check_ops.assert_shapes([(x, None)]) |
| with ops.control_dependencies([assertion]): |
| out = array_ops.identity(x) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_raises_static_incorrect_rank(self): |
| rank_two_shapes = [ |
| (1, 1), |
| (1, 3), |
| ("a", "b"), |
| (None, None), |
| ] |
| rank_three_shapes = [ |
| (1, 1, 1), |
| ("a", "b", "c"), |
| (None, None, None), |
| (1, "b", None), |
| ] |
| |
| def raises_static_rank_error(shapes, x, correct_rank, actual_rank): |
| for shape in shapes: |
| regex = (r"Tensor .* must have rank %d. Received rank %d" % |
| (correct_rank, actual_rank)) |
| self.raises_static_error(shapes=[(x, shape)], regex=regex) |
| |
| raises_static_rank_error( |
| rank_two_shapes, array_ops.ones([1]), correct_rank=2, actual_rank=1) |
| raises_static_rank_error( |
| rank_three_shapes, |
| array_ops.ones([1, 1]), |
| correct_rank=3, |
| actual_rank=2) |
| raises_static_rank_error( |
| rank_three_shapes, array_ops.constant(1), correct_rank=3, actual_rank=0) |
| |
| def test_raises_dynamic_incorrect_rank(self): |
| x_value = 5 |
| rank_two_shapes = [(1, 1), (1, 3), ("a", "b"), (None, None)] |
| with ops.Graph().as_default(): |
| x = array_ops.placeholder(dtypes.float32, None) |
| |
| for shape in rank_two_shapes: |
| regex = r"Tensor .* must have rank\] \[2\]" |
| self.raises_dynamic_error( |
| shapes=[(x, shape)], regex=regex, feed_dict={x: x_value}) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_correctly_matching(self): |
| u = array_ops.constant(1, name="u") |
| v = array_ops.ones([1, 2], name="v") |
| w = array_ops.ones([3], name="w") |
| x = array_ops.ones([1, 2, 3], name="x") |
| y = array_ops.ones([3, 1, 2], name="y") |
| z = array_ops.ones([2, 3, 1], name="z") |
| assertion = check_ops.assert_shapes([ |
| (x, ("a", "b", "c")), |
| (y, ("c", "a", "b")), |
| (z, ("b", "c", "a")), |
| (v, ("a", "b")), |
| (w, ("c",)), |
| (u, "a") |
| ]) |
| with ops.control_dependencies([assertion]): |
| out = array_ops.identity(x) |
| self.evaluate(out) |
| assertion = check_ops.assert_shapes([ |
| (x, (1, "b", "c")), |
| (y, ("c", "a", 2)), |
| (z, ("b", 3, "a")), |
| (v, ("a", 2)), |
| (w, (3,)), |
| (u, ()) |
| ]) |
| with ops.control_dependencies([assertion]): |
| out = array_ops.identity(x) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_variable_length_symbols(self): |
| x = array_ops.ones([4, 1], name="x") |
| y = array_ops.ones([4, 2], name="y") |
| assertion = check_ops.assert_shapes([ |
| (x, ("num_observations", "input_dim")), |
| (y, ("num_observations", "output_dim")), |
| ]) |
| with ops.control_dependencies([assertion]): |
| out = array_ops.identity(x) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_raise_implicit_mismatch_using_iterable_alternatives(self): |
| x = array_ops.ones([2, 2], name="x") |
| y = array_ops.ones([1, 3], name="y") |
| styles = [[ |
| (x, ("A", "B")), |
| (y, ("A", "C")), |
| ], [ |
| (x, "AB"), |
| (y, "AC") |
| ], [ |
| (x, ["A", "B"]), |
| (y, ["A", "C"]), |
| ], [ |
| (x, np.array(["A", "B"])), |
| (y, np.array(["A", "C"])) |
| ], [ |
| (x, ("A", "B")), |
| (y, "AC") |
| ]] |
| for shapes in styles: |
| self.raises_static_error( |
| shapes=shapes, |
| regex=(r"Specified by tensor .* dimension 0. " |
| "Tensor .* dimension 0 must have size 2. " |
| "Received size 1")) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_raise_explicit_mismatch_using_iterable_alternatives(self): |
| x = array_ops.ones([2, 2], name="x") |
| y = array_ops.ones([1, 3], name="y") |
| styles = [[ |
| (x, (2, 2)), |
| (y, (2, 3)), |
| ], [ |
| (x, "22"), |
| (y, "23") |
| ], [ |
| (x, [2, 2]), |
| (y, [2, 3]), |
| ], [ |
| (x, np.array([2, 2])), |
| (y, np.array([2, 3])) |
| ], [ |
| (x, (2, 2)), |
| (y, "23") |
| ]] |
| for shapes in styles: |
| self.raises_static_error( |
| shapes=shapes, |
| regex=(r"Specified explicitly. " |
| "Tensor .* dimension 0 must have size 2. " |
| "Received size 1")) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_dim_size_specified_as_unknown(self): |
| x = array_ops.ones([1, 2, 3], name="x") |
| y = array_ops.ones([2, 1], name="y") |
| a1 = check_ops.assert_shapes([ |
| (x, (None, 2, None)), |
| (y, (None, 1)), |
| ]) |
| a2 = check_ops.assert_shapes([ |
| (x, (".", 2, ".")), |
| (y, (".", 1)), |
| ]) |
| a3 = check_ops.assert_shapes([ |
| (x, ".2."), |
| (y, ".1"), |
| ]) |
| with ops.control_dependencies([a1, a2, a3]): |
| out = array_ops.identity(x) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_raise_static_shape_explicit_mismatch_innermost_dims(self): |
| x = array_ops.ones([3, 2], name="x") |
| y = array_ops.ones([2, 3], name="y") |
| s1 = [ |
| (x, (3, "Q")), |
| (y, (Ellipsis, 3, "D")), |
| ] |
| s2 = [ |
| (x, "3Q"), |
| (y, "*3D"), |
| ] |
| regex = (r"Specified explicitly. " |
| r"Tensor .* dimension -2 must have size 3. " |
| r"Received size 2") |
| self.raises_static_error(shapes=s1, regex=regex) |
| self.raises_static_error(shapes=s2, regex=regex) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_correctly_matching_innermost_dims(self): |
| x = array_ops.ones([1, 2, 3, 2], name="x") |
| y = array_ops.ones([2, 3, 3], name="y") |
| a1 = check_ops.assert_shapes([ |
| (x, (Ellipsis, "N", "Q")), |
| (y, (Ellipsis, "N", "D")), |
| ]) |
| a2 = check_ops.assert_shapes([ |
| (x, "*NQ"), |
| (y, "*ND"), |
| ]) |
| with ops.control_dependencies([a1, a2]): |
| out = array_ops.identity(x) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_raise_variable_num_outer_dims_prefix_misuse(self): |
| x = array_ops.ones([1, 2], name="x") |
| s1 = [ |
| (x, ("N", Ellipsis, "Q")), |
| ] |
| s2 = [ |
| (x, "N*Q"), |
| ] |
| regex = (r"Tensor .* specified shape index .*. " |
| r"Symbol `...` or `\*` for a variable number of " |
| r"unspecified dimensions is only allowed as the first entry") |
| self.raises_static_error(shapes=s1, regex=regex) |
| self.raises_static_error(shapes=s2, regex=regex) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_empty_shapes_dict_no_op(self): |
| assertion = check_ops.assert_shapes([]) |
| with ops.control_dependencies([assertion]): |
| out = array_ops.identity(0) |
| self.evaluate(out) |
| |
| def raises_static_error(self, shapes, regex): |
| with self.assertRaisesRegex(ValueError, regex): |
| check_ops.assert_shapes(shapes) |
| |
| def raises_dynamic_error(self, shapes, regex, feed_dict): |
| with self.session() as sess: |
| with self.assertRaisesRegex(errors.InvalidArgumentError, regex): |
| assertion = check_ops.assert_shapes(shapes) |
| with ops.control_dependencies([assertion]): |
| out = array_ops.identity(0) |
| sess.run(out, feed_dict=feed_dict) |
| |
| |
| class AssertShapesSparseTensorTest(test.TestCase): |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_assert_shapes_sparse_tensor_scalar_target_success(self): |
| sparse_float = sparse_tensor.SparseTensor( |
| constant_op.constant([[]], dtypes.int64), |
| constant_op.constant([42], dtypes.float32), |
| constant_op.constant([], dtypes.int64)) |
| assertion = check_ops.assert_shapes([(sparse_float, [])]) |
| with ops.control_dependencies([assertion]): |
| out = array_ops.identity(sparse_float) |
| self.evaluate(out) |
| |
| def test_assert_shapes_sparse_tensor_nonscalar_target_fail(self): |
| sparse_float = sparse_tensor.SparseTensor( |
| constant_op.constant([[]], dtypes.int64), |
| constant_op.constant([42], dtypes.float32), |
| constant_op.constant([], dtypes.int64)) |
| with self.assertRaisesRegexp(ValueError, |
| r"must have rank 2.*Received rank 0"): |
| assertion = check_ops.assert_shapes([(sparse_float, [None, None])]) |
| with ops.control_dependencies([assertion]): |
| out = array_ops.identity(sparse_float) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_assert_shapes_sparse_tensor_fully_specified_target_success(self): |
| sparse_float = sparse_tensor.SparseTensor( |
| constant_op.constant([[111], [232]], dtypes.int64), |
| constant_op.constant([23.4, -43.2], dtypes.float32), |
| constant_op.constant([500], dtypes.int64)) |
| assertion = check_ops.assert_shapes([(sparse_float, [500])]) |
| with ops.control_dependencies([assertion]): |
| out = array_ops.identity(sparse_float) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_assert_shapes_sparse_tensor_fully_specified_target_fail(self): |
| sparse_float = sparse_tensor.SparseTensor( |
| constant_op.constant([[111], [232]], dtypes.int64), |
| constant_op.constant([23.4, -43.2], dtypes.float32), |
| constant_op.constant([500], dtypes.int64)) |
| with self.assertRaisesRegex(ValueError, r"dimension 0 must have size 499"): |
| assertion = check_ops.assert_shapes([(sparse_float, [499])]) |
| with ops.control_dependencies([assertion]): |
| out = array_ops.identity(sparse_float) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_assert_shapes_sparse_tensor_partially_specified_target_success(self): |
| sparse_int = sparse_tensor.SparseTensor( |
| constant_op.constant([[5, 6], [7, 8]], dtypes.int64), |
| constant_op.constant([23, -43], dtypes.int32), |
| constant_op.constant([30, 40], dtypes.int64)) |
| assertion = check_ops.assert_shapes([(sparse_int, [None, 40])]) |
| with ops.control_dependencies([assertion]): |
| out = array_ops.identity(sparse_int) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_assert_shapes_sparse_tensor_symbolic_match_success(self): |
| sparse_int = sparse_tensor.SparseTensor( |
| constant_op.constant([[5, 6, 7], [8, 9, 10]], dtypes.int64), |
| constant_op.constant([23, -43], dtypes.int32), |
| constant_op.constant([30, 30, 40], dtypes.int64)) |
| assertion = check_ops.assert_shapes([(sparse_int, ["N", "N", "D"])]) |
| with ops.control_dependencies([assertion]): |
| out = array_ops.identity(sparse_int) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_assert_shapes_sparse_tensor_partially_specified_target_fail(self): |
| sparse_int = sparse_tensor.SparseTensor( |
| constant_op.constant([[5, 6], [7, 8]], dtypes.int64), |
| constant_op.constant([23, -43], dtypes.int32), |
| constant_op.constant([30, 40], dtypes.int64)) |
| with self.assertRaisesRegex(ValueError, r"dimension 1 must have size 41"): |
| assertion = check_ops.assert_shapes([(sparse_int, [None, 41])]) |
| with ops.control_dependencies([assertion]): |
| out = array_ops.identity(sparse_int) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_assert_shapes_sparse_tensor_wrong_rank_fail(self): |
| sparse_int = sparse_tensor.SparseTensor( |
| constant_op.constant([[5, 6], [7, 8]], dtypes.int64), |
| constant_op.constant([23, -43], dtypes.int32), |
| constant_op.constant([30, 40], dtypes.int64)) |
| with self.assertRaisesRegexp(ValueError, |
| r"must have rank 3\..* Received rank 2"): |
| assertion = check_ops.assert_shapes([(sparse_int, [None, None, 40])]) |
| with ops.control_dependencies([assertion]): |
| out = array_ops.identity(sparse_int) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_assert_shapes_sparse_tensor_wrong_symbolic_match_fail(self): |
| sparse_int = sparse_tensor.SparseTensor( |
| constant_op.constant([[5, 6], [7, 8]], dtypes.int64), |
| constant_op.constant([23, -43], dtypes.int32), |
| constant_op.constant([30, 40], dtypes.int64)) |
| with self.assertRaisesRegex(ValueError, r"dimension 1 must have size 30"): |
| assertion = check_ops.assert_shapes([(sparse_int, ["D", "D"])]) |
| with ops.control_dependencies([assertion]): |
| out = array_ops.identity(sparse_int) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_assert_shapes_sparse_tensor_multiple_assertions_success(self): |
| sparse_scalar = sparse_tensor.SparseTensor( |
| constant_op.constant([[]], dtypes.int64), |
| constant_op.constant([42], dtypes.float32), |
| constant_op.constant([], dtypes.int64)) |
| sparse_2d = sparse_tensor.SparseTensor( |
| constant_op.constant([[5, 6], [7, 8]], dtypes.int64), |
| constant_op.constant([23, -43], dtypes.int32), |
| constant_op.constant([30, 30], dtypes.int64)) |
| assertion = check_ops.assert_shapes([(sparse_scalar, []), |
| (sparse_2d, ["N", "N"])]) |
| with ops.control_dependencies([assertion]): |
| out = array_ops.identity(sparse_2d) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_assert_shapes_sparse_tensor_multiple_assertions_fail(self): |
| sparse_scalar = sparse_tensor.SparseTensor( |
| constant_op.constant([[]], dtypes.int64), |
| constant_op.constant([42], dtypes.float32), |
| constant_op.constant([], dtypes.int64)) |
| sparse_2d = sparse_tensor.SparseTensor( |
| constant_op.constant([[5, 6], [7, 8]], dtypes.int64), |
| constant_op.constant([23, -43], dtypes.int32), |
| constant_op.constant([30, 40], dtypes.int64)) |
| with self.assertRaisesRegex(ValueError, r"dimension 1 must have size 30"): |
| assertion = check_ops.assert_shapes([(sparse_scalar, []), |
| (sparse_2d, ["N", "N"])]) |
| with ops.control_dependencies([assertion]): |
| out = array_ops.identity(sparse_2d) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_assert_shapes_sparse_tensor_mixed_dense_and_sparse_success(self): |
| dense_scalar = constant_op.constant([42], dtypes.float32) |
| sparse_2d = sparse_tensor.SparseTensor( |
| constant_op.constant([[5, 6], [7, 8]], dtypes.int64), |
| constant_op.constant([23, -43], dtypes.int32), |
| constant_op.constant([30, 30], dtypes.int64)) |
| assertion = check_ops.assert_shapes([(dense_scalar, []), |
| (sparse_2d, ["N", "N"])]) |
| with ops.control_dependencies([assertion]): |
| out = array_ops.identity(sparse_2d) |
| self.evaluate(out) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_assert_shapes_sparse_tensor_mixed_dense_and_sparse_fail(self): |
| dense_scalar = constant_op.constant([42], dtypes.float32) |
| sparse_2d = sparse_tensor.SparseTensor( |
| constant_op.constant([[5, 6], [7, 8]], dtypes.int64), |
| constant_op.constant([23, -43], dtypes.int32), |
| constant_op.constant([30, 40], dtypes.int64)) |
| with self.assertRaisesRegex(ValueError, r"dimension 1 must have size 30"): |
| assertion = check_ops.assert_shapes([(dense_scalar, []), |
| (sparse_2d, ["N", "N"])]) |
| with ops.control_dependencies([assertion]): |
| out = array_ops.identity(sparse_2d) |
| self.evaluate(out) |
| |
| |
| class IsStrictlyIncreasingTest(test.TestCase): |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_constant_tensor_is_not_strictly_increasing(self): |
| self.assertFalse(self.evaluate(check_ops.is_strictly_increasing([1, 1, 1]))) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_decreasing_tensor_is_not_strictly_increasing(self): |
| self.assertFalse(self.evaluate( |
| check_ops.is_strictly_increasing([1, 0, -1]))) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_2d_decreasing_tensor_is_not_strictly_increasing(self): |
| self.assertFalse( |
| self.evaluate(check_ops.is_strictly_increasing([[1, 3], [2, 4]]))) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_increasing_tensor_is_increasing(self): |
| self.assertTrue(self.evaluate(check_ops.is_strictly_increasing([1, 2, 3]))) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_increasing_rank_two_tensor(self): |
| self.assertTrue( |
| self.evaluate(check_ops.is_strictly_increasing([[-1, 2], [3, 4]]))) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_tensor_with_one_element_is_strictly_increasing(self): |
| self.assertTrue(self.evaluate(check_ops.is_strictly_increasing([1]))) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_empty_tensor_is_strictly_increasing(self): |
| self.assertTrue(self.evaluate(check_ops.is_strictly_increasing([]))) |
| |
| |
| class IsNonDecreasingTest(test.TestCase): |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_constant_tensor_is_non_decreasing(self): |
| self.assertTrue(self.evaluate(check_ops.is_non_decreasing([1, 1, 1]))) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_decreasing_tensor_is_not_non_decreasing(self): |
| self.assertFalse(self.evaluate(check_ops.is_non_decreasing([3, 2, 1]))) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_2d_decreasing_tensor_is_not_non_decreasing(self): |
| self.assertFalse(self.evaluate( |
| check_ops.is_non_decreasing([[1, 3], [2, 4]]))) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_increasing_rank_one_tensor_is_non_decreasing(self): |
| self.assertTrue(self.evaluate(check_ops.is_non_decreasing([1, 2, 3]))) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_increasing_rank_two_tensor(self): |
| self.assertTrue(self.evaluate( |
| check_ops.is_non_decreasing([[-1, 2], [3, 3]]))) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_tensor_with_one_element_is_non_decreasing(self): |
| self.assertTrue(self.evaluate(check_ops.is_non_decreasing([1]))) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_empty_tensor_is_non_decreasing(self): |
| self.assertTrue(self.evaluate(check_ops.is_non_decreasing([]))) |
| |
| |
| class FloatDTypeTest(test.TestCase): |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_assert_same_float_dtype(self): |
| self.assertIs(dtypes.float32, |
| check_ops.assert_same_float_dtype(None, None)) |
| self.assertIs(dtypes.float32, check_ops.assert_same_float_dtype([], None)) |
| self.assertIs(dtypes.float32, |
| check_ops.assert_same_float_dtype([], dtypes.float32)) |
| self.assertIs(dtypes.float32, |
| check_ops.assert_same_float_dtype(None, dtypes.float32)) |
| self.assertIs(dtypes.float32, |
| check_ops.assert_same_float_dtype([None, None], None)) |
| self.assertIs( |
| dtypes.float32, |
| check_ops.assert_same_float_dtype([None, None], dtypes.float32)) |
| |
| const_float = constant_op.constant(3.0, dtype=dtypes.float32) |
| self.assertIs( |
| dtypes.float32, |
| check_ops.assert_same_float_dtype([const_float], dtypes.float32)) |
| self.assertRaises(ValueError, check_ops.assert_same_float_dtype, |
| [const_float], dtypes.int32) |
| |
| sparse_float = sparse_tensor.SparseTensor( |
| constant_op.constant([[111], [232]], dtypes.int64), |
| constant_op.constant([23.4, -43.2], dtypes.float32), |
| constant_op.constant([500], dtypes.int64)) |
| self.assertIs(dtypes.float32, |
| check_ops.assert_same_float_dtype([sparse_float], |
| dtypes.float32)) |
| self.assertRaises(ValueError, check_ops.assert_same_float_dtype, |
| [sparse_float], dtypes.int32) |
| self.assertRaises(ValueError, check_ops.assert_same_float_dtype, |
| [const_float, None, sparse_float], dtypes.float64) |
| |
| self.assertIs(dtypes.float32, |
| check_ops.assert_same_float_dtype( |
| [const_float, sparse_float])) |
| self.assertIs(dtypes.float32, |
| check_ops.assert_same_float_dtype( |
| [const_float, sparse_float], dtypes.float32)) |
| |
| const_int = constant_op.constant(3, dtype=dtypes.int32) |
| self.assertRaises(ValueError, check_ops.assert_same_float_dtype, |
| [sparse_float, const_int]) |
| self.assertRaises(ValueError, check_ops.assert_same_float_dtype, |
| [sparse_float, const_int], dtypes.int32) |
| self.assertRaises(ValueError, check_ops.assert_same_float_dtype, |
| [sparse_float, const_int], dtypes.float32) |
| self.assertRaises(ValueError, check_ops.assert_same_float_dtype, |
| [const_int]) |
| |
| |
| class AssertScalarTest(test.TestCase): |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_assert_scalar(self): |
| check_ops.assert_scalar(constant_op.constant(3)) |
| check_ops.assert_scalar(constant_op.constant("foo")) |
| check_ops.assert_scalar(3) |
| check_ops.assert_scalar("foo") |
| with self.assertRaisesRegex(ValueError, "Expected scalar"): |
| check_ops.assert_scalar(constant_op.constant([3, 4])) |
| |
| |
| if __name__ == "__main__": |
| test.main() |