blob: 096ad3fb4bb6953d0e511d59277c0768241e4d80 [file] [log] [blame]
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Internal utilities for `LinearOperator` classes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables as variables_module
from tensorflow.python.util import nest
################################################################################
# To make more friendly for TF2.
################################################################################
def convert_nonref_to_tensor(value, dtype=None, dtype_hint=None, name=None):
"""Converts the given `value` to a `Tensor` if input is nonreference type.
This function converts Python objects of various types to `Tensor` objects
except if the input has nonreference semantics. Reference semantics are
characterized by `is_ref` and is any object which is a
`tf.Variable` or instance of `tf.Module`. This function accepts any input
which `tf.convert_to_tensor` would also.
Note: This function diverges from default Numpy behavior for `float` and
`string` types when `None` is present in a Python list or scalar. Rather
than silently converting `None` values, an error will be thrown.
Args:
value: An object whose type has a registered `Tensor` conversion function.
dtype: Optional element type for the returned tensor. If missing, the
type is inferred from the type of `value`.
dtype_hint: Optional element type for the returned tensor,
used when dtype is None. In some cases, a caller may not have a
dtype in mind when converting to a tensor, so dtype_hint
can be used as a soft preference. If the conversion to
`dtype_hint` is not possible, this argument has no effect.
name: Optional name to use if a new `Tensor` is created.
Returns:
tensor: A `Tensor` based on `value`.
Raises:
TypeError: If no conversion function is registered for `value` to `dtype`.
RuntimeError: If a registered conversion function returns an invalid value.
ValueError: If the `value` is a tensor not of given `dtype` in graph mode.
#### Examples:
```python
x = tf.Variable(0.)
y = convert_nonref_to_tensor(x)
x is y
# ==> True
x = tf.constant(0.)
y = convert_nonref_to_tensor(x)
x is y
# ==> True
x = np.array(0.)
y = convert_nonref_to_tensor(x)
x is y
# ==> False
tf.is_tensor(y)
# ==> True
x = tfp.util.DeferredTensor(13.37, lambda x: x)
y = convert_nonref_to_tensor(x)
x is y
# ==> True
tf.is_tensor(y)
# ==> False
tf.equal(y, 13.37)
# ==> True
```
"""
# We explicitly do not use a tf.name_scope to avoid graph clutter.
if value is None:
return None
if is_ref(value):
if dtype is None:
return value
dtype_base = base_dtype(dtype)
value_dtype_base = base_dtype(value.dtype)
if dtype_base != value_dtype_base:
raise TypeError('Mutable type must be of dtype "{}" but is "{}".'.format(
dtype_name(dtype_base), dtype_name(value_dtype_base)))
return value
return ops.convert_to_tensor_v2_with_dispatch(
value, dtype=dtype, dtype_hint=dtype_hint, name=name)
def base_dtype(dtype):
"""Returns a non-reference `dtype` based on this `dtype`."""
dtype = dtypes.as_dtype(dtype)
if hasattr(dtype, "base_dtype"):
return dtype.base_dtype
return dtype
def dtype_name(dtype):
"""Returns the string name for this `dtype`."""
dtype = dtypes.as_dtype(dtype)
if hasattr(dtype, "name"):
return dtype.name
if hasattr(dtype, "__name__"):
return dtype.__name__
return str(dtype)
def check_dtype(arg, dtype):
"""Check that arg.dtype == self.dtype."""
if arg.dtype.base_dtype != dtype:
raise TypeError(
"Expected argument to have dtype %s. Found: %s in tensor %s" %
(dtype, arg.dtype, arg))
def is_ref(x):
"""Evaluates if the object has reference semantics.
An object is deemed "reference" if it is a `tf.Variable` instance or is
derived from a `tf.Module` with `dtype` and `shape` properties.
Args:
x: Any object.
Returns:
is_ref: Python `bool` indicating input is has nonreference semantics, i.e.,
is a `tf.Variable` or a `tf.Module` with `dtype` and `shape` properties.
"""
return (
# Note: we check that tf.Variable is a class because we might be using a
# different backend other than TF.
isinstance(x, variables_module.Variable) or
(isinstance(x, module.Module) and hasattr(x, "dtype") and
hasattr(x, "shape")))
def assert_not_ref_type(x, arg_name):
if is_ref(x):
raise TypeError(
"Argument %s cannot be reference type. Found: %s" % (arg_name, type(x)))
################################################################################
# Asserts.
################################################################################
def assert_no_entries_with_modulus_zero(
x, message=None, name="assert_no_entries_with_modulus_zero"):
"""Returns `Op` that asserts Tensor `x` has no entries with modulus zero.
Args:
x: Numeric `Tensor`, real, integer, or complex.
message: A string message to prepend to failure message.
name: A name to give this `Op`.
Returns:
An `Op` that asserts `x` has no entries with modulus zero.
"""
with ops.name_scope(name, values=[x]):
x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
dtype = x.dtype.base_dtype
should_be_nonzero = math_ops.abs(x)
zero = ops.convert_to_tensor_v2_with_dispatch(0, dtype=dtype.real_dtype)
return check_ops.assert_less(zero, should_be_nonzero, message=message)
def assert_zero_imag_part(x, message=None, name="assert_zero_imag_part"):
"""Returns `Op` that asserts Tensor `x` has no non-zero imaginary parts.
Args:
x: Numeric `Tensor`, real, integer, or complex.
message: A string message to prepend to failure message.
name: A name to give this `Op`.
Returns:
An `Op` that asserts `x` has no entries with modulus zero.
"""
with ops.name_scope(name, values=[x]):
x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
dtype = x.dtype.base_dtype
if dtype.is_floating:
return control_flow_ops.no_op()
zero = ops.convert_to_tensor_v2_with_dispatch(0, dtype=dtype.real_dtype)
return check_ops.assert_equal(zero, math_ops.imag(x), message=message)
def assert_compatible_matrix_dimensions(operator, x):
"""Assert that an argument to solve/matmul has proper domain dimension.
If `operator.shape[-2:] = [M, N]`, and `x.shape[-2:] = [Q, R]`, then
`operator.matmul(x)` is defined only if `N = Q`. This `Op` returns an
`Assert` that "fires" if this is not the case. Static checks are already
done by the base class `LinearOperator`.
Args:
operator: `LinearOperator`.
x: `Tensor`.
Returns:
`Assert` `Op`.
"""
# Static checks are done in the base class. Only tensor asserts here.
assert_same_dd = check_ops.assert_equal(
array_ops.shape(x)[-2],
operator.domain_dimension_tensor(),
# This error message made to look similar to error raised by static check
# in the base class.
message=("Dimensions are not compatible. "
"shape[-2] of argument to be the same as this operator"))
return assert_same_dd
def assert_is_batch_matrix(tensor):
"""Static assert that `tensor` has rank `2` or higher."""
sh = tensor.shape
if sh.ndims is not None and sh.ndims < 2:
raise ValueError(
"Expected [batch] matrix to have at least two dimensions. Found: "
"%s" % tensor)
def shape_tensor(shape, name=None):
"""Convert Tensor using default type, unless empty list or tuple."""
# Works just like random_ops._ShapeTensor.
if isinstance(shape, (tuple, list)) and not shape:
dtype = dtypes.int32
else:
dtype = None
return ops.convert_to_tensor_v2_with_dispatch(shape, dtype=dtype, name=name)
################################################################################
# Broadcasting versions of common linear algebra functions.
# TODO(b/77519145) Do this more efficiently in some special cases.
################################################################################
def broadcast_matrix_batch_dims(batch_matrices, name=None):
"""Broadcast leading dimensions of zero or more [batch] matrices.
Example broadcasting one batch dim of two simple matrices.
```python
x = [[1, 2],
[3, 4]] # Shape [2, 2], no batch dims
y = [[[1]]] # Shape [1, 1, 1], 1 batch dim of shape [1]
x_bc, y_bc = broadcast_matrix_batch_dims([x, y])
x_bc
==> [[[1, 2],
[3, 4]]] # Shape [1, 2, 2], 1 batch dim of shape [1].
y_bc
==> same as y
```
Example broadcasting many batch dims
```python
x = tf.random.normal(shape=(2, 3, 1, 4, 4))
y = tf.random.normal(shape=(1, 3, 2, 5, 5))
x_bc, y_bc = broadcast_matrix_batch_dims([x, y])
x_bc.shape
==> (2, 3, 2, 4, 4)
y_bc.shape
==> (2, 3, 2, 5, 5)
```
Args:
batch_matrices: Iterable of `Tensor`s, each having two or more dimensions.
name: A string name to prepend to created ops.
Returns:
bcast_matrices: List of `Tensor`s, with `bcast_matrices[i]` containing
the values from `batch_matrices[i]`, with possibly broadcast batch dims.
Raises:
ValueError: If any input `Tensor` is statically determined to have less
than two dimensions.
"""
with ops.name_scope(
name or "broadcast_matrix_batch_dims", values=batch_matrices):
check_ops.assert_proper_iterable(batch_matrices)
batch_matrices = list(batch_matrices)
for i, mat in enumerate(batch_matrices):
batch_matrices[i] = ops.convert_to_tensor_v2_with_dispatch(mat)
assert_is_batch_matrix(batch_matrices[i])
if len(batch_matrices) < 2:
return batch_matrices
# Try static broadcasting.
# bcast_batch_shape is the broadcast batch shape of ALL matrices.
# E.g. if batch_matrices = [x, y], with
# x.shape = [2, j, k] (batch shape = [2])
# y.shape = [3, 1, l, m] (batch shape = [3, 1])
# ==> bcast_batch_shape = [3, 2]
bcast_batch_shape = batch_matrices[0].shape[:-2]
for mat in batch_matrices[1:]:
bcast_batch_shape = array_ops.broadcast_static_shape(
bcast_batch_shape,
mat.shape[:-2])
if bcast_batch_shape.is_fully_defined():
for i, mat in enumerate(batch_matrices):
if mat.shape[:-2] != bcast_batch_shape:
bcast_shape = array_ops.concat(
[bcast_batch_shape.as_list(), array_ops.shape(mat)[-2:]], axis=0)
batch_matrices[i] = array_ops.broadcast_to(mat, bcast_shape)
return batch_matrices
# Since static didn't work, do dynamic, which always copies data.
bcast_batch_shape = array_ops.shape(batch_matrices[0])[:-2]
for mat in batch_matrices[1:]:
bcast_batch_shape = array_ops.broadcast_dynamic_shape(
bcast_batch_shape,
array_ops.shape(mat)[:-2])
for i, mat in enumerate(batch_matrices):
batch_matrices[i] = array_ops.broadcast_to(
mat,
array_ops.concat(
[bcast_batch_shape, array_ops.shape(mat)[-2:]], axis=0))
return batch_matrices
def matrix_solve_with_broadcast(matrix, rhs, adjoint=False, name=None):
"""Solve systems of linear equations."""
with ops.name_scope(name, "MatrixSolveWithBroadcast", [matrix, rhs]):
matrix = ops.convert_to_tensor_v2_with_dispatch(matrix, name="matrix")
rhs = ops.convert_to_tensor_v2_with_dispatch(
rhs, name="rhs", dtype=matrix.dtype)
# If either matrix/rhs has extra dims, we can reshape to get rid of them.
matrix, rhs, reshape_inv, still_need_to_transpose = _reshape_for_efficiency(
matrix, rhs, adjoint_a=adjoint)
# This will broadcast by brute force if we still need to.
matrix, rhs = broadcast_matrix_batch_dims([matrix, rhs])
solution = linalg_ops.matrix_solve(
matrix, rhs, adjoint=adjoint and still_need_to_transpose)
return reshape_inv(solution)
def _reshape_for_efficiency(a,
b,
transpose_a=False,
transpose_b=False,
adjoint_a=False,
adjoint_b=False):
"""Maybe reshape a, b, and return an inverse map. For matmul/solve."""
def identity(x):
return x
# At this point, we have not taken transpose/adjoint of a/b.
still_need_to_transpose = True
if a.shape.ndims is None or b.shape.ndims is None:
return a, b, identity, still_need_to_transpose
# This could be handled in the future, but seems less common.
if a.shape.ndims >= b.shape.ndims:
return a, b, identity, still_need_to_transpose
# From now on, we might modify b, but will not modify a.
# Suppose:
# a.shape = C + [m, n], b.shape =
# b.shape = S + C + [n, r]
b_extra_ndims = b.shape.ndims - a.shape.ndims
# b_extra_sh = S, b_main_sh = C + [n, r]
b_extra_sh = array_ops.shape(b)[:b_extra_ndims]
b_main_sh = array_ops.shape(b)[b_extra_ndims:]
# No reason to flip unless the extra dims of b are big enough. Why?
# Assume adjoint/transpose = False. Then...
# By not flipping, we have to replicate a to shape
# b_extra_sh + a.shape,
# which could use extra memory. But in all cases, the final output has shape
# b_extra_sh + a.shape[:-1] + [b.shape[-1]]
# So we only end up creating a larger object if the end dim of b is smaller
# than the end dim of a. This often happens, e.g. if b was a vector that was
# expanded to a matrix (by appending a singleton).
# Since adjoint/transpose may not be False, we must make adjustments here.
# The dim of b that holds the multiple equations.
a_domain_sz_ = a.shape[-2 if adjoint_a or transpose_a else -1]
b_eq_sz_ = b.shape[-2 if adjoint_b or transpose_b else -1]
b_extra_sz_ = (
np.prod(b.shape[:b_extra_ndims].as_list())
if b.shape[:b_extra_ndims].is_fully_defined() else None)
if (a_domain_sz_ is not None and b_eq_sz_ is not None and
b_extra_sz_ is not None):
if b_extra_sz_ < 2 or a_domain_sz_ <= b_eq_sz_:
return a, b, identity, still_need_to_transpose
# At this point, we're flipping for sure!
# Any transposes/adjoints will happen here explicitly, rather than in calling
# code. Why? To avoid having to write separate complex code for each case.
if adjoint_a:
a = array_ops.matrix_transpose(a, conjugate=True)
elif transpose_a:
a = array_ops.matrix_transpose(a, conjugate=False)
if adjoint_b:
b = array_ops.matrix_transpose(b, conjugate=True)
elif transpose_a:
b = array_ops.matrix_transpose(b, conjugate=False)
still_need_to_transpose = False
# Recompute shapes, since the transpose/adjoint may have changed them.
b_extra_sh = array_ops.shape(b)[:b_extra_ndims]
b_main_sh = array_ops.shape(b)[b_extra_ndims:]
# Permutation to put the extra dims at the end.
perm = (
np.concatenate(
(np.arange(b_extra_ndims, b.shape.ndims),
np.arange(0, b_extra_ndims)), 0))
b_extra_on_end = array_ops.transpose(b, perm=perm)
# Now squash this end into one long dim.
b_squashed_end = array_ops.reshape(
b_extra_on_end, array_ops.concat((b_main_sh[:-1], [-1]), 0))
def reshape_inv(y):
# Expand the extra dims hanging off the end, "b_extra_sh".
# Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y
# Could have different batch dims than a and b, because of broadcasting.
y_extra_shape = array_ops.concat(
(array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0)
y_extra_on_end = array_ops.reshape(y, y_extra_shape)
inverse_perm = np.argsort(perm)
return array_ops.transpose(y_extra_on_end, perm=inverse_perm)
return a, b_squashed_end, reshape_inv, still_need_to_transpose
################################################################################
# Helpers for hints.
################################################################################
def use_operator_or_provided_hint_unless_contradicting(
operator, hint_attr_name, provided_hint_value, message):
"""Get combined hint in the case where operator.hint should equal hint.
Args:
operator: LinearOperator that a meta-operator was initialized with.
hint_attr_name: String name for the attribute.
provided_hint_value: Bool or None. Value passed by user in initialization.
message: Error message to print if hints contradict.
Returns:
True, False, or None.
Raises:
ValueError: If hints contradict.
"""
op_hint = getattr(operator, hint_attr_name)
# pylint: disable=g-bool-id-comparison
if op_hint is False and provided_hint_value:
raise ValueError(message)
if op_hint and provided_hint_value is False:
raise ValueError(message)
if op_hint or provided_hint_value:
return True
if op_hint is False or provided_hint_value is False:
return False
# pylint: enable=g-bool-id-comparison
return None
################################################################################
# Utilities for blockwise operators.
################################################################################
def arg_is_blockwise(block_dimensions, arg, arg_split_dim):
"""Detect if input should be interpreted as a list of blocks."""
# Tuples and lists of length equal to the number of operators may be
# blockwise.
if (isinstance(arg, (tuple, list)) and len(arg) == len(block_dimensions)):
# If the elements of the iterable are not nested, interpret the input as
# blockwise.
if not any(nest.is_nested(x) for x in arg):
return True
else:
arg_dims = [ops.convert_to_tensor_v2_with_dispatch(
x).shape[arg_split_dim] for x in arg]
self_dims = [dim.value for dim in block_dimensions]
# If none of the operator dimensions are known, interpret the input as
# blockwise if its matching dimensions are unequal.
if all(self_d is None for self_d in self_dims):
# A nested tuple/list with a single outermost element is not blockwise
if len(arg_dims) == 1:
return False
elif any(dim != arg_dims[0] for dim in arg_dims):
return True
else:
raise ValueError(
"Parsing of the input structure is ambiguous. Please input "
"a blockwise iterable of `Tensor`s or a single `Tensor`.")
# If input dimensions equal the respective (known) blockwise operator
# dimensions, then the input is blockwise.
if all(self_d == arg_d or self_d is None
for self_d, arg_d in zip(self_dims, arg_dims)):
return True
# If input dimensions equals are all equal, and are greater than or equal
# to the sum of the known operator dimensions, interpret the input as
# blockwise.
# input is not blockwise.
self_dim = sum(self_d for self_d in self_dims if self_d is not None)
if all(s == arg_dims[0] for s in arg_dims) and arg_dims[0] >= self_dim:
return False
# If none of these conditions is met, the input shape is mismatched.
raise ValueError("Input dimension does not match operator dimension.")
else:
return False
def split_arg_into_blocks(block_dims, block_dims_fn, arg, axis=-1):
"""Split `x` into blocks matching `operators`'s `domain_dimension`.
Specifically, if we have a blockwise lower-triangular matrix, with block
sizes along the diagonal `[M_j, M_j] j = 0,1,2..J`, this method splits `arg`
on `axis` into `J` tensors, whose shape at `axis` is `M_j`.
Args:
block_dims: Iterable of `TensorShapes`.
block_dims_fn: Callable returning an iterable of `Tensor`s.
arg: `Tensor`. `arg` is split into `J` tensors.
axis: Python `Integer` representing the axis to split `arg` on.
Returns:
A list of `Tensor`s.
"""
block_sizes = [dim.value for dim in block_dims]
if any(d is None for d in block_sizes):
block_sizes = block_dims_fn()
return array_ops.split(arg, block_sizes, axis=axis)