Roll-forward of: [Grappler] Extend the RemoveStackStridedSliceSameAxis optimization to support Slice.
`tf.slice(tf.pack([x, y, z]), ...)` will always have rank one greater than `x,
y, z`. The previous version had a bug where it would incorrectly return the
original `x, y, z` tensor directly, if insufficient shape information were
available. The fix is to explicitly track whether the rank should be expanded or not, based on the slicing operation used.
PiperOrigin-RevId: 297961546
Change-Id: I80b10bd0cd39e783f8a9a41e9d0d5a3dcc250b82
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index d256a66..3281f97 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -48,6 +48,7 @@
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/tensor_coding.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/core/util/saved_tensor_slice_util.h"
@@ -3095,6 +3096,7 @@
// x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i:i+1,...]
// with
// expand_dims(a_i, axis=k)
+// where the slice operator can be StridedSlice or Slice.
//
// TODO(ebrevdo): Extend to also replace operations of the form
// concat((a_0, a_1, ..., ), axis=k)[:, ..., s_i:s_{i+1}, ...]
@@ -3103,17 +3105,16 @@
// when
// s_i = cumsum(shape(a)[k] for a in (a_0, ...,))[i]
// and slicing is in the k'th axis.
-class RemoveStackStridedSliceSameAxis : public ArithmeticOptimizerStage {
+class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage {
public:
- explicit RemoveStackStridedSliceSameAxis(
- const GraphOptimizerContext& ctx,
- const ArithmeticOptimizerContext& ctx_ext)
+ explicit RemoveStackSliceSameAxis(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
: ArithmeticOptimizerStage("RemoveStackStridedSliceSameAxis", ctx,
ctx_ext) {}
- ~RemoveStackStridedSliceSameAxis() override = default;
+ ~RemoveStackSliceSameAxis() override = default;
bool IsSupported(const NodeDef* node) const override {
- return IsStridedSlice(*node);
+ return IsStridedSlice(*node) || IsSlice(*node);
}
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
@@ -3131,14 +3132,16 @@
CheckInputs(node, pack, &pack_output_shape, &pack_axis, &return_early));
if (return_early) return Status::OK();
- int slice_start_value;
+ int64 slice_start_value;
bool found;
+ bool must_expand_dims;
TF_RETURN_IF_ERROR(GetSliceAxis(node, pack, pack_output_shape, pack_axis,
- &slice_start_value, &found));
+ &slice_start_value, &found,
+ &must_expand_dims));
if (!found) return Status::OK();
return RewriteGraph(node, pack, slice_start_value, pack_axis,
- simplified_node_name);
+ must_expand_dims, simplified_node_name);
}
protected:
@@ -3171,8 +3174,113 @@
Status GetSliceAxis(const NodeDef* node, const NodeDef* pack,
const PartialTensorShape& pack_output_shape,
- int pack_axis, int* slice_start_value, bool* found) {
+ int pack_axis, int64* slice_start_value, bool* found,
+ bool* must_expand_dims) {
*found = false;
+ if (IsSlice(*node)) {
+ *must_expand_dims = true;
+ return GetSimpleSliceAxis(node, pack, pack_output_shape, pack_axis,
+ slice_start_value, found);
+ } else {
+ return GetStridedSliceAxis(node, pack, pack_output_shape, pack_axis,
+ slice_start_value, found, must_expand_dims);
+ }
+ }
+
+ Status GetSimpleSliceAxis(const NodeDef* node, const NodeDef* pack,
+ const PartialTensorShape& pack_output_shape,
+ int pack_axis, int64* slice_start_value,
+ bool* found) {
+ NodeDef* slice_begin;
+ NodeDef* slice_size;
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &slice_begin));
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(2), &slice_size));
+ for (const auto* n : {slice_begin, slice_size}) {
+ if (!IsReallyConstant(*n)) return Status::OK();
+ }
+
+ Tensor slice_begin_t;
+ Tensor slice_size_t;
+ TF_RETURN_IF_ERROR(CheckAttrExists(*slice_begin, "value"));
+ if (!slice_begin_t.FromProto(slice_begin->attr().at("value").tensor())) {
+ return Status::OK();
+ }
+ TF_RETURN_IF_ERROR(CheckAttrExists(*slice_size, "value"));
+ if (!slice_size_t.FromProto(slice_size->attr().at("value").tensor())) {
+ return Status::OK();
+ }
+
+ auto copy_tensor_values_to_vector =
+ [node](const Tensor& t, gtl::InlinedVector<int64, 4>* vec) {
+ if (t.dtype() == DT_INT32) {
+ auto t_flat = t.flat<int32>();
+ vec->assign(&t_flat(0), &t_flat(t.NumElements()));
+ } else if (t.dtype() == DT_INT64) {
+ auto t_flat = t.flat<int64>();
+ vec->assign(&t_flat(0), &t_flat(t.NumElements()));
+ } else {
+ return errors::InvalidArgument("Node ", node->name(),
+ " has invalid type for Index attr: ",
+ DataTypeString(t.dtype()));
+ }
+ return Status::OK();
+ };
+
+ gtl::InlinedVector<int64, 4> slice_begin_vec;
+ gtl::InlinedVector<int64, 4> slice_size_vec;
+ TF_RETURN_IF_ERROR(
+ copy_tensor_values_to_vector(slice_begin_t, &slice_begin_vec));
+ TF_RETURN_IF_ERROR(
+ copy_tensor_values_to_vector(slice_size_t, &slice_size_vec));
+
+ if (slice_begin_vec.size() != slice_size_vec.size()) {
+ return errors::InvalidArgument("Node ", node->name(),
+ " has mismatched lengths for begin (",
+ slice_begin_vec.size(), ") and size (",
+ slice_size_vec.size(), ") vectors.");
+ }
+ if (!pack_output_shape.unknown_rank() &&
+ slice_begin_vec.size() != pack_output_shape.dims()) {
+ return Status::OK();
+ }
+ if (pack_axis >= slice_begin_vec.size()) {
+ return errors::InvalidArgument(
+ "Input to node ", node->name(), " had pack_axis ", pack_axis,
+ " but rank was ", slice_begin_vec.size(), ".");
+ }
+
+ *slice_start_value = slice_begin_vec[pack_axis];
+ if (slice_size_vec[pack_axis] != 1) {
+ // Not slicing a single value out.
+ return Status::OK();
+ }
+
+ for (size_t i = 0; i < slice_begin_vec.size(); ++i) {
+ if (i != pack_axis) {
+ if (slice_begin_vec[i] != 0 ||
+ !(slice_size_vec[i] == -1 ||
+ slice_size_vec[i] == pack_output_shape.dim_size(i))) {
+ // Not slicing on the same axis as the Pack op.
+ return Status::OK();
+ }
+ }
+ }
+
+ if (*slice_start_value < 0 || *slice_start_value >= pack->input_size()) {
+ return errors::InvalidArgument(
+ "Node ", node->name(), " requested invalid slice index ",
+ *slice_start_value, " on axis ", pack_axis,
+ " from tensor of shape: ", pack_output_shape.DebugString());
+ }
+
+ *found = true; // slice_start_value is valid.
+ return Status::OK();
+ }
+
+ Status GetStridedSliceAxis(const NodeDef* node, const NodeDef* pack,
+ const PartialTensorShape& pack_output_shape,
+ int pack_axis, int64* slice_start_value,
+ bool* found, bool* must_expand_dims) {
TF_RETURN_IF_ERROR(
CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask",
"new_axis_mask", "shrink_axis_mask"}));
@@ -3286,13 +3394,22 @@
" from tensor of shape: ", pack_output_shape.DebugString());
}
+ if (shrink_axis_mask == 0) {
+ *must_expand_dims = true;
+ } else if (shrink_axis_mask == (1 << slice_axis)) {
+ *must_expand_dims = false;
+ } else {
+ // Shrinking on a different axis from the one that we are slicing on.
+ return Status::OK();
+ }
+
*found = true; // slice_start_value is valid.
return Status::OK();
}
Status RewriteGraph(const NodeDef* node, const NodeDef* pack,
- int slice_start_value, int pack_axis,
- string* simplified_node_name) {
+ int64 slice_start_value, int pack_axis,
+ bool must_expand_dims, string* simplified_node_name) {
const string& input_slice = pack->input(slice_start_value);
const OpInfo::TensorProperties* input_slice_properties;
@@ -3306,7 +3423,7 @@
PartialTensorShape output_shape(output_properties->shape());
NodeDef* output =
AddEmptyNode(OptimizedNodeName(ParseNodeScopeAndName(node->name())));
- if (input_slice_shape.IsCompatibleWith(output_shape)) {
+ if (!must_expand_dims) {
output->set_op("Identity");
output->set_device(node->device());
SetDataTypeToAttr(output_properties->dtype(), "T", output);
@@ -3449,7 +3566,6 @@
if (node.device().find("SPU") != string::npos) {
return false;
}
- // Workaround for Assert and Print mistakenly being labeled as stateful.
if (IsAssert(node) || IsPrint(node)) {
return true;
}
@@ -3621,8 +3737,8 @@
pipeline.AddStage<ConvertExpm1Stage>(ctx, ctx_ext);
if (options_.unary_ops_composition)
pipeline.AddStage<UnaryOpsComposition>(ctx, ctx_ext);
- if (options_.remove_stack_strided_slice_same_axis)
- pipeline.AddStage<RemoveStackStridedSliceSameAxis>(ctx, ctx_ext);
+ if (options_.remove_stack_slice_same_axis)
+ pipeline.AddStage<RemoveStackSliceSameAxis>(ctx, ctx_ext);
if (options_.fuse_squared_diff)
pipeline.AddStage<FuseSquaredDiffStage>(ctx, ctx_ext);
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index e7c847c..76aca8b 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -84,7 +84,7 @@
bool convert_log_softmax = true;
bool convert_expm1 = true;
bool unary_ops_composition = true;
- bool remove_stack_strided_slice_same_axis = true;
+ bool remove_stack_slice_same_axis = true;
// Choose which arithmetic optimizer stages will be enabled for a given
// optimization level by default.
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 4c70343..a421daa 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -15,6 +15,7 @@
#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
+#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/math_ops.h"
@@ -4002,7 +4003,7 @@
GraphDef output;
ArithmeticOptimizer optimizer;
- EnableOnlyRemoveStackStridedSliceSameAxis(&optimizer);
+ EnableOnlyRemoveStackSliceSameAxis(&optimizer);
OptimizeAndPrune(&optimizer, &item, &output);
for (const auto& node : output.node()) {
@@ -4050,6 +4051,122 @@
tensors_expected[fExpandedC]);
}
+TEST_F(ArithmeticOptimizerTest, RemoveStackSimpleSliceSameAxis) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto a_in =
+ ops::Const(s.WithOpName("a_in"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
+ auto b_in =
+ ops::Const(s.WithOpName("b_in"), {-1.0f, -2.0f, -3.0f, -4.0f}, {2, 2});
+ auto c_in =
+ ops::Const(s.WithOpName("c_in"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
+ auto a = ops::PlaceholderWithDefault(s.WithOpName("a"), a_in,
+ PartialTensorShape({-1, -1}));
+ auto b = ops::PlaceholderWithDefault(s.WithOpName("b"), b_in,
+ PartialTensorShape({-1, -1}));
+ auto c = ops::PlaceholderWithDefault(s.WithOpName("c"), c_in,
+ PartialTensorShape({-1, -1}));
+ // stacked = tf.stack((a, b, c), axis=1).
+ // stacked.shape == [2, 3, 2] (a, b, c are stacked along new axis 1)
+ auto stacked =
+ ops::Stack(s.WithOpName("stacked"), {a.output, b.output, c.output},
+ ops::Stack::Axis(1));
+ auto expanded_a = ops::ExpandDims(s.WithOpName("expanded_a"), a, {1});
+ auto expanded_b = ops::ExpandDims(s.WithOpName("expanded_b"), b, {1});
+ auto expanded_c = ops::ExpandDims(s.WithOpName("expanded_c"), c, {1});
+ auto begin_a = ops::Const(s.WithOpName("begin_a"), {0, 0, 0}, {3});
+ auto begin_b = ops::Const(s.WithOpName("begin_b"), {0, 1, 0}, {3});
+ auto begin_c = ops::Const(s.WithOpName("begin_c"), {0, 2, 0}, {3});
+ auto sizes_to_end = ops::Const(s.WithOpName("size"), {-1, 1, -1}, {3});
+
+ // stacked[:, 0:1, :]
+ auto pa_slice = ops::Identity(
+ s.WithOpName("pa_slice_out"),
+ ops::Slice(s.WithOpName("pa_slice"), stacked, begin_a, sizes_to_end));
+
+ // stacked[:, 1:2, :]
+ auto pb_slice = ops::Identity(
+ s.WithOpName("pb_slice_out"),
+ ops::Slice(s.WithOpName("pb_slice"), stacked, begin_b, sizes_to_end));
+
+ // stacked[:, 2:3, :]
+ auto pc_slice = ops::Identity(
+ s.WithOpName("pc_slice_out"),
+ ops::Slice(s.WithOpName("pc_slice"), stacked, begin_c, sizes_to_end));
+
+ GrapplerItem item;
+ item.fetch = {"a",
+ "b",
+ "c",
+ "pa_slice_out",
+ "pb_slice_out",
+ "pc_slice_out",
+ "expanded_a",
+ "expanded_b",
+ "expanded_c"};
+ enum FetchItem {
+ fA,
+ fB,
+ fC,
+ fASliceOut,
+ fBSliceOut,
+ fCSliceOut,
+ fExpandedA,
+ fExpandedB,
+ fExpandedC,
+ };
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
+ // stacked[:, 0:1, :] == a.
+ test::ExpectTensorEqual<float>(tensors_expected[fASliceOut],
+ tensors_expected[fExpandedA]);
+ // stacked[:, 1:2, :] == b.
+ test::ExpectTensorEqual<float>(tensors_expected[fBSliceOut],
+ tensors_expected[fExpandedB]);
+ // stacked[:, 2:3, :] == c.
+ test::ExpectTensorEqual<float>(tensors_expected[fCSliceOut],
+ tensors_expected[fExpandedC]);
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyRemoveStackSliceSameAxis(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &output);
+
+ const string kExpandDimsNamePrefix(
+ "ArithmeticOptimizer/RemoveStackStridedSliceSameAxis_p");
+
+ for (const auto& node : output.node()) {
+ if (node.name() == "pa_slice_out") {
+ ASSERT_EQ(node.input_size(), 1);
+ EXPECT_EQ(node.input(0), absl::StrCat(kExpandDimsNamePrefix, "a_slice"));
+ } else if (node.name() == "pb_slice_out") {
+ ASSERT_EQ(node.input_size(), 1);
+ EXPECT_EQ(node.input(0), absl::StrCat(kExpandDimsNamePrefix, "b_slice"));
+ } else if (node.name() == "pc_slice_out") {
+ ASSERT_EQ(node.input_size(), 1);
+ EXPECT_EQ(node.input(0), absl::StrCat(kExpandDimsNamePrefix, "c_slice"));
+ } else if (absl::StartsWith(node.name(), kExpandDimsNamePrefix)) {
+ EXPECT_EQ(node.op(), "ExpandDims");
+ // The input is "a", "b", or "c", as appropriate.
+ EXPECT_EQ(node.input(0),
+ node.name().substr(kExpandDimsNamePrefix.size(), 1));
+ }
+ }
+
+ auto tensors = EvaluateNodes(output, item.fetch);
+
+ // stacked[:, 0:1, :] == a.
+ test::ExpectTensorEqual<float>(tensors[fASliceOut],
+ tensors_expected[fExpandedA]);
+
+ // stacked[:, 1:2, :] == b.
+ test::ExpectTensorEqual<float>(tensors[fBSliceOut],
+ tensors_expected[fExpandedB]);
+ // stacked[:, 2:3, :] == c.
+ test::ExpectTensorEqual<float>(tensors[fCSliceOut],
+ tensors_expected[fExpandedC]);
+}
+
TEST_F(ArithmeticOptimizerTest, SimplifyAggregationBFloat16) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h
index d3ad437..4d3ba97 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h
@@ -202,10 +202,9 @@
optimizer->options_.unary_ops_composition = true;
}
- void EnableOnlyRemoveStackStridedSliceSameAxis(
- ArithmeticOptimizer* optimizer) {
+ void EnableOnlyRemoveStackSliceSameAxis(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
- optimizer->options_.remove_stack_strided_slice_same_axis = true;
+ optimizer->options_.remove_stack_slice_same_axis = true;
}
private: