Ensure convert to tensor is only done once for input args to LinearOperators.

PiperOrigin-RevId: 271228144
diff --git a/tensorflow/python/ops/linalg/linear_operator.py b/tensorflow/python/ops/linalg/linear_operator.py
index 619c468..41a1a7b 100644
--- a/tensorflow/python/ops/linalg/linear_operator.py
+++ b/tensorflow/python/ops/linalg/linear_operator.py
@@ -282,8 +282,10 @@
     """
     return self._shape()
 
-  @abc.abstractmethod
   def _shape_tensor(self):
+    # This is not an abstractmethod, since we want derived classes to be able to
+    # override this with optional kwargs, which can reduce the number of
+    # `convert_to_tensor` calls.  See derived classes for examples.
     raise NotImplementedError("_shape_tensor is not implemented.")
 
   def shape_tensor(self, name="shape_tensor"):
@@ -335,12 +337,17 @@
     """
     # Derived classes get this "for free" once .shape() is implemented.
     with self._name_scope(name):
-      # Prefer to use statically defined shape if available.
-      if self.batch_shape.is_fully_defined():
-        return linear_operator_util.shape_tensor(
-            self.batch_shape.as_list(), name="batch_shape")
-      else:
-        return self.shape_tensor()[:-2]
+      return self._batch_shape_tensor()
+
+  def _batch_shape_tensor(self, shape=None):
+    # `shape` may be passed in if this can be pre-computed in a
+    # more efficient manner, e.g. without excessive Tensor conversions.
+    if self.batch_shape.is_fully_defined():
+      return linear_operator_util.shape_tensor(
+          self.batch_shape.as_list(), name="batch_shape")
+    else:
+      shape = self.shape_tensor() if shape is None else shape
+      return shape[:-2]
 
   @property
   def tensor_rank(self, name="tensor_rank"):
@@ -373,11 +380,16 @@
     """
     # Derived classes get this "for free" once .shape() is implemented.
     with self._name_scope(name):
-      # Prefer to use statically defined shape if available.
-      if self.tensor_rank is not None:
-        return ops.convert_to_tensor(self.tensor_rank)
-      else:
-        return array_ops.size(self.shape_tensor())
+      return self._tensor_rank_tensor()
+
+  def _tensor_rank_tensor(self, shape=None):
+    # `shape` may be passed in if this can be pre-computed in a
+    # more efficient manner, e.g. without excessive Tensor conversions.
+    if self.tensor_rank is not None:
+      return ops.convert_to_tensor(self.tensor_rank)
+    else:
+      shape = self.shape_tensor() if shape is None else shape
+      return array_ops.size(shape)
 
   @property
   def domain_dimension(self):
@@ -411,12 +423,17 @@
     """
     # Derived classes get this "for free" once .shape() is implemented.
     with self._name_scope(name):
-      # Prefer to use statically defined shape if available.
-      dim_value = tensor_shape.dimension_value(self.domain_dimension)
-      if dim_value is not None:
-        return ops.convert_to_tensor(dim_value)
-      else:
-        return self.shape_tensor()[-1]
+      return self._domain_dimension_tensor()
+
+  def _domain_dimension_tensor(self, shape=None):
+    # `shape` may be passed in if this can be pre-computed in a
+    # more efficient manner, e.g. without excessive Tensor conversions.
+    dim_value = tensor_shape.dimension_value(self.domain_dimension)
+    if dim_value is not None:
+      return ops.convert_to_tensor(dim_value)
+    else:
+      shape = self.shape_tensor() if shape is None else shape
+      return shape[-1]
 
   @property
   def range_dimension(self):
@@ -450,12 +467,17 @@
     """
     # Derived classes get this "for free" once .shape() is implemented.
     with self._name_scope(name):
-      # Prefer to use statically defined shape if available.
-      dim_value = tensor_shape.dimension_value(self.range_dimension)
-      if dim_value is not None:
-        return ops.convert_to_tensor(dim_value)
-      else:
-        return self.shape_tensor()[-2]
+      return self._range_dimension_tensor()
+
+  def _range_dimension_tensor(self, shape=None):
+    # `shape` may be passed in if this can be pre-computed in a
+    # more efficient manner, e.g. without excessive Tensor conversions.
+    dim_value = tensor_shape.dimension_value(self.range_dimension)
+    if dim_value is not None:
+      return ops.convert_to_tensor(dim_value)
+    else:
+      shape = self.shape_tensor() if shape is None else shape
+      return shape[-2]
 
   def _assert_non_singular(self):
     """Private default implementation of _assert_non_singular."""
diff --git a/tensorflow/python/ops/linalg/linear_operator_circulant.py b/tensorflow/python/ops/linalg/linear_operator_circulant.py
index 45da3c8..e781445 100644
--- a/tensorflow/python/ops/linalg/linear_operator_circulant.py
+++ b/tensorflow/python/ops/linalg/linear_operator_circulant.py
@@ -114,11 +114,6 @@
             "A [[nested] block] circulant operator is always square.")
       is_square = True
 
-      # If spectrum.shape = [s0, s1, s2], and block_depth = 2,
-      # block_shape = [s1, s2]
-      s_shape = array_ops.shape(self.spectrum)
-      self._block_shape_tensor = s_shape[-self.block_depth:]
-
       super(_BaseLinearOperatorCirculant, self).__init__(
           dtype=dtypes.as_dtype(input_output_dtype),
           graph_parents=[self.spectrum],
@@ -175,7 +170,18 @@
 
   def block_shape_tensor(self):
     """Shape of the block dimensions of `self.spectrum`."""
-    return self._block_shape_tensor
+    # If spectrum.shape = [s0, s1, s2], and block_depth = 2,
+    # block_shape = [s1, s2]
+    return self._block_shape_tensor()
+
+  def _block_shape_tensor(self, spectrum_shape=None):
+    if self.block_shape.is_fully_defined():
+      return linear_operator_util.shape_tensor(
+          self.block_shape.as_list(), name="block_shape")
+    spectrum_shape = (
+        array_ops.shape(self.spectrum)
+        if spectrum_shape is None else spectrum_shape)
+    return spectrum_shape[-self.block_depth:]
 
   @property
   def block_shape(self):
@@ -312,9 +318,10 @@
     n_x_n = tensor_shape.TensorShape([n, n])
     return batch_shape.concatenate(n_x_n)
 
-  def _shape_tensor(self):
+  def _shape_tensor(self, spectrum=None):
+    spectrum = self.spectrum if spectrum is None else spectrum
     # See self.shape for explanation of steps
-    s_shape = array_ops.shape(self._spectrum)
+    s_shape = array_ops.shape(spectrum)
     batch_shape = s_shape[:-self.block_depth]
     trailing_dims = s_shape[-self.block_depth:]
     n = math_ops.reduce_prod(trailing_dims)
@@ -369,20 +376,26 @@
 
   def _broadcast_batch_dims(self, x, spectrum):
     """Broadcast batch dims of batch matrix `x` and spectrum."""
+    spectrum = ops.convert_to_tensor(spectrum, name="spectrum")
     # spectrum.shape = batch_shape + block_shape
     # First make spectrum a batch matrix with
     #   spectrum.shape = batch_shape + [prod(block_shape), 1]
+    batch_shape = self._batch_shape_tensor(
+        shape=self._shape_tensor(spectrum=spectrum))
     spec_mat = array_ops.reshape(
-        spectrum, array_ops.concat(
-            (self.batch_shape_tensor(), [-1, 1]), axis=0))
+        spectrum, array_ops.concat((batch_shape, [-1, 1]), axis=0))
     # Second, broadcast, possibly requiring an addition of array of zeros.
     x, spec_mat = linear_operator_util.broadcast_matrix_batch_dims((x,
                                                                     spec_mat))
     # Third, put the block shape back into spectrum.
-    batch_shape = array_ops.shape(x)[:-2]
+    x_batch_shape = array_ops.shape(x)[:-2]
+    spectrum_shape = array_ops.shape(spectrum)
     spectrum = array_ops.reshape(
         spec_mat,
-        array_ops.concat((batch_shape, self.block_shape_tensor()), axis=0))
+        array_ops.concat(
+            (x_batch_shape,
+             self._block_shape_tensor(spectrum_shape=spectrum_shape)),
+            axis=0))
 
     return x, spectrum
 
diff --git a/tensorflow/python/ops/linalg/linear_operator_householder.py b/tensorflow/python/ops/linalg/linear_operator_householder.py
index 2771d8e..6d1b4de 100644
--- a/tensorflow/python/ops/linalg/linear_operator_householder.py
+++ b/tensorflow/python/ops/linalg/linear_operator_householder.py
@@ -206,9 +206,11 @@
 
   def _trace(self):
     # We have (n - 1) +1 eigenvalues and a single -1 eigenvalue.
+    shape = self.shape_tensor()
     return math_ops.cast(
-        self.domain_dimension_tensor() - 2, self.dtype) * array_ops.ones(
-            shape=self.batch_shape_tensor(), dtype=self.dtype)
+        self._domain_dimension_tensor(shape=shape) - 2,
+        self.dtype) * array_ops.ones(
+            shape=self._batch_shape_tensor(shape=shape), dtype=self.dtype)
 
   def _determinant(self):
     # For householder transformations, the determinant is -1.
@@ -224,16 +226,18 @@
     return self._matmul(rhs, adjoint, adjoint_arg)
 
   def _to_dense(self):
-    normalized_axis = self.reflection_axis / linalg.norm(
-        self.reflection_axis, axis=-1, keepdims=True)
+    reflection_axis = ops.convert_to_tensor(self.reflection_axis)
+    normalized_axis = reflection_axis / linalg.norm(
+        reflection_axis, axis=-1, keepdims=True)
     mat = normalized_axis[..., array_ops.newaxis]
     matrix = -2 * math_ops.matmul(mat, mat, adjoint_b=True)
     return array_ops.matrix_set_diag(
         matrix, 1. + array_ops.matrix_diag_part(matrix))
 
   def _diag_part(self):
-    normalized_axis = self.reflection_axis / linalg.norm(
-        self.reflection_axis, axis=-1, keepdims=True)
+    reflection_axis = ops.convert_to_tensor(self.reflection_axis)
+    normalized_axis = reflection_axis / linalg.norm(
+        reflection_axis, axis=-1, keepdims=True)
     return 1. - 2 * normalized_axis * math_ops.conj(normalized_axis)
 
   @property
diff --git a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py
index 8803d58..f4c75c1 100644
--- a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py
+++ b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py
@@ -333,6 +333,9 @@
     batch_shape = array_ops.broadcast_dynamic_shape(
         self.base_operator.batch_shape_tensor(),
         array_ops.shape(self.u)[:-2])
+    batch_shape = array_ops.broadcast_dynamic_shape(
+        batch_shape,
+        array_ops.shape(self.v)[:-2])
     return array_ops.concat(
         [batch_shape, self.base_operator.shape_tensor()[-2:]], axis=0)
 
diff --git a/tensorflow/python/ops/linalg/linear_operator_toeplitz.py b/tensorflow/python/ops/linalg/linear_operator_toeplitz.py
index 7438b53..be95ce4 100644
--- a/tensorflow/python/ops/linalg/linear_operator_toeplitz.py
+++ b/tensorflow/python/ops/linalg/linear_operator_toeplitz.py
@@ -140,10 +140,8 @@
     """
 
     with ops.name_scope(name, values=[row, col]):
-      self._row = linear_operator_util.convert_nonref_to_tensor(
-          row, name="row")
-      self._col = linear_operator_util.convert_nonref_to_tensor(
-          col, name="col")
+      self._row = linear_operator_util.convert_nonref_to_tensor(row, name="row")
+      self._col = linear_operator_util.convert_nonref_to_tensor(col, name="col")
       self._check_row_col(self._row, self._col)
 
       if is_square is False:  # pylint:disable=g-bool-id-comparison
@@ -178,10 +176,12 @@
         self.row.shape, self.col.shape)
     return v_shape.concatenate(v_shape[-1:])
 
-  def _shape_tensor(self):
+  def _shape_tensor(self, row=None, col=None):
+    row = self.row if row is None else row
+    col = self.col if col is None else col
     v_shape = array_ops.broadcast_dynamic_shape(
-        array_ops.shape(self.row),
-        array_ops.shape(self.col))
+        array_ops.shape(row),
+        array_ops.shape(col))
     k = v_shape[-1]
     return array_ops.concat((v_shape, [k]), 0)
 
@@ -208,17 +208,20 @@
     # for more details.
     x = linalg.adjoint(x) if adjoint_arg else x
     expanded_x = array_ops.concat([x, array_ops.zeros_like(x)], axis=-2)
+    col = ops.convert_to_tensor(self.col)
+    row = ops.convert_to_tensor(self.row)
     circulant_col = array_ops.concat(
-        [self._col,
-         array_ops.zeros_like(self.col[..., 0:1]),
-         array_ops.reverse(self.row[..., 1:], axis=[-1])], axis=-1)
+        [col,
+         array_ops.zeros_like(col[..., 0:1]),
+         array_ops.reverse(row[..., 1:], axis=[-1])], axis=-1)
     circulant = linear_operator_circulant.LinearOperatorCirculant(
         fft_ops.fft(_to_complex(circulant_col)),
-        input_output_dtype=self.row.dtype)
+        input_output_dtype=row.dtype)
     result = circulant.matmul(expanded_x, adjoint=adjoint, adjoint_arg=False)
 
+    shape = self._shape_tensor(row=row, col=col)
     return math_ops.cast(
-        result[..., :self.domain_dimension_tensor(), :],
+        result[..., :self._domain_dimension_tensor(shape=shape), :],
         self.dtype)
 
   def _trace(self):
@@ -232,11 +235,13 @@
         [self.domain_dimension_tensor()], self.dtype)
 
   def _to_dense(self):
+    row = ops.convert_to_tensor(self.row)
+    col = ops.convert_to_tensor(self.col)
     total_shape = array_ops.broadcast_dynamic_shape(
-        array_ops.shape(self.row), array_ops.shape(self.col))
-    n = array_ops.shape(self.row)[-1]
-    row = array_ops.broadcast_to(self.row, total_shape)
-    col = array_ops.broadcast_to(self.col, total_shape)
+        array_ops.shape(row), array_ops.shape(col))
+    n = array_ops.shape(row)[-1]
+    row = array_ops.broadcast_to(row, total_shape)
+    col = array_ops.broadcast_to(col, total_shape)
     # We concatenate the column in reverse order to the row.
     # This gives us 2*n + 1 elements.
     elements = array_ops.concat(