blob: 57125961a5d251b0da97591f5c49b93ac2845663 [file] [log] [blame]
#pragma once
#include <torch/csrc/jit/ir.h>
/* `getCustomPasses()` returns a vector of passes that will be executed after
* differentiation but before any fusion. This is the de-facto location
* for compiler backends to insert passes.
*
* Static registration of a pass can be done by creating a global
* `RegisterPass r(Pass)` variable in a compilation unit.
*
* pass_manager.h uses a Meyer's singleton
* to store a vector of `Pass`es, which modify the IR graph in place.
*/
namespace torch {
namespace jit {
// A pass modifies a Graph in place.
using Pass = std::function<void(std::shared_ptr<Graph>&)>;
TORCH_API std::vector<Pass>& getCustomPasses();
struct TORCH_API RegisterPass {
RegisterPass(Pass p);
};
} // namespace jit
} // namespace torch