| # Lint as: python3 |
| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """An XLA client in Python.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import collections |
| import enum # pylint: disable=g-bad-import-order |
| import inspect |
| import os |
| from typing import List, Sequence, Tuple, Union |
| |
| from absl import logging |
| import numpy as np |
| |
| # Note this module does *not* depend on any Python protocol buffers. The XLA |
| # Python bindings are currently packaged both as part of jaxlib and as part |
| # of TensorFlow. If we use protocol buffers here, then importing both jaxlib |
| # and TensorFlow may fail with duplicate protocol buffer message definitions. |
| |
| from tensorflow.compiler.xla.python import xla_extension as _xla |
| |
| # Most functions are snake_case for consistency with other modules, some |
| # method names are CamelCase for consistency with XLA. |
| # pylint: disable=invalid-name |
| |
| # Pylint has false positives for type annotations. |
| # pylint: disable=invalid-sequence-index |
| |
| ops = _xla.ops |
| profiler = _xla.profiler |
| |
| |
| xla_platform_names = { |
| 'cpu': 'Host', |
| 'gpu': 'CUDA', |
| } |
| |
| |
| def _cpu_backend_factory(): |
| return _xla.get_cpu_client(asynchronous=True) |
| |
| |
| def _gpu_backend_factory(distributed_client=None, node_id=0): |
| """Returns a GPU backend. BFC allocator is used by default.""" |
| allocator = os.getenv('XLA_PYTHON_CLIENT_ALLOCATOR', 'default').lower() |
| memory_fraction = os.getenv('XLA_PYTHON_CLIENT_MEM_FRACTION') |
| preallocate = os.getenv('XLA_PYTHON_CLIENT_PREALLOCATE') |
| if allocator not in ('default', 'platform', 'bfc'): |
| raise ValueError( |
| 'XLA_PYTHON_CLIENT_ALLOCATOR env var must be "default", "platform", or ' |
| '"bfc", got "%s"' % allocator) |
| config = _xla.GpuAllocatorConfig() |
| if allocator == 'default': |
| config.kind = _xla.GpuAllocatorConfig.Kind.DEFAULT |
| if allocator == 'platform': |
| config.kind = _xla.GpuAllocatorConfig.Kind.PLATFORM |
| if allocator == 'bfc': |
| config.kind = _xla.GpuAllocatorConfig.Kind.BFC |
| if memory_fraction: |
| config.memory_fraction = float(memory_fraction) |
| config.preallocate = preallocate not in ('0', 'false', 'False') |
| |
| return _xla.get_nvidia_gpu_client( |
| asynchronous=True, |
| allocator_config=config, |
| distributed_client=distributed_client, |
| node_id=node_id) |
| |
| |
| # Backend factories, keyed by user-visible name, in increasing priority order. |
| _local_backend_factories = collections.OrderedDict([ |
| ('cpu', _cpu_backend_factory), |
| ('gpu', _gpu_backend_factory), |
| ]) |
| |
| |
| def register_local_backend_factory(name, factory): |
| _local_backend_factories[name] = factory |
| |
| |
| _local_backends = None |
| |
| |
| def _get_local_backends(): |
| """Instantiates all known local backends.""" |
| global _local_backends |
| if _local_backends is not None: |
| return _local_backends |
| |
| _local_backends = collections.OrderedDict() |
| for name, factory in _local_backend_factories.items(): |
| logging.vlog(2, "Initializing backend '%s'" % name) |
| try: |
| backend = factory() |
| except RuntimeError: |
| if name == 'cpu': |
| # We always expect CPU to initialize successfully. |
| raise |
| else: |
| # If the backend isn't built into the binary, or if it has no devices, |
| # we expect a RuntimeError. |
| continue |
| _local_backends[name] = backend |
| return _local_backends |
| |
| |
| def get_local_backend(name=None): |
| """Returns a local backend. |
| |
| Args: |
| name: the backend name. If `None`, a default local backend is returned, |
| typically `gpu` if one is present, or `cpu` if not. If a string, the named |
| backend is returned or an exception raised. |
| |
| Returns: |
| A LocalBackend object. |
| """ |
| backends = _get_local_backends() |
| if name is not None: |
| try: |
| return backends[name] |
| except KeyError: |
| raise RuntimeError('Unknown backend {}'.format(name)) |
| |
| return list(backends.values())[-1] |
| |
| |
| class OpMetadata(object): |
| """Python representation of a xla.OpMetadata protobuf.""" |
| __slots__ = ('op_type', 'op_name', 'source_file', 'source_line') |
| |
| def __init__(self, op_type='', op_name='', source_file='', source_line=0): |
| self.op_type = op_type |
| self.op_name = op_name |
| self.source_file = source_file |
| self.source_line = source_line |
| |
| |
| def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): |
| """Helper for use in source mapping that returns an OpMetadata object.""" |
| full_filename, lineno = inspect.stack()[skip_frames][1:3] |
| filename = os.path.basename(full_filename) |
| return OpMetadata( |
| op_type=op_type, |
| op_name=op_name, |
| source_file=filename, |
| source_line=lineno) |
| |
| |
| PrimitiveType = _xla.PrimitiveType |
| |
| bfloat16 = _xla.bfloat16_dtype() |
| |
| XLA_ELEMENT_TYPE_TO_DTYPE = { |
| PrimitiveType.PRED: np.dtype('bool'), |
| PrimitiveType.S8: np.dtype('int8'), |
| PrimitiveType.S16: np.dtype('int16'), |
| PrimitiveType.S32: np.dtype('int32'), |
| PrimitiveType.S64: np.dtype('int64'), |
| PrimitiveType.U8: np.dtype('uint8'), |
| PrimitiveType.U16: np.dtype('uint16'), |
| PrimitiveType.U32: np.dtype('uint32'), |
| PrimitiveType.U64: np.dtype('uint64'), |
| PrimitiveType.BF16: np.dtype(bfloat16), |
| PrimitiveType.F16: np.dtype('float16'), |
| PrimitiveType.F32: np.dtype('float32'), |
| PrimitiveType.F64: np.dtype('float64'), |
| PrimitiveType.C64: np.dtype('complex64'), |
| PrimitiveType.C128: np.dtype('complex128'), |
| PrimitiveType.TUPLE: np.dtype(np.object), |
| PrimitiveType.TOKEN: np.dtype(np.object), |
| } |
| |
| # Note the conversion on the key. Numpy has a known issue wherein dtype hashing |
| # doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, |
| # when keying by dtype in this dict, we use the string form of dtypes. |
| DTYPE_TO_XLA_ELEMENT_TYPE = { |
| str(dt): et for et, dt in XLA_ELEMENT_TYPE_TO_DTYPE.items() |
| } |
| |
| |
| def dtype_to_etype(dtype): |
| """Convenience function for reading DTYPE_TO_XLA_ELEMENT_TYPE.""" |
| return DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] |
| |
| |
| Shape = _xla.Shape |
| Shape.__doc__ = """ |
| A Shape is an object defined in C++ that duck types like the following class: |
| |
| class Shape(object): |
| '''Represents an XLA shape. |
| |
| A shape is either an array shape, having rank-many integer |
| dimensions and an element type (represented by a Numpy dtype), or it |
| is a tuple shape, having a shape for every tuple component: |
| |
| type shape = |
| TupleShape of shape list |
| | ArrayShape of { dimensions: int list; element_type: dtype } |
| ''' |
| |
| @staticmethod |
| def tuple_shape(tuple_shapes) -> Shape: |
| "Construct a tuple shape." |
| |
| @staticmethod |
| def array_shape(element_type, dimensions, minor_to_major=None) -> Shape: |
| |
| @staticmethod |
| def from_pyval(pyval) -> Shape: |
| "Returns a Shape that describes a tuple-tree of Numpy arrays." |
| |
| def __init__(self, str) -> Shape: |
| "Parses a shape string." |
| def __eq__(self, other: Shape) -> bool: |
| def __ne__(self, other: Shape) -> bool: |
| def __hash__(self): |
| def __repr__(self): |
| def is_tuple(self) -> bool: |
| def is_array(self) -> bool: |
| def tuple_shapes(self) -> [Shape]: |
| def numpy_dtype(self) -> np.dtype: |
| "Like element_type(), but returns dtype('O') for a tuple shape." |
| def xla_element_type(self) -> PrimitiveType: |
| def element_type(self) -> np.dtype: |
| def dimensions(self) -> (int, int, ...): |
| def rank(self) -> int: |
| def with_major_to_minor_layout_if_absent(self) -> Shape: |
| "Returns a copy with missing layouts set to major-to-minor." |
| |
| def to_serialized_proto(self) -> bytes: |
| "Returns 'shape' as a serialized proto." |
| """ |
| |
| ProgramShape = _xla.ProgramShape |
| ProgramShape.__doc__ = """ |
| A ProgramShape is a C++ object that duck types like the following class. |
| |
| class ProgramShape(object): |
| def __init__(self, parameter_shapes, result_shape): |
| def parameter_shapes(self) -> [Shape]: |
| def result_shape(self) -> Shape: |
| def __repr__(self): |
| """ |
| |
| |
| def shape_from_pyval(pyval): |
| """Returns a Shape that describes a tuple-tree of Numpy arrays.""" |
| |
| def convert(pyval): |
| if isinstance(pyval, tuple): |
| return Shape.tuple_shape(tuple(convert(elt) for elt in pyval)) |
| else: |
| return Shape.array_shape(pyval.dtype, np.shape(pyval)) |
| |
| return convert(pyval) |
| |
| |
| DeviceAssignment = _xla.DeviceAssignment |
| DeviceAssignment.__doc__ = """ |
| A DeviceAssignment is a C++ object with the following signature. |
| |
| def create(assignment): |
| '''Builds a device assignment. |
| |
| Args: |
| assignment: a 2D numpy array of device ordinal integers, indexed by |
| [replica][computation_in_replica]. |
| Returns: |
| A device assignment. |
| ''' |
| |
| def replica_count(): |
| '''Returns the number of replicas.''' |
| def computation_count(): |
| '''Returns the number of computations per replica.''' |
| """ |
| |
| Device = _xla.Device |
| CompileOptions = _xla.CompileOptions |
| |
| |
| # An Executable is a C++ class that duck types with the following API: |
| # class Executable(object): |
| # def local_devices(self) -> [Device]: |
| # def Execute(self, arguments : [Buffer]) -> Buffer: |
| # """Execute on one replica with Buffer arguments and return value.""" |
| # |
| # def SizeOfGeneratedCodeInBytes(self) -> int: |
| # """Return generated binary size, or -1 if not known.""" |
| # |
| # def ExecuteOnLocalDevices(self, arguments: [[Buffer]]) -> [Buffer]: |
| # """Execute on many replicas with Buffer arguments and return value. |
| # |
| # Args: |
| # arguments: A sequence of sequences of Buffers. The i'th inner sequence |
| # comprises the arguments for execution on the i'th local device. |
| # |
| # Returns: |
| # A list of the computation's outputs for each local device, as a Buffer. |
| # If a shallow sequence of arguments was passed in for `arguments`, then |
| # the sole, zero'th device's output is returned instead, as a Buffer. |
| # """ |
| # |
| # There are different implementations of Executable for different backends. |
| |
| |
| def execute_with_python_values(executable, arguments, backend): |
| """Execute on one replica with Python values as arguments and output.""" |
| |
| def put(arg): |
| return backend.buffer_from_pyval(arg, device=executable.local_devices()[0]) |
| |
| arguments = [put(arg) for arg in arguments] |
| outputs = executable.Execute(arguments) |
| return [x.to_py() for x in outputs] |
| |
| |
| def execute_with_python_values_replicated(executable, arguments, backend): |
| """Execute on many replicas with Python values as arguments and output. |
| |
| Arguments: |
| executable: the program to run. |
| arguments: a list of lists of Python values indexed by `[replica][arg_num]` |
| to pass as inputs. |
| backend: the backend we are targeting. |
| |
| Returns: |
| A list of python values, one per replica. |
| """ |
| devices = executable.local_devices() |
| # pylint: disable=g-complex-comprehension |
| flat_args = [(arg, devices[replica]) |
| for replica, replica_args in enumerate(arguments) |
| for arg in replica_args] |
| flat_arg_buffers = [ |
| backend.buffer_from_pyval(pyval, device) for pyval, device in flat_args |
| ] |
| arg_buffers = [] |
| for replica_args in arguments: |
| arg_buffers.append(flat_arg_buffers[:len(replica_args)]) |
| flat_arg_buffers = flat_arg_buffers[len(replica_args):] |
| return [[x.to_py() |
| for x in xs] |
| for xs in executable.ExecuteOnLocalDevices(arg_buffers)] |
| |
| |
| class PaddingType(enum.Enum): |
| VALID = 1 |
| SAME = 2 |
| |
| |
| def window_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims, |
| window_strides): |
| """Maps PaddingType or string to pad values (list of pairs of ints).""" |
| if not isinstance(padding_type, (str, PaddingType)): |
| msg = 'padding_type must be str or PaddingType, got {}.' |
| raise TypeError(msg.format(type(padding_type))) |
| |
| if isinstance(padding_type, str): |
| if padding_type.upper() == 'VALID': |
| padding_type = PaddingType.VALID |
| elif padding_type.upper() == 'SAME': |
| padding_type = PaddingType.SAME |
| else: |
| msg = 'Unknown padding type string: expected "VALID" or "SAME", got {}.' |
| raise ValueError(msg.format(padding_type)) |
| |
| if padding_type == PaddingType.VALID: |
| return [(0, 0)] * len(window_strides) |
| elif padding_type == PaddingType.SAME: |
| out_shape = np.ceil(np.true_divide(lhs_dims, window_strides)).astype(int) |
| pad_sizes = [ |
| max((out_size - 1) * stride + filter_size - in_size, 0) |
| for out_size, stride, filter_size, in_size in zip( |
| out_shape, window_strides, rhs_dims, lhs_dims) |
| ] |
| return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes] |
| else: |
| msg = 'Unexpected PaddingType value: {}' |
| raise ValueError(msg.format(padding_type)) |
| |
| |
| XlaBuilder = _xla.XlaBuilder |
| XlaComputation = _xla.XlaComputation |
| FftType = _xla.FftType |
| |
| |
| def register_custom_call_target(name, fn, platform='cpu'): |
| """Registers a custom call target. |
| |
| Args: |
| name: bytes containing the name of the function. |
| fn: a PyCapsule object containing the function pointer. |
| platform: the target platform. |
| """ |
| _xla.register_custom_call_target(name, fn, xla_platform_names[platform]) |
| |
| |
| # Deprecated. Use register_custom_call_target instead. |
| register_cpu_custom_call_target = register_custom_call_target |
| |
| |
| class PaddingConfigDimension(object): |
| """Python representation of a xla.PaddingConfigDimension protobuf.""" |
| __slots__ = ('edge_padding_low', 'edge_padding_high', 'interior_padding') |
| |
| def __init__(self): |
| self.edge_padding_low = 0 |
| self.edge_padding_high = 0 |
| self.interior_padding = 0 |
| |
| |
| class PaddingConfig(object): |
| """Python representation of a xla.PaddingConfig protobuf.""" |
| __slots__ = ('dimensions',) |
| |
| def __init__(self): |
| self.dimensions = [] |
| |
| |
| def make_padding_config( |
| padding_config: Union[PaddingConfig, Sequence[Tuple[int, int, int]]] |
| ) -> PaddingConfig: |
| """Create PaddingConfig proto from list of triples of integers. |
| |
| Args: |
| padding_config: either a PaddingConfig or a list of integer triples |
| (edge_padding_low, edge_padding_high, interior_padding) representing the |
| configuration of the padding operation. |
| |
| Returns: |
| A `PaddingConfig` object. |
| """ |
| if isinstance(padding_config, tuple) or isinstance(padding_config, list): |
| triples = padding_config |
| padding_config = PaddingConfig() |
| for lo, hi, interior in triples: |
| dimension = PaddingConfigDimension() |
| dimension.edge_padding_low = lo |
| dimension.edge_padding_high = hi |
| dimension.interior_padding = interior |
| padding_config.dimensions.append(dimension) |
| return padding_config |
| |
| |
| class DotDimensionNumbers(object): |
| """Python representation of a xla.DotDimensionNumbers protobuf.""" |
| __slots__ = ('lhs_contracting_dimensions', 'rhs_contracting_dimensions', |
| 'lhs_batch_dimensions', 'rhs_batch_dimensions') |
| |
| def __init__(self): |
| self.lhs_contracting_dimensions = [] |
| self.rhs_contracting_dimensions = [] |
| self.lhs_batch_dimensions = [] |
| self.rhs_batch_dimensions = [] |
| |
| |
| def make_dot_dimension_numbers( |
| dimension_numbers: Union[DotDimensionNumbers, |
| Tuple[Tuple[List[int], List[int]], |
| Tuple[List[int], List[int]]]] |
| ) -> DotDimensionNumbers: |
| """Builds a DotDimensionNumbers object from a specification. |
| |
| Args: |
| dimension_numbers: either a `DotDimensionNumbers` or a nested tuple |
| `((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))` of lists of |
| integers representing the dimensions to treat as contracting dimensions |
| and batch dimensions on each input operand. |
| |
| Returns: |
| A `DotDimensionNumbers` object. |
| """ |
| if isinstance(dimension_numbers, (list, tuple)): |
| (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers |
| dot_dims_proto = DotDimensionNumbers() |
| dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract) |
| dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract) |
| dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch) |
| dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch) |
| return dot_dims_proto |
| else: |
| return dimension_numbers |
| |
| |
| class ConvolutionDimensionNumbers(object): |
| """Python representation of a xla.ConvolutionDimensionNumbers protobuf.""" |
| __slots__ = ('input_batch_dimension', 'input_feature_dimension', |
| 'input_spatial_dimensions', 'kernel_input_feature_dimension', |
| 'kernel_output_feature_dimension', 'kernel_spatial_dimensions', |
| 'output_batch_dimension', 'output_feature_dimension', |
| 'output_spatial_dimensions') |
| |
| def __init__(self): |
| self.input_batch_dimension = 0 |
| self.input_feature_dimension = 0 |
| self.input_spatial_dimensions = [] |
| self.kernel_input_feature_dimension = 0 |
| self.kernel_output_feature_dimension = 0 |
| self.kernel_spatial_dimensions = [] |
| self.output_batch_dimension = 0 |
| self.output_feature_dimension = 0 |
| self.output_spatial_dimensions = [] |
| |
| |
| def make_convolution_dimension_numbers( |
| dimension_numbers: Union[None, ConvolutionDimensionNumbers, Tuple[str, str, |
| str]], |
| num_spatial_dimensions: int) -> ConvolutionDimensionNumbers: |
| """Builds a ConvolutionDimensionNumbers object from a specification. |
| |
| Args: |
| dimension_numbers: optional, either a ConvolutionDimensionNumbers object or |
| a tuple (lhs_spec, rhs_spec, out_spec). Each element is a string of |
| length N+2 identifying by position: (1) batch dimensions in lhs, rhs, and |
| the output with the character 'N', (2) feature dimensions in lhs and the |
| output with the character 'C', (3) input and output feature dimensions |
| in rhs with the characters 'I' and 'O' respectively, and (4) spatial |
| dimension correspondences between lhs, rhs, and the output using any |
| distinct characters. For example, to indicate dimension numbers |
| consistent with the Conv operation with two spatial dimensions, one |
| could use ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate |
| dimension numbers consistent with the TensorFlow Conv2D operation, one |
| could use ('NHWC', 'HWIO', 'NHWC'). When using the latter form of |
| convolution dimension specification, window strides are associated with |
| spatial dimension character labels according to the order in which the |
| labels appear in the rhs_spec string, so that window_strides[0] is |
| matched with the dimension corresponding to the first character |
| appearing in rhs_spec that is not 'I' or 'O'. By default, use the same |
| dimension numbering as Conv and ConvWithGeneralPadding. |
| num_spatial_dimensions: the number of spatial dimensions. |
| |
| Returns: |
| A `ConvolutionDimensionNumbers` object. |
| """ |
| if dimension_numbers is None: |
| nd = num_spatial_dimensions |
| dimension_numbers = ConvolutionDimensionNumbers() |
| dimension_numbers.input_batch_dimension = 0 |
| dimension_numbers.input_feature_dimension = 1 |
| dimension_numbers.output_batch_dimension = 0 |
| dimension_numbers.output_feature_dimension = 1 |
| dimension_numbers.kernel_output_feature_dimension = 0 |
| dimension_numbers.kernel_input_feature_dimension = 1 |
| dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd)) |
| dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd)) |
| dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd)) |
| elif isinstance(dimension_numbers, tuple): |
| lhs_spec, rhs_spec, out_spec = dimension_numbers |
| dimension_numbers = ConvolutionDimensionNumbers() |
| |
| dimension_numbers.input_batch_dimension = lhs_spec.index('N') |
| dimension_numbers.input_feature_dimension = lhs_spec.index('C') |
| dimension_numbers.output_batch_dimension = out_spec.index('N') |
| dimension_numbers.output_feature_dimension = out_spec.index('C') |
| dimension_numbers.kernel_output_feature_dimension = rhs_spec.index('O') |
| dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I') |
| |
| dimension_numbers.kernel_spatial_dimensions.extend( |
| i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'}) |
| dimension_numbers.input_spatial_dimensions.extend( |
| sorted((i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}), |
| key=lambda i: rhs_spec.index(lhs_spec[i]))) |
| dimension_numbers.output_spatial_dimensions.extend( |
| sorted((i for i, c in enumerate(out_spec) if c not in {'N', 'C'}), |
| key=lambda i: rhs_spec.index(out_spec[i]))) |
| return dimension_numbers |
| |
| |
| class OpSharding(object): |
| """Python representation of a xla.OpSharding protobuf.""" |
| __slots__ = ('type', 'tile_assignment_dimensions', 'tile_assignment_devices', |
| 'tuple_shardings') |
| |
| Type = _xla.OpSharding_Type |
| |
| def __init__(self): |
| self.type = self.Type.REPLICATED |
| self.tile_assignment_dimensions = [] |
| self.tile_assignment_devices = [] |
| self.tuple_shardings = [] |
| |
| |
| class PrecisionConfig(object): |
| """Python representation of a xla.PrecisionConfig protobuf.""" |
| __slots__ = ('operand_precision',) |
| |
| Precision = _xla.PrecisionConfig_Precision |
| |
| def __init__(self): |
| self.operand_precision = [] |
| |
| |
| class GatherDimensionNumbers(object): |
| """Python representation of a xla.GatherDimensionNumbers protobuf.""" |
| __slots__ = ('offset_dims', 'collapsed_slice_dims', 'start_index_map', |
| 'index_vector_dim') |
| |
| def __init__(self): |
| self.offset_dims = [] |
| self.collapsed_slice_dims = [] |
| self.start_index_map = [] |
| self.index_vector_dim = 0 |
| |
| |
| class ScatterDimensionNumbers(object): |
| """Python representation of a xla.ScatterDimensionNumbers protobuf.""" |
| __slots__ = ('update_window_dims', 'inserted_window_dims', |
| 'scatter_dims_to_operand_dims', 'index_vector_dim') |
| |
| def __init__(self): |
| self.update_window_dims = [] |
| self.inserted_window_dims = [] |
| self.scatter_dims_to_operand_dims = [] |
| self.index_vector_dim = 0 |
| |
| |
| class ReplicaGroup(object): |
| """Python representation of a xla.ReplicaGroup protobuf.""" |
| __slots__ = ('replica_ids',) |
| |
| def __init__(self): |
| self.replica_ids = [] |
| |
| |
| def _make_replica_group_proto(replica_group): |
| replica_group_proto = ReplicaGroup() |
| replica_group_proto.replica_ids.extend(replica_group) |
| return replica_group_proto |
| |
| |
| def make_replica_groups(replica_groups): |
| if replica_groups is None: |
| replica_groups_protos = [] # special value for XLA API |
| else: |
| replica_groups = list(replica_groups) |
| replica_groups_protos = [ |
| _make_replica_group_proto(group) for group in replica_groups |
| ] |
| return replica_groups_protos |