change MirrorPad packet region
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index f66102a..8a24193 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -3166,6 +3166,7 @@
"adjust_contrast_op_test.cc",
"colorspace_op_test.cc",
"crop_and_resize_op_test.cc",
+ "mirror_pad_op_test.cc",
"non_max_suppression_op_test.cc",
"resize_area_op_test.cc",
"resize_bicubic_op_test.cc",
@@ -3178,6 +3179,7 @@
}),
deps = [
":image",
+ ":mirror_pad_op",
":ops_testutil",
":ops_util",
":sampling_kernels",
@@ -3245,6 +3247,22 @@
)
tf_cuda_cc_test(
+ name = "mirror_pad_op_benchmark_test",
+ srcs = ["mirror_pad_op_benchmark_test.cc"],
+ deps = [
+ ":mirror_pad_op",
+ ":ops_testutil",
+ ":ops_util",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
+tf_cuda_cc_test(
name = "non_max_suppression_op_gpu_test",
srcs = ["non_max_suppression_op_gpu_test.cc"],
tags = tf_cuda_tests_tags() + ["no_cuda_on_cpu_tap"],
diff --git a/tensorflow/core/kernels/mirror_pad_op.h b/tensorflow/core/kernels/mirror_pad_op.h
index eda3b2b..b94aec9 100644
--- a/tensorflow/core/kernels/mirror_pad_op.h
+++ b/tensorflow/core/kernels/mirror_pad_op.h
@@ -16,9 +16,9 @@
#ifndef TENSORFLOW_CORE_KERNELS_MIRROR_PAD_OP_H_
#define TENSORFLOW_CORE_KERNELS_MIRROR_PAD_OP_H_
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
namespace Eigen {
template <typename PaddingDimensions, typename XprType>
@@ -223,7 +223,8 @@
const Index right =
(dimensions_[dim] - padding_[dim].second) * output_strides_[dim];
- if (left <= index && (index + kPacketSize - 1) < right) {
+ const Index index_mod = index % (dimensions_[dim] * output_strides_[dim]);
+ if (left <= index_mod && (index_mod + kPacketSize - 1) < right) {
return impl_.template packet<Unaligned>(input_index);
}
diff --git a/tensorflow/core/kernels/mirror_pad_op_benchmark_test.cc b/tensorflow/core/kernels/mirror_pad_op_benchmark_test.cc
new file mode 100644
index 0000000..733d235
--- /dev/null
+++ b/tensorflow/core/kernels/mirror_pad_op_benchmark_test.cc
@@ -0,0 +1,59 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+
+static Graph* BM_MirrorPad(int batches, int height, int width, int depth,
+ int pad, const char* mode) {
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor in(DT_FLOAT, TensorShape({batches, height, width, depth}));
+ in.flat<float>().setRandom();
+ Tensor padding(DT_INT32, TensorShape({4, 2}));
+ auto boxes_tensor = padding.flat<int>().setZero();
+ for (int i = 2; i < 6; i++) boxes_tensor(i) = pad;
+
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MirrorPad")
+ .Input(test::graph::Constant(g, in))
+ .Input(test::graph::Constant(g, padding))
+ .Attr("mode", mode)
+ .Finalize(g, &ret));
+ return g;
+}
+
+#define BM_MirrorPadDev(DEVICE, B, W, H, D, P, MODE) \
+ static void BM_MirrorPad_##DEVICE##_##B##_##W##_##H##_##D##_##P##_##MODE( \
+ int iters) { \
+ testing::ItemsProcessed(iters* B*(W + 2 * P) * (H + 2 * P) * D / 32); \
+ test::Benchmark(#DEVICE, BM_MirrorPad(B, W, H, D, P, #MODE)).Run(iters); \
+ } \
+ BENCHMARK(BM_MirrorPad_##DEVICE##_##B##_##W##_##H##_##D##_##P##_##MODE);
+
+BM_MirrorPadDev(cpu, 1, 16, 16, 32, 1, REFLECT);
+BM_MirrorPadDev(cpu, 1, 16, 16, 32, 8, REFLECT);
+BM_MirrorPadDev(cpu, 1, 512, 512, 16, 1, REFLECT);
+BM_MirrorPadDev(cpu, 1, 512, 512, 16, 256, REFLECT);
+BM_MirrorPadDev(cpu, 1, 16, 16, 32, 1, SYMMETRIC);
+BM_MirrorPadDev(cpu, 1, 16, 16, 32, 8, SYMMETRIC);
+BM_MirrorPadDev(cpu, 1, 512, 512, 16, 1, SYMMETRIC);
+BM_MirrorPadDev(cpu, 1, 512, 512, 16, 256, SYMMETRIC);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mirror_pad_op_test.cc b/tensorflow/core/kernels/mirror_pad_op_test.cc
new file mode 100644
index 0000000..0afae5d
--- /dev/null
+++ b/tensorflow/core/kernels/mirror_pad_op_test.cc
@@ -0,0 +1,201 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/tensor_util.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+class MirrorPadOpTest : public OpsTestBase {
+ protected:
+ template <typename T>
+ void MakeOp(const string& mode) {
+ TF_EXPECT_OK(NodeDefBuilder("mirror_pad_op", "MirrorPad")
+ .Input(FakeInput(DataTypeToEnum<T>::value))
+ .Input(FakeInput(DT_INT32))
+ .Attr("mode", mode)
+ .Finalize(node_def()));
+ TF_EXPECT_OK(InitOp());
+ }
+};
+
+#define REGISTER_TEST(T) \
+ TEST_F(MirrorPadOpTest, TestMirrorPadReflect##T) { \
+ MakeOp<T>("REFLECT"); \
+ AddInputFromArray<T>(TensorShape({1, 2, 3, 1}), {1, 2, 3, 4, 5, 6}); \
+ AddInputFromArray<int32>(TensorShape({4, 2}), {0, 0, 1, 1, 2, 2, 0, 0}); \
+ TF_ASSERT_OK(RunOpKernel()); \
+ \
+ Tensor expected(allocator(), DataTypeToEnum<T>::value, \
+ TensorShape({1, 4, 7, 1})); \
+ test::FillValues<T>(&expected, \
+ {6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, \
+ 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}); \
+ test::ExpectTensorEqual<T>(expected, *GetOutput(0)); \
+ } \
+ \
+ TEST_F(MirrorPadOpTest, TestMirrorPadSymmetric##T) { \
+ MakeOp<T>("SYMMETRIC"); \
+ AddInputFromArray<T>(TensorShape({1, 2, 1, 3}), {1, 2, 3, 4, 5, 6}); \
+ AddInputFromArray<int32>(TensorShape({4, 2}), {1, 1, 0, 0, 0, 0, 2, 2}); \
+ TF_ASSERT_OK(RunOpKernel()); \
+ \
+ Tensor expected(allocator(), DataTypeToEnum<T>::value, \
+ TensorShape({3, 2, 1, 7})); \
+ test::FillValues<T>( \
+ &expected, \
+ {2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, \
+ 5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5}); \
+ test::ExpectTensorEqual<T>(expected, *GetOutput(0)); \
+ }
+
+REGISTER_TEST(float)
+REGISTER_TEST(double)
+REGISTER_TEST(uint8)
+REGISTER_TEST(uint16)
+REGISTER_TEST(int8)
+REGISTER_TEST(int16)
+REGISTER_TEST(int32)
+REGISTER_TEST(int64)
+
+#undef REGISTER_TEST
+
+TEST_F(MirrorPadOpTest, TestMirrorPadReflectLargeInput) {
+ MakeOp<float>("REFLECT");
+ // Generate a relatively large input
+ const int kInput = 1000;
+ const int kPad = 10;
+ const int kOutput = kInput + 2 * kPad;
+
+ // Input:
+ // 0, 1, 2, ..., 999
+ // 0, 1, 2, ..., 999
+ // ... (altogether 1000 lines)
+ // 0, 1, 2, ..., 999
+ AddInput<float>(TensorShape({1, kInput, kInput, 1}),
+ [](int i) -> float { return i % kInput; });
+ AddInputFromArray<int32>(TensorShape({4, 2}),
+ {0, 0, kPad, kPad, kPad, kPad, 0, 0});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({1, kOutput, kOutput, 1}));
+ test::FillFn<float>(&expected, [](int i) -> float {
+ i = i % kOutput;
+ if (0 <= i && i < kPad)
+ return kPad - i;
+ else if (kPad <= i && i < kInput + kPad)
+ return i - kPad;
+ else if (kInput + kPad <= i && i < kOutput)
+ return 2 * kInput + kPad - 2 - i;
+ });
+
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(MirrorPadOpTest, TestMirrorPadSymmetricLargeInput) {
+ MakeOp<float>("SYMMETRIC");
+ // Generate a relatively large input
+ const int kInput = 1000;
+ const int kPad = 10;
+ const int kOutput = kInput + 2 * kPad;
+
+ // Input:
+ // 0, 1, 2, ..., 999
+ // 0, 1, 2, ..., 999
+ // ... (altogether 1000 lines)
+ // 0, 1, 2, ..., 999
+ AddInput<float>(TensorShape({1, kInput, kInput, 1}),
+ [](int i) -> float { return i % kInput; });
+ AddInputFromArray<int32>(TensorShape({4, 2}),
+ {0, 0, kPad, kPad, kPad, kPad, 0, 0});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({1, kOutput, kOutput, 1}));
+ test::FillFn<float>(&expected, [](int i) -> float {
+ i = i % kOutput;
+ if (0 <= i && i < kPad)
+ return kPad - i - 1;
+ else if (kPad <= i && i < kInput + kPad)
+ return i - kPad;
+ else if (kInput + kPad <= i && i < kOutput)
+ return 2 * kInput + kPad - 1 - i;
+ });
+
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+class MirrorPadGradOpTest : public OpsTestBase {
+ protected:
+ template <typename T>
+ void MakeOp(const string& mode) {
+ TF_EXPECT_OK(NodeDefBuilder("mirror_pad_grad_op", "MirrorPadGrad")
+ .Input(FakeInput(DataTypeToEnum<T>::value))
+ .Input(FakeInput(DT_INT32))
+ .Attr("mode", mode)
+ .Finalize(node_def()));
+ TF_EXPECT_OK(InitOp());
+ }
+};
+
+#define REGISTER_TEST(T) \
+ TEST_F(MirrorPadGradOpTest, TestMirrorPadGradReflect##T) { \
+ MakeOp<T>("REFLECT"); \
+ AddInput<T>(TensorShape({1, 4, 7, 1}), [](int i) -> T { return i % 7; }); \
+ AddInputFromArray<int32>(TensorShape({4, 2}), {0, 0, 1, 1, 2, 2, 0, 0}); \
+ TF_ASSERT_OK(RunOpKernel()); \
+ \
+ Tensor expected(allocator(), DataTypeToEnum<T>::value, \
+ TensorShape({1, 2, 3, 1})); \
+ test::FillValues<T>(&expected, {16, 18, 8, 16, 18, 8}); \
+ test::ExpectTensorEqual<T>(expected, *GetOutput(0)); \
+ } \
+ \
+ TEST_F(MirrorPadGradOpTest, TestMirrorPadGradSymmetric##T) { \
+ MakeOp<T>("SYMMETRIC"); \
+ AddInput<T>(TensorShape({3, 2, 1, 7}), [](int i) -> T { return i % 7; }); \
+ AddInputFromArray<int32>(TensorShape({4, 2}), {1, 1, 0, 0, 0, 0, 2, 2}); \
+ TF_ASSERT_OK(RunOpKernel()); \
+ \
+ Tensor expected(allocator(), DataTypeToEnum<T>::value, \
+ TensorShape({1, 2, 1, 3})); \
+ test::FillValues<T>(&expected, {9, 27, 27, 9, 27, 27}); \
+ test::ExpectTensorEqual<T>(expected, *GetOutput(0)); \
+ }
+
+REGISTER_TEST(float)
+REGISTER_TEST(double)
+REGISTER_TEST(uint8)
+REGISTER_TEST(uint16)
+REGISTER_TEST(int8)
+REGISTER_TEST(int16)
+REGISTER_TEST(int32)
+REGISTER_TEST(int64)
+
+#undef REGISTER_TEST
+
+} // namespace tensorflow