Adds back a `DoCopy` in StridedSliceAssignOp that was accidentally deleted.
PiperOrigin-RevId: 323074864
Change-Id: Ie46f51353a85aa423e562da7e2f3009238cca07e
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index 4638382..b430299 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -885,10 +885,14 @@
// Tries to forward one of the inputs given in input_indices to
// output[output_index]. If none of the given inputs can be forwarded, calls
- // allocate_output() to allocate a new output buffer.
+ // allocate_output() to allocate a new output buffer. The index of the
+ // forwarded input will be assign to output argument forwarded_input (if it's
+ // not nullptr). If no inputs are forwarded, forwarded_input will be assigned
+ // -1.
Status forward_input_or_allocate_output(
gtl::ArraySlice<int> candidate_input_indices, int output_index,
- const TensorShape& output_shape, Tensor** output) TF_MUST_USE_RESULT;
+ const TensorShape& output_shape, Tensor** output,
+ int* forwarded_input = nullptr) TF_MUST_USE_RESULT;
Status forward_input_or_allocate_output(
gtl::ArraySlice<StringPiece> candidate_input_names,
StringPiece output_name, const TensorShape& output_shape,
@@ -1636,13 +1640,19 @@
inline Status OpKernelContext::forward_input_or_allocate_output(
gtl::ArraySlice<int> candidate_input_indices, int output_index,
- const TensorShape& output_shape, Tensor** output) {
+ const TensorShape& output_shape, Tensor** output, int* forwarded_input) {
for (int input_index : candidate_input_indices) {
if (forward_input_to_output_with_shape(input_index, output_index,
output_shape, output)) {
+ if (forwarded_input != nullptr) {
+ *forwarded_input = input_index;
+ }
return Status::OK();
}
}
+ if (forwarded_input != nullptr) {
+ *forwarded_input = -1;
+ }
return allocate_output(output_index, output_shape, output);
}
diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc
index fc08fa8..7d9dfa4 100644
--- a/tensorflow/core/kernels/strided_slice_op.cc
+++ b/tensorflow/core/kernels/strided_slice_op.cc
@@ -306,8 +306,15 @@
if (isTensor) {
const Tensor& input = context->input(0);
- OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
- {0}, 0, input.shape(), &old_lhs));
+ int forwarded_input;
+ OP_REQUIRES_OK(context,
+ context->forward_input_or_allocate_output(
+ {0}, 0, input.shape(), &old_lhs, &forwarded_input));
+ if (forwarded_input < 0) {
+ OP_REQUIRES_OK(context,
+ tensorflow::functor::DoCopy(
+ context->eigen_device<Device>(), input, old_lhs));
+ }
} else {
if (context->input_dtype(0) == DT_RESOURCE) {
core::RefCountPtr<Var> v;
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 829be7f..0a7e4e5 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -1228,13 +1228,25 @@
sess.run(v[:].assign(too_small_val))
@test_util.run_in_graph_and_eager_modes
- def testTensorStridedSliceAssign(self):
+ def testTensorStridedSliceAssignWithInputForward(self):
+ """Tests tensor_strided_slice_update with input-forwarding taking effect."""
@def_function.function
def assign(x):
y = x + 1
return gen_array_ops.tensor_strided_slice_update(y, [0], [1], [1], [0])
self.assertAllEqual([0, 1], self.evaluate(assign(array_ops.zeros([2]))))
+ @test_util.run_in_graph_and_eager_modes
+ def testTensorStridedSliceAssignNoInputForward(self):
+ """Tests tensor_strided_slice_update with no input-forwarding."""
+ x = constant_op.constant([0.2, 0.3])
+ y = x + 1
+ # y's buffer won't be forwarded to z because y and z will be alive at the
+ # same time later.
+ z = gen_array_ops.tensor_strided_slice_update(y, [0], [1], [1], [0.4])
+ ans = y + z
+ self.assertAllClose([1.6, 2.6], self.evaluate(ans))
+
class ShapeSizeRankTest(test_util.TensorFlowTestCase):