Revert "[LTC] Make some LazyGraphExecutor private data structures protected (#90457)"

This reverts commit 93aa6e3e36c022a01076d84047acd58b59244348.

Reverted https://github.com/pytorch/pytorch/pull/90457 on behalf of https://github.com/clee2000 due to broke xla somehow https://hud.pytorch.org/pytorch/pytorch/commit/93aa6e3e36c022a01076d84047acd58b59244348 https://github.com/pytorch/pytorch/actions/runs/3659842773/jobs/6186552659
diff --git a/torch/csrc/lazy/core/lazy_graph_executor.cpp b/torch/csrc/lazy/core/lazy_graph_executor.cpp
index 7719abf..acab845 100644
--- a/torch/csrc/lazy/core/lazy_graph_executor.cpp
+++ b/torch/csrc/lazy/core/lazy_graph_executor.cpp
@@ -48,6 +48,190 @@
              contiguous_t1.numel() * contiguous_t1.itemsize()) == 0;
 }
 
+// Locking:
+// We perform two kinds of operations of tensors, synchronous and asynchronous.
+// The ApplyPendingGraph() are synchronous, as we need the device data result
+// immediately. Before the synchronous operations can start, they need to wait
+// that the pending asynchronous operations have completed.
+// Synchronous operations do not hold device locks, since they are strictly
+// sequential, dictated by the PyTorch execution order.
+// The SyncTensorsGraph() is asynchronous, and returns immediately after having
+// scheduled the asynchronous operation. While executing, the asynchronous
+// operations will hold locks on all the participating devices (in most common
+// cases there will be only one device).
+// Since asynchronous operations capture device locks, only one asynchronous
+// operation can execute at the same time, on a given device. Tensor operations
+// which send data to device do not need to hold any device locks while doing
+// so. Only operations which _use_ device data (computations, and transfer from
+// server) need to wait for asynchronous operations to complete (barrier).
+
+class DeviceLocker {
+ public:
+  explicit DeviceLocker(BackendDevice device) : device_(std::move(device)) {}
+
+  const BackendDevice& device() const {
+    return device_;
+  }
+
+  void Lock() {
+    std::unique_lock<std::mutex> lock(mutex_);
+    cv_.wait(lock, [this] { return !locked_; });
+    CheckResetException();
+    locked_ = true;
+  }
+
+  void Unlock(std::exception_ptr exptr) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    locked_ = false;
+    exptr_ = std::move(exptr);
+    cv_.notify_all();
+  }
+
+  void Barrier() {
+    std::unique_lock<std::mutex> lock(mutex_);
+    cv_.wait(lock, [this] { return !locked_; });
+    cv_.notify_all();
+    CheckResetException();
+  }
+
+ private:
+  void CheckResetException() {
+    std::exception_ptr exptr = std::move(exptr_);
+    exptr_ = nullptr;
+    if (exptr != nullptr) {
+      std::rethrow_exception(exptr);
+    }
+  }
+
+  BackendDevice device_;
+  std::mutex mutex_;
+  std::condition_variable cv_;
+  bool locked_ = false;
+  std::exception_ptr exptr_;
+};
+
+class DeviceLockerArena {
+ public:
+  static DeviceLockerArena* Get() {
+    static DeviceLockerArena* arena = new DeviceLockerArena();
+    return arena;
+  }
+
+  std::shared_ptr<DeviceLocker> GetLocker(const BackendDevice& device) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    auto it = lockers_.find(device);
+    if (it == lockers_.end()) {
+      it = lockers_.emplace(device, std::make_shared<DeviceLocker>(device))
+               .first;
+    }
+    return it->second;
+  }
+
+  void DeviceBarrier(const BackendDevice& device) {
+    auto locker = DeviceLockerArena::Get()->GetLocker(device);
+    locker->Barrier();
+  }
+
+  // Use a set to impose an order on the device locking sequence (ABBA
+  // prevention).
+  std::vector<ExceptionCleanup> LockDevices(
+      const std::set<BackendDevice>& devices) {
+    std::vector<ExceptionCleanup> unlocker;
+    unlocker.reserve(devices.size());
+    for (auto& device : devices) {
+      unlocker.emplace_back(LockDevice(device));
+    }
+    return unlocker;
+  }
+
+ private:
+  ExceptionCleanup LockDevice(const BackendDevice& device) {
+    auto locker = DeviceLockerArena::Get()->GetLocker(device);
+    locker->Lock();
+    return ExceptionCleanup(
+        [locker = std::move(locker)](ExceptionCleanup::StatusType status) {
+          locker->Unlock(std::move(status));
+        });
+  }
+
+  std::mutex mutex_;
+  std::map<BackendDevice, std::shared_ptr<DeviceLocker>> lockers_;
+};
+
+class DataCacheArena {
+ public:
+  static DataCacheArena* Get() {
+    static DataCacheArena* arena =
+        new DataCacheArena(FLAGS_torch_lazy_device_data_cache_size);
+    return arena;
+  }
+
+  explicit DataCacheArena(size_t max_cache_size)
+      : max_cache_size_(max_cache_size) {}
+
+  BackendDataPtr GetDeviceData(
+      const at::Tensor& tensor,
+      const BackendDevice& device) {
+    DataCacheArena::DataCache* cache = Get()->GetDataCache(device);
+    ;
+    BackendDataPtr device_data = cache->Get(tensor);
+    if (device_data == nullptr) {
+      at::Tensor tensor_copy = CopyTensor(tensor);
+      device_data = TensorToDataHandle(tensor_copy, device);
+      cache->Add(std::move(tensor_copy), device_data);
+      TORCH_LAZY_COUNTER("DeviceDataCacheMiss", 1);
+    }
+    return device_data;
+  }
+
+  BackendDataPtr GetDeviceData(
+      const at::Scalar& value,
+      at::ScalarType scalar_type,
+      const BackendDevice& device) {
+    // Workaround since at::scalar_tensor doesn't support bfloat16 yet.
+    at::Tensor t = at::scalar_tensor(
+        value,
+        at::TensorOptions(
+            scalar_type == at::ScalarType::BFloat16 ? at::ScalarType::Float
+                                                    : scalar_type));
+    if (scalar_type == at::ScalarType::BFloat16) {
+      t = t.to(scalar_type);
+    }
+    return GetDeviceData(t, device);
+  }
+
+ private:
+  struct TensorHasher {
+    size_t operator()(const at::Tensor& tensor) const {
+      return HashReduce(
+          HashCombine(GetEnumValue(tensor.scalar_type()), TensorHash(tensor)));
+    }
+  };
+  struct TensorComparer {
+    bool operator()(const at::Tensor& tensor1, const at::Tensor& tensor2)
+        const {
+      return TensorCompare(tensor1, tensor2);
+    }
+  };
+
+  using DataCache =
+      Cache<at::Tensor, BackendData, TensorHasher, TensorComparer>;
+
+  DataCache* GetDataCache(const BackendDevice& device) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    auto it = device_caches_.find(device);
+    if (it == device_caches_.end()) {
+      std::unique_ptr<DataCache> cache(new DataCache(max_cache_size_));
+      it = device_caches_.emplace(device, std::move(cache)).first;
+    }
+    return it->second.get();
+  }
+
+  size_t max_cache_size_ = 0;
+  std::mutex mutex_;
+  std::map<BackendDevice, std::unique_ptr<DataCache>> device_caches_;
+};
+
 // The DeviceContextArena holds per device live information and statistics,
 // among which the lazy tensors which are currently alive in the system. This is
 // used to create computation "barriers" in order to flush pending operations
@@ -209,146 +393,6 @@
 std::atomic<LazyGraphExecutor*> lazy_graph_executor_registry;
 } // namespace
 
-void LazyGraphExecutor::DeviceLocker::Lock() {
-  std::unique_lock<std::mutex> lock(mutex_);
-  cv_.wait(lock, [this] { return !locked_; });
-  CheckResetException();
-  locked_ = true;
-}
-
-void LazyGraphExecutor::DeviceLocker::Unlock(std::exception_ptr exptr) {
-  std::lock_guard<std::mutex> lock(mutex_);
-  locked_ = false;
-  exptr_ = std::move(exptr);
-  cv_.notify_all();
-}
-
-void LazyGraphExecutor::DeviceLocker::Barrier() {
-  std::unique_lock<std::mutex> lock(mutex_);
-  cv_.wait(lock, [this] { return !locked_; });
-  cv_.notify_all();
-  CheckResetException();
-}
-
-void LazyGraphExecutor::DeviceLocker::CheckResetException() {
-  std::exception_ptr exptr = std::move(exptr_);
-  exptr_ = nullptr;
-  if (exptr != nullptr) {
-    std::rethrow_exception(exptr);
-  }
-}
-
-auto LazyGraphExecutor::DeviceLockerArena::Get() -> DeviceLockerArena* {
-  static DeviceLockerArena* arena = new DeviceLockerArena();
-  return arena;
-}
-
-auto LazyGraphExecutor::DeviceLockerArena::GetLocker(
-    const BackendDevice& device) -> std::shared_ptr<DeviceLocker> {
-  std::lock_guard<std::mutex> lock(mutex_);
-  auto it = lockers_.find(device);
-  if (it == lockers_.end()) {
-    it = lockers_.emplace(device, std::make_shared<DeviceLocker>(device)).first;
-  }
-  return it->second;
-}
-
-void LazyGraphExecutor::DeviceLockerArena::DeviceBarrier(
-    const BackendDevice& device) {
-  auto locker = DeviceLockerArena::Get()->GetLocker(device);
-  locker->Barrier();
-}
-
-std::vector<ExceptionCleanup> LazyGraphExecutor::DeviceLockerArena::LockDevices(
-    const std::set<BackendDevice>& devices) {
-  std::vector<ExceptionCleanup> unlocker;
-  unlocker.reserve(devices.size());
-  for (auto& device : devices) {
-    unlocker.emplace_back(LockDevice(device));
-  }
-  return unlocker;
-}
-
-ExceptionCleanup LazyGraphExecutor::DeviceLockerArena::LockDevice(
-    const BackendDevice& device) {
-  VLOG(4) << "Waiting on device barrier for device " << device << " ...";
-  std::shared_ptr<DeviceLocker> locker;
-  {
-    TORCH_LAZY_TIMED("DeviceLockWait");
-    locker = DeviceLockerArena::Get()->GetLocker(device);
-    locker->Lock();
-  }
-  VLOG(4) << "Waiting on device barrier for device " << device << " done!";
-  return torch::lazy::ExceptionCleanup(
-      [locker = std::move(locker)](
-          torch::lazy::ExceptionCleanup::StatusType status) {
-        locker->Unlock(std::move(status));
-      });
-}
-
-auto LazyGraphExecutor::DataCacheArena::Get() -> DataCacheArena* {
-  static DataCacheArena* arena =
-      new DataCacheArena(FLAGS_torch_lazy_device_data_cache_size);
-  return arena;
-}
-
-LazyGraphExecutor::DataCacheArena::DataCacheArena(size_t max_cache_size)
-    : max_cache_size_(max_cache_size) {}
-
-BackendDataPtr LazyGraphExecutor::DataCacheArena::GetDeviceData(
-    const at::Tensor& tensor,
-    const BackendDevice& device) {
-  DataCacheArena::DataCache* cache = Get()->GetDataCache(device);
-  ;
-  BackendDataPtr device_data = cache->Get(tensor);
-  if (device_data == nullptr) {
-    at::Tensor tensor_copy = CopyTensor(tensor);
-    device_data = TensorToDataHandle(tensor_copy, device);
-    cache->Add(std::move(tensor_copy), device_data);
-    TORCH_LAZY_COUNTER("DeviceDataCacheMiss", 1);
-  }
-  return device_data;
-}
-
-BackendDataPtr LazyGraphExecutor::DataCacheArena::GetDeviceData(
-    const at::Scalar& value,
-    at::ScalarType scalar_type,
-    const BackendDevice& device) {
-  // Workaround since at::scalar_tensor doesn't support bfloat16 yet.
-  at::Tensor t = at::scalar_tensor(
-      value,
-      at::TensorOptions(
-          scalar_type == at::ScalarType::BFloat16 ? at::ScalarType::Float
-                                                  : scalar_type));
-  if (scalar_type == at::ScalarType::BFloat16) {
-    t = t.to(scalar_type);
-  }
-  return GetDeviceData(t, device);
-}
-
-size_t LazyGraphExecutor::DataCacheArena::TensorHasher::operator()(
-    const at::Tensor& tensor) const {
-  return HashReduce(
-      HashCombine(GetEnumValue(tensor.scalar_type()), TensorHash(tensor)));
-}
-
-bool LazyGraphExecutor::DataCacheArena::TensorComparer::operator()(
-    const at::Tensor& tensor1,
-    const at::Tensor& tensor2) const {
-  return TensorCompare(tensor1, tensor2);
-}
-
-auto LazyGraphExecutor::DataCacheArena::GetDataCache(
-    const BackendDevice& device) -> DataCache* {
-  std::lock_guard<std::mutex> lock(mutex_);
-  auto it = device_caches_.find(device);
-  if (it == device_caches_.end()) {
-    std::unique_ptr<DataCache> cache(new DataCache(max_cache_size_));
-    it = device_caches_.emplace(device, std::move(cache)).first;
-  }
-  return it->second.get();
-}
-
 void LazyGraphExecutor::Register(LazyGraphExecutor* executor) {
   lazy_graph_executor_registry.store(executor);
 }
@@ -429,7 +473,7 @@
   TORCH_LAZY_COUNTER("MarkStep", 1);
   DeviceContextArena::Get()->MarkStep(device);
   ScopePusher::ResetScopes();
-  ResetTrimCounter();
+  g_tls_data.Reset();
   // Move TrieCache's current pointer back to its root
   TrieCache::Get()->ResetCurrent();
 }
@@ -459,11 +503,7 @@
   return GetTensorsFused(tensors);
 }
 
-void LazyGraphExecutor::ResetTrimCounter() const {
-  g_tls_data.Reset();
-}
-
-size_t LazyGraphExecutor::IncTrimCounter() const {
+size_t LazyGraphExecutor::IncTrimCounter() {
   return ++g_tls_data.trim_counter;
 }
 
diff --git a/torch/csrc/lazy/core/lazy_graph_executor.h b/torch/csrc/lazy/core/lazy_graph_executor.h
index 5221d20..10b41b6 100644
--- a/torch/csrc/lazy/core/lazy_graph_executor.h
+++ b/torch/csrc/lazy/core/lazy_graph_executor.h
@@ -87,7 +87,7 @@
   // All the tensors must be on the same device.
   std::vector<at::Tensor> GetTensors(std::vector<LazyTensorPtr>* tensors);
 
-  size_t IncTrimCounter() const;
+  size_t IncTrimCounter();
 
   // Dumps the backend specific text of the computation accumulated in the graph
   // which is attached the tensors.
@@ -161,102 +161,6 @@
     std::vector<size_t> parameter_sequence;
   };
 
-  // Locking:
-  // We perform two kinds of operations of tensors, synchronous and
-  // asynchronous. The ApplyPendingGraph() are synchronous, as we need the
-  // device data result immediately. Before the synchronous operations can
-  // start, they need to wait that the pending asynchronous operations have
-  // completed. Synchronous operations do not hold device locks, since they are
-  // strictly sequential, dictated by the PyTorch execution order. The
-  // SyncTensorsGraph() is asynchronous, and returns immediately after having
-  // scheduled the asynchronous operation. While executing, the asynchronous
-  // operations will hold locks on all the participating devices (in most common
-  // cases there will be only one device).
-  // Since asynchronous operations capture device locks, only one asynchronous
-  // operation can execute at the same time, on a given device. Tensor
-  // operations which send data to device do not need to hold any device locks
-  // while doing so. Only operations which _use_ device data (computations, and
-  // transfer from server) need to wait for asynchronous operations to complete
-  // (barrier).
-
-  class DeviceLocker {
-   public:
-    explicit DeviceLocker(BackendDevice device) : device_(std::move(device)) {}
-
-    const BackendDevice& device() const {
-      return device_;
-    }
-
-    void Lock();
-    void Unlock(std::exception_ptr exptr);
-    void Barrier();
-
-   private:
-    void CheckResetException();
-
-    BackendDevice device_;
-    std::mutex mutex_;
-    std::condition_variable cv_;
-    bool locked_ = false;
-    std::exception_ptr exptr_;
-  };
-
-  class DeviceLockerArena {
-   public:
-    static DeviceLockerArena* Get();
-
-    std::shared_ptr<DeviceLocker> GetLocker(const BackendDevice& device);
-
-    void DeviceBarrier(const BackendDevice& device);
-
-    // Use a set to impose an order on the device locking sequence (ABBA
-    // prevention).
-    std::vector<ExceptionCleanup> LockDevices(
-        const std::set<BackendDevice>& devices);
-
-   private:
-    ExceptionCleanup LockDevice(const BackendDevice& device);
-
-    std::mutex mutex_;
-    std::map<BackendDevice, std::shared_ptr<DeviceLocker>> lockers_;
-  };
-
-  class DataCacheArena {
-   public:
-    static DataCacheArena* Get();
-
-    BackendDataPtr GetDeviceData(
-        const at::Tensor& tensor,
-        const BackendDevice& device);
-
-    BackendDataPtr GetDeviceData(
-        const at::Scalar& value,
-        at::ScalarType scalar_type,
-        const BackendDevice& device);
-
-   private:
-    struct TensorHasher {
-      size_t operator()(const at::Tensor& tensor) const;
-    };
-    struct TensorComparer {
-      bool operator()(const at::Tensor& tensor1, const at::Tensor& tensor2)
-          const;
-    };
-
-    explicit DataCacheArena(size_t max_cache_size);
-
-    using DataCache =
-        Cache<at::Tensor, BackendData, TensorHasher, TensorComparer>;
-
-    DataCache* GetDataCache(const BackendDevice& device);
-
-    size_t max_cache_size_ = 0;
-    std::mutex mutex_;
-    std::map<BackendDevice, std::unique_ptr<DataCache>> device_caches_;
-  };
-
-  void ResetTrimCounter() const;
-
  private:
   struct CompilationResult {
     BackendDevice device;