Internal change
PiperOrigin-RevId: 303437665
Change-Id: I4142a9c4b125a8537ae60b4cfd12ffbe0c376574
diff --git a/tensorflow/core/api_def/base_api/api_def_EnqueueTPUEmbeddingRaggedTensorBatch.pbtxt b/tensorflow/core/api_def/base_api/api_def_EnqueueTPUEmbeddingRaggedTensorBatch.pbtxt
deleted file mode 100644
index cdcdd6d..0000000
--- a/tensorflow/core/api_def/base_api/api_def_EnqueueTPUEmbeddingRaggedTensorBatch.pbtxt
+++ /dev/null
@@ -1,77 +0,0 @@
-op {
- graph_op_name: "EnqueueTPUEmbeddingRaggedTensorBatch"
- visibility: HIDDEN
- in_arg {
- name: "sample_splits"
- description: <<END
-A list of rank 1 Tensors specifying the break points for splitting
-embedding_indices and aggregation_weights into rows.
-It corresponds to ids.row_splits in embedding_lookup(), when ids is a
-RaggedTensor.
-END
- }
- in_arg {
- name: "embedding_indices"
- description: <<END
-A list of rank 1 Tensors, indices into the embedding tables.
-It corresponds to ids.values in embedding_lookup(), when ids is a RaggedTensor.
-END
- }
- in_arg {
- name: "aggregation_weights"
- description: <<END
-A list of rank 1 Tensors containing per training example
-aggregation weights. It corresponds to the values field of a RaggedTensor
-with the same row_splits as ids in embedding_lookup(), when ids is a
-RaggedTensor.
-END
- }
- in_arg {
- name: "mode_override"
- description: <<END
-A string input that overrides the mode specified in the
-TPUEmbeddingConfiguration. Supported values are {'unspecified', 'inference',
-'training', 'backward_pass_only'}. When set to 'unspecified', the mode set
-in TPUEmbeddingConfiguration is used, otherwise mode_override is used.
-END
- }
- attr {
- name: "device_ordinal"
- description: <<END
-The TPU device to use. Should be >= 0 and less than the number
-of TPU cores in the task on which the node is placed.
-END
- }
- attr {
- name: "combiners"
- description: <<END
-A list of string scalars, one for each embedding table that specify
-how to normalize the embedding activations after weighted summation.
-Supported combiners are 'mean', 'sum', or 'sqrtn'. It is invalid to have
-the sum of the weights be 0 for 'mean' or the sum of the squared weights be
-0 for 'sqrtn'. If combiners isn't passed, the default is to use 'sum' for
-all tables.
-END
- }
- attr {
- name: "table_ids"
- description: <<END
-A list of integers specifying the identifier of the embedding table
-(offset of TableDescriptor in the TPUEmbeddingConfiguration) to lookup the
-corresponding input. The ith input is looked up using table_ids[i]. The size
-of the table_ids list must be equal to that of sample_indices,
-embedding_indices and aggregation_weights.
-END
- }
- summary: "Eases the porting of code that uses tf.nn.embedding_lookup()."
- description: <<END
-sample_splits[i], embedding_indices[i] and aggregation_weights[i] correspond
-to the ith feature. table_ids[i] indicates which embedding table to look up ith
-feature.
-
-The tensors at corresponding positions in two of the input lists,
-embedding_indices and aggregation_weights, must have the same shape, i.e. rank 1
-with dim_size() equal to the total number of lookups into the table described by
-the corresponding feature.
-END
-}
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 9175088..c4d9dae 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -13185,101 +13185,6 @@
is_stateful: true
}
op {
- name: "EnqueueTPUEmbeddingRaggedTensorBatch"
- input_arg {
- name: "sample_indices"
- type_attr: "T1"
- number_attr: "N"
- }
- input_arg {
- name: "embedding_indices"
- type_attr: "T2"
- number_attr: "N"
- }
- input_arg {
- name: "aggregation_weights"
- type_attr: "T3"
- number_attr: "N"
- }
- input_arg {
- name: "mode_override"
- type: DT_STRING
- }
- attr {
- name: "T1"
- type: "type"
- default_value {
- type: DT_INT32
- }
- allowed_values {
- list {
- type: DT_INT32
- type: DT_INT64
- }
- }
- }
- attr {
- name: "T2"
- type: "type"
- default_value {
- type: DT_INT32
- }
- allowed_values {
- list {
- type: DT_INT32
- type: DT_INT64
- }
- }
- }
- attr {
- name: "T3"
- type: "type"
- default_value {
- type: DT_FLOAT
- }
- allowed_values {
- list {
- type: DT_FLOAT
- type: DT_DOUBLE
- }
- }
- }
- attr {
- name: "N"
- type: "int"
- has_minimum: true
- minimum: 1
- }
- attr {
- name: "device_ordinal"
- type: "int"
- default_value {
- i: -1
- }
- }
- attr {
- name: "combiners"
- type: "list(string)"
- default_value {
- list {
- }
- }
- }
- attr {
- name: "table_ids"
- type: "list(int)"
- }
- attr {
- name: "max_sequence_lengths"
- type: "list(int)"
- default_value {
- list {
- }
- }
- }
- is_stateful: true
-}
-op {
name: "EnsureShape"
input_arg {
name: "input"
diff --git a/tensorflow/core/ops/tpu_embedding_ops.cc b/tensorflow/core/ops/tpu_embedding_ops.cc
index 164d78e..821dff7 100644
--- a/tensorflow/core/ops/tpu_embedding_ops.cc
+++ b/tensorflow/core/ops/tpu_embedding_ops.cc
@@ -168,20 +168,4 @@
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape);
-REGISTER_OP("EnqueueTPUEmbeddingRaggedTensorBatch")
- .Input("sample_splits: N * T1")
- .Input("embedding_indices: N * T2")
- .Input("aggregation_weights: N * T3")
- .Input("mode_override: string")
- .Attr("T1: {int32,int64} = DT_INT32")
- .Attr("T2: {int32,int64} = DT_INT32")
- .Attr("T3: {float32,float64} = DT_FLOAT")
- .Attr("N: int >= 1")
- .Attr("device_ordinal: int = -1")
- .Attr("combiners: list(string) = []")
- .Attr("table_ids: list(int)")
- .Attr("max_sequence_lengths: list(int) = []")
- .SetIsStateful()
- .SetShapeFn(shape_inference::UnknownShape);
-
} // namespace tensorflow
diff --git a/tensorflow/python/tpu/ops/tpu_ops.py b/tensorflow/python/tpu/ops/tpu_ops.py
index 8facb1f..c1ea364 100644
--- a/tensorflow/python/tpu/ops/tpu_ops.py
+++ b/tensorflow/python/tpu/ops/tpu_ops.py
@@ -444,76 +444,3 @@
enqueue_tpu_embedding_sparse_tensor_batch.__doc__ = (
gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch.__doc__)
-
-
-# pylint: disable=protected-access
-def enqueue_tpu_embedding_ragged_tensor_batch(sample_splits,
- embedding_indices,
- aggregation_weights,
- table_ids,
- device_ordinal,
- max_sequence_lengths=None,
- combiners=None,
- mode_override=None,
- name=None):
- """A placeholder op for enqueueing embedding IDs to the TPU.
-
- Args:
- sample_splits: A list of rank 1 Tensors specifying the break points for
- splitting embedding_indices and aggregation_weights into rows. It
- corresponds to ids.row_splits in embedding_lookup(), when ids is a
- RaggedTensor. Both int32 and int64 are allowed and will be converted to
- int32 internally.
- embedding_indices: A list of rank 1 Tensors, indices into the embedding
- tables. It corresponds to ids.values in embedding_lookup(), when ids is a
- RaggedTensor. Both int32 and int64 are allowed and will be converted to
- int32 internally.
- aggregation_weights: A list of rank 1 Tensors containing per training
- example aggregation weights. It corresponds to the values field of a
- RaggedTensor with the same row_splits as ids in embedding_lookup(), when
- ids is a RaggedTensor. Both float32 and float64 are allowed and will be
- converted to float32 internally.
- table_ids: A list of integers specifying the identifier of the embedding
- table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
- lookup the corresponding input. The ith input is looked up using
- table_ids[i]. The size of the table_ids list must be equal to that of
- sample_indices, embedding_indices and aggregation_weights.
- device_ordinal: The TPU device to use. Should be >= 0 and less than the
- number of TPU cores in the task on which the node is placed.
- max_sequence_lengths: A list of integers, the size of which is equal to
- sample_indices. If equal to 0, the corresponding feature is considered to
- be a non-sequence feature, If greater than 0, the corresponding feature is
- a sequence feature with the given maximal length. If None, then we assume
- a list of all zeroes.
- combiners: A list of string scalars, one for each embedding table that
- specify how to normalize the embedding activations after weighted
- summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
- invalid to have the sum of the weights be 0 for 'mean' or the sum of the
- squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
- is to use 'sum' for all tables (optional).
- mode_override: A string input that overrides the mode specified in the
- TPUEmbeddingConfiguration. Supported values are {'unspecified',
- 'inference', 'training', 'backward_pass_only'}. When set to 'unspecified',
- the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override
- is used (optional).
- name: A name for the operation (optional).
-
- Returns:
- An EnqueueTPUEmbeddingRaggedTensorBatch operation.
- """
- if mode_override is None:
- mode_override = "unspecified"
- return gen_tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch(
- sample_splits=sample_splits,
- embedding_indices=embedding_indices,
- aggregation_weights=aggregation_weights,
- table_ids=table_ids,
- device_ordinal=device_ordinal,
- max_sequence_lengths=max_sequence_lengths,
- combiners=combiners,
- mode_override=mode_override,
- name=name)
-
-
-enqueue_tpu_embedding_ragged_tensor_batch.__doc__ = (
- gen_tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch.__doc__)
diff --git a/tensorflow/python/tpu/tpu_embedding.py b/tensorflow/python/tpu/tpu_embedding.py
index e24188e..e3dbe7f 100644
--- a/tensorflow/python/tpu/tpu_embedding.py
+++ b/tensorflow/python/tpu/tpu_embedding.py
@@ -205,48 +205,6 @@
aggregation_weights=weights.values if weights is not None else None)
-class RaggedEnqueueData(
- collections.namedtuple(
- 'RaggedEnqueueData',
- ['embedding_indices', 'sample_splits', 'aggregation_weights'])):
- """RaggedTensor Data to be enqueued through generate_enqueue_ops()."""
-
- def __new__(cls,
- embedding_indices,
- sample_splits=None,
- aggregation_weights=None):
- """Data to be enqueued through generate_enqueue_ops().
-
- Args:
- embedding_indices: A rank 1 Tensor, indices into the embedding tables. It
- corresponds to ids.values in embedding_lookup(), when ids is a
- RaggedTensor. Both int32 and int64 are allowed and will be converted to
- int32 internally.
- sample_splits: A rank 1 Tensor specifying the break points for splitting
- embedding_indices and aggregation_weights into rows. It corresponds to
- ids.row_splits in embedding_lookup(), when ids is a RaggedTensor. Both
- int32 and int64 are allowed and will be converted to int32 internally.
- aggregation_weights: A rank 1 Tensor containing per training example
- aggregation weights. It corresponds to the values field of a
- RaggedTensor with the same row_splits as ids in embedding_lookup(), when
- ids is a RaggedTensor.
-
- Returns:
- An RaggedEnqueueData tuple.
-
- """
- return super(RaggedEnqueueData,
- cls).__new__(cls, embedding_indices, sample_splits,
- aggregation_weights)
-
- @staticmethod
- def from_ragged_tensor(rg_tensor, weights=None):
- return RaggedEnqueueData(
- rg_tensor.values,
- rg_tensor.row_splits,
- aggregation_weights=weights.values if weights is not None else None)
-
-
def get_enqueue_datas_list_from_sparse_tensors_list(sp_tensors_list):
"""Convenient function for generate_enqueue_ops().
@@ -271,30 +229,6 @@
return enqueue_datas_list
-def get_enqueue_datas_list_from_ragged_tensors_list(rg_tensors_list):
- """Convenient function for generate_enqueue_ops().
-
- Args:
- rg_tensors_list: a list of dictionary mapping from string of feature names
- to RaggedTensor. Each dictionary is for one TPU core. Dictionaries for the
- same host should be contiguous on the list.
-
- Returns:
- enqueue_datas_list: a list of dictionary mapping from string
- of feature names to RaggedEnqueueData. Each dictionary is for one
- TPU core. Dictionaries for the same host should be contiguous
- on the list.
-
- """
- enqueue_datas_list = []
- for rg_tensors in rg_tensors_list:
- enqueue_datas = collections.OrderedDict(
- (k, RaggedEnqueueData.from_ragged_tensor(v))
- for k, v in six.iteritems(rg_tensors))
- enqueue_datas_list.append(enqueue_datas)
- return enqueue_datas_list
-
-
AdamSlotVariableNames = collections.namedtuple(
'AdamSlotVariableNames', ['m', 'v'])
@@ -1225,12 +1159,7 @@
slot_variables_by_table,
load_ops, retrieve_ops)
- def generate_enqueue_ops(
- self,
- enqueue_datas_list,
- mode_override=None,
- ragged=False,
- ):
+ def generate_enqueue_ops(self, enqueue_datas_list, mode_override=None):
"""Generate enqueue ops.
Args:
@@ -1243,8 +1172,6 @@
'inference', 'training', 'backward_pass_only'}. When set to
'unspecified', the mode set in TPUEmbeddingConfiguration is used,
otherwise mode_override is used (optional).
- ragged: If True, creates RaggedTensor enqueue ops rather than
- SparseTensor.
Returns:
Ops to enqueue to TPU for embedding.
@@ -1255,7 +1182,6 @@
enqueue_datas,
device_ordinal=i % self._num_cores_per_host,
mode_override=mode_override,
- ragged=ragged,
) for i, enqueue_datas in enumerate(enqueue_datas_list)
]
@@ -1285,50 +1211,28 @@
for feature, enqueue_data in six.iteritems(enqueue_datas):
combiner = self._table_to_config_dict[
self._feature_to_config_dict[feature].table_id].combiner
+ if not isinstance(enqueue_data, EnqueueData):
+ raise ValueError('`enqueue_datas_list[{}]` has a feature that is '
+ 'not mapped to `EnqueueData`. `feature`: {}'.format(
+ i, feature))
- if isinstance(enqueue_data, EnqueueData):
- if enqueue_data.sample_indices is None and combiner:
- logging.warn(
- 'No sample indices set for features %f table %f but '
- 'combiner is set to %s.', feature,
- self._feature_to_config_dict[feature].table_id, combiner)
- if (enqueue_data.sample_indices is not None and
- enqueue_data.sample_indices.device !=
- enqueue_data.embedding_indices.device):
- raise ValueError(
- 'Device of sample_indices does not agree with '
- 'that of embedding_indices for feature {}.'.format(feature))
- if (enqueue_data.aggregation_weights is not None and
- enqueue_data.aggregation_weights.device !=
- enqueue_data.embedding_indices.device):
- raise ValueError(
- 'Device of aggregation_weights does not agree with '
- 'that of embedding_indices for feature {}.'.format(feature))
+ if enqueue_data.sample_indices is None and combiner:
+ logging.warn('No sample indices set for features %f table %f but '
+ 'combiner is set to %s.', feature,
+ self._feature_to_config_dict[feature].table_id, combiner)
- elif isinstance(enqueue_data, RaggedEnqueueData):
- if enqueue_data.sample_splits is None and combiner:
- logging.warn(
- 'No sample splits set for features %f table %f but '
- 'combiner is set to %s.', feature,
- self._feature_to_config_dict[feature].table_id, combiner)
- if (enqueue_data.sample_splits is not None and
- enqueue_data.sample_splits.device !=
- enqueue_data.embedding_indices.device):
- raise ValueError(
- 'Device of sample_splits does not agree with '
- 'that of embedding_indices for feature {}.'.format(feature))
- if (enqueue_data.aggregation_weights is not None and
- enqueue_data.aggregation_weights.device !=
- enqueue_data.embedding_indices.device):
- raise ValueError(
- 'Device of aggregation_weights does not agree with '
- 'that of embedding_indices for feature {}.'.format(feature))
-
- else:
+ if (enqueue_data.sample_indices is not None and
+ enqueue_data.sample_indices.device !=
+ enqueue_data.embedding_indices.device):
raise ValueError(
- '`enqueue_datas_list[{}]` has a feature that is not mapped to '
- '`EnqueueData` or `RaggedEnqueueData`. `feature`: {}'.format(
- i, feature))
+ 'Device of sample_indices does not agree with '
+ 'that of embedding_indices for feature {}.'.format(feature))
+ if (enqueue_data.aggregation_weights is not None and
+ enqueue_data.aggregation_weights.device !=
+ enqueue_data.embedding_indices.device):
+ raise ValueError(
+ 'Device of aggregation_weights does not agree with '
+ 'that of embedding_indices for feature {}.'.format(feature))
# Check all features are on the same device.
if device is None:
device = enqueue_data.embedding_indices.device
@@ -1353,69 +1257,23 @@
else:
contiguous_device = device
- def _generate_enqueue_op(self,
- enqueue_datas,
- device_ordinal,
- mode_override=None,
- ragged=False):
- """Creates op for enqueuing batch to TPU."""
+ def _generate_enqueue_op(
+ self, enqueue_datas, device_ordinal, mode_override=None):
enqueue_data0 = list(enqueue_datas.values())[0]
with ops.colocate_with(enqueue_data0.embedding_indices):
- if ragged:
- # note that this is currently identical in behavior
- return tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch(
- device_ordinal=device_ordinal,
- combiners=self._combiners,
- mode_override=mode_override,
- **self._format_for_tpu_embedding_ragged_tensor_batch(enqueue_datas))
- else:
- return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
- device_ordinal=device_ordinal,
- combiners=self._combiners,
- mode_override=mode_override,
- **self._format_for_tpu_embedding_sparse_tensor_batch(enqueue_datas))
-
- def _format_for_tpu_embedding_ragged_tensor_batch(self, enqueue_datas):
- """Format sparse features for `enqueue_tpu_embedding_ragged_tensor_batch()`.
-
- Args:
- enqueue_datas: a `Dict` of `RaggedEnqueueData` objects for embedding.
-
- Returns:
- Dict of arguments for `enqueue_tpu_embedding_ragged_tensor_batch()`.
- """
-
- kwargs = {
- 'sample_splits': [],
- 'embedding_indices': [],
- 'aggregation_weights': [],
- 'table_ids': [],
- 'max_sequence_lengths': [],
- }
- for table_id, table in enumerate(self._table_to_features_dict):
- features = self._table_to_features_dict[table]
- for feature in features:
- enqueue_data = enqueue_datas[feature]
-
- kwargs['sample_splits'].append(enqueue_data.sample_splits)
-
- kwargs['aggregation_weights'].append(
- enqueue_data.aggregation_weights if enqueue_data.aggregation_weights
- is not None else array_ops.zeros((0,), dtype=dtypes.float32))
-
- kwargs['embedding_indices'].append(enqueue_data.embedding_indices)
-
- kwargs['table_ids'].append(table_id)
- kwargs['max_sequence_lengths'].append(
- self._feature_to_config_dict[feature].max_sequence_length)
-
- return kwargs
+ return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
+ device_ordinal=device_ordinal,
+ combiners=self._combiners,
+ mode_override=mode_override,
+ **self._format_for_tpu_embedding_sparse_tensor_batch(enqueue_datas)
+ )
def _format_for_tpu_embedding_sparse_tensor_batch(self, enqueue_datas):
"""Format sparse features for `enqueue_tpu_embedding_sparse_tensor_batch()`.
Args:
- enqueue_datas: a `Dict` of `EnqueueData` objects for embedding.
+ enqueue_datas: a `Dict` of tensors for embedding. Can be sparse or
+ dense.
Returns:
Dict of arguments for `enqueue_tpu_embedding_sparse_tensor_batch()`.
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index af2a47f..80aca63 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -1245,10 +1245,6 @@
argspec: "args=[\'batch\', \'mode_override\', \'device_ordinal\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
}
member_method {
- name: "EnqueueTPUEmbeddingRaggedTensorBatch"
- argspec: "args=[\'sample_splits\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'None\'], "
- }
- member_method {
name: "EnqueueTPUEmbeddingSparseBatch"
argspec: "args=[\'sample_indices\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'device_ordinal\', \'combiners\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index af2a47f..80aca63 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -1245,10 +1245,6 @@
argspec: "args=[\'batch\', \'mode_override\', \'device_ordinal\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
}
member_method {
- name: "EnqueueTPUEmbeddingRaggedTensorBatch"
- argspec: "args=[\'sample_splits\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'None\'], "
- }
- member_method {
name: "EnqueueTPUEmbeddingSparseBatch"
argspec: "args=[\'sample_indices\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'device_ordinal\', \'combiners\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'None\'], "
}