file:line for tracing (#21247)
Summary:
Stacked on https://github.com/pytorch/pytorch/pull/21217
This adds support for recording file and line information during tracing, by extracting the top Python interpreter frame
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21247
Reviewed By: suo, driazati
Differential Revision: D15594553
Pulled By: jamesr66a
fbshipit-source-id: 72e1b3a46f1dabe3e83a608ec1a7d083bd1720f9
diff --git a/test/test_jit.py b/test/test_jit.py
index b88e6ff..f226ef7 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -3733,6 +3733,16 @@
FileCheck().check('<string>:2:12').run(scripted.foo.graph)
+ def test_file_line_trace(self):
+ def foobar(xyz):
+ return torch.neg(xyz)
+
+ scripted = torch.jit.trace(foobar, (torch.rand(3, 4)))
+
+ _, lineno = inspect.getsourcelines(foobar)
+ FileCheck().check('test_jit.py:{}:0'.format(lineno + 1))\
+ .run(scripted.graph)
+
def test_tensor_shape(self):
x = torch.empty(34, 56, 78)
diff --git a/torch/csrc/jit/python_tracer.cpp b/torch/csrc/jit/python_tracer.cpp
index 110034e..383503e 100644
--- a/torch/csrc/jit/python_tracer.cpp
+++ b/torch/csrc/jit/python_tracer.cpp
@@ -23,18 +23,30 @@
// Python interpreter retrieval routine adapted from
// https://stackoverflow.com/a/8706144
-std::string getPythonInterpreterStackTrace() {
+SourceRange getPythonInterpreterSourceRange() {
+ c10::optional<std::string> source_filename;
+ size_t source_line = 0;
std::stringstream stack_trace;
+
AutoGIL gil;
PyFrameObject* frame = PyEval_GetFrame();
+
while (nullptr != frame) {
int line = PyCode_Addr2Line(frame->f_code, frame->f_lasti);
std::string filename = THPUtils_unpackString(frame->f_code->co_filename);
std::string funcname = THPUtils_unpackString(frame->f_code->co_name);
stack_trace << filename << "(" << line << "): " << funcname << "\n";
+ if (!source_filename) {
+ source_filename = filename;
+ source_line = line;
+ }
frame = frame->f_back;
}
- return stack_trace.str();
+
+ auto stack_trace_text = stack_trace.str();
+ auto source =
+ std::make_shared<Source>(stack_trace_text, source_filename, source_line);
+ return SourceRange(source, 0, stack_trace_text.size());
}
std::shared_ptr<torch::jit::Graph> createGraphByTracing(
@@ -104,7 +116,7 @@
}
void pythonRecordSourceLocation(Node* n) {
- n->setSourceRange(SourceRange(getPythonInterpreterStackTrace()));
+ n->setSourceRange(getPythonInterpreterSourceRange());
}
void pythonWarn(const std::string& reason) {