Fix the issue when running tpu embedding test in auto graph.
PiperOrigin-RevId: 414640929
Change-Id: I9121d215e21b7ef99fbc84d9c2c0026b5bae216c
diff --git a/tensorflow/python/tpu/tpu_embedding_v2.py b/tensorflow/python/tpu/tpu_embedding_v2.py
index 1b1aaa6..f8d681d 100644
--- a/tensorflow/python/tpu/tpu_embedding_v2.py
+++ b/tensorflow/python/tpu/tpu_embedding_v2.py
@@ -15,6 +15,7 @@
"""Mid level API for TPU Embeddings."""
import functools
+import math
from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Tuple, Union
from absl import logging
@@ -37,7 +38,6 @@
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import numpy_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as tf_variables
@@ -567,8 +567,8 @@
# optimization configurations we have, the worse the performance will be.
num_features = {table: 0 for table in self._table_config}
for i, feature in enumerate(nest.flatten(self._feature_config)):
- num_features[feature.table] += math_ops.reduce_prod(
- self._output_shapes[i]) / tensor_core_batch_size
+ num_features[feature.table] += self._get_reduce_prod(
+ self._output_shapes[i]) // tensor_core_batch_size
# Map each callable dynamic learning rate to its in index in the list.
learning_rate_index = {r: i for i, r in enumerate(
@@ -1439,12 +1439,19 @@
def _get_tensor_core_batch_size(self, output_shapes):
"""Get the tensor core batch size based on the output shapes."""
- tensor_core_batch_size = math_ops.reduce_prod(output_shapes[0])
+ tensor_core_batch_size = self._get_reduce_prod(output_shapes[0])
for output_shape in output_shapes[1:]:
- tensor_core_batch_size = numpy_ops.gcd(tensor_core_batch_size,
- math_ops.reduce_prod(output_shape))
+ tensor_core_batch_size = math.gcd(tensor_core_batch_size,
+ self._get_reduce_prod(output_shape))
return tensor_core_batch_size
+ def _get_reduce_prod(self, shape: TensorShape) -> int:
+ """Get the reduce prod of a tensorshape."""
+ result = 1
+ for dim in shape.as_list():
+ result *= dim
+ return result
+
def _update_output_shapes(self, incoming_output_shapes: List[TensorShape]):
"""Update the existing output shapes based on the new output shapes.