| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Operator dispatch for RaggedTensors.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import collections |
| import numpy as np |
| |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import sparse_tensor |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import clip_ops |
| from tensorflow.python.ops import gen_bitwise_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import parsing_ops |
| from tensorflow.python.ops import string_ops |
| from tensorflow.python.ops import variables |
| from tensorflow.python.ops.ragged import ragged_array_ops |
| from tensorflow.python.ops.ragged import ragged_batch_gather_ops |
| from tensorflow.python.ops.ragged import ragged_concat_ops |
| from tensorflow.python.ops.ragged import ragged_gather_ops |
| from tensorflow.python.ops.ragged import ragged_math_ops |
| from tensorflow.python.ops.ragged import ragged_squeeze_op |
| from tensorflow.python.ops.ragged import ragged_string_ops |
| from tensorflow.python.ops.ragged import ragged_tensor |
| from tensorflow.python.ops.ragged import ragged_tensor_shape |
| from tensorflow.python.ops.ragged import ragged_util |
| from tensorflow.python.ops.ragged import ragged_where_op |
| from tensorflow.python.util import deprecation |
| from tensorflow.python.util import dispatch |
| from tensorflow.python.util import tf_decorator |
| from tensorflow.python.util import tf_export |
| from tensorflow.python.util import tf_inspect |
| |
| # @TODO(edloper): Set this to True in the CL that exports RaggedTensors. |
| _UPDATE_DOCSTRINGS = False |
| |
| # Information about an argument to an operation: The name of the argument, its |
| # position in the argument list, and a boolean flag indicating whether it |
| # expects a list of tensors. |
| _ArgInfo = collections.namedtuple('ArgInfo', ['name', 'position', 'is_list']) |
| |
| |
| def _get_arg_infos(func, arg_names): |
| """Returns an `_ArgInfo` for each argument of `func` specified by `arg_names`. |
| |
| Args: |
| func: The function whose arguments should be described. |
| arg_names: The names of the arguments to get info for. |
| |
| Returns: |
| A tuple of `_ArgInfo`s. |
| """ |
| arg_infos = [] |
| |
| # Inspect the func's argspec to find the position of each arg. |
| arg_spec = tf_inspect.getargspec(func) |
| for argname in arg_names: |
| assert isinstance(argname, str) |
| is_list = argname.startswith('[') and argname.endswith(']') |
| if is_list: |
| argname = argname[1:-1] |
| if argname not in arg_spec.args: |
| raise ValueError('Argument %r not found function in %s. Args=%s' % |
| (argname, func, arg_spec.args)) |
| arg_infos.append(_ArgInfo(argname, arg_spec.args.index(argname), is_list)) |
| return arg_infos |
| |
| |
| def _is_convertible_to_tensor(value): |
| """Returns true if `value` is convertible to a `Tensor`.""" |
| if value is None: |
| return True |
| if isinstance(value, |
| (ops.Tensor, variables.Variable, np.ndarray, int, float, str)): |
| return True |
| elif isinstance(value, (sparse_tensor.SparseTensor,)): |
| return False |
| else: |
| try: |
| ops.convert_to_tensor(value) |
| return True |
| except (TypeError, ValueError): |
| return False |
| |
| |
| class UnaryRaggedElementwiseDispatcher(dispatch.OpDispatcher): |
| """OpDispatcher for unary ops that map a base op across ragged values.""" |
| |
| def __init__(self, original_op, arg_is_list=False): |
| self._original_op = original_op |
| self._arg_is_list = arg_is_list |
| arg_names = tf_inspect.getfullargspec(original_op)[0] |
| self._x = arg_names[0] |
| if _UPDATE_DOCSTRINGS: |
| original_op.__doc__ = ( |
| original_op.__doc__.rstrip() + '\n\n' + |
| ' `{x}` may be a `tf.RaggedTensor`.\n'.format(x=self._x)) |
| |
| def handle(self, args, kwargs): |
| if args: |
| x, args = args[0], args[1:] |
| else: |
| kwargs = kwargs.copy() |
| x = kwargs.pop(self._x, None) |
| if x is None: |
| return self.NOT_SUPPORTED |
| if self._arg_is_list: |
| found_ragged = False |
| for elt in x: |
| if ragged_tensor.is_ragged(elt): |
| found_ragged = True |
| elif not _is_convertible_to_tensor(elt): |
| return self.NOT_SUPPORTED |
| if found_ragged: |
| x = ragged_tensor.match_row_splits_dtypes(*x) |
| nested_splits_lists = [ |
| elt.nested_row_splits for elt in x if ragged_tensor.is_ragged(elt) |
| ] |
| flat_values = [ |
| elt.flat_values if ragged_tensor.is_ragged(elt) else elt |
| for elt in x |
| ] |
| with ops.control_dependencies( |
| ragged_util.assert_splits_match(nested_splits_lists)): |
| return ragged_tensor.RaggedTensor.from_nested_row_splits( |
| self._original_op(flat_values, *args, **kwargs), |
| nested_splits_lists[0], validate=False) |
| else: |
| return self.NOT_SUPPORTED |
| else: |
| found_ragged = ragged_tensor.is_ragged(x) |
| if found_ragged: |
| mapped_values = self._original_op(x.flat_values, *args, **kwargs) |
| return x.with_flat_values(mapped_values) |
| else: |
| return self.NOT_SUPPORTED |
| |
| |
| class BinaryRaggedElementwiseDispatcher(dispatch.OpDispatcher): |
| """OpDispatcher for binary ops that map a base op across ragged values. |
| |
| Supports broadcasting. |
| """ |
| |
| def __init__(self, original_op): |
| self._original_op = original_op |
| arg_names = tf_inspect.getfullargspec(original_op)[0] |
| self._x = arg_names[0] |
| self._y = arg_names[1] |
| if _UPDATE_DOCSTRINGS: |
| original_op.__doc__ = ( |
| original_op.__doc__.rstrip() + '\n\n' + |
| ' `{x}` and `{y}` may be a `tf.RaggedTensor`.\n'.format( |
| x=self._x, y=self._y)) |
| |
| def handle(self, args, kwargs): |
| # Extract the binary args. |
| if len(args) > 1: |
| x = args[0] |
| y = args[1] |
| args = args[2:] |
| elif args: |
| kwargs = kwargs.copy() |
| x = args[0] |
| y = kwargs.pop(self._y, None) |
| args = args[1:] |
| else: |
| kwargs = kwargs.copy() |
| x = kwargs.pop(self._x, None) |
| y = kwargs.pop(self._y, None) |
| |
| # Bail if we don't have at least one ragged argument. |
| x_is_ragged = ragged_tensor.is_ragged(x) |
| y_is_ragged = ragged_tensor.is_ragged(y) |
| if not (x_is_ragged or y_is_ragged): |
| return self.NOT_SUPPORTED |
| |
| # Convert args to tensors. Bail if conversion fails. |
| try: |
| if not x_is_ragged: |
| x = ops.convert_to_tensor(x, name=self._x, preferred_dtype=y.dtype) |
| if not y_is_ragged: |
| y = ops.convert_to_tensor(y, name=self._y, preferred_dtype=x.dtype) |
| except (TypeError, ValueError): |
| return self.NOT_SUPPORTED |
| |
| if x_is_ragged and y_is_ragged: |
| x, y = ragged_tensor.match_row_splits_dtypes(x, y) |
| |
| if ((x_is_ragged and y_is_ragged) or |
| (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or |
| (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)): |
| bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape( |
| ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x), |
| ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y)) |
| x = ragged_tensor_shape.broadcast_to( |
| x, bcast_shape, broadcast_inner_dimensions=False) |
| y = ragged_tensor_shape.broadcast_to( |
| y, bcast_shape, broadcast_inner_dimensions=False) |
| |
| x_values = x.flat_values if ragged_tensor.is_ragged(x) else x |
| y_values = y.flat_values if ragged_tensor.is_ragged(y) else y |
| mapped_values = self._original_op(x_values, y_values, *args, **kwargs) |
| if ragged_tensor.is_ragged(x): |
| return x.with_flat_values(mapped_values) |
| else: |
| return y.with_flat_values(mapped_values) |
| |
| |
| class RaggedDispatcher(dispatch.OpDispatcher): |
| """OpDispatcher for ragged ops. |
| |
| Dispatches to a wrapped op-handler if at least one of the `tensor_args` |
| arguments is a RaggedTensor or a RaggedTensorValue; and all of the |
| `tensor_args` arguments are convertible to Tensor or RaggedTensor. |
| """ |
| |
| def __init__(self, original_op, ragged_op, ragged_args): |
| op_arg_names = tf_inspect.getfullargspec(original_op)[0] |
| ragged_arg_names = tf_inspect.getfullargspec(ragged_op)[0] |
| if op_arg_names != ragged_arg_names: |
| raise AssertionError( |
| 'Signature must exactly match when overriding %s with %s: %s vs %s' % |
| (original_op, ragged_op, op_arg_names, ragged_arg_names)) |
| self._ragged_op = ragged_op |
| self._ragged_args = _get_arg_infos(ragged_op, ragged_args) |
| if _UPDATE_DOCSTRINGS: |
| arg_list = ' and '.join('`%s`' % arg for arg in ragged_args) |
| original_op.__doc__ = ( |
| original_op.__doc__.rstrip() + '\n\n' + |
| ' {0} may be a `tf.RaggedTensor`.\n'.format(arg_list)) |
| |
| def handle(self, args, kwargs): |
| if self.is_supported(args, kwargs): |
| return self._ragged_op(*args, **kwargs) |
| else: |
| return self.NOT_SUPPORTED |
| |
| def is_supported(self, args, kwargs): |
| found_ragged = False |
| for arg_info in self._ragged_args: |
| if arg_info.position < len(args): |
| arg = args[arg_info.position] |
| else: |
| arg = kwargs.get(arg_info.name, None) |
| |
| if arg_info.is_list: |
| if not isinstance(arg, (list, tuple)): |
| return False |
| for elt in arg: |
| if ragged_tensor.is_ragged(elt): |
| found_ragged = True |
| elif not _is_convertible_to_tensor(elt): |
| return False |
| else: |
| if ragged_tensor.is_ragged(arg): |
| found_ragged = True |
| elif not _is_convertible_to_tensor(arg): |
| return False |
| return found_ragged |
| |
| |
| _UNARY_ELEMENTWISE_OPS = [ |
| array_ops.check_numerics, |
| array_ops.identity, |
| array_ops.ones_like, |
| array_ops.ones_like_v2, |
| array_ops.zeros_like, |
| array_ops.zeros_like_v2, |
| clip_ops.clip_by_value, |
| gen_bitwise_ops.invert, |
| math_ops.abs, |
| math_ops.acos, |
| math_ops.acosh, |
| math_ops.angle, |
| math_ops.asin, |
| math_ops.asinh, |
| math_ops.atan, |
| math_ops.atanh, |
| math_ops.cast, |
| math_ops.ceil, |
| math_ops.conj, |
| math_ops.cos, |
| math_ops.cosh, |
| math_ops.digamma, |
| math_ops.erf, |
| math_ops.erfc, |
| math_ops.exp, |
| math_ops.expm1, |
| math_ops.floor, |
| math_ops.imag, |
| math_ops.is_finite, |
| math_ops.is_inf, |
| math_ops.is_nan, |
| math_ops.lgamma, |
| math_ops.log, |
| math_ops.log1p, |
| math_ops.log_sigmoid, |
| math_ops.logical_not, |
| math_ops.negative, |
| math_ops.real, |
| math_ops.reciprocal, |
| math_ops.rint, |
| math_ops.round, |
| math_ops.rsqrt, |
| math_ops.saturate_cast, |
| math_ops.sign, |
| math_ops.sin, |
| math_ops.sinh, |
| math_ops.sqrt, |
| math_ops.square, |
| math_ops.tan, |
| parsing_ops.decode_compressed, |
| string_ops.string_to_number, |
| string_ops.string_to_hash_bucket, |
| string_ops.as_string, |
| string_ops.decode_base64, |
| string_ops.encode_base64, |
| string_ops.regex_full_match, |
| string_ops.regex_replace, |
| string_ops.string_strip, |
| string_ops.string_to_hash_bucket, |
| string_ops.string_to_hash_bucket_fast, |
| string_ops.string_to_hash_bucket_strong, |
| string_ops.substr, |
| string_ops.substr_v2, |
| string_ops.string_length, |
| string_ops.string_length_v2, |
| string_ops.unicode_script, |
| ] |
| |
| _UNARY_LIST_ELEMENTWISE_OPS = [ |
| math_ops.add_n, |
| string_ops.string_join, |
| ] |
| |
| _BINARY_ELEMENTWISE_OPS = [ |
| gen_bitwise_ops.bitwise_and, |
| gen_bitwise_ops.bitwise_or, |
| gen_bitwise_ops.bitwise_xor, |
| gen_bitwise_ops.left_shift, |
| gen_bitwise_ops.right_shift, |
| math_ops.add, |
| math_ops.atan2, |
| math_ops.complex, |
| math_ops.div_no_nan, |
| math_ops.divide, |
| math_ops.equal, |
| math_ops.floordiv, |
| math_ops.floormod, |
| math_ops.greater, |
| math_ops.greater_equal, |
| math_ops.less, |
| math_ops.less_equal, |
| math_ops.logical_and, |
| math_ops.logical_or, |
| math_ops.logical_xor, |
| math_ops.maximum, |
| math_ops.minimum, |
| math_ops.multiply, |
| math_ops.not_equal, |
| math_ops.pow, |
| math_ops.realdiv, |
| math_ops.squared_difference, |
| math_ops.subtract, |
| math_ops.truediv, |
| math_ops.truncatediv, |
| math_ops.truncatemod, |
| ] |
| |
| |
| # We don't need to register a separate delegation handler for these v1 ops, |
| # since they delegate to the v2 ops (which already have a handler). But we |
| # still want to include them in the ragged_op_list() output. |
| _V2_OPS_THAT_ARE_DELEGATED_TO_FROM_V1_OPS = [ |
| math_ops.reduce_sum, |
| math_ops.reduce_prod, |
| math_ops.reduce_min, |
| math_ops.reduce_max, |
| math_ops.reduce_mean, |
| math_ops.reduce_any, |
| math_ops.reduce_all, |
| string_ops.string_to_number, |
| string_ops.string_to_hash_bucket, |
| string_ops.reduce_join_v2, |
| ] |
| |
| |
| def _ragged_gather_v1(params, indices, validate_indices=None, name=None, |
| axis=0, batch_dims=0): |
| return ragged_gather_ops.gather( |
| params=params, |
| indices=indices, |
| validate_indices=validate_indices, |
| axis=axis, |
| batch_dims=batch_dims, |
| name=name) |
| |
| |
| def _ragged_gather_nd_v1(params, indices, name=None, batch_dims=0): |
| return ragged_gather_ops.gather_nd( |
| params=params, |
| indices=indices, |
| batch_dims=batch_dims, |
| name=name) |
| |
| |
| def _ragged_expand_dims_v1(input, axis=None, name=None, dim=None): # pylint: disable=redefined-builtin |
| if dim is not None: |
| axis = dim |
| return ragged_array_ops.expand_dims(input=input, axis=axis, name=name) |
| |
| |
| def _ragged_size_v1(input, name=None, out_type=dtypes.int32): # pylint: disable=redefined-builtin |
| return ragged_array_ops.size(input=input, out_type=out_type, name=name) |
| |
| |
| def _ragged_squeeze_v1(input, axis=None, name=None, squeeze_dims=None): # pylint: disable=redefined-builtin |
| axis = deprecation.deprecated_argument_lookup('axis', axis, 'squeeze_dims', |
| squeeze_dims) |
| return ragged_squeeze_op.squeeze(input, axis, name) |
| |
| # (original_op, ragged_op, ragged_args) |
| _RAGGED_DISPATCH_OPS = [ |
| (array_ops.batch_gather, ragged_batch_gather_ops.batch_gather, |
| ['params', 'indices']), |
| (array_ops.concat, ragged_concat_ops.concat, ['[values]']), |
| (array_ops.expand_dims, _ragged_expand_dims_v1, ['input']), |
| (array_ops.expand_dims_v2, ragged_array_ops.expand_dims, ['input']), |
| (array_ops.gather, _ragged_gather_v1, ['params', 'indices']), |
| (array_ops.gather_v2, ragged_gather_ops.gather, ['params', 'indices']), |
| (array_ops.gather_nd, _ragged_gather_nd_v1, ['params', 'indices']), |
| (array_ops.gather_nd_v2, ragged_gather_ops.gather_nd, ['params', |
| 'indices']), |
| (array_ops.rank, ragged_array_ops.rank, ['input']), |
| (array_ops.size, _ragged_size_v1, ['input']), |
| (array_ops.size_v2, ragged_array_ops.size, ['input']), |
| (array_ops.squeeze, _ragged_squeeze_v1, ['input']), |
| (array_ops.squeeze_v2, ragged_squeeze_op.squeeze, ['input']), |
| (array_ops.stack, ragged_concat_ops.stack, ['[values]']), |
| (array_ops.tile, ragged_array_ops.tile, ['input']), |
| (array_ops.where, ragged_where_op.where, ['condition', 'x', 'y']), |
| (math_ops.unsorted_segment_sum, ragged_math_ops.segment_sum, |
| ['data', 'segment_ids']), |
| (math_ops.unsorted_segment_prod, ragged_math_ops.segment_prod, |
| ['data', 'segment_ids']), |
| (math_ops.unsorted_segment_min, ragged_math_ops.segment_min, |
| ['data', 'segment_ids']), |
| (math_ops.unsorted_segment_max, ragged_math_ops.segment_max, |
| ['data', 'segment_ids']), |
| (math_ops.unsorted_segment_mean, ragged_math_ops.segment_mean, |
| ['data', 'segment_ids']), |
| (math_ops.unsorted_segment_sqrt_n, ragged_math_ops.segment_sqrt_n, |
| ['data', 'segment_ids']), |
| (string_ops.reduce_join_v2, ragged_string_ops.reduce_join, ['inputs']), |
| (math_ops.reduce_sum, ragged_math_ops.reduce_sum, ['input_tensor']), |
| (math_ops.reduce_prod, ragged_math_ops.reduce_prod, ['input_tensor']), |
| (math_ops.reduce_min, ragged_math_ops.reduce_min, ['input_tensor']), |
| (math_ops.reduce_max, ragged_math_ops.reduce_max, ['input_tensor']), |
| (math_ops.reduce_mean, ragged_math_ops.reduce_mean, ['input_tensor']), |
| (math_ops.reduce_any, ragged_math_ops.reduce_any, ['input_tensor']), |
| (math_ops.reduce_all, ragged_math_ops.reduce_all, ['input_tensor']), |
| ] |
| |
| |
| def register_dispatchers(): |
| """Constructs & registers OpDispatchers for ragged ops.""" |
| |
| op_list = ( |
| _UNARY_ELEMENTWISE_OPS + _UNARY_LIST_ELEMENTWISE_OPS + |
| _BINARY_ELEMENTWISE_OPS + [x[0] for x in _RAGGED_DISPATCH_OPS]) |
| for op in op_list: |
| _, undecorated_op = tf_decorator.unwrap(op) |
| if not hasattr(undecorated_op, |
| tf_export.API_ATTRS[tf_export.TENSORFLOW_API_NAME].names): |
| raise AssertionError('Expected %s to be an exported symbol ' |
| '(while adding a RaggedTensor dispatcher)') |
| |
| for op in _UNARY_ELEMENTWISE_OPS: |
| UnaryRaggedElementwiseDispatcher(op).register(op) |
| |
| for op in _UNARY_LIST_ELEMENTWISE_OPS: |
| UnaryRaggedElementwiseDispatcher(op, True).register(op) |
| |
| for op in _BINARY_ELEMENTWISE_OPS: |
| BinaryRaggedElementwiseDispatcher(op).register(op) |
| |
| for (original_op, ragged_op, args) in _RAGGED_DISPATCH_OPS: |
| RaggedDispatcher(original_op, ragged_op, args).register(original_op) |
| |
| |
| def _ragged_op_signature(op, ragged_args): |
| """Returns a signature for the given op, marking ragged args in bold.""" |
| op_name = tf_export.get_canonical_name_for_symbol(op) |
| argspec = tf_inspect.getfullargspec(op) |
| arg_names = argspec.args |
| |
| # Mark ragged arguments in bold. |
| for pos in ragged_args: |
| arg_names[pos] = '**' + arg_names[pos] + '**' |
| |
| # Add argument defaults. |
| for pos in range(-1, -len(argspec.defaults) - 1, -1): |
| arg_names[pos] += '=`{!r}`'.format(argspec.defaults[pos]) |
| |
| # Add varargs and keyword args |
| if argspec.varargs: |
| arg_names.append('*' + argspec.varargs) |
| if argspec.varkw: |
| arg_names.append('**' + argspec.varkw) |
| |
| return '* `tf.{}`({})'.format(op_name, ', '.join(arg_names)) |
| |
| |
| def _op_is_in_tf_version(op, version): |
| if version == 1: |
| return (tf_export.get_v1_names(tf_decorator.unwrap(op)[1]) or |
| op in _V2_OPS_THAT_ARE_DELEGATED_TO_FROM_V1_OPS) |
| elif version == 2: |
| return tf_export.get_v2_names(tf_decorator.unwrap(op)[1]) |
| else: |
| raise ValueError('Expected version 1 or 2.') |
| |
| |
| def ragged_op_list(tf_version=1): |
| """Returns a string listing operators that have dispathers registered.""" |
| lines = [] |
| for op in _UNARY_ELEMENTWISE_OPS + _UNARY_LIST_ELEMENTWISE_OPS: |
| if _op_is_in_tf_version(op, tf_version): |
| lines.append(_ragged_op_signature(op, [0])) |
| for op in _BINARY_ELEMENTWISE_OPS: |
| if _op_is_in_tf_version(op, tf_version): |
| lines.append(_ragged_op_signature(op, [0, 1])) |
| for op, _, ragged_args in _RAGGED_DISPATCH_OPS: |
| if _op_is_in_tf_version(op, tf_version): |
| arginfos = _get_arg_infos(op, ragged_args) |
| ragged_args = [arginfo.position for arginfo in arginfos] |
| lines.append(_ragged_op_signature(op, ragged_args)) |
| return ('\n\n### Additional ops that support `RaggedTensor`\n\n' |
| 'Arguments that accept `RaggedTensor`s are marked in **bold**.\n\n' + |
| '\n'.join(sorted(lines)) + 'n') |
| |
| |
| register_dispatchers() |