Adding additional debug logging and documentation for shape functions (#77115)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77115
Approved by: https://github.com/eellison
diff --git a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp
index 88c1f53..52dcb2f 100644
--- a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp
+++ b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp
@@ -249,7 +249,13 @@
std::unordered_map<std::string, std::shared_ptr<Graph>>& reused_functions,
const CompilationUnit& module) {
std::shared_ptr<Graph> graph;
+ GRAPH_DEBUG(
+ "Registering schema: ",
+ *schema_string,
+ " with shape compute func: ",
+ shape_compute_function_name);
if (reused_functions.count(shape_compute_function_name)) {
+ GRAPH_DEBUG("Registering reused schema");
graph = reused_functions[shape_compute_function_name];
} else {
Function& shape_compute_function =
diff --git a/torch/jit/_shape_functions.py b/torch/jit/_shape_functions.py
index 1c4ac59..550b924 100644
--- a/torch/jit/_shape_functions.py
+++ b/torch/jit/_shape_functions.py
@@ -5,9 +5,17 @@
###
# There are generated files that depend on this file
-# To re-generate, please run:
-# cd ~/pytorch && python
-# torchgen/shape_functions/gen_jit_shape_functions.py
+# To re-generate, please run from the root of the repo:
+# python torchgen/shape_functions/gen_jit_shape_functions.py
+
+# How to test:
+# After regenerating files, compile PyTorch.
+# Then run: ./build/bin/test_jit --gtest_filter=TestShapeGraphLinting.Basic
+# If you have enabled opinfo testing for the op, also run:
+# python test/test_ops_jit.py TestJitCPU::test_variant_consistency_jit_[FAILING_OP]_cpu_float32
+# to reproduce errors from opinfo tests.
+
+# Example PR: https://github.com/pytorch/pytorch/pull/80860/files
####
import torch