blob: 29b6278ca9ef8b5f90edb9960a252d8c6a9c396f [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""An XLA client in Python."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import collections
import enum # pylint: disable=g-bad-import-order
import inspect
import itertools
import os
from absl import logging
import numpy as np
import six
# Note this module does *not* depend on any Python protocol buffers. The XLA
# Python bindings are currently packaged both as part of jaxlib and as part
# of TensorFlow. If we use protocol buffers here, then importing both jaxlib
# and TensorFlow may fail with duplicate protocol buffer message definitions.
from tensorflow.compiler.xla.python import xla_extension as _xla
from tensorflow.compiler.xla.python.xla_extension import ops
# Most functions are snake_case for consistency with other modules, whereas
# method names of ComputationBuilder and Computation are CamelCase for
# consistency with XLA.
# pylint: disable=invalid-name
@six.add_metaclass(abc.ABCMeta)
class Backend(object):
"""Abstract base class for XLA backends."""
def __init__(self, platform):
"""Creates a new Backend.
Args:
platform: A string naming the platform; for example 'gpu'.
"""
self.platform = platform
@abc.abstractmethod
def device_count(self):
"""Returns the number of devices known to the backend."""
@abc.abstractmethod
def local_device_count(self):
"""Returns the number of devices local to this host."""
@abc.abstractmethod
def devices(self):
"""Returns a list of `device_count()` Device subclasses."""
@abc.abstractmethod
def host_id(self):
"""Returns the integer ID of this host."""
@abc.abstractmethod
def buffer_from_pyval(self, pyval, device=None):
"""Allocates a fresh buffer and populates it with `pyval`."""
def buffers_from_pyvals(self, pyvals_and_devices):
"""Allocates buffers and populates them with `pyvals`."""
return [
self.buffer_from_pyval(pyval, device)
for pyval, device in pyvals_and_devices
]
@abc.abstractmethod
def make_tuple(self, c_buffers, device):
"""Makes a tuple from a sequence of backend buffer objects."""
@abc.abstractmethod
def compile(self, computation, compile_options):
"""Compiles a computation. Returns an executable."""
class LocalBackend(Backend):
"""XLA backend implemented using the in-process xla::LocalClient API."""
def __init__(self, platform, client):
"""Creates a new LocalBackend.
Args:
platform: A string; the user-visible platform name, e.g. 'gpu'.
client: An _xla.PyLocalClient object.
"""
super(LocalBackend, self).__init__(platform)
self.client = client
def device_count(self):
return self.client.device_count()
def local_device_count(self):
return self.client.local_device_count()
def devices(self):
return self.client.devices()
def local_devices(self):
return self.client.local_devices()
def host_id(self):
return self.client.host_id()
def buffer_from_pyval(self, pyval, device=None):
if device is None:
device = self.local_devices()[0]
return _xla.PyLocalBuffer.from_python(pyval, self.client, device)
def make_tuple(self, c_buffers, device):
return _xla.PyLocalBuffer.make_tuple(c_buffers, self.client, device)
def compile(self, c_computation, compile_options):
options = _xla.ExecutableBuildOptions()
options.num_replicas = compile_options.num_replicas
if compile_options.result_layout:
options.result_layout = compile_options.result_layout
options.debug_options.xla_cpu_fast_math_honor_infs = True
options.debug_options.xla_cpu_fast_math_honor_nans = True
options.debug_options.xla_cpu_fast_math_honor_division = True
options.debug_options.xla_cpu_fast_math_honor_functions = True
options.debug_options.xla_gpu_enable_fast_min_max = False
return _xla.LocalExecutable.Compile(c_computation,
compile_options.argument_layouts,
options, self.client,
compile_options.device_assignment)
def serialize(self, executable):
return self.client.SerializeExecutable(executable)
def deserialize(self, serialized_executable):
return self.client.DeserializeExecutable(serialized_executable, self.client)
xla_platform_names = {
'cpu': 'Host',
'gpu': 'CUDA',
}
def _cpu_backend_factory():
client = _xla.LocalClient.Get(
platform='cpu',
xla_platform_id=xla_platform_names['cpu'],
asynchronous=True)
return LocalBackend(platform='cpu', client=client)
def _gpu_backend_factory():
"""Returns a GPU backend. BFC allocator is used by default."""
allocator = os.getenv('XLA_PYTHON_CLIENT_ALLOCATOR', 'default').lower()
memory_fraction = os.getenv('XLA_PYTHON_CLIENT_MEM_FRACTION')
preallocate = os.getenv('XLA_PYTHON_CLIENT_PREALLOCATE')
if allocator not in ('default', 'platform', 'bfc'):
raise ValueError(
'XLA_PYTHON_CLIENT_ALLOCATOR env var must be "default", "platform", or '
'"bfc", got "%s"' % allocator)
config = _xla.AllocatorConfig()
if allocator == 'default':
config.kind = _xla.AllocatorConfig.Kind.DEFAULT
if allocator == 'platform':
config.kind = _xla.AllocatorConfig.Kind.PLATFORM
if allocator == 'bfc':
config.kind = _xla.AllocatorConfig.Kind.BFC
if memory_fraction:
config.memory_fraction = float(memory_fraction)
config.preallocate = preallocate not in ('0', 'false', 'False')
client = _xla.LocalClient.Get(
platform='gpu',
xla_platform_id=xla_platform_names['gpu'],
asynchronous=True,
allocator_config=config)
return LocalBackend(platform='gpu', client=client)
# Backend factories, keyed by user-visible name, in increasing priority order.
_local_backend_factories = collections.OrderedDict([
('cpu', _cpu_backend_factory),
('gpu', _gpu_backend_factory),
])
def register_local_backend_factory(name, factory):
_local_backend_factories[name] = factory
_local_backends = None
def _get_local_backends():
"""Instantiates all known local backends."""
global _local_backends
if _local_backends is not None:
return _local_backends
_local_backends = collections.OrderedDict()
for name, factory in _local_backend_factories.items():
logging.vlog(2, "Initializing backend '%s'" % name)
try:
backend = factory()
except RuntimeError:
if name == 'cpu':
# We always expect CPU to initialize successfully.
raise
else:
# If the backend isn't built into the binary, or if it has no devices,
# we expect a RuntimeError.
continue
_local_backends[name] = backend
return _local_backends
def get_local_backend(name=None):
"""Returns a local backend.
Args:
name: the backend name. If `None`, a default local backend is returned,
typically `gpu` if one is present, or `cpu` if not. If a string, the named
backend is returned or an exception raised.
Returns:
A LocalBackend object.
"""
backends = _get_local_backends()
if name is not None:
try:
return backends[name]
except KeyError:
raise RuntimeError('Unknown backend {}'.format(name))
return list(backends.values())[-1]
class OpMetadata(object):
"""Python representation of a xla.OpMetadata protobuf."""
__slots__ = ('op_type', 'op_name', 'source_file', 'source_line')
def __init__(self, op_type='', op_name='', source_file='', source_line=0):
self.op_type = op_type
self.op_name = op_name
self.source_file = source_file
self.source_line = source_line
def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1):
"""Helper for use in source mapping that returns an OpMetadata object."""
full_filename, lineno = inspect.stack()[skip_frames][1:3]
filename = os.path.basename(full_filename)
return OpMetadata(
op_type=op_type,
op_name=op_name,
source_file=filename,
source_line=lineno)
PrimitiveType = _xla.PrimitiveType
XLA_ELEMENT_TYPE_TO_DTYPE = {
PrimitiveType.PRED: np.dtype('bool'),
PrimitiveType.S8: np.dtype('int8'),
PrimitiveType.S16: np.dtype('int16'),
PrimitiveType.S32: np.dtype('int32'),
PrimitiveType.S64: np.dtype('int64'),
PrimitiveType.U8: np.dtype('uint8'),
PrimitiveType.U16: np.dtype('uint16'),
PrimitiveType.U32: np.dtype('uint32'),
PrimitiveType.U64: np.dtype('uint64'),
PrimitiveType.F16: np.dtype('float16'),
PrimitiveType.F32: np.dtype('float32'),
PrimitiveType.F64: np.dtype('float64'),
PrimitiveType.C64: np.dtype('complex64'),
PrimitiveType.C128: np.dtype('complex128'),
PrimitiveType.TUPLE: np.dtype(np.object),
PrimitiveType.TOKEN: np.dtype(np.object),
}
# Note the conversion on the key. Numpy has a known issue wherein dtype hashing
# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus,
# when keying by dtype in this dict, we use the string form of dtypes.
DTYPE_TO_XLA_ELEMENT_TYPE = {
str(dt): et for et, dt in XLA_ELEMENT_TYPE_TO_DTYPE.items()
}
def dtype_to_etype(dtype):
"""Convenience function for reading DTYPE_TO_XLA_ELEMENT_TYPE."""
return DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))]
Shape = _xla.Shape
Shape.__doc__ = """
A Shape is an object defined in C++ that duck types like the following class:
class Shape(object):
'''Represents an XLA shape.
A shape is either an array shape, having rank-many integer
dimensions and an element type (represented by a Numpy dtype), or it
is a tuple shape, having a shape for every tuple component:
type shape =
TupleShape of shape list
| ArrayShape of { dimensions: int list; element_type: dtype }
'''
@staticmethod
def tuple_shape(tuple_shapes) -> Shape:
"Construct a tuple shape."
@staticmethod
def array_shape(element_type, dimensions, minor_to_major=None) -> Shape:
@staticmethod
def from_pyval(pyval) -> Shape:
"Returns a Shape that describes a tuple-tree of Numpy arrays."
def __init__(self, str) -> Shape:
"Parses a shape string."
def __eq__(self, other: Shape) -> bool:
def __ne__(self, other: Shape) -> bool:
def __hash__(self):
def __repr__(self):
def is_tuple(self) -> bool:
def is_array(self) -> bool:
def tuple_shapes(self) -> [Shape]:
def numpy_dtype(self) -> np.dtype:
"Like element_type(), but returns dtype('O') for a tuple shape."
def xla_element_type(self) -> PrimitiveType:
def element_type(self) -> np.dtype:
def dimensions(self) -> (int, int, ...):
def rank(self) -> int:
def with_major_to_minor_layout_if_absent(self) -> Shape:
"Returns a copy with missing layouts set to major-to-minor."
def to_serialized_proto(self) -> bytes:
"Returns 'shape' as a serialized proto."
"""
ProgramShape = _xla.ProgramShape
ProgramShape.__doc__ = """
A ProgramShape is a C++ object that duck types like the following class.
class ProgramShape(object):
def __init__(self, parameter_shapes, result_shape):
def parameter_shapes(self) -> [Shape]:
def result_shape(self) -> Shape:
def __repr__(self):
"""
class Buffer(object):
"""Represents a handle to data owned by XLA.
The referent is ready for use in executing a local, compiled
Computation. On XLA platforms involving a device (e.g. GPU), this
means the referent is in device memory.
"""
@staticmethod
def from_pyval(pyval, device=None, backend=None):
"""Copies the `pyval` to a freshly allocated on-device buffer."""
backend = backend or get_local_backend()
return backend.buffer_from_pyval(pyval, device)
@staticmethod
def from_pyvals(pyvals_and_devices, backend=None):
"""Copies multiple Python values to freshly allocated on-device buffers.
Arguments:
pyvals_and_devices: a list of `(pyval, device)` pairs, where `pyval` is a
Python value to copy (e.g., a NumPy array), and `device` is an integer
device ordinal.
backend: a Backend object, or `None` to use the default local backend.
Returns:
A list of `Buffer` objects corresponding to `pyvals_and_devices`.
"""
backend = backend or get_local_backend()
return backend.buffers_from_pyvals(pyvals_and_devices)
@staticmethod
def make_tuple(buffers, device, backend=None):
backend = backend or get_local_backend()
return backend.make_tuple(buffers, device)
# Buffer is not an instantiable type and exists only for its static methods.
# The underlying buffer objects are C++ object with the following
# API:
# def shape(self) -> Shape:
# def device(self) -> int:
# def delete(self):
# def destructure(self) -> [Buffer]
# def is_deleted(self) -> bool:
# def block_host_until_ready(self):
# """Blocks the calling thread until the buffer is ready on device."""
# def copy_to_host_async(self):
# """Requests a copy of the buffer to the host.
#
# Does not block waiting for the copy. Values fetched are available via
# `to_py()`; the purpose of `copy_to_host_async` is to prefetch values
# for subsequent `to_py()` calls, especially when requesting many values
# at once.
# """
# def to_py(self):
# """Returns the value of the buffer as a Python tuple tree of ndarrays."""
#
# TODO(phawkins): remove Buffer and its static methods completely, have
# clients call methods on Backend to create buffers.
# TODO(phawkins): Alias for backward compatibility. Remove after JAX drops
# compatibility with Jaxlib versions older than 0.1.13.
LocalBuffer = Buffer
def shape_from_pyval(pyval):
"""Returns a Shape that describes a tuple-tree of Numpy arrays."""
def convert(pyval):
if isinstance(pyval, tuple):
return Shape.tuple_shape(tuple(convert(elt) for elt in pyval))
else:
return Shape.array_shape(pyval.dtype, np.shape(pyval))
return convert(pyval)
def transfer_to_infeed(value, device_ordinal=0):
"""Transfers the given value into the XLA infeed queue.
XLA's infeed queue is a single queue that feeds the "XLA virtual machine" with
a totally ordered stream of values. This is dequeued from XLA computations via
the Infeed() operation.
Args:
value: the value that the caller would like to enqueue into the XLA infeed
queue
device_ordinal: the device to infeed the value to. Each device has a
distinct infeed queue.
"""
# TODO(phawkins): support non-default backends.
backend = get_local_backend()
backend.client.TransferToInfeed(value, device_ordinal)
def transfer_from_outfeed(shape, device_ordinal=0):
"""Transfers a literal of the given shape from `device_ordinal`'s outfeed.
Args:
shape: The shape of the value to transfer from outfeed.
device_ordinal: The device ordinal to transfer the outfeed value from. Each
device has a distinct outfeed queue..
Returns:
The literal value that is produced from the outfeed queue.
"""
# TODO(phawkins): support non-default backends.
backend = get_local_backend()
return backend.client.TransferFromOutfeed(
shape.with_major_to_minor_layout_if_absent(), device_ordinal)
DeviceAssignment = _xla.DeviceAssignment
DeviceAssignment.__doc__ = """
A DeviceAssignment is a C++ object with the following signature.
def create(assignment):
'''Builds a device assignment.
Args:
assignment: a 2D numpy array of device ordinal integers, indexed by
[replica][computation_in_replica].
Returns:
A device assignment.
'''
def replica_count():
'''Returns the number of replicas.'''
def computation_count():
'''Returns the number of computations per replica.'''
"""
Device = _xla.Device
class CompileOptions(object):
"""Python object for XLA compile options.
These options can be passed to the 'compile' step when using a local XLA
client.
"""
def __init__(self):
self.xla_dump_to = None
self.dump_hlo_pass_re = None
self.dump_hlo_module_re = None
self.dump_hlo_as_text = None
self.dump_hlo_as_proto = None
self.hlo_profile = None
self.num_replicas = 1
self.argument_layouts = None
self.result_layout = None
self.device_assignment = None
class Computation(object):
"""Python wrapper for an XLA Computation.
A Computation can be compiled to form an Executable, or used as a
subcomputation in ComputationBuilder methods.
"""
def __init__(self, c_computation, backend=None):
self._c_computation = c_computation
# The backend argument is deprecated. Pass a backend to Compile() instead.
self._backend = backend
@property
def computation(self):
return self._c_computation
def GetSerializedProto(self):
"""Gets the serialized HloModuleProto proto object in this computation.
Returns:
A string containing a serialized HloModuleProto proto containing the
computation and its dependencies.
"""
return self.computation.GetSerializedProto()
def GetHloText(self):
"""Get the textual HLO representation of this computation.
Returns:
A string containing the textual HLO.
"""
return self.computation.GetHloText()
def GetHloDotGraph(self):
"""Get a Graphviz Dot representation of this computation.
Returns:
A string containing the graphviz dot graph.
"""
return self.computation.GetHloDotGraph()
def Compile(self, argument_shapes=None, compile_options=None, backend=None):
"""Compiles a computation.
Computations are the result of a "ComputationBuild'ing" process.
Arguments:
argument_shapes: Deprecated. Use compile_options.argument_layouts instead.
compile_options: options to use for compilation, includes an optional laid
out result shape for the computation.
backend: a `Backend` for which an executable should be generated.
Returns:
A Executable instance.
"""
backend = backend or self._backend or get_local_backend()
compile_options = compile_options or CompileOptions()
if argument_shapes:
compile_options.argument_layouts = argument_shapes
return backend.compile(self.computation, compile_options)
def GetProgramShape(self):
return self._c_computation.GetProgramShape()
def GetReturnValueShape(self):
return self._c_computation.GetProgramShape().result_shape()
# An Executable is a C++ class that duck types with the following API:
# class Executable(object):
# def DeviceOrdinals(self) -> [int]:
# def Execute(self, arguments : [Buffer]) -> Buffer:
# """Execute on one replica with Buffer arguments and return value."""
#
# def SizeOfGeneratedCodeInBytes(self) -> int:
# """Return generated binary size, or -1 if not known."""
#
# def ExecutePerReplica(self, arguments: [[Buffer]]) -> [Buffer]:
# """Execute on many replicas with Buffer arguments and return value.
#
# Args:
# arguments: A sequence of sequences of Buffers. The i'th inner sequence
# comprises the arguments for execution on the i'th replica.
#
# Returns:
# A list of the computation's outputs for each replica, as a Buffer. If
# a shallow sequence of arguments was passed in for `arguments`, then the
# sole, zero'th replica's output is returned instead, as a Buffer.
# """
#
# There are different implementations of Executable for the Local and XRT
# backends.
def execute_with_python_values(executable, arguments=(), backend=None):
"""Execute on one replica with Python values as arguments and output."""
backend = backend or get_local_backend()
def put(arg):
return Buffer.from_pyval(
arg, device=executable.DeviceOrdinals()[0], backend=backend)
arguments = [put(arg) for arg in arguments]
return executable.Execute(arguments).to_py()
def execute_with_python_values_replicated(executable, arguments, backend=None):
"""Execute on many replicas with Python values as arguments and output.
Arguments:
executable: the program to run.
arguments: a list of lists of Python values indexed by
`[replica][arg_num]` to pass as inputs.
backend: the backend we are targeting.
Returns:
A list of python values, one per replica.
"""
backend = backend or get_local_backend()
device_ordinals = executable.DeviceOrdinals()
# pylint: disable=g-complex-comprehension
flat_args = [(arg, device_ordinals[replica])
for replica, replica_args in enumerate(arguments)
for arg in replica_args]
flat_arg_buffers = Buffer.from_pyvals(flat_args, backend=backend)
arg_buffers = []
for replica_args in arguments:
arg_buffers.append(flat_arg_buffers[:len(replica_args)])
flat_arg_buffers = flat_arg_buffers[len(replica_args):]
return [out.to_py() for out in executable.ExecutePerReplica(arg_buffers)]
class PaddingType(enum.Enum):
VALID = 1
SAME = 2
def _convert_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims,
window_strides):
"""Maps PaddingType or string to pad values (list of pairs of ints)."""
if not isinstance(padding_type, (str, PaddingType)):
msg = 'padding_type must be str or PaddingType, got {}.'
raise TypeError(msg.format(type(padding_type)))
if isinstance(padding_type, str):
if padding_type.upper() == 'VALID':
padding_type = PaddingType.VALID
elif padding_type.upper() == 'SAME':
padding_type = PaddingType.SAME
else:
msg = 'Unknown padding type string: expected "VALID" or "SAME", got {}.'
raise ValueError(msg.format(padding_type))
if padding_type == PaddingType.VALID:
return [(0, 0)] * len(window_strides)
elif padding_type == PaddingType.SAME:
out_shape = np.ceil(np.true_divide(lhs_dims, window_strides)).astype(int)
pad_sizes = [
max((out_size - 1) * stride + filter_size - in_size, 0)
for out_size, stride, filter_size, in_size in zip(
out_shape, window_strides, rhs_dims, lhs_dims)
]
return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes]
else:
msg = 'Unexpected PaddingType value: {}'
raise ValueError(msg.format(padding_type))
class ComputationBuilder(object):
"""XLA computation builder.
Enqueues XLA ops in sequence and in order to build a
Computation, which in turn can be compiled into a
LocalExecutable, which in turn can be locally executed.
"""
# The methods of this class map 1-to-1 onto the XLA C++
# computation builder API. Therefore, there's no need to laboriously list
# arguments and return values for every method, especially where it's obvious.
#
# pylint: disable=g-doc-return-or-yield
# pylint: disable=g-doc-args
def __init__(self, name):
self._builder = _xla.XlaBuilder(name)
self._parameter_numbering = itertools.count()
def Build(self, root=None, backend=None):
"""Builds a `Computation` from the contents of the builder.
Args:
root: if not None, the operator containing the return value of the
computation.
Returns:
A `Computation`.
"""
if root is not None:
return Computation(self._builder.Build(root), backend=backend)
else:
return Computation(self._builder.Build(), backend=backend)
def GetShape(self, operand):
return self._builder.GetShape(operand)
def SetOpMetadata(self, op_metadata):
"""Set metadata for operations that are about to be enqueued."""
self._builder.SetOpMetadata(op_metadata)
def ClearOpMetadata(self):
"""Clear metadata for operations that are about to be enqueued."""
self._builder.ClearOpMetadata()
def CreateToken(self):
"""Enqueues a CreateToken op onto the computation.
Returns:
An XlaOp, representing a fresh token.
"""
return ops.CreateToken(self._builder)
def AfterAll(self, tokens):
"""Enqueues a after-all op onto the computation.
`AfterAll` takes a variadic number of tokens and produces a single token.
Args:
tokens: a list of `XlaOp` values representing predecessor tokens.
Returns:
An `XlaOp`.
"""
return ops.AfterAll(self._builder, tokens)
def Infeed(self, shape, token=None):
"""Enqueues an infeed op onto the computation.
Infeed operations dequeue data of the given shape from the device's infeed
queue for subsequent use in the computation.
Args:
shape: a `Shape` describing the shape of the infed value.
token: an optional `XlaOp` representing a token after which the infeed
effect should be sequenced.
Returns:
An XlaOp, representing a (value, token) pair.
"""
if token is None:
token = ops.CreateToken(self._builder)
return ops.InfeedWithToken(token,
shape.with_major_to_minor_layout_if_absent())
def Outfeed(self, operand, token=None):
"""Enqueues an outfeed op onto the computation.
Outfeed operations enqueue data, using the given operand, onto the XLA
outfeed queue for subsequent dequeue via the client API.
Args:
operand: an `XlaOp` representing the data to outfeed.
token: an `XlaOp` representing a token after which the outfeed should be
sequenced.
Returns:
An `XlaOp` representing a token.
"""
if token is None:
token = ops.CreateToken(self._builder)
return ops.OutfeedWithToken(operand, token, self._builder.GetShape(operand),
'')
def Constant(self, value):
"""Enqueues a constant op onto the computation.
Args:
value: value for the constant, as a np.array with an explicit dtype set to
one of the supported types.
Returns:
An XlaOp.
"""
return ops.ConstantLiteral(self._builder, value)
def ConstantF32Scalar(self, value):
"""Convenience method to enqueue a scalar F32 constant op.
Args:
value: a floating-point number.
Returns:
An XlaOp.
"""
return self.Constant(np.array(value, dtype=np.float32))
def ConstantF64Scalar(self, value):
"""Convenience method to enqueue a scalar F32 constant op.
Args:
value: a floating-point number.
Returns:
An XlaOp.
"""
return self.Constant(np.array(value, dtype=np.float64))
def ConstantS32Scalar(self, value):
"""Convenience method to enqueue a scalar S32 constant op.
Args:
value: a floating-point number.
Returns:
An XlaOp.
"""
return self.Constant(np.array(value, dtype=np.int32))
def ConstantS64Scalar(self, value):
"""Convenience method to enqueue a scalar S64 constant op.
Args:
value: a floating-point number.
Returns:
An XlaOp.
"""
return self.Constant(np.array(value, dtype=np.int64))
def ConstantPredScalar(self, value):
"""Convenience method to enqueue a scalar PRED constant op.
Args:
value: a boolean value.
Returns:
An XlaOp.
"""
return self.Constant(np.array(value, dtype=np.bool))
def ParameterWithShape(self, shape, name=None, parameter_num=None):
"""Enqueues a Parameter op onto the computation, given a shape.
Args:
shape: the parameter's shape as a Shape object.
name: optional string name for the parameter.
parameter_num: parameter number in the computation function. If None, the
next linear parameter number is used. The default value capability can
be used for auto-numbering. If you're using auto-numbering for some
parameters, use it for *all* parameters to avoid clashes.
Returns:
An XlaOp.
"""
if name is None:
name = ''
if parameter_num is None:
parameter_num = next(self._parameter_numbering)
return ops.Parameter(self._builder, parameter_num,
shape.with_major_to_minor_layout_if_absent(),
name.encode('utf8'))
def ParameterFromNumpy(self, value, name=None, parameter_num=None):
"""Enqueues a Parameter op onto the computation.
Args:
value: a Numpy array, or a nested tuple thereof, from which the shape is
inferred.
name: as in ParameterWithShape.
parameter_num: as in ParameterWithShape.
Returns:
An XlaOp.
"""
return self.ParameterWithShape(
shape_from_pyval(value), name=name, parameter_num=parameter_num)
def Iota(self, dtype, size):
"""Enqueues an iota constant onto the computation.
Args:
dtype: expected numpy dtype of the output.
size: integer, the number of elements in the array.
Returns:
An XlaOp representing the added iota constant.
"""
element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))]
return ops.Iota(self._builder, element_type, size)
def BroadcastedIota(self, dtype, shape, dimension):
"""Enqueues a broadcasted iota constant onto the computation.
Args:
dtype: expected numpy dtype of the output.
shape: tuple of integers, the expected output shape (dimensions).
dimension: positive integer, dimension along which to increment values.
Returns:
An XlaOp representing the added broadcasted iota constant.
"""
element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))]
xla_shape = _xla.Shape.array_shape(element_type, shape, None)
return ops.Iota(self._builder, xla_shape, dimension)
def Concatenate(self, operands, dimension):
"""Enqueues a concatenate operation onto the computation.
Args:
operands: the operands to concatenate.
dimension: the dimension in which to perform the concatenation.
Returns:
An XlaOp representing the added concatenate op.
"""
return ops.ConcatInDim(self._builder, list(operands), dimension)
def ReplicaId(self):
"""Enqueues a ReplicaId operation onto the computation.
Returns:
A LocalOp representing the replica id.
"""
return _xla.ops.ReplicaId(self._builder)
def Pad(self, operand, padding_value, padding_config):
"""Enqueues a Pad operation onto the computation.
Args:
operand: XlaOp representing the array to pad.
padding_value: XlaOp representing the scalar pad value.
padding_config: either a PaddingConfig or a list of integer triples
(edge_padding_low, edge_padding_high, interior_padding) representing the
configuration of the padding operation.
Returns:
An XlaOp representing the added Pad op.
"""
if isinstance(padding_config, tuple) or isinstance(padding_config, list):
padding_config = GetPaddingConfigFromTriples(padding_config)
return ops.Pad(operand, padding_value, padding_config)
def Reshape(self, operand, dimensions, new_sizes):
"""Enqueues a reshape op onto the computation.
Args:
operand: XlaOp representing the array to be reshaped.
dimensions: sequence of integers encoding the order in which dimensions
are collapsed or None, in which case dimensions are flattened in order.
new_sizes: sequence of integers encoding the new dimension sizes (shape).
Returns:
An XlaOp representing the added Reshape op.
"""
if dimensions is None:
ndim = len(self.GetShape(operand).dimensions())
dimensions = tuple(range(ndim))
return ops.Reshape(operand, dimensions, new_sizes)
def AllReduce(self, operand, computation, replica_groups=None):
"""AllReduce op.
Args:
operand: XlaOp representing the input array
computation: a Computation object - binary reduction function.
replica_groups: optional, list of lists of ints encoding a partition of
the set {0, 1, ..., num_replicas} into equally-sized replica groups
within which the all-to-all is performed. If not supplied or None (the
default), all replicas belong to the same group.
Returns:
An XlaOp that represents the all-reduced result.
"""
replica_groups_protos = _get_replica_groups_protos(replica_groups)
return ops.AllReduce(operand, computation.computation,
replica_groups_protos, None)
def AllToAll(self,
operand,
split_dimension,
concat_dimension,
replica_groups=None):
"""AllToAll op.
Args:
operand: XlaOp representing the input array
split_dimension: the dimension along which the operand is split
concat_dimension: the dimension along which the split blocks are
concatenated
replica_groups: optional, list of lists of ints encoding a partition of
the set {0, 1, ..., num_replicas} into equally-sized replica groups
within which the all-to-all is performed. If not supplied or None (the
default), all replicas belong to the same group.
Returns:
An XlaOp that represents the all-to-all concatenation.
"""
replica_groups_protos = _get_replica_groups_protos(replica_groups)
if not replica_groups:
split_count = 1
else:
split_count = len(replica_groups[0])
if not all(split_count == len(g) for g in replica_groups):
raise ValueError('Replica groups must be equally sized')
return ops.AllToAll(operand, split_dimension, concat_dimension, split_count,
replica_groups_protos)
def CrossReplicaSum(self, operand, replica_groups=None):
"""CrossReplicaSum op.
Args:
operand: the operand to sum across replica instances.
replica_groups: optional, list of lists of ints encoding a partition of
the set {0, 1, ..., num_replicas} into equally-sized replica groups
within which the cross-replica sum is performed. If not supplied or None
(the default), all replicas belong to the same group.
Returns:
An XlaOp that represents on each replica the sum of its group's values.
"""
replica_groups_protos = _get_replica_groups_protos(replica_groups)
return ops.CrossReplicaSum(operand, replica_groups_protos)
def Trans(self, operand):
"""Specialized matrix transpose op."""
return ops.Transpose(operand, [1, 0])
def Transpose(self, operand, permutation):
"""Transpose op."""
return ops.Transpose(operand, permutation)
def SelectAndScatter(self, operand, select, window_dimensions, window_strides,
padding, source, init_value, scatter):
"""Select and scatter op, used by the gradient of ReduceWindow.
Args:
operand: XlaOp for array of dimension N and type T over which the windows
slide.
select: Computation of type (T, T) -> Pred to apply to the elements of
each window to indicate which element is selected.
window_dimensions: sequence of N integers for dimensions of the window.
window_strides: sequence of N integers for the strides of the window.
padding: PaddingType representing either 'SAME' or 'VALID ' padding.
source: XlaOp for array of type T with values to scatter.
init_value: XlaOp of scalar type T for initial out value.
scatter: Computation of type (T, T) -> T to apply to each scatter source
element with its destination element.
Returns:
An XlaOp representing the added SelectAndScatter op.
"""
pads = _convert_padding_type_to_pad_values(
padding,
self.GetShape(operand).dimensions(), window_dimensions, window_strides)
return ops.SelectAndScatterWithGeneralPadding(operand, select.computation,
window_dimensions,
window_strides, pads, source,
init_value,
scatter.computation)
def Slice(self, operand, start_indices, limit_indices, strides=None):
"""Enqueues a slice operation onto the computation.
Args:
operand: XlaOp for the N dimensional array to be sliced.
start_indices: iterable of N integers containing the starting indices of
the slice for each dimension.
limit_indices: iterable of N integers containing the ending indices
(exclusive) of the slice for each dimension.
strides: optional iterable of N integers containing the stride sizes for
each dimension.
Returns:
An XlaOp representing the added Slice op.
"""
if strides is None:
start_indices = list(start_indices)
strides = [1] * len(start_indices)
return ops.Slice(operand, start_indices, limit_indices, strides)
def DynamicSlice(self, operand, start_indices, slice_sizes):
"""Enqueues a slice op with dynamic start indices onto the computation.
Args:
operand: XlaOp for the N dimensional array to be sliced.
start_indices: XlaOp for the 1D array of N integers containing the
starting indices of the slice.
slice_sizes: iterable of N integers containing the slice sizes in each
dimension.
Returns:
An XlaOp representing the added DynamicSlice op.
"""
slice_sizes = list(slice_sizes)
if isinstance(start_indices, _xla.XlaOp):
start_indices = [
ops.Reshape(ops.Slice(start_indices, [i], [i + 1], [1]), [])
for i in range(len(slice_sizes))
]
return ops.DynamicSlice(operand, list(start_indices), slice_sizes)
def DynamicUpdateSlice(self, operand, update, start_indices):
"""Enqueues a dynamic update slice operation onto the computation.
Args:
operand: XlaOp for the N dimensional array to be updated.
update: N dimensional array comprising the slice update.
start_indices: Rank-1 array of N integers comprising the starting indices
of the slice along each dimension.
Returns:
An XlaOp representing the added DynamicUpdateSlice op.
"""
if isinstance(start_indices, _xla.XlaOp):
ndims = self._builder.GetShape(start_indices).dimensions()[0]
start_indices = [
ops.Reshape(ops.Slice(start_indices, [i], [i + 1], [1]), [])
for i in range(ndims)
]
return ops.DynamicUpdateSlice(operand, update, list(start_indices))
def Tuple(self, *elems):
"""Enqueues a tuple operation onto the computation.
Args:
elems: a sequence of tuple operands (each a XlaOp).
Returns:
An XlaOp representing the added Tuple op.
"""
return ops.Tuple(self._builder, list(elems))
def Call(self, computation_to_apply, operands):
"""Enqueues a call operation onto the computation.
Args:
computation_to_apply: a Computation object.
operands: an iterable of XlaOp. The number and types of operands must
match the arity of computation_to_apply.
Returns:
An XlaOp representing the added call op.
"""
return ops.Call(self._builder, computation_to_apply.computation,
list(operands))
def CustomCall(self,
call_target_name,
operands,
shape_with_layout,
operand_shapes_with_layout,
opaque=None):
"""Enqueues a custom call operation onto the computation.
Args:
call_target_name: the name of the function to call.
operands: an iterable of XlaOp. The number and types of operands must
match the arity of `operand_shapes_with_layout`.
shape_with_layout: the shape of the operator's output, with layout.
operand_shapes_with_layout: the shapes of `operands`, including the
expected layouts.
opaque: an opaque string passed to the backend.
Returns:
An XlaOp representing the added custom call op.
"""
opaque = opaque or b''
return ops.CustomCall(self._builder, call_target_name,
list(operands), shape_with_layout,
list(operand_shapes_with_layout), opaque)
def Map(self, operands, computation_to_apply, dimensions):
"""Enqueues a map operation onto the computation.
Args:
operands: an iterable of XlaOp.
computation_to_apply: a Computation object.
dimensions: dimensions over which to apply map the function.
Returns:
An XlaOp representing the added Map op.
"""
return ops.Map(self._builder, list(operands),
computation_to_apply.computation, dimensions, [])
def Reduce(self, operand, init_value, computation_to_apply, dimensions):
"""Enqueues a reduction operation onto the computation.
Args:
operand: reduction operand (XlaOp).
init_value: reduction initial value (XlaOp).
computation_to_apply: a Computation object - binary reduction function.
dimensions: sequence of dimensions (integers) to reduce on.
Returns:
An XlaOp representing the added Reduce op.
"""
return ops.Reduce(self._builder, [operand], [init_value],
computation_to_apply.computation, dimensions)
def ReduceWindow(self, operand, init_value, computation_to_apply,
window_dimensions, window_strides, padding):
"""Enqueues a windowed reduction operation onto the computation.
Args:
operand: reduction operand (XlaOp).
init_value: reduction initial value (XlaOp).
computation_to_apply: a binary reduction function (Computation).
window_dimensions: dimensions of window (sequence of integers).
window_strides: strides for window (sequence of integers).
padding: PaddingType representing either 'SAME' or 'VALID' padding.
Returns:
An XlaOp representing the added ReduceWindow op.
"""
pads = _convert_padding_type_to_pad_values(
padding,
self.GetShape(operand).dimensions(), window_dimensions, window_strides)
return ops.ReduceWindowWithGeneralPadding(operand, init_value,
computation_to_apply.computation,
window_dimensions, window_strides,
(), (), pads)
def ReduceWindowWithGeneralPadding(self, operand, init_value,
computation_to_apply, window_dimensions,
window_strides, base_dilations,
window_dilations, padding):
"""Enqueues a windowed reduction operation onto the computation.
Args:
operand: reduction operand (XlaOp).
init_value: reduction initial value (XlaOp).
computation_to_apply: a binary reduction function (Computation).
window_dimensions: dimensions of window (sequence of integers).
window_strides: strides for window (sequence of integers).
base_dilations: dilations for the base (sequence of integers).
window_dilations: dilations for window (sequence of integers).
padding: length-N array-like of pairs of integers of (low, high) padding.
Returns:
An XlaOp representing the added ReduceWindow op.
"""
return ops.ReduceWindowWithGeneralPadding(operand, init_value,
computation_to_apply.computation,
window_dimensions, window_strides,
base_dilations, window_dilations,
padding)
def RngNormal(self, mu, sigma, dims):
"""Enqueues an RngNormal operation onto the computation.
Args:
mu: An XlaOp to an F32 scalar specifying the mean.
sigma: An XlaOp to an F32 scalar specifying the standard deviation.
dims: A 1D array-like of nonnegative integers specifying the dimensions.
Returns: a XlaOp to the generated array of F32 values.
"""
shape = _xla.Shape.array_shape(self.GetShape(mu).xla_element_type(), dims)
return ops.RngNormal(mu, sigma, shape)
def RngUniform(self, a, b, dims):
"""Enqueues an RngUniform operation onto the computation.
Args:
a: a XlaOp to an F32, S32, or U32 scalar (consistent with the type of b)
specifying the low end of the interval [a, b) over which values are
generated.
b: a XlaOp to an F32, S32, or U32 scalar (consistent with the type of a)
specifying the high end of the interval [a, b) over which values are
generated.
dims: A 1D array-like of nonnegative integers specifying the dimensions.
Returns: a XlaOp to the generated array of values with the same numeric type
(F32, S32, or U32) as the arguments a and b.
"""
shape = _xla.Shape.array_shape(self.GetShape(a).xla_element_type(), dims)
return ops.RngUniform(a, b, shape)
def While(self, cond, body, init):
"""Enqueues a While operation onto the computation.
Args:
cond: a Computation for the loop condition, which has type T -> PRED
body: a Computation for the loop body, which has type T -> T
init: a XlaOp for the initial parameter, which has type T
Returns: a XlaOp representing the While operation.
"""
return ops.While(cond.computation, body.computation, init)
def Conditional(self, pred, true_operand, true_computation, false_operand,
false_computation):
"""Enqueues a Conditional operation onto the computation.
Args:
predicate: a XlaOp to test, which has scalar type PRED
true_operand: a XlaOp of type T_0
true_computation: a Computation to apply to true_operand, type T_0 -> S
false_operand: a ComputationDatahandle of type T_1
false_computation: a Computation to apply to false_operand, type T_1 -> S
Returns: a XlaOp representing the Conditional operation.
"""
return ops.Conditional(pred, true_operand, true_computation.computation,
false_operand, false_computation.computation)
def IsConstant(self, operand):
"""Checks whether the given operand is a compile-time constant.
Args:
operand: a ComputationDataHandle to test.
Returns: bool indicating whether `operand` is a compile-time constant,
meaning its value does not depend on any parametersor, or on stateful
operators such as `RngNormal` or `Infeed`.
"""
return self._builder.IsConstant(operand)
def BuildConstantSubGraph(self, operand):
"""Builds a constant sub graph.
Args:
operand: a XlaOp to test.
Returns: a Computation that is rooted on the given `operand` which is a
compile-time constant.
"""
return ops.BuildConstantSubGraph(operand)
def DotGeneral(self, lhs, rhs, dimension_numbers, precision_config=None):
"""Enqueues a general dot operation onto the computation.
Args:
lhs: XlaOp for the left-hand-side array.
rhs: XlaOp for the right-hand-side array.
dimension_numbers: either a DotDimensionNumbers or a nested tuple
((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) of lists of
integers representing the dimensions to treat as contracting dimensions
and batch dimensions on each input operand.
Returns: a XlaOp representing the DotGeneral operation.
"""
if isinstance(dimension_numbers, tuple):
dimension_numbers = GetDotDimensionsFromLists(dimension_numbers)
return ops.DotGeneral(
lhs, rhs, dimension_numbers, precision_config=precision_config)
def Conv(self,
lhs,
rhs,
window_strides,
padding,
feature_group_count=1,
batch_group_count=1,
precision_config=None):
"""Enqueues a Conv operation onto the computation.
Args:
lhs: XlaOp for the rank N+2 array of inputs.
rhs: XlaOp for the rank N+2 array of kernel weights.
window_strides: length-N array-like of integer kernel strides.
padding: PaddingType representing either 'SAME' or 'VALID' padding.
feature_group_count: number of feature groups for grouped convolution.
batch_group_count: number of batch groups for grouped convolution.
Returns: a XlaOp representing the Conv operation.
"""
pads = _convert_padding_type_to_pad_values(
padding,
self.GetShape(lhs).dimensions()[2:],
self.GetShape(rhs).dimensions()[2:], window_strides)
return self.ConvGeneralDilated(
lhs,
rhs,
window_strides,
pads, [], [],
dimension_numbers=None,
feature_group_count=feature_group_count,
batch_group_count=batch_group_count,
precision_config=precision_config)
def ConvWithGeneralPadding(self,
lhs,
rhs,
window_strides,
padding,
lhs_dilation,
rhs_dilation,
feature_group_count=1,
batch_group_count=1,
precision_config=None):
"""Enqueues a ConvWithGeneralPadding operation onto the computation.
Args:
lhs: XlaOp for the rank N+2 array of inputs.
rhs: XlaOp for the rank N+2 array of kernel weights.
window_strides: length-N array-like of kernel strides.
padding: length-N array-like of pairs of integers of (low, high) padding.
lhs_dilation: length-N array-like of dilation factors.
rhs_dilation: length-N array-like of dilation factors.
feature_group_count: number of feature groups for grouped convolution.
batch_group_count: number of batch groups for grouped convolution.
Returns:
A ComputationdataHandle representing the added ConvWithGeneralPadding op.
"""
return self.ConvGeneralDilated(
lhs,
rhs,
list(window_strides),
list(padding),
list(lhs_dilation),
list(rhs_dilation),
dimension_numbers=None,
feature_group_count=feature_group_count,
batch_group_count=batch_group_count,
precision_config=precision_config)
def _GetConvDimensionNumbers(self, num_spatial_dims):
"""Create ConvolutionDimensionNumbers proto for convolutions."""
nd = num_spatial_dims
dimension_numbers = ConvolutionDimensionNumbers()
dimension_numbers.input_batch_dimension = 0
dimension_numbers.input_feature_dimension = 1
dimension_numbers.output_batch_dimension = 0
dimension_numbers.output_feature_dimension = 1
dimension_numbers.kernel_output_feature_dimension = 0
dimension_numbers.kernel_input_feature_dimension = 1
dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd))
dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd))
dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd))
return dimension_numbers
def ConvGeneralDilated(self,
lhs,
rhs,
window_strides,
padding,
lhs_dilation,
rhs_dilation,
dimension_numbers=None,
feature_group_count=1,
batch_group_count=1,
precision_config=None):
"""Enqueues a ConvGeneralDilated operation onto the computation.
Args:
lhs: XlaOp for the rank N+2 array of inputs.
rhs: XlaOp for the rank N+2 array of kernel weights.
window_strides: length-N array-like of integer kernel strides.
padding: length-N array-like of pairs of integers of (low, high) padding.
lhs_dilation: length-N array-like of integer dilation factors.
rhs_dilation: length-N array-like of integer dilation factors.
dimension_numbers: optional, either a ConvolutionDimensionNumbers object
or a tuple (lhs_spec, rhs_spec, out_spec). Each element is a string of
length N+2 identifying by position: (1) batch dimensions in lhs, rhs,
and the output with the character 'N', (2) feature dimensions in lhs
and the output with the character 'C', (3) input and output feature
dimensions in rhs with the characters 'I' and 'O' respectively, and
(4) spatial dimension correspondences between lhs, rhs, and the output
using any distinct characters. For example, to indicate dimension
numbers consistent with the Conv operation with two spatial
dimensions, one could use ('NCHW', 'OIHW', 'NCHW'). As another
example, to indicate dimension numbers consistent with the TensorFlow
Conv2D operation, one could use ('NHWC', 'HWIO', 'NHWC'). When using
the latter form of convolution dimension specification, window strides
are associated with spatial dimension character labels according to
the order in which the labels appear in the rhs_spec string, so that
window_strides[0] is matched with the dimension corresponding to the
first character appearing in rhs_spec that is not 'I' or 'O'. By
default, use the same dimension numbering as Conv and
ConvWithGeneralPadding.
feature_group_count: number of feature groups for grouped convolution.
batch_group_count: number of batch groups for grouped convolution.
Returns: a XlaOp representing the ConvGenralDilated operation.
"""
if dimension_numbers is None:
dimension_numbers = self._GetConvDimensionNumbers(len(window_strides))
elif isinstance(dimension_numbers, tuple):
lhs_spec, rhs_spec, out_spec = dimension_numbers
dimension_numbers = ConvolutionDimensionNumbers()
dimension_numbers.input_batch_dimension = lhs_spec.index('N')
dimension_numbers.input_feature_dimension = lhs_spec.index('C')
dimension_numbers.output_batch_dimension = out_spec.index('N')
dimension_numbers.output_feature_dimension = out_spec.index('C')
dimension_numbers.kernel_output_feature_dimension = rhs_spec.index('O')
dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I')
dimension_numbers.kernel_spatial_dimensions.extend(
i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'})
dimension_numbers.input_spatial_dimensions.extend(
sorted((i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}),
key=lambda i: rhs_spec.index(lhs_spec[i])))
dimension_numbers.output_spatial_dimensions.extend(
sorted((i for i, c in enumerate(out_spec) if c not in {'N', 'C'}),
key=lambda i: rhs_spec.index(out_spec[i])))
return ops.ConvGeneralDilated(
lhs,
rhs,
window_strides,
padding,
lhs_dilation,
rhs_dilation,
dimension_numbers,
feature_group_count,
batch_group_count,
precision_config=precision_config)
def Sort(self, operands, dimension=-1, comparator=None):
"""Enqueues a sort operation onto the computation.
Args:
operands: either an XlaOp or a sequence of XlaOps to sort. All operands
must be arrays with the same dimensions.
dimension: the array dimension over which to sort.
comparator: a comparator XlaComputation. See the XLA operation semantics
for details.
Returns:
Either an XlaOp or a tuple of XlaOps (if `operands` was an XlaOp or
a tuple of XlaOps, respectively.)
"""
operands = (
list(operands)
if isinstance(operands, collections.Sequence) else [operands])
return ops.Sort(self._builder, operands, dimension,
comparator.computation if comparator else None)
def SortKeyVal(self, keys, values, dimension=-1):
"""Enqueues a key-value sort operation onto the computation.
Deprecated. Use `Sort` instead.
"""
return ops.Sort(self._builder, [keys, values], dimension)
def QR(self, a, full_matrices=True):
"""Enqueues a QR decomposition onto the computation."""
return self.Tuple(*ops.QR(a, full_matrices))
def TriangularSolve(self,
a,
b,
left_side=False,
lower=False,
transpose_a=False,
conjugate_a=False,
unit_diagonal=False):
"""Enqueues a triangular-solve operation onto the computation."""
if not transpose_a:
transpose = _xla.TriangularSolveOptions_Transpose.NO_TRANSPOSE
if conjugate_a:
a = self.Conj(a)
else:
transpose = (
_xla.TriangularSolveOptions_Transpose.ADJOINT
if conjugate_a else _xla.TriangularSolveOptions_Transpose.TRANSPOSE)
return ops.TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose)
def Eigh(self, a, full_matrices=True):
"""Enqueues a symmetric/Hermitian eigendecomposition."""
return self.Tuple(*ops.Eigh(a, full_matrices))
def SVD(self, a):
"""Enqueues a singular value decomposition."""
return self.Tuple(*ops.SVD(a))
def Gather(self,
a,
start_indices,
dimension_numbers,
slice_sizes,
indices_are_sorted=False):
"""Enqueues a Gather operation onto the computation."""
return ops.Gather(a, start_indices, dimension_numbers, slice_sizes,
indices_are_sorted)
def Scatter(self,
a,
scatter_indices,
updates,
update_computation,
dimension_numbers,
indices_are_sorted=False,
unique_indices=False):
"""Enqueues a Scatter operation onto the computation."""
return ops.Scatter(a, scatter_indices, updates,
update_computation.computation, dimension_numbers,
indices_are_sorted, unique_indices)
def Fft(self, operand, fft_type, fft_lengths):
"""Enqueues a FFT operation onto the computation."""
return ops.Fft(operand, fft_type, fft_lengths)
FftType = _xla.FftType
_UNARY_OPS = [
'Not',
'Clz',
'Abs',
'Exp',
'Expm1',
'Floor',
'Round',
'Ceil',
'Log',
'Log1p',
'Sign',
'Cos',
'Sin',
'Tanh',
'IsFinite',
'Sqrt',
'Rsqrt',
'Square',
'Reciprocal',
'Neg',
'Erf',
'Erfc',
'ErfInv',
'Lgamma',
'Digamma',
'BesselI0e',
'BesselI1e',
'Acos',
'Asin',
'Atan',
'Tan',
'Acosh',
'Asinh',
'Atanh',
'Cosh',
'Sinh',
'Real',
'Imag',
'Conj',
]
_BINARY_OPS = [
'Eq',
'Ne',
'Ge',
'Gt',
'Lt',
'Le',
'Add',
'Sub',
'Mul',
'Div',
'Rem',
'Max',
'Min',
'And',
'Or',
'Xor',
'Pow',
'ShiftLeft',
'ShiftRightArithmetic',
'ShiftRightLogical',
'Atan2',
'Complex',
]
_OTHER_OPS = [
'BitcastConvertType',
'Broadcast',
'BroadcastInDim',
'Cholesky',
'Clamp',
'Collapse',
'CollectivePermute',
'ConvertElementType',
'Dot',
'GetTupleElement',
'ReducePrecision',
'Rev',
'Select',
'SliceInDim',
]
def _forward_methods_to_local_builder():
"""Forward remaining ComputationBuilder methods to the C API.
Set up methods, corresponding to XLA operations,
whose calls are forwarded in a boilerplate manner to the underlying
_xla.ops API.
"""
def forward_op(target_method):
def forward(builder, *args, **kwargs):
del builder
return target_method(*args, **kwargs)
return forward
for method_name in itertools.chain(_UNARY_OPS, _BINARY_OPS, _OTHER_OPS):
forward = forward_op(getattr(ops, method_name))
forward.__name__ = method_name
setattr(ComputationBuilder, method_name, forward)
_forward_methods_to_local_builder()
def register_custom_call_target(name, fn, platform='cpu'):
"""Registers a custom call target.
Args:
name: bytes containing the name of the function.
fn: a PyCapsule object containing the function pointer.
platform: the target platform.
"""
_xla.RegisterCustomCallTarget(name, fn, xla_platform_names[platform])
# Deprecated. Use register_custom_call_target instead.
register_cpu_custom_call_target = register_custom_call_target
class PaddingConfigDimension(object):
"""Python representation of a xla.PaddingConfigDimension protobuf."""
__slots__ = ('edge_padding_low', 'edge_padding_high', 'interior_padding')
def __init__(self):
self.edge_padding_low = 0
self.edge_padding_high = 0
self.interior_padding = 0
class PaddingConfig(object):
"""Python representation of a xla.PaddingConfig protobuf."""
__slots__ = ('dimensions',)
def __init__(self):
self.dimensions = []
def GetPaddingConfigFromTriples(triples):
"""Create PaddingConfig proto from list of triples of integers."""
padding_config = PaddingConfig()
for lo, hi, interior in triples:
dimension = PaddingConfigDimension()
dimension.edge_padding_low = lo
dimension.edge_padding_high = hi
dimension.interior_padding = interior
padding_config.dimensions.append(dimension)
return padding_config
class DotDimensionNumbers(object):
"""Python representation of a xla.DotDimensionNumbers protobuf."""
__slots__ = ('lhs_contracting_dimensions', 'rhs_contracting_dimensions',
'lhs_batch_dimensions', 'rhs_batch_dimensions')
def __init__(self):
self.lhs_contracting_dimensions = []
self.rhs_contracting_dimensions = []
self.lhs_batch_dimensions = []
self.rhs_batch_dimensions = []
def GetDotDimensionsFromLists(dimension_numbers):
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
dot_dims_proto = DotDimensionNumbers()
dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract)
dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract)
dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch)
dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch)
return dot_dims_proto
class ConvolutionDimensionNumbers(object):
"""Python representation of a xla.ConvolutionDimensionNumbers protobuf."""
__slots__ = ('input_batch_dimension', 'input_feature_dimension',
'input_spatial_dimensions', 'kernel_input_feature_dimension',
'kernel_output_feature_dimension', 'kernel_spatial_dimensions',
'output_batch_dimension', 'output_feature_dimension',
'output_spatial_dimensions')
def __init__(self):
self.input_batch_dimension = 0
self.input_feature_dimension = 0
self.input_spatial_dimensions = []
self.kernel_input_feature_dimension = 0
self.kernel_output_feature_dimension = 0
self.kernel_spatial_dimensions = []
self.output_batch_dimension = 0
self.output_feature_dimension = 0
self.output_spatial_dimensions = []
class PrecisionConfig(object):
"""Python representation of a xla.PrecisionConfig protobuf."""
__slots__ = ('operand_precision',)
Precision = _xla.PrecisionConfig_Precision
def __init__(self):
self.operand_precision = []
class GatherDimensionNumbers(object):
"""Python representation of a xla.GatherDimensionNumbers protobuf."""
__slots__ = ('offset_dims', 'collapsed_slice_dims', 'start_index_map',
'index_vector_dim')
def __init__(self):
self.offset_dims = []
self.collapsed_slice_dims = []
self.start_index_map = []
self.index_vector_dim = 0
class ScatterDimensionNumbers(object):
"""Python representation of a xla.ScatterDimensionNumbers protobuf."""
__slots__ = ('update_window_dims', 'inserted_window_dims',
'scatter_dims_to_operand_dims', 'index_vector_dim')
def __init__(self):
self.update_window_dims = []
self.inserted_window_dims = []
self.scatter_dims_to_operand_dims = []
self.index_vector_dim = 0
class ReplicaGroup(object):
"""Python representation of a xla.ReplicaGroup protobuf."""
__slots__ = ('replica_ids',)
def __init__(self):
self.replica_ids = []
def _make_replica_group_proto(replica_group):
replica_group_proto = ReplicaGroup()
replica_group_proto.replica_ids.extend(replica_group)
return replica_group_proto
def _get_replica_groups_protos(replica_groups):
if replica_groups is None:
replica_groups_protos = [] # special value for XLA API
else:
replica_groups = list(replica_groups)
replica_groups_protos = [
_make_replica_group_proto(group) for group in replica_groups
]
return replica_groups_protos