[execution trace] ignore some properties when symbolic size/strides exist (#112458)

Fixes #112235

Otherwise an exception will be thrown when we try to access storage or sizes on a tensor with symbolic size/strides.

Added a test in test/dynamo/test_profiler.py

Differential Revision: [D50821576](https://our.internmc.facebook.com/intern/diff/D50821576)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112458
Approved by: https://github.com/aaronenyeshi
diff --git a/test/dynamo/test_profiler.py b/test/dynamo/test_profiler.py
index 592086a..4bcdf35 100644
--- a/test/dynamo/test_profiler.py
+++ b/test/dynamo/test_profiler.py
@@ -10,6 +10,8 @@
 from torch._dynamo.testing import same
 from torch._dynamo.utils import dynamo_timed
 
+from torch.testing._internal.common_utils import TemporaryFileName
+
 
 class DynamoProfilerTests(torch._dynamo.test_case.TestCase):
     def test_dynamo_timed_profiling_isolated(self):
@@ -92,6 +94,21 @@
         with torch.profiler.profile(record_shapes=True):
             opt_fn(*inputs)
 
+    def test_execution_trace_dynamic_shapes(self):
+        def fn(x, y, z):
+            return x @ y + z
+
+        et = torch.profiler.ExecutionTraceObserver()
+        opt_fn = torch.compile(fn, dynamic=True, backend="aot_eager")
+        inputs = [torch.rand((4, 4)) for _ in range(3)]
+
+        with TemporaryFileName() as fname:
+            et.register_callback(fname)
+            et.start()
+            out = opt_fn(*inputs)
+            et.stop()
+            et.unregister_callback()
+
     def test_profiler_cache_lookup(self):
         def fn(x):
             y = x**2
diff --git a/torch/csrc/profiler/standalone/execution_trace_observer.cpp b/torch/csrc/profiler/standalone/execution_trace_observer.cpp
index 8c6644a..690a07f 100644
--- a/torch/csrc/profiler/standalone/execution_trace_observer.cpp
+++ b/torch/csrc/profiler/standalone/execution_trace_observer.cpp
@@ -87,7 +87,8 @@
     const size_t maxArrayLen = maxNumElements) {
   if (val.isTensor()) {
     auto& tensor = val.toTensor();
-    if (tensor.defined()) {
+    if (tensor.defined() &&
+        !tensor.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {
       return vectorToString(tensor.sizes().vec());
     }
   } else if (val.isTuple()) {
@@ -389,7 +390,8 @@
     size_t numel = 0;
     size_t itemsize = 0;
     std::string device_str = "";
-    if (t->has_storage()) {
+    // symbolic sizes/strides implies t->storage_offset() will fail
+    if (t->has_storage() && !t->has_symbolic_sizes_strides()) {
       auto& t_storage = t->storage();
       storage_id = getObjectID(ob, t_storage.data());
       offset = t->storage_offset();