#include <torch/csrc/utils/pybind.h> | |
#include <nvToolsExt.h> | |
namespace torch { namespace cuda { namespace shared { | |
void initNvtxBindings(PyObject* module) { | |
auto m = py::handle(module).cast<py::module>(); | |
auto nvtx = m.def_submodule("_nvtx", "libNvToolsExt.so bindings"); | |
nvtx.def("rangePushA", nvtxRangePushA); | |
nvtx.def("rangePop", nvtxRangePop); | |
nvtx.def("markA", nvtxMarkA); | |
} | |
} // namespace shared | |
} // namespace cuda | |
} // namespace torch |