| #include <torch/csrc/utils/throughput_benchmark.h> | 
 |  | 
 | #include <pybind11/pybind11.h> | 
 | #include <torch/csrc/jit/python/pybind_utils.h> | 
 | #include <torch/csrc/utils/pybind.h> | 
 |  | 
 | namespace torch::throughput_benchmark { | 
 |  | 
 | std::ostream& operator<<( | 
 |     std::ostream& os, | 
 |     const BenchmarkExecutionStats& value) { | 
 |   return os << "Average latency / iter (ms): " << value.latency_avg_ms | 
 |             << "\n Total number of iters: " << value.num_iters; | 
 | } | 
 |  | 
 | void ThroughputBenchmark::addInput(py::args args, py::kwargs kwargs) { | 
 |   CHECK(script_module_.initialized() ^ module_.initialized()); | 
 |   if (script_module_.initialized()) { | 
 |     script_module_.addInput(std::move(args), std::move(kwargs)); | 
 |   } else { | 
 |     CHECK(module_.initialized()); | 
 |     module_.addInput(std::move(args), std::move(kwargs)); | 
 |   } | 
 | } | 
 |  | 
 | py::object ThroughputBenchmark::runOnce( | 
 |     const py::args& args, | 
 |     const py::kwargs& kwargs) { | 
 |   CHECK(script_module_.initialized() ^ module_.initialized()); | 
 |   if (script_module_.initialized()) { | 
 |     c10::IValue result; | 
 |     { | 
 |       pybind11::gil_scoped_release no_gil_guard; | 
 |       result = script_module_.runOnce(args, kwargs); | 
 |     } | 
 |     return jit::toPyObject(std::move(result)); | 
 |   } else { | 
 |     CHECK(module_.initialized()); | 
 |     return module_.runOnce(args, kwargs); | 
 |   } | 
 | } | 
 |  | 
 | ThroughputBenchmark::ThroughputBenchmark(const jit::Module& script_module) | 
 |     : script_module_(script_module) {} | 
 |  | 
 | ThroughputBenchmark::ThroughputBenchmark(py::object module) | 
 |     : module_(std::move(module)) {} | 
 |  | 
 | BenchmarkExecutionStats ThroughputBenchmark::benchmark( | 
 |     const BenchmarkConfig& config) const { | 
 |   CHECK(script_module_.initialized() ^ module_.initialized()); | 
 |   // Main benchmark thread doesn't hold the GIL after scheduling worker threads | 
 |   // But for now we don't release it as we will be implicitly manipulating with | 
 |   // py::object ref. counts in the case of nn.Module benchmarking. | 
 |   if (script_module_.initialized()) { | 
 |     return script_module_.benchmark(config); | 
 |   } else { | 
 |     CHECK(module_.initialized()); | 
 |     TORCH_WARN( | 
 |         "Starting benchmark on an nn.Module. This can be slow due " | 
 |         "to Python GIL.For proper inference simulation you might want to switch to " | 
 |         "a ScriptModule instead"); | 
 |     return module_.benchmark(config); | 
 |   } | 
 | } | 
 |  | 
 | namespace detail { | 
 |  | 
 | template <> | 
 | void ScriptModuleBenchmark::runOnce(ScriptModuleInput&& input) const { | 
 |   CHECK(initialized_); | 
 |   // TODO: provide guarantees that compiler won't optimize this out | 
 |   model_.get_method("forward").function()(std::move(input)); | 
 | } | 
 |  | 
 | template <> | 
 | ScriptModuleOutput ScriptModuleBenchmark::runOnce( | 
 |     const py::args& args, | 
 |     const py::kwargs& kwargs) const { | 
 |   CHECK(initialized_); | 
 |   auto& function = model_.get_method("forward").function(); | 
 |   ScriptModuleInput stack = jit::createStackForSchema( | 
 |       function.getSchema(), args, kwargs, model_._ivalue()); | 
 |   return function(std::move(stack)); | 
 | } | 
 |  | 
 | template <> | 
 | void ModuleBenchmark::runOnce(ModuleInput&& input) const { | 
 |   CHECK(initialized_); | 
 |   pybind11::gil_scoped_acquire gil_guard; | 
 |   model_(*input.args, **input.kwargs); | 
 | } | 
 |  | 
 | template <> | 
 | ModuleOutput ModuleBenchmark::runOnce( | 
 |     const py::args& args, | 
 |     const py::kwargs& kwargs) const { | 
 |   CHECK(initialized_); | 
 |   pybind11::gil_scoped_acquire gil_guard; | 
 |   return model_(*args, **kwargs); | 
 | } | 
 |  | 
 | template <> | 
 | void ScriptModuleBenchmark::addInput(py::args&& args, py::kwargs&& kwargs) { | 
 |   jit::Stack stack = jit::createStackForSchema( | 
 |       model_.get_method("forward").function().getSchema(), | 
 |       args, | 
 |       kwargs, | 
 |       model_._ivalue()); | 
 |   inputs_.emplace_back(std::move(stack)); | 
 | } | 
 |  | 
 | template <> | 
 | void ScriptModuleBenchmark::addInput(ScriptModuleInput&& input) { | 
 |   input.insert(input.begin(), model_._ivalue()); | 
 |   inputs_.emplace_back(std::move(input)); | 
 | } | 
 |  | 
 | template <> | 
 | void ModuleBenchmark::addInput(py::args&& args, py::kwargs&& kwargs) { | 
 |   inputs_.emplace_back(std::move(args), std::move(kwargs)); | 
 | } | 
 |  | 
 | template <> | 
 | ModuleInput cloneInput<ModuleInput>(const ModuleInput& input) { | 
 |   pybind11::gil_scoped_acquire gil_guard; | 
 |   py::args args = input.args; | 
 |   py::kwargs kwargs = input.kwargs; | 
 |   return {std::move(args), std::move(kwargs)}; | 
 | } | 
 |  | 
 | template <> | 
 | ScriptModuleInput cloneInput<ScriptModuleInput>( | 
 |     const ScriptModuleInput& input) { | 
 |   return input; | 
 | } | 
 |  | 
 | } // namespace detail | 
 |  | 
 | } // namespace torch::throughput_benchmark |