Internal change

PiperOrigin-RevId: 303437665
Change-Id: I4142a9c4b125a8537ae60b4cfd12ffbe0c376574
diff --git a/tensorflow/core/api_def/base_api/api_def_EnqueueTPUEmbeddingRaggedTensorBatch.pbtxt b/tensorflow/core/api_def/base_api/api_def_EnqueueTPUEmbeddingRaggedTensorBatch.pbtxt
deleted file mode 100644
index cdcdd6d..0000000
--- a/tensorflow/core/api_def/base_api/api_def_EnqueueTPUEmbeddingRaggedTensorBatch.pbtxt
+++ /dev/null
@@ -1,77 +0,0 @@
-op {
-  graph_op_name: "EnqueueTPUEmbeddingRaggedTensorBatch"
-  visibility: HIDDEN
-  in_arg {
-    name: "sample_splits"
-    description: <<END
-A list of rank 1 Tensors specifying the break points for splitting
-embedding_indices and aggregation_weights into rows.
-It corresponds to ids.row_splits in embedding_lookup(), when ids is a
-RaggedTensor.
-END
-  }
-  in_arg {
-    name: "embedding_indices"
-    description: <<END
-A list of rank 1 Tensors, indices into the embedding tables.
-It corresponds to ids.values in embedding_lookup(), when ids is a RaggedTensor.
-END
-  }
-  in_arg {
-    name: "aggregation_weights"
-    description: <<END
-A list of rank 1 Tensors containing per training example
-aggregation weights. It corresponds to the values field of a RaggedTensor
-with the same row_splits as ids in embedding_lookup(), when ids is a
-RaggedTensor.
-END
-  }
-  in_arg {
-    name: "mode_override"
-    description: <<END
-A string input that overrides the mode specified in the
-TPUEmbeddingConfiguration. Supported values are {'unspecified', 'inference',
-'training', 'backward_pass_only'}. When set to 'unspecified', the mode set
-in TPUEmbeddingConfiguration is used, otherwise mode_override is used.
-END
-  }
-  attr {
-    name: "device_ordinal"
-    description: <<END
-The TPU device to use. Should be >= 0 and less than the number
-of TPU cores in the task on which the node is placed.
-END
-  }
-  attr {
-    name: "combiners"
-    description: <<END
-A list of string scalars, one for each embedding table that specify
-how to normalize the embedding activations after weighted summation.
-Supported combiners are 'mean', 'sum', or 'sqrtn'. It is invalid to have
-the sum of the weights be 0 for 'mean' or the sum of the squared weights be
-0 for 'sqrtn'. If combiners isn't passed, the default is to use 'sum' for
-all tables.
-END
-  }
-  attr {
-    name: "table_ids"
-    description: <<END
-A list of integers specifying the identifier of the embedding table
-(offset of TableDescriptor in the TPUEmbeddingConfiguration) to lookup the
-corresponding input. The ith input is looked up using table_ids[i]. The size
-of the table_ids list must be equal to that of sample_indices,
-embedding_indices and aggregation_weights.
-END
-  }
-  summary: "Eases the porting of code that uses tf.nn.embedding_lookup()."
-  description: <<END
-sample_splits[i], embedding_indices[i] and aggregation_weights[i] correspond
-to the ith feature. table_ids[i] indicates which embedding table to look up ith
-feature.
-
-The tensors at corresponding positions in two of the input lists,
-embedding_indices and aggregation_weights, must have the same shape, i.e. rank 1
-with dim_size() equal to the total number of lookups into the table described by
-the corresponding feature.
-END
-}
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 9175088..c4d9dae 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -13185,101 +13185,6 @@
   is_stateful: true
 }
 op {
-  name: "EnqueueTPUEmbeddingRaggedTensorBatch"
-  input_arg {
-    name: "sample_indices"
-    type_attr: "T1"
-    number_attr: "N"
-  }
-  input_arg {
-    name: "embedding_indices"
-    type_attr: "T2"
-    number_attr: "N"
-  }
-  input_arg {
-    name: "aggregation_weights"
-    type_attr: "T3"
-    number_attr: "N"
-  }
-  input_arg {
-    name: "mode_override"
-    type: DT_STRING
-  }
-  attr {
-    name: "T1"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "T2"
-    type: "type"
-    default_value {
-      type: DT_INT32
-    }
-    allowed_values {
-      list {
-        type: DT_INT32
-        type: DT_INT64
-      }
-    }
-  }
-  attr {
-    name: "T3"
-    type: "type"
-    default_value {
-      type: DT_FLOAT
-    }
-    allowed_values {
-      list {
-        type: DT_FLOAT
-        type: DT_DOUBLE
-      }
-    }
-  }
-  attr {
-    name: "N"
-    type: "int"
-    has_minimum: true
-    minimum: 1
-  }
-  attr {
-    name: "device_ordinal"
-    type: "int"
-    default_value {
-      i: -1
-    }
-  }
-  attr {
-    name: "combiners"
-    type: "list(string)"
-    default_value {
-      list {
-      }
-    }
-  }
-  attr {
-    name: "table_ids"
-    type: "list(int)"
-  }
-  attr {
-    name: "max_sequence_lengths"
-    type: "list(int)"
-    default_value {
-      list {
-      }
-    }
-  }
-  is_stateful: true
-}
-op {
   name: "EnsureShape"
   input_arg {
     name: "input"
diff --git a/tensorflow/core/ops/tpu_embedding_ops.cc b/tensorflow/core/ops/tpu_embedding_ops.cc
index 164d78e..821dff7 100644
--- a/tensorflow/core/ops/tpu_embedding_ops.cc
+++ b/tensorflow/core/ops/tpu_embedding_ops.cc
@@ -168,20 +168,4 @@
     .SetIsStateful()
     .SetShapeFn(shape_inference::UnknownShape);
 
-REGISTER_OP("EnqueueTPUEmbeddingRaggedTensorBatch")
-    .Input("sample_splits: N * T1")
-    .Input("embedding_indices: N * T2")
-    .Input("aggregation_weights: N * T3")
-    .Input("mode_override: string")
-    .Attr("T1: {int32,int64} = DT_INT32")
-    .Attr("T2: {int32,int64} = DT_INT32")
-    .Attr("T3: {float32,float64} = DT_FLOAT")
-    .Attr("N: int >= 1")
-    .Attr("device_ordinal: int = -1")
-    .Attr("combiners: list(string) = []")
-    .Attr("table_ids: list(int)")
-    .Attr("max_sequence_lengths: list(int) = []")
-    .SetIsStateful()
-    .SetShapeFn(shape_inference::UnknownShape);
-
 }  // namespace tensorflow
diff --git a/tensorflow/python/tpu/ops/tpu_ops.py b/tensorflow/python/tpu/ops/tpu_ops.py
index 8facb1f..c1ea364 100644
--- a/tensorflow/python/tpu/ops/tpu_ops.py
+++ b/tensorflow/python/tpu/ops/tpu_ops.py
@@ -444,76 +444,3 @@
 
 enqueue_tpu_embedding_sparse_tensor_batch.__doc__ = (
     gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch.__doc__)
-
-
-# pylint: disable=protected-access
-def enqueue_tpu_embedding_ragged_tensor_batch(sample_splits,
-                                              embedding_indices,
-                                              aggregation_weights,
-                                              table_ids,
-                                              device_ordinal,
-                                              max_sequence_lengths=None,
-                                              combiners=None,
-                                              mode_override=None,
-                                              name=None):
-  """A placeholder op for enqueueing embedding IDs to the TPU.
-
-  Args:
-    sample_splits: A list of rank 1 Tensors specifying the break points for
-      splitting embedding_indices and aggregation_weights into rows. It
-      corresponds to ids.row_splits in embedding_lookup(), when ids is a
-      RaggedTensor. Both int32 and int64 are allowed and will be converted to
-      int32 internally.
-    embedding_indices: A list of rank 1 Tensors, indices into the embedding
-      tables. It corresponds to ids.values in embedding_lookup(), when ids is a
-      RaggedTensor. Both int32 and int64 are allowed and will be converted to
-      int32 internally.
-    aggregation_weights: A list of rank 1 Tensors containing per training
-      example aggregation weights. It corresponds to the values field of a
-      RaggedTensor with the same row_splits as ids in embedding_lookup(), when
-      ids is a RaggedTensor. Both float32 and float64 are allowed and will be
-      converted to float32 internally.
-    table_ids: A list of integers specifying the identifier of the embedding
-      table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
-      lookup the corresponding input. The ith input is looked up using
-      table_ids[i]. The size of the table_ids list must be equal to that of
-      sample_indices, embedding_indices and aggregation_weights.
-    device_ordinal: The TPU device to use. Should be >= 0 and less than the
-      number of TPU cores in the task on which the node is placed.
-    max_sequence_lengths: A list of integers, the size of which is equal to
-      sample_indices. If equal to 0, the corresponding feature is considered to
-      be a non-sequence feature, If greater than 0, the corresponding feature is
-      a sequence feature with the given maximal length. If None, then we assume
-      a list of all zeroes.
-    combiners: A list of string scalars, one for each embedding table that
-      specify how to normalize the embedding activations after weighted
-      summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
-      invalid to have the sum of the weights be 0 for 'mean' or the sum of the
-      squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
-      is to use 'sum' for all tables (optional).
-    mode_override: A string input that overrides the mode specified in the
-      TPUEmbeddingConfiguration. Supported values are {'unspecified',
-      'inference', 'training', 'backward_pass_only'}. When set to 'unspecified',
-      the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override
-      is used (optional).
-    name: A name for the operation (optional).
-
-  Returns:
-    An EnqueueTPUEmbeddingRaggedTensorBatch operation.
-  """
-  if mode_override is None:
-    mode_override = "unspecified"
-  return gen_tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch(
-      sample_splits=sample_splits,
-      embedding_indices=embedding_indices,
-      aggregation_weights=aggregation_weights,
-      table_ids=table_ids,
-      device_ordinal=device_ordinal,
-      max_sequence_lengths=max_sequence_lengths,
-      combiners=combiners,
-      mode_override=mode_override,
-      name=name)
-
-
-enqueue_tpu_embedding_ragged_tensor_batch.__doc__ = (
-    gen_tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch.__doc__)
diff --git a/tensorflow/python/tpu/tpu_embedding.py b/tensorflow/python/tpu/tpu_embedding.py
index e24188e..e3dbe7f 100644
--- a/tensorflow/python/tpu/tpu_embedding.py
+++ b/tensorflow/python/tpu/tpu_embedding.py
@@ -205,48 +205,6 @@
         aggregation_weights=weights.values if weights is not None else None)
 
 
-class RaggedEnqueueData(
-    collections.namedtuple(
-        'RaggedEnqueueData',
-        ['embedding_indices', 'sample_splits', 'aggregation_weights'])):
-  """RaggedTensor Data to be enqueued through generate_enqueue_ops()."""
-
-  def __new__(cls,
-              embedding_indices,
-              sample_splits=None,
-              aggregation_weights=None):
-    """Data to be enqueued through generate_enqueue_ops().
-
-    Args:
-      embedding_indices: A rank 1 Tensor, indices into the embedding tables. It
-        corresponds to ids.values in embedding_lookup(), when ids is a
-        RaggedTensor. Both int32 and int64 are allowed and will be converted to
-        int32 internally.
-      sample_splits: A rank 1 Tensor specifying the break points for splitting
-        embedding_indices and aggregation_weights into rows. It corresponds to
-        ids.row_splits in embedding_lookup(), when ids is a RaggedTensor. Both
-        int32 and int64 are allowed and will be converted to int32 internally.
-      aggregation_weights: A rank 1 Tensor containing per training example
-        aggregation weights. It corresponds to the values field of a
-        RaggedTensor with the same row_splits as ids in embedding_lookup(), when
-        ids is a RaggedTensor.
-
-    Returns:
-      An RaggedEnqueueData tuple.
-
-    """
-    return super(RaggedEnqueueData,
-                 cls).__new__(cls, embedding_indices, sample_splits,
-                              aggregation_weights)
-
-  @staticmethod
-  def from_ragged_tensor(rg_tensor, weights=None):
-    return RaggedEnqueueData(
-        rg_tensor.values,
-        rg_tensor.row_splits,
-        aggregation_weights=weights.values if weights is not None else None)
-
-
 def get_enqueue_datas_list_from_sparse_tensors_list(sp_tensors_list):
   """Convenient function for generate_enqueue_ops().
 
@@ -271,30 +229,6 @@
   return enqueue_datas_list
 
 
-def get_enqueue_datas_list_from_ragged_tensors_list(rg_tensors_list):
-  """Convenient function for generate_enqueue_ops().
-
-  Args:
-    rg_tensors_list: a list of dictionary mapping from string of feature names
-      to RaggedTensor. Each dictionary is for one TPU core. Dictionaries for the
-      same host should be contiguous on the list.
-
-  Returns:
-    enqueue_datas_list: a list of dictionary mapping from string
-      of feature names to RaggedEnqueueData. Each dictionary is for one
-      TPU core. Dictionaries for the same host should be contiguous
-      on the list.
-
-  """
-  enqueue_datas_list = []
-  for rg_tensors in rg_tensors_list:
-    enqueue_datas = collections.OrderedDict(
-        (k, RaggedEnqueueData.from_ragged_tensor(v))
-        for k, v in six.iteritems(rg_tensors))
-    enqueue_datas_list.append(enqueue_datas)
-  return enqueue_datas_list
-
-
 AdamSlotVariableNames = collections.namedtuple(
     'AdamSlotVariableNames', ['m', 'v'])
 
@@ -1225,12 +1159,7 @@
                            slot_variables_by_table,
                            load_ops, retrieve_ops)
 
-  def generate_enqueue_ops(
-      self,
-      enqueue_datas_list,
-      mode_override=None,
-      ragged=False,
-  ):
+  def generate_enqueue_ops(self, enqueue_datas_list, mode_override=None):
     """Generate enqueue ops.
 
     Args:
@@ -1243,8 +1172,6 @@
         'inference', 'training', 'backward_pass_only'}. When set to
         'unspecified', the mode set in TPUEmbeddingConfiguration is used,
         otherwise mode_override is used (optional).
-      ragged: If True, creates RaggedTensor enqueue ops rather than
-        SparseTensor.
 
     Returns:
       Ops to enqueue to TPU for embedding.
@@ -1255,7 +1182,6 @@
             enqueue_datas,
             device_ordinal=i % self._num_cores_per_host,
             mode_override=mode_override,
-            ragged=ragged,
         ) for i, enqueue_datas in enumerate(enqueue_datas_list)
     ]
 
@@ -1285,50 +1211,28 @@
       for feature, enqueue_data in six.iteritems(enqueue_datas):
         combiner = self._table_to_config_dict[
             self._feature_to_config_dict[feature].table_id].combiner
+        if not isinstance(enqueue_data, EnqueueData):
+          raise ValueError('`enqueue_datas_list[{}]` has a feature that is '
+                           'not mapped to `EnqueueData`. `feature`: {}'.format(
+                               i, feature))
 
-        if isinstance(enqueue_data, EnqueueData):
-          if enqueue_data.sample_indices is None and combiner:
-            logging.warn(
-                'No sample indices set for features %f table %f but '
-                'combiner is set to %s.', feature,
-                self._feature_to_config_dict[feature].table_id, combiner)
-          if (enqueue_data.sample_indices is not None and
-              enqueue_data.sample_indices.device !=
-              enqueue_data.embedding_indices.device):
-            raise ValueError(
-                'Device of sample_indices does not agree with '
-                'that of embedding_indices for feature {}.'.format(feature))
-          if (enqueue_data.aggregation_weights is not None and
-              enqueue_data.aggregation_weights.device !=
-              enqueue_data.embedding_indices.device):
-            raise ValueError(
-                'Device of aggregation_weights does not agree with '
-                'that of embedding_indices for feature {}.'.format(feature))
+        if enqueue_data.sample_indices is None and combiner:
+          logging.warn('No sample indices set for features %f table %f but '
+                       'combiner is set to %s.', feature,
+                       self._feature_to_config_dict[feature].table_id, combiner)
 
-        elif isinstance(enqueue_data, RaggedEnqueueData):
-          if enqueue_data.sample_splits is None and combiner:
-            logging.warn(
-                'No sample splits set for features %f table %f but '
-                'combiner is set to %s.', feature,
-                self._feature_to_config_dict[feature].table_id, combiner)
-          if (enqueue_data.sample_splits is not None and
-              enqueue_data.sample_splits.device !=
-              enqueue_data.embedding_indices.device):
-            raise ValueError(
-                'Device of sample_splits does not agree with '
-                'that of embedding_indices for feature {}.'.format(feature))
-          if (enqueue_data.aggregation_weights is not None and
-              enqueue_data.aggregation_weights.device !=
-              enqueue_data.embedding_indices.device):
-            raise ValueError(
-                'Device of aggregation_weights does not agree with '
-                'that of embedding_indices for feature {}.'.format(feature))
-
-        else:
+        if (enqueue_data.sample_indices is not None and
+            enqueue_data.sample_indices.device !=
+            enqueue_data.embedding_indices.device):
           raise ValueError(
-              '`enqueue_datas_list[{}]` has a feature that is not mapped to '
-              '`EnqueueData` or `RaggedEnqueueData`. `feature`: {}'.format(
-                  i, feature))
+              'Device of sample_indices does not agree with '
+              'that of embedding_indices for feature {}.'.format(feature))
+        if (enqueue_data.aggregation_weights is not None and
+            enqueue_data.aggregation_weights.device !=
+            enqueue_data.embedding_indices.device):
+          raise ValueError(
+              'Device of aggregation_weights does not agree with '
+              'that of embedding_indices for feature {}.'.format(feature))
         # Check all features are on the same device.
         if device is None:
           device = enqueue_data.embedding_indices.device
@@ -1353,69 +1257,23 @@
       else:
         contiguous_device = device
 
-  def _generate_enqueue_op(self,
-                           enqueue_datas,
-                           device_ordinal,
-                           mode_override=None,
-                           ragged=False):
-    """Creates op for enqueuing batch to TPU."""
+  def _generate_enqueue_op(
+      self, enqueue_datas, device_ordinal, mode_override=None):
     enqueue_data0 = list(enqueue_datas.values())[0]
     with ops.colocate_with(enqueue_data0.embedding_indices):
-      if ragged:
-        # note that this is currently identical in behavior
-        return tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch(
-            device_ordinal=device_ordinal,
-            combiners=self._combiners,
-            mode_override=mode_override,
-            **self._format_for_tpu_embedding_ragged_tensor_batch(enqueue_datas))
-      else:
-        return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
-            device_ordinal=device_ordinal,
-            combiners=self._combiners,
-            mode_override=mode_override,
-            **self._format_for_tpu_embedding_sparse_tensor_batch(enqueue_datas))
-
-  def _format_for_tpu_embedding_ragged_tensor_batch(self, enqueue_datas):
-    """Format sparse features for `enqueue_tpu_embedding_ragged_tensor_batch()`.
-
-    Args:
-      enqueue_datas: a `Dict` of `RaggedEnqueueData` objects for embedding.
-
-    Returns:
-      Dict of arguments for `enqueue_tpu_embedding_ragged_tensor_batch()`.
-    """
-
-    kwargs = {
-        'sample_splits': [],
-        'embedding_indices': [],
-        'aggregation_weights': [],
-        'table_ids': [],
-        'max_sequence_lengths': [],
-    }
-    for table_id, table in enumerate(self._table_to_features_dict):
-      features = self._table_to_features_dict[table]
-      for feature in features:
-        enqueue_data = enqueue_datas[feature]
-
-        kwargs['sample_splits'].append(enqueue_data.sample_splits)
-
-        kwargs['aggregation_weights'].append(
-            enqueue_data.aggregation_weights if enqueue_data.aggregation_weights
-            is not None else array_ops.zeros((0,), dtype=dtypes.float32))
-
-        kwargs['embedding_indices'].append(enqueue_data.embedding_indices)
-
-        kwargs['table_ids'].append(table_id)
-        kwargs['max_sequence_lengths'].append(
-            self._feature_to_config_dict[feature].max_sequence_length)
-
-    return kwargs
+      return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
+          device_ordinal=device_ordinal,
+          combiners=self._combiners,
+          mode_override=mode_override,
+          **self._format_for_tpu_embedding_sparse_tensor_batch(enqueue_datas)
+      )
 
   def _format_for_tpu_embedding_sparse_tensor_batch(self, enqueue_datas):
     """Format sparse features for `enqueue_tpu_embedding_sparse_tensor_batch()`.
 
     Args:
-      enqueue_datas: a `Dict` of `EnqueueData` objects for embedding.
+      enqueue_datas: a `Dict` of tensors for embedding. Can be sparse or
+      dense.
 
     Returns:
       Dict of arguments for `enqueue_tpu_embedding_sparse_tensor_batch()`.
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index af2a47f..80aca63 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -1245,10 +1245,6 @@
     argspec: "args=[\'batch\', \'mode_override\', \'device_ordinal\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
   }
   member_method {
-    name: "EnqueueTPUEmbeddingRaggedTensorBatch"
-    argspec: "args=[\'sample_splits\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'None\'], "
-  }
-  member_method {
     name: "EnqueueTPUEmbeddingSparseBatch"
     argspec: "args=[\'sample_indices\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'device_ordinal\', \'combiners\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'None\'], "
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index af2a47f..80aca63 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -1245,10 +1245,6 @@
     argspec: "args=[\'batch\', \'mode_override\', \'device_ordinal\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
   }
   member_method {
-    name: "EnqueueTPUEmbeddingRaggedTensorBatch"
-    argspec: "args=[\'sample_splits\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'table_ids\', \'device_ordinal\', \'combiners\', \'max_sequence_lengths\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'[]\', \'None\'], "
-  }
-  member_method {
     name: "EnqueueTPUEmbeddingSparseBatch"
     argspec: "args=[\'sample_indices\', \'embedding_indices\', \'aggregation_weights\', \'mode_override\', \'device_ordinal\', \'combiners\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'[]\', \'None\'], "
   }