| #pragma once |
| |
| #include <ATen/ATen.h> |
| |
| namespace at::caching { |
| |
| // Some systems (just cudagraphs currently) will persist a static tensor output |
| // whose TensorImpl does not change across iterations. For these tensors caching |
| // dtype conversions is invalid. Additionally, there will be an extra reference |
| // count to these cached tensors that would prevent buffer inplacing and other |
| // checks on tensor uniqueness. If we are not using these systems the enabled |
| // flag will be false and we will avoid the hash lookup. |
| |
| TORCH_API bool is_cached_tensor(const at::Tensor& t); |
| TORCH_API void add_cached_tensor(const at::Tensor& t); |
| TORCH_API void remove_cached_tensor(const at::Tensor& t); |
| TORCH_API void set_cached_tensors_enabled(bool enable); |
| |
| // For gradient buffer stealing we will adjust the use count of tensors |
| // which are persisted by cudagraphs, just as we need to adjust reference |
| // count of tensors with hooks. |
| TORCH_API size_t adjusted_use_count(const at::Tensor& t); |
| |
| } // namespace at::caching |