blob: c60d3cbc830419dbd9f7340a5886fb0fd0e04b63 [file] [log] [blame]
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=g-short-docstring-punctuation
"""Asserts and Boolean Checks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
NUMERIC_TYPES = frozenset(
[dtypes.float32, dtypes.float64, dtypes.int8, dtypes.int16, dtypes.int32,
dtypes.int64, dtypes.uint8, dtypes.qint8, dtypes.qint32, dtypes.quint8,
dtypes.complex64])
__all__ = [
'assert_negative',
'assert_positive',
'assert_proper_iterable',
'assert_non_negative',
'assert_non_positive',
'assert_equal',
'assert_none_equal',
'assert_near',
'assert_integer',
'assert_less',
'assert_less_equal',
'assert_greater',
'assert_greater_equal',
'assert_rank',
'assert_rank_at_least',
'assert_rank_in',
'assert_same_float_dtype',
'assert_scalar',
'assert_type',
'assert_shapes',
'is_non_decreasing',
'is_numeric_tensor',
'is_strictly_increasing',
]
def _maybe_constant_value_string(t):
if not isinstance(t, ops.Tensor):
return str(t)
const_t = tensor_util.constant_value(t)
if const_t is not None:
return str(const_t)
return t
def _assert_static(condition, data):
"""Raises a InvalidArgumentError with as much information as possible."""
if not condition:
data_static = [_maybe_constant_value_string(x) for x in data]
raise errors.InvalidArgumentError(node_def=None, op=None,
message='\n'.join(data_static))
def _shape_and_dtype_str(tensor):
"""Returns a string containing tensor's shape and dtype."""
return 'shape=%s dtype=%s' % (tensor.shape, tensor.dtype.name)
def _unary_assert_doc(sym, sym_name):
"""Common docstring for assert_* ops that evaluate a unary predicate over every element of a tensor.
Args:
sym: Mathematical symbol for the check performed on each element, i.e. "> 0"
sym_name: English-language name for the op described by sym
Returns:
Decorator that adds the appropriate docstring to the function for symbol
`sym`.
"""
def _decorator(func):
"""Generated decorator that adds the appropriate docstring to the function for symbol `sym`.
Args:
func: Function for a TensorFlow op
Returns:
Version of `func` with documentation attached.
"""
opname = func.__name__
cap_sym_name = sym_name.capitalize()
func.__doc__ = """
Assert the condition `x {sym}` holds element-wise.
When running in graph mode, you should add a dependency on this operation
to ensure that it runs. Example of adding a dependency to an operation:
```python
with tf.control_dependencies([tf.debugging.{opname}(x, y)]):
output = tf.reduce_sum(x)
```
{sym_name} means, for every element `x[i]` of `x`, we have `x[i] {sym}`.
If `x` is empty this is trivially satisfied.
Args:
x: Numeric `Tensor`.
data: The tensors to print out if the condition is False. Defaults to
error message and first few entries of `x`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to "{opname}".
Returns:
Op that raises `InvalidArgumentError` if `x {sym}` is False.
@compatibility(eager)
returns None
@end_compatibility
Raises:
InvalidArgumentError: if the check can be performed immediately and
`x {sym}` is False. The check can be performed immediately during
eager execution or if `x` is statically known.
""".format(
sym=sym, sym_name=cap_sym_name, opname=opname)
return func
return _decorator
def _binary_assert_doc(sym):
"""Common docstring for most of the v1 assert_* ops that compare two tensors element-wise.
Args:
sym: Binary operation symbol, i.e. "=="
Returns:
Decorator that adds the appropriate docstring to the function for
symbol `sym`.
"""
def _decorator(func):
"""Generated decorator that adds the appropriate docstring to the function for symbol `sym`.
Args:
func: Function for a TensorFlow op
Returns:
A version of `func` with documentation attached.
"""
opname = func.__name__
func.__doc__ = """
Assert the condition `x {sym} y` holds element-wise.
This condition holds if for every pair of (possibly broadcast) elements
`x[i]`, `y[i]`, we have `x[i] {sym} y[i]`.
If both `x` and `y` are empty, this is trivially satisfied.
When running in graph mode, you should add a dependency on this operation
to ensure that it runs. Example of adding a dependency to an operation:
```python
with tf.control_dependencies([tf.compat.v1.{opname}(x, y)]):
output = tf.reduce_sum(x)
```
Args:
x: Numeric `Tensor`.
y: Numeric `Tensor`, same dtype as and broadcastable to `x`.
data: The tensors to print out if the condition is False. Defaults to
error message and first few entries of `x`, `y`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to "{opname}".
Returns:
Op that raises `InvalidArgumentError` if `x {sym} y` is False.
@compatibility(eager)
returns None
@end_compatibility
Raises:
InvalidArgumentError: if the check can be performed immediately and
`x {sym} y` is False. The check can be performed immediately during
eager execution or if `x` and `y` are statically known.
""".format(
sym=sym, opname=opname)
return func
return _decorator
def _make_assert_msg_data(sym, x, y, summarize, test_op):
"""Subroutine of _binary_assert that generates the components of the default error message when running in eager mode.
Args:
sym: Mathematical symbol for the test to apply to pairs of tensor elements,
i.e. "=="
x: First input to the assertion after applying `convert_to_tensor()`
y: Second input to the assertion
summarize: Value of the "summarize" parameter to the original assert_* call;
tells how many elements of each tensor to print.
test_op: TensorFlow op that returns a Boolean tensor with True in each
position where the assertion is satisfied.
Returns:
List of tensors and scalars that, when stringified and concatenated,
will produce the error message string.
"""
# Prepare a message with first elements of x and y.
data = []
data.append('Condition x %s y did not hold.' % sym)
if summarize > 0:
if x.shape == y.shape and x.shape.as_list():
# If the shapes of x and y are the same (and not scalars),
# Get the values that actually differed and their indices.
# If shapes are different this information is more confusing
# than useful.
mask = math_ops.logical_not(test_op)
indices = array_ops.where(mask)
indices_np = indices.numpy()
x_vals = array_ops.boolean_mask(x, mask)
y_vals = array_ops.boolean_mask(y, mask)
num_vals = min(summarize, indices_np.shape[0])
data.append('Indices of first %d different values:' % num_vals)
data.append(indices_np[:num_vals])
data.append('Corresponding x values:')
data.append(x_vals.numpy().reshape((-1,))[:num_vals])
data.append('Corresponding y values:')
data.append(y_vals.numpy().reshape((-1,))[:num_vals])
# reshape((-1,)) is the fastest way to get a flat array view.
x_np = x.numpy().reshape((-1,))
y_np = y.numpy().reshape((-1,))
x_sum = min(x_np.size, summarize)
y_sum = min(y_np.size, summarize)
data.append('First %d elements of x:' % x_sum)
data.append(x_np[:x_sum])
data.append('First %d elements of y:' % y_sum)
data.append(y_np[:y_sum])
return data
def _pretty_print(data_item, summarize):
"""Format a data item for use in an error message in eager mode.
Args:
data_item: One of the items in the "data" argument to an assert_* function.
Can be a Tensor or a scalar value.
summarize: How many elements to retain of each tensor-valued entry in data.
Returns:
An appropriate string representation of data_item
"""
if isinstance(data_item, ops.Tensor):
arr = data_item.numpy()
if np.isscalar(arr):
# Tensor.numpy() returns a scalar for zero-dimensional tensors
return str(arr)
else:
flat = arr.reshape((-1,))
lst = [str(x) for x in flat[:summarize]]
if len(lst) < flat.size:
lst.append('...')
return str(lst)
else:
return str(data_item)
def _binary_assert(sym, opname, op_func, static_func, x, y, data, summarize,
message, name):
"""Generic binary elementwise assertion.
Implements the behavior described in _binary_assert_doc() above.
Args:
sym: Mathematical symbol for the test to apply to pairs of tensor elements,
i.e. "=="
opname: Name of the assert op in the public API, i.e. "assert_equal"
op_func: Function that, if passed the two Tensor inputs to the assertion (x
and y), will return the test to be passed to reduce_all() i.e.
static_func: Function that, if passed numpy ndarray versions of the two
inputs to the assertion, will return a Boolean ndarray with containing
True in all positions where the assertion PASSES.
i.e. np.equal for assert_equal()
x: Numeric `Tensor`.
y: Numeric `Tensor`, same dtype as and broadcastable to `x`.
data: The tensors to print out if the condition is False. Defaults to
error message and first few entries of `x`, `y`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to the value of
`opname`.
Returns:
See docstring template in _binary_assert_doc().
"""
with ops.name_scope(name, opname, [x, y, data]):
x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y')
if context.executing_eagerly():
test_op = op_func(x, y)
condition = math_ops.reduce_all(test_op)
if condition:
return
# If we get here, the assertion has failed.
# Default to printing 3 elements like control_flow_ops.Assert (used
# by graph mode) does. Also treat negative values as "print
# everything" for consistency with Tensor::SummarizeValue().
if summarize is None:
summarize = 3
elif summarize < 0:
summarize = 1e9 # Code below will find exact size of x and y.
if data is None:
data = _make_assert_msg_data(sym, x, y, summarize, test_op)
if message is not None:
data = [message] + list(data)
raise errors.InvalidArgumentError(
node_def=None,
op=None,
message=('\n'.join([_pretty_print(d, summarize) for d in data])))
else: # not context.executing_eagerly()
if data is None:
data = [
'Condition x %s y did not hold element-wise:' % sym,
'x (%s) = ' % x.name, x,
'y (%s) = ' % y.name, y
]
if message is not None:
data = [message] + list(data)
condition = math_ops.reduce_all(op_func(x, y))
x_static = tensor_util.constant_value(x)
y_static = tensor_util.constant_value(y)
if x_static is not None and y_static is not None:
condition_static = np.all(static_func(x_static, y_static))
_assert_static(condition_static, data)
return control_flow_ops.Assert(condition, data, summarize=summarize)
@tf_export(
'debugging.assert_proper_iterable',
v1=['debugging.assert_proper_iterable', 'assert_proper_iterable'])
@deprecation.deprecated_endpoints('assert_proper_iterable')
def assert_proper_iterable(values):
"""Static assert that values is a "proper" iterable.
`Ops` that expect iterables of `Tensor` can call this to validate input.
Useful since `Tensor`, `ndarray`, byte/text type are all iterables themselves.
Args:
values: Object to be checked.
Raises:
TypeError: If `values` is not iterable or is one of
`Tensor`, `SparseTensor`, `np.array`, `tf.compat.bytes_or_text_types`.
"""
unintentional_iterables = (
(ops.Tensor, sparse_tensor.SparseTensor, np.ndarray)
+ compat.bytes_or_text_types
)
if isinstance(values, unintentional_iterables):
raise TypeError(
'Expected argument "values" to be a "proper" iterable. Found: %s' %
type(values))
if not hasattr(values, '__iter__'):
raise TypeError(
'Expected argument "values" to be iterable. Found: %s' % type(values))
@tf_export('debugging.assert_negative', v1=[])
def assert_negative_v2(x, message=None, summarize=None, name=None):
"""Assert the condition `x < 0` holds element-wise.
This Op checks that `x[i] < 0` holds for every element of `x`. If `x` is
empty, this is trivially satisfied.
If `x` is not negative everywhere, `message`, as well as the first `summarize`
entries of `x` are printed, and `InvalidArgumentError` is raised.
Args:
x: Numeric `Tensor`.
message: A string to prefix to the default message.
summarize: Print this many entries of each tensor.
name: A name for this operation (optional). Defaults to "assert_negative".
Returns:
Op raising `InvalidArgumentError` unless `x` is all negative. This can be
used with `tf.control_dependencies` inside of `tf.function`s to block
followup computation until the check has executed.
@compatibility(eager)
returns None
@end_compatibility
Raises:
InvalidArgumentError: if the check can be performed immediately and
`x[i] < 0` is False. The check can be performed immediately during eager
execution or if `x` is statically known.
"""
return assert_negative(x=x, message=message, summarize=summarize, name=name)
@tf_export(v1=['debugging.assert_negative', 'assert_negative'])
@deprecation.deprecated_endpoints('assert_negative')
@_unary_assert_doc('< 0', 'negative')
def assert_negative(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
message = message or ''
with ops.name_scope(name, 'assert_negative', [x, data]):
x = ops.convert_to_tensor(x, name='x')
if data is None:
if context.executing_eagerly():
name = _shape_and_dtype_str(x)
else:
name = x.name
data = [
message,
'Condition x < 0 did not hold element-wise:',
'x (%s) = ' % name, x]
zero = ops.convert_to_tensor(0, dtype=x.dtype)
return assert_less(x, zero, data=data, summarize=summarize)
@tf_export('debugging.assert_positive', v1=[])
def assert_positive_v2(x, message=None, summarize=None, name=None):
"""Assert the condition `x > 0` holds element-wise.
This Op checks that `x[i] > 0` holds for every element of `x`. If `x` is
empty, this is trivially satisfied.
If `x` is not positive everywhere, `message`, as well as the first `summarize`
entries of `x` are printed, and `InvalidArgumentError` is raised.
Args:
x: Numeric `Tensor`.
message: A string to prefix to the default message.
summarize: Print this many entries of each tensor.
name: A name for this operation (optional). Defaults to "assert_positive".
Returns:
Op raising `InvalidArgumentError` unless `x` is all positive. This can be
used with `tf.control_dependencies` inside of `tf.function`s to block
followup computation until the check has executed.
@compatibility(eager)
returns None
@end_compatibility
Raises:
InvalidArgumentError: if the check can be performed immediately and
`x[i] > 0` is False. The check can be performed immediately during eager
execution or if `x` is statically known.
"""
return assert_positive(x=x, summarize=summarize, message=message, name=name)
@tf_export(v1=['debugging.assert_positive', 'assert_positive'])
@deprecation.deprecated_endpoints('assert_positive')
@_unary_assert_doc('> 0', 'positive')
def assert_positive(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
message = message or ''
with ops.name_scope(name, 'assert_positive', [x, data]):
x = ops.convert_to_tensor(x, name='x')
if data is None:
if context.executing_eagerly():
name = _shape_and_dtype_str(x)
else:
name = x.name
data = [
message, 'Condition x > 0 did not hold element-wise:',
'x (%s) = ' % name, x]
zero = ops.convert_to_tensor(0, dtype=x.dtype)
return assert_less(zero, x, data=data, summarize=summarize)
@tf_export('debugging.assert_non_negative', v1=[])
def assert_non_negative_v2(x, message=None, summarize=None, name=None):
"""Assert the condition `x >= 0` holds element-wise.
This Op checks that `x[i] >= 0` holds for every element of `x`. If `x` is
empty, this is trivially satisfied.
If `x` is not >= 0 everywhere, `message`, as well as the first `summarize`
entries of `x` are printed, and `InvalidArgumentError` is raised.
Args:
x: Numeric `Tensor`.
message: A string to prefix to the default message.
summarize: Print this many entries of each tensor.
name: A name for this operation (optional). Defaults to
"assert_non_negative".
Returns:
Op raising `InvalidArgumentError` unless `x` is all non-negative. This can
be used with `tf.control_dependencies` inside of `tf.function`s to block
followup computation until the check has executed.
@compatibility(eager)
returns None
@end_compatibility
Raises:
InvalidArgumentError: if the check can be performed immediately and
`x[i] >= 0` is False. The check can be performed immediately during eager
execution or if `x` is statically known.
"""
return assert_non_negative(x=x, summarize=summarize, message=message,
name=name)
@tf_export(v1=['debugging.assert_non_negative', 'assert_non_negative'])
@deprecation.deprecated_endpoints('assert_non_negative')
@_unary_assert_doc('>= 0', 'non-negative')
def assert_non_negative(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
message = message or ''
with ops.name_scope(name, 'assert_non_negative', [x, data]):
x = ops.convert_to_tensor(x, name='x')
if data is None:
if context.executing_eagerly():
name = _shape_and_dtype_str(x)
else:
name = x.name
data = [
message,
'Condition x >= 0 did not hold element-wise:',
'x (%s) = ' % name, x]
zero = ops.convert_to_tensor(0, dtype=x.dtype)
return assert_less_equal(zero, x, data=data, summarize=summarize)
@tf_export('debugging.assert_non_positive', v1=[])
def assert_non_positive_v2(x, message=None, summarize=None, name=None):
"""Assert the condition `x <= 0` holds element-wise.
This Op checks that `x[i] <= 0` holds for every element of `x`. If `x` is
empty, this is trivially satisfied.
If `x` is not <= 0 everywhere, `message`, as well as the first `summarize`
entries of `x` are printed, and `InvalidArgumentError` is raised.
Args:
x: Numeric `Tensor`.
message: A string to prefix to the default message.
summarize: Print this many entries of each tensor.
name: A name for this operation (optional). Defaults to
"assert_non_positive".
Returns:
Op raising `InvalidArgumentError` unless `x` is all non-positive. This can
be used with `tf.control_dependencies` inside of `tf.function`s to block
followup computation until the check has executed.
@compatibility(eager)
returns None
@end_compatibility
Raises:
InvalidArgumentError: if the check can be performed immediately and
`x[i] <= 0` is False. The check can be performed immediately during eager
execution or if `x` is statically known.
"""
return assert_non_positive(x=x, summarize=summarize, message=message,
name=name)
@tf_export(v1=['debugging.assert_non_positive', 'assert_non_positive'])
@deprecation.deprecated_endpoints('assert_non_positive')
@_unary_assert_doc('<= 0', 'non-positive')
def assert_non_positive(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
message = message or ''
with ops.name_scope(name, 'assert_non_positive', [x, data]):
x = ops.convert_to_tensor(x, name='x')
if data is None:
if context.executing_eagerly():
name = _shape_and_dtype_str(x)
else:
name = x.name
data = [
message,
'Condition x <= 0 did not hold element-wise:'
'x (%s) = ' % name, x]
zero = ops.convert_to_tensor(0, dtype=x.dtype)
return assert_less_equal(x, zero, data=data, summarize=summarize)
@tf_export('debugging.assert_equal', 'assert_equal', v1=[])
def assert_equal_v2(x, y, message=None, summarize=None, name=None):
"""Assert the condition `x == y` holds element-wise.
This Op checks that `x[i] == y[i]` holds for every pair of (possibly
broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
trivially satisfied.
If `x` and `y` are not equal, `message`, as well as the first `summarize`
entries of `x` and `y` are printed, and `InvalidArgumentError` is raised.
Args:
x: Numeric `Tensor`.
y: Numeric `Tensor`, same dtype as and broadcastable to `x`.
message: A string to prefix to the default message.
summarize: Print this many entries of each tensor.
name: A name for this operation (optional). Defaults to "assert_equal".
Returns:
Op that raises `InvalidArgumentError` if `x == y` is False. This can be
used with `tf.control_dependencies` inside of `tf.function`s to block
followup computation until the check has executed.
@compatibility(eager)
returns None
@end_compatibility
Raises:
InvalidArgumentError: if the check can be performed immediately and
`x == y` is False. The check can be performed immediately during eager
execution or if `x` and `y` are statically known.
"""
return assert_equal(x=x, y=y, summarize=summarize, message=message, name=name)
@tf_export(v1=['debugging.assert_equal', 'assert_equal'])
@_binary_assert_doc('==')
def assert_equal(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
with ops.name_scope(name, 'assert_equal', [x, y, data]):
# Short-circuit if x and y are the same tensor.
if x is y:
return None if context.executing_eagerly() else control_flow_ops.no_op()
return _binary_assert('==', 'assert_equal', math_ops.equal, np.equal, x, y,
data, summarize, message, name)
@tf_export('debugging.assert_none_equal', v1=[])
def assert_none_equal_v2(x, y, summarize=None, message=None, name=None):
"""Assert the condition `x != y` holds for all elements.
This Op checks that `x[i] != y[i]` holds for every pair of (possibly
broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
trivially satisfied.
If any elements of `x` and `y` are equal, `message`, as well as the first
`summarize` entries of `x` and `y` are printed, and `InvalidArgumentError`
is raised.
Args:
x: Numeric `Tensor`.
y: Numeric `Tensor`, same dtype as and broadcastable to `x`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to
"assert_none_equal".
Returns:
Op that raises `InvalidArgumentError` if `x != y` is ever False. This can
be used with `tf.control_dependencies` inside of `tf.function`s to block
followup computation until the check has executed.
@compatibility(eager)
returns None
@end_compatibility
Raises:
InvalidArgumentError: if the check can be performed immediately and
`x != y` is False for any pair of elements in `x` and `y`. The check can
be performed immediately during eager execution or if `x` and `y` are
statically known.
"""
return assert_none_equal(x=x, y=y, summarize=summarize, message=message,
name=name)
@tf_export(v1=['debugging.assert_none_equal', 'assert_none_equal'])
@deprecation.deprecated_endpoints('assert_none_equal')
@_binary_assert_doc('!=')
def assert_none_equal(
x, y, data=None, summarize=None, message=None, name=None):
return _binary_assert('!=', 'assert_none_equal', math_ops.not_equal,
np.not_equal, x, y, data, summarize, message, name)
@tf_export('debugging.assert_near', v1=[])
def assert_near_v2(x, y, rtol=None, atol=None, message=None, summarize=None,
name=None):
"""Assert the condition `x` and `y` are close element-wise.
This Op checks that `x[i] - y[i] < atol + rtol * tf.abs(y[i])` holds for every
pair of (possibly broadcast) elements of `x` and `y`. If both `x` and `y` are
empty, this is trivially satisfied.
If any elements of `x` and `y` are not close, `message`, as well as the first
`summarize` entries of `x` and `y` are printed, and `InvalidArgumentError`
is raised.
The default `atol` and `rtol` is `10 * eps`, where `eps` is the smallest
representable positive number such that `1 + eps != 1`. This is about
`1.2e-6` in `32bit`, `2.22e-15` in `64bit`, and `0.00977` in `16bit`.
See `numpy.finfo`.
Args:
x: Float or complex `Tensor`.
y: Float or complex `Tensor`, same dtype as and broadcastable to `x`.
rtol: `Tensor`. Same `dtype` as, and broadcastable to, `x`.
The relative tolerance. Default is `10 * eps`.
atol: `Tensor`. Same `dtype` as, and broadcastable to, `x`.
The absolute tolerance. Default is `10 * eps`.
message: A string to prefix to the default message.
summarize: Print this many entries of each tensor.
name: A name for this operation (optional). Defaults to "assert_near".
Returns:
Op that raises `InvalidArgumentError` if `x` and `y` are not close enough.
This can be used with `tf.control_dependencies` inside of `tf.function`s
to block followup computation until the check has executed.
@compatibility(eager)
returns None
@end_compatibility
Raises:
InvalidArgumentError: if the check can be performed immediately and
`x != y` is False for any pair of elements in `x` and `y`. The check can
be performed immediately during eager execution or if `x` and `y` are
statically known.
@compatibility(numpy)
Similar to `numpy.assert_allclose`, except tolerance depends on data type.
This is due to the fact that `TensorFlow` is often used with `32bit`, `64bit`,
and even `16bit` data.
@end_compatibility
"""
return assert_near(x=x, y=y, rtol=rtol, atol=atol, summarize=summarize,
message=message, name=name)
@tf_export(v1=['debugging.assert_near', 'assert_near'])
@deprecation.deprecated_endpoints('assert_near')
def assert_near(
x, y, rtol=None, atol=None, data=None, summarize=None, message=None,
name=None):
"""Assert the condition `x` and `y` are close element-wise.
Example of adding a dependency to an operation:
```python
with tf.control_dependencies([tf.compat.v1.assert_near(x, y)]):
output = tf.reduce_sum(x)
```
This condition holds if for every pair of (possibly broadcast) elements
`x[i]`, `y[i]`, we have
```tf.abs(x[i] - y[i]) <= atol + rtol * tf.abs(y[i])```.
If both `x` and `y` are empty, this is trivially satisfied.
The default `atol` and `rtol` is `10 * eps`, where `eps` is the smallest
representable positive number such that `1 + eps != 1`. This is about
`1.2e-6` in `32bit`, `2.22e-15` in `64bit`, and `0.00977` in `16bit`.
See `numpy.finfo`.
Args:
x: Float or complex `Tensor`.
y: Float or complex `Tensor`, same `dtype` as, and broadcastable to, `x`.
rtol: `Tensor`. Same `dtype` as, and broadcastable to, `x`.
The relative tolerance. Default is `10 * eps`.
atol: `Tensor`. Same `dtype` as, and broadcastable to, `x`.
The absolute tolerance. Default is `10 * eps`.
data: The tensors to print out if the condition is False. Defaults to
error message and first few entries of `x`, `y`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to "assert_near".
Returns:
Op that raises `InvalidArgumentError` if `x` and `y` are not close enough.
@compatibility(numpy)
Similar to `numpy.assert_allclose`, except tolerance depends on data type.
This is due to the fact that `TensorFlow` is often used with `32bit`, `64bit`,
and even `16bit` data.
@end_compatibility
"""
message = message or ''
with ops.name_scope(name, 'assert_near', [x, y, rtol, atol, data]):
x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y', dtype=x.dtype)
eps = np.finfo(x.dtype.as_numpy_dtype).eps
rtol = 10 * eps if rtol is None else rtol
atol = 10 * eps if atol is None else atol
rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=x.dtype)
atol = ops.convert_to_tensor(atol, name='atol', dtype=x.dtype)
if context.executing_eagerly():
x_name = _shape_and_dtype_str(x)
y_name = _shape_and_dtype_str(y)
else:
x_name = x.name
y_name = y.name
if data is None:
data = [
message,
'x and y not equal to tolerance rtol = %s, atol = %s' % (rtol, atol),
'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y
]
tol = atol + rtol * math_ops.abs(y)
diff = math_ops.abs(x - y)
condition = math_ops.reduce_all(math_ops.less(diff, tol))
return control_flow_ops.Assert(condition, data, summarize=summarize)
@tf_export('debugging.assert_less', 'assert_less', v1=[])
def assert_less_v2(x, y, message=None, summarize=None, name=None):
"""Assert the condition `x < y` holds element-wise.
This Op checks that `x[i] < y[i]` holds for every pair of (possibly
broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
trivially satisfied.
If `x` is not less than `y` element-wise, `message`, as well as the first
`summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` is
raised.
Args:
x: Numeric `Tensor`.
y: Numeric `Tensor`, same dtype as and broadcastable to `x`.
message: A string to prefix to the default message.
summarize: Print this many entries of each tensor.
name: A name for this operation (optional). Defaults to "assert_less".
Returns:
Op that raises `InvalidArgumentError` if `x < y` is False.
This can be used with `tf.control_dependencies` inside of `tf.function`s
to block followup computation until the check has executed.
@compatibility(eager)
returns None
@end_compatibility
Raises:
InvalidArgumentError: if the check can be performed immediately and
`x < y` is False. The check can be performed immediately during eager
execution or if `x` and `y` are statically known.
"""
return assert_less(x=x, y=y, summarize=summarize, message=message, name=name)
@tf_export(v1=['debugging.assert_less', 'assert_less'])
@_binary_assert_doc('<')
def assert_less(x, y, data=None, summarize=None, message=None, name=None):
return _binary_assert('<', 'assert_less', math_ops.less, np.less, x, y, data,
summarize, message, name)
@tf_export('debugging.assert_less_equal', v1=[])
def assert_less_equal_v2(x, y, message=None, summarize=None, name=None):
"""Assert the condition `x <= y` holds element-wise.
This Op checks that `x[i] <= y[i]` holds for every pair of (possibly
broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
trivially satisfied.
If `x` is not less or equal than `y` element-wise, `message`, as well as the
first `summarize` entries of `x` and `y` are printed, and
`InvalidArgumentError` is raised.
Args:
x: Numeric `Tensor`.
y: Numeric `Tensor`, same dtype as and broadcastable to `x`.
message: A string to prefix to the default message.
summarize: Print this many entries of each tensor.
name: A name for this operation (optional). Defaults to "assert_less_equal".
Returns:
Op that raises `InvalidArgumentError` if `x <= y` is False. This can be
used with `tf.control_dependencies` inside of `tf.function`s to block
followup computation until the check has executed.
@compatibility(eager)
returns None
@end_compatibility
Raises:
InvalidArgumentError: if the check can be performed immediately and
`x <= y` is False. The check can be performed immediately during eager
execution or if `x` and `y` are statically known.
"""
return assert_less_equal(x=x, y=y,
summarize=summarize, message=message, name=name)
@tf_export(v1=['debugging.assert_less_equal', 'assert_less_equal'])
@deprecation.deprecated_endpoints('assert_less_equal')
@_binary_assert_doc('<=')
def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None):
return _binary_assert('<=', 'assert_less_equal', math_ops.less_equal,
np.less_equal, x, y, data, summarize, message, name)
@tf_export('debugging.assert_greater', 'assert_greater', v1=[])
def assert_greater_v2(x, y, message=None, summarize=None, name=None):
"""Assert the condition `x > y` holds element-wise.
This Op checks that `x[i] > y[i]` holds for every pair of (possibly
broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
trivially satisfied.
If `x` is not greater than `y` element-wise, `message`, as well as the first
`summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` is
raised.
Args:
x: Numeric `Tensor`.
y: Numeric `Tensor`, same dtype as and broadcastable to `x`.
message: A string to prefix to the default message.
summarize: Print this many entries of each tensor.
name: A name for this operation (optional). Defaults to "assert_greater".
Returns:
Op that raises `InvalidArgumentError` if `x > y` is False. This can be
used with `tf.control_dependencies` inside of `tf.function`s to block
followup computation until the check has executed.
@compatibility(eager)
returns None
@end_compatibility
Raises:
InvalidArgumentError: if the check can be performed immediately and
`x > y` is False. The check can be performed immediately during eager
execution or if `x` and `y` are statically known.
"""
return assert_greater(x=x, y=y, summarize=summarize, message=message,
name=name)
@tf_export(v1=['debugging.assert_greater', 'assert_greater'])
@_binary_assert_doc('>')
def assert_greater(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
return _binary_assert('>', 'assert_greater', math_ops.greater, np.greater, x,
y, data, summarize, message, name)
@tf_export('debugging.assert_greater_equal', v1=[])
def assert_greater_equal_v2(x, y, message=None, summarize=None, name=None):
"""Assert the condition `x >= y` holds element-wise.
This Op checks that `x[i] >= y[i]` holds for every pair of (possibly
broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
trivially satisfied.
If `x` is not greater or equal to `y` element-wise, `message`, as well as the
first `summarize` entries of `x` and `y` are printed, and
`InvalidArgumentError` is raised.
Args:
x: Numeric `Tensor`.
y: Numeric `Tensor`, same dtype as and broadcastable to `x`.
message: A string to prefix to the default message.
summarize: Print this many entries of each tensor.
name: A name for this operation (optional). Defaults to
"assert_greater_equal".
Returns:
Op that raises `InvalidArgumentError` if `x >= y` is False. This can be
used with `tf.control_dependencies` inside of `tf.function`s to block
followup computation until the check has executed.
@compatibility(eager)
returns None
@end_compatibility
Raises:
InvalidArgumentError: if the check can be performed immediately and
`x >= y` is False. The check can be performed immediately during eager
execution or if `x` and `y` are statically known.
"""
return assert_greater_equal(x=x, y=y, summarize=summarize, message=message,
name=name)
@tf_export(v1=['debugging.assert_greater_equal', 'assert_greater_equal'])
@deprecation.deprecated_endpoints('assert_greater_equal')
@_binary_assert_doc('>=')
def assert_greater_equal(x, y, data=None, summarize=None, message=None,
name=None):
return _binary_assert('>=', 'assert_greater_equal', math_ops.greater_equal,
np.greater_equal, x, y, data, summarize, message, name)
def _assert_rank_condition(
x, rank, static_condition, dynamic_condition, data, summarize):
"""Assert `x` has a rank that satisfies a given condition.
Args:
x: Numeric `Tensor`.
rank: Scalar `Tensor`.
static_condition: A python function that takes `[actual_rank, given_rank]`
and returns `True` if the condition is satisfied, `False` otherwise.
dynamic_condition: An `op` that takes [actual_rank, given_rank] and return
`True` if the condition is satisfied, `False` otherwise.
data: The tensors to print out if the condition is false. Defaults to
error message and first few entries of `x`.
summarize: Print this many entries of each tensor.
Returns:
Op raising `InvalidArgumentError` if `x` fails dynamic_condition.
Raises:
ValueError: If static checks determine `x` fails static_condition.
"""
assert_type(rank, dtypes.int32)
# Attempt to statically defined rank.
rank_static = tensor_util.constant_value(rank)
if rank_static is not None:
if rank_static.ndim != 0:
raise ValueError('Rank must be a scalar.')
x_rank_static = x.get_shape().ndims
if x_rank_static is not None:
if not static_condition(x_rank_static, rank_static):
raise ValueError(
'Static rank condition failed', x_rank_static, rank_static)
return control_flow_ops.no_op(name='static_checks_determined_all_ok')
condition = dynamic_condition(array_ops.rank(x), rank)
# Add the condition that `rank` must have rank zero. Prevents the bug where
# someone does assert_rank(x, [n]), rather than assert_rank(x, n).
if rank_static is None:
this_data = ['Rank must be a scalar. Received rank: ', rank]
rank_check = assert_rank(rank, 0, data=this_data)
condition = control_flow_ops.with_dependencies([rank_check], condition)
return control_flow_ops.Assert(condition, data, summarize=summarize)
@tf_export('debugging.assert_rank', 'assert_rank', v1=[])
def assert_rank_v2(x, rank, message=None, name=None):
"""Assert that `x` has rank equal to `rank`.
This Op checks that the rank of `x` is equal to `rank`.
If `x` has a different rank, `message`, as well as the shape of `x` are
printed, and `InvalidArgumentError` is raised.
Args:
x: `Tensor`.
rank: Scalar integer `Tensor`.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to
"assert_rank".
Returns:
Op raising `InvalidArgumentError` unless `x` has specified rank.
If static checks determine `x` has correct rank, a `no_op` is returned.
This can be used with `tf.control_dependencies` inside of `tf.function`s
to block followup computation until the check has executed.
@compatibility(eager)
returns None
@end_compatibility
Raises:
InvalidArgumentError: if the check can be performed immediately and
`x` does not have rank `rank`. The check can be performed immediately
during eager execution or if the shape of `x` is statically known.
"""
return assert_rank(x=x, rank=rank, message=message, name=name)
@tf_export(v1=['debugging.assert_rank', 'assert_rank'])
def assert_rank(x, rank, data=None, summarize=None, message=None, name=None):
"""Assert `x` has rank equal to `rank`.
Example of adding a dependency to an operation:
```python
with tf.control_dependencies([tf.compat.v1.assert_rank(x, 2)]):
output = tf.reduce_sum(x)
```
Args:
x: Numeric `Tensor`.
rank: Scalar integer `Tensor`.
data: The tensors to print out if the condition is False. Defaults to
error message and the shape of `x`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to "assert_rank".
Returns:
Op raising `InvalidArgumentError` unless `x` has specified rank.
If static checks determine `x` has correct rank, a `no_op` is returned.
Raises:
ValueError: If static checks determine `x` has wrong rank.
"""
with ops.name_scope(name, 'assert_rank', (x, rank) + tuple(data or [])):
x = ops.convert_to_tensor(x, name='x')
rank = ops.convert_to_tensor(rank, name='rank')
message = message or ''
static_condition = lambda actual_rank, given_rank: actual_rank == given_rank
dynamic_condition = math_ops.equal
if context.executing_eagerly():
name = ''
else:
name = x.name
if data is None:
data = [
message,
'Tensor %s must have rank' % name, rank, 'Received shape: ',
array_ops.shape(x)
]
try:
assert_op = _assert_rank_condition(x, rank, static_condition,
dynamic_condition, data, summarize)
except ValueError as e:
if e.args[0] == 'Static rank condition failed':
raise ValueError(
'%s. Tensor %s must have rank %d. Received rank %d, shape %s' %
(message, name, e.args[2], e.args[1], x.get_shape()))
else:
raise
return assert_op
@tf_export('debugging.assert_rank_at_least', v1=[])
def assert_rank_at_least_v2(x, rank, message=None, name=None):
"""Assert that `x` has rank of at least `rank`.
This Op checks that the rank of `x` is greater or equal to `rank`.
If `x` has a rank lower than `rank`, `message`, as well as the shape of `x`
are printed, and `InvalidArgumentError` is raised.
Args:
x: `Tensor`.
rank: Scalar integer `Tensor`.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to
"assert_rank_at_least".
Returns:
Op raising `InvalidArgumentError` unless `x` has specified rank or higher.
If static checks determine `x` has correct rank, a `no_op` is returned.
This can be used with `tf.control_dependencies` inside of `tf.function`s
to block followup computation until the check has executed.
@compatibility(eager)
returns None
@end_compatibility
Raises:
InvalidArgumentError: `x` does not have rank at least `rank`, but the rank
cannot be statically determined.
ValueError: If static checks determine `x` has mismatched rank.
"""
return assert_rank_at_least(x=x, rank=rank, message=message, name=name)
@tf_export(v1=['debugging.assert_rank_at_least', 'assert_rank_at_least'])
@deprecation.deprecated_endpoints('assert_rank_at_least')
def assert_rank_at_least(
x, rank, data=None, summarize=None, message=None, name=None):
"""Assert `x` has rank equal to `rank` or higher.
Example of adding a dependency to an operation:
```python
with tf.control_dependencies([tf.compat.v1.assert_rank_at_least(x, 2)]):
output = tf.reduce_sum(x)
```
Args:
x: Numeric `Tensor`.
rank: Scalar `Tensor`.
data: The tensors to print out if the condition is False. Defaults to
error message and first few entries of `x`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional).
Defaults to "assert_rank_at_least".
Returns:
Op raising `InvalidArgumentError` unless `x` has specified rank or higher.
If static checks determine `x` has correct rank, a `no_op` is returned.
Raises:
ValueError: If static checks determine `x` has wrong rank.
"""
with ops.name_scope(
name, 'assert_rank_at_least', (x, rank) + tuple(data or [])):
x = ops.convert_to_tensor(x, name='x')
rank = ops.convert_to_tensor(rank, name='rank')
message = message or ''
static_condition = lambda actual_rank, given_rank: actual_rank >= given_rank
dynamic_condition = math_ops.greater_equal
if context.executing_eagerly():
name = ''
else:
name = x.name
if data is None:
data = [
message,
'Tensor %s must have rank at least' % name, rank,
'Received shape: ', array_ops.shape(x)
]
try:
assert_op = _assert_rank_condition(x, rank, static_condition,
dynamic_condition, data, summarize)
except ValueError as e:
if e.args[0] == 'Static rank condition failed':
raise ValueError(
'%s. Tensor %s must have rank at least %d. Received rank %d, '
'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape()))
else:
raise
return assert_op
def _static_rank_in(actual_rank, given_ranks):
return actual_rank in given_ranks
def _dynamic_rank_in(actual_rank, given_ranks):
if len(given_ranks) < 1:
return ops.convert_to_tensor(False)
result = math_ops.equal(given_ranks[0], actual_rank)
for given_rank in given_ranks[1:]:
result = math_ops.logical_or(
result, math_ops.equal(given_rank, actual_rank))
return result
def _assert_ranks_condition(
x, ranks, static_condition, dynamic_condition, data, summarize):
"""Assert `x` has a rank that satisfies a given condition.
Args:
x: Numeric `Tensor`.
ranks: Scalar `Tensor`.
static_condition: A python function that takes
`[actual_rank, given_ranks]` and returns `True` if the condition is
satisfied, `False` otherwise.
dynamic_condition: An `op` that takes [actual_rank, given_ranks]
and return `True` if the condition is satisfied, `False` otherwise.
data: The tensors to print out if the condition is false. Defaults to
error message and first few entries of `x`.
summarize: Print this many entries of each tensor.
Returns:
Op raising `InvalidArgumentError` if `x` fails dynamic_condition.
Raises:
ValueError: If static checks determine `x` fails static_condition.
"""
for rank in ranks:
assert_type(rank, dtypes.int32)
# Attempt to statically defined rank.
ranks_static = tuple([tensor_util.constant_value(rank) for rank in ranks])
if not any(r is None for r in ranks_static):
for rank_static in ranks_static:
if rank_static.ndim != 0:
raise ValueError('Rank must be a scalar.')
x_rank_static = x.get_shape().ndims
if x_rank_static is not None:
if not static_condition(x_rank_static, ranks_static):
raise ValueError(
'Static rank condition failed', x_rank_static, ranks_static)
return control_flow_ops.no_op(name='static_checks_determined_all_ok')
condition = dynamic_condition(array_ops.rank(x), ranks)
# Add the condition that `rank` must have rank zero. Prevents the bug where
# someone does assert_rank(x, [n]), rather than assert_rank(x, n).
for rank, rank_static in zip(ranks, ranks_static):
if rank_static is None:
this_data = ['Rank must be a scalar. Received rank: ', rank]
rank_check = assert_rank(rank, 0, data=this_data)
condition = control_flow_ops.with_dependencies([rank_check], condition)
return control_flow_ops.Assert(condition, data, summarize=summarize)
@tf_export('debugging.assert_rank_in', v1=[])
def assert_rank_in_v2(x, ranks, message=None, name=None):
"""Assert that `x` has a rank in `ranks`.
This Op checks that the rank of `x` is in `ranks`.
If `x` has a different rank, `message`, as well as the shape of `x` are
printed, and `InvalidArgumentError` is raised.
Args:
x: `Tensor`.
ranks: `Iterable` of scalar `Tensor` objects.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to "assert_rank_in".
Returns:
Op raising `InvalidArgumentError` unless rank of `x` is in `ranks`.
If static checks determine `x` has matching rank, a `no_op` is returned.
This can be used with `tf.control_dependencies` inside of `tf.function`s
to block followup computation until the check has executed.
@compatibility(eager)
returns None
@end_compatibility
Raises:
InvalidArgumentError: `x` does not have rank in `ranks`, but the rank cannot
be statically determined.
ValueError: If static checks determine `x` has mismatched rank.
"""
return assert_rank_in(x=x, ranks=ranks, message=message, name=name)
@tf_export(v1=['debugging.assert_rank_in', 'assert_rank_in'])
@deprecation.deprecated_endpoints('assert_rank_in')
def assert_rank_in(
x, ranks, data=None, summarize=None, message=None, name=None):
"""Assert `x` has rank in `ranks`.
Example of adding a dependency to an operation:
```python
with tf.control_dependencies([tf.compat.v1.assert_rank_in(x, (2, 4))]):
output = tf.reduce_sum(x)
```
Args:
x: Numeric `Tensor`.
ranks: Iterable of scalar `Tensor` objects.
data: The tensors to print out if the condition is False. Defaults to
error message and first few entries of `x`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional).
Defaults to "assert_rank_in".
Returns:
Op raising `InvalidArgumentError` unless rank of `x` is in `ranks`.
If static checks determine `x` has matching rank, a `no_op` is returned.
Raises:
ValueError: If static checks determine `x` has mismatched rank.
"""
with ops.name_scope(
name, 'assert_rank_in', (x,) + tuple(ranks) + tuple(data or [])):
x = ops.convert_to_tensor(x, name='x')
ranks = tuple([ops.convert_to_tensor(rank, name='rank') for rank in ranks])
message = message or ''
if context.executing_eagerly():
name = ''
else:
name = x.name
if data is None:
data = [
message, 'Tensor %s must have rank in' % name
] + list(ranks) + [
'Received shape: ', array_ops.shape(x)
]
try:
assert_op = _assert_ranks_condition(x, ranks, _static_rank_in,
_dynamic_rank_in, data, summarize)
except ValueError as e:
if e.args[0] == 'Static rank condition failed':
raise ValueError(
'%s. Tensor %s must have rank in %s. Received rank %d, '
'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape()))
else:
raise
return assert_op
@tf_export('debugging.assert_integer', v1=[])
def assert_integer_v2(x, message=None, name=None):
"""Assert that `x` is of integer dtype.
If `x` has a non-integer type, `message`, as well as the dtype of `x` are
printed, and `InvalidArgumentError` is raised.
This can always be checked statically, so this method returns nothing.
Args:
x: A `Tensor`.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to "assert_integer".
Raises:
TypeError: If `x.dtype` is not a non-quantized integer type.
"""
assert_integer(x=x, message=message, name=name)
@tf_export(v1=['debugging.assert_integer', 'assert_integer'])
@deprecation.deprecated_endpoints('assert_integer')
def assert_integer(x, message=None, name=None):
"""Assert that `x` is of integer dtype.
Example of adding a dependency to an operation:
```python
with tf.control_dependencies([tf.compat.v1.assert_integer(x)]):
output = tf.reduce_sum(x)
```
Args:
x: `Tensor` whose basetype is integer and is not quantized.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to "assert_integer".
Raises:
TypeError: If `x.dtype` is anything other than non-quantized integer.
Returns:
A `no_op` that does nothing. Type can be determined statically.
"""
message = message or ''
with ops.name_scope(name, 'assert_integer', [x]):
x = ops.convert_to_tensor(x, name='x')
if not x.dtype.is_integer:
if context.executing_eagerly():
name = 'tensor'
else:
name = x.name
err_msg = (
'%s Expected "x" to be integer type. Found: %s of dtype %s'
% (message, name, x.dtype))
raise TypeError(err_msg)
return control_flow_ops.no_op('statically_determined_was_integer')
@tf_export('debugging.assert_type', v1=[])
def assert_type_v2(tensor, tf_type, message=None, name=None):
"""Asserts that the given `Tensor` is of the specified type.
This can always be checked statically, so this method returns nothing.
Args:
tensor: A `Tensor`.
tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`,
etc).
message: A string to prefix to the default message.
name: A name for this operation. Defaults to "assert_type"
Raises:
TypeError: If the tensor's data type doesn't match `tf_type`.
"""
assert_type(tensor=tensor, tf_type=tf_type, message=message, name=name)
@tf_export(v1=['debugging.assert_type', 'assert_type'])
@deprecation.deprecated_endpoints('assert_type')
def assert_type(tensor, tf_type, message=None, name=None):
"""Statically asserts that the given `Tensor` is of the specified type.
Args:
tensor: A `Tensor`.
tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`,
etc).
message: A string to prefix to the default message.
name: A name to give this `Op`. Defaults to "assert_type"
Raises:
TypeError: If the tensors data type doesn't match `tf_type`.
Returns:
A `no_op` that does nothing. Type can be determined statically.
"""
message = message or ''
with ops.name_scope(name, 'assert_type', [tensor]):
tensor = ops.convert_to_tensor(tensor, name='tensor')
if tensor.dtype != tf_type:
if context.executing_eagerly():
raise TypeError('%s tensor must be of type %s' % (message, tf_type))
else:
raise TypeError('%s %s must be of type %s' % (message, tensor.name,
tf_type))
return control_flow_ops.no_op('statically_determined_correct_type')
def _dimension_sizes(x):
"""Gets the dimension sizes of a tensor `x`.
If a size can be determined statically it is returned as an integer,
otherwise as a tensor.
If `x` is a scalar it is treated as rank 1 size 1.
Args:
x: A `Tensor`.
Returns:
Dimension sizes.
"""
dynamic_shape = array_ops.shape(x)
rank = x.get_shape().rank
rank_is_known = rank is not None
if rank_is_known and rank == 0:
return tuple([1])
if rank_is_known and rank > 0:
static_shape = x.get_shape().as_list()
sizes = [
int(size) if size is not None else dynamic_shape[i]
for i, size in enumerate(static_shape)
]
return sizes
has_rank_zero = math_ops.equal(array_ops.rank(x), 0)
return control_flow_ops.cond(
has_rank_zero, lambda: array_ops.constant([1]), lambda: dynamic_shape)
def _symbolic_dimension_sizes(symbolic_shape):
# If len(symbolic_shape) == 0 construct a tuple
if not symbolic_shape:
return tuple([1])
return symbolic_shape
def _has_known_value(dimension_size):
not_none = dimension_size is not None
try:
int(dimension_size)
can_be_parsed_as_int = True
except (ValueError, TypeError):
can_be_parsed_as_int = False
return not_none and can_be_parsed_as_int
def _is_symbol_for_any_size(symbol):
return symbol in [None, '.']
_TensorDimSizes = collections.namedtuple(
'_TensorDimSizes',
['x', 'unspecified_dim', 'actual_sizes', 'symbolic_sizes'])
@tf_export('debugging.assert_shapes', v1=[])
def assert_shapes_v2(shapes, data=None, summarize=None, message=None,
name=None):
"""Assert tensor shapes and dimension size relationships between tensors.
This Op checks that a collection of tensors shape relationships
satisfies given constraints.
Example:
>>> n = 10
>>> q = 3
>>> d = 7
>>> x = tf.zeros([n,q])
>>> y = tf.ones([n,d])
>>> param = tf.Variable([1.0, 2.0, 3.0])
>>> scalar = 1.0
>>> tf.debugging.assert_shapes([
... (x, ('N', 'Q')),
... (y, ('N', 'D')),
... (param, ('Q',)),
... (scalar, ()),
... ])
>>> tf.debugging.assert_shapes([
... (x, ('N', 'D')),
... (y, ('N', 'D'))
... ])
Traceback (most recent call last):
...
ValueError: ...
If `x`, `y`, `param` or `scalar` does not have a shape that satisfies
all specified constraints, `message`, as well as the first `summarize` entries
of the first encountered violating tensor are printed, and
`InvalidArgumentError` is raised.
Size entries in the specified shapes are checked against other entries by
their __hash__, except:
- a size entry is interpreted as an explicit size if it can be parsed as an
integer primitive.
- a size entry is interpreted as *any* size if it is None or '.'.
If the first entry of a shape is `...` (type `Ellipsis`) or '*' that indicates
a variable number of outer dimensions of unspecified size, i.e. the constraint
applies to the inner-most dimensions only.
Scalar tensors and specified shapes of length zero (excluding the 'inner-most'
prefix) are both treated as having a single dimension of size one.
Args:
shapes: dictionary with (`Tensor` to shape) items. A shape must be an
iterable.
data: The tensors to print out if the condition is False. Defaults to error
message and first few entries of the violating tensor.
summarize: Print this many entries of the tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to "assert_shapes".
Raises:
ValueError: If static checks determine any shape constraint is violated.
"""
assert_shapes(
shapes, data=data, summarize=summarize, message=message, name=name)
@tf_export(v1=['debugging.assert_shapes'])
def assert_shapes(shapes, data=None, summarize=None, message=None, name=None):
"""Assert tensor shapes and dimension size relationships between tensors.
This Op checks that a collection of tensors shape relationships
satisfies given constraints.
Example:
```python
tf.assert_shapes({
(x, ('N', 'Q')),
(y, ('N', 'D')),
(param, ('Q',)),
(scalar, ())
})
```
Example of adding a dependency to an operation:
```python
with tf.control_dependencies([tf.assert_shapes(shapes)]):
output = tf.matmul(x, y, transpose_a=True)
```
If `x`, `y`, `param` or `scalar` does not have a shape that satisfies
all specified constraints, `message`, as well as the first `summarize` entries
of the first encountered violating tensor are printed, and
`InvalidArgumentError` is raised.
Size entries in the specified shapes are checked against other entries by
their __hash__, except:
- a size entry is interpreted as an explicit size if it can be parsed as an
integer primitive.
- a size entry is interpreted as *any* size if it is None or '.'.
If the first entry of a shape is `...` (type `Ellipsis`) or '*' that indicates
a variable number of outer dimensions of unspecified size, i.e. the constraint
applies to the inner-most dimensions only.
Scalar tensors and specified shapes of length zero (excluding the 'inner-most'
prefix) are both treated as having a single dimension of size one.
Args:
shapes: dictionary with (`Tensor` to shape) items. A shape must be an
iterable.
data: The tensors to print out if the condition is False. Defaults to error
message and first few entries of the violating tensor.
summarize: Print this many entries of the tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional). Defaults to "assert_shapes".
Returns:
Op raising `InvalidArgumentError` unless all shape constraints are
satisfied.
If static checks determine all constraints are satisfied, a `no_op` is
returned.
Raises:
ValueError: If static checks determine any shape constraint is violated.
"""
# If the user manages to assemble a dict containing tensors (possible in
# Graph mode only), make sure we still accept that.
if isinstance(shapes, dict):
shapes = shapes.items()
message = message or ''
with ops.name_scope(name, 'assert_shapes', [shapes, data]):
# Shape specified as None implies no constraint
shape_constraints = [
(ops.convert_to_tensor(x), s) for x, s in shapes if s is not None
]
executing_eagerly = context.executing_eagerly()
def tensor_name(x):
if executing_eagerly:
return _shape_and_dtype_str(x)
return x.name
tensor_dim_sizes = []
for tensor, symbolic_shape in shape_constraints:
is_iterable = (
hasattr(symbolic_shape, '__iter__') or
hasattr(symbolic_shape, '__getitem__') # For Python 2 compat.
)
if not is_iterable:
raise ValueError(
'%s. '
'Tensor %s. Specified shape must be an iterable. '
'An iterable has the attribute `__iter__` or `__getitem__`. '
'Received specified shape: %s' %
(message, tensor_name(tensor), symbolic_shape))
# We convert this into a tuple to handle strings, lists and numpy arrays
symbolic_shape_tuple = tuple(symbolic_shape)
tensors_specified_innermost = False
for i, symbol in enumerate(symbolic_shape_tuple):
if symbol not in [Ellipsis, '*']:
continue
if i != 0:
raise ValueError(
'%s. '
'Tensor %s specified shape index %d. '
'Symbol `...` or `*` for a variable number of '
'unspecified dimensions is only allowed as the first entry' %
(message, tensor_name(tensor), i))
tensors_specified_innermost = True
# Only include the size of the specified dimensions since the 0th symbol
# is either ellipsis or *
tensor_dim_sizes.append(
_TensorDimSizes(
tensor, tensors_specified_innermost, _dimension_sizes(tensor),
_symbolic_dimension_sizes(
symbolic_shape_tuple[1:]
if tensors_specified_innermost else symbolic_shape_tuple)))
rank_assertions = []
for sizes in tensor_dim_sizes:
rank = len(sizes.symbolic_sizes)
rank_zero_or_one = rank in [0, 1]
if sizes.unspecified_dim:
if rank_zero_or_one:
# No assertion of rank needed as `x` only need to have rank at least
# 0. See elif rank_zero_or_one case comment.
continue
assertion = assert_rank_at_least(
x=sizes.x,
rank=rank,
data=data,
summarize=summarize,
message=message,
name=name)
elif rank_zero_or_one:
# Rank 0 is treated as rank 1 size 1, i.e. there is
# no distinction between the two in terms of rank.
# See _dimension_sizes.
assertion = assert_rank_in(
x=sizes.x,
ranks=[0, 1],
data=data,
summarize=summarize,
message=message,
name=name)
else:
assertion = assert_rank(
x=sizes.x,
rank=rank,
data=data,
summarize=summarize,
message=message,
name=name)
rank_assertions.append(assertion)
size_assertions = []
size_specifications = {}
for sizes in tensor_dim_sizes:
for i, size_symbol in enumerate(sizes.symbolic_sizes):
if _is_symbol_for_any_size(size_symbol):
# Size specified as any implies no constraint
continue
if sizes.unspecified_dim:
tensor_dim = i - len(sizes.symbolic_sizes)
else:
tensor_dim = i
if size_symbol in size_specifications or _has_known_value(size_symbol):
if _has_known_value(size_symbol):
specified_size = int(size_symbol)
size_check_message = 'Specified explicitly'
else:
specified_size, specified_by_y, specified_at_dim = \
size_specifications[size_symbol]
size_check_message = (
'Specified by tensor %s dimension %d' %
(tensor_name(specified_by_y), specified_at_dim))
actual_size = sizes.actual_sizes[tensor_dim]
if _has_known_value(actual_size) and _has_known_value(specified_size):
if int(actual_size) != int(specified_size):
raise ValueError(
'%s. %s. Tensor %s dimension %s must have size %d. '
'Received size %d, shape %s' %
(message, size_check_message, tensor_name(sizes.x),
tensor_dim, specified_size, actual_size,
sizes.x.get_shape()))
# No dynamic assertion needed
continue
condition = math_ops.equal(
ops.convert_to_tensor(actual_size),
ops.convert_to_tensor(specified_size))
data_ = data
if data is None:
data_ = [
message, size_check_message,
'Tensor %s dimension' % tensor_name(sizes.x), tensor_dim,
'must have size', specified_size, 'Received shape: ',
array_ops.shape(sizes.x)
]
size_assertions.append(
control_flow_ops.Assert(condition, data_, summarize=summarize))
else:
size = sizes.actual_sizes[tensor_dim]
size_specifications[size_symbol] = (size, sizes.x, tensor_dim)
with ops.control_dependencies(rank_assertions):
shapes_assertion = control_flow_ops.group(size_assertions)
return shapes_assertion
# pylint: disable=line-too-long
def _get_diff_for_monotonic_comparison(x):
"""Gets the difference x[1:] - x[:-1]."""
x = array_ops.reshape(x, [-1])
if not is_numeric_tensor(x):
raise TypeError('Expected x to be numeric, instead found: %s' % x)
# If x has less than 2 elements, there is nothing to compare. So return [].
is_shorter_than_two = math_ops.less(array_ops.size(x), 2)
short_result = lambda: ops.convert_to_tensor([], dtype=x.dtype)
# With 2 or more elements, return x[1:] - x[:-1]
s_len = array_ops.shape(x) - 1
diff = lambda: array_ops.strided_slice(x, [1], [1] + s_len)- array_ops.strided_slice(x, [0], s_len)
return control_flow_ops.cond(is_shorter_than_two, short_result, diff)
@tf_export(
'debugging.is_numeric_tensor',
v1=['debugging.is_numeric_tensor', 'is_numeric_tensor'])
@deprecation.deprecated_endpoints('is_numeric_tensor')
def is_numeric_tensor(tensor):
"""Returns `True` if the elements of `tensor` are numbers.
Specifically, returns `True` if the dtype of `tensor` is one of the following:
* `tf.float32`
* `tf.float64`
* `tf.int8`
* `tf.int16`
* `tf.int32`
* `tf.int64`
* `tf.uint8`
* `tf.qint8`
* `tf.qint32`
* `tf.quint8`
* `tf.complex64`
Returns `False` if `tensor` is of a non-numeric type or if `tensor` is not
a `tf.Tensor` object.
"""
return isinstance(tensor, ops.Tensor) and tensor.dtype in NUMERIC_TYPES
@tf_export(
'math.is_non_decreasing',
v1=[
'math.is_non_decreasing', 'debugging.is_non_decreasing',
'is_non_decreasing'
])
@deprecation.deprecated_endpoints('debugging.is_non_decreasing',
'is_non_decreasing')
def is_non_decreasing(x, name=None):
"""Returns `True` if `x` is non-decreasing.
Elements of `x` are compared in row-major order. The tensor `[x[0],...]`
is non-decreasing if for every adjacent pair we have `x[i] <= x[i+1]`.
If `x` has less than two elements, it is trivially non-decreasing.
See also: `is_strictly_increasing`
Args:
x: Numeric `Tensor`.
name: A name for this operation (optional). Defaults to "is_non_decreasing"
Returns:
Boolean `Tensor`, equal to `True` iff `x` is non-decreasing.
Raises:
TypeError: if `x` is not a numeric tensor.
"""
with ops.name_scope(name, 'is_non_decreasing', [x]):
diff = _get_diff_for_monotonic_comparison(x)
# When len(x) = 1, diff = [], less_equal = [], and reduce_all([]) = True.
zero = ops.convert_to_tensor(0, dtype=diff.dtype)
return math_ops.reduce_all(math_ops.less_equal(zero, diff))
@tf_export(
'math.is_strictly_increasing',
v1=[
'math.is_strictly_increasing', 'debugging.is_strictly_increasing',
'is_strictly_increasing'
])
@deprecation.deprecated_endpoints('debugging.is_strictly_increasing',
'is_strictly_increasing')
def is_strictly_increasing(x, name=None):
"""Returns `True` if `x` is strictly increasing.
Elements of `x` are compared in row-major order. The tensor `[x[0],...]`
is strictly increasing if for every adjacent pair we have `x[i] < x[i+1]`.
If `x` has less than two elements, it is trivially strictly increasing.
See also: `is_non_decreasing`
Args:
x: Numeric `Tensor`.
name: A name for this operation (optional).
Defaults to "is_strictly_increasing"
Returns:
Boolean `Tensor`, equal to `True` iff `x` is strictly increasing.
Raises:
TypeError: if `x` is not a numeric tensor.
"""
with ops.name_scope(name, 'is_strictly_increasing', [x]):
diff = _get_diff_for_monotonic_comparison(x)
# When len(x) = 1, diff = [], less = [], and reduce_all([]) = True.
zero = ops.convert_to_tensor(0, dtype=diff.dtype)
return math_ops.reduce_all(math_ops.less(zero, diff))
def _assert_same_base_type(items, expected_type=None):
r"""Asserts all items are of the same base type.
Args:
items: List of graph items (e.g., `Variable`, `Tensor`, `SparseTensor`,
`Operation`, or `IndexedSlices`). Can include `None` elements, which
will be ignored.
expected_type: Expected type. If not specified, assert all items are
of the same base type.
Returns:
Validated type, or none if neither expected_type nor items provided.
Raises:
ValueError: If any types do not match.
"""
original_expected_type = expected_type
mismatch = False
for item in items:
if item is not None:
item_type = item.dtype.base_dtype
if not expected_type:
expected_type = item_type
elif expected_type != item_type:
mismatch = True
break
if mismatch:
# Loop back through and build up an informative error message (this is very
# slow, so we don't do it unless we found an error above).
expected_type = original_expected_type
original_item_str = None
for item in items:
if item is not None:
item_type = item.dtype.base_dtype
if not expected_type:
expected_type = item_type
original_item_str = item.name if hasattr(item, 'name') else str(item)
elif expected_type != item_type:
raise ValueError('%s, type=%s, must be of the same type (%s)%s.' % (
item.name if hasattr(item, 'name') else str(item),
item_type, expected_type,
(' as %s' % original_item_str) if original_item_str else ''))
return expected_type # Should be unreachable
else:
return expected_type
@tf_export(
'debugging.assert_same_float_dtype',
v1=['debugging.assert_same_float_dtype', 'assert_same_float_dtype'])
@deprecation.deprecated_endpoints('assert_same_float_dtype')
def assert_same_float_dtype(tensors=None, dtype=None):
"""Validate and return float type based on `tensors` and `dtype`.
For ops such as matrix multiplication, inputs and weights must be of the
same float type. This function validates that all `tensors` are the same type,
validates that type is `dtype` (if supplied), and returns the type. Type must
be a floating point type. If neither `tensors` nor `dtype` is supplied,
the function will return `dtypes.float32`.
Args:
tensors: Tensors of input values. Can include `None` elements, which will be
ignored.
dtype: Expected type.
Returns:
Validated type.
Raises:
ValueError: if neither `tensors` nor `dtype` is supplied, or result is not
float, or the common type of the inputs is not a floating point type.
"""
if tensors:
dtype = _assert_same_base_type(tensors, dtype)
if not dtype:
dtype = dtypes.float32
elif not dtype.is_floating:
raise ValueError('Expected floating point type, got %s.' % dtype)
return dtype
@tf_export('debugging.assert_scalar', v1=[])
def assert_scalar_v2(tensor, message=None, name=None):
"""Asserts that the given `tensor` is a scalar.
This function raises `ValueError` unless it can be certain that the given
`tensor` is a scalar. `ValueError` is also raised if the shape of `tensor` is
unknown.
This is always checked statically, so this method returns nothing.
Args:
tensor: A `Tensor`.
message: A string to prefix to the default message.
name: A name for this operation. Defaults to "assert_scalar"
Raises:
ValueError: If the tensor is not scalar (rank 0), or if its shape is
unknown.
"""
assert_scalar(tensor=tensor, message=message, name=name)
@tf_export(v1=['debugging.assert_scalar', 'assert_scalar'])
@deprecation.deprecated_endpoints('assert_scalar')
def assert_scalar(tensor, name=None, message=None):
"""Asserts that the given `tensor` is a scalar (i.e. zero-dimensional).
This function raises `ValueError` unless it can be certain that the given
`tensor` is a scalar. `ValueError` is also raised if the shape of `tensor` is
unknown.
Args:
tensor: A `Tensor`.
name: A name for this operation. Defaults to "assert_scalar"
message: A string to prefix to the default message.
Returns:
The input tensor (potentially converted to a `Tensor`).
Raises:
ValueError: If the tensor is not scalar (rank 0), or if its shape is
unknown.
"""
with ops.name_scope(name, 'assert_scalar', [tensor]) as name_scope:
tensor = ops.convert_to_tensor(tensor, name=name_scope)
shape = tensor.get_shape()
if shape.ndims != 0:
if context.executing_eagerly():
raise ValueError('%sExpected scalar shape, saw shape: %s.'
% (message or '', shape,))
else:
raise ValueError('%sExpected scalar shape for %s, saw shape: %s.'
% (message or '', tensor.name, shape))
return tensor
@tf_export('ensure_shape')
def ensure_shape(x, shape, name=None):
"""Updates the shape of a tensor and checks at runtime that the shape holds.
For example:
```python
x = tf.compat.v1.placeholder(tf.int32)
print(x.shape)
==> TensorShape(None)
y = x * 2
print(y.shape)
==> TensorShape(None)
y = tf.ensure_shape(y, (None, 3, 3))
print(y.shape)
==> TensorShape([Dimension(None), Dimension(3), Dimension(3)])
with tf.compat.v1.Session() as sess:
# Raises tf.errors.InvalidArgumentError, because the shape (3,) is not
# compatible with the shape (None, 3, 3)
sess.run(y, feed_dict={x: [1, 2, 3]})
```
NOTE: This differs from `Tensor.set_shape` in that it sets the static shape
of the resulting tensor and enforces it at runtime, raising an error if the
tensor's runtime shape is incompatible with the specified shape.
`Tensor.set_shape` sets the static shape of the tensor without enforcing it
at runtime, which may result in inconsistencies between the statically-known
shape of tensors and the runtime value of tensors.
Args:
x: A `Tensor`.
shape: A `TensorShape` representing the shape of this tensor, a
`TensorShapeProto`, a list, a tuple, or None.
name: A name for this operation (optional). Defaults to "EnsureShape".
Returns:
A `Tensor`. Has the same type and contents as `x`. At runtime, raises a
`tf.errors.InvalidArgumentError` if `shape` is incompatible with the shape
of `x`.
"""
if not isinstance(shape, tensor_shape.TensorShape):
shape = tensor_shape.TensorShape(shape)
return array_ops.ensure_shape(x, shape, name=name)
@ops.RegisterGradient('EnsureShape')
def _ensure_shape_grad(op, grad):
del op # Unused.
return grad