blob: a0f7ead42d63a9585b323be16ed55a486216acc1 [file] [log] [blame]
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""`LinearOperator` acting like the identity matrix."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.linalg import linalg_impl as linalg
from tensorflow.python.ops.linalg import linear_operator
from tensorflow.python.ops.linalg import linear_operator_util
from tensorflow.python.util.tf_export import tf_export
__all__ = [
"LinearOperatorIdentity",
"LinearOperatorScaledIdentity",
]
class BaseLinearOperatorIdentity(linear_operator.LinearOperator):
"""Base class for Identity operators."""
def _check_num_rows_possibly_add_asserts(self):
"""Static check of init arg `num_rows`, possibly add asserts."""
# Possibly add asserts.
if self._assert_proper_shapes:
self._num_rows = control_flow_ops.with_dependencies([
check_ops.assert_rank(
self._num_rows,
0,
message="Argument num_rows must be a 0-D Tensor."),
check_ops.assert_non_negative(
self._num_rows,
message="Argument num_rows must be non-negative."),
], self._num_rows)
# Static checks.
if not self._num_rows.dtype.is_integer:
raise TypeError("Argument num_rows must be integer type. Found:"
" %s" % self._num_rows)
num_rows_static = self._num_rows_static
if num_rows_static is None:
return # Cannot do any other static checks.
if num_rows_static.ndim != 0:
raise ValueError("Argument num_rows must be a 0-D Tensor. Found:"
" %s" % num_rows_static)
if num_rows_static < 0:
raise ValueError("Argument num_rows must be non-negative. Found:"
" %s" % num_rows_static)
def _min_matrix_dim(self):
"""Minimum of domain/range dimension, if statically available, else None."""
domain_dim = tensor_shape.dimension_value(self.domain_dimension)
range_dim = tensor_shape.dimension_value(self.range_dimension)
if domain_dim is None or range_dim is None:
return None
return min(domain_dim, range_dim)
def _min_matrix_dim_tensor(self):
"""Minimum of domain/range dimension, as a tensor."""
return math_ops.reduce_min(self.shape_tensor()[-2:])
def _ones_diag(self):
"""Returns the diagonal of this operator as all ones."""
if self.shape.is_fully_defined():
d_shape = self.batch_shape.concatenate([self._min_matrix_dim()])
else:
d_shape = array_ops.concat(
[self.batch_shape_tensor(),
[self._min_matrix_dim_tensor()]], axis=0)
return array_ops.ones(shape=d_shape, dtype=self.dtype)
@tf_export("linalg.LinearOperatorIdentity")
class LinearOperatorIdentity(BaseLinearOperatorIdentity):
"""`LinearOperator` acting like a [batch] square identity matrix.
This operator acts like a [batch] identity matrix `A` with shape
`[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a
batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
an `N x N` matrix. This matrix `A` is not materialized, but for
purposes of broadcasting this shape will be relevant.
`LinearOperatorIdentity` is initialized with `num_rows`, and optionally
`batch_shape`, and `dtype` arguments. If `batch_shape` is `None`, this
operator efficiently passes through all arguments. If `batch_shape` is
provided, broadcasting may occur, which will require making copies.
```python
# Create a 2 x 2 identity matrix.
operator = LinearOperatorIdentity(num_rows=2, dtype=tf.float32)
operator.to_dense()
==> [[1., 0.]
[0., 1.]]
operator.shape
==> [2, 2]
operator.log_abs_determinant()
==> 0.
x = ... Shape [2, 4] Tensor
operator.matmul(x)
==> Shape [2, 4] Tensor, same as x.
y = tf.random.normal(shape=[3, 2, 4])
# Note that y.shape is compatible with operator.shape because operator.shape
# is broadcast to [3, 2, 2].
# This broadcast does NOT require copying data, since we can infer that y
# will be passed through without changing shape. We are always able to infer
# this if the operator has no batch_shape.
x = operator.solve(y)
==> Shape [3, 2, 4] Tensor, same as y.
# Create a 2-batch of 2x2 identity matrices
operator = LinearOperatorIdentity(num_rows=2, batch_shape=[2])
operator.to_dense()
==> [[[1., 0.]
[0., 1.]],
[[1., 0.]
[0., 1.]]]
# Here, even though the operator has a batch shape, the input is the same as
# the output, so x can be passed through without a copy. The operator is able
# to detect that no broadcast is necessary because both x and the operator
# have statically defined shape.
x = ... Shape [2, 2, 3]
operator.matmul(x)
==> Shape [2, 2, 3] Tensor, same as x
# Here the operator and x have different batch_shape, and are broadcast.
# This requires a copy, since the output is different size than the input.
x = ... Shape [1, 2, 3]
operator.matmul(x)
==> Shape [2, 2, 3] Tensor, equal to [x, x]
```
### Shape compatibility
This operator acts on [batch] matrix with compatible shape.
`x` is a batch matrix with compatible shape for `matmul` and `solve` if
```
operator.shape = [B1,...,Bb] + [N, N], with b >= 0
x.shape = [C1,...,Cc] + [N, R],
and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd]
```
### Performance
If `batch_shape` initialization arg is `None`:
* `operator.matmul(x)` is `O(1)`
* `operator.solve(x)` is `O(1)`
* `operator.determinant()` is `O(1)`
If `batch_shape` initialization arg is provided, and static checks cannot
rule out the need to broadcast:
* `operator.matmul(x)` is `O(D1*...*Dd*N*R)`
* `operator.solve(x)` is `O(D1*...*Dd*N*R)`
* `operator.determinant()` is `O(B1*...*Bb)`
#### Matrix property hints
This `LinearOperator` is initialized with boolean flags of the form `is_X`,
for `X = non_singular, self_adjoint, positive_definite, square`.
These have the following meaning:
* If `is_X == True`, callers should expect the operator to have the
property `X`. This is a promise that should be fulfilled, but is *not* a
runtime assert. For example, finite floating point precision may result
in these promises being violated.
* If `is_X == False`, callers should expect the operator to not have `X`.
* If `is_X == None` (the default), callers should have no expectation either
way.
"""
def __init__(self,
num_rows,
batch_shape=None,
dtype=None,
is_non_singular=True,
is_self_adjoint=True,
is_positive_definite=True,
is_square=True,
assert_proper_shapes=False,
name="LinearOperatorIdentity"):
r"""Initialize a `LinearOperatorIdentity`.
The `LinearOperatorIdentity` is initialized with arguments defining `dtype`
and shape.
This operator is able to broadcast the leading (batch) dimensions, which
sometimes requires copying data. If `batch_shape` is `None`, the operator
can take arguments of any batch shape without copying. See examples.
Args:
num_rows: Scalar non-negative integer `Tensor`. Number of rows in the
corresponding identity matrix.
batch_shape: Optional `1-D` integer `Tensor`. The shape of the leading
dimensions. If `None`, this operator has no leading dimensions.
dtype: Data type of the matrix that this operator represents.
is_non_singular: Expect that this operator is non-singular.
is_self_adjoint: Expect that this operator is equal to its hermitian
transpose.
is_positive_definite: Expect that this operator is positive definite,
meaning the quadratic form `x^H A x` has positive real part for all
nonzero `x`. Note that we do not require the operator to be
self-adjoint to be positive-definite. See:
https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
is_square: Expect that this operator acts like square [batch] matrices.
assert_proper_shapes: Python `bool`. If `False`, only perform static
checks that initialization and method arguments have proper shape.
If `True`, and static checks are inconclusive, add asserts to the graph.
name: A name for this `LinearOperator`
Raises:
ValueError: If `num_rows` is determined statically to be non-scalar, or
negative.
ValueError: If `batch_shape` is determined statically to not be 1-D, or
negative.
ValueError: If any of the following is not `True`:
`{is_self_adjoint, is_non_singular, is_positive_definite}`.
TypeError: If `num_rows` or `batch_shape` is ref-type (e.g. Variable).
"""
dtype = dtype or dtypes.float32
self._assert_proper_shapes = assert_proper_shapes
with ops.name_scope(name):
dtype = dtypes.as_dtype(dtype)
if not is_self_adjoint:
raise ValueError("An identity operator is always self adjoint.")
if not is_non_singular:
raise ValueError("An identity operator is always non-singular.")
if not is_positive_definite:
raise ValueError("An identity operator is always positive-definite.")
if not is_square:
raise ValueError("An identity operator is always square.")
super(LinearOperatorIdentity, self).__init__(
dtype=dtype,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name)
linear_operator_util.assert_not_ref_type(num_rows, "num_rows")
linear_operator_util.assert_not_ref_type(batch_shape, "batch_shape")
self._num_rows = linear_operator_util.shape_tensor(
num_rows, name="num_rows")
self._num_rows_static = tensor_util.constant_value(self._num_rows)
self._check_num_rows_possibly_add_asserts()
if batch_shape is None:
self._batch_shape_arg = None
else:
self._batch_shape_arg = linear_operator_util.shape_tensor(
batch_shape, name="batch_shape_arg")
self._batch_shape_static = tensor_util.constant_value(
self._batch_shape_arg)
self._check_batch_shape_possibly_add_asserts()
def _shape(self):
matrix_shape = tensor_shape.TensorShape((self._num_rows_static,
self._num_rows_static))
if self._batch_shape_arg is None:
return matrix_shape
batch_shape = tensor_shape.TensorShape(self._batch_shape_static)
return batch_shape.concatenate(matrix_shape)
def _shape_tensor(self):
matrix_shape = array_ops.stack((self._num_rows, self._num_rows), axis=0)
if self._batch_shape_arg is None:
return matrix_shape
return array_ops.concat((self._batch_shape_arg, matrix_shape), 0)
def _assert_non_singular(self):
return control_flow_ops.no_op("assert_non_singular")
def _assert_positive_definite(self):
return control_flow_ops.no_op("assert_positive_definite")
def _assert_self_adjoint(self):
return control_flow_ops.no_op("assert_self_adjoint")
def _possibly_broadcast_batch_shape(self, x):
"""Return 'x', possibly after broadcasting the leading dimensions."""
# If we have no batch shape, our batch shape broadcasts with everything!
if self._batch_shape_arg is None:
return x
# Static attempt:
# If we determine that no broadcast is necessary, pass x through
# If we need a broadcast, add to an array of zeros.
#
# special_shape is the shape that, when broadcast with x's shape, will give
# the correct broadcast_shape. Note that
# We have already verified the second to last dimension of self.shape
# matches x's shape in assert_compatible_matrix_dimensions.
# Also, the final dimension of 'x' can have any shape.
# Therefore, the final two dimensions of special_shape are 1's.
special_shape = self.batch_shape.concatenate([1, 1])
bshape = array_ops.broadcast_static_shape(x.shape, special_shape)
if special_shape.is_fully_defined():
# bshape.is_fully_defined iff special_shape.is_fully_defined.
if bshape == x.shape:
return x
# Use the built in broadcasting of addition.
zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype)
return x + zeros
# Dynamic broadcast:
# Always add to an array of zeros, rather than using a "cond", since a
# cond would require copying data from GPU --> CPU.
special_shape = array_ops.concat((self.batch_shape_tensor(), [1, 1]), 0)
zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype)
return x + zeros
def _matmul(self, x, adjoint=False, adjoint_arg=False):
# Note that adjoint has no effect since this matrix is self-adjoint.
x = linalg.adjoint(x) if adjoint_arg else x
if self._assert_proper_shapes:
aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x)
x = control_flow_ops.with_dependencies([aps], x)
return self._possibly_broadcast_batch_shape(x)
def _determinant(self):
return array_ops.ones(shape=self.batch_shape_tensor(), dtype=self.dtype)
def _log_abs_determinant(self):
return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype)
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
return self._matmul(rhs, adjoint_arg=adjoint_arg)
def _trace(self):
# Get Tensor of all ones of same shape as self.batch_shape.
if self.batch_shape.is_fully_defined():
batch_of_ones = array_ops.ones(shape=self.batch_shape, dtype=self.dtype)
else:
batch_of_ones = array_ops.ones(
shape=self.batch_shape_tensor(), dtype=self.dtype)
if self._min_matrix_dim() is not None:
return self._min_matrix_dim() * batch_of_ones
else:
return (math_ops.cast(self._min_matrix_dim_tensor(), self.dtype) *
batch_of_ones)
def _diag_part(self):
return self._ones_diag()
def add_to_tensor(self, mat, name="add_to_tensor"):
"""Add matrix represented by this operator to `mat`. Equiv to `I + mat`.
Args:
mat: `Tensor` with same `dtype` and shape broadcastable to `self`.
name: A name to give this `Op`.
Returns:
A `Tensor` with broadcast shape and same `dtype` as `self`.
"""
with self._name_scope(name):
mat = ops.convert_to_tensor_v2_with_dispatch(mat, name="mat")
mat_diag = array_ops.matrix_diag_part(mat)
new_diag = 1 + mat_diag
return array_ops.matrix_set_diag(mat, new_diag)
def _eigvals(self):
return self._ones_diag()
def _cond(self):
return array_ops.ones(self.batch_shape_tensor(), dtype=self.dtype)
def _check_num_rows_possibly_add_asserts(self):
"""Static check of init arg `num_rows`, possibly add asserts."""
# Possibly add asserts.
if self._assert_proper_shapes:
self._num_rows = control_flow_ops.with_dependencies([
check_ops.assert_rank(
self._num_rows,
0,
message="Argument num_rows must be a 0-D Tensor."),
check_ops.assert_non_negative(
self._num_rows,
message="Argument num_rows must be non-negative."),
], self._num_rows)
# Static checks.
if not self._num_rows.dtype.is_integer:
raise TypeError("Argument num_rows must be integer type. Found:"
" %s" % self._num_rows)
num_rows_static = self._num_rows_static
if num_rows_static is None:
return # Cannot do any other static checks.
if num_rows_static.ndim != 0:
raise ValueError("Argument num_rows must be a 0-D Tensor. Found:"
" %s" % num_rows_static)
if num_rows_static < 0:
raise ValueError("Argument num_rows must be non-negative. Found:"
" %s" % num_rows_static)
def _check_batch_shape_possibly_add_asserts(self):
"""Static check of init arg `batch_shape`, possibly add asserts."""
if self._batch_shape_arg is None:
return
# Possibly add asserts
if self._assert_proper_shapes:
self._batch_shape_arg = control_flow_ops.with_dependencies([
check_ops.assert_rank(
self._batch_shape_arg,
1,
message="Argument batch_shape must be a 1-D Tensor."),
check_ops.assert_non_negative(
self._batch_shape_arg,
message="Argument batch_shape must be non-negative."),
], self._batch_shape_arg)
# Static checks
if not self._batch_shape_arg.dtype.is_integer:
raise TypeError("Argument batch_shape must be integer type. Found:"
" %s" % self._batch_shape_arg)
if self._batch_shape_static is None:
return # Cannot do any other static checks.
if self._batch_shape_static.ndim != 1:
raise ValueError("Argument batch_shape must be a 1-D Tensor. Found:"
" %s" % self._batch_shape_static)
if np.any(self._batch_shape_static < 0):
raise ValueError("Argument batch_shape must be non-negative. Found:"
"%s" % self._batch_shape_static)
@tf_export("linalg.LinearOperatorScaledIdentity")
class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
"""`LinearOperator` acting like a scaled [batch] identity matrix `A = c I`.
This operator acts like a scaled [batch] identity matrix `A` with shape
`[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a
batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
a scaled version of the `N x N` identity matrix.
`LinearOperatorIdentity` is initialized with `num_rows`, and a `multiplier`
(a `Tensor`) of shape `[B1,...,Bb]`. `N` is set to `num_rows`, and the
`multiplier` determines the scale for each batch member.
```python
# Create a 2 x 2 scaled identity matrix.
operator = LinearOperatorIdentity(num_rows=2, multiplier=3.)
operator.to_dense()
==> [[3., 0.]
[0., 3.]]
operator.shape
==> [2, 2]
operator.log_abs_determinant()
==> 2 * Log[3]
x = ... Shape [2, 4] Tensor
operator.matmul(x)
==> 3 * x
y = tf.random.normal(shape=[3, 2, 4])
# Note that y.shape is compatible with operator.shape because operator.shape
# is broadcast to [3, 2, 2].
x = operator.solve(y)
==> 3 * x
# Create a 2-batch of 2x2 identity matrices
operator = LinearOperatorIdentity(num_rows=2, multiplier=5.)
operator.to_dense()
==> [[[5., 0.]
[0., 5.]],
[[5., 0.]
[0., 5.]]]
x = ... Shape [2, 2, 3]
operator.matmul(x)
==> 5 * x
# Here the operator and x have different batch_shape, and are broadcast.
x = ... Shape [1, 2, 3]
operator.matmul(x)
==> 5 * x
```
### Shape compatibility
This operator acts on [batch] matrix with compatible shape.
`x` is a batch matrix with compatible shape for `matmul` and `solve` if
```
operator.shape = [B1,...,Bb] + [N, N], with b >= 0
x.shape = [C1,...,Cc] + [N, R],
and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd]
```
### Performance
* `operator.matmul(x)` is `O(D1*...*Dd*N*R)`
* `operator.solve(x)` is `O(D1*...*Dd*N*R)`
* `operator.determinant()` is `O(D1*...*Dd)`
#### Matrix property hints
This `LinearOperator` is initialized with boolean flags of the form `is_X`,
for `X = non_singular, self_adjoint, positive_definite, square`.
These have the following meaning
* If `is_X == True`, callers should expect the operator to have the
property `X`. This is a promise that should be fulfilled, but is *not* a
runtime assert. For example, finite floating point precision may result
in these promises being violated.
* If `is_X == False`, callers should expect the operator to not have `X`.
* If `is_X == None` (the default), callers should have no expectation either
way.
"""
def __init__(self,
num_rows,
multiplier,
is_non_singular=None,
is_self_adjoint=None,
is_positive_definite=None,
is_square=True,
assert_proper_shapes=False,
name="LinearOperatorScaledIdentity"):
r"""Initialize a `LinearOperatorScaledIdentity`.
The `LinearOperatorScaledIdentity` is initialized with `num_rows`, which
determines the size of each identity matrix, and a `multiplier`,
which defines `dtype`, batch shape, and scale of each matrix.
This operator is able to broadcast the leading (batch) dimensions.
Args:
num_rows: Scalar non-negative integer `Tensor`. Number of rows in the
corresponding identity matrix.
multiplier: `Tensor` of shape `[B1,...,Bb]`, or `[]` (a scalar).
is_non_singular: Expect that this operator is non-singular.
is_self_adjoint: Expect that this operator is equal to its hermitian
transpose.
is_positive_definite: Expect that this operator is positive definite,
meaning the quadratic form `x^H A x` has positive real part for all
nonzero `x`. Note that we do not require the operator to be
self-adjoint to be positive-definite. See:
https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
is_square: Expect that this operator acts like square [batch] matrices.
assert_proper_shapes: Python `bool`. If `False`, only perform static
checks that initialization and method arguments have proper shape.
If `True`, and static checks are inconclusive, add asserts to the graph.
name: A name for this `LinearOperator`
Raises:
ValueError: If `num_rows` is determined statically to be non-scalar, or
negative.
"""
self._assert_proper_shapes = assert_proper_shapes
with ops.name_scope(name, values=[multiplier, num_rows]):
self._multiplier = linear_operator_util.convert_nonref_to_tensor(
multiplier, name="multiplier")
# Check and auto-set hints.
if not self._multiplier.dtype.is_complex:
if is_self_adjoint is False: # pylint: disable=g-bool-id-comparison
raise ValueError("A real diagonal operator is always self adjoint.")
else:
is_self_adjoint = True
if not is_square:
raise ValueError("A ScaledIdentity operator is always square.")
linear_operator_util.assert_not_ref_type(num_rows, "num_rows")
super(LinearOperatorScaledIdentity, self).__init__(
dtype=self._multiplier.dtype.base_dtype,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name)
self._num_rows = linear_operator_util.shape_tensor(
num_rows, name="num_rows")
self._num_rows_static = tensor_util.constant_value(self._num_rows)
self._check_num_rows_possibly_add_asserts()
self._num_rows_cast_to_dtype = math_ops.cast(self._num_rows, self.dtype)
self._num_rows_cast_to_real_dtype = math_ops.cast(self._num_rows,
self.dtype.real_dtype)
def _shape(self):
matrix_shape = tensor_shape.TensorShape((self._num_rows_static,
self._num_rows_static))
batch_shape = self.multiplier.shape
return batch_shape.concatenate(matrix_shape)
def _shape_tensor(self):
matrix_shape = array_ops.stack((self._num_rows, self._num_rows), axis=0)
batch_shape = array_ops.shape(self.multiplier)
return array_ops.concat((batch_shape, matrix_shape), 0)
def _assert_non_singular(self):
return check_ops.assert_positive(
math_ops.abs(self.multiplier), message="LinearOperator was singular")
def _assert_positive_definite(self):
return check_ops.assert_positive(
math_ops.real(self.multiplier),
message="LinearOperator was not positive definite.")
def _assert_self_adjoint(self):
imag_multiplier = math_ops.imag(self.multiplier)
return check_ops.assert_equal(
array_ops.zeros_like(imag_multiplier),
imag_multiplier,
message="LinearOperator was not self-adjoint")
def _make_multiplier_matrix(self, conjugate=False):
# Shape [B1,...Bb, 1, 1]
multiplier_matrix = array_ops.expand_dims(
array_ops.expand_dims(self.multiplier, -1), -1)
if conjugate:
multiplier_matrix = math_ops.conj(multiplier_matrix)
return multiplier_matrix
def _matmul(self, x, adjoint=False, adjoint_arg=False):
x = linalg.adjoint(x) if adjoint_arg else x
if self._assert_proper_shapes:
aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x)
x = control_flow_ops.with_dependencies([aps], x)
return x * self._make_multiplier_matrix(conjugate=adjoint)
def _determinant(self):
return self.multiplier**self._num_rows_cast_to_dtype
def _log_abs_determinant(self):
return self._num_rows_cast_to_real_dtype * math_ops.log(
math_ops.abs(self.multiplier))
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
if self._assert_proper_shapes:
aps = linear_operator_util.assert_compatible_matrix_dimensions(self, rhs)
rhs = control_flow_ops.with_dependencies([aps], rhs)
return rhs / self._make_multiplier_matrix(conjugate=adjoint)
def _trace(self):
# Get Tensor of all ones of same shape as self.batch_shape.
if self.batch_shape.is_fully_defined():
batch_of_ones = array_ops.ones(shape=self.batch_shape, dtype=self.dtype)
else:
batch_of_ones = array_ops.ones(
shape=self.batch_shape_tensor(), dtype=self.dtype)
if self._min_matrix_dim() is not None:
return self.multiplier * self._min_matrix_dim() * batch_of_ones
else:
return (self.multiplier * math_ops.cast(self._min_matrix_dim_tensor(),
self.dtype) * batch_of_ones)
def _diag_part(self):
return self._ones_diag() * self.multiplier[..., array_ops.newaxis]
def add_to_tensor(self, mat, name="add_to_tensor"):
"""Add matrix represented by this operator to `mat`. Equiv to `I + mat`.
Args:
mat: `Tensor` with same `dtype` and shape broadcastable to `self`.
name: A name to give this `Op`.
Returns:
A `Tensor` with broadcast shape and same `dtype` as `self`.
"""
with self._name_scope(name):
# Shape [B1,...,Bb, 1]
multiplier_vector = array_ops.expand_dims(self.multiplier, -1)
# Shape [C1,...,Cc, M, M]
mat = ops.convert_to_tensor_v2_with_dispatch(mat, name="mat")
# Shape [C1,...,Cc, M]
mat_diag = array_ops.matrix_diag_part(mat)
# multiplier_vector broadcasts here.
new_diag = multiplier_vector + mat_diag
return array_ops.matrix_set_diag(mat, new_diag)
def _eigvals(self):
return self._ones_diag() * self.multiplier[..., array_ops.newaxis]
def _cond(self):
# Condition number for a scalar time identity matrix is one, except when the
# scalar is zero.
return array_ops.where_v2(
math_ops.equal(self._multiplier, 0.),
math_ops.cast(np.nan, dtype=self.dtype),
math_ops.cast(1., dtype=self.dtype))
@property
def multiplier(self):
"""The [batch] scalar `Tensor`, `c` in `cI`."""
return self._multiplier