The `linalg.LinearOperator*` Module APIs do not support top-level dispatching because they are classes w/ methods instead of top-level methods in TF's APIs. But, their class methods call out to APIs that do support dispatching.
This CL updates the convert_to_tensor calls in the `linalg.LinearOperator*` APIs to use the publicly exposed, dispatching `convert_to_tensor_v2_with_dispatch`, which enables the Operators to effectively work with dispatching as the APIs they call out to support dispatching as well.
PiperOrigin-RevId: 324834645
Change-Id: If2e9f17be101e74f8835497d8ca51a0174055053
diff --git a/tensorflow/python/ops/linalg/linear_operator.py b/tensorflow/python/ops/linalg/linear_operator.py
index 8e1967f..cf14cdb 100644
--- a/tensorflow/python/ops/linalg/linear_operator.py
+++ b/tensorflow/python/ops/linalg/linear_operator.py
@@ -385,7 +385,7 @@
# `shape` may be passed in if this can be pre-computed in a
# more efficient manner, e.g. without excessive Tensor conversions.
if self.tensor_rank is not None:
- return ops.convert_to_tensor(self.tensor_rank)
+ return ops.convert_to_tensor_v2_with_dispatch(self.tensor_rank)
else:
shape = self.shape_tensor() if shape is None else shape
return array_ops.size(shape)
@@ -429,7 +429,7 @@
# more efficient manner, e.g. without excessive Tensor conversions.
dim_value = tensor_shape.dimension_value(self.domain_dimension)
if dim_value is not None:
- return ops.convert_to_tensor(dim_value)
+ return ops.convert_to_tensor_v2_with_dispatch(dim_value)
else:
shape = self.shape_tensor() if shape is None else shape
return shape[-1]
@@ -473,7 +473,7 @@
# more efficient manner, e.g. without excessive Tensor conversions.
dim_value = tensor_shape.dimension_value(self.range_dimension)
if dim_value is not None:
- return ops.convert_to_tensor(dim_value)
+ return ops.convert_to_tensor_v2_with_dispatch(dim_value)
else:
shape = self.shape_tensor() if shape is None else shape
return shape[-2]
@@ -641,7 +641,7 @@
return linear_operator_algebra.matmul(left_operator, right_operator)
with self._name_scope(name):
- x = ops.convert_to_tensor(x, name="x")
+ x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
self._check_input_dtype(x)
self_dim = -2 if adjoint else -1
@@ -688,7 +688,7 @@
A `Tensor` with shape `[..., M]` and same `dtype` as `self`.
"""
with self._name_scope(name):
- x = ops.convert_to_tensor(x, name="x")
+ x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
self._check_input_dtype(x)
self_dim = -2 if adjoint else -1
tensor_shape.dimension_at_index(
@@ -834,7 +834,7 @@
return linear_operator_algebra.solve(left_operator, right_operator)
with self._name_scope(name):
- rhs = ops.convert_to_tensor(rhs, name="rhs")
+ rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
self._check_input_dtype(rhs)
self_dim = -1 if adjoint else -2
@@ -891,7 +891,7 @@
NotImplementedError: If `self.is_non_singular` or `is_square` is False.
"""
with self._name_scope(name):
- rhs = ops.convert_to_tensor(rhs, name="rhs")
+ rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
self._check_input_dtype(rhs)
self_dim = -1 if adjoint else -2
tensor_shape.dimension_at_index(
@@ -1054,7 +1054,7 @@
A `Tensor` with broadcast shape and same `dtype` as `self`.
"""
with self._name_scope(name):
- x = ops.convert_to_tensor(x, name="x")
+ x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
self._check_input_dtype(x)
return self._add_to_tensor(x)
diff --git a/tensorflow/python/ops/linalg/linear_operator_block_diag.py b/tensorflow/python/ops/linalg/linear_operator_block_diag.py
index 7c50d00..7afa15a 100644
--- a/tensorflow/python/ops/linalg/linear_operator_block_diag.py
+++ b/tensorflow/python/ops/linalg/linear_operator_block_diag.py
@@ -263,7 +263,7 @@
def _shape_tensor(self):
# Avoid messy broadcasting if possible.
if self.shape.is_fully_defined():
- return ops.convert_to_tensor(
+ return ops.convert_to_tensor_v2_with_dispatch(
self.shape.as_list(), dtype=dtypes.int32, name="shape")
domain_dimension = sum(self._block_domain_dimension_tensors())
@@ -330,12 +330,12 @@
if linear_operator_util.arg_is_blockwise(block_dimensions, x, arg_dim):
for i, block in enumerate(x):
if not isinstance(block, linear_operator.LinearOperator):
- block = ops.convert_to_tensor(block)
+ block = ops.convert_to_tensor_v2_with_dispatch(block)
self._check_input_dtype(block)
block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim])
x[i] = block
else:
- x = ops.convert_to_tensor(x, name="x")
+ x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
self._check_input_dtype(x)
op_dimension = (self.range_dimension if adjoint
else self.domain_dimension)
@@ -404,7 +404,7 @@
if linear_operator_util.arg_is_blockwise(block_dimensions, x, -1):
for i, block in enumerate(x):
if not isinstance(block, linear_operator.LinearOperator):
- block = ops.convert_to_tensor(block)
+ block = ops.convert_to_tensor_v2_with_dispatch(block)
self._check_input_dtype(block)
block_dimensions[i].assert_is_compatible_with(block.shape[-1])
x[i] = block
@@ -412,7 +412,7 @@
y_mat = self.matmul(x_mat, adjoint=adjoint)
return [array_ops.squeeze(y, axis=-1) for y in y_mat]
- x = ops.convert_to_tensor(x, name="x")
+ x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
self._check_input_dtype(x)
op_dimension = (self.range_dimension if adjoint
else self.domain_dimension)
@@ -508,12 +508,12 @@
split_rhs = rhs
for i, block in enumerate(split_rhs):
if not isinstance(block, linear_operator.LinearOperator):
- block = ops.convert_to_tensor(block)
+ block = ops.convert_to_tensor_v2_with_dispatch(block)
self._check_input_dtype(block)
block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim])
split_rhs[i] = block
else:
- rhs = ops.convert_to_tensor(rhs, name="rhs")
+ rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
self._check_input_dtype(rhs)
op_dimension = (self.domain_dimension if adjoint
else self.range_dimension)
@@ -583,7 +583,7 @@
if linear_operator_util.arg_is_blockwise(block_dimensions, rhs, -1):
for i, block in enumerate(rhs):
if not isinstance(block, linear_operator.LinearOperator):
- block = ops.convert_to_tensor(block)
+ block = ops.convert_to_tensor_v2_with_dispatch(block)
self._check_input_dtype(block)
block_dimensions[i].assert_is_compatible_with(block.shape[-1])
rhs[i] = block
@@ -591,7 +591,7 @@
solution_mat = self.solve(rhs_mat, adjoint=adjoint)
return [array_ops.squeeze(x, axis=-1) for x in solution_mat]
- rhs = ops.convert_to_tensor(rhs, name="rhs")
+ rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
self._check_input_dtype(rhs)
op_dimension = (self.domain_dimension if adjoint
else self.range_dimension)
diff --git a/tensorflow/python/ops/linalg/linear_operator_block_lower_triangular.py b/tensorflow/python/ops/linalg/linear_operator_block_lower_triangular.py
index b4bf8bd..84f2ff1 100644
--- a/tensorflow/python/ops/linalg/linear_operator_block_lower_triangular.py
+++ b/tensorflow/python/ops/linalg/linear_operator_block_lower_triangular.py
@@ -366,7 +366,7 @@
def _shape_tensor(self):
# Avoid messy broadcasting if possible.
if self.shape.is_fully_defined():
- return ops.convert_to_tensor(
+ return ops.convert_to_tensor_v2_with_dispatch(
self.shape.as_list(), dtype=dtypes.int32, name="shape")
domain_dimension = sum(self._block_domain_dimension_tensors())
@@ -433,12 +433,12 @@
if linear_operator_util.arg_is_blockwise(block_dimensions, x, arg_dim):
for i, block in enumerate(x):
if not isinstance(block, linear_operator.LinearOperator):
- block = ops.convert_to_tensor(block)
+ block = ops.convert_to_tensor_v2_with_dispatch(block)
self._check_input_dtype(block)
block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim])
x[i] = block
else:
- x = ops.convert_to_tensor(x, name="x")
+ x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
self._check_input_dtype(x)
op_dimension = (self.range_dimension if adjoint
else self.domain_dimension)
@@ -543,7 +543,7 @@
if linear_operator_util.arg_is_blockwise(block_dimensions, x, -1):
for i, block in enumerate(x):
if not isinstance(block, linear_operator.LinearOperator):
- block = ops.convert_to_tensor(block)
+ block = ops.convert_to_tensor_v2_with_dispatch(block)
self._check_input_dtype(block)
block_dimensions[i].assert_is_compatible_with(block.shape[-1])
x[i] = block
@@ -551,7 +551,7 @@
y_mat = self.matmul(x_mat, adjoint=adjoint)
return [array_ops.squeeze(y, axis=-1) for y in y_mat]
- x = ops.convert_to_tensor(x, name="x")
+ x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
self._check_input_dtype(x)
op_dimension = (self.range_dimension if adjoint
else self.domain_dimension)
@@ -674,7 +674,7 @@
if blockwise_arg:
for i, block in enumerate(rhs):
if not isinstance(block, linear_operator.LinearOperator):
- block = ops.convert_to_tensor(block)
+ block = ops.convert_to_tensor_v2_with_dispatch(block)
self._check_input_dtype(block)
block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim])
rhs[i] = block
@@ -684,7 +684,7 @@
split_rhs = rhs
else:
- rhs = ops.convert_to_tensor(rhs, name="rhs")
+ rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
self._check_input_dtype(rhs)
op_dimension = (self.domain_dimension if adjoint
else self.range_dimension)
@@ -795,14 +795,14 @@
if linear_operator_util.arg_is_blockwise(block_dimensions, rhs, -1):
for i, block in enumerate(rhs):
if not isinstance(block, linear_operator.LinearOperator):
- block = ops.convert_to_tensor(block)
+ block = ops.convert_to_tensor_v2_with_dispatch(block)
self._check_input_dtype(block)
block_dimensions[i].assert_is_compatible_with(block.shape[-1])
rhs[i] = block
rhs_mat = [array_ops.expand_dims(block, axis=-1) for block in rhs]
solution_mat = self.solve(rhs_mat, adjoint=adjoint)
return [array_ops.squeeze(x, axis=-1) for x in solution_mat]
- rhs = ops.convert_to_tensor(rhs, name="rhs")
+ rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
self._check_input_dtype(rhs)
op_dimension = (self.domain_dimension if adjoint
else self.range_dimension)
diff --git a/tensorflow/python/ops/linalg/linear_operator_circulant.py b/tensorflow/python/ops/linalg/linear_operator_circulant.py
index ace2769..d4b671c 100644
--- a/tensorflow/python/ops/linalg/linear_operator_circulant.py
+++ b/tensorflow/python/ops/linalg/linear_operator_circulant.py
@@ -378,7 +378,7 @@
def _broadcast_batch_dims(self, x, spectrum):
"""Broadcast batch dims of batch matrix `x` and spectrum."""
- spectrum = ops.convert_to_tensor(spectrum, name="spectrum")
+ spectrum = ops.convert_to_tensor_v2_with_dispatch(spectrum, name="spectrum")
# spectrum.shape = batch_shape + block_shape
# First make spectrum a batch matrix with
# spectrum.shape = batch_shape + [prod(block_shape), 1]
@@ -755,7 +755,7 @@
name=name)
def _eigvals(self):
- return ops.convert_to_tensor(self.spectrum)
+ return ops.convert_to_tensor_v2_with_dispatch(self.spectrum)
@tf_export("linalg.LinearOperatorCirculant2D")
diff --git a/tensorflow/python/ops/linalg/linear_operator_diag.py b/tensorflow/python/ops/linalg/linear_operator_diag.py
index d51d6b8..b5e81b2 100644
--- a/tensorflow/python/ops/linalg/linear_operator_diag.py
+++ b/tensorflow/python/ops/linalg/linear_operator_diag.py
@@ -251,7 +251,7 @@
return array_ops.matrix_set_diag(x, new_diag)
def _eigvals(self):
- return ops.convert_to_tensor(self.diag)
+ return ops.convert_to_tensor_v2_with_dispatch(self.diag)
def _cond(self):
abs_diag = math_ops.abs(self.diag)
diff --git a/tensorflow/python/ops/linalg/linear_operator_full_matrix.py b/tensorflow/python/ops/linalg/linear_operator_full_matrix.py
index 8d92d1a..b108225 100644
--- a/tensorflow/python/ops/linalg/linear_operator_full_matrix.py
+++ b/tensorflow/python/ops/linalg/linear_operator_full_matrix.py
@@ -160,7 +160,7 @@
dtypes.complex128,
]
- matrix = ops.convert_to_tensor(matrix, name="matrix")
+ matrix = ops.convert_to_tensor_v2_with_dispatch(matrix, name="matrix")
dtype = matrix.dtype
if dtype not in allowed_dtypes:
diff --git a/tensorflow/python/ops/linalg/linear_operator_householder.py b/tensorflow/python/ops/linalg/linear_operator_householder.py
index 142d48c..265c862 100644
--- a/tensorflow/python/ops/linalg/linear_operator_householder.py
+++ b/tensorflow/python/ops/linalg/linear_operator_householder.py
@@ -198,7 +198,8 @@
# Note that because this is a reflection, it lies in O(n) (for real vector
# spaces) or U(n) (for complex vector spaces), and thus is its own adjoint.
- reflection_axis = ops.convert_to_tensor(self.reflection_axis)
+ reflection_axis = ops.convert_to_tensor_v2_with_dispatch(
+ self.reflection_axis)
x = linalg.adjoint(x) if adjoint_arg else x
normalized_axis = reflection_axis / linalg.norm(
reflection_axis, axis=-1, keepdims=True)
@@ -229,7 +230,8 @@
return self._matmul(rhs, adjoint, adjoint_arg)
def _to_dense(self):
- reflection_axis = ops.convert_to_tensor(self.reflection_axis)
+ reflection_axis = ops.convert_to_tensor_v2_with_dispatch(
+ self.reflection_axis)
normalized_axis = reflection_axis / linalg.norm(
reflection_axis, axis=-1, keepdims=True)
mat = normalized_axis[..., array_ops.newaxis]
@@ -238,7 +240,8 @@
matrix, 1. + array_ops.matrix_diag_part(matrix))
def _diag_part(self):
- reflection_axis = ops.convert_to_tensor(self.reflection_axis)
+ reflection_axis = ops.convert_to_tensor_v2_with_dispatch(
+ self.reflection_axis)
normalized_axis = reflection_axis / linalg.norm(
reflection_axis, axis=-1, keepdims=True)
return 1. - 2 * normalized_axis * math_ops.conj(normalized_axis)
diff --git a/tensorflow/python/ops/linalg/linear_operator_identity.py b/tensorflow/python/ops/linalg/linear_operator_identity.py
index 8226e74..a0f7ead 100644
--- a/tensorflow/python/ops/linalg/linear_operator_identity.py
+++ b/tensorflow/python/ops/linalg/linear_operator_identity.py
@@ -394,7 +394,7 @@
A `Tensor` with broadcast shape and same `dtype` as `self`.
"""
with self._name_scope(name):
- mat = ops.convert_to_tensor(mat, name="mat")
+ mat = ops.convert_to_tensor_v2_with_dispatch(mat, name="mat")
mat_diag = array_ops.matrix_diag_part(mat)
new_diag = 1 + mat_diag
return array_ops.matrix_set_diag(mat, new_diag)
@@ -720,7 +720,7 @@
multiplier_vector = array_ops.expand_dims(self.multiplier, -1)
# Shape [C1,...,Cc, M, M]
- mat = ops.convert_to_tensor(mat, name="mat")
+ mat = ops.convert_to_tensor_v2_with_dispatch(mat, name="mat")
# Shape [C1,...,Cc, M]
mat_diag = array_ops.matrix_diag_part(mat)
diff --git a/tensorflow/python/ops/linalg/linear_operator_permutation.py b/tensorflow/python/ops/linalg/linear_operator_permutation.py
index 3a44cd5..9cc8e15 100644
--- a/tensorflow/python/ops/linalg/linear_operator_permutation.py
+++ b/tensorflow/python/ops/linalg/linear_operator_permutation.py
@@ -197,7 +197,7 @@
return array_ops.shape(perm)[-1]
def _matmul(self, x, adjoint=False, adjoint_arg=False):
- perm = ops.convert_to_tensor(self.perm)
+ perm = ops.convert_to_tensor_v2_with_dispatch(self.perm)
if adjoint and not self.is_self_adjoint:
# TODO(srvasude): invert_permutation doesn't work on batches so we use
# argsort.
@@ -232,13 +232,13 @@
return self._matmul(rhs, adjoint=(not adjoint), adjoint_arg=adjoint_arg)
def _to_dense(self):
- perm = ops.convert_to_tensor(self.perm)
+ perm = ops.convert_to_tensor_v2_with_dispatch(self.perm)
return math_ops.cast(math_ops.equal(
math_ops.range(0, self._domain_dimension_tensor(perm)),
perm[..., array_ops.newaxis]), self.dtype)
def _diag_part(self):
- perm = ops.convert_to_tensor(self.perm)
+ perm = ops.convert_to_tensor_v2_with_dispatch(self.perm)
return math_ops.cast(math_ops.equal(
math_ops.range(0, self._domain_dimension_tensor(perm)),
perm), self.dtype)
diff --git a/tensorflow/python/ops/linalg/linear_operator_toeplitz.py b/tensorflow/python/ops/linalg/linear_operator_toeplitz.py
index 71fff44..2d61a53 100644
--- a/tensorflow/python/ops/linalg/linear_operator_toeplitz.py
+++ b/tensorflow/python/ops/linalg/linear_operator_toeplitz.py
@@ -209,8 +209,8 @@
# for more details.
x = linalg.adjoint(x) if adjoint_arg else x
expanded_x = array_ops.concat([x, array_ops.zeros_like(x)], axis=-2)
- col = ops.convert_to_tensor(self.col)
- row = ops.convert_to_tensor(self.row)
+ col = ops.convert_to_tensor_v2_with_dispatch(self.col)
+ row = ops.convert_to_tensor_v2_with_dispatch(self.row)
circulant_col = array_ops.concat(
[col,
array_ops.zeros_like(col[..., 0:1]),
@@ -236,8 +236,8 @@
[self.domain_dimension_tensor()], self.dtype)
def _to_dense(self):
- row = ops.convert_to_tensor(self.row)
- col = ops.convert_to_tensor(self.col)
+ row = ops.convert_to_tensor_v2_with_dispatch(self.row)
+ col = ops.convert_to_tensor_v2_with_dispatch(self.col)
total_shape = array_ops.broadcast_dynamic_shape(
array_ops.shape(row), array_ops.shape(col))
n = array_ops.shape(row)[-1]
diff --git a/tensorflow/python/ops/linalg/linear_operator_tridiag.py b/tensorflow/python/ops/linalg/linear_operator_tridiag.py
index 4227478..2ba310f 100644
--- a/tensorflow/python/ops/linalg/linear_operator_tridiag.py
+++ b/tensorflow/python/ops/linalg/linear_operator_tridiag.py
@@ -246,7 +246,7 @@
self.diagonals, linalg.adjoint(self.diagonals),
message='Matrix was not equal to its adjoint.')]
elif self.diagonals_format == _COMPACT:
- diagonals = ops.convert_to_tensor(self.diagonals)
+ diagonals = ops.convert_to_tensor_v2_with_dispatch(self.diagonals)
asserts += [linear_operator_util.assert_zero_imag_part(
diagonals[..., 1, :], message=diag_message)]
# Roll the subdiagonal so the shifted argument is at the end.
@@ -353,7 +353,9 @@
align='LEFT_RIGHT',
padding_value=0.)
- diagonals = [ops.convert_to_tensor(d) for d in self.diagonals]
+ diagonals = [
+ ops.convert_to_tensor_v2_with_dispatch(d) for d in self.diagonals
+ ]
diagonals = array_ops.stack(diagonals, axis=-2)
return gen_array_ops.matrix_diag_v3(
diff --git a/tensorflow/python/ops/linalg/linear_operator_util.py b/tensorflow/python/ops/linalg/linear_operator_util.py
index 948f2f8..096ad3f 100644
--- a/tensorflow/python/ops/linalg/linear_operator_util.py
+++ b/tensorflow/python/ops/linalg/linear_operator_util.py
@@ -114,7 +114,7 @@
raise TypeError('Mutable type must be of dtype "{}" but is "{}".'.format(
dtype_name(dtype_base), dtype_name(value_dtype_base)))
return value
- return ops.convert_to_tensor(
+ return ops.convert_to_tensor_v2_with_dispatch(
value, dtype=dtype, dtype_hint=dtype_hint, name=name)
@@ -189,10 +189,10 @@
An `Op` that asserts `x` has no entries with modulus zero.
"""
with ops.name_scope(name, values=[x]):
- x = ops.convert_to_tensor(x, name="x")
+ x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
dtype = x.dtype.base_dtype
should_be_nonzero = math_ops.abs(x)
- zero = ops.convert_to_tensor(0, dtype=dtype.real_dtype)
+ zero = ops.convert_to_tensor_v2_with_dispatch(0, dtype=dtype.real_dtype)
return check_ops.assert_less(zero, should_be_nonzero, message=message)
@@ -208,13 +208,13 @@
An `Op` that asserts `x` has no entries with modulus zero.
"""
with ops.name_scope(name, values=[x]):
- x = ops.convert_to_tensor(x, name="x")
+ x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
dtype = x.dtype.base_dtype
if dtype.is_floating:
return control_flow_ops.no_op()
- zero = ops.convert_to_tensor(0, dtype=dtype.real_dtype)
+ zero = ops.convert_to_tensor_v2_with_dispatch(0, dtype=dtype.real_dtype)
return check_ops.assert_equal(zero, math_ops.imag(x), message=message)
@@ -261,7 +261,7 @@
dtype = dtypes.int32
else:
dtype = None
- return ops.convert_to_tensor(shape, dtype=dtype, name=name)
+ return ops.convert_to_tensor_v2_with_dispatch(shape, dtype=dtype, name=name)
################################################################################
@@ -323,7 +323,7 @@
batch_matrices = list(batch_matrices)
for i, mat in enumerate(batch_matrices):
- batch_matrices[i] = ops.convert_to_tensor(mat)
+ batch_matrices[i] = ops.convert_to_tensor_v2_with_dispatch(mat)
assert_is_batch_matrix(batch_matrices[i])
if len(batch_matrices) < 2:
@@ -366,8 +366,9 @@
def matrix_solve_with_broadcast(matrix, rhs, adjoint=False, name=None):
"""Solve systems of linear equations."""
with ops.name_scope(name, "MatrixSolveWithBroadcast", [matrix, rhs]):
- matrix = ops.convert_to_tensor(matrix, name="matrix")
- rhs = ops.convert_to_tensor(rhs, name="rhs", dtype=matrix.dtype)
+ matrix = ops.convert_to_tensor_v2_with_dispatch(matrix, name="matrix")
+ rhs = ops.convert_to_tensor_v2_with_dispatch(
+ rhs, name="rhs", dtype=matrix.dtype)
# If either matrix/rhs has extra dims, we can reshape to get rid of them.
matrix, rhs, reshape_inv, still_need_to_transpose = _reshape_for_efficiency(
@@ -526,7 +527,8 @@
if not any(nest.is_nested(x) for x in arg):
return True
else:
- arg_dims = [ops.convert_to_tensor(x).shape[arg_split_dim] for x in arg]
+ arg_dims = [ops.convert_to_tensor_v2_with_dispatch(
+ x).shape[arg_split_dim] for x in arg]
self_dims = [dim.value for dim in block_dimensions]
# If none of the operator dimensions are known, interpret the input as
diff --git a/tensorflow/python/util/dispatch_test.py b/tensorflow/python/util/dispatch_test.py
index 2b3946c..f06f2fd 100644
--- a/tensorflow/python/util/dispatch_test.py
+++ b/tensorflow/python/util/dispatch_test.py
@@ -18,11 +18,13 @@
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.linalg import linear_operator_diag
from tensorflow.python.ops.proto_ops import decode_proto
from tensorflow.python.platform import googletest
from tensorflow.python.platform import test
@@ -60,6 +62,8 @@
self.name = name
self.args = args
self.kwargs = kwargs
+ self.shape = array_ops.ones(shape=(4, 4)).shape
+ self.dtype = dtypes.float32
def __repr__(self):
if self.args is None and self.kwargs is None:
@@ -70,6 +74,10 @@
["{}={}".format(name, x) for (name, x) in self.kwargs.items()])
return "{}({})".format(self.name, ", ".join(args))
+ @property
+ def is_tensor_like(self):
+ return True
+
@classmethod
def _overload_all_operators(cls): # pylint: disable=invalid-name
"""Register overloads for all operators."""
@@ -282,5 +290,42 @@
# Clean up.
dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers
+ def testGlobalDispatcherLinearOperators(self):
+ original_global_dispatchers = dispatch._GLOBAL_DISPATCHERS
+ try:
+ TensorTracerOpDispatcher().register()
+
+ x = TensorTracer("x")
+
+ # To grab the eigenvalues the diag operator just calls convert_to_tensor
+ # (twice) in this case.
+ trace = linear_operator_diag.LinearOperatorDiag(x).eigvals()
+ self.assertEqual(
+ str(trace),
+ "convert_to_tensor(convert_to_tensor(x, dtype=None, dtype_hint=None, "
+ "name=diag))")
+
+ # The diagonal tensor addition gets traced even though the linear_operator
+ # API only uses dispatchable ops instead of directly exposing dispatching.
+ trace = linear_operator_diag.LinearOperatorDiag(x).add_to_tensor(x)
+ self.assertIn(
+ "linalg.set_diag(convert_to_tensor(x, name=x), __operators__.add("
+ "convert_to_tensor(x, dtype=None, dtype_hint=None, name=diag), "
+ "linalg.diag_part(convert_to_tensor(x, name=x)), "
+ "name=",
+ str(trace))
+
+ # The dispatch-supporting ops the non-singular check calls out to
+ # get traced.
+ trace = linear_operator_diag.LinearOperatorDiag(x).assert_non_singular()
+ self.assertIn("debugging.assert_less", str(trace))
+ self.assertIn(
+ "message=Singular operator: Diagonal contained zero values.",
+ str(trace))
+
+ finally:
+ # Clean up.
+ dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers
+
if __name__ == "__main__":
googletest.main()