blob: e4d4dfa8affa43d9e2f86fdd6d99ff5e63167980 [file] [log] [blame]
#include <torch/csrc/autograd/profiler_python.h>
#include <cstdint>
#include <deque>
#include <iostream>
#include <limits>
#include <memory>
#include <queue>
#include <string>
#include <utility>
#include <vector>
#include <Python.h>
#include <frameobject.h>
#include <c10/macros/Macros.h>
#include <c10/util/C++17.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/irange.h>
#include <torch/csrc/profiler/collection.h>
#include <torch/csrc/profiler/containers.h>
#include <torch/csrc/profiler/util.h>
#include <torch/csrc/utils/python_strings.h>
#include <torch/csrc/utils/pybind.h>
namespace py = pybind11;
namespace torch {
namespace profiler {
namespace impl {
namespace {
enum CallType { PyCall = 0, PyModuleCall, PyCCall };
static constexpr size_t CallTypeSize = 3;
// ============================================================================
// == Miscellaneous structs and utils =========================================
// ============================================================================
struct CodeLocation {
CodeLocation() = default;
explicit CodeLocation(const PyFrameObject* frame)
: code_{frame->f_code}, lasti_{frame->f_lasti} {}
bool operator==(const CodeLocation& other) const {
return code_ == other.code_ && lasti_ == other.lasti_;
}
PyCodeObject* code_{nullptr};
int lasti_{0};
};
PyObject* nnModuleCode() {
static auto module_call_code = []() {
pybind11::gil_scoped_acquire gil;
return py::module::import("torch.nn")
.attr("Module")
.attr("__call__")
.attr("__code__")
.ptr();
}();
return module_call_code;
}
} // namespace
} // namespace impl
} // namespace profiler
} // namespace torch
template <>
struct std::hash<torch::profiler::impl::CodeLocation> {
size_t operator()(const torch::profiler::impl::CodeLocation& x) {
return c10::get_hash(x.code_, x.lasti_);
}
};
namespace torch {
namespace profiler {
namespace impl {
namespace {
// ============================================================================
// == CallTypeHelper: Tools for generic programming on specializations. =======
// ============================================================================
template <template <CallType> class ClassT>
class CallTypeHelper final {
private:
static_assert(
CallType::PyCall == 0,
"CallTypeHelper uses integer math which depends on a zero start.");
static constexpr size_t End = CallTypeSize;
template <size_t... I>
static constexpr std::tuple<ClassT<(CallType)I>...> make_tuple_impl(
std::index_sequence<I...>);
template <size_t C, typename T, typename FunctorT, typename... Args>
static void map(T& t, FunctorT& f, Args... args) {
f(std::get<C>(t), args...);
c10::guts::if_constexpr<C + 1 < End>(
[&](auto _) { map<C + 1>(_(t), f, std::forward<Args>(args)...); });
}
public:
using tuple_type = decltype(make_tuple_impl(std::make_index_sequence<End>{}));
template <typename FunctorT, typename... Args>
static void map(tuple_type& t, FunctorT& f, Args... args) {
map<0>(t, f, std::forward<Args>(args)...);
}
};
// ============================================================================
// == Event type definitions. =================================================
// ============================================================================
// When we are tracing a Python program, the general procedure is to record
// every time we enter or exit a function and later replay these events during
// post processing. Thus, during the profiling phase we want to do the MINIMAL
// amount of work to capture all of the information that we need; otherwise we
// will distort the profile. (While we don't wish to be terribly inefficient
// during post processing, we are willing to do extra fixup work in post if it
// reduces overhead in the profiling phase.)
//
// When the tracer first enters a frame, it constructs a CallKey for that
// location. The contents of the key vary by context. For a python function
// the key is the (PyCodeObject*, int) pair that defines the bytecode of the
// function. For an `nn.Module` the key is a (non-owning) pointer to `self`.
// For a bound C function it is a (non-owning) pointer to the bound function.
// A CallKey should be small, inexpensive, and POD.
//
// We then collect a CallKey<CallType::PyCall> for the calling frame for better
// source tracking. This pair is a `Callsite`, and serves as a first level key
// during tracing. We lookup the Callsite in a thread local cache which maps
// Callsite to a unique integer `TraceKey`. On a cache hit, we simply store the
// TraceKey and return. On a cache miss, we use a global value cache to store
// whatever fields we need from the two CallKeys, generate a new TraceKey, and
// update the local cache.
//
// During post processing we:
// 1) Determine the type represented by a TraceKey by checking which
// sub-cache it appears in in the thread local cache.
// 2) Look up the pair of CallKeys from the thread local cache.
// 3) Look up the expanded values of each CallKey from the global value cache.
//
// To add a new event type to the cache:
// 1) Add an entry to the `CallType` enum.
// 2) Add a specialization of Config which defined key_t and cache_t.
// 3) Add a specialization of ValueCache::store and ValueCache::load.
template<CallType>
struct Config;
template<>
struct Config<CallType::PyCall> {
using key_t = CodeLocation;
using cache_t = ska::flat_hash_map<key_t, PyFrameState>;
static constexpr EventType event_type = EventType::PyCall;
};
template <>
struct Config<CallType::PyModuleCall> {
using key_t = PyModuleSelf;
struct cache_t {
c10::optional<CodeLocation> module_forward_;
ska::flat_hash_map<PyModuleSelf, PyModuleCls> modules_;
ska::flat_hash_map<PyModuleCls, at::StringView> module_cls_names_;
};
static constexpr EventType event_type = EventType::PyCall;
};
template<>
struct Config<CallType::PyCCall> {
using key_t = torch::profiler::impl::PyCFunction;
using cache_t = ska::flat_hash_map<key_t, at::StringView>;
static constexpr EventType event_type = EventType::PyCCall;
};
// ============================================================================
// == Callsite & ValueCache: Storage during profiling =========================
// ============================================================================
template <CallType C>
class Callsite {
public:
static constexpr CallType call_type = C;
using key_t = typename Config<C>::key_t;
static_assert(
std::is_trivially_copyable<key_t>::value,
"Key should be trivial, as it is passed by value.");
template <typename U>
Callsite(U value, const PyFrameObject* f_back)
: value_(value), caller_(f_back) {}
bool operator==(const Callsite<C>& other) const {
return value_ == other.value_ && caller_ == other.caller_;
}
key_t value_;
Config<CallType::PyCall>::key_t caller_;
};
class ValueCache {
public:
template <CallType C>
void store(const typename Config<C>::key_t&);
template <CallType C>
auto load(const Callsite<C>& callsite, size_t python_tid) const {
auto caller = load<CallType::PyCall>(callsite.caller_);
TORCH_INTERNAL_ASSERT(!caller.second.has_value());
return ExtraFields<Config<C>::event_type>{
/*end_time_ns=*/std::numeric_limits<time_t>::min(),
python_tid,
caller.first,
load<C>(callsite.value_)};
}
void trimPrefixes();
private:
template <CallType C>
typename ExtraFields<Config<C>::event_type>::args_t load(
const typename Config<C>::key_t&) const;
template <CallType C>
using State = typename Config<C>::cache_t;
CallTypeHelper<State>::tuple_type state_;
};
// ============================================================================
// == Type specific store and load implementations. ===========================
// ============================================================================
using PyCallKey = Config<CallType::PyCall>::key_t;
using PyModuleCallKey = Config<CallType::PyModuleCall>::key_t;
using PyCCallKey = Config<CallType::PyCCall>::key_t;
template <>
void ValueCache::store<CallType::PyCall>(const PyCallKey& key) {
auto& locations = std::get<CallType::PyCall>(state_);
if (C10_UNLIKELY(locations.find(key) == locations.end())) {
TORCH_INTERNAL_ASSERT(key.code_ != nullptr);
locations[key] = {
PyCode_Addr2Line(key.code_, key.lasti_),
at::StringView(THPUtils_unpackString(key.code_->co_filename)),
at::StringView(THPUtils_unpackString(key.code_->co_name))};
}
}
template <>
ExtraFields<EventType::PyCall>::args_t ValueCache::load<CallType::PyCall>(
const PyCallKey& key) const {
return {std::get<CallType::PyCall>(state_).at(key), c10::nullopt};
}
template <>
void ValueCache::store<CallType::PyModuleCall>(const PyModuleCallKey& key) {
auto& cache = std::get<CallType::PyModuleCall>(state_);
if (C10_UNLIKELY(cache.modules_.find(key) == cache.modules_.end())) {
if (C10_UNLIKELY(!cache.module_forward_.has_value())) {
auto frame = PyEval_GetFrame();
TORCH_INTERNAL_ASSERT((PyObject*)(frame->f_code) == nnModuleCode());
cache.module_forward_ = PyCallKey(frame);
store<CallType::PyCall>(*cache.module_forward_);
}
auto cls_handle = py::handle((PyObject*)key).attr("__class__");
auto cls = PyModuleCls(cls_handle.ptr());
cache.modules_[key] = cls;
if (cache.module_cls_names_.find(cls) == cache.module_cls_names_.end()) {
cache.module_cls_names_[cls] =
at::StringView(py::str(cls_handle.attr("__name__")));
}
}
}
template <>
ExtraFields<EventType::PyCall>::args_t ValueCache::load<CallType::PyModuleCall>(
const PyModuleCallKey& key) const {
auto& cache = std::get<CallType::PyModuleCall>(state_);
TORCH_INTERNAL_ASSERT(cache.module_forward_.has_value());
auto cls = cache.modules_.at(key);
auto fwd = std::get<CallType::PyCall>(state_).at(*cache.module_forward_);
return {fwd, NNModuleInfo{key, cls, cache.module_cls_names_.at(cls)}};
}
template <>
void ValueCache::store<CallType::PyCCall>(const PyCCallKey& key) {
auto& names = std::get<CallType::PyCCall>(state_);
if (C10_UNLIKELY(names.find(key) == names.end())) {
names[key] = at::StringView(py::repr((PyObject*)key));
}
}
template <>
ExtraFields<EventType::PyCCall>::args_t ValueCache::load<CallType::PyCCall>(
const PyCCallKey& key) const {
return std::get<CallType::PyCCall>(state_).at(key);
}
// TODO: Use re2.
void ValueCache::trimPrefixes() {
static auto prefixes = py::module::import("torch.profiler.python_tracer")
.attr("_prefix_regex")().cast<std::vector<std::string>>();
for (auto& it : std::get<CallType::PyCall>(state_)) {
std::string filename = it.second.filename_.str();
for (const auto& p : prefixes) {
if (filename.compare(0, p.size(), p) == 0) {
filename.erase(0, p.size());
it.second.filename_ = at::StringView(filename);
break;
}
}
}
}
// ============================================================================
// == TraceKey cache ==========================================================
// ============================================================================
using python_tracer::TraceKey;
TraceKey nextKey() {
static std::atomic<uint64_t> key{0};
return TraceKey{++key};
}
template <CallType C>
struct TraceKeyCacheState {
struct Hash {
size_t operator()(const Callsite<C>& key) {
return c10::get_hash(key.value_, key.caller_);
}
};
TraceKey intern(Callsite<C> callsite, ValueCache& value_cache) {
auto it = state_.find(callsite);
if (C10_UNLIKELY(it == state_.end())) {
value_cache.store<C>(callsite.value_);
value_cache.store<CallType::PyCall>(callsite.caller_);
it = state_.insert({callsite, nextKey()}).first;
}
return it->second;
}
auto lookup(Callsite<C>& callsite, ValueCache& value_cache) const {
return std::make_pair(
value_cache.load<C>(callsite.value_),
value_cache.load<CallType::PyCall>(callsite.caller_));
}
ska::flat_hash_map<Callsite<C>, TraceKey, Hash> state_;
};
// ============================================================================
// == Core CPython data types =================================================
// ============================================================================
// PyObject that allows different threads to record events without colliding.
// It is passed as the second argument when enabling tracing via
// `PyEval_SetProfile`.
struct ThreadLocalResults;
struct TraceContext {
PyObject_HEAD
ThreadLocalResults* thread_local_results_;
};
// CPython boilerplate to define `TraceContext` as a proper python object.
static PyTypeObject TraceContextType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"TraceContext", /* tp_name */
sizeof(TraceContext), /* tp_basicsize */
0, /* tp_itemsize */
nullptr, /* tp_dealloc */
0, /* tp_vectorcall_offset */ // NOLINT: modernize-use-nullptr
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
nullptr, /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
nullptr, /* tp_hash */
nullptr, /* tp_call */
nullptr, /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT, /* tp_flags */
"Python tracer TLS", /* tp_doc */
nullptr, /* tp_traverse */
nullptr, /* tp_clear */
nullptr, /* tp_richcompare */
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
nullptr, /* tp_methods */
nullptr, /* tp_members */
nullptr, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */
nullptr, /* tp_descr_set */
0, /* tp_dictoffset */
nullptr, /* tp_init */
nullptr, /* tp_alloc */
PyType_GenericNew, /* tp_new */
nullptr /* tp_free */
};
// ============================================================================
// == Thread local cache ======================================================
// ============================================================================
struct ThreadLocalResults {
ThreadLocalResults(PyThreadState* thread_state, ValueCache* value_cache)
: thread_state_{thread_state},
ctx_{(TraceContext*)TraceContextType.tp_alloc(&TraceContextType, 0)},
value_cache_{value_cache} {
ctx_->thread_local_results_ = this;
}
ThreadLocalResults() = delete;
ThreadLocalResults(const ThreadLocalResults&) = delete;
ThreadLocalResults(ThreadLocalResults&&) = delete;
ThreadLocalResults& operator=(const ThreadLocalResults&) = delete;
ThreadLocalResults& operator=(const ThreadLocalResults&&) = delete;
~ThreadLocalResults() {
Py_DECREF((PyObject*)ctx_);
}
template <CallType C, EventType E, typename... Args>
TraceKey intern(Args... args) {
static_assert(
Config<C>::event_type == E,
"ThreadLocalResults.intern called from the wrong typed context.");
return std::get<C>(trace_keys_)
.intern(Callsite<C>(std::forward<Args>(args)...), *value_cache_);
}
static constexpr size_t BLOCK_SIZE = 1024;
PyThreadState* thread_state_;
TraceContext* ctx_;
ValueCache* value_cache_;
CallTypeHelper<TraceKeyCacheState>::tuple_type trace_keys_;
AppendOnlyList<approx_time_t, BLOCK_SIZE> exit_times_;
AppendOnlyList<approx_time_t, BLOCK_SIZE> c_exit_times_;
};
// ============================================================================
// == Tracing implementation ==================================================
// ============================================================================
class PythonTracer final : public python_tracer::PythonTracerBase {
public:
static int pyProfileFn(
PyObject* obj,
PyFrameObject* frame,
int what,
PyObject* arg);
static PythonTracer& singleton();
void start(torch::profiler::impl::RecordQueue* queue) override;
void stop() override;
std::vector<std::shared_ptr<Result>> getEvents(
std::function<time_t(approx_time_t)> time_converter,
std::vector<python_tracer::CompressedEvent>& enters) override;
void clear() override;
private:
PythonTracer();
void recordPyCall(ThreadLocalResults& tls, PyFrameObject* frame);
void recordCCall(ThreadLocalResults& tls, PyFrameObject* frame, PyObject* arg);
torch::profiler::impl::RecordQueue* queue_;
PyObject* module_call_code_;
std::deque<ThreadLocalResults> thread_local_results_;
ValueCache value_cache_;
};
PythonTracer& PythonTracer::singleton() {
static PythonTracer singleton_;
return singleton_;
}
PythonTracer::PythonTracer()
: queue_(nullptr), module_call_code_(nnModuleCode()) {}
void PythonTracer::start(torch::profiler::impl::RecordQueue* queue) {
TORCH_CHECK(queue_ == nullptr, "PythonTracer is already active")
TORCH_CHECK(
!thread_local_results_.size(),
"PythonTracer should not have active contexts");
queue_ = queue;
pybind11::gil_scoped_acquire gil;
// Loop over all threads within the current interpreter. We will need to
// register a trace function with each thread. We set the current thread to
// position zero to ensure that it is traced, and so we can restore the
// thread state after registration. The profiler cannot post process multiple
// python threads yet, so this section is temporarily disabled.
std::vector<PyThreadState*> thread_states{PyThreadState_Get()};
/*
if (all_threads) {
auto thread_state = thread_states[0];
while (thread_state != nullptr) {
if (thread_state != thread_states[0]) {
thread_states.push_back(thread_state);
}
thread_state = PyThreadState_Next(thread_state);
}
}
*/
// Register the tracer in each thread.
for (const auto i : c10::irange(thread_states.size())) {
PyThreadState* thread_state = thread_states[i];
PyThreadState_Swap(thread_state);
thread_local_results_.emplace_back(thread_state, &value_cache_);
auto* ctx = thread_local_results_.back().ctx_;
// When we begin profiling there are already frames on the Python
// interpreter stack. To ensure a complete trace, we must push calls
// to all the prior frames onto our event stack. (We stop at depth=128)
std::vector<PyFrameObject*> current_stack;
auto frame = PyEval_GetFrame();
size_t depth = 0; // Make sure we can't infinite loop.
while (frame != nullptr && depth <= 128) {
current_stack.push_back(frame);
frame = frame->f_back;
depth++;
}
for (auto it = current_stack.rbegin(); it != current_stack.rend(); it++) {
recordPyCall(thread_local_results_.back(), *it);
}
// Note:
// This profile will not compose with other CPython profilers, and
// cannot be round tripped via `sys.settrace(sys.gettrace())`
PyEval_SetProfile(PythonTracer::pyProfileFn, (PyObject*)ctx);
}
// Restore the thread state to its initial value.
PyThreadState_Swap(thread_states[0]);
};
void PythonTracer::stop() {
TORCH_INTERNAL_ASSERT(queue_ != nullptr, "PythonTracer is not running.")
queue_ = nullptr;
pybind11::gil_scoped_acquire gil;
PyThreadState* initial_thread_state = PyThreadState_Get();
for (const auto& i : thread_local_results_) {
PyThreadState_Swap(i.thread_state_);
PyEval_SetProfile(nullptr, nullptr);
}
PyThreadState_Swap(initial_thread_state);
}
void PythonTracer::clear() {
TORCH_CHECK(queue_ == nullptr, "Cannot clear state while PythonTracer is active.");
thread_local_results_.clear();
value_cache_ = ValueCache();
}
void PythonTracer::recordPyCall(ThreadLocalResults& tls, PyFrameObject* frame) {
static constexpr auto E = EventType::PyCall;
auto get_key = [&]() -> TraceKey {
if ((PyObject*)(frame->f_code) == module_call_code_) {
// By default, CPython stores locals in a "fast" format, with an array
// of names and an array of values. Consequently, frame->f_locals is
// NULL since the interpreter has no need to populate it.
//
// If these arrays were part of the public API then we could very
// quickly access `self`. Unfortunately they are not, and moreover are
// not stable across versions. As a result, we are forced to call
// `PyFrame_FastToLocals` which forces the interpreter to materialize
// the full dict of locals.
PyFrame_FastToLocals(frame);
auto self = PyDict_GetItemString(frame->f_locals, "self");
PyFrame_LocalsToFast(frame, 0);
TORCH_INTERNAL_ASSERT(frame->f_back != nullptr);
return tls.intern<CallType::PyModuleCall, E>(self, frame->f_back);
} else {
auto f_back = frame->f_back != nullptr ? frame->f_back : frame;
return tls.intern<CallType::PyCall, E>(frame, f_back);
}
};
queue_->getSubqueue()->emplace_py_call(get_key(), getApproximateTime());
}
void PythonTracer::recordCCall(
ThreadLocalResults& tls,
PyFrameObject* frame,
PyObject* arg) {
// NB: For C calls a new frame is not created, so we use `frame` rather than
// `frame->f_back`.
auto key = tls.intern<CallType::PyCCall, EventType::PyCCall>(arg, frame);
queue_->getSubqueue()->emplace_py_call(key, getApproximateTime());
}
// ============================================================================
// == Post processing =========================================================
// ============================================================================
struct Exit {
bool operator>(const Exit& other) const {
return t_ > other.t_;
}
time_t t_;
size_t python_tid_;
};
class PostProcess {
public:
PostProcess(
std::function<time_t(approx_time_t)> time_converter,
std::deque<ThreadLocalResults>& tls,
const ValueCache& value_cache)
: time_converter_{time_converter} {
for (size_t python_tid : c10::irange(tls.size())) {
CallTypeHelper<TraceKeyCacheState>::map(
tls[python_tid].trace_keys_, *this, value_cache, python_tid);
addExits<EventType::PyCall>(tls[python_tid].exit_times_, python_tid);
addExits<EventType::PyCCall>(tls[python_tid].c_exit_times_, python_tid);
}
}
template <CallType C>
void operator()(
const TraceKeyCacheState<C>& trace_cache,
const ValueCache& value_cache,
size_t python_tid) {
for (const auto& it : trace_cache.state_) {
const auto inserted = get_state<Config<C>::event_type>().fields_.insert(
{it.second, value_cache.load(it.first, python_tid)});
TORCH_INTERNAL_ASSERT(inserted.second, "Duplicate key: ", it.second);
}
}
template <EventType E, size_t N>
void addExits(
AppendOnlyList<approx_time_t, N>& exits,
size_t python_tid) {
for (const auto i : exits) {
get_state<E>().exits_.push({time_converter_(i), python_tid});
}
}
std::vector<std::shared_ptr<Result>> run(
std::vector<python_tracer::CompressedEvent>& enters) {
std::stable_sort(
enters.begin(), enters.end(), [](const auto a, const auto b) {
return a.enter_t_ < b.enter_t_;
});
std::vector<std::shared_ptr<Result>> out;
populate<EventType::PyCall>(enters, out);
populate<EventType::PyCCall>(enters, out);
return out;
}
private:
template <EventType E>
void populate(
std::vector<python_tracer::CompressedEvent>& enters,
std::vector<std::shared_ptr<Result>>& out) {
using stack_t = std::vector<std::shared_ptr<Result>>;
ska::flat_hash_map<size_t, stack_t> stacks;
auto& state = get_state<E>();
for (const auto& enter : enters) {
auto fields_it = state.fields_.find(enter.key_);
if (fields_it != state.fields_.end()) {
while (!state.exits_.empty() &&
state.exits_.top().t_ < enter.enter_t_) {
auto& stack = stacks[state.exits_.top().python_tid_];
TORCH_INTERNAL_ASSERT(stack.size(), "Python replay stack is empty.");
c10::get<ExtraFields<E>>(stack.back()->extra_fields_).end_time_ns_ =
state.exits_.top().t_;
state.exits_.pop();
stack.pop_back();
}
out.push_back(Result::create(
enter.enter_t_,
enter.system_tid_,
enter.kineto_info_,
fields_it->second));
stacks[fields_it->second.python_tid_].push_back(out.back());
}
}
}
template <EventType E>
struct State {
ska::flat_hash_map<TraceKey, ExtraFields<E>> fields_;
std::priority_queue<Exit, std::vector<Exit>, std::greater<Exit>> exits_;
};
template <EventType E>
auto& get_state() {
return std::get<E == EventType::PyCall ? 0 : 1>(state_);
}
std::function<time_t(approx_time_t)> time_converter_;
std::tuple<State<EventType::PyCall>, State<EventType::PyCCall>> state_;
};
struct PythonIDVisitor {
void operator()(ExtraFields<EventType::PyCall>& py_call) {
py_call.id_ = ++current_python_id_;
if (py_call.module_.has_value()) {
auto& m = py_call.module_;
auto& module_ids = module_ids_[m->cls_];
m->id_ = module_ids.insert({m->self_, module_ids.size()}).first->second;
}
}
void operator()(ExtraFields<EventType::PyCCall>& py_call) {
py_call.id_ = ++current_python_id_;
}
template <typename T>
void operator()(T&) {}
size_t current_python_id_{0};
ska::flat_hash_map<PyModuleCls, ska::flat_hash_map<PyModuleSelf, size_t>>
module_ids_;
};
std::vector<std::shared_ptr<Result>> PythonTracer::getEvents(
std::function<time_t(approx_time_t)> time_converter,
std::vector<python_tracer::CompressedEvent>& enters) {
value_cache_.trimPrefixes();
PostProcess post_process(time_converter, thread_local_results_, value_cache_);
auto out = post_process.run(enters);
std::stable_sort(
out.begin(), out.end(), [](const auto& a, const auto& b) {
return a->start_time_ns_ < b->start_time_ns_;
});
PythonIDVisitor id_visitor;
for (auto& i : out) {
c10::visit(id_visitor, i->extra_fields_);
}
return out;
}
// ============================================================================
// == API =====================================================================
// ============================================================================
int PythonTracer::pyProfileFn(
PyObject* obj,
PyFrameObject* frame,
int what,
PyObject* arg) {
auto& local_results =
*reinterpret_cast<TraceContext*>(obj)->thread_local_results_;
switch (what) {
case PyTrace_CALL:
PythonTracer::singleton().recordPyCall(local_results, frame);
break;
case PyTrace_C_CALL:
PythonTracer::singleton().recordCCall(local_results, frame, arg);
break;
case PyTrace_EXCEPTION:
case PyTrace_RETURN:
local_results.exit_times_.emplace_back(getApproximateTime());
break;
case PyTrace_C_EXCEPTION:
case PyTrace_C_RETURN:
local_results.c_exit_times_.emplace_back(getApproximateTime());
break;
}
return 0;
}
python_tracer::PythonTracerBase& getTracer() {
return PythonTracer::singleton();
}
} // namespace
} // namespace impl
} // namespace profiler
} // namespace torch
namespace torch {
namespace autograd {
namespace profiler {
namespace python_tracer {
void init() {
pybind11::gil_scoped_acquire gil;
TORCH_CHECK(PyType_Ready(&torch::profiler::impl::TraceContextType) == 0);
torch::profiler::impl::python_tracer::registerTracer(
&torch::profiler::impl::getTracer);
}
} // namespace python_tracer
} // namespace profiler
} // namespace autograd
} // namespace torch