| /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/core/framework/resource_op_kernel.h" |
| |
| #include <memory> |
| |
| #include "tensorflow/core/framework/allocator.h" |
| #include "tensorflow/core/framework/node_def_builder.h" |
| #include "tensorflow/core/framework/node_def_util.h" |
| #include "tensorflow/core/framework/op_kernel.h" |
| #include "tensorflow/core/framework/types.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/core/status_test_util.h" |
| #include "tensorflow/core/lib/strings/strcat.h" |
| #include "tensorflow/core/platform/mutex.h" |
| #include "tensorflow/core/platform/test.h" |
| #include "tensorflow/core/platform/thread_annotations.h" |
| #include "tensorflow/core/public/version.h" |
| |
| namespace tensorflow { |
| namespace { |
| |
| // Stub DeviceBase subclass which only returns allocators. |
| class StubDevice : public DeviceBase { |
| public: |
| StubDevice() : DeviceBase(nullptr) {} |
| |
| Allocator* GetAllocator(AllocatorAttributes) override { |
| return cpu_allocator(); |
| } |
| }; |
| |
| // Stub resource for testing resource op kernel. |
| class StubResource : public ResourceBase { |
| public: |
| string DebugString() const override { return ""; } |
| int code; |
| }; |
| |
| class StubResourceOpKernel : public ResourceOpKernel<StubResource> { |
| public: |
| using ResourceOpKernel::ResourceOpKernel; |
| |
| StubResource* resource() LOCKS_EXCLUDED(mu_) { |
| mutex_lock lock(mu_); |
| return resource_; |
| } |
| |
| private: |
| Status CreateResource(StubResource** resource) override { |
| *resource = CHECK_NOTNULL(new StubResource); |
| return GetNodeAttr(def(), "code", &(*resource)->code); |
| } |
| |
| Status VerifyResource(StubResource* resource) override { |
| int code; |
| TF_RETURN_IF_ERROR(GetNodeAttr(def(), "code", &code)); |
| if (code != resource->code) { |
| return errors::InvalidArgument("stub has code ", resource->code, |
| " but requested code ", code); |
| } |
| return Status::OK(); |
| } |
| }; |
| |
| REGISTER_OP("StubResourceOp") |
| .Attr("code: int") |
| .Attr("container: string = ''") |
| .Attr("shared_name: string = ''") |
| .Output("output: Ref(string)"); |
| |
| REGISTER_KERNEL_BUILDER(Name("StubResourceOp").Device(DEVICE_CPU), |
| StubResourceOpKernel); |
| |
| class ResourceOpKernelTest : public ::testing::Test { |
| protected: |
| std::unique_ptr<StubResourceOpKernel> CreateOp(int code, |
| const string& shared_name) { |
| NodeDef node_def; |
| TF_CHECK_OK( |
| NodeDefBuilder(strings::StrCat("test-node", count_++), "StubResourceOp") |
| .Attr("code", code) |
| .Attr("shared_name", shared_name) |
| .Finalize(&node_def)); |
| Status status; |
| std::unique_ptr<OpKernel> op(CreateOpKernel( |
| DEVICE_CPU, &device_, device_.GetAllocator(AllocatorAttributes()), |
| node_def, TF_GRAPH_DEF_VERSION, &status)); |
| TF_EXPECT_OK(status) << status; |
| EXPECT_TRUE(op != nullptr); |
| |
| // Downcast to StubResourceOpKernel to call resource() later. |
| std::unique_ptr<StubResourceOpKernel> resource_op( |
| dynamic_cast<StubResourceOpKernel*>(op.get())); |
| EXPECT_TRUE(resource_op != nullptr); |
| if (resource_op != nullptr) { |
| op.release(); |
| } |
| return resource_op; |
| } |
| |
| Status RunOpKernel(OpKernel* op) { |
| OpKernelContext::Params params; |
| |
| params.device = &device_; |
| params.resource_manager = &mgr_; |
| params.op_kernel = op; |
| |
| OpKernelContext context(¶ms); |
| op->Compute(&context); |
| return context.status(); |
| } |
| |
| StubDevice device_; |
| ResourceMgr mgr_; |
| int count_ = 0; |
| }; |
| |
| TEST_F(ResourceOpKernelTest, PrivateResource) { |
| // Empty shared_name means private resource. |
| const int code = -100; |
| auto op = CreateOp(code, ""); |
| ASSERT_TRUE(op != nullptr); |
| TF_EXPECT_OK(RunOpKernel(op.get())); |
| |
| // Default non-shared name provided from ContainerInfo. |
| const string key = "_0_" + op->name(); |
| |
| StubResource* resource; |
| TF_ASSERT_OK( |
| mgr_.Lookup<StubResource>(mgr_.default_container(), key, &resource)); |
| EXPECT_EQ(op->resource(), resource); // Check resource identity. |
| EXPECT_EQ(code, resource->code); // Check resource stored information. |
| resource->Unref(); |
| |
| // Destroy the op kernel. Expect the resource to be released. |
| op = nullptr; |
| Status s = |
| mgr_.Lookup<StubResource>(mgr_.default_container(), key, &resource); |
| |
| EXPECT_FALSE(s.ok()); |
| } |
| |
| TEST_F(ResourceOpKernelTest, SharedResource) { |
| const string shared_name = "shared_stub"; |
| const int code = -201; |
| auto op = CreateOp(code, shared_name); |
| ASSERT_TRUE(op != nullptr); |
| TF_EXPECT_OK(RunOpKernel(op.get())); |
| |
| StubResource* resource; |
| TF_ASSERT_OK(mgr_.Lookup<StubResource>(mgr_.default_container(), shared_name, |
| &resource)); |
| EXPECT_EQ(op->resource(), resource); // Check resource identity. |
| EXPECT_EQ(code, resource->code); // Check resource stored information. |
| resource->Unref(); |
| |
| // Destroy the op kernel. Expect the resource not to be released. |
| op = nullptr; |
| TF_ASSERT_OK(mgr_.Lookup<StubResource>(mgr_.default_container(), shared_name, |
| &resource)); |
| resource->Unref(); |
| } |
| |
| TEST_F(ResourceOpKernelTest, LookupShared) { |
| auto op1 = CreateOp(-333, "shared_stub"); |
| auto op2 = CreateOp(-333, "shared_stub"); |
| ASSERT_TRUE(op1 != nullptr); |
| ASSERT_TRUE(op2 != nullptr); |
| |
| TF_EXPECT_OK(RunOpKernel(op1.get())); |
| TF_EXPECT_OK(RunOpKernel(op2.get())); |
| EXPECT_EQ(op1->resource(), op2->resource()); |
| } |
| |
| TEST_F(ResourceOpKernelTest, VerifyResource) { |
| auto op1 = CreateOp(-444, "shared_stub"); |
| auto op2 = CreateOp(0, "shared_stub"); // Different resource code. |
| ASSERT_TRUE(op1 != nullptr); |
| ASSERT_TRUE(op2 != nullptr); |
| |
| TF_EXPECT_OK(RunOpKernel(op1.get())); |
| EXPECT_FALSE(RunOpKernel(op2.get()).ok()); |
| EXPECT_TRUE(op1->resource() != nullptr); |
| EXPECT_TRUE(op2->resource() == nullptr); |
| } |
| |
| } // namespace |
| } // namespace tensorflow |