Switch TORCH_TRACE to accept a directory by default (#121331)
Directory is better because it works smoothly with distributed
runs; otherwise you'd need to modify torchrun to setup distinct
log names for each file.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Differential Revision: [D54597814](https://our.internmc.facebook.com/intern/diff/D54597814)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121331
Approved by: https://github.com/albanD
diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py
index 13d86dc..b2b3475 100644
--- a/torch/_logging/_internal.py
+++ b/torch/_logging/_internal.py
@@ -904,20 +904,13 @@
# Setup handler for the special trace_log, with different default
# configuration
- #
- # TODO: Automatically initialize this in Tupperware environment to point
- # to /logs/dedicated_logs_XXX
- trace_file_name = os.environ.get(TRACE_ENV_VAR, None)
- handler: Optional[logging.Handler] = None
- if trace_file_name is not None:
- handler = logging.FileHandler(trace_file_name)
- else:
- # This handler may remove itself if we are not actually in an FB
- # environment. This allows us to defer actually initializing it until
- # we actually need to log anything. This is important because JK
- # initializes a C++ singleton, which will pork our process if we
- # subsequently fork.
- handler = LazyFbTraceHandler()
+ trace_dir_name = os.environ.get(TRACE_ENV_VAR, None)
+ # This handler may remove itself if trace_dir_name is None and we are not
+ # actually in an FB environment. This allows us to defer actually
+ # initializing it until we actually need to log anything. This is
+ # important because JK initializes a C++ singleton, which will pork our
+ # process if we subsequently fork.
+ handler = LazyTraceHandler(trace_dir_name)
# This log is ALWAYS at debug level. We will additionally test if there
# are any handlers before deciding to actually call logging on this. Do
# not manually call
@@ -927,12 +920,13 @@
trace_log.addHandler(trace_log_handler)
-class LazyFbTraceHandler(logging.StreamHandler):
+class LazyTraceHandler(logging.StreamHandler):
"""Like FileHandler, but the file is allocated lazily only upon the first log message"""
- def __init__(self):
+ def __init__(self, root_dir: Optional[str]):
# This is implemented in the same way that delay is implemented on
# FileHandler
+ self.root_dir = root_dir
logging.Handler.__init__(self)
self.stream = None
self._builtin_open = open
@@ -961,35 +955,34 @@
def emit(self, record):
if self.stream is None:
- # TODO: more robust is_fbcode test
- import torch.version
-
- TRACE_LOG_DIR = "/logs"
- open_func = self._builtin_open
-
ok = False
- import torch.version as torch_version
+ if self.root_dir is None:
+ TRACE_LOG_DIR = "/logs"
+ open_func = self._builtin_open
- if hasattr(torch_version, "git_version"):
- log.info("LazyFbTraceHandler: disabled because not fbcode")
- elif not torch._utils_internal.justknobs_check("pytorch/trace:enable"):
- log.info(
- "LazyFbTraceHandler: disabled because justknobs_check('pytorch/trace:enable') returned False"
- )
- elif not os.path.exists(TRACE_LOG_DIR):
- log.info(
- "LazyFbTraceHandler: disabled because %s does not exist",
- TRACE_LOG_DIR,
- )
- elif not os.access(TRACE_LOG_DIR, os.W_OK):
- log.info(
- "LazyFbTraceHandler: disabled because %s is not writeable",
- TRACE_LOG_DIR,
- )
- else:
- ok = True
+ import torch.version as torch_version
- if ok:
+ if hasattr(torch_version, "git_version"):
+ log.info("LazyTraceHandler: disabled because not fbcode")
+ elif not torch._utils_internal.justknobs_check("pytorch/trace:enable"):
+ log.info(
+ "LazyTraceHandler: disabled because justknobs_check('pytorch/trace:enable') returned False"
+ )
+ elif not os.path.exists(TRACE_LOG_DIR):
+ log.info(
+ "LazyTraceHandler: disabled because %s does not exist",
+ TRACE_LOG_DIR,
+ )
+ elif not os.access(TRACE_LOG_DIR, os.W_OK):
+ log.info(
+ "LazyTraceHandler: disabled because %s is not writeable",
+ TRACE_LOG_DIR,
+ )
+ else:
+ self.root_dir = TRACE_LOG_DIR
+
+ if self.root_dir is not None:
+ os.makedirs(self.root_dir, exist_ok=True)
ranksuffix = ""
if dist.is_available() and dist.is_initialized():
ranksuffix = f"rank_{dist.get_rank()}_"
@@ -997,10 +990,10 @@
mode="w+",
suffix=".log",
prefix=f"dedicated_log_torch_trace_{ranksuffix}",
- dir=TRACE_LOG_DIR,
+ dir=self.root_dir,
delete=False,
)
- log.info("LazyFbTraceHandler: logging to %s", self.stream.name)
+ log.info("LazyTraceHandler: logging to %s", self.stream.name)
else:
# We go poof, remove and no-op
trace_log.removeHandler(self)