blob: 2d61a536e291cb77e8f8caa2260b3daa83c72966 [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""`LinearOperator` acting like a Toeplitz matrix."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.linalg import linalg_impl as linalg
from tensorflow.python.ops.linalg import linear_operator
from tensorflow.python.ops.linalg import linear_operator_circulant
from tensorflow.python.ops.linalg import linear_operator_util
from tensorflow.python.ops.signal import fft_ops
from tensorflow.python.util.tf_export import tf_export
__all__ = ["LinearOperatorToeplitz",]
class LinearOperatorToeplitz(linear_operator.LinearOperator):
"""`LinearOperator` acting like a [batch] of toeplitz matrices.
This operator acts like a [batch] Toeplitz matrix `A` with shape
`[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a
batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
an `N x N` matrix. This matrix `A` is not materialized, but for
purposes of broadcasting this shape will be relevant.
#### Description in terms of toeplitz matrices
Toeplitz means that `A` has constant diagonals. Hence, `A` can be generated
with two vectors. One represents the first column of the matrix, and the
other represents the first row.
Below is a 4 x 4 example:
A = |a b c d|
|e a b c|
|f e a b|
|g f e a|
#### Example of a Toeplitz operator.
# Create a 3 x 3 Toeplitz operator.
col = [1., 2., 3.]
row = [1., 4., -9.]
operator = LinearOperatorToeplitz(col, row)
==> [[1., 4., -9.],
[2., 1., 4.],
[3., 2., 1.]]
==> [3, 3]
==> scalar Tensor
x = ... Shape [3, 4] Tensor
==> Shape [3, 4] Tensor
#### Shape compatibility
This operator acts on [batch] matrix with compatible shape.
`x` is a batch matrix with compatible shape for `matmul` and `solve` if
operator.shape = [B1,...,Bb] + [N, N], with b >= 0
x.shape = [C1,...,Cc] + [N, R],
and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd]
#### Matrix property hints
This `LinearOperator` is initialized with boolean flags of the form `is_X`,
for `X = non_singular, self_adjoint, positive_definite, square`.
These have the following meaning:
* If `is_X == True`, callers should expect the operator to have the
property `X`. This is a promise that should be fulfilled, but is *not* a
runtime assert. For example, finite floating point precision may result
in these promises being violated.
* If `is_X == False`, callers should expect the operator to not have `X`.
* If `is_X == None` (the default), callers should have no expectation either
def __init__(self,
r"""Initialize a `LinearOperatorToeplitz`.
col: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`.
The first column of the operator. Allowed dtypes: `float16`, `float32`,
`float64`, `complex64`, `complex128`. Note that the first entry of
`col` is assumed to be the same as the first entry of `row`.
row: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`.
The first row of the operator. Allowed dtypes: `float16`, `float32`,
`float64`, `complex64`, `complex128`. Note that the first entry of
`row` is assumed to be the same as the first entry of `col`.
is_non_singular: Expect that this operator is non-singular.
is_self_adjoint: Expect that this operator is equal to its hermitian
transpose. If `diag.dtype` is real, this is auto-set to `True`.
is_positive_definite: Expect that this operator is positive definite,
meaning the quadratic form `x^H A x` has positive real part for all
nonzero `x`. Note that we do not require the operator to be
self-adjoint to be positive-definite. See:
is_square: Expect that this operator acts like square [batch] matrices.
name: A name for this `LinearOperator`.
with ops.name_scope(name, values=[row, col]):
self._row = linear_operator_util.convert_nonref_to_tensor(row, name="row")
self._col = linear_operator_util.convert_nonref_to_tensor(col, name="col")
self._check_row_col(self._row, self._col)
if is_square is False: # pylint:disable=g-bool-id-comparison
raise ValueError("Only square Toeplitz operators currently supported.")
is_square = True
super(LinearOperatorToeplitz, self).__init__(
self._set_graph_parents([self._row, self._col])
def _check_row_col(self, row, col):
"""Static check of row and column."""
for name, tensor in [["row", row], ["col", col]]:
if tensor.shape.ndims is not None and tensor.shape.ndims < 1:
raise ValueError("Argument {} must have at least 1 dimension. "
"Found: {}".format(name, tensor))
if row.shape[-1] is not None and col.shape[-1] is not None:
if row.shape[-1] != col.shape[-1]:
raise ValueError(
"Expected square matrix, got row and col with mismatched "
def _shape(self):
# If d_shape = [5, 3], we return [5, 3, 3].
v_shape = array_ops.broadcast_static_shape(
self.row.shape, self.col.shape)
return v_shape.concatenate(v_shape[-1:])
def _shape_tensor(self, row=None, col=None):
row = self.row if row is None else row
col = self.col if col is None else col
v_shape = array_ops.broadcast_dynamic_shape(
k = v_shape[-1]
return array_ops.concat((v_shape, [k]), 0)
def _assert_self_adjoint(self):
return check_ops.assert_equal(
message=("row and col are not the same, and "
"so this operator is not self-adjoint."))
# TODO(srvasude): Add efficient solver and determinant calculations to this
# class (based on Levinson recursion.)
def _matmul(self, x, adjoint=False, adjoint_arg=False):
# Given a Toeplitz matrix, we can embed it in a Circulant matrix to perform
# efficient matrix multiplications. Given a Toeplitz matrix with first row
# [t_0, t_1, ... t_{n-1}] and first column [t0, t_{-1}, ..., t_{-(n-1)},
# let C by the circulant matrix with first column [t0, t_{-1}, ...,
# t_{-(n-1)}, 0, t_{n-1}, ..., t_1]. Also adjoin to our input vector `x`
# `n` zeros, to make it a vector of length `2n` (call it y). It can be shown
# that if we take the first n entries of `Cy`, this is equal to the Toeplitz
# multiplication. See:
# for more details.
x = linalg.adjoint(x) if adjoint_arg else x
expanded_x = array_ops.concat([x, array_ops.zeros_like(x)], axis=-2)
col = ops.convert_to_tensor_v2_with_dispatch(self.col)
row = ops.convert_to_tensor_v2_with_dispatch(self.row)
circulant_col = array_ops.concat(
array_ops.zeros_like(col[..., 0:1]),
array_ops.reverse(row[..., 1:], axis=[-1])], axis=-1)
circulant = linear_operator_circulant.LinearOperatorCirculant(
result = circulant.matmul(expanded_x, adjoint=adjoint, adjoint_arg=False)
shape = self._shape_tensor(row=row, col=col)
return math_ops.cast(
result[..., :self._domain_dimension_tensor(shape=shape), :],
def _trace(self):
return math_ops.cast(
dtype=self.dtype) * self.col[..., 0]
def _diag_part(self):
diag_entry = self.col[..., 0:1]
return diag_entry * array_ops.ones(
[self.domain_dimension_tensor()], self.dtype)
def _to_dense(self):
row = ops.convert_to_tensor_v2_with_dispatch(self.row)
col = ops.convert_to_tensor_v2_with_dispatch(self.col)
total_shape = array_ops.broadcast_dynamic_shape(
array_ops.shape(row), array_ops.shape(col))
n = array_ops.shape(row)[-1]
row = array_ops.broadcast_to(row, total_shape)
col = array_ops.broadcast_to(col, total_shape)
# We concatenate the column in reverse order to the row.
# This gives us 2*n + 1 elements.
elements = array_ops.concat(
[array_ops.reverse(col, axis=[-1]), row[..., 1:]], axis=-1)
# Given the above vector, the i-th row of the Toeplitz matrix
# is the last n elements of the above vector shifted i right
# (hence the first row is just the row vector provided, and
# the first element of each row will belong to the column vector).
# We construct these set of indices below.
indices = math_ops.mod(
# How much to shift right. This corresponds to `i`.
math_ops.range(0, n) +
# Specifies the last `n` indices.
math_ops.range(n - 1, -1, -1)[..., array_ops.newaxis],
# Mod out by the total number of elements to ensure the index is
# non-negative (for tf.gather) and < 2 * n - 1.
2 * n - 1)
return array_ops.gather(elements, indices, axis=-1)
def col(self):
return self._col
def row(self):
return self._row
def _to_complex(x):
dtype = dtypes.complex64
if x.dtype in [dtypes.float64, dtypes.complex128]:
dtype = dtypes.complex128
return math_ops.cast(x, dtype)