Switch TORCH_TRACE to accept a directory by default (#121331)

Directory is better because it works smoothly with distributed
runs; otherwise you'd need to modify torchrun to setup distinct
log names for each file.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Differential Revision: [D54597814](https://our.internmc.facebook.com/intern/diff/D54597814)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121331
Approved by: https://github.com/albanD
diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py
index 13d86dc..b2b3475 100644
--- a/torch/_logging/_internal.py
+++ b/torch/_logging/_internal.py
@@ -904,20 +904,13 @@
 
     # Setup handler for the special trace_log, with different default
     # configuration
-    #
-    # TODO: Automatically initialize this in Tupperware environment to point
-    # to /logs/dedicated_logs_XXX
-    trace_file_name = os.environ.get(TRACE_ENV_VAR, None)
-    handler: Optional[logging.Handler] = None
-    if trace_file_name is not None:
-        handler = logging.FileHandler(trace_file_name)
-    else:
-        # This handler may remove itself if we are not actually in an FB
-        # environment.  This allows us to defer actually initializing it until
-        # we actually need to log anything.  This is important because JK
-        # initializes a C++ singleton, which will pork our process if we
-        # subsequently fork.
-        handler = LazyFbTraceHandler()
+    trace_dir_name = os.environ.get(TRACE_ENV_VAR, None)
+    # This handler may remove itself if trace_dir_name is None and we are not
+    # actually in an FB environment.  This allows us to defer actually
+    # initializing it until we actually need to log anything.  This is
+    # important because JK initializes a C++ singleton, which will pork our
+    # process if we subsequently fork.
+    handler = LazyTraceHandler(trace_dir_name)
     # This log is ALWAYS at debug level.  We will additionally test if there
     # are any handlers before deciding to actually call logging on this.  Do
     # not manually call
@@ -927,12 +920,13 @@
     trace_log.addHandler(trace_log_handler)
 
 
-class LazyFbTraceHandler(logging.StreamHandler):
+class LazyTraceHandler(logging.StreamHandler):
     """Like FileHandler, but the file is allocated lazily only upon the first log message"""
 
-    def __init__(self):
+    def __init__(self, root_dir: Optional[str]):
         # This is implemented in the same way that delay is implemented on
         # FileHandler
+        self.root_dir = root_dir
         logging.Handler.__init__(self)
         self.stream = None
         self._builtin_open = open
@@ -961,35 +955,34 @@
 
     def emit(self, record):
         if self.stream is None:
-            # TODO: more robust is_fbcode test
-            import torch.version
-
-            TRACE_LOG_DIR = "/logs"
-            open_func = self._builtin_open
-
             ok = False
-            import torch.version as torch_version
+            if self.root_dir is None:
+                TRACE_LOG_DIR = "/logs"
+                open_func = self._builtin_open
 
-            if hasattr(torch_version, "git_version"):
-                log.info("LazyFbTraceHandler: disabled because not fbcode")
-            elif not torch._utils_internal.justknobs_check("pytorch/trace:enable"):
-                log.info(
-                    "LazyFbTraceHandler: disabled because justknobs_check('pytorch/trace:enable') returned False"
-                )
-            elif not os.path.exists(TRACE_LOG_DIR):
-                log.info(
-                    "LazyFbTraceHandler: disabled because %s does not exist",
-                    TRACE_LOG_DIR,
-                )
-            elif not os.access(TRACE_LOG_DIR, os.W_OK):
-                log.info(
-                    "LazyFbTraceHandler: disabled because %s is not writeable",
-                    TRACE_LOG_DIR,
-                )
-            else:
-                ok = True
+                import torch.version as torch_version
 
-            if ok:
+                if hasattr(torch_version, "git_version"):
+                    log.info("LazyTraceHandler: disabled because not fbcode")
+                elif not torch._utils_internal.justknobs_check("pytorch/trace:enable"):
+                    log.info(
+                        "LazyTraceHandler: disabled because justknobs_check('pytorch/trace:enable') returned False"
+                    )
+                elif not os.path.exists(TRACE_LOG_DIR):
+                    log.info(
+                        "LazyTraceHandler: disabled because %s does not exist",
+                        TRACE_LOG_DIR,
+                    )
+                elif not os.access(TRACE_LOG_DIR, os.W_OK):
+                    log.info(
+                        "LazyTraceHandler: disabled because %s is not writeable",
+                        TRACE_LOG_DIR,
+                    )
+                else:
+                    self.root_dir = TRACE_LOG_DIR
+
+            if self.root_dir is not None:
+                os.makedirs(self.root_dir, exist_ok=True)
                 ranksuffix = ""
                 if dist.is_available() and dist.is_initialized():
                     ranksuffix = f"rank_{dist.get_rank()}_"
@@ -997,10 +990,10 @@
                     mode="w+",
                     suffix=".log",
                     prefix=f"dedicated_log_torch_trace_{ranksuffix}",
-                    dir=TRACE_LOG_DIR,
+                    dir=self.root_dir,
                     delete=False,
                 )
-                log.info("LazyFbTraceHandler: logging to %s", self.stream.name)
+                log.info("LazyTraceHandler: logging to %s", self.stream.name)
             else:
                 # We go poof, remove and no-op
                 trace_log.removeHandler(self)