Implement Map over Array3D.
PiperOrigin-RevId: 438644462
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc
index ee56aff..d04ad07 100644
--- a/tensorflow/compiler/xla/reference_util.cc
+++ b/tensorflow/compiler/xla/reference_util.cc
@@ -665,6 +665,43 @@
return result;
}
+/* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::MapArray3D(
+ const Array3D<float>& array,
+ const std::function<float(float)>& map_function) {
+ int64_t n1 = array.n1();
+ int64_t n2 = array.n2();
+ int64_t n3 = array.n3();
+ auto result = absl::make_unique<Array3D<float>>(n1, n2, n3);
+ for (int64_t i = 0; i < n1; ++i) {
+ for (int64_t j = 0; j < n2; ++j) {
+ for (int64_t k = 0; k < n3; ++k) {
+ (*result)(i, j, k) = map_function(array(i, j, k));
+ }
+ }
+ }
+ return result;
+}
+
+/* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::MapArray3D(
+ const Array3D<float>& lhs, const Array3D<float>& rhs,
+ const std::function<float(float, float)>& map_function) {
+ CHECK_EQ(lhs.n1(), rhs.n1());
+ CHECK_EQ(lhs.n2(), rhs.n2());
+ CHECK_EQ(lhs.n3(), rhs.n3());
+ int64_t n1 = lhs.n1();
+ int64_t n2 = rhs.n2();
+ int64_t n3 = rhs.n3();
+ auto result = absl::make_unique<Array3D<float>>(n1, n2, n3);
+ for (int64_t i = 0; i < n1; ++i) {
+ for (int64_t j = 0; j < n2; ++j) {
+ for (int64_t k = 0; k < n3; ++k) {
+ (*result)(i, j, k) = map_function(lhs(i, j, k), rhs(i, j, k));
+ }
+ }
+ }
+ return result;
+}
+
/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapWithIndexArray2D(
const Array2D<float>& matrix,
const std::function<float(float, int64_t, int64_t)>& map_function) {
diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h
index 10e71fe..50af95a 100644
--- a/tensorflow/compiler/xla/reference_util.h
+++ b/tensorflow/compiler/xla/reference_util.h
@@ -169,6 +169,18 @@
const Array2D<float>& lhs, const Array2D<float>& rhs,
const std::function<float(float, float)>& map_function);
+ // Applies map_function to each element in the input (3D array) and returns
+ // the result.
+ static std::unique_ptr<Array3D<float>> MapArray3D(
+ const Array3D<float>& array,
+ const std::function<float(float)>& map_function);
+
+ // Applies map_function to each pair of corresponding elements in the two
+ // inputs arrays and returns the result.
+ static std::unique_ptr<Array3D<float>> MapArray3D(
+ const Array3D<float>& lhs, const Array3D<float>& rhs,
+ const std::function<float(float, float)>& map_function);
+
// Number of windows in a given dimension. Calculation taken from
// xla::MakePadding().
static int64_t WindowCount(int64_t unpadded_width, int64_t window_len,
diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc
index 9746b3d..79a6bb1 100644
--- a/tensorflow/compiler/xla/reference_util_test.cc
+++ b/tensorflow/compiler/xla/reference_util_test.cc
@@ -101,6 +101,16 @@
ErrorSpec(0.0001));
}
+TEST_F(ReferenceUtilTest, MapArray3D) {
+ auto identity = [](float value) { return std::log(std::exp(value)); };
+ Array3D<float> input(2, 3, 4);
+ input.FillIota(0);
+ auto result = ReferenceUtil::MapArray3D(input, identity);
+ auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result);
+ LiteralTestUtil::ExpectR3NearArray3D(input, actual_literal,
+ ErrorSpec(0.0001));
+}
+
TEST_F(ReferenceUtilTest, MapWithIndexArray2D) {
auto add_index = [](float value, int64_t row, int64_t col) {
return value + row + col;