[dynamo][cpp-guards] Improve the logs (#123780)

For this program

~~~
@torch.compile(backend="eager")
def fn(x, y, d):
    return x * y * d["foo"] * d["bar"]
~~~

Python logs are

~~~
V0410 15:48:57.778000 140318524949632 torch/_dynamo/guards.py:1785] [0/0] [__guards] GUARDS:
V0410 15:48:57.778000 140318524949632 torch/_dynamo/guards.py:1803] [0/0] [__guards] ___check_type_id(L['d'], 8833952)                             # return x * y * d["foo"] * d["bar"]  # examples/ord_dicts.py:24 in fn
V0410 15:48:57.778000 140318524949632 torch/_dynamo/guards.py:1803] [0/0] [__guards] len(L['d']) == 2                                              # return x * y * d["foo"] * d["bar"]  # examples/ord_dicts.py:24 in fn
V0410 15:48:57.779000 140318524949632 torch/_dynamo/guards.py:1803] [0/0] [__guards] list(L['d'].keys()) == ['foo', 'bar']                         # return x * y * d["foo"] * d["bar"]  # examples/ord_dicts.py:24 in fn
V0410 15:48:57.779000 140318524949632 torch/_dynamo/guards.py:1803] [0/0] [__guards] hasattr(L['x'], '_dynamo_dynamic_indices') == False           # return x * y * d["foo"] * d["bar"]  # examples/ord_dicts.py:24 in fn
V0410 15:48:57.779000 140318524949632 torch/_dynamo/guards.py:1803] [0/0] [__guards] hasattr(L['y'], '_dynamo_dynamic_indices') == False           # return x * y * d["foo"] * d["bar"]  # examples/ord_dicts.py:24 in fn
V0410 15:48:57.779000 140318524949632 torch/_dynamo/guards.py:1803] [0/0] [__guards] ___check_type_id(L['d']['bar'], 8842592)                      # return x * y * d["foo"] * d["bar"]  # examples/ord_dicts.py:24 in fn
V0410 15:48:57.779000 140318524949632 torch/_dynamo/guards.py:1803] [0/0] [__guards] L['d']['bar'] == 2                                            # return x * y * d["foo"] * d["bar"]  # examples/ord_dicts.py:24 in fn
V0410 15:48:57.779000 140318524949632 torch/_dynamo/guards.py:1803] [0/0] [__guards] ___check_type_id(L['d']['foo'], 8842592)                      # return x * y * d["foo"] * d["bar"]  # examples/ord_dicts.py:24 in fn
V0410 15:48:57.779000 140318524949632 torch/_dynamo/guards.py:1803] [0/0] [__guards] L['d']['foo'] == 4                                            # return x * y * d["foo"] * d["bar"]  # examples/ord_dicts.py:24 in fn
V0410 15:48:57.779000 140318524949632 torch/_dynamo/guards.py:1803] [0/0] [__guards] utils_device.CURRENT_DEVICE == None                           # _dynamo/output_graph.py:450 in init_ambient_guards
V0410 15:48:57.779000 140318524949632 torch/_dynamo/guards.py:1803] [0/0] [__guards] check_tensor(L['x'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[4], stride=[1])  # return x * y * d["foo"] * d["bar"]  # examples/ord_dicts.py:24 in fn
V0410 15:48:57.780000 140318524949632 torch/_dynamo/guards.py:1803] [0/0] [__guards] check_tensor(L['y'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[4], stride=[1])  # return x * y * d["foo"] * d["bar"]  # examples/ord_dicts.py:24 in fn
~~~

CPP logs are

~~~
V0410 15:49:41.607000 140481927914624 torch/_dynamo/guards.py:1792] [0/0] [__guards] GUARDS:
V0410 15:49:41.607000 140481927914624 torch/_dynamo/guards.py:1769] [0/0] [__guards]
V0410 15:49:41.607000 140481927914624 torch/_dynamo/guards.py:1769] [0/0] [__guards] TREE_GUARD_MANAGER:
V0410 15:49:41.607000 140481927914624 torch/_dynamo/guards.py:1769] [0/0] [__guards] +- RootGuardManager
V0410 15:49:41.607000 140481927914624 torch/_dynamo/guards.py:1769] [0/0] [__guards] | +- DEFAULT_DEVICE: utils_device.CURRENT_DEVICE == None                           # _dynamo/output_graph.py:450 in init_ambient_guards
V0410 15:49:41.607000 140481927914624 torch/_dynamo/guards.py:1769] [0/0] [__guards] | +- GLOBAL_STATE: ___check_global_state()
V0410 15:49:41.607000 140481927914624 torch/_dynamo/guards.py:1769] [0/0] [__guards] | +- DictSubclassGuardManager: source=L['d'], accessed_by=DictGetItemGuardAccessor(d)
V0410 15:49:41.607000 140481927914624 torch/_dynamo/guards.py:1769] [0/0] [__guards] | | +- KeyValueManager pair at index=0
V0410 15:49:41.607000 140481927914624 torch/_dynamo/guards.py:1769] [0/0] [__guards] | | | +- KeyManager: GuardManager: source=list(L['d'].keys())[0]
V0410 15:49:41.607000 140481927914624 torch/_dynamo/guards.py:1769] [0/0] [__guards] | | | | +- EQUALS_MATCH: list(L['d'].keys())[0] == 'foo'                               # return x * y * d["foo"] * d["bar"]  # examples/ord_dicts.py:24 in fn
V0410 15:49:41.607000 140481927914624 torch/_dynamo/guards.py:1769] [0/0] [__guards] | | | +- ValueManager: GuardManager: source=L['d']['foo']
V0410 15:49:41.607000 140481927914624 torch/_dynamo/guards.py:1769] [0/0] [__guards] | | | | +- EQUALS_MATCH: L['d']['foo'] == 4                                            # return x * y * d["foo"] * d["bar"]  # examples/ord_dicts.py:24 in fn
V0410 15:49:41.607000 140481927914624 torch/_dynamo/guards.py:1769] [0/0] [__guards] | | +- KeyValueManager pair at index=1
V0410 15:49:41.607000 140481927914624 torch/_dynamo/guards.py:1769] [0/0] [__guards] | | | +- KeyManager: GuardManager: source=list(L['d'].keys())[1]
V0410 15:49:41.607000 140481927914624 torch/_dynamo/guards.py:1769] [0/0] [__guards] | | | | +- EQUALS_MATCH: list(L['d'].keys())[1] == 'bar'                               # return x * y * d["foo"] * d["bar"]  # examples/ord_dicts.py:24 in fn
V0410 15:49:41.607000 140481927914624 torch/_dynamo/guards.py:1769] [0/0] [__guards] | | | +- ValueManager: GuardManager: source=L['d']['bar']
V0410 15:49:41.607000 140481927914624 torch/_dynamo/guards.py:1769] [0/0] [__guards] | | | | +- EQUALS_MATCH: L['d']['bar'] == 2                                            # return x * y * d["foo"] * d["bar"]  # examples/ord_dicts.py:24 in fn
V0410 15:49:41.607000 140481927914624 torch/_dynamo/guards.py:1769] [0/0] [__guards] | +- GuardManager: source=L['x'], accessed_by=DictGetItemGuardAccessor(x)
V0410 15:49:41.607000 140481927914624 torch/_dynamo/guards.py:1769] [0/0] [__guards] | | +- TENSOR_MATCH: check_tensor(L['x'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[4], stride=[1])  # return x * y * d["foo"] * d["bar"]  # examples/ord_dicts.py:24 in fn
V0410 15:49:41.607000 140481927914624 torch/_dynamo/guards.py:1769] [0/0] [__guards] | | +- NO_HASATTR: hasattr(L['x'], '_dynamo_dynamic_indices') == False           # return x * y * d["foo"] * d["bar"]  # examples/ord_dicts.py:24 in fn
V0410 15:49:41.607000 140481927914624 torch/_dynamo/guards.py:1769] [0/0] [__guards] | | +- NO_TENSOR_ALIASING: check_no_aliasing(L['x'], L['y'])
V0410 15:49:41.607000 140481927914624 torch/_dynamo/guards.py:1769] [0/0] [__guards] | +- GuardManager: source=L['y'], accessed_by=DictGetItemGuardAccessor(y)
V0410 15:49:41.607000 140481927914624 torch/_dynamo/guards.py:1769] [0/0] [__guards] | | +- TENSOR_MATCH: check_tensor(L['y'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[4], stride=[1])  # return x * y * d["foo"] * d["bar"]  # examples/ord_dicts.py:24 in fn
V0410 15:49:41.607000 140481927914624 torch/_dynamo/guards.py:1769] [0/0] [__guards] | | +- NO_HASATTR: hasattr(L['y'], '_dynamo_dynamic_indices') == False           # return x * y * d["foo"] * d["bar"]  # examples/ord_dicts.py:24 in fn
V0410 15:49:41.607000 140481927914624 torch/_dynamo/guards.py:1769] [0/0] [__guards] | | +- NO_TENSOR_ALIASING: check_no_aliasing(L['x'], L['y'])
~~~~

This info is also present in this gist for better viewing - https://gist.github.com/anijain2305/b418706e4ad4ec2d601530bc24cf8a20

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123780
Approved by: https://github.com/ezyang, https://github.com/jansel
ghstack dependencies: #123773, #123787
diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py
index e775d63..f54b3cf 100644
--- a/torch/_dynamo/guards.py
+++ b/torch/_dynamo/guards.py
@@ -139,23 +139,25 @@
         parts = [guard_name + ": " + part for part in parts]
         return parts
 
-    def get_manager_line(self, accessor_str, guard_manager):
+    def get_manager_line(self, guard_manager, accessor_str=None):
         source = guard_manager.get_source()
         t = guard_manager.__class__.__name__
-        s = t + "(source = " + source + ", accessor = " + accessor_str + ")"
+        s = t + ": source=" + source
+        if accessor_str:
+            s += ", " + accessor_str
         return s
 
     def construct_dict_manager_string(self, mgr, body):
         for idx, (key_mgr, val_mgr) in sorted(mgr.get_key_value_managers().items()):
-            if key_mgr:
-                accessor = f"KeyManager(index={idx})"
-                body.writeline(self.get_manager_line(accessor, key_mgr))
-                self.construct_manager_string(key_mgr, body)
+            body.writeline(f"KeyValueManager pair at index={idx}")
+            with body.indent():
+                if key_mgr:
+                    body.writeline(f"KeyManager: {self.get_manager_line(key_mgr)}")
+                    self.construct_manager_string(key_mgr, body)
 
-            if val_mgr:
-                accessor = f"ValueManager(index={idx})"
-                body.writeline(self.get_manager_line(accessor, val_mgr))
-                self.construct_manager_string(val_mgr, body)
+                if val_mgr:
+                    body.writeline(f"ValueManager: {self.get_manager_line(val_mgr)}")
+                    self.construct_manager_string(val_mgr, body)
 
     def construct_manager_string(self, mgr, body):
         with body.indent():
@@ -170,7 +172,9 @@
             for accessor, child_mgr in zip(
                 mgr.get_accessors(), mgr.get_child_managers()
             ):
-                body.writeline(self.get_manager_line(accessor.repr(), child_mgr))
+                body.writeline(
+                    self.get_manager_line(child_mgr, f"accessed_by={accessor.repr()}")
+                )
                 self.construct_manager_string(child_mgr, body)
 
     def __str__(self):
@@ -346,7 +350,7 @@
 
 
 def get_key_index_source(source, index):
-    return f"{source}.keys()[{index}]"
+    return f"list({source}.keys())[{index}]"
 
 
 def getitem_on_dict_manager(
@@ -361,13 +365,19 @@
         index = get_key_index(base_example_value, source.index)
 
     key_source = get_key_index_source(base_source_name, index)
-    value_source = f"{base_source_name}[{key_source}]"
+    key_example_value = list(base_example_value.keys())[index]
+    if isinstance(key_example_value, (int, str)):
+        value_source = f"{base_source_name}[{key_example_value!r}]"
+    else:
+        value_source = f"{base_source_name}[{key_source}]"
     if not isinstance(source.index, ConstDictKeySource):
         # We have to insert a key manager guard here
         # TODO - source debug string is probably wrong here.
         base_guard_manager.get_key_manager(
             index=index, source=key_source, example_value=source.index
-        ).add_equals_match_guard(source.index, [f"{key_source} == {source.index}"])
+        ).add_equals_match_guard(
+            source.index, [f"{key_source} == {key_example_value!r}"]
+        )
 
     return base_guard_manager.get_value_manager(
         index=index, source=value_source, example_value=example_value
@@ -453,7 +463,7 @@
         dict_mgr = self.get_guard_manager(guard)
         assert isinstance(dict_mgr, DictGuardManager)
         for idx, key in enumerate(value.keys()):
-            key_source = guard.name + f".keys()[{idx}]"
+            key_source = get_key_index_source(guard.name, idx)
             key_manager = dict_mgr.get_key_manager(
                 index=idx, source=key_source, example_value=key
             )
@@ -469,7 +479,7 @@
             else:
                 # Install EQUALS_MATCH guard
                 key_manager.add_equals_match_guard(
-                    key, get_verbose_code_parts(f"{key_source} == {key}", guard)
+                    key, get_verbose_code_parts(f"{key_source} == {key!r}", guard)
                 )
 
     def get_global_guard_manager(self):