| #pragma once |
| |
| #include <torch/csrc/WindowsTorchApiMacro.h> |
| |
| namespace torch { namespace autograd { |
| |
| struct TORCH_API GradMode { |
| static bool is_enabled(); |
| static void set_enabled(bool enabled); |
| }; |
| |
| // A RAII, thread local (!) guard that enables or disables grad mode upon |
| // construction, and sets it back to the original value upon destruction. |
| struct TORCH_API AutoGradMode { |
| AutoGradMode(bool enabled) : prev_mode(GradMode::is_enabled()) { |
| GradMode::set_enabled(enabled); |
| } |
| ~AutoGradMode() { |
| GradMode::set_enabled(prev_mode); |
| } |
| bool prev_mode; |
| }; |
| |
| }} |