blob: e75a09492ec12b95bad32b221a8e78a1b79f3a6b [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""Helper library for handling infeed between hosts and TPUs.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
from tensorflow.compiler.xla.python_api import xla_shape
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu
from tensorflow.contrib.tpu.python.tpu import tpu_sharding
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.util import nest
class InfeedQueue(object):
"""A helper object to build a device infeed queue.
The InfeedQueue builds the host-side and device-side Ops to enqueue and
dequeue elements, respectively, and ensures that their types and
shapes match.
"""
def __init__(self,
number_of_tuple_elements=None,
tuple_types=None,
tuple_shapes=None,
shard_dimensions=None,
name=None):
"""Creates a new InfeedQueue with the given configuration.
The configuration need not be fully specified at creation since it
can be modified subsequently by methods that set the values
explicitly or infer them from the shapes of inputs.
Args:
number_of_tuple_elements: the number of Tensors fed atomically through the
queue, must be present unless it can be inferred from other arguments.
tuple_types: if not None, a list of types of the elements of the queue.
tuple_shapes: if not None, a list of shapes of the elements of the queue.
shard_dimensions: if not None, a list of dimensions on which the
elements of the queue should be sharded during automatic
parallelization.
name: the name of the queue.
Raises:
ValueError: if number_of_tuple_elements <= 0; or
number_of_tuple_arguments, tuple_types, tuple_shapes, and
shard_dimensions are all None; or the length of tuple_types,
tuple_shapes, or shard_dimensions is not equal to
number_of_tuple_elements; or any element of shard_dimensions
can't be converted to a Dimension.
TypeError: if any element of tuple_types or tuple_shapes can't
be converted to a dtype or TensorShape, respectively.
"""
self._frozen = False
self._generated_enqueue_ops = False
self._generated_dequeue_op = False
self._name = "InfeedQueue" if name is None else name
if number_of_tuple_elements is None:
if tuple_types is not None:
number_of_tuple_elements = len(tuple_types)
elif tuple_shapes is not None:
number_of_tuple_elements = len(tuple_shapes)
elif shard_dimensions is not None:
number_of_tuple_elements = len(shard_dimensions)
else:
raise ValueError(
"number of tuple elements cannot be inferred from InfeedQueue "
"constructor"
)
if number_of_tuple_elements <= 0:
raise ValueError("number_of_tuple_elements %d must be > 0" %
number_of_tuple_elements)
# Make an empty sharding policy for each tuple element.
self._sharding_policies = [
tpu_sharding.ShardingPolicy()
for _ in xrange(number_of_tuple_elements)
]
if tuple_types is not None:
self.set_tuple_types(tuple_types)
else:
self._tuple_types = None
if tuple_shapes is not None:
self.set_tuple_shapes(tuple_shapes)
else:
self._tuple_shapes = None
if shard_dimensions is not None:
self.set_shard_dimensions(shard_dimensions)
self._validate()
def _validate(self):
"""Checks that the configuration is self-consistent.
Raises:
ValueError: if the shapes and sharding policies don't match.
"""
if self.tuple_shapes is not None:
for (policy, shape) in zip(self._sharding_policies, self._tuple_shapes):
# Raise an error if the policy is incompatible with the shape.
_ = policy.get_sharded_shape(shape)
@property
def number_of_tuple_elements(self):
"""Returns the number of InfeedQueue tuple elements."""
return len(self._sharding_policies)
@property
def tuple_types(self):
"""Returns the types of the InfeedQueue tuple elements."""
return self._tuple_types
def set_tuple_types(self, tuple_types):
"""Sets the type of each element of the queue.
tuple_types must be a list of length
self.number_of_tuple_elements, and each element must be
convertible to a dtype.
Args:
tuple_types: the types of each queue element.
Raises:
ValueError: if tuple_types is not of length
self.number_of_tuple_elements.
TypeError: if an element of tuple_types cannot be converted to a
dtype.
"""
if len(tuple_types) != self.number_of_tuple_elements:
raise ValueError("tuple_types is %s, but must be a list of length %d" %
(str(tuple_types), self.number_of_tuple_elements))
if self._frozen:
for (frozen, updated) in zip(self._tuple_types, tuple_types):
if frozen != updated:
raise ValueError(
"Trying to update InfeedQueue with frozen configuration with an "
"incompatible type. Frozen types are %s, updated types are %s" % (
str(self._tuple_types), str(tuple_types)))
else:
try:
self._tuple_types = [dtypes.as_dtype(t) for t in tuple_types]
except (TypeError) as e:
raise TypeError(
"tuple_types is %s, but must be a list of elements each "
"convertible to dtype: got error %s" % (str(tuple_types), str(e)))
@property
def tuple_shapes(self):
"""Returns the shapes of the InfeedQueue tuple elements."""
return self._tuple_shapes
def set_tuple_shapes(self, tuple_shapes):
"""Sets the shape of each element of the queue.
tuple_shapes must be a list of length
self.number_of_tuple_elements, and each element must be
convertible to a TensorShape.
Args:
tuple_shapes: the shapes of each queue element.
Raises:
ValueError: if tuple_shapes is not of length
self.number_of_tuple_elements.
TypeError: if an element of tuple_shapes cannot be converted to
a TensorShape.
"""
if len(tuple_shapes) != self.number_of_tuple_elements:
raise ValueError("tuple_shapes is %s, but must be a list of length %d" %
(str(tuple_shapes), self.number_of_tuple_elements))
try:
tuple_shapes = [tensor_shape.as_shape(shape) for shape in tuple_shapes]
except (ValueError, TypeError) as e:
raise TypeError(
"tuple_shapes is %s, but must be a list of elements each "
"convertible to TensorShape: got error %s" % (str(tuple_shapes),
str(e)))
if self._frozen:
for (frozen, updated) in zip(self._tuple_shapes, tuple_shapes):
if frozen != updated:
raise ValueError(
"Trying to update InfeedQueue with frozen configuration with an "
"incompatible shape. Frozen shapes are %s, updated shapes are %s"
% (str(self._tuple_shapes), str(tuple_shapes)))
else:
self._tuple_shapes = tuple_shapes
self._validate()
@property
def sharding_policies(self):
"""Returns the sharding policies of the InfeedQueue tuple elements."""
return self._sharding_policies
@property
def shard_dimensions(self):
"""Gets the shard dimension of each tuple element.
Returns:
A list of length number_of_tuple_elements, where each list entry
is the shard dimension of that tuple element or None if the
shard dimension has not been set.
"""
# The number of shards is always the same for all the policies.
return [policy.shard_dimension for policy in self._sharding_policies]
def set_shard_dimensions(self, shard_dimensions):
"""Sets the shard_dimension of each element of the queue.
shard_dimensions must be a list of length
self.number_of_tuple_elements, and each element must be
convertible to a Dimension compatible with self.tuple_shapes.
Args:
shard_dimensions: the dimensions of each queue element.
Raises:
ValueError: if shard_dimensions is not of length
self.number_of_tuple_elements; or an element of
shard_dimensions cannot be converted to a Dimension; or an
element of shard_dimensions is a Dimension that is out of
range for the corresponding tuple element shape.
"""
if len(shard_dimensions) != self.number_of_tuple_elements:
raise ValueError("shard_dimensions is %s, but must be a list of length %d"
% (str(shard_dimensions),
self.number_of_tuple_elements))
for (policy, dimension) in zip(self._sharding_policies, shard_dimensions):
policy.set_shard_dimension(dimension)
self._validate()
@property
def number_of_shards(self):
"""Gets the number of shards to use for the InfeedQueue.
Returns:
Number of shards or None if the number of shards has not been set.
"""
# The number of shards is always the same for all the policies.
return self._sharding_policies[0].number_of_shards
def set_number_of_shards(self, number_of_shards):
"""Sets the number of shards to use for the InfeedQueue.
Args:
number_of_shards: number of ways to shard the InfeedQueue.
Raises:
ValueError: if number_of_shards is not > 0; or the policies have
been frozen and number_of_shards was already set to something
else.
"""
for policy in self._sharding_policies:
policy.set_number_of_shards(number_of_shards)
self._validate()
def set_configuration_from_input_tensors(self, input_tensors):
"""Sets the shapes and types of the queue tuple elements.
input_tensors is a list of Tensors whose types and shapes are used
to set the queue configuration.
Args:
input_tensors: list of Tensors of the same types and shapes as
the desired queue Tuple.
Raises:
ValueError: if input_tensors is not a list of length
self.number_of_tuple_elements
"""
if len(input_tensors) != self.number_of_tuple_elements:
raise ValueError(
"input_tensors is %s, but should be a list of %d Tensors", (
str(input_tensors), self.number_of_tuple_elements))
self.set_tuple_shapes([t.shape for t in input_tensors])
self.set_tuple_types([t.dtype for t in input_tensors])
def set_configuration_from_sharded_input_tensors(self, input_tensors):
"""Sets the shapes and types of the queue tuple elements.
input_tensors is a list of lists of Tensors whose types and shapes are used
to set the queue configuration. The length of the outer list is the number
of shards required, and each inner list is the tuple of Tensors to use to
determine the types and shapes of the corresponding shard. This method
depends on the shard dimension, and calling it freezes the shard policy.
Args:
input_tensors: list of lists of Tensors. The outer list length corresponds
to the desired number of shards, and each inner list is the size
and shape of the desired configuration of the corresponding shard.
Raises:
ValueError: if any inner list is not a list of length
self.number_of_tuple_elements; or the inner lists do not combine to
form a consistent unsharded shape.
TypeError: if the types of the Tensors in the inner lists do not match.
"""
if not self._frozen:
# Unset the tuple shapes in case the configuration becomes
# transiently inconsistent.
self._tuple_shapes = None
number_of_shards = len(input_tensors)
self.set_number_of_shards(number_of_shards)
for t in input_tensors:
if len(t) != self.number_of_tuple_elements:
raise ValueError(
"input_tensors is %s but must be a list of lists, where each inner"
" list has length number_of_tuple_elements=%d" % (
str(input_tensors), self.number_of_tuple_elements))
# Transpose the inputs to make a list of shard shapes for each tuple
# element.
sharded_shapes = [[t[i].shape for t in input_tensors]
for i in xrange(self.number_of_tuple_elements)]
# For each tuple, get the unsharded shape using that tuple's policy.
unsharded_shapes = [
policy.get_unsharded_shape(s)
for (policy, s) in zip(self._sharding_policies, sharded_shapes)
]
self.set_tuple_shapes(unsharded_shapes)
for i in xrange(1, self.number_of_shards):
for (t1, t2) in zip(input_tensors[0], input_tensors[i]):
if t1.dtype != t2.dtype:
raise TypeError(
"types of the tuple elements of input_tensors %s are not "
"consistent" % str(input_tensors))
self.set_tuple_types([t.dtype for t in input_tensors[0]])
def freeze(self):
"""Freezes the InfeedQueue so it can no longer be modified.
The configuration is implicitly frozen before any host-side or
device-side Ops are generated. The configuration cannot be frozen
until the types and shapes of the tuple elements have been set.
Raises:
ValueError: if the types or shapes of the tuple elements have not been
set.
"""
self._frozen = True
if self._tuple_types is None:
raise ValueError(
"Can't freeze an InfeedQueue without setting all tuple types.")
if self._tuple_shapes is None:
raise ValueError(
"Can't freeze an InfeedQueue without setting all tuple shapes.")
for shape in self._tuple_shapes:
if shape.dims is None:
raise ValueError(
"Can't freeze an InfeedQueue without setting all tuple shapes.")
for policy in self._sharding_policies:
policy.freeze()
self._validate()
def generate_dequeue_op(self, tpu_device=0):
"""Generates the device-side Op to dequeue a tuple from the queue.
Implicitly freezes the queue configuration if it is not already
frozen, which will raise errors if the shapes and types have not
been fully specified.
Args:
tpu_device: The TPU device ordinal where the infeed instruction should be
placed. If None, no explicit placement will be performed, and it is up
to the user to call this API from within a proper TPU device scope.
The XLA code will fail if the TPU dequeue instruction is not bound to
any device.
Returns:
A list of Outputs corresponding to a shard of infeed dequeued
into XLA, suitable for use within a replicated block.
Raises:
ValueError: if the types or shapes of the tuple elements have not been
set; or if a dequeue op has already been generated.
"""
self.freeze()
if self._generated_dequeue_op:
raise ValueError("Can't generate two dequeue Ops from the same queue")
self._generated_dequeue_op = True
full_name = "%s/dequeue" % self._name
sharded_shapes = [
policy.get_sharded_shape(shape)
for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
]
if tpu_device is not None:
with ops.device(tpu.core(tpu_device)):
return tpu_ops.infeed_dequeue_tuple(
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
else:
return tpu_ops.infeed_dequeue_tuple(
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
def _generate_enqueue_op(self,
inputs,
name_prefix,
index,
device=None,
tpu_ordinal=-1):
"""Generate a host-side Op to enqueue a tuple to the queue.
If device is None the inputs are all required to have the same
device specification, and the enqueue Op is colocated with
inputs[0]. Otherwise the enqueue Op is placed on 'device'.
Args:
inputs: a list of Tensors with the types and shapes of the tuple elements.
name_prefix: the base name for the Op.
index: the shard index, used to uniquify the Op name.
device: device to place the Op on, or None if it should be
colocated with the inputs.
tpu_ordinal: ordinal of the TPU device on the host to use for
infeed if device is a CPU device. Should be set to -1 if device
is a TPU device.
Returns:
An Op corresponding to a shard of infeed enqueued at the host,
suitable for use within a replicated block.
Raises:
ValueError: if device is None and inputs do not all have the
same device specification.
"""
full_name = "%s/%d" % (name_prefix, index)
shapes = [t.shape for t in inputs]
if device is None:
devices = [t.device for t in inputs]
for i in xrange(1, self.number_of_tuple_elements):
if devices[0] != devices[i]:
raise ValueError(
"input devices for shard %d are %s, but should all be the same",
index, str(devices))
with ops.colocate_with(inputs[0]):
return tpu_ops.infeed_enqueue_tuple(
inputs=inputs,
shapes=shapes,
name=full_name,
device_ordinal=tpu_ordinal)
else:
with ops.device(device):
return tpu_ops.infeed_enqueue_tuple(
inputs=inputs,
shapes=shapes,
name=full_name,
device_ordinal=tpu_ordinal)
def generate_enqueue_ops(self,
sharded_inputs,
tpu_ordinal_function=None,
placement_function=None):
"""Generates the host-side Ops to enqueue the shards of a tuple.
sharded_inputs is a list, one for each shard, of lists of
Tensors. sharded_inputs[0] is the tuple of Tensors to use to feed
shard 0 if the queue. Returns the host-side Ops that must be run to
enqueue the sharded tuple. The Op for shard i is colocated with the inputs
for shard i.
Implicitly freezes the queue configuration if it is not already
frozen. If the configuration has already been frozen, and is not
compatible with the types and shapes of sharded_inputs, an error
will be raised.
Args:
sharded_inputs: a list of lists of Tensors. The length of the outer list
determines the number of shards. Each inner list indicates the types
and shapes of the tuples in the corresponding shard.
tpu_ordinal_function: if not None, a function that takes the
shard index as input and returns the ordinal of the TPU device
the shard's infeed should be placed on. tpu_ordinal_function must be
set if the inputs are placed on CPU devices.
placement_function: if not None, a function that takes the shard index as
input and returns the host device where the enqueue op should be placed
on.
Returns:
A list of host-side Ops, one for each shard, that when executed together
will enqueue a full-size element of infeed.
Raises:
ValueError: if the queue configuration has previously been frozen and the
shapes of the elements of sharded_inputs are not compatible with the
frozen configuration; or if the shapes of the elements of sharded_inputs
don't form a consistent unsharded tuple; or if the elements of a tuple
have different device constraints.
TypeError: if the queue configuration has previously been frozen and the
types of the elements of sharded_inputs are not compatible with the
frozen configuration; or if the types of the elements of sharded_inputs
don't form a consistent unsharded tuple.
"""
self.set_configuration_from_sharded_input_tensors(sharded_inputs)
self.freeze()
if self._generated_enqueue_ops:
raise ValueError("Can't generate two enqueue Ops from the same queue")
self._generated_enqueue_ops = True
if tpu_ordinal_function is None:
tpu_ordinal_function = lambda index: -1
name_prefix = "%s/enqueue" % self._name
return [
self._generate_enqueue_op(
shard,
name_prefix,
index,
tpu_ordinal=tpu_ordinal_function(index),
device=placement_function(index) if placement_function else None)
for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards))
]
# TODO(misard) Generalize this to the case of systems that don't
# have 8 devices per host, and figure out what to do with
# model-parallelism.
def _default_placement_function(self, index):
return "/task:%d/device:CPU:0" % (index / 8)
def _default_ordinal_function(self, index):
return index % 8
# TODO(b/36470756) remove this from tutorials once we have a better story
# for automatic placement of input pipelines.
def split_inputs_and_generate_enqueue_ops(self,
inputs,
device_assignment=None,
placement_function=None,
tpu_ordinal_function=None):
"""POORLY-PERFORMING ON MULTI-HOST SYSTEMS.
Generates the host-side Ops to enqueue a tuple.
This method performs poorly because it takes an entire input on a single
host, splits it, and distributes it to all of the cores. It is present only
to simplify tutorial examples.
inputs is a list of Tensors to use to feed the queue. Each input is split
into self.number_of_shards shards. Returns an Op for each shard to enqueue
the shard. The Op for shard i is placed on device placement_function(i).
Implicitly freezes the queue configuration if it is not already
frozen. If the configuration has already been frozen, and is not
compatible with the types and shapes of inputs, an error
will be raised.
Args:
inputs: a list of Tensors which indicates the types and shapes of the
queue tuple.
device_assignment: if not `None`, a TPU `DeviceAssignment`. If
device_assignment is not `None`, but `placement_function` and
`ordinal_function` are None, then `device_assignment` will be used to
place infeeds on the first k TPU shards, where k is the number of shards
in the queue. If all three are `None`, then default placement and
ordinal functions are used.
placement_function: if not None, a function that takes the shard
index as input and returns a device string indicating which
device the shard's infeed should be placed on. If placement_function
and tpu_ordinal_function are None, inputs are sharded round-robin
across the devices in the system.
tpu_ordinal_function: if not None, a function that takes the
shard index as input and returns the ordinal of the TPU device
the shard's infeed should be placed on. If placement_function
and tpu_ordinal_function are None, inputs are sharded round-robin
across the devices in the system.
Returns:
A list of host-side Ops, one for each shard, that when executed together
will enqueue a full-size element of infeed.
Raises:
ValueError: if the queue configuration has previously been frozen and the
shapes of the elements of inputs are not compatible with the frozen
configuration.
TypeError: if the queue configuration has previously been frozen and the
types of the elements of inputs are not compatible with the frozen
configuration.
"""
if device_assignment is None:
if placement_function is None:
placement_function = self._default_placement_function
if tpu_ordinal_function is None:
tpu_ordinal_function = self._default_ordinal_function
else:
def _placement_function_from_map(index):
return device_assignment.host_device(replica=index)
def _ordinal_function_from_map(index):
return device_assignment.tpu_ordinal(replica=index)
if placement_function is None:
placement_function = _placement_function_from_map
if tpu_ordinal_function is None:
tpu_ordinal_function = _ordinal_function_from_map
self.set_configuration_from_input_tensors(inputs)
self.freeze()
if self._generated_enqueue_ops:
raise ValueError("Can't generate two enqueue Ops from the same queue")
self._generated_enqueue_ops = True
split_name_prefix = "%s/split" % self._name
if self.number_of_shards == 1:
transposed_sharded_inputs = [[inp] for inp in inputs]
else:
def split_fn(inp, num_shards, axis, name):
with ops.colocate_with(inp):
return array_ops.split(inp, num_shards, axis=axis, name=name)
transposed_sharded_inputs = [
split_fn(
inp,
self.number_of_shards,
axis=policy.shard_dimension,
name="%s/%d" % (split_name_prefix, index))
for (inp, policy, index) in zip(inputs, self._sharding_policies,
xrange(self.number_of_tuple_elements))
]
sharded_inputs = [[shard[i] for shard in transposed_sharded_inputs]
for i in xrange(self.number_of_shards)]
name_prefix = "%s/enqueue" % self._name
return [
self._generate_enqueue_op(
shard,
name_prefix,
index,
device=placement_function(index),
tpu_ordinal=tpu_ordinal_function(index))
for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards))
]
class _PartitionedInfeedQueue(InfeedQueue):
"""A helper object to build a device infeed queue with input partition.
Args:
number_of_tuple_elements: the number of Tensors fed atomically through the
queue, must be present unless it can be inferred from other arguments.
device_assignment: A TPU `DeviceAssignment` which is used to place all the
partitions to different TPU infeed queues.
host_id: The id of the host machine.
input_partition_dims: A nested list/tuple of integers. Each inner
list/tuple describes how to partition the corresponding input tensor.
tuple_types: If not None, a list of types of the elements of the queue.
tuple_shapes: If not None, a list of shapes of the elements of the queue.
name: The name of the queue.
"""
def __init__(self,
number_of_tuple_elements,
device_assignment,
host_id,
input_partition_dims=None,
tuple_types=None,
tuple_shapes=None,
name=None):
super(_PartitionedInfeedQueue, self).__init__(
number_of_tuple_elements=number_of_tuple_elements,
tuple_types=tuple_types,
tuple_shapes=None,
shard_dimensions=None,
name="PartitionedInfeedQueue" if name is None else name)
self._input_partition_dims = input_partition_dims
self._host_id = host_id
self._device_assignment = device_assignment
def generate_dequeue_op(self, tpu_device=0):
"""Generate TPU dequeue ops.
Args:
tpu_device: The TPU device ordinal where the infeed instruction should be
placed.
Returns:
A list of Outputs corresponding to a partition of infeed dequeued
into XLA, suitable for use within a replicated block.
Raises:
ValueError: if the types or shapes of the tuple elements have not been
set; or if a dequeue op has already been generated.
"""
self.freeze()
if self._generated_dequeue_op:
raise ValueError("Can't generate two dequeue Ops from the same queue")
self._generated_dequeue_op = True
full_name = "%s/dequeue" % self._name
sharded_shapes = [
policy.get_sharded_shape(shape)
for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
]
with ops.device(tpu.core(tpu_device)):
values = tpu_ops.infeed_dequeue_tuple(
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
return self._tag_sharding_attribute_for_dequeued_tensors(
values, self._input_partition_dims)
def generate_enqueue_ops(self, per_host_sharded_inputs):
"""Generates the host-side Ops to enqueue the partitioned inputs.
per_host_sharded_inputs is a list, one for each replica, of lists of
Tensors. sharded_inputs[i] is the tuple of Tensors to use to feed
replica i.
sharded_inputs[i][j] is partitioned by self._input_partition_dims[j].
For example, if sharded_inputs[i][j] is a 2-D Tensor:
[[A, B, C, D],
[E ,F, G, H]]
self._input_partition_dims[j] is [2, 4].
sharded_inputs[i][j] will be partitioned and flattened into:
[A, B, C, D, E, F, G, H] and fed into the logical core ids:
[0, 1, 2, 3, 4, 5, 6, 7] respectively.
Args:
per_host_sharded_inputs: a list of lists of Tensors. The length of the
outer list determines the number of shards. Each inner list indicates
the types and shapes of the tuples in the corresponding shard.
Returns:
A list of host-side Ops, one for each shard, that when executed together
will enqueue a full-size element of infeed.
Raises:
ValueError: if the queue configuration has previously been frozen and the
shapes of the elements of sharded_inputs are not compatible with the
frozen configuration; or if the shapes of the elements of sharded_inputs
don't form a consistent unsharded tuple; or if the elements of a tuple
have different device constraints; or if the partition dims are invalid.
TypeError: if the queue configuration has previously been frozen and the
types of the elements of sharded_inputs are not compatible with the
frozen configuration; or if the types of the elements of sharded_inputs
don't form a consistent unsharded tuple.
"""
self.set_configuration_from_sharded_input_tensors(per_host_sharded_inputs)
number_of_replicas_per_host = len(per_host_sharded_inputs)
number_of_tuple_elements = len(per_host_sharded_inputs[0])
assert len(self._input_partition_dims) == number_of_tuple_elements
per_host_enqueue_ops = []
for replica_index in range(number_of_replicas_per_host):
flattened_inputs = per_host_sharded_inputs[replica_index]
inputs_part_dims_flat = nest.flatten_up_to(flattened_inputs,
self._input_partition_dims)
inputs_parted_iters = [
iter(self._partition_or_replicate_on_host(x, dims)) for x, dims in
zip(per_host_sharded_inputs[replica_index], inputs_part_dims_flat)
]
for logical_core in xrange(self._device_assignment.num_cores_per_replica):
# Places different partitions to different logic cores.
replica_id = self._device_assignment.lookup_replicas(
self._host_id, logical_core)[replica_index]
ordinal = self._device_assignment.tpu_ordinal(
replica=replica_id, logical_core=logical_core)
infeed_inputs = []
for it in inputs_parted_iters:
input_for_device = next(it, None)
if input_for_device is not None:
infeed_inputs.append(input_for_device)
if infeed_inputs:
per_host_enqueue_ops.append(
tpu_ops.infeed_enqueue_tuple(
inputs=infeed_inputs,
shapes=[x.shape for x in infeed_inputs],
name="enqueue/replica_{0}/input_{1}".format(
replica_index, logical_core),
device_ordinal=ordinal))
return per_host_enqueue_ops
def _check_input_partition_dims(self, tensor, dims):
"""Checks that input partition dims are valid for the `Tensor`.
Args:
tensor: Input tensor for partitioning.
dims: A list of integer describes how to partition the input tensor.
Raises:
ValueError: If the tensor can't be partitioned by dims or the
num_cores_per_replica doesn't match the number of
partitions(dims.prod()).
"""
if dims is None:
return
dims = np.array(dims)
if (dims < 1).any():
raise ValueError("All input partition dims must be >= 1.")
# No partitioning, so don't perform further checks.
if dims.prod() == 1:
return
if dims.prod() != self._device_assignment.num_cores_per_replica:
raise ValueError(
"The product of each input parition dim should equal to "
"num_cores_per_replica. (dim = {}, num_cores_per_replica "
"= {})".format(dims, self._device_assignment.num_cores_per_replica))
if dims.shape[0] != tensor.shape.ndims:
raise ValueError(
"Input partition dims must have the same number of dimensions "
"as the `Tensor` to be partitioned. (tensor shape = {}, input "
"partition dims = {}).".format(tensor.shape.as_list(), dims))
tensor.shape.assert_is_fully_defined()
if (np.array(tensor.shape.as_list()) % dims != 0).any():
raise ValueError(
"All input partition dims must divide exactly into the `Tensor` "
"shape (tensor shape = {}, input partition dims = {}).".format(
tensor.shape.as_list(), dims))
def _partition_or_replicate_on_host(self, tensor, dims):
"""Partitions or replicates the input tensor.
The ops inside this function are placed on the host side.
Args:
tensor: The input tensor which will be partioned or replicated.
dims: A list of integer describes how to partition the input tensor.
Returns:
An iterator of `Tensor`s or a list of partioned tensors.
"""
self._check_input_partition_dims(tensor, dims)
if dims is None:
return itertools.repeat(tensor)
else:
output = [tensor]
for axis, dim in enumerate(dims):
if dim > 1:
output = [array_ops.split(x, dim, axis=axis) for x in output]
output = nest.flatten(output)
return output
def _tag_sharding_attribute_for_dequeued_tensor(self, tensor, dims):
"""Tags appropriate XLA sharding attribute to the dequeued tensor.
Args:
tensor: The dequeued tensor on TPU.
dims: A list of integer describes how the tensor is partitioned.
Returns:
The same tensor with the xla_sharding attribute.
"""
if dims is None:
return xla_sharding.replicate(tensor)
elif np.prod(dims) == 1:
return xla_sharding.assign_device(tensor, 0)
else:
tile_shape = np.array(tensor.shape.as_list()) // dims
tile_assignment = np.arange(np.prod(dims)).reshape(dims)
return xla_sharding.tile(
tensor=tensor,
tile_shape=xla_shape.CreateShapeFromDtypeAndTuple(
dtype=np.dtype(tensor.dtype.as_numpy_dtype),
shape_tuple=tile_shape),
tile_assignment=tile_assignment)
def _tag_sharding_attribute_for_dequeued_tensors(self, dequeues, dims):
"""Tags appropriate XLA sharding attribute to the dequeued tensors.
Args:
dequeues: A list of dequeued tensors on TPU.
dims: A list of integer describes how the tensor is partitioned.
Returns:
The same dequeues with appropriate xla_sharding attribute.
"""
nest.assert_shallow_structure(dequeues, dims)
return nest.map_structure_up_to(
dequeues, self._tag_sharding_attribute_for_dequeued_tensor, dequeues,
dims)