[PyPer] Skip printing out per node time when do_profile is on (#63256)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63256
This suppresses printing out the per node time which is very long when the net has too many ops. It can be easily turned on by setting `--pt_sr_print_per_node_time=1`.
Reviewed By: ajyu, mikeiovine
Differential Revision: D30298331
fbshipit-source-id: 32b3f93b3fe19d335654168311fda93331a1e706
diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp
index 2d8b6c4..f51c4e0 100644
--- a/torch/csrc/jit/runtime/static/impl.cpp
+++ b/torch/csrc/jit/runtime/static/impl.cpp
@@ -845,7 +845,8 @@
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs,
const int warmup_runs,
- const int main_runs) {
+ const int main_runs,
+ bool print_per_node_time) {
float time_per_iter = benchmark_model(args, kwargs, warmup_runs, main_runs);
std::cout << "Static runtime ms per iter: " << time_per_iter
<< ". Iters per second: " << 1000.0 / time_per_iter << std::endl;
@@ -853,11 +854,13 @@
IndividualMetrics results =
benchmark_individual_ops(args, kwargs, warmup_runs, main_runs);
- for (const auto i : c10::irange(nodes_.size())) {
- const Node* node = nodes_[i].node();
- std::cout << "Node #" << i << ": " << results.time_per_node[i]
- << " ms/iter, ";
- node->print(std::cout, 0, nullptr, false);
+ if (print_per_node_time) {
+ for (const auto i : c10::irange(nodes_.size())) {
+ const Node* node = nodes_[i].node();
+ std::cout << "Node #" << i << ": " << results.time_per_node[i]
+ << " ms/iter, ";
+ node->print(std::cout, 0, nullptr, false);
+ }
}
std::vector<std::pair<std::string, double>> time_per_node_type_vec{
diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h
index bf28dfc..cc36df0 100644
--- a/torch/csrc/jit/runtime/static/impl.h
+++ b/torch/csrc/jit/runtime/static/impl.h
@@ -209,7 +209,8 @@
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs,
const int warmup_runs,
- const int main_runs);
+ const int main_runs,
+ bool print_per_node_time = false);
float benchmark_model(
const std::vector<c10::IValue>& args,