blob: e6fefd77dc27bb32e5737fc7d2b393eb1f259543 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XRT_XRT_MEMORY_MANAGER_H_
#define TENSORFLOW_COMPILER_XRT_XRT_MEMORY_MANAGER_H_
#include <memory>
#include <vector>
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/compiler/xrt/xrt_refptr.h"
#include "tensorflow/compiler/xrt/xrt_state.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/device_memory_allocator.h"
#include "tensorflow/stream_executor/stream_executor.h"
namespace tensorflow {
// The XRTMemoryManager manages all the XRT allocations. It is a ResourceBase
// object which leaves within the ResourceMgr. This is only one XRT memory
// manager object within the ResourceMgr container.
class XRTMemoryManager : public ResourceBase {
// The DeviceContext class, defined and implemented locally inside the
// xrt_memory_manager.cc file, holds, for each device, all the information
// related to the XRT memory management for such device.
class DeviceContext;
public:
// A working set is a set of tuple allocations which are the input of a given
// operation, and as such they must be pinned on the device memory. The tuple
// allocations added to the WorkingSet will be unpinned at object destruction.
class WorkingSet {
public:
explicit WorkingSet(RefPtr<XRTMemoryManager> memory_manager);
~WorkingSet();
// Looks up the tuple handle within the memory manager, and pins it to the
// device (if not already pinned).
Status LookupAndPin(xla::Backend* backend, int64_t handle,
se::DeviceMemoryAllocator* allocator);
const std::vector<RefPtr<XRTTupleAllocation>>& PinnedTuples() const {
return pinned_tuples_;
}
const RefPtr<XRTMemoryManager>& MemoryManager() const {
return memory_manager_;
}
private:
RefPtr<XRTMemoryManager> memory_manager_;
std::vector<RefPtr<XRTTupleAllocation>> pinned_tuples_;
};
// Retrieves the XRTMemoryManager singleton stored within the ResourceMgr.
static RefPtr<XRTMemoryManager> Get(ResourceMgr* rm);
// Registers an XRTTupleAllocation and returns the unique handle identifying
// it.
int64_t Register(RefPtr<XRTTupleAllocation> tuple);
// Looks up an handle returned by the Register() API and returns the
// XRTTupleAllocation behind it.
xla::StatusOr<RefPtr<XRTTupleAllocation>> Lookup(int64_t handle);
Status Lookup(int64_t handle, RefPtr<XRTTupleAllocation>* tuple) {
TF_ASSIGN_OR_RETURN(*tuple, Lookup(handle));
return OkStatus();
}
// Releases an handle by dropping the references count held on the
// XRTTupleAllocation by the XRTMemoryManager. Existing XRTTupleAllocation
// references will continue to be valid.
Status Release(int64_t handle);
// Tries to compact all the memory allocations on a given device. This is
// currently done by swapping-out all the existing allocation, and swapping
// them back in.
Status CompactAllocations(xla::Backend* backend, int device_ordinal,
se::DeviceMemoryAllocator* allocator);
// Releases all the device memory allocated by XRT within the resource
// manager.
void ReleaseAllAllocations();
// Tries to allocate size bytes of device memory from the device_ordinal
// device. Might attempt to free some unpinned device memory, if the underline
// allocator call fails, and try the allocation again.
xla::StatusOr<se::OwningDeviceMemory> Allocate(
xla::Backend* backend, int device_ordinal, size_t size,
se::DeviceMemoryAllocator* allocator);
// Runs the specified function and handling the error::RESOURCE_EXHAUSTED
// status code coming out of it. In such cases, we run different memory
// freeing operations trying to make runfn succeed. The requested_free_size
// argument represents an hint of the requested memory size which would make
// runfn succeed.
template <typename T>
xla::StatusOr<T> Run(const std::function<xla::StatusOr<T>()>& runfn,
xla::Backend* backend, int device_ordinal,
size_t requested_free_size,
se::DeviceMemoryAllocator* allocator);
string DebugString() const override;
// Returns the invalid key value, which will be never generated by the
// Intern() API.
static int64_t InvalidKey() { return 0; }
private:
// Structure used to track the progress of a try-to-free operation. It is
// initialized and the passed to the TryFreeMemoryStep() API.
struct MemoryReclaimContext {
MemoryReclaimContext(xla::Backend* backend, int device_ordinal,
size_t requested_free_size,
se::DeviceMemoryAllocator* specific_allocator)
: backend(backend),
device_ordinal(device_ordinal),
requested_free_size(requested_free_size) {
allocator = specific_allocator;
}
xla::Backend* const backend = nullptr;
se::DeviceMemoryAllocator* allocator = nullptr;
const int device_ordinal = 0;
const size_t requested_free_size = 0;
size_t free_size = 0;
bool done_freeing = false;
bool done_compacting = false;
};
DeviceContext* GetDeviceContext(int device_ordinal, bool create_if_missing);
// Called multiple times while trying to make a memory consuming function call
// to fit. Performs progressively more expensive memory reduction operations,
// until returning error::RESOURCE_EXHAUSTED when no further reductions are
// possible.
Status TryFreeMemoryStep(MemoryReclaimContext* mrctx, const Status& status);
mutex lock_;
std::vector<std::unique_ptr<DeviceContext>> device_contexts_;
};
template <typename T>
xla::StatusOr<T> XRTMemoryManager::Run(
const std::function<xla::StatusOr<T>()>& runfn, xla::Backend* backend,
int device_ordinal, size_t requested_free_size,
se::DeviceMemoryAllocator* allocator) {
MemoryReclaimContext mrctx(backend, device_ordinal, requested_free_size,
allocator);
while (true) {
// We assume that runfn is a relatively fast-fail function compared to the
// operations required to free up the required memory. Here we call into the
// TryFreeMemoryStep() API multiple times, which will run progressively more
// expensive operations.
auto result_or = runfn();
if (result_or.status().code() != error::RESOURCE_EXHAUSTED) {
return result_or;
}
TF_RETURN_IF_ERROR(TryFreeMemoryStep(&mrctx, result_or.status()));
}
}
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_XRT_XRT_MEMORY_MANAGER_H_