Add torch.autograd.profiler.range
diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py
index a6a1c28..1ae80ea 100644
--- a/torch/autograd/profiler.py
+++ b/torch/autograd/profiler.py
@@ -8,6 +8,18 @@
from collections import defaultdict, namedtuple
+class range(object):
+ def __init__(self, name):
+ self.name = name
+
+ def __enter__(self):
+ torch.autograd._push_range(self.name)
+
+ def __exit__(self, *args):
+ torch.autograd._pop_range()
+ return False
+
+
class EventList(list):
"""A list of Events (for pretty printing)"""
def __init__(self, *args, **kwargs):
diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp
index d0fedc4..a36ba8a 100644
--- a/torch/csrc/autograd/init.cpp
+++ b/torch/csrc/autograd/init.cpp
@@ -65,5 +65,16 @@
m.def("_enable_profiler", torch::autograd::profiler::enableProfiler);
m.def("_disable_profiler", torch::autograd::profiler::disableProfiler);
+ m.def("_push_range", [](const char *name) {
+ using namespace torch::autograd::profiler;
+ if (!profiling) return;
+ pushRange(name);
+ });
+ m.def("_pop_range", []() {
+ using namespace torch::autograd::profiler;
+ if (!profiling) return;
+ popRange();
+ });
+
Py_RETURN_TRUE;
}