[execution trace] ignore some properties when symbolic size/strides exist (#112458)
Fixes #112235
Otherwise an exception will be thrown when we try to access storage or sizes on a tensor with symbolic size/strides.
Added a test in test/dynamo/test_profiler.py
Differential Revision: [D50821576](https://our.internmc.facebook.com/intern/diff/D50821576)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112458
Approved by: https://github.com/aaronenyeshi
diff --git a/test/dynamo/test_profiler.py b/test/dynamo/test_profiler.py
index 592086a..4bcdf35 100644
--- a/test/dynamo/test_profiler.py
+++ b/test/dynamo/test_profiler.py
@@ -10,6 +10,8 @@
from torch._dynamo.testing import same
from torch._dynamo.utils import dynamo_timed
+from torch.testing._internal.common_utils import TemporaryFileName
+
class DynamoProfilerTests(torch._dynamo.test_case.TestCase):
def test_dynamo_timed_profiling_isolated(self):
@@ -92,6 +94,21 @@
with torch.profiler.profile(record_shapes=True):
opt_fn(*inputs)
+ def test_execution_trace_dynamic_shapes(self):
+ def fn(x, y, z):
+ return x @ y + z
+
+ et = torch.profiler.ExecutionTraceObserver()
+ opt_fn = torch.compile(fn, dynamic=True, backend="aot_eager")
+ inputs = [torch.rand((4, 4)) for _ in range(3)]
+
+ with TemporaryFileName() as fname:
+ et.register_callback(fname)
+ et.start()
+ out = opt_fn(*inputs)
+ et.stop()
+ et.unregister_callback()
+
def test_profiler_cache_lookup(self):
def fn(x):
y = x**2
diff --git a/torch/csrc/profiler/standalone/execution_trace_observer.cpp b/torch/csrc/profiler/standalone/execution_trace_observer.cpp
index 8c6644a..690a07f 100644
--- a/torch/csrc/profiler/standalone/execution_trace_observer.cpp
+++ b/torch/csrc/profiler/standalone/execution_trace_observer.cpp
@@ -87,7 +87,8 @@
const size_t maxArrayLen = maxNumElements) {
if (val.isTensor()) {
auto& tensor = val.toTensor();
- if (tensor.defined()) {
+ if (tensor.defined() &&
+ !tensor.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {
return vectorToString(tensor.sizes().vec());
}
} else if (val.isTuple()) {
@@ -389,7 +390,8 @@
size_t numel = 0;
size_t itemsize = 0;
std::string device_str = "";
- if (t->has_storage()) {
+ // symbolic sizes/strides implies t->storage_offset() will fail
+ if (t->has_storage() && !t->has_symbolic_sizes_strides()) {
auto& t_storage = t->storage();
storage_id = getObjectID(ob, t_storage.data());
offset = t->storage_offset();