| /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/core/framework/variant.h" |
| |
| #include <xmmintrin.h> |
| |
| #include <vector> |
| |
| #include "tensorflow/core/framework/tensor.h" |
| #include "tensorflow/core/framework/tensor.pb.h" |
| #include "tensorflow/core/framework/tensor_shape.pb.h" |
| #include "tensorflow/core/framework/variant_encode_decode.h" |
| #include "tensorflow/core/framework/variant_tensor_data.h" |
| #include "tensorflow/core/kernels/ops_testutil.h" |
| #include "tensorflow/core/lib/core/coding.h" |
| #include "tensorflow/core/platform/test.h" |
| |
| namespace tensorflow { |
| |
| namespace { |
| |
| template <typename T, bool BIG> |
| struct Wrapper { |
| T value; |
| char big[BIG ? 256 : 0]; |
| string TypeName() const { return "POD"; } |
| }; |
| |
| template <bool BIG> |
| using Int = Wrapper<int, BIG>; |
| |
| template <bool BIG> |
| using Float = Wrapper<float, BIG>; |
| |
| template <bool BIG> |
| class MaybeAlive { |
| public: |
| MaybeAlive() : alive_(false) {} |
| |
| explicit MaybeAlive(bool alive) : alive_(alive) { |
| if (alive) ++live_counter_; |
| } |
| |
| ~MaybeAlive() { |
| if (alive_) --live_counter_; |
| } |
| |
| MaybeAlive(const MaybeAlive& rhs) : alive_(rhs.alive_) { |
| if (alive_) ++live_counter_; |
| } |
| |
| MaybeAlive& operator=(const MaybeAlive& rhs) { |
| if (this == &rhs) return *this; |
| if (alive_) --live_counter_; |
| alive_ = rhs.alive_; |
| if (alive_) ++live_counter_; |
| return *this; |
| } |
| |
| MaybeAlive(MaybeAlive&& rhs) : alive_(false) { |
| alive_ = std::move(rhs.alive_); |
| if (alive_) ++live_counter_; |
| } |
| |
| MaybeAlive& operator=(MaybeAlive&& rhs) { |
| if (this == &rhs) return *this; |
| if (alive_) --live_counter_; |
| alive_ = std::move(rhs.alive_); |
| if (alive_) ++live_counter_; |
| return *this; |
| } |
| |
| static int LiveCounter() { return live_counter_; } |
| |
| string TypeName() const { return "MaybeAlive"; } |
| void Encode(VariantTensorData* data) const {} |
| bool Decode(VariantTensorData data) { return false; } |
| |
| private: |
| bool alive_; |
| char big_[BIG ? 256 : 0]; |
| static int live_counter_; |
| }; |
| |
| template <> |
| int MaybeAlive<false>::live_counter_ = 0; |
| template <> |
| int MaybeAlive<true>::live_counter_ = 0; |
| |
| template <bool BIG> |
| class DeleteCounter { |
| public: |
| DeleteCounter() : big_{}, counter_(nullptr) {} |
| explicit DeleteCounter(int* counter) : big_{}, counter_(counter) {} |
| ~DeleteCounter() { |
| if (counter_) ++*counter_; |
| } |
| // Need custom move operations because int* just gets copied on move, but we |
| // need to clear counter_ on move. |
| DeleteCounter& operator=(const DeleteCounter& rhs) = default; |
| DeleteCounter& operator=(DeleteCounter&& rhs) { |
| if (this == &rhs) return *this; |
| counter_ = rhs.counter_; |
| rhs.counter_ = nullptr; |
| return *this; |
| } |
| DeleteCounter(DeleteCounter&& rhs) { |
| counter_ = rhs.counter_; |
| rhs.counter_ = nullptr; |
| } |
| DeleteCounter(const DeleteCounter& rhs) = default; |
| char big_[BIG ? 256 : 0]; |
| int* counter_; |
| |
| string TypeName() const { return "DeleteCounter"; } |
| void Encode(VariantTensorData* data) const {} |
| bool Decode(VariantTensorData data) { return false; } |
| }; |
| |
| } // end namespace |
| |
| TEST(VariantTest, MoveAndCopyBetweenBigAndSmall) { |
| Variant x; |
| int deleted_big = 0; |
| int deleted_small = 0; |
| x = DeleteCounter</*BIG=*/true>(&deleted_big); |
| EXPECT_EQ(deleted_big, 0); |
| x = DeleteCounter</*BIG=*/false>(&deleted_small); |
| EXPECT_EQ(deleted_big, 1); |
| EXPECT_EQ(deleted_small, 0); |
| x = DeleteCounter</*BIG=*/true>(&deleted_big); |
| EXPECT_EQ(deleted_big, 1); |
| EXPECT_EQ(deleted_small, 1); |
| x.clear(); |
| EXPECT_EQ(deleted_big, 2); |
| EXPECT_EQ(deleted_small, 1); |
| DeleteCounter</*BIG=*/true> big(&deleted_big); |
| DeleteCounter</*BIG=*/false> small(&deleted_small); |
| EXPECT_EQ(deleted_big, 2); |
| EXPECT_EQ(deleted_small, 1); |
| x = big; |
| EXPECT_EQ(deleted_big, 2); |
| EXPECT_EQ(deleted_small, 1); |
| x = small; |
| EXPECT_EQ(deleted_big, 3); |
| EXPECT_EQ(deleted_small, 1); |
| x = std::move(big); |
| EXPECT_EQ(deleted_big, 3); |
| EXPECT_EQ(deleted_small, 2); |
| x = std::move(small); |
| EXPECT_EQ(deleted_big, 4); |
| EXPECT_EQ(deleted_small, 2); |
| x.clear(); |
| EXPECT_EQ(deleted_big, 4); |
| EXPECT_EQ(deleted_small, 3); |
| } |
| |
| TEST(VariantTest, MoveAndCopyBetweenBigAndSmallVariants) { |
| int deleted_big = 0; |
| int deleted_small = 0; |
| { |
| Variant x = DeleteCounter</*BIG=*/true>(&deleted_big); |
| Variant y = DeleteCounter</*BIG=*/false>(&deleted_small); |
| EXPECT_EQ(deleted_big, 0); |
| EXPECT_EQ(deleted_small, 0); |
| x = y; |
| EXPECT_EQ(deleted_big, 1); |
| EXPECT_EQ(deleted_small, 0); |
| x = x; |
| EXPECT_EQ(deleted_big, 1); |
| EXPECT_EQ(deleted_small, 0); |
| EXPECT_NE(x.get<DeleteCounter<false>>(), nullptr); |
| EXPECT_NE(y.get<DeleteCounter<false>>(), nullptr); |
| x = std::move(y); |
| EXPECT_EQ(deleted_small, 1); |
| EXPECT_NE(x.get<DeleteCounter<false>>(), nullptr); |
| } |
| EXPECT_EQ(deleted_big, 1); |
| EXPECT_EQ(deleted_small, 2); |
| |
| deleted_big = 0; |
| deleted_small = 0; |
| { |
| Variant x = DeleteCounter</*BIG=*/false>(&deleted_small); |
| Variant y = DeleteCounter</*BIG=*/true>(&deleted_big); |
| EXPECT_EQ(deleted_big, 0); |
| EXPECT_EQ(deleted_small, 0); |
| x = y; |
| EXPECT_EQ(deleted_big, 0); |
| EXPECT_EQ(deleted_small, 1); |
| x = x; |
| EXPECT_EQ(deleted_big, 0); |
| EXPECT_EQ(deleted_small, 1); |
| EXPECT_NE(x.get<DeleteCounter<true>>(), nullptr); |
| EXPECT_NE(y.get<DeleteCounter<true>>(), nullptr); |
| x = std::move(y); |
| EXPECT_EQ(deleted_big, 1); |
| EXPECT_NE(x.get<DeleteCounter<true>>(), nullptr); |
| } |
| EXPECT_EQ(deleted_big, 2); |
| EXPECT_EQ(deleted_small, 1); |
| } |
| |
| template <bool BIG> |
| void TestDestructOnVariantMove() { |
| CHECK_EQ(MaybeAlive<BIG>::LiveCounter(), 0); |
| { |
| Variant a = MaybeAlive<BIG>(true); |
| Variant b = std::move(a); |
| } |
| EXPECT_EQ(MaybeAlive<BIG>::LiveCounter(), 0); |
| } |
| |
| TEST(VariantTest, RHSDestructOnVariantMoveBig) { |
| TestDestructOnVariantMove</*BIG=*/true>(); |
| } |
| |
| TEST(VariantTest, RHSDestructOnVariantMoveSmall) { |
| TestDestructOnVariantMove</*BIG=*/false>(); |
| } |
| |
| TEST(VariantTest, Int) { |
| Variant x; |
| EXPECT_EQ(x.get<void>(), nullptr); |
| x = 3; |
| EXPECT_NE(x.get<void>(), nullptr); |
| EXPECT_EQ(*x.get<int>(), 3); |
| EXPECT_EQ(x.TypeName(), "int"); |
| } |
| |
| struct MayCreateAlignmentDifficulties { |
| int a; |
| __m128 b; |
| }; |
| |
| bool M128AllEqual(const __m128& a, const __m128& b) { |
| return _mm_movemask_ps(_mm_cmpeq_ps(a, b)) == 0xf; |
| } |
| |
| TEST(VariantTest, NotAlignable) { |
| Variant x; |
| EXPECT_EQ(x.get<void>(), nullptr); |
| __m128 v = _mm_set_ps(1.0, 2.0, 3.0, 4.0); |
| x = MayCreateAlignmentDifficulties{-1, v}; |
| EXPECT_NE(x.get<void>(), nullptr); |
| auto* x_val = x.get<MayCreateAlignmentDifficulties>(); |
| // check that *x_val == x |
| Variant y = x; |
| EXPECT_EQ(x_val->a, -1); |
| EXPECT_TRUE(M128AllEqual(x_val->b, v)); |
| auto* y_val = y.get<MayCreateAlignmentDifficulties>(); |
| EXPECT_EQ(y_val->a, -1); |
| EXPECT_TRUE(M128AllEqual(y_val->b, v)); |
| Variant z = std::move(y); |
| auto* z_val = z.get<MayCreateAlignmentDifficulties>(); |
| EXPECT_EQ(z_val->a, -1); |
| EXPECT_TRUE(M128AllEqual(z_val->b, v)); |
| } |
| |
| template <bool BIG> |
| void TestBasic() { |
| Variant x; |
| EXPECT_EQ(x.get<void>(), nullptr); |
| |
| x = Int<BIG>{42}; |
| |
| EXPECT_NE(x.get<void>(), nullptr); |
| EXPECT_NE(x.get<Int<BIG>>(), nullptr); |
| EXPECT_EQ(x.get<Int<BIG>>()->value, 42); |
| EXPECT_EQ(x.TypeName(), "POD"); |
| } |
| |
| TEST(VariantTest, Basic) { TestBasic<false>(); } |
| |
| TEST(VariantTest, BasicBig) { TestBasic<true>(); } |
| |
| template <bool BIG> |
| void TestConstGet() { |
| Variant x; |
| EXPECT_EQ(x.get<void>(), nullptr); |
| |
| x = Int<BIG>{42}; |
| |
| const Variant y = x; |
| |
| EXPECT_NE(y.get<void>(), nullptr); |
| EXPECT_NE(y.get<Int<BIG>>(), nullptr); |
| EXPECT_EQ(y.get<Int<BIG>>()->value, 42); |
| } |
| |
| TEST(VariantTest, ConstGet) { TestConstGet<false>(); } |
| |
| TEST(VariantTest, ConstGetBig) { TestConstGet<true>(); } |
| |
| template <bool BIG> |
| void TestClear() { |
| Variant x; |
| EXPECT_EQ(x.get<void>(), nullptr); |
| |
| x = Int<BIG>{42}; |
| |
| EXPECT_NE(x.get<void>(), nullptr); |
| EXPECT_NE(x.get<Int<BIG>>(), nullptr); |
| EXPECT_EQ(x.get<Int<BIG>>()->value, 42); |
| |
| x.clear(); |
| EXPECT_EQ(x.get<void>(), nullptr); |
| } |
| |
| TEST(VariantTest, Clear) { TestClear<false>(); } |
| |
| TEST(VariantTest, ClearBig) { TestClear<true>(); } |
| |
| template <bool BIG> |
| void TestClearDeletes() { |
| Variant x; |
| EXPECT_EQ(x.get<void>(), nullptr); |
| |
| int deleted_count = 0; |
| using DC = DeleteCounter<BIG>; |
| DC dc(&deleted_count); |
| EXPECT_EQ(deleted_count, 0); |
| x = dc; |
| EXPECT_EQ(deleted_count, 0); |
| |
| EXPECT_NE(x.get<void>(), nullptr); |
| EXPECT_NE(x.get<DC>(), nullptr); |
| |
| x.clear(); |
| EXPECT_EQ(x.get<void>(), nullptr); |
| EXPECT_EQ(deleted_count, 1); |
| |
| x = dc; |
| EXPECT_EQ(deleted_count, 1); |
| |
| Variant y = x; |
| EXPECT_EQ(deleted_count, 1); |
| |
| x.clear(); |
| EXPECT_EQ(deleted_count, 2); |
| |
| y.clear(); |
| EXPECT_EQ(deleted_count, 3); |
| } |
| |
| TEST(VariantTest, ClearDeletesOnHeap) { TestClearDeletes</*BIG=*/true>(); } |
| |
| TEST(VariantTest, ClearDeletesOnStack) { TestClearDeletes</*BIG=*/false>(); } |
| |
| TEST(VariantTest, Tensor) { |
| Variant x; |
| Tensor t(DT_FLOAT, {}); |
| t.flat<float>()(0) = 42.0f; |
| x = t; |
| |
| EXPECT_NE(x.get<Tensor>(), nullptr); |
| EXPECT_EQ(x.get<Tensor>()->flat<float>()(0), 42.0f); |
| x.get<Tensor>()->flat<float>()(0) += 1.0f; |
| EXPECT_EQ(x.get<Tensor>()->flat<float>()(0), 43.0f); |
| EXPECT_EQ(x.TypeName(), "tensorflow::Tensor"); |
| } |
| |
| TEST(VariantTest, NontrivialTensorVariantCopy) { |
| Tensor variants(DT_VARIANT, {}); |
| Tensor t(true); |
| test::FillValues<Variant>(&variants, gtl::ArraySlice<Variant>({t})); |
| const Tensor* t_c = variants.flat<Variant>()(0).get<Tensor>(); |
| EXPECT_EQ(t_c->dtype(), t.dtype()); |
| EXPECT_EQ(t_c->shape(), t.shape()); |
| EXPECT_EQ(t_c->scalar<bool>()(), t.scalar<bool>()()); |
| } |
| |
| TEST(VariantTest, TensorProto) { |
| Variant x; |
| TensorProto t; |
| t.set_dtype(DT_FLOAT); |
| t.mutable_tensor_shape()->set_unknown_rank(true); |
| x = t; |
| |
| EXPECT_EQ(x.TypeName(), "tensorflow.TensorProto"); |
| EXPECT_NE(x.get<TensorProto>(), nullptr); |
| EXPECT_EQ(x.get<TensorProto>()->dtype(), DT_FLOAT); |
| EXPECT_EQ(x.get<TensorProto>()->tensor_shape().unknown_rank(), true); |
| } |
| |
| template <bool BIG> |
| void TestCopyValue() { |
| Variant x, y; |
| x = Int<BIG>{10}; |
| y = x; |
| |
| EXPECT_EQ(x.get<Int<BIG>>()->value, 10); |
| EXPECT_EQ(x.get<Int<BIG>>()->value, y.get<Int<BIG>>()->value); |
| } |
| |
| TEST(VariantTest, CopyValue) { TestCopyValue<false>(); } |
| |
| TEST(VariantTest, CopyValueBig) { TestCopyValue<true>(); } |
| |
| template <bool BIG> |
| void TestMoveValue() { |
| Variant x; |
| x = []() -> Variant { |
| Variant y; |
| y = Int<BIG>{10}; |
| return y; |
| }(); |
| EXPECT_EQ(x.get<Int<BIG>>()->value, 10); |
| } |
| |
| TEST(VariantTest, MoveValue) { TestMoveValue<false>(); } |
| |
| TEST(VariantTest, MoveValueBig) { TestMoveValue<true>(); } |
| |
| TEST(VariantTest, TypeMismatch) { |
| Variant x; |
| x = Int<false>{10}; |
| EXPECT_EQ(x.get<float>(), nullptr); |
| EXPECT_EQ(x.get<int>(), nullptr); |
| EXPECT_NE(x.get<Int<false>>(), nullptr); |
| } |
| |
| struct TensorList { |
| void Encode(VariantTensorData* data) const { data->tensors_ = vec; } |
| |
| bool Decode(VariantTensorData data) { |
| vec = std::move(data.tensors_); |
| return true; |
| } |
| |
| string TypeName() const { return "TensorList"; } |
| |
| std::vector<Tensor> vec; |
| }; |
| |
| TEST(VariantTest, TensorListTest) { |
| Variant x; |
| |
| TensorList vec; |
| for (int i = 0; i < 4; ++i) { |
| Tensor elem(DT_INT32, {1}); |
| elem.flat<int>()(0) = i; |
| vec.vec.push_back(elem); |
| } |
| |
| for (int i = 0; i < 4; ++i) { |
| Tensor elem(DT_FLOAT, {1}); |
| elem.flat<float>()(0) = 2 * i; |
| vec.vec.push_back(elem); |
| } |
| |
| x = vec; |
| |
| EXPECT_EQ(x.TypeName(), "TensorList"); |
| EXPECT_EQ(x.DebugString(), "Variant<type: TensorList value: ?>"); |
| const TensorList& stored_vec = *x.get<TensorList>(); |
| for (int i = 0; i < 4; ++i) { |
| EXPECT_EQ(stored_vec.vec[i].flat<int>()(0), i); |
| } |
| for (int i = 0; i < 4; ++i) { |
| EXPECT_EQ(stored_vec.vec[i + 4].flat<float>()(0), 2 * i); |
| } |
| |
| VariantTensorData serialized; |
| x.Encode(&serialized); |
| |
| Variant y = TensorList(); |
| y.Decode(serialized); |
| |
| const TensorList& decoded_vec = *y.get<TensorList>(); |
| for (int i = 0; i < 4; ++i) { |
| EXPECT_EQ(decoded_vec.vec[i].flat<int>()(0), i); |
| } |
| for (int i = 0; i < 4; ++i) { |
| EXPECT_EQ(decoded_vec.vec[i + 4].flat<float>()(0), 2 * i); |
| } |
| |
| VariantTensorDataProto data; |
| serialized.ToProto(&data); |
| const Variant y_unknown = data; |
| EXPECT_EQ(y_unknown.TypeName(), "TensorList"); |
| EXPECT_EQ(y_unknown.TypeId(), MakeTypeIndex<VariantTensorDataProto>()); |
| EXPECT_EQ(y_unknown.DebugString(), |
| strings::StrCat( |
| "Variant<type: TensorList value: ", data.DebugString(), ">")); |
| } |
| |
| template <bool BIG> |
| void TestVariantArray() { |
| Variant x[2]; |
| x[0] = Int<BIG>{2}; |
| x[1] = Float<BIG>{2.0f}; |
| |
| EXPECT_EQ(x[0].get<Int<BIG>>()->value, 2); |
| EXPECT_EQ(x[1].get<Float<BIG>>()->value, 2.0f); |
| } |
| |
| TEST(VariantTest, VariantArray) { TestVariantArray<false>(); } |
| |
| TEST(VariantTest, VariantArrayBig) { TestVariantArray<true>(); } |
| |
| template <bool BIG> |
| void PodUpdateTest() { |
| struct Pod { |
| int x; |
| float y; |
| char big[BIG ? 256 : 0]; |
| |
| string TypeName() const { return "POD"; } |
| }; |
| |
| Variant x = Pod{10, 20.f}; |
| EXPECT_NE(x.get<Pod>(), nullptr); |
| EXPECT_EQ(x.TypeName(), "POD"); |
| EXPECT_EQ(x.DebugString(), "Variant<type: POD value: ?>"); |
| |
| x.get<Pod>()->x += x.get<Pod>()->y; |
| EXPECT_EQ(x.get<Pod>()->x, 30); |
| } |
| |
| TEST(VariantTest, PodUpdate) { PodUpdateTest<false>(); } |
| |
| TEST(VariantTest, PodUpdateBig) { PodUpdateTest<true>(); } |
| |
| template <bool BIG> |
| void TestEncodeDecodePod() { |
| struct Pod { |
| int x; |
| float y; |
| char big[BIG ? 256 : 0]; |
| |
| string TypeName() const { return "POD"; } |
| }; |
| |
| Variant x; |
| Pod p{10, 20.0f}; |
| x = p; |
| |
| VariantTensorData serialized; |
| x.Encode(&serialized); |
| |
| Variant y = Pod{}; |
| y.Decode(serialized); |
| |
| EXPECT_EQ(p.x, y.get<Pod>()->x); |
| EXPECT_EQ(p.y, y.get<Pod>()->y); |
| } |
| |
| TEST(VariantTest, EncodeDecodePod) { TestEncodeDecodePod<false>(); } |
| |
| TEST(VariantTest, EncodeDecodePodBig) { TestEncodeDecodePod<true>(); } |
| |
| TEST(VariantTest, EncodeDecodeTensor) { |
| Variant x; |
| Tensor t(DT_INT32, {}); |
| t.flat<int>()(0) = 42; |
| x = t; |
| |
| VariantTensorData serialized; |
| x.Encode(&serialized); |
| |
| Variant y = Tensor(); |
| y.Decode(serialized); |
| EXPECT_EQ(y.DebugString(), |
| "Variant<type: tensorflow::Tensor value: Tensor<type: int32 shape: " |
| "[] values: 42>>"); |
| EXPECT_EQ(x.get<Tensor>()->flat<int>()(0), y.get<Tensor>()->flat<int>()(0)); |
| } |
| |
| } // end namespace tensorflow |