blob: 213ecbcf764df65abfce7a3ee4e819507800eb01 [file] [log] [blame]
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.test_util."""
import collections
import copy
import random
import threading
import unittest
import weakref
from absl.testing import parameterized
import numpy as np
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python import pywrap_sanitizers
from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import test_ops # pylint: disable=unused-import
from tensorflow.python.framework import test_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
class TestUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase):
def test_assert_ops_in_graph(self):
with ops.Graph().as_default():
constant_op.constant(["hello", "taffy"], name="hello")
test_util.assert_ops_in_graph({"hello": "Const"}, ops.get_default_graph())
self.assertRaises(ValueError, test_util.assert_ops_in_graph,
{"bye": "Const"}, ops.get_default_graph())
self.assertRaises(ValueError, test_util.assert_ops_in_graph,
{"hello": "Variable"}, ops.get_default_graph())
@test_util.run_deprecated_v1
def test_session_functions(self):
with self.test_session() as sess:
sess_ref = weakref.ref(sess)
with self.cached_session(graph=None, config=None) as sess2:
# We make sure that sess2 is sess.
assert sess2 is sess
# We make sure we raise an exception if we use cached_session with
# different values.
with self.assertRaises(ValueError):
with self.cached_session(graph=ops.Graph()) as sess2:
pass
with self.assertRaises(ValueError):
with self.cached_session(force_gpu=True) as sess2:
pass
# We make sure that test_session will cache the session even after the
# with scope.
assert not sess_ref()._closed
with self.session() as unique_sess:
unique_sess_ref = weakref.ref(unique_sess)
with self.session() as sess2:
assert sess2 is not unique_sess
# We make sure the session is closed when we leave the with statement.
assert unique_sess_ref()._closed
def test_assert_equal_graph_def(self):
with ops.Graph().as_default() as g:
def_empty = g.as_graph_def()
constant_op.constant(5, name="five")
constant_op.constant(7, name="seven")
def_57 = g.as_graph_def()
with ops.Graph().as_default() as g:
constant_op.constant(7, name="seven")
constant_op.constant(5, name="five")
def_75 = g.as_graph_def()
# Comparing strings is order dependent
self.assertNotEqual(str(def_57), str(def_75))
# assert_equal_graph_def doesn't care about order
test_util.assert_equal_graph_def(def_57, def_75)
# Compare two unequal graphs
with self.assertRaisesRegex(AssertionError,
r"^Found unexpected node '{{node seven}}"):
test_util.assert_equal_graph_def(def_57, def_empty)
def test_assert_equal_graph_def_hash_table(self):
def get_graph_def():
with ops.Graph().as_default() as g:
x = constant_op.constant([2, 9], name="x")
keys = constant_op.constant([1, 2], name="keys")
values = constant_op.constant([3, 4], name="values")
default = constant_op.constant(-1, name="default")
table = lookup_ops.StaticHashTable(
lookup_ops.KeyValueTensorInitializer(keys, values), default)
_ = table.lookup(x)
return g.as_graph_def()
def_1 = get_graph_def()
def_2 = get_graph_def()
# The unique shared_name of each table makes the graph unequal.
with self.assertRaisesRegex(AssertionError, "hash_table_"):
test_util.assert_equal_graph_def(def_1, def_2,
hash_table_shared_name=False)
# That can be ignored. (NOTE: modifies GraphDefs in-place.)
test_util.assert_equal_graph_def(def_1, def_2,
hash_table_shared_name=True)
def testIsGoogleCudaEnabled(self):
# The test doesn't assert anything. It ensures the py wrapper
# function is generated correctly.
if test_util.IsGoogleCudaEnabled():
print("GoogleCuda is enabled")
else:
print("GoogleCuda is disabled")
def testIsMklEnabled(self):
# This test doesn't assert anything.
# It ensures the py wrapper function is generated correctly.
if test_util.IsMklEnabled():
print("MKL is enabled")
else:
print("MKL is disabled")
@test_util.disable_asan("Skip test if ASAN is enabled.")
def testDisableAsan(self):
self.assertFalse(pywrap_sanitizers.is_asan_enabled())
@test_util.disable_msan("Skip test if MSAN is enabled.")
def testDisableMsan(self):
self.assertFalse(pywrap_sanitizers.is_msan_enabled())
@test_util.disable_tsan("Skip test if TSAN is enabled.")
def testDisableTsan(self):
self.assertFalse(pywrap_sanitizers.is_tsan_enabled())
@test_util.disable_ubsan("Skip test if UBSAN is enabled.")
def testDisableUbsan(self):
self.assertFalse(pywrap_sanitizers.is_ubsan_enabled())
@test_util.run_in_graph_and_eager_modes
def testAssertProtoEqualsStr(self):
graph_str = "node { name: 'w1' op: 'params' }"
graph_def = graph_pb2.GraphDef()
text_format.Merge(graph_str, graph_def)
# test string based comparison
self.assertProtoEquals(graph_str, graph_def)
# test original comparison
self.assertProtoEquals(graph_def, graph_def)
@test_util.run_in_graph_and_eager_modes
def testAssertProtoEqualsAny(self):
# Test assertProtoEquals with a protobuf.Any field.
meta_graph_def_str = """
meta_info_def {
meta_graph_version: "outer"
any_info {
[type.googleapis.com/tensorflow.MetaGraphDef] {
meta_info_def {
meta_graph_version: "inner"
}
}
}
}
"""
meta_graph_def_outer = meta_graph_pb2.MetaGraphDef()
meta_graph_def_outer.meta_info_def.meta_graph_version = "outer"
meta_graph_def_inner = meta_graph_pb2.MetaGraphDef()
meta_graph_def_inner.meta_info_def.meta_graph_version = "inner"
meta_graph_def_outer.meta_info_def.any_info.Pack(meta_graph_def_inner)
self.assertProtoEquals(meta_graph_def_str, meta_graph_def_outer)
self.assertProtoEquals(meta_graph_def_outer, meta_graph_def_outer)
# Check if the assertion failure message contains the content of
# the inner proto.
with self.assertRaisesRegex(AssertionError, r'meta_graph_version: "inner"'):
self.assertProtoEquals("", meta_graph_def_outer)
@test_util.run_in_graph_and_eager_modes
def testNDArrayNear(self):
a1 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
a2 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
a3 = np.array([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]])
self.assertTrue(self._NDArrayNear(a1, a2, 1e-5))
self.assertFalse(self._NDArrayNear(a1, a3, 1e-5))
@test_util.run_in_graph_and_eager_modes
def testCheckedThreadSucceeds(self):
def noop(ev):
ev.set()
event_arg = threading.Event()
self.assertFalse(event_arg.is_set())
t = self.checkedThread(target=noop, args=(event_arg,))
t.start()
t.join()
self.assertTrue(event_arg.is_set())
@test_util.run_in_graph_and_eager_modes
def testCheckedThreadFails(self):
def err_func():
return 1 // 0
t = self.checkedThread(target=err_func)
t.start()
with self.assertRaises(self.failureException) as fe:
t.join()
self.assertTrue("integer division or modulo by zero" in str(fe.exception))
@test_util.run_in_graph_and_eager_modes
def testCheckedThreadWithWrongAssertionFails(self):
x = 37
def err_func():
self.assertTrue(x < 10)
t = self.checkedThread(target=err_func)
t.start()
with self.assertRaises(self.failureException) as fe:
t.join()
self.assertTrue("False is not true" in str(fe.exception))
@test_util.run_in_graph_and_eager_modes
def testMultipleThreadsWithOneFailure(self):
def err_func(i):
self.assertTrue(i != 7)
threads = [
self.checkedThread(
target=err_func, args=(i,)) for i in range(10)
]
for t in threads:
t.start()
for i, t in enumerate(threads):
if i == 7:
with self.assertRaises(self.failureException):
t.join()
else:
t.join()
def _WeMustGoDeeper(self, msg):
with self.assertRaisesOpError(msg):
with ops.Graph().as_default():
node_def = ops._NodeDef("IntOutput", "name")
node_def_orig = ops._NodeDef("IntOutput", "orig")
op_orig = ops.Operation(node_def_orig, ops.get_default_graph())
op = ops.Operation(node_def, ops.get_default_graph(),
original_op=op_orig)
raise errors.UnauthenticatedError(node_def, op, "true_err")
@test_util.run_in_graph_and_eager_modes
def testAssertRaisesOpErrorDoesNotPassMessageDueToLeakedStack(self):
with self.assertRaises(AssertionError):
self._WeMustGoDeeper("this_is_not_the_error_you_are_looking_for")
self._WeMustGoDeeper("true_err")
self._WeMustGoDeeper("name")
self._WeMustGoDeeper("orig")
@test_util.run_in_graph_and_eager_modes
def testAllCloseTensors(self):
a_raw_data = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
a = constant_op.constant(a_raw_data)
b = math_ops.add(1, constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8]]))
self.assertAllClose(a, b)
self.assertAllClose(a, a_raw_data)
a_dict = {"key": a}
b_dict = {"key": b}
self.assertAllClose(a_dict, b_dict)
# Disable this subtest until we debug the new np.array() coercion behavior
# https://numpy.org/doc/stable/release/1.20.0-notes.html#array-coercion-restructure
# x_list is of the form [Op1, Op2] in np<1.20 this works fine, but in
# >=1.20 it behaves like np.array([np.array(Op1), np.array(op2)]) which
# doesn't work in either 1.19 or 1.20.
# TODO(b/202303409): Disable the gate once we fix assertAllClose or fix
# a deeper issue with conversion.
versions = np.version.version.split(".")
major, minor = int(versions[0]), int(versions[1])
if major == 1 and minor < 20:
x_list = [a, b]
y_list = [a_raw_data, b]
self.assertAllClose(x_list, y_list)
@test_util.run_in_graph_and_eager_modes
def testAllCloseScalars(self):
self.assertAllClose(7, 7 + 1e-8)
with self.assertRaisesRegex(AssertionError, r"Not equal to tolerance"):
self.assertAllClose(7, 7 + 1e-5)
@test_util.run_in_graph_and_eager_modes
def testAllCloseList(self):
with self.assertRaisesRegex(AssertionError, r"not close dif"):
self.assertAllClose([0], [1])
@test_util.run_in_graph_and_eager_modes
def testAllCloseDictToNonDict(self):
with self.assertRaisesRegex(ValueError, r"Can't compare dict to non-dict"):
self.assertAllClose(1, {"a": 1})
with self.assertRaisesRegex(ValueError, r"Can't compare dict to non-dict"):
self.assertAllClose({"a": 1}, 1)
@test_util.run_in_graph_and_eager_modes
def testAllCloseNamedtuples(self):
a = 7
b = (2., 3.)
c = np.ones((3, 2, 4)) * 7.
expected = {"a": a, "b": b, "c": c}
my_named_tuple = collections.namedtuple("MyNamedTuple", ["a", "b", "c"])
# Identity.
self.assertAllClose(expected, my_named_tuple(a=a, b=b, c=c))
self.assertAllClose(
my_named_tuple(a=a, b=b, c=c), my_named_tuple(a=a, b=b, c=c))
@test_util.run_in_graph_and_eager_modes
def testAllCloseDicts(self):
a = 7
b = (2., 3.)
c = np.ones((3, 2, 4)) * 7.
expected = {"a": a, "b": b, "c": c}
# Identity.
self.assertAllClose(expected, expected)
self.assertAllClose(expected, dict(expected))
# With each item removed.
for k in expected:
actual = dict(expected)
del actual[k]
with self.assertRaisesRegex(AssertionError, r"mismatched keys"):
self.assertAllClose(expected, actual)
# With each item changed.
with self.assertRaisesRegex(AssertionError, r"Not equal to tolerance"):
self.assertAllClose(expected, {"a": a + 1e-5, "b": b, "c": c})
with self.assertRaisesRegex(AssertionError, r"Shape mismatch"):
self.assertAllClose(expected, {"a": a, "b": b + (4.,), "c": c})
c_copy = np.array(c)
c_copy[1, 1, 1] += 1e-5
with self.assertRaisesRegex(AssertionError, r"Not equal to tolerance"):
self.assertAllClose(expected, {"a": a, "b": b, "c": c_copy})
@test_util.run_in_graph_and_eager_modes
def testAllCloseListOfNamedtuples(self):
my_named_tuple = collections.namedtuple("MyNamedTuple", ["x", "y"])
l1 = [
my_named_tuple(x=np.array([[2.3, 2.5]]), y=np.array([[0.97, 0.96]])),
my_named_tuple(x=np.array([[3.3, 3.5]]), y=np.array([[0.98, 0.99]]))
]
l2 = [
([[2.3, 2.5]], [[0.97, 0.96]]),
([[3.3, 3.5]], [[0.98, 0.99]]),
]
self.assertAllClose(l1, l2)
@test_util.run_in_graph_and_eager_modes
def testAllCloseNestedStructure(self):
a = {"x": np.ones((3, 2, 4)) * 7, "y": (2, [{"nested": {"m": 3, "n": 4}}])}
self.assertAllClose(a, a)
b = copy.deepcopy(a)
self.assertAllClose(a, b)
# Test mismatched values
b["y"][1][0]["nested"]["n"] = 4.2
with self.assertRaisesRegex(AssertionError,
r"\[y\]\[1\]\[0\]\[nested\]\[n\]"):
self.assertAllClose(a, b)
@test_util.run_in_graph_and_eager_modes
def testArrayNear(self):
a = [1, 2]
b = [1, 2, 5]
with self.assertRaises(AssertionError):
self.assertArrayNear(a, b, 0.001)
a = [1, 2]
b = [[1, 2], [3, 4]]
with self.assertRaises(TypeError):
self.assertArrayNear(a, b, 0.001)
a = [1, 2]
b = [1, 2]
self.assertArrayNear(a, b, 0.001)
@test_util.skip_if(True) # b/117665998
def testForceGPU(self):
with self.assertRaises(errors.InvalidArgumentError):
with self.test_session(force_gpu=True):
# this relies on us not having a GPU implementation for assert, which
# seems sensible
x = constant_op.constant(True)
y = [15]
control_flow_ops.Assert(x, y).run()
@test_util.run_in_graph_and_eager_modes
def testAssertAllCloseAccordingToType(self):
# test plain int
self.assertAllCloseAccordingToType(1, 1, rtol=1e-8, atol=1e-8)
# test float64
self.assertAllCloseAccordingToType(
np.asarray([1e-8], dtype=np.float64),
np.asarray([2e-8], dtype=np.float64),
rtol=1e-8, atol=1e-8
)
self.assertAllCloseAccordingToType(
constant_op.constant([1e-8], dtype=dtypes.float64),
constant_op.constant([2e-8], dtype=dtypes.float64),
rtol=1e-8,
atol=1e-8)
with (self.assertRaises(AssertionError)):
self.assertAllCloseAccordingToType(
np.asarray([1e-7], dtype=np.float64),
np.asarray([2e-7], dtype=np.float64),
rtol=1e-8, atol=1e-8
)
# test float32
self.assertAllCloseAccordingToType(
np.asarray([1e-7], dtype=np.float32),
np.asarray([2e-7], dtype=np.float32),
rtol=1e-8, atol=1e-8,
float_rtol=1e-7, float_atol=1e-7
)
self.assertAllCloseAccordingToType(
constant_op.constant([1e-7], dtype=dtypes.float32),
constant_op.constant([2e-7], dtype=dtypes.float32),
rtol=1e-8,
atol=1e-8,
float_rtol=1e-7,
float_atol=1e-7)
with (self.assertRaises(AssertionError)):
self.assertAllCloseAccordingToType(
np.asarray([1e-6], dtype=np.float32),
np.asarray([2e-6], dtype=np.float32),
rtol=1e-8, atol=1e-8,
float_rtol=1e-7, float_atol=1e-7
)
# test float16
self.assertAllCloseAccordingToType(
np.asarray([1e-4], dtype=np.float16),
np.asarray([2e-4], dtype=np.float16),
rtol=1e-8, atol=1e-8,
float_rtol=1e-7, float_atol=1e-7,
half_rtol=1e-4, half_atol=1e-4
)
self.assertAllCloseAccordingToType(
constant_op.constant([1e-4], dtype=dtypes.float16),
constant_op.constant([2e-4], dtype=dtypes.float16),
rtol=1e-8,
atol=1e-8,
float_rtol=1e-7,
float_atol=1e-7,
half_rtol=1e-4,
half_atol=1e-4)
with (self.assertRaises(AssertionError)):
self.assertAllCloseAccordingToType(
np.asarray([1e-3], dtype=np.float16),
np.asarray([2e-3], dtype=np.float16),
rtol=1e-8, atol=1e-8,
float_rtol=1e-7, float_atol=1e-7,
half_rtol=1e-4, half_atol=1e-4
)
@test_util.run_in_graph_and_eager_modes
def testAssertAllEqual(self):
i = variables.Variable([100] * 3, dtype=dtypes.int32, name="i")
j = constant_op.constant([20] * 3, dtype=dtypes.int32, name="j")
k = math_ops.add(i, j, name="k")
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual([100] * 3, i)
self.assertAllEqual([120] * 3, k)
self.assertAllEqual([20] * 3, j)
with self.assertRaisesRegex(AssertionError, r"not equal lhs"):
self.assertAllEqual([0] * 3, k)
@test_util.run_in_graph_and_eager_modes
def testAssertNotAllEqual(self):
i = variables.Variable([100], dtype=dtypes.int32, name="i")
j = constant_op.constant([20], dtype=dtypes.int32, name="j")
k = math_ops.add(i, j, name="k")
self.evaluate(variables.global_variables_initializer())
self.assertNotAllEqual([100] * 3, i)
self.assertNotAllEqual([120] * 3, k)
self.assertNotAllEqual([20] * 3, j)
with self.assertRaisesRegex(
AssertionError, r"two values are equal at all elements.*extra message"):
self.assertNotAllEqual([120], k, msg="extra message")
@test_util.run_in_graph_and_eager_modes
def testAssertNotAllClose(self):
# Test with arrays
self.assertNotAllClose([0.1], [0.2])
with self.assertRaises(AssertionError):
self.assertNotAllClose([-1.0, 2.0], [-1.0, 2.0])
# Test with tensors
x = constant_op.constant([1.0, 1.0], name="x")
y = math_ops.add(x, x)
self.assertAllClose([2.0, 2.0], y)
self.assertNotAllClose([0.9, 1.0], x)
with self.assertRaises(AssertionError):
self.assertNotAllClose([1.0, 1.0], x)
@test_util.run_in_graph_and_eager_modes
def testAssertNotAllCloseRTol(self):
# Test with arrays
with self.assertRaises(AssertionError):
self.assertNotAllClose([1.1, 2.1], [1.0, 2.0], rtol=0.2)
# Test with tensors
x = constant_op.constant([1.0, 1.0], name="x")
y = math_ops.add(x, x)
self.assertAllClose([2.0, 2.0], y)
with self.assertRaises(AssertionError):
self.assertNotAllClose([0.9, 1.0], x, rtol=0.2)
@test_util.run_in_graph_and_eager_modes
def testAssertNotAllCloseATol(self):
# Test with arrays
with self.assertRaises(AssertionError):
self.assertNotAllClose([1.1, 2.1], [1.0, 2.0], atol=0.2)
# Test with tensors
x = constant_op.constant([1.0, 1.0], name="x")
y = math_ops.add(x, x)
self.assertAllClose([2.0, 2.0], y)
with self.assertRaises(AssertionError):
self.assertNotAllClose([0.9, 1.0], x, atol=0.2)
@test_util.run_in_graph_and_eager_modes
def testAssertAllGreaterLess(self):
x = constant_op.constant([100.0, 110.0, 120.0], dtype=dtypes.float32)
y = constant_op.constant([10.0] * 3, dtype=dtypes.float32)
z = math_ops.add(x, y)
self.assertAllClose([110.0, 120.0, 130.0], z)
self.assertAllGreater(x, 95.0)
self.assertAllLess(x, 125.0)
with self.assertRaises(AssertionError):
self.assertAllGreater(x, 105.0)
with self.assertRaises(AssertionError):
self.assertAllGreater(x, 125.0)
with self.assertRaises(AssertionError):
self.assertAllLess(x, 115.0)
with self.assertRaises(AssertionError):
self.assertAllLess(x, 95.0)
@test_util.run_in_graph_and_eager_modes
def testAssertAllGreaterLessEqual(self):
x = constant_op.constant([100.0, 110.0, 120.0], dtype=dtypes.float32)
y = constant_op.constant([10.0] * 3, dtype=dtypes.float32)
z = math_ops.add(x, y)
self.assertAllEqual([110.0, 120.0, 130.0], z)
self.assertAllGreaterEqual(x, 95.0)
self.assertAllLessEqual(x, 125.0)
with self.assertRaises(AssertionError):
self.assertAllGreaterEqual(x, 105.0)
with self.assertRaises(AssertionError):
self.assertAllGreaterEqual(x, 125.0)
with self.assertRaises(AssertionError):
self.assertAllLessEqual(x, 115.0)
with self.assertRaises(AssertionError):
self.assertAllLessEqual(x, 95.0)
def testAssertAllInRangeWithNonNumericValuesFails(self):
s1 = constant_op.constant("Hello, ", name="s1")
c = constant_op.constant([1 + 2j, -3 + 5j], name="c")
b = constant_op.constant([False, True], name="b")
with self.assertRaises(AssertionError):
self.assertAllInRange(s1, 0.0, 1.0)
with self.assertRaises(AssertionError):
self.assertAllInRange(c, 0.0, 1.0)
with self.assertRaises(AssertionError):
self.assertAllInRange(b, 0, 1)
@test_util.run_in_graph_and_eager_modes
def testAssertAllInRange(self):
x = constant_op.constant([10.0, 15.0], name="x")
self.assertAllInRange(x, 10, 15)
with self.assertRaises(AssertionError):
self.assertAllInRange(x, 10, 15, open_lower_bound=True)
with self.assertRaises(AssertionError):
self.assertAllInRange(x, 10, 15, open_upper_bound=True)
with self.assertRaises(AssertionError):
self.assertAllInRange(
x, 10, 15, open_lower_bound=True, open_upper_bound=True)
@test_util.run_in_graph_and_eager_modes
def testAssertAllInRangeScalar(self):
x = constant_op.constant(10.0, name="x")
nan = constant_op.constant(np.nan, name="nan")
self.assertAllInRange(x, 5, 15)
with self.assertRaises(AssertionError):
self.assertAllInRange(nan, 5, 15)
with self.assertRaises(AssertionError):
self.assertAllInRange(x, 10, 15, open_lower_bound=True)
with self.assertRaises(AssertionError):
self.assertAllInRange(x, 1, 2)
@test_util.run_in_graph_and_eager_modes
def testAssertAllInRangeErrorMessageEllipses(self):
x_init = np.array([[10.0, 15.0]] * 12)
x = constant_op.constant(x_init, name="x")
with self.assertRaises(AssertionError):
self.assertAllInRange(x, 5, 10)
@test_util.run_in_graph_and_eager_modes
def testAssertAllInRangeDetectsNaNs(self):
x = constant_op.constant(
[[np.nan, 0.0], [np.nan, np.inf], [np.inf, np.nan]], name="x")
with self.assertRaises(AssertionError):
self.assertAllInRange(x, 0.0, 2.0)
@test_util.run_in_graph_and_eager_modes
def testAssertAllInRangeWithInfinities(self):
x = constant_op.constant([10.0, np.inf], name="x")
self.assertAllInRange(x, 10, np.inf)
with self.assertRaises(AssertionError):
self.assertAllInRange(x, 10, np.inf, open_upper_bound=True)
@test_util.run_in_graph_and_eager_modes
def testAssertAllInSet(self):
b = constant_op.constant([True, False], name="b")
x = constant_op.constant([13, 37], name="x")
self.assertAllInSet(b, [False, True])
self.assertAllInSet(b, (False, True))
self.assertAllInSet(b, {False, True})
self.assertAllInSet(x, [0, 13, 37, 42])
self.assertAllInSet(x, (0, 13, 37, 42))
self.assertAllInSet(x, {0, 13, 37, 42})
with self.assertRaises(AssertionError):
self.assertAllInSet(b, [False])
with self.assertRaises(AssertionError):
self.assertAllInSet(x, (42,))
def testRandomSeed(self):
# Call setUp again for WithCApi case (since it makes a new default graph
# after setup).
# TODO(skyewm): remove this when C API is permanently enabled.
with context.eager_mode():
self.setUp()
a = random.randint(1, 1000)
a_np_rand = np.random.rand(1)
a_rand = random_ops.random_normal([1])
# ensure that randomness in multiple testCases is deterministic.
self.setUp()
b = random.randint(1, 1000)
b_np_rand = np.random.rand(1)
b_rand = random_ops.random_normal([1])
self.assertEqual(a, b)
self.assertEqual(a_np_rand, b_np_rand)
self.assertAllEqual(a_rand, b_rand)
@test_util.run_in_graph_and_eager_modes
def test_callable_evaluate(self):
def model():
return resource_variable_ops.ResourceVariable(
name="same_name",
initial_value=1) + 1
with context.eager_mode():
self.assertEqual(2, self.evaluate(model))
@test_util.run_in_graph_and_eager_modes
def test_nested_tensors_evaluate(self):
expected = {"a": 1, "b": 2, "nested": {"d": 3, "e": 4}}
nested = {"a": constant_op.constant(1),
"b": constant_op.constant(2),
"nested": {"d": constant_op.constant(3),
"e": constant_op.constant(4)}}
self.assertEqual(expected, self.evaluate(nested))
def test_run_in_graph_and_eager_modes(self):
l = []
def inc(self, with_brackets):
del self # self argument is required by run_in_graph_and_eager_modes.
mode = "eager" if context.executing_eagerly() else "graph"
with_brackets = "with_brackets" if with_brackets else "without_brackets"
l.append((with_brackets, mode))
f = test_util.run_in_graph_and_eager_modes(inc)
f(self, with_brackets=False)
f = test_util.run_in_graph_and_eager_modes()(inc) # pylint: disable=assignment-from-no-return
f(self, with_brackets=True)
self.assertEqual(len(l), 4)
self.assertEqual(set(l), {
("with_brackets", "graph"),
("with_brackets", "eager"),
("without_brackets", "graph"),
("without_brackets", "eager"),
})
def test_get_node_def_from_graph(self):
graph_def = graph_pb2.GraphDef()
node_foo = graph_def.node.add()
node_foo.name = "foo"
self.assertIs(test_util.get_node_def_from_graph("foo", graph_def), node_foo)
self.assertIsNone(test_util.get_node_def_from_graph("bar", graph_def))
def test_run_in_eager_and_graph_modes_test_class(self):
msg = "`run_in_graph_and_eager_modes` only supports test methods.*"
with self.assertRaisesRegex(ValueError, msg):
@test_util.run_in_graph_and_eager_modes()
class Foo(object):
pass
del Foo # Make pylint unused happy.
def test_run_in_eager_and_graph_modes_skip_graph_runs_eager(self):
modes = []
def _test(self):
if not context.executing_eagerly():
self.skipTest("Skipping in graph mode")
modes.append("eager" if context.executing_eagerly() else "graph")
test_util.run_in_graph_and_eager_modes(_test)(self)
self.assertEqual(modes, ["eager"])
def test_run_in_eager_and_graph_modes_skip_eager_runs_graph(self):
modes = []
def _test(self):
if context.executing_eagerly():
self.skipTest("Skipping in eager mode")
modes.append("eager" if context.executing_eagerly() else "graph")
test_util.run_in_graph_and_eager_modes(_test)(self)
self.assertEqual(modes, ["graph"])
def test_run_in_graph_and_eager_modes_setup_in_same_mode(self):
modes = []
mode_name = lambda: "eager" if context.executing_eagerly() else "graph"
class ExampleTest(test_util.TensorFlowTestCase):
def runTest(self):
pass
def setUp(self):
modes.append("setup_" + mode_name())
@test_util.run_in_graph_and_eager_modes
def testBody(self):
modes.append("run_" + mode_name())
e = ExampleTest()
e.setUp()
e.testBody()
self.assertEqual(modes[1:2], ["run_graph"])
self.assertEqual(modes[2:], ["setup_eager", "run_eager"])
@parameterized.named_parameters(dict(testcase_name="argument",
arg=True))
@test_util.run_in_graph_and_eager_modes
def test_run_in_graph_and_eager_works_with_parameterized_keyword(self, arg):
self.assertEqual(arg, True)
@combinations.generate(combinations.combine(arg=True))
@test_util.run_in_graph_and_eager_modes
def test_run_in_graph_and_eager_works_with_combinations(self, arg):
self.assertEqual(arg, True)
def test_build_as_function_and_v1_graph(self):
class GraphModeAndFunctionTest(parameterized.TestCase):
def __init__(inner_self): # pylint: disable=no-self-argument
super(GraphModeAndFunctionTest, inner_self).__init__()
inner_self.graph_mode_tested = False
inner_self.inside_function_tested = False
def runTest(self):
del self
@test_util.build_as_function_and_v1_graph
def test_modes(inner_self): # pylint: disable=no-self-argument
if ops.inside_function():
self.assertFalse(inner_self.inside_function_tested)
inner_self.inside_function_tested = True
else:
self.assertFalse(inner_self.graph_mode_tested)
inner_self.graph_mode_tested = True
test_object = GraphModeAndFunctionTest()
test_object.test_modes_v1_graph()
test_object.test_modes_function()
self.assertTrue(test_object.graph_mode_tested)
self.assertTrue(test_object.inside_function_tested)
@test_util.run_in_graph_and_eager_modes
def test_consistent_random_seed_in_assert_all_equal(self):
random_seed.set_seed(1066)
index = random_ops.random_shuffle([0, 1, 2, 3, 4], seed=2021)
# This failed when `a` and `b` were evaluated in separate sessions.
self.assertAllEqual(index, index)
def test_with_forward_compatibility_horizons(self):
tested_codepaths = set()
def some_function_with_forward_compat_behavior():
if compat.forward_compatible(2050, 1, 1):
tested_codepaths.add("future")
else:
tested_codepaths.add("present")
@test_util.with_forward_compatibility_horizons(None, [2051, 1, 1])
def some_test(self):
del self # unused
some_function_with_forward_compat_behavior()
some_test(None)
self.assertEqual(tested_codepaths, set(["present", "future"]))
class SkipTestTest(test_util.TensorFlowTestCase):
def _verify_test_in_set_up_or_tear_down(self):
with self.assertRaises(unittest.SkipTest):
with test_util.skip_if_error(self, ValueError,
["foo bar", "test message"]):
raise ValueError("test message")
try:
with self.assertRaisesRegex(ValueError, "foo bar"):
with test_util.skip_if_error(self, ValueError, "test message"):
raise ValueError("foo bar")
except unittest.SkipTest:
raise RuntimeError("Test is not supposed to skip.")
def setUp(self):
super(SkipTestTest, self).setUp()
self._verify_test_in_set_up_or_tear_down()
def tearDown(self):
super(SkipTestTest, self).tearDown()
self._verify_test_in_set_up_or_tear_down()
def test_skip_if_error_should_skip(self):
with self.assertRaises(unittest.SkipTest):
with test_util.skip_if_error(self, ValueError, "test message"):
raise ValueError("test message")
def test_skip_if_error_should_skip_with_list(self):
with self.assertRaises(unittest.SkipTest):
with test_util.skip_if_error(self, ValueError,
["foo bar", "test message"]):
raise ValueError("test message")
def test_skip_if_error_should_skip_without_expected_message(self):
with self.assertRaises(unittest.SkipTest):
with test_util.skip_if_error(self, ValueError):
raise ValueError("test message")
def test_skip_if_error_should_skip_without_error_message(self):
with self.assertRaises(unittest.SkipTest):
with test_util.skip_if_error(self, ValueError):
raise ValueError()
def test_skip_if_error_should_raise_message_mismatch(self):
try:
with self.assertRaisesRegex(ValueError, "foo bar"):
with test_util.skip_if_error(self, ValueError, "test message"):
raise ValueError("foo bar")
except unittest.SkipTest:
raise RuntimeError("Test is not supposed to skip.")
def test_skip_if_error_should_raise_no_message(self):
try:
with self.assertRaisesRegex(ValueError, ""):
with test_util.skip_if_error(self, ValueError, "test message"):
raise ValueError()
except unittest.SkipTest:
raise RuntimeError("Test is not supposed to skip.")
# Its own test case to reproduce variable sharing issues which only pop up when
# setUp() is overridden and super() is not called.
class GraphAndEagerNoVariableSharing(test_util.TensorFlowTestCase):
def setUp(self):
pass # Intentionally does not call TensorFlowTestCase's super()
@test_util.run_in_graph_and_eager_modes
def test_no_variable_sharing(self):
variable_scope.get_variable(
name="step_size",
initializer=np.array(1e-5, np.float32),
use_resource=True,
trainable=False)
class GarbageCollectionTest(test_util.TensorFlowTestCase):
def test_no_reference_cycle_decorator(self):
class ReferenceCycleTest(object):
def __init__(inner_self): # pylint: disable=no-self-argument
inner_self.assertEqual = self.assertEqual # pylint: disable=invalid-name
@test_util.assert_no_garbage_created
def test_has_cycle(self):
a = []
a.append(a)
@test_util.assert_no_garbage_created
def test_has_no_cycle(self):
pass
with self.assertRaises(AssertionError):
ReferenceCycleTest().test_has_cycle()
ReferenceCycleTest().test_has_no_cycle()
@test_util.run_in_graph_and_eager_modes
def test_no_leaked_tensor_decorator(self):
class LeakedTensorTest(object):
def __init__(inner_self): # pylint: disable=no-self-argument
inner_self.assertEqual = self.assertEqual # pylint: disable=invalid-name
@test_util.assert_no_new_tensors
def test_has_leak(self):
self.a = constant_op.constant([3.], name="leak")
@test_util.assert_no_new_tensors
def test_has_no_leak(self):
constant_op.constant([3.], name="no-leak")
with self.assertRaisesRegex(AssertionError, "Tensors not deallocated"):
LeakedTensorTest().test_has_leak()
LeakedTensorTest().test_has_no_leak()
def test_no_new_objects_decorator(self):
class LeakedObjectTest(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(LeakedObjectTest, self).__init__(*args, **kwargs)
self.accumulation = []
@unittest.expectedFailure
@test_util.assert_no_new_pyobjects_executing_eagerly
def test_has_leak(self):
self.accumulation.append([1.])
@test_util.assert_no_new_pyobjects_executing_eagerly
def test_has_no_leak(self):
self.not_accumulating = [1.]
self.assertTrue(LeakedObjectTest("test_has_leak").run().wasSuccessful())
self.assertTrue(LeakedObjectTest("test_has_no_leak").run().wasSuccessful())
class RunFunctionsEagerlyInV2Test(test_util.TensorFlowTestCase,
parameterized.TestCase):
@parameterized.named_parameters(
[("_RunEagerly", True), ("_RunGraph", False)])
def test_run_functions_eagerly(self, run_eagerly): # pylint: disable=g-wrong-blank-lines
results = []
@def_function.function
def add_two(x):
for _ in range(5):
x += 2
results.append(x)
return x
with test_util.run_functions_eagerly(run_eagerly):
add_two(constant_op.constant(2.))
if context.executing_eagerly():
if run_eagerly:
self.assertTrue(isinstance(t, ops.EagerTensor) for t in results)
else:
self.assertTrue(isinstance(t, ops.Tensor) for t in results)
else:
self.assertTrue(isinstance(t, ops.Tensor) for t in results)
if __name__ == "__main__":
googletest.main()