| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/xla/client/value_inference.h" |
| |
| #include <memory> |
| #include <utility> |
| #include <vector> |
| |
| #include "absl/strings/match.h" |
| #include "absl/types/span.h" |
| #include "tensorflow/compiler/xla/client/client_library.h" |
| #include "tensorflow/compiler/xla/client/global_data.h" |
| #include "tensorflow/compiler/xla/client/lib/arithmetic.h" |
| #include "tensorflow/compiler/xla/client/lib/prng.h" |
| #include "tensorflow/compiler/xla/client/xla_builder.h" |
| #include "tensorflow/compiler/xla/client/xla_computation.h" |
| #include "tensorflow/compiler/xla/layout_util.h" |
| #include "tensorflow/compiler/xla/literal.h" |
| #include "tensorflow/compiler/xla/shape_util.h" |
| #include "tensorflow/compiler/xla/status_macros.h" |
| #include "tensorflow/compiler/xla/statusor.h" |
| #include "tensorflow/compiler/xla/test.h" |
| #include "tensorflow/compiler/xla/tests/literal_test_util.h" |
| #include "tensorflow/compiler/xla/tests/test_macros.h" |
| #include "tensorflow/compiler/xla/tests/test_utils.h" |
| #include "tensorflow/compiler/xla/xla_data.pb.h" |
| #include "tensorflow/core/lib/core/status_test_util.h" |
| #include "tensorflow/core/platform/status.h" |
| #include "tensorflow/core/platform/statusor.h" |
| #include "tensorflow/core/platform/types.h" |
| |
| namespace xla { |
| namespace { |
| |
| class ValueInferenceTest : public ::testing::Test { |
| public: |
| string TestName() const { |
| return ::testing::UnitTest::GetInstance()->current_test_info()->name(); |
| } |
| }; |
| |
| class DynamismInferenceTest : public ValueInferenceTest { |
| public: |
| explicit DynamismInferenceTest(se::Platform* platform = nullptr) |
| : platform_(platform) {} |
| |
| StatusOr<Literal> ComputeDynamismLiteral(XlaOp operand, XlaBuilder* builder, |
| Layout* output_layout = nullptr) { |
| TF_RETURN_IF_ERROR(builder->first_error()); |
| ValueInference value_inference(builder); |
| TF_ASSIGN_OR_RETURN(auto literal_slice, |
| value_inference.AnalyzeIsDynamic(operand)); |
| return literal_slice.Clone(); |
| } |
| |
| StatusOr<bool> ComputeDynamismScalar(XlaOp operand, XlaBuilder* builder, |
| ShapeIndex index = {}) { |
| TF_ASSIGN_OR_RETURN(auto literal, |
| ComputeDynamismLiteral(operand, builder, nullptr)); |
| return literal.Get<bool>({}, index); |
| } |
| |
| se::Platform* platform_; |
| }; |
| |
| TEST_F(DynamismInferenceTest, ScalarInt32Literal) { |
| XlaBuilder b(TestName()); |
| auto computation = ConstantR0<int32>(&b, 42); |
| |
| auto value = ComputeDynamismScalar(computation, &b); |
| ASSERT_TRUE(value.ok()) << value.status(); |
| // A constant is not dynamic. |
| EXPECT_EQ(value.ValueOrDie(), false); |
| } |
| |
| TEST_F(DynamismInferenceTest, Iota) { |
| // The output of iota are consistened static. |
| XlaBuilder b(TestName()); |
| auto computation = Iota(&b, S32, 2); |
| // Iota is not dynamic. |
| EXPECT_FALSE( |
| ComputeDynamismLiteral(computation, &b).ValueOrDie().Get<bool>({0})); |
| } |
| |
| TEST_F(DynamismInferenceTest, TupleSimple) { |
| XlaBuilder b(TestName()); |
| auto c = ConstantR0<int32>(&b, 42); |
| auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); |
| |
| auto tuple = Tuple(&b, {c, p}); |
| EXPECT_EQ(ComputeDynamismScalar(tuple, &b, {0}).ValueOrDie(), false); |
| EXPECT_EQ(ComputeDynamismScalar(tuple, &b, {1}).ValueOrDie(), true); |
| } |
| |
| TEST_F(DynamismInferenceTest, TupleGteKeepsDynamism) { |
| XlaBuilder b(TestName()); |
| auto c = ConstantR0<int32>(&b, 42); |
| auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); |
| |
| auto tuple = Tuple(&b, {c, p}); |
| auto gte0 = GetTupleElement(tuple, 0); |
| auto gte1 = GetTupleElement(tuple, 1); |
| auto tuple_2 = Tuple(&b, {gte0, gte1}); |
| EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false); |
| EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true); |
| } |
| |
| TEST_F(DynamismInferenceTest, PredValueUsedTwice) { |
| XlaBuilder b(TestName()); |
| auto c = ConstantR0<int32>(&b, 42); |
| auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); |
| auto pred = Eq(c, p); |
| auto result = Select(pred, p, c); |
| EXPECT_EQ(ComputeDynamismScalar(result, &b, {}).ValueOrDie(), true); |
| } |
| |
| TEST_F(DynamismInferenceTest, ReduceUsedTwice) { |
| XlaBuilder b(TestName()); |
| auto c = ConstantR0<int32>(&b, 42); |
| auto p = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2}), "p0"); |
| auto zero = ConstantR0<int32>(&b, 0); |
| XlaComputation add_s32 = CreateScalarAddComputation(S32, &b); |
| auto reduce = Reduce(p, zero, add_s32, {0}); |
| auto pred = Eq(c, reduce); |
| auto result = Select(pred, reduce, c); |
| EXPECT_EQ(ComputeDynamismScalar(result, &b, {}).ValueOrDie(), true); |
| } |
| |
| TEST_F(DynamismInferenceTest, VariadicReduce) { |
| XlaBuilder b(TestName()); |
| auto c = ConstantR2<int32>(&b, {{0, 0}}); |
| auto p = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {1, 2}), "p0"); |
| // half_dynamic[0] is static, half_dynamic[0] is dynamic. |
| auto half_dynamic = ConcatInDim(&b, {c, p}, 0); |
| XlaBuilder reduce_add("reduce_add"); |
| auto p0 = Parameter(&reduce_add, 0, ShapeUtil::MakeScalarShape(S32), "p"); |
| auto p1 = Parameter(&reduce_add, 1, ShapeUtil::MakeScalarShape(S32), "p"); |
| auto p2 = Parameter(&reduce_add, 2, ShapeUtil::MakeScalarShape(S32), "p"); |
| auto p3 = Parameter(&reduce_add, 3, ShapeUtil::MakeScalarShape(S32), "p"); |
| auto reduce_result = p0; |
| reduce_result = Add(reduce_result, p1); |
| reduce_result = Add(reduce_result, p2); |
| reduce_result = Add(reduce_result, p3); |
| Tuple(&reduce_add, {reduce_result, reduce_result}); |
| auto init = ConstantR0<int32>(&b, 0); |
| auto variadic_reduce = Reduce(&b, {half_dynamic, half_dynamic}, {init, init}, |
| reduce_add.Build().ConsumeValueOrDie(), {1}); |
| auto result = GetTupleElement(variadic_reduce, 0); |
| |
| // result[0] should be static; result[1] should be dynamic. |
| EXPECT_FALSE( |
| ComputeDynamismLiteral(result, &b).ValueOrDie().Get<bool>({0})); |
| EXPECT_TRUE( |
| ComputeDynamismLiteral(result, &b).ValueOrDie().Get<bool>({1})); |
| } |
| |
| TEST_F(DynamismInferenceTest, DynamicSelectorWithMixedValues) { |
| XlaBuilder b(TestName()); |
| auto constant_pred = ConstantR1<bool>(&b, {true}); |
| auto dynamic_pred = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {1}), "p0"); |
| auto concat = ConcatInDim(&b, {constant_pred, dynamic_pred}, 0); |
| auto constant_values = ConstantR1<bool>(&b, {true, true}); |
| auto result = Select(concat, constant_values, constant_values); |
| // First result is static (selector is constant, both values are constant). |
| // Iota is not dynamic. |
| EXPECT_FALSE(ComputeDynamismLiteral(result, &b).ValueOrDie().Get<bool>({0})); |
| // Second result is dynamic (selector is dynamic). |
| EXPECT_TRUE(ComputeDynamismLiteral(result, &b).ValueOrDie().Get<bool>({1})); |
| } |
| |
| TEST_F(DynamismInferenceTest, ConcatSliceReshapeKeepsDynamism) { |
| XlaBuilder b(TestName()); |
| auto c = ConstantR0<int32>(&b, 42); |
| auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); |
| |
| auto concat = ConcatScalars(&b, {c, p}); |
| auto slice0 = SliceInDim(concat, 0, 1, 1, 0); |
| auto reshape0 = Reshape(slice0, {}); |
| auto slice1 = SliceInDim(concat, 1, 2, 1, 0); |
| auto reshape1 = Reshape(slice1, {}); |
| auto tuple_2 = Tuple(&b, {reshape0, reshape1}); |
| EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false); |
| EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true); |
| } |
| |
| TEST_F(DynamismInferenceTest, ParameterIsDynamic) { |
| XlaBuilder b(TestName()); |
| auto computation = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); |
| |
| auto value = ComputeDynamismScalar(computation, &b); |
| ASSERT_TRUE(value.ok()) << value.status(); |
| // A parameter is considered dynamic. |
| EXPECT_EQ(value.ValueOrDie(), true); |
| } |
| |
| TEST_F(DynamismInferenceTest, UnaryOpKeepsDynamism) { |
| XlaBuilder b(TestName()); |
| auto c = ConstantR0<int32>(&b, 42); |
| auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); |
| |
| auto neg0 = Neg(c); |
| auto neg1 = Neg(p); |
| auto tuple_2 = Tuple(&b, {neg0, neg1}); |
| EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false); |
| EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true); |
| } |
| |
| TEST_F(DynamismInferenceTest, ParameterWithToken) { |
| // Test that token shape can be handled in a parameter. |
| XlaBuilder b(TestName()); |
| auto p = |
| Parameter(&b, 0, |
| ShapeUtil::MakeTupleShape({ShapeUtil::MakeTokenShape(), |
| ShapeUtil::MakeScalarShape(S32)}), |
| "p0"); |
| EXPECT_EQ(ComputeDynamismScalar(p, &b, {0}).ValueOrDie(), true); |
| EXPECT_EQ(ComputeDynamismScalar(p, &b, {1}).ValueOrDie(), true); |
| } |
| |
| TEST_F(DynamismInferenceTest, BinaryOpsOrsDynamism) { |
| XlaBuilder b(TestName()); |
| auto c = ConstantR0<int32>(&b, 42); |
| auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); |
| |
| // Static value + static value = static |
| auto add1 = Add(c, c); |
| // Dynamic value + dynamic value = dynamic |
| auto add2 = Add(p, c); |
| auto tuple_2 = Tuple(&b, {add1, add2}); |
| EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false); |
| EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true); |
| } |
| |
| TEST_F(DynamismInferenceTest, GetDimensionSize) { |
| XlaBuilder b(TestName()); |
| // param = Param([<=2, 3]) |
| // get_dimension_size(param, 0) is dynamic |
| // get_dimension_size(param, 1) is static |
| auto p = |
| Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "p0"); |
| |
| auto gds0 = GetDimensionSize(p, 0); |
| auto gds1 = GetDimensionSize(p, 1); |
| auto tuple_2 = Tuple(&b, {gds0, gds1}); |
| EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), true); |
| EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), false); |
| } |
| |
| TEST_F(DynamismInferenceTest, DynamicSliceWithConstantOperands) { |
| XlaBuilder b(TestName()); |
| |
| auto constant = ConstantR1<int32>(&b, {0, 1, 2, 3}); |
| auto slice_start = ConstantR0(&b, 1); |
| auto dynamic_slice = DynamicSlice(constant, {slice_start}, {1}); |
| EXPECT_FALSE( |
| ComputeDynamismLiteral(dynamic_slice, &b).ValueOrDie().Get<bool>({0})); |
| } |
| |
| TEST_F(DynamismInferenceTest, GatherWithCommonParent) { |
| XlaBuilder b(TestName()); |
| // Test the analysis on a gather where first operand and second operand have |
| // common parents. |
| Shape indices_shape = ShapeUtil::MakeShape(S32, {2}); |
| |
| auto operand1 = Parameter(&b, 0, indices_shape, "p1"); |
| auto operand2 = Parameter(&b, 1, indices_shape, "p2"); |
| auto indices = Sub(operand1, operand2); |
| GatherDimensionNumbers dim_numbers; |
| dim_numbers.add_offset_dims(1); |
| dim_numbers.add_start_index_map(0); |
| dim_numbers.set_index_vector_dim(1); |
| auto gather = Gather(operand1, indices, dim_numbers, {1}); |
| ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message(); |
| EXPECT_TRUE( |
| ComputeDynamismLiteral(gather, &b).ValueOrDie().Get<bool>({0, 0})); |
| } |
| |
| TEST_F(DynamismInferenceTest, GatherWithConstantParent) { |
| XlaBuilder b(TestName()); |
| // Test the analysis on a gather. |
| Shape indices_shape = ShapeUtil::MakeShape(S32, {2}); |
| auto data_operand = ConstantR1<int32>(&b, {1, 2}); |
| auto indices = ConstantR1<int32>(&b, {1, 2}); |
| GatherDimensionNumbers dim_numbers; |
| dim_numbers.add_offset_dims(1); |
| dim_numbers.add_start_index_map(0); |
| dim_numbers.set_index_vector_dim(1); |
| auto gather = Gather(data_operand, indices, dim_numbers, {1}); |
| ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message(); |
| // Everything is constant, result is also contant. |
| EXPECT_FALSE( |
| ComputeDynamismLiteral(gather, &b).ValueOrDie().Get<bool>({0, 0})); |
| } |
| |
| TEST_F(DynamismInferenceTest, GatherWithSharedConstantParent) { |
| XlaBuilder b(TestName()); |
| // Test the analysis on a gather. |
| Shape indices_shape = ShapeUtil::MakeShape(S32, {2}); |
| auto operand1 = ConstantR1<int32>(&b, {1, 2}); |
| auto operand2 = ConstantR1<int32>(&b, {1, 2}); |
| auto indices = Sub(operand1, operand2); |
| GatherDimensionNumbers dim_numbers; |
| dim_numbers.add_offset_dims(1); |
| dim_numbers.add_start_index_map(0); |
| dim_numbers.set_index_vector_dim(1); |
| auto gather = Gather(operand1, indices, dim_numbers, {1}); |
| ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message(); |
| // Everything is constant, result is also contant. |
| EXPECT_FALSE( |
| ComputeDynamismLiteral(gather, &b).ValueOrDie().Get<bool>({0, 0})); |
| } |
| |
| TEST_F(DynamismInferenceTest, InferThroughPad) { |
| XlaBuilder b(TestName()); |
| // Test the analysis on a gather. |
| auto operand1 = ConstantR1<int32>(&b, {1, 2}); |
| auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {}), "p0"); |
| PaddingConfig padding_config; |
| padding_config.add_dimensions()->set_edge_padding_high(1); |
| // After pad the value is [constant, constant, parameter]. |
| auto pad = Pad(operand1, parameter, padding_config); |
| ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message(); |
| // Everything is constant, result is also contant. |
| EXPECT_FALSE(ComputeDynamismLiteral(pad, &b).ValueOrDie().Get<bool>({0})); |
| EXPECT_FALSE(ComputeDynamismLiteral(pad, &b).ValueOrDie().Get<bool>({1})); |
| EXPECT_TRUE(ComputeDynamismLiteral(pad, &b).ValueOrDie().Get<bool>({2})); |
| } |
| |
| TEST_F(DynamismInferenceTest, InferThroughConditionalBranchesAreSame) { |
| // The result of following conditional is static. |
| // pred = .. # a dynamic value |
| // if (pred) { |
| // return (1) # both branches return the same value |
| // } else { |
| // return (1) |
| // } |
| // |
| |
| auto s32_shape = ShapeUtil::MakeShape(S32, {}); |
| auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape}); |
| XlaBuilder true_builder("true"); |
| Parameter(&true_builder, 0, s32_shape, "cond_param"); |
| Tuple(&true_builder, {ConstantR0<int32>(&true_builder, 1)}); |
| auto true_computation = true_builder.Build().ValueOrDie(); |
| |
| XlaBuilder false_builder("false"); |
| Parameter(&false_builder, 0, s32_shape, "cond_param"); |
| Tuple(&false_builder, {ConstantR0<int32>(&false_builder, 1)}); |
| auto false_computation = false_builder.Build().ValueOrDie(); |
| |
| XlaBuilder b(TestName()); |
| auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "p0"); |
| auto constant = ConstantR0<int32>(&b, 0); |
| auto cond = Conditional(parameter, constant, true_computation, constant, |
| false_computation); |
| auto gte = GetTupleElement(cond, 0); |
| ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message(); |
| // Result is not dynamic. |
| EXPECT_FALSE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({})); |
| } |
| |
| TEST_F(DynamismInferenceTest, InferThroughConditionalBranchesAreNotSame) { |
| // The result of following conditional is dynamic. |
| // pred = .. # a dynamic value |
| // if (pred) { |
| // return (1) # These two branches return different values. |
| // } else { |
| // return (2) |
| // } |
| // |
| |
| auto s32_shape = ShapeUtil::MakeShape(S32, {}); |
| auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape}); |
| XlaBuilder true_builder("true"); |
| Parameter(&true_builder, 0, s32_shape, "cond_param"); |
| Tuple(&true_builder, {ConstantR0<int32>(&true_builder, 1)}); |
| auto true_computation = true_builder.Build().ValueOrDie(); |
| |
| XlaBuilder false_builder("false"); |
| Parameter(&false_builder, 0, s32_shape, "cond_param"); |
| Tuple(&false_builder, {ConstantR0<int32>(&false_builder, 2)}); |
| auto false_computation = false_builder.Build().ValueOrDie(); |
| |
| XlaBuilder b(TestName()); |
| auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "p0"); |
| auto constant = ConstantR0<int32>(&b, 0); |
| auto cond = Conditional(parameter, constant, true_computation, constant, |
| false_computation); |
| auto gte = GetTupleElement(cond, 0); |
| ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message(); |
| // Result is dynamic. |
| EXPECT_TRUE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({})); |
| } |
| |
| TEST_F(DynamismInferenceTest, InferThroughConditionalPredIsConstantTrueBranch) { |
| // The result of following conditional is static. |
| // pred = true |
| // if (pred) { |
| // return (1) |
| // } else { |
| // return (..dynamic_value...) |
| // } |
| // |
| |
| auto s32_shape = ShapeUtil::MakeShape(S32, {}); |
| auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape}); |
| XlaBuilder true_builder("true"); |
| Parameter(&true_builder, 0, s32_shape, "cond_param"); |
| Tuple(&true_builder, {ConstantR0<int32>(&true_builder, 0)}); |
| auto true_computation = true_builder.Build().ValueOrDie(); |
| |
| XlaBuilder false_builder("false"); |
| Tuple(&false_builder, |
| {Parameter(&false_builder, 0, s32_shape, "cond_param")}); |
| auto false_computation = false_builder.Build().ValueOrDie(); |
| |
| XlaBuilder b(TestName()); |
| auto pred = ConstantR0<bool>(&b, true); |
| auto constant = ConstantR0<int32>(&b, 0); |
| auto cond = Conditional(pred, constant, true_computation, constant, |
| false_computation); |
| auto gte = GetTupleElement(cond, 0); |
| ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message(); |
| // Result is not dynamic. |
| EXPECT_FALSE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({})); |
| } |
| |
| TEST_F(DynamismInferenceTest, |
| InferThroughConditionalPredIsConstantFalseBranch) { |
| // The result of following conditional is dynamic. |
| // pred = false |
| // if (pred) { |
| // return (1) |
| // } else { |
| // return (..dynamic_value...) |
| // } |
| // |
| |
| auto s32_shape = ShapeUtil::MakeShape(S32, {}); |
| auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape}); |
| XlaBuilder true_builder("true"); |
| Parameter(&true_builder, 0, s32_shape, "cond_param"); |
| Tuple(&true_builder, {ConstantR0<int32>(&true_builder, 0)}); |
| auto true_computation = true_builder.Build().ValueOrDie(); |
| |
| XlaBuilder false_builder("false"); |
| Tuple(&false_builder, |
| {Parameter(&false_builder, 0, s32_shape, "cond_param")}); |
| auto false_computation = false_builder.Build().ValueOrDie(); |
| |
| XlaBuilder b(TestName()); |
| auto param = Parameter(&b, 0, s32_shape, "param"); |
| auto pred = ConstantR0<bool>(&b, false); |
| auto constant = ConstantR0<int32>(&b, 0); |
| auto cond = |
| Conditional(pred, constant, true_computation, param, false_computation); |
| auto gte = GetTupleElement(cond, 0); |
| ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message(); |
| // Result is dynamic. |
| EXPECT_TRUE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({})); |
| } |
| |
| TEST_F(DynamismInferenceTest, ArgumentForwardingNestedTuple) { |
| // The result of following conditional is considered static. |
| // pred = .. dynamic value.. |
| // |
| // op = 1 |
| // if (pred) { |
| // if (pred) { |
| // return op |
| // } else { |
| // return op |
| // } |
| // } else { |
| // if (pred) { |
| // return op |
| // } else { |
| // return op |
| // } |
| // } |
| // |
| auto pred_shape = ShapeUtil::MakeShape(PRED, {}); |
| auto s32_shape = ShapeUtil::MakeShape(S32, {}); |
| auto tuple_shape = ShapeUtil::MakeTupleShape({pred_shape, s32_shape}); |
| auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape}); |
| XlaBuilder inner_true_builder("inner_true"); |
| Parameter(&inner_true_builder, 0, s32_shape, "cond_param"); |
| Tuple(&inner_true_builder, {ConstantR0<int32>(&inner_true_builder, 0)}); |
| auto inner_true_computation = inner_true_builder.Build().ValueOrDie(); |
| |
| XlaBuilder inner_false_builder("inner_false"); |
| Tuple(&inner_false_builder, |
| {Parameter(&inner_false_builder, 0, s32_shape, "cond_param")}); |
| auto inner_false_computation = inner_false_builder.Build().ValueOrDie(); |
| |
| XlaBuilder true_builder("true"); |
| { |
| auto param = Parameter(&true_builder, 0, tuple_shape, "param"); |
| auto op = GetTupleElement(param, 1); |
| auto pred = GetTupleElement(param, 0); |
| Conditional(pred, op, inner_true_computation, op, inner_false_computation); |
| } |
| auto true_computation = true_builder.Build().ValueOrDie(); |
| XlaBuilder false_builder("false"); |
| { |
| auto param = Parameter(&false_builder, 0, tuple_shape, "param"); |
| auto op = GetTupleElement(param, 1); |
| auto pred = GetTupleElement(param, 0); |
| Conditional(pred, op, inner_true_computation, op, inner_false_computation); |
| } |
| auto false_computation = false_builder.Build().ValueOrDie(); |
| XlaBuilder b(TestName()); |
| auto constant = ConstantR0<int32>(&b, 0); |
| auto pred = Parameter(&b, 0, pred_shape, "param"); |
| auto param = Tuple(&b, {pred, constant}); |
| auto cond = |
| Conditional(pred, param, true_computation, param, false_computation); |
| auto gte = GetTupleElement(cond, 0); |
| ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message(); |
| // Result is static. |
| EXPECT_FALSE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({})); |
| } |
| |
| class UpperBoundInferenceTest : public ValueInferenceTest { |
| public: |
| explicit UpperBoundInferenceTest(se::Platform* platform = nullptr) |
| : platform_(platform) {} |
| |
| StatusOr<OptionalLiteral> ComputeUpperBoundLiteral( |
| XlaOp operand, XlaBuilder* builder, Layout* output_layout = nullptr) { |
| ValueInference value_inference(builder); |
| TF_ASSIGN_OR_RETURN(auto literal, |
| value_inference.AnalyzeConstant( |
| operand, ValueInferenceMode::kUpperBound)); |
| return literal; |
| } |
| |
| se::Platform* platform_; |
| }; |
| |
| TEST_F(UpperBoundInferenceTest, GetDimensionSize) { |
| XlaBuilder b(TestName()); |
| auto p = |
| Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "p0"); |
| |
| auto gds0 = GetDimensionSize(p, 0); |
| auto gds1 = GetDimensionSize(p, 1); |
| auto tuple_2 = Tuple(&b, {gds0, gds1}); |
| EXPECT_EQ( |
| ComputeUpperBoundLiteral(tuple_2, &b).ValueOrDie().Get<int32>({}, {0}), |
| 2); |
| EXPECT_EQ( |
| ComputeUpperBoundLiteral(tuple_2, &b).ValueOrDie().Get<int32>({}, {1}), |
| 3); |
| } |
| |
| TEST_F(UpperBoundInferenceTest, GetDimensionSizeSub) { |
| XlaBuilder b(TestName()); |
| auto p = |
| Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "p0"); |
| |
| // The range of the first dimension is [0, 2] |
| auto gds0 = GetDimensionSize(p, 0); |
| // The range of the second dimension is [3, 3] |
| auto gds1 = GetDimensionSize(p, 1); |
| // Upper bound of `second_dimension - first_dimension` is 3 - 0 = 3 |
| auto sub = Sub(gds1, gds0); |
| EXPECT_EQ(ComputeUpperBoundLiteral(sub, &b).ValueOrDie().Get<int32>({}), 3); |
| } |
| |
| TEST_F(UpperBoundInferenceTest, GetDimensionSizeDiv) { |
| XlaBuilder b(TestName()); |
| auto p = |
| Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "p0"); |
| // The range of the first dimension is [0, 2] |
| auto gds0 = GetDimensionSize(p, 0); |
| // The range of the second dimension is [3, 3] |
| auto gds1 = GetDimensionSize(p, 1); |
| // Upper bound of `second_dimension / first_dimension` is 3 / 1 = 3. Notice we |
| // don't use 0 as the lower bound as it would create divide-by-zero error. |
| auto div = Div(gds1, gds0); |
| EXPECT_EQ(ComputeUpperBoundLiteral(div, &b).ValueOrDie().Get<int32>({}), 3); |
| } |
| |
| TEST_F(UpperBoundInferenceTest, SumSubtract) { |
| // If x = a, y = b - a |
| // upperbound(x + y) should be upperbound(b) |
| XlaBuilder b(TestName()); |
| auto p = |
| Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, true}), "p0"); |
| // The range of the first dimension is [0, 2] |
| auto gds0 = GetDimensionSize(p, 0); |
| // The range of the second dimension is [0, 3] |
| auto gds1 = GetDimensionSize(p, 1); |
| auto sub = Sub(gds1, gds0); |
| auto add = Add(sub, gds0); |
| EXPECT_EQ(ComputeUpperBoundLiteral(add, &b).ValueOrDie().Get<int32>({}), 3); |
| auto add2 = Add(gds1, gds0); |
| // upperbound(gds1 - gds0 + gds1 + gds0) ==> upperbound(2 * gds1) |
| auto add3 = Add(sub, add2); |
| EXPECT_EQ(ComputeUpperBoundLiteral(add3, &b).ValueOrDie().Get<int32>({}), 6); |
| } |
| |
| TEST_F(UpperBoundInferenceTest, SumSubtractWithDataShuffling) { |
| // Similar to the test above, but with some data shuffling ops in it |
| // (broadcast, slice, reshape, identity convert, etc). |
| XlaBuilder b(TestName()); |
| auto p = |
| Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, true}), "p0"); |
| // The range of the first dimension is [0, 2] |
| auto gds0 = GetDimensionSize(p, 0); |
| // The range of the second dimension is [0, 3] |
| auto gds1 = GetDimensionSize(p, 1); |
| auto broadcast = Broadcast(gds0, {1, 10}); |
| auto convert = ConvertElementType(broadcast, S32); // Identity convert. |
| auto slice = SliceInDim(convert, /*start_index=*/0, /*limit_index=*/1, |
| /*stride=*/1, /*dimno=*/1); |
| gds0 = Reshape(slice, {}); |
| auto sub = Sub(gds1, gds0); |
| auto add = Add(sub, gds0); |
| EXPECT_EQ(ComputeUpperBoundLiteral(add, &b).ValueOrDie().Get<int32>({}), 3); |
| auto add2 = Add(gds1, gds0); |
| // upperbound(gds1 - gds0 + gds1 + gds0) ==> upperbound(2 * gds1) |
| auto add3 = Add(sub, add2); |
| EXPECT_EQ(ComputeUpperBoundLiteral(add3, &b).ValueOrDie().Get<int32>({}), 6); |
| } |
| |
| TEST_F(UpperBoundInferenceTest, SumSubtractEquivalentGetDimensionSize) { |
| XlaBuilder b(TestName()); |
| auto p = |
| Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, true}), "p0"); |
| // The range of the first dimension is [0, 2] |
| auto gds0 = GetDimensionSize(p, 0); |
| // The range of the second dimension is [0, 3] |
| auto gds1 = GetDimensionSize(p, 1); |
| // gds2 is equivalent to gds0 |
| auto gds2 = GetDimensionSize(p, 0); |
| auto sub = Sub(gds1, gds2); |
| auto add = Add(sub, gds0); |
| // upperbound(gds0 + gds1 - gds2) is equal to upperbound(gds1) if gds0 == |
| // gds2. |
| EXPECT_EQ(ComputeUpperBoundLiteral(add, &b).ValueOrDie().Get<int32>({}), 3); |
| } |
| |
| TEST_F(UpperBoundInferenceTest, ParamCantInferBound) { |
| // We can infer a parameter's dimension's bound, but not the parameter value's |
| // bound. |
| XlaBuilder b(TestName()); |
| auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2}, {true}), "p0"); |
| auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}, {}), "p1"); |
| auto gds = GetDimensionSize(p0, 0); |
| auto sub = Div(gds, p1); |
| EXPECT_FALSE(ComputeUpperBoundLiteral(sub, &b) |
| .ValueOrDie() |
| .Get<int32>({}) |
| .has_value()); |
| } |
| |
| TEST_F(UpperBoundInferenceTest, KeyValueSort) { |
| XlaBuilder comparator_b("comparator"); |
| auto p0 = Parameter(&comparator_b, 0, ShapeUtil::MakeShape(S32, {}), "p0"); |
| auto p1 = Parameter(&comparator_b, 1, ShapeUtil::MakeShape(S32, {}), "p1"); |
| Parameter(&comparator_b, 2, ShapeUtil::MakeShape(S32, {}), "p2"); |
| Parameter(&comparator_b, 3, ShapeUtil::MakeShape(S32, {}), "p3"); |
| Compare(p0, p1, ComparisonDirection::kGe); |
| TF_ASSERT_OK_AND_ASSIGN(auto comparator, comparator_b.Build()); |
| |
| int64_t elem_count = 17; |
| XlaBuilder b(TestName()); |
| auto param = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {elem_count}), "p0"); |
| auto iota = Iota(&b, S32, elem_count); |
| auto sort = Sort({param, iota}, comparator); |
| auto gte = GetTupleElement(sort, 1); |
| |
| for (int64_t i = 0; i < elem_count; ++i) { |
| auto result_first_elem = |
| ComputeUpperBoundLiteral(gte, &b).ValueOrDie().Get<int32>({i}); |
| // We can infer the bound of sort. |
| EXPECT_TRUE(result_first_elem.has_value()); |
| // The bound of the sort result is the max value in the input. |
| EXPECT_EQ(result_first_elem.value(), elem_count - 1); |
| } |
| } |
| |
| class ConstValueInferenceTest : public ValueInferenceTest { |
| public: |
| explicit ConstValueInferenceTest(se::Platform* platform = nullptr) |
| : platform_(platform) {} |
| |
| StatusOr<OptionalLiteral> ComputeConstantValueLiteral( |
| XlaOp operand, XlaBuilder* builder, Layout* output_layout = nullptr) { |
| ValueInference value_inference(builder); |
| TF_ASSIGN_OR_RETURN(auto literal, value_inference.AnalyzeConstant( |
| operand, ValueInferenceMode::kValue)); |
| return literal; |
| } |
| |
| se::Platform* platform_; |
| }; |
| |
| TEST_F(ConstValueInferenceTest, ConstValuePassThroughSetBound) { |
| XlaBuilder b(TestName()); |
| auto p0 = ConstantR0<int32>(&b, 32); |
| Shape shape = ShapeUtil::MakeShape(S32, {}); |
| xla::Literal dynamism = xla::LiteralUtil::CreateR0<bool>(false); |
| xla::Literal bound = xla::LiteralUtil::CreateR0<int32>(32); |
| xla::Literal tuple = |
| xla::LiteralUtil::MakeTupleOwned(std::move(bound), std::move(dynamism)); |
| auto set_bound = |
| CustomCall(&b, "SetBound", {p0}, shape, "", false, {}, &tuple); |
| auto result = |
| ComputeConstantValueLiteral(set_bound, &b).ValueOrDie().Get<int32>({}); |
| EXPECT_TRUE(result.has_value()); |
| EXPECT_EQ(result.value(), 32); |
| } |
| |
| } // namespace |
| } // namespace xla |