| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Experimental library that exposes XLA operations directly in TensorFlow. |
| |
| It is sometimes useful to be able to build HLO programs directly from |
| TensorFlow. This file provides Tensorflow operators that mirror the semantics of |
| HLO operators as closely as possible. |
| |
| Note: Most of the operators defined in this module are used by the jax2tf |
| converter (see go/jax2tf for details) and are used in SavedModel produced |
| by jax2tf. Hence, we need to maintain backwards compatibility for these |
| operators. Please reach out to the JAX team if you want to make changes. |
| """ |
| |
| from tensorflow.compiler.tf2xla.ops import gen_xla_ops |
| from tensorflow.compiler.xla import xla_data_pb2 |
| from tensorflow.core.framework import attr_value_pb2 |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import ops |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import bitwise_ops |
| from tensorflow.python.ops import gen_math_ops |
| from tensorflow.python.ops import gen_random_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import random_ops |
| from tensorflow.python.ops import special_math_ops |
| from tensorflow.python.ops import stateless_random_ops |
| from tensorflow.python.ops.numpy_ops import np_utils |
| |
| # TODO(phawkins): provide wrappers for all XLA operators. Currently the missing |
| # ops include: |
| # infeed/outfeed (available via tf.contrib.tpu) |
| # collectives, e.g., cross-replica-sum (available via tf.contrib.tpu) |
| # conditional |
| # gather/scatter |
| # collapse |
| |
| # This file reuses builtin names (following XLA's names, so we can call things |
| # like xla.max), so we capture the builtin versions here. |
| # pylint: disable=redefined-builtin |
| _max = max |
| _min = min |
| _slice = slice # pylint: disable=invalid-name |
| |
| constant = constant_op.constant |
| |
| # Unary operators. |
| |
| # For most arithmetic operators there is a TensorFlow operator |
| # that exactly corresponds to each XLA operator. Rather than defining |
| # XLA-specific variants, we reuse the corresponding TensorFlow operator. |
| # TODO(phawkins): It would be even better to have TensorFlow operators that 1:1 |
| # wrap every HLO operator, because that would allow us to be confident that the |
| # semantics match. |
| |
| |
| def _unary_op(fn): |
| """Wrapper that restricts `fn` to have the correct signature.""" |
| |
| def unary_op_wrapper(x, name=None): |
| return fn(x, name=name) |
| |
| return unary_op_wrapper |
| |
| |
| abs = _unary_op(math_ops.abs) |
| # TODO(phawkins): implement clz. |
| conj = _unary_op(math_ops.conj) |
| cos = _unary_op(math_ops.cos) |
| ceil = _unary_op(math_ops.ceil) |
| digamma = _unary_op(math_ops.digamma) |
| erf = _unary_op(math_ops.erf) |
| erfc = _unary_op(math_ops.erfc) |
| erfinv = _unary_op(math_ops.erfinv) |
| ndtri = _unary_op(math_ops.ndtri) |
| exp = _unary_op(math_ops.exp) |
| expm1 = _unary_op(math_ops.expm1) |
| floor = _unary_op(math_ops.floor) |
| imag = _unary_op(math_ops.imag) |
| is_finite = _unary_op(math_ops.is_finite) |
| lgamma = _unary_op(math_ops.lgamma) |
| log = _unary_op(math_ops.log) |
| log1p = _unary_op(math_ops.log1p) |
| logical_not = _unary_op(math_ops.logical_not) |
| neg = _unary_op(math_ops.neg) |
| real = _unary_op(math_ops.real) |
| # TODO(phawkins): unlike xla::Round, this rounds to even instead of zero for |
| # numbers halfway between two integers. |
| round = _unary_op(math_ops.round) |
| sin = _unary_op(math_ops.sin) |
| sign = _unary_op(math_ops.sign) |
| tanh = _unary_op(math_ops.tanh) |
| |
| # Bessel |
| bessel_i0e = _unary_op(special_math_ops.bessel_i0e) |
| bessel_i1e = _unary_op(special_math_ops.bessel_i1e) |
| |
| # Binary operators |
| |
| # The main difference between TensorFlow and XLA binary ops is the broadcasting |
| # semantics. TensorFlow uses Numpy-style broadcasting semantics, whereas XLA |
| # requires an explicit specification of which dimensions to broadcast if the |
| # arguments have different ranks. |
| |
| |
| def _broadcasting_binary_op(fn): |
| """Wraps a binary Tensorflow operator and performs XLA-style broadcasting.""" |
| |
| def broadcasting_binary_op_wrapper(x, y, broadcast_dims=None, name=None): |
| """Inner wrapper function.""" |
| broadcast_dims = broadcast_dims or [] |
| broadcast_dims = ops.convert_to_tensor(broadcast_dims, dtypes.int64) |
| # Rather than relying on having static shape information in the TensorFlow |
| # graph, we use an XlaBroadcastHelper op that can compute the correct shapes |
| # at JIT compilation time. |
| x, y = gen_xla_ops.xla_broadcast_helper(x, y, broadcast_dims) |
| return fn(x, y, name=name) |
| |
| return broadcasting_binary_op_wrapper |
| |
| |
| # Map from TF signed types to TF unsigned types. |
| _SIGNED_TO_UNSIGNED_TABLE = { |
| dtypes.int8: dtypes.uint8, |
| dtypes.int16: dtypes.uint16, |
| dtypes.int32: dtypes.uint32, |
| dtypes.int64: dtypes.uint64, |
| } |
| |
| # Map from TF unsigned types to TF signed types. |
| _UNSIGNED_TO_SIGNED_TABLE = { |
| dtypes.uint8: dtypes.int8, |
| dtypes.uint16: dtypes.int16, |
| dtypes.uint32: dtypes.int32, |
| dtypes.uint64: dtypes.int64, |
| } |
| |
| |
| def _shift_right_logical_helper(x, y, name=None): |
| """Performs an integer right logical shift irrespective of input type.""" |
| assert y.dtype == x.dtype |
| dtype = x.dtype |
| signed = dtype in _SIGNED_TO_UNSIGNED_TABLE |
| if signed: |
| unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[dtype] |
| x = math_ops.cast(x, unsigned_dtype) |
| y = math_ops.cast(y, unsigned_dtype) |
| output = bitwise_ops.right_shift(x, y, name=name) |
| if signed: |
| output = math_ops.cast(output, dtype) |
| return output |
| |
| |
| def _shift_right_arithmetic_helper(x, y, name=None): |
| """Performs an integer right arithmetic shift irrespective of input type.""" |
| assert y.dtype == x.dtype |
| dtype = x.dtype |
| unsigned = dtype in _UNSIGNED_TO_SIGNED_TABLE |
| if unsigned: |
| signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[dtype] |
| x = math_ops.cast(x, signed_dtype) |
| y = math_ops.cast(y, signed_dtype) |
| output = bitwise_ops.right_shift(x, y, name=name) |
| if unsigned: |
| output = math_ops.cast(output, dtype) |
| return output |
| |
| |
| add = _broadcasting_binary_op(math_ops.add) |
| sub = _broadcasting_binary_op(math_ops.sub) |
| mul = _broadcasting_binary_op(math_ops.mul) |
| div = _broadcasting_binary_op(math_ops.div) |
| rem = _broadcasting_binary_op(gen_math_ops.mod) |
| max = _broadcasting_binary_op(math_ops.maximum) |
| min = _broadcasting_binary_op(math_ops.minimum) |
| atan2 = _broadcasting_binary_op(math_ops.atan2) |
| complex = _broadcasting_binary_op(math_ops.complex) |
| logical_and = _broadcasting_binary_op(math_ops.logical_and) |
| logical_or = _broadcasting_binary_op(math_ops.logical_or) |
| logical_xor = _broadcasting_binary_op(math_ops.logical_xor) |
| eq = _broadcasting_binary_op(math_ops.equal) |
| ne = _broadcasting_binary_op(math_ops.not_equal) |
| ge = _broadcasting_binary_op(math_ops.greater_equal) |
| gt = _broadcasting_binary_op(math_ops.greater) |
| le = _broadcasting_binary_op(math_ops.less_equal) |
| lt = _broadcasting_binary_op(math_ops.less) |
| pow = _broadcasting_binary_op(math_ops.pow) |
| shift_left = _broadcasting_binary_op(bitwise_ops.left_shift) |
| shift_right_logical = _broadcasting_binary_op(_shift_right_logical_helper) |
| shift_right_arithmetic = _broadcasting_binary_op(_shift_right_arithmetic_helper) |
| |
| igamma = _broadcasting_binary_op(math_ops.igamma) |
| igamma_grad_a = _broadcasting_binary_op(gen_math_ops.igamma_grad_a) |
| random_gamma_grad = _broadcasting_binary_op(gen_random_ops.random_gamma_grad) |
| igammac = _broadcasting_binary_op(math_ops.igammac) |
| polygamma = _broadcasting_binary_op(math_ops.polygamma) |
| zeta = _broadcasting_binary_op(math_ops.zeta) |
| |
| |
| def _binary_op(fn): |
| """Wrapper that restricts `fn` to have the correct signature.""" |
| |
| def binary_op_wrapper(x, y, name=None): |
| return fn(x, y, name=name) |
| |
| return binary_op_wrapper |
| |
| |
| transpose = _binary_op(array_ops.transpose) |
| rev = _binary_op(array_ops.reverse) |
| |
| bitcast_convert_type = array_ops.bitcast |
| |
| |
| def broadcast(x, dims, name=None): |
| x = ops.convert_to_tensor(x) |
| shape = array_ops.concat([constant_op.constant(dims), |
| array_ops.shape(x)], |
| axis=0) |
| return array_ops.broadcast_to(x, shape, name=name) |
| |
| |
| def clamp(a, x, b, name=None): |
| return min(max(a, x, name=name), b, name=name) |
| |
| |
| concatenate = array_ops.concat |
| |
| |
| def conv(lhs, |
| rhs, |
| window_strides, |
| padding, |
| lhs_dilation, |
| rhs_dilation, |
| dimension_numbers, |
| feature_group_count=1, |
| precision_config=None, |
| preferred_element_type=None, |
| name=None, |
| use_v2=False, |
| batch_group_count=1): |
| """Wraps the XLA ConvGeneralDilated operator. |
| |
| ConvGeneralDilated is the most general form of XLA convolution and is |
| documented at |
| https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution |
| |
| Args: |
| lhs: the input tensor |
| rhs: the kernel tensor |
| window_strides: the inter-window strides |
| padding: the padding to apply at the start and end of each input dimensions |
| lhs_dilation: dilation to apply between input elements |
| rhs_dilation: dilation to apply between kernel elements |
| dimension_numbers: a `ConvolutionDimensionNumbers` proto. |
| feature_group_count: number of feature groups for grouped convolution. |
| precision_config: a `xla.PrecisionConfig` proto. |
| preferred_element_type: the result `dtype`. |
| name: an optional name for the operator. |
| use_v2: an optional request to use the XlaConvV2 op even if not necessary. |
| batch_group_count: number of batch groups or grouped filters. |
| |
| Returns: |
| A tensor representing the output of the convolution. |
| """ |
| precision_config_proto = "" |
| if precision_config: |
| precision_config_proto = precision_config.SerializeToString() |
| needs_v2 = ( |
| preferred_element_type or (lhs.dtype != rhs.dtype) or |
| batch_group_count > 1) |
| if preferred_element_type is None: |
| preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype) |
| if needs_v2 or use_v2: |
| return gen_xla_ops.xla_conv_v2( |
| lhs, |
| rhs, |
| window_strides=window_strides, |
| padding=padding, |
| lhs_dilation=lhs_dilation, |
| rhs_dilation=rhs_dilation, |
| feature_group_count=feature_group_count, |
| batch_group_count=batch_group_count, |
| dimension_numbers=dimension_numbers.SerializeToString(), |
| precision_config=precision_config_proto, |
| preferred_element_type=preferred_element_type, |
| name=name) |
| return gen_xla_ops.xla_conv( |
| lhs, |
| rhs, |
| window_strides=window_strides, |
| padding=padding, |
| lhs_dilation=lhs_dilation, |
| rhs_dilation=rhs_dilation, |
| feature_group_count=feature_group_count, |
| dimension_numbers=dimension_numbers.SerializeToString(), |
| precision_config=precision_config_proto, |
| name=name) |
| |
| |
| convert_element_type = math_ops.cast |
| |
| |
| def dot(lhs, rhs, name=None): |
| return math_ops.tensordot(lhs, rhs, axes=1, name=name) |
| |
| |
| DotDimensionNumbers = xla_data_pb2.DotDimensionNumbers |
| PrecisionConfig = xla_data_pb2.PrecisionConfig |
| |
| |
| def dot_general(lhs, |
| rhs, |
| dimension_numbers, |
| precision_config=None, |
| preferred_element_type=None, |
| name=None, |
| use_v2=False): |
| precision_config_proto = "" |
| if precision_config: |
| precision_config_proto = precision_config.SerializeToString() |
| needs_v2 = preferred_element_type or (lhs.dtype != rhs.dtype) |
| if preferred_element_type is None: |
| preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype) |
| if needs_v2 or use_v2: |
| return gen_xla_ops.xla_dot_v2( |
| lhs, |
| rhs, |
| dimension_numbers=dimension_numbers.SerializeToString(), |
| precision_config=precision_config_proto, |
| preferred_element_type=preferred_element_type, |
| name=name) |
| return gen_xla_ops.xla_dot( |
| lhs, |
| rhs, |
| dimension_numbers=dimension_numbers.SerializeToString(), |
| precision_config=precision_config_proto, |
| name=name) |
| |
| |
| def self_adjoint_eig(a, lower, max_iter, epsilon): |
| return gen_xla_ops.xla_self_adjoint_eig(a, lower, max_iter, epsilon) |
| |
| |
| def svd(a, max_iter, epsilon, precision_config=None): |
| precision_config_proto = "" |
| if precision_config: |
| precision_config_proto = precision_config.SerializeToString() |
| return gen_xla_ops.xla_svd(a, max_iter, epsilon, precision_config_proto) |
| |
| |
| dynamic_slice = gen_xla_ops.xla_dynamic_slice |
| dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice |
| einsum = gen_xla_ops.xla_einsum |
| |
| # TODO(phawkins): generalize tf.pad to support interior padding, and then remove |
| # the XLA-specific pad operator. |
| pad = gen_xla_ops.xla_pad |
| |
| |
| def random_normal(mu, sigma, dims, name=None): |
| mu = ops.convert_to_tensor(mu) |
| return random_ops.random_normal( |
| dims, mean=mu, stddev=sigma, dtype=mu.dtype, name=name) |
| |
| |
| def random_uniform(minval, maxval, dims, name=None): |
| minval = ops.convert_to_tensor(minval) |
| return random_ops.random_uniform( |
| dims, minval, maxval, dtype=minval.dtype, name=name) |
| |
| |
| def rng_bit_generator(algorithm, initial_state, shape, dtype): |
| """Stateless PRNG bit generator. |
| |
| Wraps the XLA RngBitGenerator operator, documented at |
| https://www.tensorflow.org/performance/xla/operation_semantics#rngbitgenerator. |
| |
| Args: |
| algorithm: The PRNG algorithm to use, one of |
| tf.random.Algorithm.{PHILOX, THREEFRY, AUTO_SELECT}. |
| initial_state: Initial state for the PRNG algorithm. For THREEFRY, it |
| should be a u64[2] and for PHILOX a u64[3]. |
| shape: The output shape of the generated data. |
| dtype: The type of the tensor. |
| |
| Returns: |
| a tuple with a new state and generated data of the given shape. |
| """ |
| alg_int = stateless_random_ops.convert_alg_to_int(algorithm) |
| return gen_xla_ops.xla_rng_bit_generator(alg_int, initial_state, shape, |
| dtype=dtype) |
| |
| |
| recv = gen_xla_ops.xla_recv |
| reduce = gen_xla_ops.xla_reduce |
| variadic_reduce = gen_xla_ops.xla_variadic_reduce_v2 |
| |
| ops.no_gradient("XlaVariadicReduce") |
| |
| |
| def reduce_window(operand, |
| init, |
| reducer, |
| window_dimensions, |
| window_strides=None, |
| base_dilations=None, |
| window_dilations=None, |
| padding=None, |
| name=None): |
| """Wraps the XLA ReduceWindow operator. |
| |
| ReduceWindow is documented at |
| https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . |
| |
| Args: |
| operand: the input tensor |
| init: a scalar tensor representing the initial value for the reduction |
| reducer: a reduction function that combines a pair of scalars. |
| window_dimensions: shape of the window, as a list of integers |
| window_strides: inter-window strides, as a list of integers. Optional; if |
| omitted, defaults to strides of 1. |
| padding: padding to apply to 'operand'. List of (low, high) pairs of |
| integers that specify the padding to apply before and after each |
| dimension. Optional; if omitted, defaults to no padding. |
| name: the operator name, or None. |
| |
| Returns: |
| A tensor that represents the output of the reduce_window operator. |
| """ |
| window_strides = window_strides or [1] * len(window_dimensions) |
| base_dilations = base_dilations or [1] * len(window_dimensions) |
| window_dilations = window_dilations or [1] * len(window_dimensions) |
| padding = padding or [(0, 0)] * len(window_dimensions) |
| return gen_xla_ops.xla_reduce_window( |
| input=operand, |
| init_value=init, |
| window_dimensions=window_dimensions, |
| window_strides=window_strides, |
| base_dilations=base_dilations, |
| window_dilations=window_dilations, |
| padding=padding, |
| computation=reducer, |
| name=name) |
| |
| |
| replica_id = gen_xla_ops.xla_replica_id |
| |
| # Set a static bound for the given input value as a hint to Xla compiler, |
| # returns the same value. |
| # Usage: |
| # def f(t, p): |
| # p = xla.set_bound(p, 3) # Tells xla the constraint that p <= 3. |
| # return t[:p] # xla knows the bound of the slice is 3. |
| set_bound = gen_xla_ops.xla_set_bound |
| |
| |
| # Make a static dimension into a xla bounded dynamic dimension. The current |
| # static dimension size will become the bound and the second operand becomes the |
| # dynamic size of the dimension. |
| # |
| # This should mostly be used for testing. |
| # |
| # def f(): |
| # array = tf.convert_to_tensor([[1, 2, 3, 4, 5]]) |
| # # Tells xla the valid size of the array is 3. |
| # dim = 0 |
| # p = xla_set_dynamic_dimension_size(array, dim, 3) |
| # assert(reduce_sum(p) == 6) # xla knows only the first 3 elements are valid. |
| set_dynamic_dimension_size = gen_xla_ops.xla_set_dynamic_dimension_size |
| |
| |
| # Inverse of xla_set_dynamic_dimension_size. Make an xla bounded dynamic |
| # dimension into a static dimension. The bound of the size of dimension |
| # `dim_index` becomes the static dimension size. |
| remove_dynamic_dimension_size = gen_xla_ops.xla_remove_dynamic_dimension_size |
| |
| |
| def reshape(x, new_sizes, dimensions=None, name=None): |
| if dimensions is not None: |
| x = array_ops.transpose(x, dimensions) |
| x = array_ops.reshape(x, new_sizes, name=name) |
| return x |
| |
| |
| def select(condition, x, y, name=None): |
| return array_ops.where(condition, x, y, name) |
| |
| |
| select_and_scatter = gen_xla_ops.xla_select_and_scatter |
| send = gen_xla_ops.xla_send |
| |
| |
| def slice(x, start_dims, limit_dims, strides): |
| spec = [ |
| _slice(start, limit, stride) |
| for (start, limit, stride) in zip(start_dims, limit_dims, strides) |
| ] |
| return x[tuple(spec)] |
| |
| |
| sharding = gen_xla_ops.xla_sharding |
| |
| |
| @ops.RegisterGradient("XlaSharding") |
| def _sharding_grad(op, grad): |
| """Gradient for XlaSharding op.""" |
| sharding_attr = op.get_attr("sharding") |
| grad_sharding = gen_xla_ops.xla_sharding( |
| grad, |
| sharding=sharding_attr, |
| unspecified_dims=op.get_attr("unspecified_dims")) |
| # pylint: disable=protected-access |
| grad_sharding.op._set_attr("_XlaSharding", |
| attr_value_pb2.AttrValue(s=sharding_attr)) |
| return [grad_sharding] |
| |
| |
| spmd_full_to_shard_shape = gen_xla_ops.xla_spmd_full_to_shard_shape |
| spmd_shard_to_full_shape = gen_xla_ops.xla_spmd_shard_to_full_shape |
| |
| |
| @ops.RegisterGradient("XlaSpmdFullToShardShape") |
| def _spmd_full_to_shard_shape_grad(op, grad): |
| s2f = gen_xla_ops.xla_spmd_shard_to_full_shape( |
| grad, |
| manual_sharding=op.get_attr("manual_sharding"), |
| full_shape=op.inputs[0].shape.as_list(), |
| dim=op.get_attr("dim"), |
| unspecified_dims=op.get_attr("unspecified_dims")) |
| return [s2f] |
| |
| |
| @ops.RegisterGradient("XlaSpmdShardToFullShape") |
| def _spmd_shard_to_full_shape_grad(op, grad): |
| f2s = gen_xla_ops.xla_spmd_full_to_shard_shape( |
| grad, |
| manual_sharding=op.get_attr("manual_sharding"), |
| dim=op.get_attr("dim"), |
| unspecified_dims=op.get_attr("unspecified_dims")) |
| return [f2s] |
| |
| |
| sort = gen_xla_ops.xla_sort |
| key_value_sort = gen_xla_ops.xla_key_value_sort |
| variadic_sort = gen_xla_ops.xla_variadic_sort |
| while_loop = gen_xla_ops.xla_while |
| dequantize = gen_xla_ops.xla_dequantize |
| custom_call = gen_xla_ops.xla_custom_call |
| |
| |
| def call_module(args, *, module, Tout, Sout, dim_args_spec=()): |
| return gen_xla_ops.xla_call_module( |
| args, module=module, dim_args_spec=dim_args_spec, Tout=Tout, Sout=Sout) |
| |
| |
| def gather(operand, start_indices, dimension_numbers, slice_sizes, |
| indices_are_sorted=False, name=None): |
| return gen_xla_ops.xla_gather( |
| operand, |
| start_indices, |
| slice_sizes=slice_sizes, |
| dimension_numbers=dimension_numbers.SerializeToString(), |
| indices_are_sorted=indices_are_sorted, |
| name=name) |
| |
| |
| def scatter(operand, scatter_indices, updates, update_computation, |
| dimension_numbers, indices_are_sorted=False, name=None): |
| return gen_xla_ops.xla_scatter( |
| operand, |
| scatter_indices, |
| updates, |
| update_computation=update_computation, |
| dimension_numbers=dimension_numbers.SerializeToString(), |
| indices_are_sorted=indices_are_sorted, |
| name=name) |
| |
| |
| def optimization_barrier(*args): |
| return gen_xla_ops.xla_optimization_barrier(args) |
| |
| |
| def reduce_precision(operand, exponent_bits, mantissa_bits): |
| return gen_xla_ops.xla_reduce_precision(operand, exponent_bits, mantissa_bits) |