| # Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """`LinearOperator` that wraps a [batch] matrix.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import ops |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops.linalg import linear_operator |
| from tensorflow.python.ops.linalg import linear_operator_util |
| from tensorflow.python.util.tf_export import tf_export |
| |
| __all__ = ["LinearOperatorFullMatrix"] |
| |
| |
| @tf_export("linalg.LinearOperatorFullMatrix") |
| class LinearOperatorFullMatrix(linear_operator.LinearOperator): |
| """`LinearOperator` that wraps a [batch] matrix. |
| |
| This operator wraps a [batch] matrix `A` (which is a `Tensor`) with shape |
| `[B1,...,Bb, M, N]` for some `b >= 0`. The first `b` indices index a |
| batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is |
| an `M x N` matrix. |
| |
| ```python |
| # Create a 2 x 2 linear operator. |
| matrix = [[1., 2.], [3., 4.]] |
| operator = LinearOperatorFullMatrix(matrix) |
| |
| operator.to_dense() |
| ==> [[1., 2.] |
| [3., 4.]] |
| |
| operator.shape |
| ==> [2, 2] |
| |
| operator.log_abs_determinant() |
| ==> scalar Tensor |
| |
| x = ... Shape [2, 4] Tensor |
| operator.matmul(x) |
| ==> Shape [2, 4] Tensor |
| |
| # Create a [2, 3] batch of 4 x 4 linear operators. |
| matrix = tf.random.normal(shape=[2, 3, 4, 4]) |
| operator = LinearOperatorFullMatrix(matrix) |
| ``` |
| |
| #### Shape compatibility |
| |
| This operator acts on [batch] matrix with compatible shape. |
| `x` is a batch matrix with compatible shape for `matmul` and `solve` if |
| |
| ``` |
| operator.shape = [B1,...,Bb] + [M, N], with b >= 0 |
| x.shape = [B1,...,Bb] + [N, R], with R >= 0. |
| ``` |
| |
| #### Performance |
| |
| `LinearOperatorFullMatrix` has exactly the same performance as would be |
| achieved by using standard `TensorFlow` matrix ops. Intelligent choices are |
| made based on the following initialization hints. |
| |
| * If `dtype` is real, and `is_self_adjoint` and `is_positive_definite`, a |
| Cholesky factorization is used for the determinant and solve. |
| |
| In all cases, suppose `operator` is a `LinearOperatorFullMatrix` of shape |
| `[M, N]`, and `x.shape = [N, R]`. Then |
| |
| * `operator.matmul(x)` is `O(M * N * R)`. |
| * If `M=N`, `operator.solve(x)` is `O(N^3 * R)`. |
| * If `M=N`, `operator.determinant()` is `O(N^3)`. |
| |
| If instead `operator` and `x` have shape `[B1,...,Bb, M, N]` and |
| `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`. |
| |
| #### Matrix property hints |
| |
| This `LinearOperator` is initialized with boolean flags of the form `is_X`, |
| for `X = non_singular, self_adjoint, positive_definite, square`. |
| These have the following meaning: |
| |
| * If `is_X == True`, callers should expect the operator to have the |
| property `X`. This is a promise that should be fulfilled, but is *not* a |
| runtime assert. For example, finite floating point precision may result |
| in these promises being violated. |
| * If `is_X == False`, callers should expect the operator to not have `X`. |
| * If `is_X == None` (the default), callers should have no expectation either |
| way. |
| """ |
| |
| def __init__(self, |
| matrix, |
| is_non_singular=None, |
| is_self_adjoint=None, |
| is_positive_definite=None, |
| is_square=None, |
| name="LinearOperatorFullMatrix"): |
| r"""Initialize a `LinearOperatorFullMatrix`. |
| |
| Args: |
| matrix: Shape `[B1,...,Bb, M, N]` with `b >= 0`, `M, N >= 0`. |
| Allowed dtypes: `float16`, `float32`, `float64`, `complex64`, |
| `complex128`. |
| is_non_singular: Expect that this operator is non-singular. |
| is_self_adjoint: Expect that this operator is equal to its hermitian |
| transpose. |
| is_positive_definite: Expect that this operator is positive definite, |
| meaning the quadratic form `x^H A x` has positive real part for all |
| nonzero `x`. Note that we do not require the operator to be |
| self-adjoint to be positive-definite. See: |
| https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices |
| is_square: Expect that this operator acts like square [batch] matrices. |
| name: A name for this `LinearOperator`. |
| |
| Raises: |
| TypeError: If `diag.dtype` is not an allowed type. |
| """ |
| |
| with ops.name_scope(name, values=[matrix]): |
| self._matrix = linear_operator_util.convert_nonref_to_tensor( |
| matrix, name="matrix") |
| self._check_matrix(self._matrix) |
| |
| super(LinearOperatorFullMatrix, self).__init__( |
| dtype=self._matrix.dtype, |
| graph_parents=None, |
| is_non_singular=is_non_singular, |
| is_self_adjoint=is_self_adjoint, |
| is_positive_definite=is_positive_definite, |
| is_square=is_square, |
| name=name) |
| # TODO(b/143910018) Remove graph_parents in V3. |
| self._set_graph_parents([self._matrix]) |
| |
| def _check_matrix(self, matrix): |
| """Static check of the `matrix` argument.""" |
| allowed_dtypes = [ |
| dtypes.float16, |
| dtypes.float32, |
| dtypes.float64, |
| dtypes.complex64, |
| dtypes.complex128, |
| ] |
| |
| matrix = ops.convert_to_tensor_v2_with_dispatch(matrix, name="matrix") |
| |
| dtype = matrix.dtype |
| if dtype not in allowed_dtypes: |
| raise TypeError( |
| "Argument matrix must have dtype in %s. Found: %s" |
| % (allowed_dtypes, dtype)) |
| |
| if matrix.shape.ndims is not None and matrix.shape.ndims < 2: |
| raise ValueError( |
| "Argument matrix must have at least 2 dimensions. Found: %s" |
| % matrix) |
| |
| def _shape(self): |
| return self._matrix.shape |
| |
| def _shape_tensor(self): |
| return array_ops.shape(self._matrix) |
| |
| def _matmul(self, x, adjoint=False, adjoint_arg=False): |
| return math_ops.matmul( |
| self._matrix, x, adjoint_a=adjoint, adjoint_b=adjoint_arg) |
| |
| def _solve(self, rhs, adjoint=False, adjoint_arg=False): |
| return self._dense_solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg) |
| |
| def _to_dense(self): |
| return self._matrix |