blob: c3f34fa2730da1e8f109f3a2fa559bc626186104 [file] [log] [blame]
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Global flags for aot autograd
"""
import os
import sys
from typing import TYPE_CHECKING
# Converts torch rng ops to their functional philox rng equivalents. Note that
# we functionalize only CUDA rng ops today.
functionalize_rng_ops = False
# can be useful for debugging if we are incorrectly creating meta fake tensors
fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", True)
# Enables optional asserts in hotpath code to check for errors. If
# you are seeing weird accuracy problems, try turning this on.
# This is currently off by default as it will harm tracing time,
# but it is on by default for aot_eager.
debug_assert = False
debug_partitioner = os.environ.get("AOT_PARTITIONER_DEBUG", False)
# Today, if you are in a situation where there is "false aliasing"
# (e.g. you have a bunch of model parameters that all alias the same underlying buffer),
# our checks for this situation are very slow if these inputs have dynamic shapes.
# This config is set to ensure that there aren't too many aliased inputs in this situation,
# so that we error loudly instead of compiling forever.
# Eventually, we should make these checks faster.
# For now, however, you can simply turn off dynamic shapes by marking your inputs static
# when you run into this situation.
_max_aliased_inputs_with_dynamic_shapes_enabled = 5
static_weight_shapes = True
# Applies CSE to the graph before partitioning
cse = True
# Restricts the amount of computation AOTAutograd can do.
# NB: We have essentially disabled this heuristic now. However, this is kept
# here for now in case it's useful. Setting it low can artificially reduce the
# amount of recomputation AOTAutograd performs, although not in any kind of
# principled way.
max_dist_from_bw = 1000
# Bans recomputation of nodes that are reading from nodes that is far before
# the current node
ban_recompute_used_far_apart = True
# Breaks up long chain of fusible ops, as otherwise we can have an arbitrarily
# long chain of recomputation in the backwards pass.
ban_recompute_long_fusible_chains = True
# Bans recomputation of nodes that must be materialized in the backwards pass
# (used by a non-fusible node)
ban_recompute_materialized_backward = True
# Chooses to ban recomputation of nodes based off an allowlist. Setting it to
# False changes it to use a denylist. Main change is on operators like
# sort/pool/stuff that isn't cheap enough to be fusible for free but also isn't
# that expensive
ban_recompute_not_in_allowlist = True
# Chooses to ban recomputation of reductions. This is generally a good idea, as
# the result of reductions is generally very small but recomputing reductions in
# a fusion can be expensive.
ban_recompute_reductions = True
# Sets all of the ban_recompute heuristics to False except ban_recompute_reductions
# Generally, this will probably result in some memory improvement, but at the
# cost of some performance
aggressive_recomputation = False
# If FakeTensor.data_ptr() should error.
# This option is independent of AOTAutograd and torch.compile, but our policy
# is to turn it off during torch.compile.
fake_tensor_allow_unsafe_data_ptr_access = True
# Unlifts effect tokens from the inputs/outputs in the traced graph and instead
# inserts make_token/sink_token calls in the graph to create tokens and then
# sink them at the end. Note that this means the graph is no longer functional
# which may lead to silent errors unless the backend knows how to handle the
# tokens.
unlift_effect_tokens = False
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403
from torch.utils._config_module import install_config_module
# adds patch, save_config, invalid config checks, etc
install_config_module(sys.modules[__name__])