Revert "[LTC] Make some LazyGraphExecutor private data structures protected (#90457)"
This reverts commit 93aa6e3e36c022a01076d84047acd58b59244348.
Reverted https://github.com/pytorch/pytorch/pull/90457 on behalf of https://github.com/clee2000 due to broke xla somehow https://hud.pytorch.org/pytorch/pytorch/commit/93aa6e3e36c022a01076d84047acd58b59244348 https://github.com/pytorch/pytorch/actions/runs/3659842773/jobs/6186552659
diff --git a/torch/csrc/lazy/core/lazy_graph_executor.cpp b/torch/csrc/lazy/core/lazy_graph_executor.cpp
index 7719abf..acab845 100644
--- a/torch/csrc/lazy/core/lazy_graph_executor.cpp
+++ b/torch/csrc/lazy/core/lazy_graph_executor.cpp
@@ -48,6 +48,190 @@
contiguous_t1.numel() * contiguous_t1.itemsize()) == 0;
}
+// Locking:
+// We perform two kinds of operations of tensors, synchronous and asynchronous.
+// The ApplyPendingGraph() are synchronous, as we need the device data result
+// immediately. Before the synchronous operations can start, they need to wait
+// that the pending asynchronous operations have completed.
+// Synchronous operations do not hold device locks, since they are strictly
+// sequential, dictated by the PyTorch execution order.
+// The SyncTensorsGraph() is asynchronous, and returns immediately after having
+// scheduled the asynchronous operation. While executing, the asynchronous
+// operations will hold locks on all the participating devices (in most common
+// cases there will be only one device).
+// Since asynchronous operations capture device locks, only one asynchronous
+// operation can execute at the same time, on a given device. Tensor operations
+// which send data to device do not need to hold any device locks while doing
+// so. Only operations which _use_ device data (computations, and transfer from
+// server) need to wait for asynchronous operations to complete (barrier).
+
+class DeviceLocker {
+ public:
+ explicit DeviceLocker(BackendDevice device) : device_(std::move(device)) {}
+
+ const BackendDevice& device() const {
+ return device_;
+ }
+
+ void Lock() {
+ std::unique_lock<std::mutex> lock(mutex_);
+ cv_.wait(lock, [this] { return !locked_; });
+ CheckResetException();
+ locked_ = true;
+ }
+
+ void Unlock(std::exception_ptr exptr) {
+ std::lock_guard<std::mutex> lock(mutex_);
+ locked_ = false;
+ exptr_ = std::move(exptr);
+ cv_.notify_all();
+ }
+
+ void Barrier() {
+ std::unique_lock<std::mutex> lock(mutex_);
+ cv_.wait(lock, [this] { return !locked_; });
+ cv_.notify_all();
+ CheckResetException();
+ }
+
+ private:
+ void CheckResetException() {
+ std::exception_ptr exptr = std::move(exptr_);
+ exptr_ = nullptr;
+ if (exptr != nullptr) {
+ std::rethrow_exception(exptr);
+ }
+ }
+
+ BackendDevice device_;
+ std::mutex mutex_;
+ std::condition_variable cv_;
+ bool locked_ = false;
+ std::exception_ptr exptr_;
+};
+
+class DeviceLockerArena {
+ public:
+ static DeviceLockerArena* Get() {
+ static DeviceLockerArena* arena = new DeviceLockerArena();
+ return arena;
+ }
+
+ std::shared_ptr<DeviceLocker> GetLocker(const BackendDevice& device) {
+ std::lock_guard<std::mutex> lock(mutex_);
+ auto it = lockers_.find(device);
+ if (it == lockers_.end()) {
+ it = lockers_.emplace(device, std::make_shared<DeviceLocker>(device))
+ .first;
+ }
+ return it->second;
+ }
+
+ void DeviceBarrier(const BackendDevice& device) {
+ auto locker = DeviceLockerArena::Get()->GetLocker(device);
+ locker->Barrier();
+ }
+
+ // Use a set to impose an order on the device locking sequence (ABBA
+ // prevention).
+ std::vector<ExceptionCleanup> LockDevices(
+ const std::set<BackendDevice>& devices) {
+ std::vector<ExceptionCleanup> unlocker;
+ unlocker.reserve(devices.size());
+ for (auto& device : devices) {
+ unlocker.emplace_back(LockDevice(device));
+ }
+ return unlocker;
+ }
+
+ private:
+ ExceptionCleanup LockDevice(const BackendDevice& device) {
+ auto locker = DeviceLockerArena::Get()->GetLocker(device);
+ locker->Lock();
+ return ExceptionCleanup(
+ [locker = std::move(locker)](ExceptionCleanup::StatusType status) {
+ locker->Unlock(std::move(status));
+ });
+ }
+
+ std::mutex mutex_;
+ std::map<BackendDevice, std::shared_ptr<DeviceLocker>> lockers_;
+};
+
+class DataCacheArena {
+ public:
+ static DataCacheArena* Get() {
+ static DataCacheArena* arena =
+ new DataCacheArena(FLAGS_torch_lazy_device_data_cache_size);
+ return arena;
+ }
+
+ explicit DataCacheArena(size_t max_cache_size)
+ : max_cache_size_(max_cache_size) {}
+
+ BackendDataPtr GetDeviceData(
+ const at::Tensor& tensor,
+ const BackendDevice& device) {
+ DataCacheArena::DataCache* cache = Get()->GetDataCache(device);
+ ;
+ BackendDataPtr device_data = cache->Get(tensor);
+ if (device_data == nullptr) {
+ at::Tensor tensor_copy = CopyTensor(tensor);
+ device_data = TensorToDataHandle(tensor_copy, device);
+ cache->Add(std::move(tensor_copy), device_data);
+ TORCH_LAZY_COUNTER("DeviceDataCacheMiss", 1);
+ }
+ return device_data;
+ }
+
+ BackendDataPtr GetDeviceData(
+ const at::Scalar& value,
+ at::ScalarType scalar_type,
+ const BackendDevice& device) {
+ // Workaround since at::scalar_tensor doesn't support bfloat16 yet.
+ at::Tensor t = at::scalar_tensor(
+ value,
+ at::TensorOptions(
+ scalar_type == at::ScalarType::BFloat16 ? at::ScalarType::Float
+ : scalar_type));
+ if (scalar_type == at::ScalarType::BFloat16) {
+ t = t.to(scalar_type);
+ }
+ return GetDeviceData(t, device);
+ }
+
+ private:
+ struct TensorHasher {
+ size_t operator()(const at::Tensor& tensor) const {
+ return HashReduce(
+ HashCombine(GetEnumValue(tensor.scalar_type()), TensorHash(tensor)));
+ }
+ };
+ struct TensorComparer {
+ bool operator()(const at::Tensor& tensor1, const at::Tensor& tensor2)
+ const {
+ return TensorCompare(tensor1, tensor2);
+ }
+ };
+
+ using DataCache =
+ Cache<at::Tensor, BackendData, TensorHasher, TensorComparer>;
+
+ DataCache* GetDataCache(const BackendDevice& device) {
+ std::lock_guard<std::mutex> lock(mutex_);
+ auto it = device_caches_.find(device);
+ if (it == device_caches_.end()) {
+ std::unique_ptr<DataCache> cache(new DataCache(max_cache_size_));
+ it = device_caches_.emplace(device, std::move(cache)).first;
+ }
+ return it->second.get();
+ }
+
+ size_t max_cache_size_ = 0;
+ std::mutex mutex_;
+ std::map<BackendDevice, std::unique_ptr<DataCache>> device_caches_;
+};
+
// The DeviceContextArena holds per device live information and statistics,
// among which the lazy tensors which are currently alive in the system. This is
// used to create computation "barriers" in order to flush pending operations
@@ -209,146 +393,6 @@
std::atomic<LazyGraphExecutor*> lazy_graph_executor_registry;
} // namespace
-void LazyGraphExecutor::DeviceLocker::Lock() {
- std::unique_lock<std::mutex> lock(mutex_);
- cv_.wait(lock, [this] { return !locked_; });
- CheckResetException();
- locked_ = true;
-}
-
-void LazyGraphExecutor::DeviceLocker::Unlock(std::exception_ptr exptr) {
- std::lock_guard<std::mutex> lock(mutex_);
- locked_ = false;
- exptr_ = std::move(exptr);
- cv_.notify_all();
-}
-
-void LazyGraphExecutor::DeviceLocker::Barrier() {
- std::unique_lock<std::mutex> lock(mutex_);
- cv_.wait(lock, [this] { return !locked_; });
- cv_.notify_all();
- CheckResetException();
-}
-
-void LazyGraphExecutor::DeviceLocker::CheckResetException() {
- std::exception_ptr exptr = std::move(exptr_);
- exptr_ = nullptr;
- if (exptr != nullptr) {
- std::rethrow_exception(exptr);
- }
-}
-
-auto LazyGraphExecutor::DeviceLockerArena::Get() -> DeviceLockerArena* {
- static DeviceLockerArena* arena = new DeviceLockerArena();
- return arena;
-}
-
-auto LazyGraphExecutor::DeviceLockerArena::GetLocker(
- const BackendDevice& device) -> std::shared_ptr<DeviceLocker> {
- std::lock_guard<std::mutex> lock(mutex_);
- auto it = lockers_.find(device);
- if (it == lockers_.end()) {
- it = lockers_.emplace(device, std::make_shared<DeviceLocker>(device)).first;
- }
- return it->second;
-}
-
-void LazyGraphExecutor::DeviceLockerArena::DeviceBarrier(
- const BackendDevice& device) {
- auto locker = DeviceLockerArena::Get()->GetLocker(device);
- locker->Barrier();
-}
-
-std::vector<ExceptionCleanup> LazyGraphExecutor::DeviceLockerArena::LockDevices(
- const std::set<BackendDevice>& devices) {
- std::vector<ExceptionCleanup> unlocker;
- unlocker.reserve(devices.size());
- for (auto& device : devices) {
- unlocker.emplace_back(LockDevice(device));
- }
- return unlocker;
-}
-
-ExceptionCleanup LazyGraphExecutor::DeviceLockerArena::LockDevice(
- const BackendDevice& device) {
- VLOG(4) << "Waiting on device barrier for device " << device << " ...";
- std::shared_ptr<DeviceLocker> locker;
- {
- TORCH_LAZY_TIMED("DeviceLockWait");
- locker = DeviceLockerArena::Get()->GetLocker(device);
- locker->Lock();
- }
- VLOG(4) << "Waiting on device barrier for device " << device << " done!";
- return torch::lazy::ExceptionCleanup(
- [locker = std::move(locker)](
- torch::lazy::ExceptionCleanup::StatusType status) {
- locker->Unlock(std::move(status));
- });
-}
-
-auto LazyGraphExecutor::DataCacheArena::Get() -> DataCacheArena* {
- static DataCacheArena* arena =
- new DataCacheArena(FLAGS_torch_lazy_device_data_cache_size);
- return arena;
-}
-
-LazyGraphExecutor::DataCacheArena::DataCacheArena(size_t max_cache_size)
- : max_cache_size_(max_cache_size) {}
-
-BackendDataPtr LazyGraphExecutor::DataCacheArena::GetDeviceData(
- const at::Tensor& tensor,
- const BackendDevice& device) {
- DataCacheArena::DataCache* cache = Get()->GetDataCache(device);
- ;
- BackendDataPtr device_data = cache->Get(tensor);
- if (device_data == nullptr) {
- at::Tensor tensor_copy = CopyTensor(tensor);
- device_data = TensorToDataHandle(tensor_copy, device);
- cache->Add(std::move(tensor_copy), device_data);
- TORCH_LAZY_COUNTER("DeviceDataCacheMiss", 1);
- }
- return device_data;
-}
-
-BackendDataPtr LazyGraphExecutor::DataCacheArena::GetDeviceData(
- const at::Scalar& value,
- at::ScalarType scalar_type,
- const BackendDevice& device) {
- // Workaround since at::scalar_tensor doesn't support bfloat16 yet.
- at::Tensor t = at::scalar_tensor(
- value,
- at::TensorOptions(
- scalar_type == at::ScalarType::BFloat16 ? at::ScalarType::Float
- : scalar_type));
- if (scalar_type == at::ScalarType::BFloat16) {
- t = t.to(scalar_type);
- }
- return GetDeviceData(t, device);
-}
-
-size_t LazyGraphExecutor::DataCacheArena::TensorHasher::operator()(
- const at::Tensor& tensor) const {
- return HashReduce(
- HashCombine(GetEnumValue(tensor.scalar_type()), TensorHash(tensor)));
-}
-
-bool LazyGraphExecutor::DataCacheArena::TensorComparer::operator()(
- const at::Tensor& tensor1,
- const at::Tensor& tensor2) const {
- return TensorCompare(tensor1, tensor2);
-}
-
-auto LazyGraphExecutor::DataCacheArena::GetDataCache(
- const BackendDevice& device) -> DataCache* {
- std::lock_guard<std::mutex> lock(mutex_);
- auto it = device_caches_.find(device);
- if (it == device_caches_.end()) {
- std::unique_ptr<DataCache> cache(new DataCache(max_cache_size_));
- it = device_caches_.emplace(device, std::move(cache)).first;
- }
- return it->second.get();
-}
-
void LazyGraphExecutor::Register(LazyGraphExecutor* executor) {
lazy_graph_executor_registry.store(executor);
}
@@ -429,7 +473,7 @@
TORCH_LAZY_COUNTER("MarkStep", 1);
DeviceContextArena::Get()->MarkStep(device);
ScopePusher::ResetScopes();
- ResetTrimCounter();
+ g_tls_data.Reset();
// Move TrieCache's current pointer back to its root
TrieCache::Get()->ResetCurrent();
}
@@ -459,11 +503,7 @@
return GetTensorsFused(tensors);
}
-void LazyGraphExecutor::ResetTrimCounter() const {
- g_tls_data.Reset();
-}
-
-size_t LazyGraphExecutor::IncTrimCounter() const {
+size_t LazyGraphExecutor::IncTrimCounter() {
return ++g_tls_data.trim_counter;
}
diff --git a/torch/csrc/lazy/core/lazy_graph_executor.h b/torch/csrc/lazy/core/lazy_graph_executor.h
index 5221d20..10b41b6 100644
--- a/torch/csrc/lazy/core/lazy_graph_executor.h
+++ b/torch/csrc/lazy/core/lazy_graph_executor.h
@@ -87,7 +87,7 @@
// All the tensors must be on the same device.
std::vector<at::Tensor> GetTensors(std::vector<LazyTensorPtr>* tensors);
- size_t IncTrimCounter() const;
+ size_t IncTrimCounter();
// Dumps the backend specific text of the computation accumulated in the graph
// which is attached the tensors.
@@ -161,102 +161,6 @@
std::vector<size_t> parameter_sequence;
};
- // Locking:
- // We perform two kinds of operations of tensors, synchronous and
- // asynchronous. The ApplyPendingGraph() are synchronous, as we need the
- // device data result immediately. Before the synchronous operations can
- // start, they need to wait that the pending asynchronous operations have
- // completed. Synchronous operations do not hold device locks, since they are
- // strictly sequential, dictated by the PyTorch execution order. The
- // SyncTensorsGraph() is asynchronous, and returns immediately after having
- // scheduled the asynchronous operation. While executing, the asynchronous
- // operations will hold locks on all the participating devices (in most common
- // cases there will be only one device).
- // Since asynchronous operations capture device locks, only one asynchronous
- // operation can execute at the same time, on a given device. Tensor
- // operations which send data to device do not need to hold any device locks
- // while doing so. Only operations which _use_ device data (computations, and
- // transfer from server) need to wait for asynchronous operations to complete
- // (barrier).
-
- class DeviceLocker {
- public:
- explicit DeviceLocker(BackendDevice device) : device_(std::move(device)) {}
-
- const BackendDevice& device() const {
- return device_;
- }
-
- void Lock();
- void Unlock(std::exception_ptr exptr);
- void Barrier();
-
- private:
- void CheckResetException();
-
- BackendDevice device_;
- std::mutex mutex_;
- std::condition_variable cv_;
- bool locked_ = false;
- std::exception_ptr exptr_;
- };
-
- class DeviceLockerArena {
- public:
- static DeviceLockerArena* Get();
-
- std::shared_ptr<DeviceLocker> GetLocker(const BackendDevice& device);
-
- void DeviceBarrier(const BackendDevice& device);
-
- // Use a set to impose an order on the device locking sequence (ABBA
- // prevention).
- std::vector<ExceptionCleanup> LockDevices(
- const std::set<BackendDevice>& devices);
-
- private:
- ExceptionCleanup LockDevice(const BackendDevice& device);
-
- std::mutex mutex_;
- std::map<BackendDevice, std::shared_ptr<DeviceLocker>> lockers_;
- };
-
- class DataCacheArena {
- public:
- static DataCacheArena* Get();
-
- BackendDataPtr GetDeviceData(
- const at::Tensor& tensor,
- const BackendDevice& device);
-
- BackendDataPtr GetDeviceData(
- const at::Scalar& value,
- at::ScalarType scalar_type,
- const BackendDevice& device);
-
- private:
- struct TensorHasher {
- size_t operator()(const at::Tensor& tensor) const;
- };
- struct TensorComparer {
- bool operator()(const at::Tensor& tensor1, const at::Tensor& tensor2)
- const;
- };
-
- explicit DataCacheArena(size_t max_cache_size);
-
- using DataCache =
- Cache<at::Tensor, BackendData, TensorHasher, TensorComparer>;
-
- DataCache* GetDataCache(const BackendDevice& device);
-
- size_t max_cache_size_ = 0;
- std::mutex mutex_;
- std::map<BackendDevice, std::unique_ptr<DataCache>> device_caches_;
- };
-
- void ResetTrimCounter() const;
-
private:
struct CompilationResult {
BackendDevice device;