Implement Map over Array3D.

PiperOrigin-RevId: 438644462
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc
index ee56aff..d04ad07 100644
--- a/tensorflow/compiler/xla/reference_util.cc
+++ b/tensorflow/compiler/xla/reference_util.cc
@@ -665,6 +665,43 @@
   return result;
 }
 
+/* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::MapArray3D(
+    const Array3D<float>& array,
+    const std::function<float(float)>& map_function) {
+  int64_t n1 = array.n1();
+  int64_t n2 = array.n2();
+  int64_t n3 = array.n3();
+  auto result = absl::make_unique<Array3D<float>>(n1, n2, n3);
+  for (int64_t i = 0; i < n1; ++i) {
+    for (int64_t j = 0; j < n2; ++j) {
+      for (int64_t k = 0; k < n3; ++k) {
+        (*result)(i, j, k) = map_function(array(i, j, k));
+      }
+    }
+  }
+  return result;
+}
+
+/* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::MapArray3D(
+    const Array3D<float>& lhs, const Array3D<float>& rhs,
+    const std::function<float(float, float)>& map_function) {
+  CHECK_EQ(lhs.n1(), rhs.n1());
+  CHECK_EQ(lhs.n2(), rhs.n2());
+  CHECK_EQ(lhs.n3(), rhs.n3());
+  int64_t n1 = lhs.n1();
+  int64_t n2 = rhs.n2();
+  int64_t n3 = rhs.n3();
+  auto result = absl::make_unique<Array3D<float>>(n1, n2, n3);
+  for (int64_t i = 0; i < n1; ++i) {
+    for (int64_t j = 0; j < n2; ++j) {
+      for (int64_t k = 0; k < n3; ++k) {
+        (*result)(i, j, k) = map_function(lhs(i, j, k), rhs(i, j, k));
+      }
+    }
+  }
+  return result;
+}
+
 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapWithIndexArray2D(
     const Array2D<float>& matrix,
     const std::function<float(float, int64_t, int64_t)>& map_function) {
diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h
index 10e71fe..50af95a 100644
--- a/tensorflow/compiler/xla/reference_util.h
+++ b/tensorflow/compiler/xla/reference_util.h
@@ -169,6 +169,18 @@
       const Array2D<float>& lhs, const Array2D<float>& rhs,
       const std::function<float(float, float)>& map_function);
 
+  // Applies map_function to each element in the input (3D array) and returns
+  // the result.
+  static std::unique_ptr<Array3D<float>> MapArray3D(
+      const Array3D<float>& array,
+      const std::function<float(float)>& map_function);
+
+  // Applies map_function to each pair of corresponding elements in the two
+  // inputs arrays and returns the result.
+  static std::unique_ptr<Array3D<float>> MapArray3D(
+      const Array3D<float>& lhs, const Array3D<float>& rhs,
+      const std::function<float(float, float)>& map_function);
+
   // Number of windows in a given dimension. Calculation taken from
   // xla::MakePadding().
   static int64_t WindowCount(int64_t unpadded_width, int64_t window_len,
diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc
index 9746b3d..79a6bb1 100644
--- a/tensorflow/compiler/xla/reference_util_test.cc
+++ b/tensorflow/compiler/xla/reference_util_test.cc
@@ -101,6 +101,16 @@
                                        ErrorSpec(0.0001));
 }
 
+TEST_F(ReferenceUtilTest, MapArray3D) {
+  auto identity = [](float value) { return std::log(std::exp(value)); };
+  Array3D<float> input(2, 3, 4);
+  input.FillIota(0);
+  auto result = ReferenceUtil::MapArray3D(input, identity);
+  auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result);
+  LiteralTestUtil::ExpectR3NearArray3D(input, actual_literal,
+                                       ErrorSpec(0.0001));
+}
+
 TEST_F(ReferenceUtilTest, MapWithIndexArray2D) {
   auto add_index = [](float value, int64_t row, int64_t col) {
     return value + row + col;