Ensure convert to tensor is only done once for input args to LinearOperators.
PiperOrigin-RevId: 271228144
diff --git a/tensorflow/python/ops/linalg/linear_operator.py b/tensorflow/python/ops/linalg/linear_operator.py
index 619c468..41a1a7b 100644
--- a/tensorflow/python/ops/linalg/linear_operator.py
+++ b/tensorflow/python/ops/linalg/linear_operator.py
@@ -282,8 +282,10 @@
"""
return self._shape()
- @abc.abstractmethod
def _shape_tensor(self):
+ # This is not an abstractmethod, since we want derived classes to be able to
+ # override this with optional kwargs, which can reduce the number of
+ # `convert_to_tensor` calls. See derived classes for examples.
raise NotImplementedError("_shape_tensor is not implemented.")
def shape_tensor(self, name="shape_tensor"):
@@ -335,12 +337,17 @@
"""
# Derived classes get this "for free" once .shape() is implemented.
with self._name_scope(name):
- # Prefer to use statically defined shape if available.
- if self.batch_shape.is_fully_defined():
- return linear_operator_util.shape_tensor(
- self.batch_shape.as_list(), name="batch_shape")
- else:
- return self.shape_tensor()[:-2]
+ return self._batch_shape_tensor()
+
+ def _batch_shape_tensor(self, shape=None):
+ # `shape` may be passed in if this can be pre-computed in a
+ # more efficient manner, e.g. without excessive Tensor conversions.
+ if self.batch_shape.is_fully_defined():
+ return linear_operator_util.shape_tensor(
+ self.batch_shape.as_list(), name="batch_shape")
+ else:
+ shape = self.shape_tensor() if shape is None else shape
+ return shape[:-2]
@property
def tensor_rank(self, name="tensor_rank"):
@@ -373,11 +380,16 @@
"""
# Derived classes get this "for free" once .shape() is implemented.
with self._name_scope(name):
- # Prefer to use statically defined shape if available.
- if self.tensor_rank is not None:
- return ops.convert_to_tensor(self.tensor_rank)
- else:
- return array_ops.size(self.shape_tensor())
+ return self._tensor_rank_tensor()
+
+ def _tensor_rank_tensor(self, shape=None):
+ # `shape` may be passed in if this can be pre-computed in a
+ # more efficient manner, e.g. without excessive Tensor conversions.
+ if self.tensor_rank is not None:
+ return ops.convert_to_tensor(self.tensor_rank)
+ else:
+ shape = self.shape_tensor() if shape is None else shape
+ return array_ops.size(shape)
@property
def domain_dimension(self):
@@ -411,12 +423,17 @@
"""
# Derived classes get this "for free" once .shape() is implemented.
with self._name_scope(name):
- # Prefer to use statically defined shape if available.
- dim_value = tensor_shape.dimension_value(self.domain_dimension)
- if dim_value is not None:
- return ops.convert_to_tensor(dim_value)
- else:
- return self.shape_tensor()[-1]
+ return self._domain_dimension_tensor()
+
+ def _domain_dimension_tensor(self, shape=None):
+ # `shape` may be passed in if this can be pre-computed in a
+ # more efficient manner, e.g. without excessive Tensor conversions.
+ dim_value = tensor_shape.dimension_value(self.domain_dimension)
+ if dim_value is not None:
+ return ops.convert_to_tensor(dim_value)
+ else:
+ shape = self.shape_tensor() if shape is None else shape
+ return shape[-1]
@property
def range_dimension(self):
@@ -450,12 +467,17 @@
"""
# Derived classes get this "for free" once .shape() is implemented.
with self._name_scope(name):
- # Prefer to use statically defined shape if available.
- dim_value = tensor_shape.dimension_value(self.range_dimension)
- if dim_value is not None:
- return ops.convert_to_tensor(dim_value)
- else:
- return self.shape_tensor()[-2]
+ return self._range_dimension_tensor()
+
+ def _range_dimension_tensor(self, shape=None):
+ # `shape` may be passed in if this can be pre-computed in a
+ # more efficient manner, e.g. without excessive Tensor conversions.
+ dim_value = tensor_shape.dimension_value(self.range_dimension)
+ if dim_value is not None:
+ return ops.convert_to_tensor(dim_value)
+ else:
+ shape = self.shape_tensor() if shape is None else shape
+ return shape[-2]
def _assert_non_singular(self):
"""Private default implementation of _assert_non_singular."""
diff --git a/tensorflow/python/ops/linalg/linear_operator_circulant.py b/tensorflow/python/ops/linalg/linear_operator_circulant.py
index 45da3c8..e781445 100644
--- a/tensorflow/python/ops/linalg/linear_operator_circulant.py
+++ b/tensorflow/python/ops/linalg/linear_operator_circulant.py
@@ -114,11 +114,6 @@
"A [[nested] block] circulant operator is always square.")
is_square = True
- # If spectrum.shape = [s0, s1, s2], and block_depth = 2,
- # block_shape = [s1, s2]
- s_shape = array_ops.shape(self.spectrum)
- self._block_shape_tensor = s_shape[-self.block_depth:]
-
super(_BaseLinearOperatorCirculant, self).__init__(
dtype=dtypes.as_dtype(input_output_dtype),
graph_parents=[self.spectrum],
@@ -175,7 +170,18 @@
def block_shape_tensor(self):
"""Shape of the block dimensions of `self.spectrum`."""
- return self._block_shape_tensor
+ # If spectrum.shape = [s0, s1, s2], and block_depth = 2,
+ # block_shape = [s1, s2]
+ return self._block_shape_tensor()
+
+ def _block_shape_tensor(self, spectrum_shape=None):
+ if self.block_shape.is_fully_defined():
+ return linear_operator_util.shape_tensor(
+ self.block_shape.as_list(), name="block_shape")
+ spectrum_shape = (
+ array_ops.shape(self.spectrum)
+ if spectrum_shape is None else spectrum_shape)
+ return spectrum_shape[-self.block_depth:]
@property
def block_shape(self):
@@ -312,9 +318,10 @@
n_x_n = tensor_shape.TensorShape([n, n])
return batch_shape.concatenate(n_x_n)
- def _shape_tensor(self):
+ def _shape_tensor(self, spectrum=None):
+ spectrum = self.spectrum if spectrum is None else spectrum
# See self.shape for explanation of steps
- s_shape = array_ops.shape(self._spectrum)
+ s_shape = array_ops.shape(spectrum)
batch_shape = s_shape[:-self.block_depth]
trailing_dims = s_shape[-self.block_depth:]
n = math_ops.reduce_prod(trailing_dims)
@@ -369,20 +376,26 @@
def _broadcast_batch_dims(self, x, spectrum):
"""Broadcast batch dims of batch matrix `x` and spectrum."""
+ spectrum = ops.convert_to_tensor(spectrum, name="spectrum")
# spectrum.shape = batch_shape + block_shape
# First make spectrum a batch matrix with
# spectrum.shape = batch_shape + [prod(block_shape), 1]
+ batch_shape = self._batch_shape_tensor(
+ shape=self._shape_tensor(spectrum=spectrum))
spec_mat = array_ops.reshape(
- spectrum, array_ops.concat(
- (self.batch_shape_tensor(), [-1, 1]), axis=0))
+ spectrum, array_ops.concat((batch_shape, [-1, 1]), axis=0))
# Second, broadcast, possibly requiring an addition of array of zeros.
x, spec_mat = linear_operator_util.broadcast_matrix_batch_dims((x,
spec_mat))
# Third, put the block shape back into spectrum.
- batch_shape = array_ops.shape(x)[:-2]
+ x_batch_shape = array_ops.shape(x)[:-2]
+ spectrum_shape = array_ops.shape(spectrum)
spectrum = array_ops.reshape(
spec_mat,
- array_ops.concat((batch_shape, self.block_shape_tensor()), axis=0))
+ array_ops.concat(
+ (x_batch_shape,
+ self._block_shape_tensor(spectrum_shape=spectrum_shape)),
+ axis=0))
return x, spectrum
diff --git a/tensorflow/python/ops/linalg/linear_operator_householder.py b/tensorflow/python/ops/linalg/linear_operator_householder.py
index 2771d8e..6d1b4de 100644
--- a/tensorflow/python/ops/linalg/linear_operator_householder.py
+++ b/tensorflow/python/ops/linalg/linear_operator_householder.py
@@ -206,9 +206,11 @@
def _trace(self):
# We have (n - 1) +1 eigenvalues and a single -1 eigenvalue.
+ shape = self.shape_tensor()
return math_ops.cast(
- self.domain_dimension_tensor() - 2, self.dtype) * array_ops.ones(
- shape=self.batch_shape_tensor(), dtype=self.dtype)
+ self._domain_dimension_tensor(shape=shape) - 2,
+ self.dtype) * array_ops.ones(
+ shape=self._batch_shape_tensor(shape=shape), dtype=self.dtype)
def _determinant(self):
# For householder transformations, the determinant is -1.
@@ -224,16 +226,18 @@
return self._matmul(rhs, adjoint, adjoint_arg)
def _to_dense(self):
- normalized_axis = self.reflection_axis / linalg.norm(
- self.reflection_axis, axis=-1, keepdims=True)
+ reflection_axis = ops.convert_to_tensor(self.reflection_axis)
+ normalized_axis = reflection_axis / linalg.norm(
+ reflection_axis, axis=-1, keepdims=True)
mat = normalized_axis[..., array_ops.newaxis]
matrix = -2 * math_ops.matmul(mat, mat, adjoint_b=True)
return array_ops.matrix_set_diag(
matrix, 1. + array_ops.matrix_diag_part(matrix))
def _diag_part(self):
- normalized_axis = self.reflection_axis / linalg.norm(
- self.reflection_axis, axis=-1, keepdims=True)
+ reflection_axis = ops.convert_to_tensor(self.reflection_axis)
+ normalized_axis = reflection_axis / linalg.norm(
+ reflection_axis, axis=-1, keepdims=True)
return 1. - 2 * normalized_axis * math_ops.conj(normalized_axis)
@property
diff --git a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py
index 8803d58..f4c75c1 100644
--- a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py
+++ b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py
@@ -333,6 +333,9 @@
batch_shape = array_ops.broadcast_dynamic_shape(
self.base_operator.batch_shape_tensor(),
array_ops.shape(self.u)[:-2])
+ batch_shape = array_ops.broadcast_dynamic_shape(
+ batch_shape,
+ array_ops.shape(self.v)[:-2])
return array_ops.concat(
[batch_shape, self.base_operator.shape_tensor()[-2:]], axis=0)
diff --git a/tensorflow/python/ops/linalg/linear_operator_toeplitz.py b/tensorflow/python/ops/linalg/linear_operator_toeplitz.py
index 7438b53..be95ce4 100644
--- a/tensorflow/python/ops/linalg/linear_operator_toeplitz.py
+++ b/tensorflow/python/ops/linalg/linear_operator_toeplitz.py
@@ -140,10 +140,8 @@
"""
with ops.name_scope(name, values=[row, col]):
- self._row = linear_operator_util.convert_nonref_to_tensor(
- row, name="row")
- self._col = linear_operator_util.convert_nonref_to_tensor(
- col, name="col")
+ self._row = linear_operator_util.convert_nonref_to_tensor(row, name="row")
+ self._col = linear_operator_util.convert_nonref_to_tensor(col, name="col")
self._check_row_col(self._row, self._col)
if is_square is False: # pylint:disable=g-bool-id-comparison
@@ -178,10 +176,12 @@
self.row.shape, self.col.shape)
return v_shape.concatenate(v_shape[-1:])
- def _shape_tensor(self):
+ def _shape_tensor(self, row=None, col=None):
+ row = self.row if row is None else row
+ col = self.col if col is None else col
v_shape = array_ops.broadcast_dynamic_shape(
- array_ops.shape(self.row),
- array_ops.shape(self.col))
+ array_ops.shape(row),
+ array_ops.shape(col))
k = v_shape[-1]
return array_ops.concat((v_shape, [k]), 0)
@@ -208,17 +208,20 @@
# for more details.
x = linalg.adjoint(x) if adjoint_arg else x
expanded_x = array_ops.concat([x, array_ops.zeros_like(x)], axis=-2)
+ col = ops.convert_to_tensor(self.col)
+ row = ops.convert_to_tensor(self.row)
circulant_col = array_ops.concat(
- [self._col,
- array_ops.zeros_like(self.col[..., 0:1]),
- array_ops.reverse(self.row[..., 1:], axis=[-1])], axis=-1)
+ [col,
+ array_ops.zeros_like(col[..., 0:1]),
+ array_ops.reverse(row[..., 1:], axis=[-1])], axis=-1)
circulant = linear_operator_circulant.LinearOperatorCirculant(
fft_ops.fft(_to_complex(circulant_col)),
- input_output_dtype=self.row.dtype)
+ input_output_dtype=row.dtype)
result = circulant.matmul(expanded_x, adjoint=adjoint, adjoint_arg=False)
+ shape = self._shape_tensor(row=row, col=col)
return math_ops.cast(
- result[..., :self.domain_dimension_tensor(), :],
+ result[..., :self._domain_dimension_tensor(shape=shape), :],
self.dtype)
def _trace(self):
@@ -232,11 +235,13 @@
[self.domain_dimension_tensor()], self.dtype)
def _to_dense(self):
+ row = ops.convert_to_tensor(self.row)
+ col = ops.convert_to_tensor(self.col)
total_shape = array_ops.broadcast_dynamic_shape(
- array_ops.shape(self.row), array_ops.shape(self.col))
- n = array_ops.shape(self.row)[-1]
- row = array_ops.broadcast_to(self.row, total_shape)
- col = array_ops.broadcast_to(self.col, total_shape)
+ array_ops.shape(row), array_ops.shape(col))
+ n = array_ops.shape(row)[-1]
+ row = array_ops.broadcast_to(row, total_shape)
+ col = array_ops.broadcast_to(col, total_shape)
# We concatenate the column in reverse order to the row.
# This gives us 2*n + 1 elements.
elements = array_ops.concat(