Fix the issue when running tpu embedding test in auto graph.

PiperOrigin-RevId: 414640929
Change-Id: I9121d215e21b7ef99fbc84d9c2c0026b5bae216c
diff --git a/tensorflow/python/tpu/tpu_embedding_v2.py b/tensorflow/python/tpu/tpu_embedding_v2.py
index 1b1aaa6..f8d681d 100644
--- a/tensorflow/python/tpu/tpu_embedding_v2.py
+++ b/tensorflow/python/tpu/tpu_embedding_v2.py
@@ -15,6 +15,7 @@
 """Mid level API for TPU Embeddings."""
 
 import functools
+import math
 from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Tuple, Union
 
 from absl import logging
@@ -37,7 +38,6 @@
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import embedding_ops
 from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import numpy_ops
 from tensorflow.python.ops import sparse_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables as tf_variables
@@ -567,8 +567,8 @@
     #    optimization configurations we have, the worse the performance will be.
     num_features = {table: 0 for table in self._table_config}
     for i, feature in enumerate(nest.flatten(self._feature_config)):
-      num_features[feature.table] += math_ops.reduce_prod(
-          self._output_shapes[i]) / tensor_core_batch_size
+      num_features[feature.table] += self._get_reduce_prod(
+          self._output_shapes[i]) // tensor_core_batch_size
 
     # Map each callable dynamic learning rate to its in index in the list.
     learning_rate_index = {r: i for i, r in enumerate(
@@ -1439,12 +1439,19 @@
 
   def _get_tensor_core_batch_size(self, output_shapes):
     """Get the tensor core batch size based on the output shapes."""
-    tensor_core_batch_size = math_ops.reduce_prod(output_shapes[0])
+    tensor_core_batch_size = self._get_reduce_prod(output_shapes[0])
     for output_shape in output_shapes[1:]:
-      tensor_core_batch_size = numpy_ops.gcd(tensor_core_batch_size,
-                                             math_ops.reduce_prod(output_shape))
+      tensor_core_batch_size = math.gcd(tensor_core_batch_size,
+                                        self._get_reduce_prod(output_shape))
     return tensor_core_batch_size
 
+  def _get_reduce_prod(self, shape: TensorShape) -> int:
+    """Get the reduce prod of a tensorshape."""
+    result = 1
+    for dim in shape.as_list():
+      result *= dim
+    return result
+
   def _update_output_shapes(self, incoming_output_shapes: List[TensorShape]):
     """Update the existing output shapes based on the new output shapes.