blob: 17ea8c2fe95033b7b387366f622219e369df0607 [file] [log] [blame]
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.check_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import gradients
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
# pylint:disable=g-error-prone-assert-raises
class AssertV2Asserts(test.TestCase):
def test_passes_when_it_should(self):
# This is a v2 test and need to run eagerly
with context.eager_mode():
c1 = constant_op.constant(-1, name="minus_one", dtype=dtypes.int32)
c2 = constant_op.constant(2, name="two", dtype=dtypes.int32)
c3 = constant_op.constant([3., 3.], name="three", dtype=dtypes.float32)
c4 = constant_op.constant([3., 3.5], name="three_and_a_half",
dtype=dtypes.float32)
scalar = c1
non_scalar = c3
integer = c1
non_integer = c3
positive = c2
negative = c1
cases = [
(check_ops.assert_equal_v2, (c1, c1), (c1, c2)),
(check_ops.assert_less_v2, (c1, c2), (c1, c1)),
(check_ops.assert_near_v2, (c3, c3), (c3, c4)),
(check_ops.assert_greater_v2, (c2, c1), (c1, c1)),
(check_ops.assert_negative_v2, (negative,), (positive,)),
(check_ops.assert_positive_v2, (positive,), (negative,)),
(check_ops.assert_less_equal_v2, (c1, c1), (c2, c1)),
(check_ops.assert_none_equal_v2, (c1, c2), (c3, c4)),
(check_ops.assert_non_negative_v2, (positive,), (negative,)),
(check_ops.assert_non_positive_v2, (negative,), (positive,)),
(check_ops.assert_greater_equal_v2, (c1, c1), (c1, c2)),
(check_ops.assert_type_v2, (c1, dtypes.int32), (c1, dtypes.float32),
TypeError),
(check_ops.assert_integer_v2, (integer,), (non_integer,),
TypeError),
(check_ops.assert_scalar_v2, (scalar,), (non_scalar,),
ValueError),
(check_ops.assert_rank_v2, (c1, 0), (c3, 2), ValueError),
(check_ops.assert_rank_in_v2, (c1, [0, 1]), (c1, [1, 2]),
ValueError),
(check_ops.assert_rank_at_least_v2, (non_scalar, 1), (scalar, 1),
ValueError),
]
for case in cases:
fn = case[0]
passing_args = case[1]
failing_args = case[2]
error = errors.InvalidArgumentError if len(case) < 4 else case[3]
print("Testing %s passing properly." % fn)
fn(*passing_args)
print("Testing %s failing properly." % fn)
@def_function.function
def failing_fn():
fn(*failing_args, message="fail") # pylint: disable=cell-var-from-loop
with self.assertRaisesRegex(error, "fail"):
failing_fn()
del failing_fn
class AssertProperIterableTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_single_tensor_raises(self):
tensor = constant_op.constant(1)
with self.assertRaisesRegex(TypeError, "proper"):
check_ops.assert_proper_iterable(tensor)
@test_util.run_in_graph_and_eager_modes
def test_single_sparse_tensor_raises(self):
ten = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
with self.assertRaisesRegex(TypeError, "proper"):
check_ops.assert_proper_iterable(ten)
@test_util.run_in_graph_and_eager_modes
def test_single_ndarray_raises(self):
array = np.array([1, 2, 3])
with self.assertRaisesRegex(TypeError, "proper"):
check_ops.assert_proper_iterable(array)
@test_util.run_in_graph_and_eager_modes
def test_single_string_raises(self):
mystr = "hello"
with self.assertRaisesRegex(TypeError, "proper"):
check_ops.assert_proper_iterable(mystr)
@test_util.run_in_graph_and_eager_modes
def test_non_iterable_object_raises(self):
non_iterable = 1234
with self.assertRaisesRegex(TypeError, "to be an iterable"):
check_ops.assert_proper_iterable(non_iterable)
@test_util.run_in_graph_and_eager_modes
def test_list_does_not_raise(self):
list_of_stuff = [
constant_op.constant([11, 22]), constant_op.constant([1, 2])
]
check_ops.assert_proper_iterable(list_of_stuff)
@test_util.run_in_graph_and_eager_modes
def test_generator_does_not_raise(self):
generator_of_stuff = (constant_op.constant([11, 22]), constant_op.constant(
[1, 2]))
check_ops.assert_proper_iterable(generator_of_stuff)
class AssertEqualTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_equal(self):
small = constant_op.constant([1, 2], name="small")
with ops.control_dependencies([check_ops.assert_equal(small, small)]):
out = array_ops.identity(small)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_scalar_comparison(self):
const_true = constant_op.constant(True, name="true")
const_false = constant_op.constant(False, name="false")
with self.assertRaisesRegex(errors.InvalidArgumentError, "fail"):
check_ops.assert_equal(const_true, const_false, message="fail")
def test_returns_none_with_eager(self):
with context.eager_mode():
small = constant_op.constant([1, 2], name="small")
x = check_ops.assert_equal(small, small)
assert x is None
@test_util.run_in_graph_and_eager_modes
@test_util.run_deprecated_v1
def test_raises_when_greater(self):
# Static check
static_small = constant_op.constant([1, 2], name="small")
static_big = constant_op.constant([3, 4], name="big")
with self.assertRaisesRegex(errors.InvalidArgumentError, "fail"):
check_ops.assert_equal(static_big, static_small, message="fail")
@test_util.run_deprecated_v1
def test_raises_when_greater_dynamic(self):
with self.cached_session():
small = array_ops.placeholder(dtypes.int32, name="small")
big = array_ops.placeholder(dtypes.int32, name="big")
with ops.control_dependencies(
[check_ops.assert_equal(big, small, message="fail")]):
out = array_ops.identity(small)
with self.assertRaisesOpError("fail.*big.*small"):
out.eval(feed_dict={small: [1, 2], big: [3, 4]})
def test_error_message_eager(self):
expected_error_msg_full = r"""big does not equal small
Condition x == y did not hold.
Indices of first 3 different values:
\[\[0 0\]
\[1 1\]
\[2 0\]\]
Corresponding x values:
\[2 3 6\]
Corresponding y values:
\[20 30 60\]
First 6 elements of x:
\[2 2 3 3 6 6\]
First 6 elements of y:
\[20 2 3 30 60 6\]"""
expected_error_msg_default = r"""big does not equal small
Condition x == y did not hold.
Indices of first 3 different values:
\[\[0 0\]
\[1 1\]
\[2 0\]\]
Corresponding x values:
\[2 3 6\]
Corresponding y values:
\[20 30 60\]
First 3 elements of x:
\[2 2 3\]
First 3 elements of y:
\[20 2 3\]"""
expected_error_msg_short = r"""big does not equal small
Condition x == y did not hold.
Indices of first 2 different values:
\[\[0 0\]
\[1 1\]\]
Corresponding x values:
\[2 3\]
Corresponding y values:
\[20 30\]
First 2 elements of x:
\[2 2\]
First 2 elements of y:
\[20 2\]"""
with context.eager_mode():
big = constant_op.constant([[2, 2], [3, 3], [6, 6]])
small = constant_op.constant([[20, 2], [3, 30], [60, 6]])
with self.assertRaisesRegex(errors.InvalidArgumentError,
expected_error_msg_full):
check_ops.assert_equal(big, small, message="big does not equal small",
summarize=10)
with self.assertRaisesRegex(errors.InvalidArgumentError,
expected_error_msg_default):
check_ops.assert_equal(big, small, message="big does not equal small")
with self.assertRaisesRegex(errors.InvalidArgumentError,
expected_error_msg_short):
check_ops.assert_equal(big, small, message="big does not equal small",
summarize=2)
@test_util.run_in_graph_and_eager_modes
@test_util.run_deprecated_v1
def test_raises_when_less(self):
# Static check
static_small = constant_op.constant([3, 1], name="small")
static_big = constant_op.constant([4, 2], name="big")
with self.assertRaisesRegex(errors.InvalidArgumentError, "fail"):
check_ops.assert_equal(static_big, static_small, message="fail")
@test_util.run_deprecated_v1
def test_raises_when_less_dynamic(self):
with self.cached_session():
small = array_ops.placeholder(dtypes.int32, name="small")
big = array_ops.placeholder(dtypes.int32, name="big")
with ops.control_dependencies([check_ops.assert_equal(small, big)]):
out = array_ops.identity(small)
with self.assertRaisesOpError("small.*big"):
out.eval(feed_dict={small: [3, 1], big: [4, 2]})
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_equal_and_broadcastable_shapes(self):
small = constant_op.constant([[1, 2], [1, 2]], name="small")
small_2 = constant_op.constant([1, 2], name="small_2")
with ops.control_dependencies([check_ops.assert_equal(small, small_2)]):
out = array_ops.identity(small)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_raises_when_equal_but_non_broadcastable_shapes(self):
small = constant_op.constant([1, 1, 1], name="small")
small_2 = constant_op.constant([1, 1], name="small_2")
# The exception in eager and non-eager mode is different because
# eager mode relies on shape check done as part of the C++ op, while
# graph mode does shape checks when creating the `Operation` instance.
with self.assertRaisesIncompatibleShapesError(
(errors.InvalidArgumentError, ValueError)):
with ops.control_dependencies([check_ops.assert_equal(small, small_2)]):
out = array_ops.identity(small)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_raises_when_not_equal_and_broadcastable_shapes(self):
cond = constant_op.constant([True, False], name="small")
with self.assertRaisesRegex(errors.InvalidArgumentError, "fail"):
check_ops.assert_equal(cond, False, message="fail")
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_both_empty(self):
larry = constant_op.constant([])
curly = constant_op.constant([])
with ops.control_dependencies([check_ops.assert_equal(larry, curly)]):
out = array_ops.identity(larry)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_noop_when_both_identical(self):
larry = constant_op.constant([])
check_op = check_ops.assert_equal(larry, larry)
if context.executing_eagerly():
self.assertIs(check_op, None)
else:
self.assertEqual(check_op.type, "NoOp")
class AssertNoneEqualTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_not_equal(self):
small = constant_op.constant([1, 2], name="small")
big = constant_op.constant([10, 20], name="small")
with ops.control_dependencies(
[check_ops.assert_none_equal(big, small)]):
out = array_ops.identity(small)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
@test_util.run_deprecated_v1
def test_raises_when_equal(self):
small = constant_op.constant([3, 1], name="small")
with self.assertRaisesOpError("x != y did not hold"):
with ops.control_dependencies(
[check_ops.assert_none_equal(small, small)]):
out = array_ops.identity(small)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_not_equal_and_broadcastable_shapes(self):
small = constant_op.constant([1, 2], name="small")
big = constant_op.constant([3], name="big")
with ops.control_dependencies(
[check_ops.assert_none_equal(small, big)]):
out = array_ops.identity(small)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_raises_when_not_equal_but_non_broadcastable_shapes(self):
small = constant_op.constant([1, 1, 1], name="small")
big = constant_op.constant([10, 10], name="big")
# The exception in eager and non-eager mode is different because
# eager mode relies on shape check done as part of the C++ op, while
# graph mode does shape checks when creating the `Operation` instance.
with self.assertRaisesIncompatibleShapesError(
(ValueError, errors.InvalidArgumentError)):
with ops.control_dependencies(
[check_ops.assert_none_equal(small, big)]):
out = array_ops.identity(small)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_both_empty(self):
larry = constant_op.constant([])
curly = constant_op.constant([])
with ops.control_dependencies(
[check_ops.assert_none_equal(larry, curly)]):
out = array_ops.identity(larry)
self.evaluate(out)
def test_returns_none_with_eager(self):
with context.eager_mode():
t1 = constant_op.constant([1, 2])
t2 = constant_op.constant([3, 4])
x = check_ops.assert_none_equal(t1, t2)
assert x is None
def test_static_check_in_graph_mode(self):
with ops.Graph().as_default():
with self.assertRaisesRegex( # pylint:disable=g-error-prone-assert-raises
errors.InvalidArgumentError, "Custom error message"):
check_ops.assert_none_equal(1, 1, message="Custom error message")
def test_error_message_eager(self):
# Note that the following three strings are regexes
expected_error_msg_full = r"""\[ *0\. +1\. +2\. +3\. +4\. +5\.\]"""
expected_error_msg_default = r"""\[ *0\. +1\. +2\.\]"""
expected_error_msg_short = r"""\[ *0\. +1\.\]"""
with context.eager_mode():
t = constant_op.constant(
np.array(range(6)), shape=[2, 3], dtype=np.float32)
with self.assertRaisesRegex( # pylint:disable=g-error-prone-assert-raises
errors.InvalidArgumentError, expected_error_msg_full):
check_ops.assert_none_equal(
t, t, message="This is the error message.", summarize=10)
with self.assertRaisesRegex( # pylint:disable=g-error-prone-assert-raises
errors.InvalidArgumentError, expected_error_msg_full):
check_ops.assert_none_equal(
t, t, message="This is the error message.", summarize=-1)
with self.assertRaisesRegex( # pylint:disable=g-error-prone-assert-raises
errors.InvalidArgumentError, expected_error_msg_default):
check_ops.assert_none_equal(t, t, message="This is the error message.")
with self.assertRaisesRegex( # pylint:disable=g-error-prone-assert-raises
errors.InvalidArgumentError, expected_error_msg_short):
check_ops.assert_none_equal(
t, t, message="This is the error message.", summarize=2)
class AssertAllCloseTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_equal(self):
x = constant_op.constant(1., name="x")
y = constant_op.constant(1., name="y")
with ops.control_dependencies(
[check_ops.assert_near(x, y, message="failure message")]):
out = array_ops.identity(x)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_close_enough_32_bit_due_to_default_rtol(self):
eps = np.finfo(np.float32).eps
# Default rtol/atol is 10*eps
x = constant_op.constant(1., name="x")
y = constant_op.constant(1. + 2 * eps, name="y", dtype=np.float32)
with ops.control_dependencies(
[check_ops.assert_near(x, y, atol=0., message="failure message")]):
out = array_ops.identity(x)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_close_enough_32_bit_due_to_default_atol(self):
eps = np.finfo(np.float32).eps
# Default rtol/atol is 10*eps
x = constant_op.constant(0., name="x")
y = constant_op.constant(0. + 2 * eps, name="y", dtype=np.float32)
with ops.control_dependencies(
[check_ops.assert_near(x, y, rtol=0., message="failure message")]):
out = array_ops.identity(x)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_close_enough_64_bit_due_to_default_rtol(self):
eps = np.finfo(np.float64).eps
# Default rtol/atol is 10*eps
x = constant_op.constant(1., name="x", dtype=np.float64)
y = constant_op.constant(1. + 2 * eps, name="y", dtype=np.float64)
with ops.control_dependencies(
[check_ops.assert_near(x, y, atol=0., message="failure message")]):
out = array_ops.identity(x)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_close_enough_64_bit_due_to_default_atol(self):
eps = np.finfo(np.float64).eps
# Default rtol/atol is 10*eps
x = constant_op.constant(0., name="x", dtype=np.float64)
y = constant_op.constant(0. + 2 * eps, name="y", dtype=np.float64)
with ops.control_dependencies(
[check_ops.assert_near(x, y, rtol=0., message="failure message")]):
out = array_ops.identity(x)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_close_enough_due_to_custom_rtol(self):
x = constant_op.constant(1., name="x")
y = constant_op.constant(1.1, name="y")
with ops.control_dependencies(
[check_ops.assert_near(x, y, atol=0., rtol=0.5,
message="failure message")]):
out = array_ops.identity(x)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_close_enough_due_to_custom_atol(self):
x = constant_op.constant(0., name="x")
y = constant_op.constant(0.1, name="y", dtype=np.float32)
with ops.control_dependencies(
[check_ops.assert_near(x, y, atol=0.5, rtol=0.,
message="failure message")]):
out = array_ops.identity(x)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_both_empty(self):
larry = constant_op.constant([])
curly = constant_op.constant([])
with ops.control_dependencies([check_ops.assert_near(larry, curly)]):
out = array_ops.identity(larry)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_raises_when_atol_violated(self):
x = constant_op.constant(10., name="x")
y = constant_op.constant(10.2, name="y")
with self.assertRaisesOpError( # pylint:disable=g-error-prone-assert-raises
"x and y not equal to tolerance"):
with ops.control_dependencies(
[check_ops.assert_near(x, y, atol=0.1,
message="failure message")]):
out = array_ops.identity(x)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_raises_when_default_rtol_violated(self):
x = constant_op.constant(0.1, name="x")
y = constant_op.constant(0.0, name="y")
with self.assertRaisesOpError( # pylint:disable=g-error-prone-assert-raises
"x and y not equal to tolerance"):
with ops.control_dependencies(
[check_ops.assert_near(x, y, message="failure message")]):
out = array_ops.identity(x)
self.evaluate(out)
def test_returns_none_with_eager(self):
with context.eager_mode():
t1 = constant_op.constant([1., 2.])
t2 = constant_op.constant([1., 2.])
x = check_ops.assert_near(t1, t2)
assert x is None
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_complex(self):
x = constant_op.constant(1. + 0.1j, name="x")
y = constant_op.constant(1.1 + 0.1j, name="y")
with ops.control_dependencies([
check_ops.assert_near(
x, y, atol=0., rtol=0.5, message="failure message")
]):
out = array_ops.identity(x)
self.evaluate(out)
class AssertLessTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
@test_util.run_deprecated_v1
def test_raises_when_equal(self):
small = constant_op.constant([1, 2], name="small")
with self.assertRaisesOpError( # pylint:disable=g-error-prone-assert-raises
"failure message.*\n*.* x < y did not hold"):
with ops.control_dependencies(
[check_ops.assert_less(
small, small, message="failure message")]):
out = array_ops.identity(small)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
@test_util.run_deprecated_v1
def test_raises_when_greater(self):
small = constant_op.constant([1, 2], name="small")
big = constant_op.constant([3, 4], name="big")
with self.assertRaisesOpError( # pylint:disable=g-error-prone-assert-raises
"x < y did not hold"):
with ops.control_dependencies([check_ops.assert_less(big, small)]):
out = array_ops.identity(small)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_less(self):
small = constant_op.constant([3, 1], name="small")
big = constant_op.constant([4, 2], name="big")
with ops.control_dependencies([check_ops.assert_less(small, big)]):
out = array_ops.identity(small)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_less_and_broadcastable_shapes(self):
small = constant_op.constant([1], name="small")
big = constant_op.constant([3, 2], name="big")
with ops.control_dependencies([check_ops.assert_less(small, big)]):
out = array_ops.identity(small)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_raises_when_less_but_non_broadcastable_shapes(self):
small = constant_op.constant([1, 1, 1], name="small")
big = constant_op.constant([3, 2], name="big")
# The exception in eager and non-eager mode is different because
# eager mode relies on shape check done as part of the C++ op, while
# graph mode does shape checks when creating the `Operation` instance.
with self.assertRaisesIncompatibleShapesError(
(ValueError, errors.InvalidArgumentError)):
with ops.control_dependencies([check_ops.assert_less(small, big)]):
out = array_ops.identity(small)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_both_empty(self):
larry = constant_op.constant([])
curly = constant_op.constant([])
with ops.control_dependencies([check_ops.assert_less(larry, curly)]):
out = array_ops.identity(larry)
self.evaluate(out)
def test_returns_none_with_eager(self):
with context.eager_mode():
t1 = constant_op.constant([1, 2])
t2 = constant_op.constant([3, 4])
x = check_ops.assert_less(t1, t2)
assert x is None
def test_static_check_in_graph_mode(self):
with ops.Graph().as_default():
with self.assertRaisesRegex( # pylint:disable=g-error-prone-assert-raises
errors.InvalidArgumentError, "Custom error message"):
check_ops.assert_less(1, 1, message="Custom error message")
class AssertLessEqualTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_equal(self):
small = constant_op.constant([1, 2], name="small")
with ops.control_dependencies(
[check_ops.assert_less_equal(small, small)]):
out = array_ops.identity(small)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
@test_util.run_deprecated_v1
def test_raises_when_greater(self):
small = constant_op.constant([1, 2], name="small")
big = constant_op.constant([3, 4], name="big")
with self.assertRaisesOpError( # pylint:disable=g-error-prone-assert-raises
"fail"):
with ops.control_dependencies(
[check_ops.assert_less_equal(
big, small, message="fail")]):
out = array_ops.identity(small)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_less_equal(self):
small = constant_op.constant([1, 2], name="small")
big = constant_op.constant([3, 2], name="big")
with ops.control_dependencies([check_ops.assert_less_equal(small, big)]):
out = array_ops.identity(small)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_less_equal_and_broadcastable_shapes(self):
small = constant_op.constant([1], name="small")
big = constant_op.constant([3, 1], name="big")
with ops.control_dependencies([check_ops.assert_less_equal(small, big)]):
out = array_ops.identity(small)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_raises_when_less_equal_but_non_broadcastable_shapes(self):
small = constant_op.constant([3, 1], name="small")
big = constant_op.constant([1, 1, 1], name="big")
# The exception in eager and non-eager mode is different because
# eager mode relies on shape check done as part of the C++ op, while
# graph mode does shape checks when creating the `Operation` instance.
with self.assertRaisesRegex( # pylint:disable=g-error-prone-assert-raises
(errors.InvalidArgumentError, ValueError),
(r"Incompatible shapes: \[2\] vs. \[3\]|"
r"Dimensions must be equal, but are 2 and 3")):
with ops.control_dependencies(
[check_ops.assert_less_equal(small, big)]):
out = array_ops.identity(small)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_both_empty(self):
larry = constant_op.constant([])
curly = constant_op.constant([])
with ops.control_dependencies(
[check_ops.assert_less_equal(larry, curly)]):
out = array_ops.identity(larry)
self.evaluate(out)
def test_static_check_in_graph_mode(self):
with ops.Graph().as_default():
with self.assertRaisesRegex( # pylint:disable=g-error-prone-assert-raises
errors.InvalidArgumentError, "Custom error message"):
check_ops.assert_less_equal(1, 0, message="Custom error message")
class AssertGreaterTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
@test_util.run_deprecated_v1
def test_raises_when_equal(self):
small = constant_op.constant([1, 2], name="small")
with self.assertRaisesOpError( # pylint:disable=g-error-prone-assert-raises
"fail"):
with ops.control_dependencies(
[check_ops.assert_greater(
small, small, message="fail")]):
out = array_ops.identity(small)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
@test_util.run_deprecated_v1
def test_raises_when_less(self):
small = constant_op.constant([1, 2], name="small")
big = constant_op.constant([3, 4], name="big")
with self.assertRaisesOpError( # pylint:disable=g-error-prone-assert-raises
"x > y did not hold"):
with ops.control_dependencies([check_ops.assert_greater(small, big)]):
out = array_ops.identity(big)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_greater(self):
small = constant_op.constant([3, 1], name="small")
big = constant_op.constant([4, 2], name="big")
with ops.control_dependencies([check_ops.assert_greater(big, small)]):
out = array_ops.identity(small)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_greater_and_broadcastable_shapes(self):
small = constant_op.constant([1], name="small")
big = constant_op.constant([3, 2], name="big")
with ops.control_dependencies([check_ops.assert_greater(big, small)]):
out = array_ops.identity(small)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_raises_when_greater_but_non_broadcastable_shapes(self):
small = constant_op.constant([1, 1, 1], name="small")
big = constant_op.constant([3, 2], name="big")
# The exception in eager and non-eager mode is different because
# eager mode relies on shape check done as part of the C++ op, while
# graph mode does shape checks when creating the `Operation` instance.
with self.assertRaisesRegex( # pylint:disable=g-error-prone-assert-raises
(errors.InvalidArgumentError, ValueError),
(r"Incompatible shapes: \[2\] vs. \[3\]|"
r"Dimensions must be equal, but are 2 and 3")):
with ops.control_dependencies([check_ops.assert_greater(big, small)]):
out = array_ops.identity(small)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_both_empty(self):
larry = constant_op.constant([])
curly = constant_op.constant([])
with ops.control_dependencies([check_ops.assert_greater(larry, curly)]):
out = array_ops.identity(larry)
self.evaluate(out)
def test_static_check_in_graph_mode(self):
with ops.Graph().as_default():
with self.assertRaisesRegex( # pylint:disable=g-error-prone-assert-raises
errors.InvalidArgumentError, "Custom error message"):
check_ops.assert_greater(0, 1, message="Custom error message")
class AssertGreaterEqualTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_equal(self):
small = constant_op.constant([1, 2], name="small")
with ops.control_dependencies(
[check_ops.assert_greater_equal(small, small)]):
out = array_ops.identity(small)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
@test_util.run_deprecated_v1
def test_raises_when_less(self):
small = constant_op.constant([1, 2], name="small")
big = constant_op.constant([3, 4], name="big")
with self.assertRaisesOpError( # pylint:disable=g-error-prone-assert-raises
"fail"):
with ops.control_dependencies(
[check_ops.assert_greater_equal(
small, big, message="fail")]):
out = array_ops.identity(small)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_greater_equal(self):
small = constant_op.constant([1, 2], name="small")
big = constant_op.constant([3, 2], name="big")
with ops.control_dependencies(
[check_ops.assert_greater_equal(big, small)]):
out = array_ops.identity(small)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_greater_equal_and_broadcastable_shapes(self):
small = constant_op.constant([1], name="small")
big = constant_op.constant([3, 1], name="big")
with ops.control_dependencies(
[check_ops.assert_greater_equal(big, small)]):
out = array_ops.identity(small)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_raises_when_less_equal_but_non_broadcastable_shapes(self):
small = constant_op.constant([1, 1, 1], name="big")
big = constant_op.constant([3, 1], name="small")
# The exception in eager and non-eager mode is different because
# eager mode relies on shape check done as part of the C++ op, while
# graph mode does shape checks when creating the `Operation` instance.
with self.assertRaisesRegex( # pylint:disable=g-error-prone-assert-raises
(errors.InvalidArgumentError, ValueError),
(r"Incompatible shapes: \[2\] vs. \[3\]|"
r"Dimensions must be equal, but are 2 and 3")):
with ops.control_dependencies(
[check_ops.assert_greater_equal(big, small)]):
out = array_ops.identity(small)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_both_empty(self):
larry = constant_op.constant([])
curly = constant_op.constant([])
with ops.control_dependencies(
[check_ops.assert_greater_equal(larry, curly)]):
out = array_ops.identity(larry)
self.evaluate(out)
def test_static_check_in_graph_mode(self):
with ops.Graph().as_default():
with self.assertRaisesRegex( # pylint:disable=g-error-prone-assert-raises
errors.InvalidArgumentError, "Custom error message"):
check_ops.assert_greater_equal(0, 1, message="Custom error message")
class AssertNegativeTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_negative(self):
frank = constant_op.constant([-1, -2], name="frank")
with ops.control_dependencies([check_ops.assert_negative(frank)]):
out = array_ops.identity(frank)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
@test_util.run_deprecated_v1
def test_raises_when_positive(self):
doug = constant_op.constant([1, 2], name="doug")
with self.assertRaisesOpError( # pylint:disable=g-error-prone-assert-raises
"fail"):
with ops.control_dependencies(
[check_ops.assert_negative(
doug, message="fail")]):
out = array_ops.identity(doug)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
@test_util.run_deprecated_v1
def test_raises_when_zero(self):
claire = constant_op.constant([0], name="claire")
with self.assertRaisesOpError( # pylint:disable=g-error-prone-assert-raises
"x < 0 did not hold"):
with ops.control_dependencies([check_ops.assert_negative(claire)]):
out = array_ops.identity(claire)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_empty_tensor_doesnt_raise(self):
# A tensor is negative when it satisfies:
# For every element x_i in x, x_i < 0
# and an empty tensor has no elements, so this is trivially satisfied.
# This is standard set theory.
empty = constant_op.constant([], name="empty")
with ops.control_dependencies([check_ops.assert_negative(empty)]):
out = array_ops.identity(empty)
self.evaluate(out)
def test_static_check_in_graph_mode(self):
with ops.Graph().as_default():
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Custom error message"):
check_ops.assert_negative(1, message="Custom error message")
# pylint:disable=g-error-prone-assert-raises
class AssertPositiveTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
@test_util.run_deprecated_v1
def test_raises_when_negative(self):
freddie = constant_op.constant([-1, -2], name="freddie")
with self.assertRaisesOpError("fail"):
with ops.control_dependencies(
[check_ops.assert_positive(
freddie, message="fail")]):
out = array_ops.identity(freddie)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_positive(self):
remmy = constant_op.constant([1, 2], name="remmy")
with ops.control_dependencies([check_ops.assert_positive(remmy)]):
out = array_ops.identity(remmy)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
@test_util.run_deprecated_v1
def test_raises_when_zero(self):
meechum = constant_op.constant([0], name="meechum")
with self.assertRaisesOpError("x > 0 did not hold"):
with ops.control_dependencies([check_ops.assert_positive(meechum)]):
out = array_ops.identity(meechum)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_empty_tensor_doesnt_raise(self):
# A tensor is positive when it satisfies:
# For every element x_i in x, x_i > 0
# and an empty tensor has no elements, so this is trivially satisfied.
# This is standard set theory.
empty = constant_op.constant([], name="empty")
with ops.control_dependencies([check_ops.assert_positive(empty)]):
out = array_ops.identity(empty)
self.evaluate(out)
def test_static_check_in_graph_mode(self):
with ops.Graph().as_default():
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Custom error message"):
check_ops.assert_positive(-1, message="Custom error message")
class EnsureShapeTest(test.TestCase):
# Static shape inference
@test_util.run_deprecated_v1
def testStaticShape(self):
placeholder = array_ops.placeholder(dtypes.int32)
ensure_shape_op = check_ops.ensure_shape(placeholder, (3, 3, 3))
self.assertEqual(ensure_shape_op.get_shape(), (3, 3, 3))
@test_util.run_deprecated_v1
def testStaticShape_MergesShapes(self):
placeholder = array_ops.placeholder(dtypes.int32, shape=(None, None, 3))
ensure_shape_op = check_ops.ensure_shape(placeholder, (5, 4, None))
self.assertEqual(ensure_shape_op.get_shape(), (5, 4, 3))
@test_util.run_deprecated_v1
def testStaticShape_RaisesErrorWhenRankIncompatible(self):
placeholder = array_ops.placeholder(dtypes.int32, shape=(None, None, 3))
with self.assertRaises(ValueError):
check_ops.ensure_shape(placeholder, (2, 3))
@test_util.run_deprecated_v1
def testStaticShape_RaisesErrorWhenDimIncompatible(self):
placeholder = array_ops.placeholder(dtypes.int32, shape=(None, None, 3))
with self.assertRaises(ValueError):
check_ops.ensure_shape(placeholder, (2, 2, 4))
@test_util.run_deprecated_v1
def testStaticShape_CanSetUnknownShape(self):
placeholder = array_ops.placeholder(dtypes.int32)
derived = placeholder / 3
ensure_shape_op = check_ops.ensure_shape(derived, None)
self.assertEqual(ensure_shape_op.get_shape(), None)
# Dynamic shape check
@test_util.run_deprecated_v1
@test_util.disable_xla(
"b/123337890") # Dynamic shapes not supported now with XLA
def testEnsuresDynamicShape_RaisesError(self):
placeholder = array_ops.placeholder(dtypes.int32)
derived = math_ops.divide(placeholder, 3, name="MyDivide")
derived = check_ops.ensure_shape(derived, (3, 3, 3))
feed_val = [[1], [2]]
with self.cached_session() as sess:
with self.assertRaisesWithPredicateMatch(
errors.InvalidArgumentError,
r"Shape of tensor MyDivide \[2,1\] is not compatible with "
r"expected shape \[3,3,3\]."):
sess.run(derived, feed_dict={placeholder: feed_val})
@test_util.run_deprecated_v1
@test_util.disable_xla(
"b/123337890") # Dynamic shapes not supported now with XLA
def testEnsuresDynamicShape_RaisesErrorDimUnknown(self):
placeholder = array_ops.placeholder(dtypes.int32)
derived = placeholder / 3
derived = check_ops.ensure_shape(derived, (None, None, 3))
feed_val = [[1], [2]]
with self.cached_session() as sess:
with self.assertRaisesWithPredicateMatch(
errors.InvalidArgumentError,
r"Shape of tensor [A-Za-z_]* \[2,1\] is not compatible with "
r"expected shape \[\?,\?,3\]."):
sess.run(derived, feed_dict={placeholder: feed_val})
@test_util.run_deprecated_v1
def testEnsuresDynamicShape(self):
placeholder = array_ops.placeholder(dtypes.int32)
derived = placeholder / 3
derived = check_ops.ensure_shape(derived, (2, 1))
feed_val = [[1], [2]]
with self.cached_session() as sess:
sess.run(derived, feed_dict={placeholder: feed_val})
@test_util.run_deprecated_v1
def testEnsuresDynamicShape_WithUnknownDims(self):
placeholder = array_ops.placeholder(dtypes.int32)
derived = placeholder / 3
derived = check_ops.ensure_shape(derived, (None, None))
feed_val = [[1], [2]]
with self.cached_session() as sess:
sess.run(derived, feed_dict={placeholder: feed_val})
@test_util.run_deprecated_v1
def testGradient(self):
placeholder = array_ops.placeholder(dtypes.float32)
derived = check_ops.ensure_shape(placeholder, (None, None))
gradient = gradients.gradients(derived, placeholder)
feed_val = [[4.0], [-1.0]]
with self.cached_session() as sess:
gradient_values, = sess.run(gradient, feed_dict={placeholder: feed_val})
expected = [[1.0], [1.0]]
self.assertAllEqual(gradient_values, expected)
class EnsureShapeBenchmark(test.Benchmark):
def _grappler_all_off_config(self):
config = config_pb2.ConfigProto()
off = rewriter_config_pb2.RewriterConfig.OFF
config.graph_options.optimizer_options.opt_level = -1
config.graph_options.rewrite_options.disable_model_pruning = 1
config.graph_options.rewrite_options.constant_folding = off
config.graph_options.rewrite_options.layout_optimizer = off
config.graph_options.rewrite_options.arithmetic_optimization = off
config.graph_options.rewrite_options.dependency_optimization = off
return config
def _run(self, op, feed_dict=None, num_iters=5000, name=None, **kwargs):
config = self._grappler_all_off_config()
with session.Session(config=config) as sess:
deltas = []
# Warm up the session
for _ in range(5):
sess.run(op, feed_dict=feed_dict)
for _ in range(num_iters):
start = time.time()
sess.run(op, feed_dict=feed_dict)
end = time.time()
deltas.append(end - start)
mean_time = np.median(deltas)
mean_us = mean_time * 1e6
# mean_us = (end - start) * 1e6 / num_iters
self.report_benchmark(
name=name,
wall_time=mean_us,
extras=kwargs,
)
def benchmark_const_op(self):
# In this case, we expect that the overhead of a `session.run` call
# far outweighs the time taken to execute the op...
shape = (3, 3, 100)
input_op = random_ops.random_normal(shape)
self._run(array_ops.identity(input_op), name="SingleConstOp")
def benchmark_single_ensure_op(self):
# In this case, we expect that the overhead of a `session.run` call
# far outweighs the time taken to execute the op...
shape = (3, 3, 100)
input_op = random_ops.random_normal(shape)
ensure_shape_op = check_ops.ensure_shape(input_op, shape)
self._run(ensure_shape_op, name="SingleEnsureShapeOp")
def _apply_n_times(self, op, target, n=1000):
for _ in range(n):
target = op(target)
return target
def benchmark_n_ops(self):
shape = (1000,)
input_op = random_ops.random_normal(shape)
n_ops = self._apply_n_times(array_ops.identity, input_op)
self._run(n_ops, name="NIdentityOps_1000")
def benchmark_n_ensure_ops(self):
shape = (1000,)
input_op = random_ops.random_normal(shape)
n_ensure_ops = self._apply_n_times(
lambda x: check_ops.ensure_shape(array_ops.identity(x), shape),
input_op)
self._run(n_ensure_ops, name="NEnsureShapeAndIdentityOps_1000")
class AssertRankTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_rank_zero_tensor_raises_if_rank_too_small_static_rank(self):
tensor = constant_op.constant(1, name="my_tensor")
desired_rank = 1
with self.assertRaisesRegex(ValueError, "fail.*must have rank 1"):
with ops.control_dependencies(
[check_ops.assert_rank(
tensor, desired_rank, message="fail")]):
self.evaluate(array_ops.identity(tensor))
@test_util.run_deprecated_v1
def test_rank_zero_tensor_raises_if_rank_too_small_dynamic_rank(self):
with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
desired_rank = 1
with ops.control_dependencies(
[check_ops.assert_rank(
tensor, desired_rank, message="fail")]):
with self.assertRaisesOpError("fail.*my_tensor.*rank"):
array_ops.identity(tensor).eval(feed_dict={tensor: 0})
@test_util.run_in_graph_and_eager_modes
def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_static_rank(self):
tensor = constant_op.constant(1, name="my_tensor")
desired_rank = 0
with ops.control_dependencies(
[check_ops.assert_rank(tensor, desired_rank)]):
self.evaluate(array_ops.identity(tensor))
@test_util.run_deprecated_v1
def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self):
with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
desired_rank = 0
with ops.control_dependencies(
[check_ops.assert_rank(tensor, desired_rank)]):
array_ops.identity(tensor).eval(feed_dict={tensor: 0})
@test_util.run_in_graph_and_eager_modes
def test_rank_one_tensor_raises_if_rank_too_large_static_rank(self):
tensor = constant_op.constant([1, 2], name="my_tensor")
desired_rank = 0
with self.assertRaisesRegex(ValueError, "rank"):
with ops.control_dependencies(
[check_ops.assert_rank(tensor, desired_rank)]):
self.evaluate(array_ops.identity(tensor))
@test_util.run_deprecated_v1
def test_rank_one_tensor_raises_if_rank_too_large_dynamic_rank(self):
with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
desired_rank = 0
with ops.control_dependencies(
[check_ops.assert_rank(tensor, desired_rank)]):
with self.assertRaisesOpError("my_tensor.*rank"):
array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]})
@test_util.run_in_graph_and_eager_modes
def test_rank_one_tensor_doesnt_raise_if_rank_just_right_static_rank(self):
tensor = constant_op.constant([1, 2], name="my_tensor")
desired_rank = 1
with ops.control_dependencies(
[check_ops.assert_rank(tensor, desired_rank)]):
self.evaluate(array_ops.identity(tensor))
@test_util.run_deprecated_v1
def test_rank_one_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self):
with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
desired_rank = 1
with ops.control_dependencies(
[check_ops.assert_rank(tensor, desired_rank)]):
array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]})
@test_util.run_in_graph_and_eager_modes
def test_rank_one_tensor_raises_if_rank_too_small_static_rank(self):
tensor = constant_op.constant([1, 2], name="my_tensor")
desired_rank = 2
with self.assertRaisesRegex(ValueError, "rank"):
with ops.control_dependencies(
[check_ops.assert_rank(tensor, desired_rank)]):
self.evaluate(array_ops.identity(tensor))
@test_util.run_deprecated_v1
def test_rank_one_tensor_raises_if_rank_too_small_dynamic_rank(self):
with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
desired_rank = 2
with ops.control_dependencies(
[check_ops.assert_rank(tensor, desired_rank)]):
with self.assertRaisesOpError("my_tensor.* must have rank"):
array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]})
@test_util.run_in_graph_and_eager_modes
def test_raises_if_rank_is_not_scalar_static(self):
tensor = constant_op.constant([1, 2], name="my_tensor")
with self.assertRaisesRegex(ValueError, "Argument `rank` must be a scalar"):
check_ops.assert_rank(tensor, np.array([], dtype=np.int32))
@test_util.run_deprecated_v1
def test_raises_if_rank_is_not_scalar_dynamic(self):
with self.cached_session():
tensor = constant_op.constant(
[1, 2], dtype=dtypes.float32, name="my_tensor")
rank_tensor = array_ops.placeholder(dtypes.int32, name="rank_tensor")
with self.assertRaisesOpError("Rank must be a scalar."):
with ops.control_dependencies(
[check_ops.assert_rank(tensor, rank_tensor)]):
array_ops.identity(tensor).eval(feed_dict={rank_tensor: [1, 2]})
@test_util.run_in_graph_and_eager_modes
def test_raises_if_rank_is_not_integer_static(self):
tensor = constant_op.constant([1, 2], name="my_tensor")
with self.assertRaisesRegex(TypeError, "must be of type <dtype: 'int32'>"):
check_ops.assert_rank(tensor, .5)
@test_util.run_deprecated_v1
def test_raises_if_rank_is_not_integer_dynamic(self):
with self.cached_session():
tensor = constant_op.constant(
[1, 2], dtype=dtypes.float32, name="my_tensor")
rank_tensor = array_ops.placeholder(dtypes.float32, name="rank_tensor")
with self.assertRaisesRegex(TypeError,
"must be of type <dtype: 'int32'>"):
with ops.control_dependencies(
[check_ops.assert_rank(tensor, rank_tensor)]):
array_ops.identity(tensor).eval(feed_dict={rank_tensor: .5})
class AssertRankInTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_rank_zero_tensor_raises_if_rank_mismatch_static_rank(self):
tensor_rank0 = constant_op.constant(42, name="my_tensor")
with self.assertRaisesRegex(ValueError, "fail.*must have rank.*1.*2"):
with ops.control_dependencies([
check_ops.assert_rank_in(tensor_rank0, (1, 2), message="fail")]):
self.evaluate(array_ops.identity(tensor_rank0))
@test_util.run_deprecated_v1
def test_rank_zero_tensor_raises_if_rank_mismatch_dynamic_rank(self):
with self.cached_session():
tensor_rank0 = array_ops.placeholder(dtypes.float32, name="my_tensor")
with ops.control_dependencies([
check_ops.assert_rank_in(tensor_rank0, (1, 2), message="fail")]):
with self.assertRaisesOpError("fail.*my_tensor.*rank"):
array_ops.identity(tensor_rank0).eval(feed_dict={tensor_rank0: 42.0})
@test_util.run_in_graph_and_eager_modes
def test_rank_zero_tensor_doesnt_raise_if_rank_matches_static_rank(self):
tensor_rank0 = constant_op.constant(42, name="my_tensor")
for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)):
with ops.control_dependencies([
check_ops.assert_rank_in(tensor_rank0, desired_ranks)]):
self.evaluate(array_ops.identity(tensor_rank0))
@test_util.run_deprecated_v1
def test_rank_zero_tensor_doesnt_raise_if_rank_matches_dynamic_rank(self):
with self.cached_session():
tensor_rank0 = array_ops.placeholder(dtypes.float32, name="my_tensor")
for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)):
with ops.control_dependencies([
check_ops.assert_rank_in(tensor_rank0, desired_ranks)]):
array_ops.identity(tensor_rank0).eval(feed_dict={tensor_rank0: 42.0})
@test_util.run_in_graph_and_eager_modes
def test_rank_one_tensor_doesnt_raise_if_rank_matches_static_rank(self):
tensor_rank1 = constant_op.constant([42, 43], name="my_tensor")
for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)):
with ops.control_dependencies([
check_ops.assert_rank_in(tensor_rank1, desired_ranks)]):
self.evaluate(array_ops.identity(tensor_rank1))
@test_util.run_deprecated_v1
def test_rank_one_tensor_doesnt_raise_if_rank_matches_dynamic_rank(self):
with self.cached_session():
tensor_rank1 = array_ops.placeholder(dtypes.float32, name="my_tensor")
for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)):
with ops.control_dependencies([
check_ops.assert_rank_in(tensor_rank1, desired_ranks)]):
array_ops.identity(tensor_rank1).eval(feed_dict={
tensor_rank1: (42.0, 43.0)
})
@test_util.run_in_graph_and_eager_modes
def test_rank_one_tensor_raises_if_rank_mismatches_static_rank(self):
tensor_rank1 = constant_op.constant((42, 43), name="my_tensor")
with self.assertRaisesRegex(ValueError, "rank"):
with ops.control_dependencies([
check_ops.assert_rank_in(tensor_rank1, (0, 2))]):
self.evaluate(array_ops.identity(tensor_rank1))
@test_util.run_deprecated_v1
def test_rank_one_tensor_raises_if_rank_mismatches_dynamic_rank(self):
with self.cached_session():
tensor_rank1 = array_ops.placeholder(dtypes.float32, name="my_tensor")
with ops.control_dependencies([
check_ops.assert_rank_in(tensor_rank1, (0, 2))]):
with self.assertRaisesOpError("my_tensor.*rank"):
array_ops.identity(tensor_rank1).eval(feed_dict={
tensor_rank1: (42.0, 43.0)
})
@test_util.run_in_graph_and_eager_modes
def test_raises_if_rank_is_not_scalar_static(self):
tensor = constant_op.constant((42, 43), name="my_tensor")
desired_ranks = (
np.array(1, dtype=np.int32),
np.array((2, 1), dtype=np.int32))
with self.assertRaisesRegex(
ValueError, "Argument `ranks` must contain scalar tensors."):
check_ops.assert_rank_in(tensor, desired_ranks)
@test_util.run_deprecated_v1
def test_raises_if_rank_is_not_scalar_dynamic(self):
with self.cached_session():
tensor = constant_op.constant(
(42, 43), dtype=dtypes.float32, name="my_tensor")
desired_ranks = (
array_ops.placeholder(dtypes.int32, name="rank0_tensor"),
array_ops.placeholder(dtypes.int32, name="rank1_tensor"))
with self.assertRaisesOpError("Rank must be a scalar"):
with ops.control_dependencies(
(check_ops.assert_rank_in(tensor, desired_ranks),)):
array_ops.identity(tensor).eval(feed_dict={
desired_ranks[0]: 1,
desired_ranks[1]: [2, 1],
})
@test_util.run_in_graph_and_eager_modes
def test_raises_if_rank_is_not_integer_static(self):
tensor = constant_op.constant((42, 43), name="my_tensor")
with self.assertRaisesRegex(TypeError, "must be of type <dtype: 'int32'>"):
check_ops.assert_rank_in(tensor, (1, .5,))
@test_util.run_deprecated_v1
def test_raises_if_rank_is_not_integer_dynamic(self):
with self.cached_session():
tensor = constant_op.constant(
(42, 43), dtype=dtypes.float32, name="my_tensor")
rank_tensor = array_ops.placeholder(dtypes.float32, name="rank_tensor")
with self.assertRaisesRegex(TypeError,
"must be of type <dtype: 'int32'>"):
with ops.control_dependencies(
[check_ops.assert_rank_in(tensor, (1, rank_tensor))]):
array_ops.identity(tensor).eval(feed_dict={rank_tensor: .5})
class AssertRankAtLeastTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_rank_zero_tensor_raises_if_rank_too_small_static_rank(self):
tensor = constant_op.constant(1, name="my_tensor")
desired_rank = 1
with self.assertRaisesRegex(ValueError, "rank at least 1"):
with ops.control_dependencies(
[check_ops.assert_rank_at_least(tensor, desired_rank)]):
self.evaluate(array_ops.identity(tensor))
@test_util.run_deprecated_v1
def test_rank_zero_tensor_raises_if_rank_too_small_dynamic_rank(self):
with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
desired_rank = 1
with ops.control_dependencies(
[check_ops.assert_rank_at_least(tensor, desired_rank)]):
with self.assertRaisesOpError("my_tensor.*rank"):
array_ops.identity(tensor).eval(feed_dict={tensor: 0})
@test_util.run_in_graph_and_eager_modes
def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_static_rank(self):
tensor = constant_op.constant(1, name="my_tensor")
desired_rank = 0
with ops.control_dependencies(
[check_ops.assert_rank_at_least(tensor, desired_rank)]):
self.evaluate(array_ops.identity(tensor))
@test_util.run_deprecated_v1
def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self):
with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
desired_rank = 0
with ops.control_dependencies(
[check_ops.assert_rank_at_least(tensor, desired_rank)]):
array_ops.identity(tensor).eval(feed_dict={tensor: 0})
@test_util.run_in_graph_and_eager_modes
def test_rank_one_ten_doesnt_raise_raise_if_rank_too_large_static_rank(self):
tensor = constant_op.constant([1, 2], name="my_tensor")
desired_rank = 0
with ops.control_dependencies(
[check_ops.assert_rank_at_least(tensor, desired_rank)]):
self.evaluate(array_ops.identity(tensor))
@test_util.run_deprecated_v1
def test_rank_one_ten_doesnt_raise_if_rank_too_large_dynamic_rank(self):
with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
desired_rank = 0
with ops.control_dependencies(
[check_ops.assert_rank_at_least(tensor, desired_rank)]):
array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]})
@test_util.run_in_graph_and_eager_modes
def test_rank_one_tensor_doesnt_raise_if_rank_just_right_static_rank(self):
tensor = constant_op.constant([1, 2], name="my_tensor")
desired_rank = 1
with ops.control_dependencies(
[check_ops.assert_rank_at_least(tensor, desired_rank)]):
self.evaluate(array_ops.identity(tensor))
@test_util.run_deprecated_v1
def test_rank_one_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self):
with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
desired_rank = 1
with ops.control_dependencies(
[check_ops.assert_rank_at_least(tensor, desired_rank)]):
array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]})
@test_util.run_in_graph_and_eager_modes
def test_rank_one_tensor_raises_if_rank_too_small_static_rank(self):
tensor = constant_op.constant([1, 2], name="my_tensor")
desired_rank = 2
with self.assertRaisesRegex(ValueError, "rank at least 2"):
with ops.control_dependencies(
[check_ops.assert_rank_at_least(tensor, desired_rank)]):
self.evaluate(array_ops.identity(tensor))
@test_util.run_deprecated_v1
def test_rank_one_tensor_raises_if_rank_too_small_dynamic_rank(self):
with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
desired_rank = 2
with ops.control_dependencies(
[check_ops.assert_rank_at_least(tensor, desired_rank)]):
with self.assertRaisesOpError("my_tensor.*rank"):
array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]})
class AssertNonNegativeTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
@test_util.run_deprecated_v1
def test_raises_when_negative(self):
zoe = constant_op.constant([-1, -2], name="zoe")
with self.assertRaisesOpError("x >= 0 did not hold"):
with ops.control_dependencies([check_ops.assert_non_negative(zoe)]):
out = array_ops.identity(zoe)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_zero_and_positive(self):
lucas = constant_op.constant([0, 2], name="lucas")
with ops.control_dependencies([check_ops.assert_non_negative(lucas)]):
out = array_ops.identity(lucas)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_empty_tensor_doesnt_raise(self):
# A tensor is non-negative when it satisfies:
# For every element x_i in x, x_i >= 0
# and an empty tensor has no elements, so this is trivially satisfied.
# This is standard set theory.
empty = constant_op.constant([], name="empty")
with ops.control_dependencies([check_ops.assert_non_negative(empty)]):
out = array_ops.identity(empty)
self.evaluate(out)
def test_static_check_in_graph_mode(self):
with ops.Graph().as_default():
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Custom error message"):
check_ops.assert_non_negative(-1, message="Custom error message")
class AssertNonPositiveTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_zero_and_negative(self):
tom = constant_op.constant([0, -2], name="tom")
with ops.control_dependencies([check_ops.assert_non_positive(tom)]):
out = array_ops.identity(tom)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
@test_util.run_deprecated_v1
def test_raises_when_positive(self):
rachel = constant_op.constant([0, 2], name="rachel")
with self.assertRaisesOpError("x <= 0 did not hold"):
with ops.control_dependencies([check_ops.assert_non_positive(rachel)]):
out = array_ops.identity(rachel)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_empty_tensor_doesnt_raise(self):
# A tensor is non-positive when it satisfies:
# For every element x_i in x, x_i <= 0
# and an empty tensor has no elements, so this is trivially satisfied.
# This is standard set theory.
empty = constant_op.constant([], name="empty")
with ops.control_dependencies([check_ops.assert_non_positive(empty)]):
out = array_ops.identity(empty)
self.evaluate(out)
def test_static_check_in_graph_mode(self):
with ops.Graph().as_default():
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Custom error message"):
check_ops.assert_non_positive(1, message="Custom error message")
class AssertIntegerTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_integer(self):
integers = constant_op.constant([1, 2], name="integers")
with ops.control_dependencies([check_ops.assert_integer(integers)]):
out = array_ops.identity(integers)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_raises_when_float(self):
floats = constant_op.constant([1.0, 2.0], name="floats")
with self.assertRaisesRegex(TypeError, "Expected.*integer"):
check_ops.assert_integer(floats)
class AssertTypeTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_correct_type(self):
integers = constant_op.constant([1, 2], dtype=dtypes.int64)
with ops.control_dependencies([
check_ops.assert_type(integers, dtypes.int64)]):
out = array_ops.identity(integers)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_sparsetensor_doesnt_raise_when_correct_type(self):
sparse_float = sparse_tensor.SparseTensor(
constant_op.constant([[111], [232]], dtypes.int64),
constant_op.constant([23.4, -43.2], dtypes.float32),
constant_op.constant([500], dtypes.int64))
with ops.control_dependencies(
[check_ops.assert_type(sparse_float, dtypes.float32)]):
out = sparse_tensor.SparseTensor(sparse_float.indices,
array_ops.identity(sparse_float.values),
sparse_float.dense_shape)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_raises_when_wrong_type(self):
floats = constant_op.constant([1.0, 2.0], dtype=dtypes.float16)
with self.assertRaisesRegex(TypeError, "must be of type.*float32"):
check_ops.assert_type(floats, dtypes.float32)
@test_util.run_in_graph_and_eager_modes
def test_sparsetensor_raises_when_wrong_type(self):
sparse_float16 = sparse_tensor.SparseTensor(
constant_op.constant([[111], [232]], dtypes.int64),
constant_op.constant([23.4, -43.2], dtypes.float16),
constant_op.constant([500], dtypes.int64))
with self.assertRaisesRegex(TypeError, "must be of type.*float32"):
check_ops.assert_type(sparse_float16, dtypes.float32)
def test_raise_when_tf_type_is_not_dtype(self):
# Test case for GitHub issue:
# https://github.com/tensorflow/tensorflow/issues/45975
value = constant_op.constant(0.0)
with self.assertRaisesRegex(TypeError,
"Cannot convert.*to a TensorFlow DType"):
check_ops.assert_type(value, (dtypes.float32,))
class AssertShapesTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_raise_static_shape_mismatch(self):
x = array_ops.ones([3, 2], name="x")
y = array_ops.ones([2, 3], name="y")
shapes = [
(x, ("N", "Q")),
(y, ("N", "D")),
]
regex = (r"Specified by tensor .* dimension 0. "
r"Tensor .* dimension 0 must have size 3. "
r"Received size 2")
self.raises_static_error(shapes=shapes, regex=regex)
def test_raise_dynamic_shape_mismatch(self):
with ops.Graph().as_default():
x = array_ops.placeholder(dtypes.float32, [None, 2], name="x")
y = array_ops.placeholder(dtypes.float32, [None, 3], name="y")
shapes = [
(x, ("N", "Q")),
(y, ("N", "D")),
]
regex = (r"\[Specified by tensor x.* dimension 0\] "
r"\[Tensor y.* dimension\] \[0\] \[must have size\] \[3\]")
feed_dict = {x: np.ones([3, 2]), y: np.ones([2, 3])}
self.raises_dynamic_error(shapes=shapes, regex=regex, feed_dict=feed_dict)
@test_util.run_in_graph_and_eager_modes
def test_raise_static_shape_explicit_mismatch(self):
x = array_ops.ones([3, 2], name="x")
y = array_ops.ones([2, 3], name="y")
shapes = [
(x, (3, "Q")),
(y, (3, "D")),
]
regex = (r"Specified explicitly. "
r"Tensor .* dimension 0 must have size 3. "
r"Received size 2")
self.raises_static_error(shapes=shapes, regex=regex)
@test_util.run_in_graph_and_eager_modes
def test_rank_zero_rank_one_size_one_equivalence(self):
rank_one_size_one = array_ops.ones([1], name="rank_one_size_one")
rank_zero = array_ops.constant(5, name="rank_zero")
check_ops.assert_shapes([
(rank_one_size_one, ()),
(rank_zero, ()),
])
check_ops.assert_shapes([
(rank_one_size_one, (1,)),
(rank_zero, (1,)),
])
@test_util.run_in_graph_and_eager_modes
def test_raise_static_rank_1_size_not_1_mismatch_scalar(self):
x = array_ops.constant([2, 2], name="x")
shapes = [
(x, ()),
]
regex = (r"Specified explicitly. "
r"Tensor .* dimension 0 must have size 1. "
r"Received size 2")
self.raises_static_error(shapes=shapes, regex=regex)
@test_util.run_in_graph_and_eager_modes
def test_raise_static_scalar_mismatch_rank_1_size_not_1(self):
x = array_ops.constant(2, name="x")
shapes = [
(x, (2,)),
]
regex = (r"Specified explicitly. "
r"Tensor .* dimension 0 must have size 2. "
r"Received size 1")
self.raises_static_error(shapes=shapes, regex=regex)
@test_util.run_in_graph_and_eager_modes
def test_scalar_implies_size_one(self):
scalar = array_ops.constant(5, name="rank_zero")
x = array_ops.ones([2, 2], name="x")
shapes = [(scalar, ("a",)), (x, ("a", 2))]
regex = (r"Specified by tensor .* dimension 0. "
r"Tensor .* dimension 0 must have size 1. "
r"Received size 2")
self.raises_static_error(shapes=shapes, regex=regex)
@test_util.run_in_graph_and_eager_modes
def test_raise_not_iterable(self):
x = array_ops.constant([1, 2], name="x")
shapes = [(x, 2)]
regex = (r"Tensor .*. "
r"Specified shape must be an iterable. "
r"An iterable has the attribute `__iter__` or `__getitem__`. "
r"Received specified shape: 2")
self.raises_static_error(shapes=shapes, regex=regex)
def test_raise_dynamic_shape_explicit_mismatch(self):
with ops.Graph().as_default():
x = array_ops.placeholder(dtypes.float32, [None, 2], name="xa")
y = array_ops.placeholder(dtypes.float32, [None, 3], name="y")
shapes = [
(x, (3, "Q")),
(y, (3, "D")),
]
regex = (r"\[Specified explicitly\] "
r"\[Tensor y.* dimension\] \[0\] \[must have size\] \[3\]")
feed_dict = {x: np.ones([3, 2]), y: np.ones([2, 3])}
self.raises_dynamic_error(shapes=shapes, regex=regex, feed_dict=feed_dict)
@test_util.run_in_graph_and_eager_modes
def test_no_op_when_specified_as_unknown(self):
x = array_ops.constant([1, 1], name="x")
assertion = check_ops.assert_shapes([(x, None)])
with ops.control_dependencies([assertion]):
out = array_ops.identity(x)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_raises_static_incorrect_rank(self):
rank_two_shapes = [
(1, 1),
(1, 3),
("a", "b"),
(None, None),
]
rank_three_shapes = [
(1, 1, 1),
("a", "b", "c"),
(None, None, None),
(1, "b", None),
]
def raises_static_rank_error(shapes, x, correct_rank, actual_rank):
for shape in shapes:
regex = (r"Tensor .* must have rank %d. Received rank %d" %
(correct_rank, actual_rank))
self.raises_static_error(shapes=[(x, shape)], regex=regex)
raises_static_rank_error(
rank_two_shapes, array_ops.ones([1]), correct_rank=2, actual_rank=1)
raises_static_rank_error(
rank_three_shapes,
array_ops.ones([1, 1]),
correct_rank=3,
actual_rank=2)
raises_static_rank_error(
rank_three_shapes, array_ops.constant(1), correct_rank=3, actual_rank=0)
def test_raises_dynamic_incorrect_rank(self):
x_value = 5
rank_two_shapes = [(1, 1), (1, 3), ("a", "b"), (None, None)]
with ops.Graph().as_default():
x = array_ops.placeholder(dtypes.float32, None)
for shape in rank_two_shapes:
regex = r"Tensor .* must have rank\] \[2\]"
self.raises_dynamic_error(
shapes=[(x, shape)], regex=regex, feed_dict={x: x_value})
@test_util.run_in_graph_and_eager_modes
def test_correctly_matching(self):
u = array_ops.constant(1, name="u")
v = array_ops.ones([1, 2], name="v")
w = array_ops.ones([3], name="w")
x = array_ops.ones([1, 2, 3], name="x")
y = array_ops.ones([3, 1, 2], name="y")
z = array_ops.ones([2, 3, 1], name="z")
assertion = check_ops.assert_shapes([
(x, ("a", "b", "c")),
(y, ("c", "a", "b")),
(z, ("b", "c", "a")),
(v, ("a", "b")),
(w, ("c",)),
(u, "a")
])
with ops.control_dependencies([assertion]):
out = array_ops.identity(x)
self.evaluate(out)
assertion = check_ops.assert_shapes([
(x, (1, "b", "c")),
(y, ("c", "a", 2)),
(z, ("b", 3, "a")),
(v, ("a", 2)),
(w, (3,)),
(u, ())
])
with ops.control_dependencies([assertion]):
out = array_ops.identity(x)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_variable_length_symbols(self):
x = array_ops.ones([4, 1], name="x")
y = array_ops.ones([4, 2], name="y")
assertion = check_ops.assert_shapes([
(x, ("num_observations", "input_dim")),
(y, ("num_observations", "output_dim")),
])
with ops.control_dependencies([assertion]):
out = array_ops.identity(x)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_raise_implicit_mismatch_using_iterable_alternatives(self):
x = array_ops.ones([2, 2], name="x")
y = array_ops.ones([1, 3], name="y")
styles = [[
(x, ("A", "B")),
(y, ("A", "C")),
], [
(x, "AB"),
(y, "AC")
], [
(x, ["A", "B"]),
(y, ["A", "C"]),
], [
(x, np.array(["A", "B"])),
(y, np.array(["A", "C"]))
], [
(x, ("A", "B")),
(y, "AC")
]]
for shapes in styles:
self.raises_static_error(
shapes=shapes,
regex=(r"Specified by tensor .* dimension 0. "
"Tensor .* dimension 0 must have size 2. "
"Received size 1"))
@test_util.run_in_graph_and_eager_modes
def test_raise_explicit_mismatch_using_iterable_alternatives(self):
x = array_ops.ones([2, 2], name="x")
y = array_ops.ones([1, 3], name="y")
styles = [[
(x, (2, 2)),
(y, (2, 3)),
], [
(x, "22"),
(y, "23")
], [
(x, [2, 2]),
(y, [2, 3]),
], [
(x, np.array([2, 2])),
(y, np.array([2, 3]))
], [
(x, (2, 2)),
(y, "23")
]]
for shapes in styles:
self.raises_static_error(
shapes=shapes,
regex=(r"Specified explicitly. "
"Tensor .* dimension 0 must have size 2. "
"Received size 1"))
@test_util.run_in_graph_and_eager_modes
def test_dim_size_specified_as_unknown(self):
x = array_ops.ones([1, 2, 3], name="x")
y = array_ops.ones([2, 1], name="y")
a1 = check_ops.assert_shapes([
(x, (None, 2, None)),
(y, (None, 1)),
])
a2 = check_ops.assert_shapes([
(x, (".", 2, ".")),
(y, (".", 1)),
])
a3 = check_ops.assert_shapes([
(x, ".2."),
(y, ".1"),
])
with ops.control_dependencies([a1, a2, a3]):
out = array_ops.identity(x)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_raise_static_shape_explicit_mismatch_innermost_dims(self):
x = array_ops.ones([3, 2], name="x")
y = array_ops.ones([2, 3], name="y")
s1 = [
(x, (3, "Q")),
(y, (Ellipsis, 3, "D")),
]
s2 = [
(x, "3Q"),
(y, "*3D"),
]
regex = (r"Specified explicitly. "
r"Tensor .* dimension -2 must have size 3. "
r"Received size 2")
self.raises_static_error(shapes=s1, regex=regex)
self.raises_static_error(shapes=s2, regex=regex)
@test_util.run_in_graph_and_eager_modes
def test_correctly_matching_innermost_dims(self):
x = array_ops.ones([1, 2, 3, 2], name="x")
y = array_ops.ones([2, 3, 3], name="y")
a1 = check_ops.assert_shapes([
(x, (Ellipsis, "N", "Q")),
(y, (Ellipsis, "N", "D")),
])
a2 = check_ops.assert_shapes([
(x, "*NQ"),
(y, "*ND"),
])
with ops.control_dependencies([a1, a2]):
out = array_ops.identity(x)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_raise_variable_num_outer_dims_prefix_misuse(self):
x = array_ops.ones([1, 2], name="x")
s1 = [
(x, ("N", Ellipsis, "Q")),
]
s2 = [
(x, "N*Q"),
]
regex = (r"Tensor .* specified shape index .*. "
r"Symbol `...` or `\*` for a variable number of "
r"unspecified dimensions is only allowed as the first entry")
self.raises_static_error(shapes=s1, regex=regex)
self.raises_static_error(shapes=s2, regex=regex)
@test_util.run_in_graph_and_eager_modes
def test_empty_shapes_dict_no_op(self):
assertion = check_ops.assert_shapes([])
with ops.control_dependencies([assertion]):
out = array_ops.identity(0)
self.evaluate(out)
def raises_static_error(self, shapes, regex):
with self.assertRaisesRegex(ValueError, regex):
check_ops.assert_shapes(shapes)
def raises_dynamic_error(self, shapes, regex, feed_dict):
with self.session() as sess:
with self.assertRaisesRegex(errors.InvalidArgumentError, regex):
assertion = check_ops.assert_shapes(shapes)
with ops.control_dependencies([assertion]):
out = array_ops.identity(0)
sess.run(out, feed_dict=feed_dict)
class AssertShapesSparseTensorTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_assert_shapes_sparse_tensor_scalar_target_success(self):
sparse_float = sparse_tensor.SparseTensor(
constant_op.constant([[]], dtypes.int64),
constant_op.constant([42], dtypes.float32),
constant_op.constant([], dtypes.int64))
assertion = check_ops.assert_shapes([(sparse_float, [])])
with ops.control_dependencies([assertion]):
out = array_ops.identity(sparse_float)
self.evaluate(out)
def test_assert_shapes_sparse_tensor_nonscalar_target_fail(self):
sparse_float = sparse_tensor.SparseTensor(
constant_op.constant([[]], dtypes.int64),
constant_op.constant([42], dtypes.float32),
constant_op.constant([], dtypes.int64))
with self.assertRaisesRegexp(ValueError,
r"must have rank 2.*Received rank 0"):
assertion = check_ops.assert_shapes([(sparse_float, [None, None])])
with ops.control_dependencies([assertion]):
out = array_ops.identity(sparse_float)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_assert_shapes_sparse_tensor_fully_specified_target_success(self):
sparse_float = sparse_tensor.SparseTensor(
constant_op.constant([[111], [232]], dtypes.int64),
constant_op.constant([23.4, -43.2], dtypes.float32),
constant_op.constant([500], dtypes.int64))
assertion = check_ops.assert_shapes([(sparse_float, [500])])
with ops.control_dependencies([assertion]):
out = array_ops.identity(sparse_float)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_assert_shapes_sparse_tensor_fully_specified_target_fail(self):
sparse_float = sparse_tensor.SparseTensor(
constant_op.constant([[111], [232]], dtypes.int64),
constant_op.constant([23.4, -43.2], dtypes.float32),
constant_op.constant([500], dtypes.int64))
with self.assertRaisesRegex(ValueError, r"dimension 0 must have size 499"):
assertion = check_ops.assert_shapes([(sparse_float, [499])])
with ops.control_dependencies([assertion]):
out = array_ops.identity(sparse_float)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_assert_shapes_sparse_tensor_partially_specified_target_success(self):
sparse_int = sparse_tensor.SparseTensor(
constant_op.constant([[5, 6], [7, 8]], dtypes.int64),
constant_op.constant([23, -43], dtypes.int32),
constant_op.constant([30, 40], dtypes.int64))
assertion = check_ops.assert_shapes([(sparse_int, [None, 40])])
with ops.control_dependencies([assertion]):
out = array_ops.identity(sparse_int)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_assert_shapes_sparse_tensor_symbolic_match_success(self):
sparse_int = sparse_tensor.SparseTensor(
constant_op.constant([[5, 6, 7], [8, 9, 10]], dtypes.int64),
constant_op.constant([23, -43], dtypes.int32),
constant_op.constant([30, 30, 40], dtypes.int64))
assertion = check_ops.assert_shapes([(sparse_int, ["N", "N", "D"])])
with ops.control_dependencies([assertion]):
out = array_ops.identity(sparse_int)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_assert_shapes_sparse_tensor_partially_specified_target_fail(self):
sparse_int = sparse_tensor.SparseTensor(
constant_op.constant([[5, 6], [7, 8]], dtypes.int64),
constant_op.constant([23, -43], dtypes.int32),
constant_op.constant([30, 40], dtypes.int64))
with self.assertRaisesRegex(ValueError, r"dimension 1 must have size 41"):
assertion = check_ops.assert_shapes([(sparse_int, [None, 41])])
with ops.control_dependencies([assertion]):
out = array_ops.identity(sparse_int)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_assert_shapes_sparse_tensor_wrong_rank_fail(self):
sparse_int = sparse_tensor.SparseTensor(
constant_op.constant([[5, 6], [7, 8]], dtypes.int64),
constant_op.constant([23, -43], dtypes.int32),
constant_op.constant([30, 40], dtypes.int64))
with self.assertRaisesRegexp(ValueError,
r"must have rank 3\..* Received rank 2"):
assertion = check_ops.assert_shapes([(sparse_int, [None, None, 40])])
with ops.control_dependencies([assertion]):
out = array_ops.identity(sparse_int)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_assert_shapes_sparse_tensor_wrong_symbolic_match_fail(self):
sparse_int = sparse_tensor.SparseTensor(
constant_op.constant([[5, 6], [7, 8]], dtypes.int64),
constant_op.constant([23, -43], dtypes.int32),
constant_op.constant([30, 40], dtypes.int64))
with self.assertRaisesRegex(ValueError, r"dimension 1 must have size 30"):
assertion = check_ops.assert_shapes([(sparse_int, ["D", "D"])])
with ops.control_dependencies([assertion]):
out = array_ops.identity(sparse_int)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_assert_shapes_sparse_tensor_multiple_assertions_success(self):
sparse_scalar = sparse_tensor.SparseTensor(
constant_op.constant([[]], dtypes.int64),
constant_op.constant([42], dtypes.float32),
constant_op.constant([], dtypes.int64))
sparse_2d = sparse_tensor.SparseTensor(
constant_op.constant([[5, 6], [7, 8]], dtypes.int64),
constant_op.constant([23, -43], dtypes.int32),
constant_op.constant([30, 30], dtypes.int64))
assertion = check_ops.assert_shapes([(sparse_scalar, []),
(sparse_2d, ["N", "N"])])
with ops.control_dependencies([assertion]):
out = array_ops.identity(sparse_2d)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_assert_shapes_sparse_tensor_multiple_assertions_fail(self):
sparse_scalar = sparse_tensor.SparseTensor(
constant_op.constant([[]], dtypes.int64),
constant_op.constant([42], dtypes.float32),
constant_op.constant([], dtypes.int64))
sparse_2d = sparse_tensor.SparseTensor(
constant_op.constant([[5, 6], [7, 8]], dtypes.int64),
constant_op.constant([23, -43], dtypes.int32),
constant_op.constant([30, 40], dtypes.int64))
with self.assertRaisesRegex(ValueError, r"dimension 1 must have size 30"):
assertion = check_ops.assert_shapes([(sparse_scalar, []),
(sparse_2d, ["N", "N"])])
with ops.control_dependencies([assertion]):
out = array_ops.identity(sparse_2d)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_assert_shapes_sparse_tensor_mixed_dense_and_sparse_success(self):
dense_scalar = constant_op.constant([42], dtypes.float32)
sparse_2d = sparse_tensor.SparseTensor(
constant_op.constant([[5, 6], [7, 8]], dtypes.int64),
constant_op.constant([23, -43], dtypes.int32),
constant_op.constant([30, 30], dtypes.int64))
assertion = check_ops.assert_shapes([(dense_scalar, []),
(sparse_2d, ["N", "N"])])
with ops.control_dependencies([assertion]):
out = array_ops.identity(sparse_2d)
self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_assert_shapes_sparse_tensor_mixed_dense_and_sparse_fail(self):
dense_scalar = constant_op.constant([42], dtypes.float32)
sparse_2d = sparse_tensor.SparseTensor(
constant_op.constant([[5, 6], [7, 8]], dtypes.int64),
constant_op.constant([23, -43], dtypes.int32),
constant_op.constant([30, 40], dtypes.int64))
with self.assertRaisesRegex(ValueError, r"dimension 1 must have size 30"):
assertion = check_ops.assert_shapes([(dense_scalar, []),
(sparse_2d, ["N", "N"])])
with ops.control_dependencies([assertion]):
out = array_ops.identity(sparse_2d)
self.evaluate(out)
class IsStrictlyIncreasingTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_constant_tensor_is_not_strictly_increasing(self):
self.assertFalse(self.evaluate(check_ops.is_strictly_increasing([1, 1, 1])))
@test_util.run_in_graph_and_eager_modes
def test_decreasing_tensor_is_not_strictly_increasing(self):
self.assertFalse(self.evaluate(
check_ops.is_strictly_increasing([1, 0, -1])))
@test_util.run_in_graph_and_eager_modes
def test_2d_decreasing_tensor_is_not_strictly_increasing(self):
self.assertFalse(
self.evaluate(check_ops.is_strictly_increasing([[1, 3], [2, 4]])))
@test_util.run_in_graph_and_eager_modes
def test_increasing_tensor_is_increasing(self):
self.assertTrue(self.evaluate(check_ops.is_strictly_increasing([1, 2, 3])))
@test_util.run_in_graph_and_eager_modes
def test_increasing_rank_two_tensor(self):
self.assertTrue(
self.evaluate(check_ops.is_strictly_increasing([[-1, 2], [3, 4]])))
@test_util.run_in_graph_and_eager_modes
def test_tensor_with_one_element_is_strictly_increasing(self):
self.assertTrue(self.evaluate(check_ops.is_strictly_increasing([1])))
@test_util.run_in_graph_and_eager_modes
def test_empty_tensor_is_strictly_increasing(self):
self.assertTrue(self.evaluate(check_ops.is_strictly_increasing([])))
class IsNonDecreasingTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_constant_tensor_is_non_decreasing(self):
self.assertTrue(self.evaluate(check_ops.is_non_decreasing([1, 1, 1])))
@test_util.run_in_graph_and_eager_modes
def test_decreasing_tensor_is_not_non_decreasing(self):
self.assertFalse(self.evaluate(check_ops.is_non_decreasing([3, 2, 1])))
@test_util.run_in_graph_and_eager_modes
def test_2d_decreasing_tensor_is_not_non_decreasing(self):
self.assertFalse(self.evaluate(
check_ops.is_non_decreasing([[1, 3], [2, 4]])))
@test_util.run_in_graph_and_eager_modes
def test_increasing_rank_one_tensor_is_non_decreasing(self):
self.assertTrue(self.evaluate(check_ops.is_non_decreasing([1, 2, 3])))
@test_util.run_in_graph_and_eager_modes
def test_increasing_rank_two_tensor(self):
self.assertTrue(self.evaluate(
check_ops.is_non_decreasing([[-1, 2], [3, 3]])))
@test_util.run_in_graph_and_eager_modes
def test_tensor_with_one_element_is_non_decreasing(self):
self.assertTrue(self.evaluate(check_ops.is_non_decreasing([1])))
@test_util.run_in_graph_and_eager_modes
def test_empty_tensor_is_non_decreasing(self):
self.assertTrue(self.evaluate(check_ops.is_non_decreasing([])))
class FloatDTypeTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_assert_same_float_dtype(self):
self.assertIs(dtypes.float32,
check_ops.assert_same_float_dtype(None, None))
self.assertIs(dtypes.float32, check_ops.assert_same_float_dtype([], None))
self.assertIs(dtypes.float32,
check_ops.assert_same_float_dtype([], dtypes.float32))
self.assertIs(dtypes.float32,
check_ops.assert_same_float_dtype(None, dtypes.float32))
self.assertIs(dtypes.float32,
check_ops.assert_same_float_dtype([None, None], None))
self.assertIs(
dtypes.float32,
check_ops.assert_same_float_dtype([None, None], dtypes.float32))
const_float = constant_op.constant(3.0, dtype=dtypes.float32)
self.assertIs(
dtypes.float32,
check_ops.assert_same_float_dtype([const_float], dtypes.float32))
self.assertRaises(ValueError, check_ops.assert_same_float_dtype,
[const_float], dtypes.int32)
sparse_float = sparse_tensor.SparseTensor(
constant_op.constant([[111], [232]], dtypes.int64),
constant_op.constant([23.4, -43.2], dtypes.float32),
constant_op.constant([500], dtypes.int64))
self.assertIs(dtypes.float32,
check_ops.assert_same_float_dtype([sparse_float],
dtypes.float32))
self.assertRaises(ValueError, check_ops.assert_same_float_dtype,
[sparse_float], dtypes.int32)
self.assertRaises(ValueError, check_ops.assert_same_float_dtype,
[const_float, None, sparse_float], dtypes.float64)
self.assertIs(dtypes.float32,
check_ops.assert_same_float_dtype(
[const_float, sparse_float]))
self.assertIs(dtypes.float32,
check_ops.assert_same_float_dtype(
[const_float, sparse_float], dtypes.float32))
const_int = constant_op.constant(3, dtype=dtypes.int32)
self.assertRaises(ValueError, check_ops.assert_same_float_dtype,
[sparse_float, const_int])
self.assertRaises(ValueError, check_ops.assert_same_float_dtype,
[sparse_float, const_int], dtypes.int32)
self.assertRaises(ValueError, check_ops.assert_same_float_dtype,
[sparse_float, const_int], dtypes.float32)
self.assertRaises(ValueError, check_ops.assert_same_float_dtype,
[const_int])
class AssertScalarTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_assert_scalar(self):
check_ops.assert_scalar(constant_op.constant(3))
check_ops.assert_scalar(constant_op.constant("foo"))
check_ops.assert_scalar(3)
check_ops.assert_scalar("foo")
with self.assertRaisesRegex(ValueError, "Expected scalar"):
check_ops.assert_scalar(constant_op.constant([3, 4]))
if __name__ == "__main__":
test.main()