[inductor] FX graph cache: Add support for symbolic shapes (#111421)
Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object.
Test Plan: New unit tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111421
Approved by: https://github.com/ezyang
diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py
index bad39cc..b2bbf01 100644
--- a/test/inductor/test_codecache.py
+++ b/test/inductor/test_codecache.py
@@ -17,6 +17,7 @@
TensorMetadataAndValues,
)
from torch.testing._internal.common_cuda import SM80OrLater
+from torch.testing._internal.common_device_type import largeTensorTest
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
@@ -98,7 +99,8 @@
@config.patch({"fx_graph_cache": True})
@parametrize("device", ("cuda", "cpu"))
@parametrize("dtype", (torch.float32, torch.bfloat16))
- def test_cache_load_function(self, device, dtype):
+ @parametrize("dynamic", (False, True))
+ def test_cache_load_function(self, device, dtype, dynamic):
"""
Verify that we can populate and load functions from the cache.
"""
@@ -114,7 +116,7 @@
b = torch.rand(5, 5, dtype=dtype, device=device)
c = a.view(5, 5)
- compiled_fn = torch.compile(fn, dynamic=False)
+ compiled_fn = torch.compile(fn, dynamic=dynamic)
# A first call shold miss in the cache.
self.assertEqual(fn(a, b), compiled_fn(a, b))
@@ -138,7 +140,8 @@
@config.patch({"fx_graph_cache": True})
@parametrize("device", ("cuda", "cpu"))
@parametrize("dtype", (torch.float32, torch.float64))
- def test_cache_load_model(self, device, dtype):
+ @parametrize("dynamic", (False, True))
+ def test_cache_load_model(self, device, dtype, dynamic):
"""
Verify that we can populate and load models from the cache.
"""
@@ -150,7 +153,7 @@
mod(x).sum().backward()
return [p.grad for p in mod.parameters()]
- compiled_fn = torch.compile(fn, dynamic=False)
+ compiled_fn = torch.compile(fn, dynamic=dynamic)
mod = MyModelConv2d().to(device=device, dtype=dtype)
inp = torch.randn(2, 3, 16, 16, device=device, dtype=dtype)
@@ -172,6 +175,132 @@
# And the results should be the same.
self.assertEqual(grads1, grads2)
+ @largeTensorTest("64GB", device="cuda")
+ @config.patch({"fx_graph_cache": True})
+ @parametrize("device", ("cuda",))
+ @parametrize("dtype", (torch.float16, torch.bfloat16))
+ def test_cache_load_with_guards_int32_bounds(self, device, dtype):
+ """
+ Test caching the same graph, but under conditions that introduce guards
+ for tensor sizes < int32.
+ """
+ if device == "cuda" and not HAS_CUDA:
+ raise unittest.SkipTest("requires CUDA")
+ if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
+ raise unittest.SkipTest("requires SM80 or later")
+
+ def fn(x, y):
+ return (x + x, y + y)
+
+ compiled_fn = torch.compile(fn, dynamic=True)
+
+ # Iterate over different shapes, varying whether the total
+ # size is below or above int32. For each combination, we expect
+ # different guards around whether the symbolic sizes do or do
+ # not exceed int32.
+ shapes = (
+ ((5, 6), (7, 8)),
+ ((5, 6), (47000, 47001)),
+ ((47000, 47001), (5, 6)),
+ )
+ for a_shape, b_shape in shapes:
+ a = torch.rand(a_shape, device=device, dtype=dtype)
+ b = torch.rand(b_shape, device=device, dtype=dtype)
+
+ # AVOID a dynamo reset here. We expect guards to have been
+ # added that will be violated with the new shape. We should
+ # see a recompilation (along with a cache miss).
+ counters.clear()
+ res1 = compiled_fn(a, b)
+ self.assertGreater(counters["inductor"]["fxgraph_cache_miss"], 0)
+ self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
+
+ # A second call should hit. (Reset here to force compilation).
+ counters.clear()
+ torch._dynamo.reset()
+ res2 = compiled_fn(a, b)
+ self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
+ self.assertGreater(counters["inductor"]["fxgraph_cache_hit"], 0)
+
+ self.assertEqual(res1, res2)
+
+ @config.patch({"fx_graph_cache": True})
+ @parametrize("device", ("cuda", "cpu"))
+ @parametrize("dtype", (torch.float32, torch.bfloat16))
+ def test_cache_load_with_guards_static_bounds(self, device, dtype):
+ """
+ Test caching the same graph, but under conditions that introduce guards
+ for static bounds.
+ """
+ if device == "cuda" and not HAS_CUDA:
+ raise unittest.SkipTest("requires CUDA")
+ if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
+ raise unittest.SkipTest("requires SM80 or later")
+
+ # See lowering; for all of the pooling operators, we always guard and
+ # make the height/width static.
+ def fn(x):
+ return torch.nn.functional.adaptive_avg_pool2d(x, [5, 7])
+
+ compiled_fn = torch.compile(fn, dynamic=True)
+
+ # Iterate over different input shapes. Each new shape should cause
+ # a cache miss.
+ shapes = ((1, 64, 8, 9), (1, 64, 9, 10), (1, 64, 10, 11))
+ for shape in shapes:
+ x = torch.rand(shape, device=device, dtype=dtype)
+
+ # AVOID a dynamo reset here. For each cache hit, we expect guards
+ # to have been added that will be violated with each new shape.
+ # We should see a recompilation (along with a cache miss).
+ counters.clear()
+ res1 = compiled_fn(x)
+ self.assertGreater(counters["inductor"]["fxgraph_cache_miss"], 0)
+ self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
+
+ # A second call should hit.
+ counters.clear()
+ torch._dynamo.reset()
+ res2 = compiled_fn(x)
+ self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
+ self.assertGreater(counters["inductor"]["fxgraph_cache_hit"], 0)
+
+ self.assertEqual(res1, res2)
+
+ @config.patch({"fx_graph_cache": True})
+ def test_cache_clear(self):
+ """
+ Test clearing the cache.
+ """
+
+ def fn(x, y):
+ return (x * y,)
+
+ a = torch.rand(5, 5)
+ b = torch.rand(5, 5)
+
+ compiled_fn = torch.compile(fn)
+
+ # A first call shold miss in the cache.
+ self.assertEqual(fn(a, b), compiled_fn(a, b))
+ self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
+ self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
+
+ # A second call should hit.
+ counters.clear()
+ torch._dynamo.reset()
+ self.assertEqual(fn(a, b), compiled_fn(a, b))
+ self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
+ self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
+
+ # Clear the cache; now we should miss.
+ counters.clear()
+ torch._dynamo.reset()
+ torch._inductor.codecache.FxGraphCache.clear()
+ self.assertEqual(fn(a, b), compiled_fn(a, b))
+ self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
+ self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
+
class TestFxGraphCacheHashing(TestCase):
def test_tensor_constants(self):
@@ -296,16 +425,16 @@
ordering of the kwargs dict and any set arguments.
"""
# Dict order of the kwargs should not affect hashes.
- details1 = FxGraphHashDetails([], {"a": 0, "z": 1})
- details2 = FxGraphHashDetails([], {"z": 1, "a": 0})
+ details1 = FxGraphHashDetails(None, [], {"a": 0, "z": 1})
+ details2 = FxGraphHashDetails(None, [], {"z": 1, "a": 0})
self.assertEqual(
FxGraphCachePickler.dumps(details1),
FxGraphCachePickler.dumps(details2),
)
# Different kwarg values should affect hashes.
- details1 = FxGraphHashDetails([], {"a": 0})
- details2 = FxGraphHashDetails([], {"a": 1})
+ details1 = FxGraphHashDetails(None, [], {"a": 0})
+ details2 = FxGraphHashDetails(None, [], {"a": 1})
self.assertNotEqual(
FxGraphCachePickler.dumps(details1),
FxGraphCachePickler.dumps(details2),
@@ -315,16 +444,16 @@
# sorting and creating a new set seems to change the order.
set1 = {"a", "b", "c", "d", "e", "f", "g"}
set2 = set(sorted(set1)) # noqa: C414
- details1 = FxGraphHashDetails([], {"a": set1})
- details2 = FxGraphHashDetails([], {"a": set2})
+ details1 = FxGraphHashDetails(None, [], {"a": set1})
+ details2 = FxGraphHashDetails(None, [], {"a": set2})
self.assertEqual(
FxGraphCachePickler.dumps(details1),
FxGraphCachePickler.dumps(details2),
)
# But different set contents should affect hashes.
- details1 = FxGraphHashDetails([], {"a": {1, 2, 3}})
- details2 = FxGraphHashDetails([], {"a": {1, 2}})
+ details1 = FxGraphHashDetails(None, [], {"a": {1, 2, 3}})
+ details2 = FxGraphHashDetails(None, [], {"a": {1, 2}})
self.assertNotEqual(
FxGraphCachePickler.dumps(details1),
FxGraphCachePickler.dumps(details2),
@@ -335,11 +464,11 @@
Test that different config settings affect hashes.
"""
with config.patch({"max_autotune": False}):
- details1 = FxGraphHashDetails([], {})
- details2 = FxGraphHashDetails([], {})
+ details1 = FxGraphHashDetails(None, [], {})
+ details2 = FxGraphHashDetails(None, [], {})
with config.patch({"max_autotune": True}):
- details3 = FxGraphHashDetails([], {})
+ details3 = FxGraphHashDetails(None, [], {})
self.assertEqual(
FxGraphCachePickler.dumps(details1),
diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py
index 428555b..c92fe42 100644
--- a/torch/_inductor/codecache.py
+++ b/torch/_inductor/codecache.py
@@ -51,6 +51,7 @@
from torch._inductor.codegen.cuda import cuda_env
from torch._inductor.utils import developer_warning, is_linux
from torch._prims_common import suggest_memory_format
+from torch.fx.experimental.symbolic_shapes import has_hint, hint_int, ShapeEnv
if TYPE_CHECKING:
from torch._inductor.graph import GraphLowering
@@ -326,16 +327,16 @@
return lock_dir
+def sha256_hash(data: bytes) -> str:
+ # [:51] to strip off the "Q====" suffix common to every hash value.
+ return base64.b32encode(hashlib.sha256(data).digest())[:51].decode("utf-8").lower()
+
+
def code_hash(code: Union[str, bytes], extra: str = ""):
hashing_str = code if isinstance(code, bytes) else code.encode("utf-8")
if extra != "":
hashing_str = hashing_str + b"||" + extra.encode("utf-8")
- return (
- "c"
- + base64.b32encode(hashlib.sha256(hashing_str).digest())[:51]
- .decode("utf-8")
- .lower()
- )
+ return "c" + sha256_hash(hashing_str)
def get_path(
@@ -357,8 +358,6 @@
return code_hash(content, extra)
if hash_type == "cubin":
return code_hash(repr(content))
- if hash_type == "cg":
- return extra
raise AssertionError(f"Unknown hash type {hash_type}")
@@ -425,21 +424,10 @@
values: List[Any]
-class DynamicShapeError(BaseException):
- """
- TODO(masnesral): Support dynamic shapes and remove this.
- """
-
- pass
-
-
def extract_tensor_metadata(t: torch.Tensor) -> TensorMetadata:
"""
Extract the TensorMetadata of a tensor.
"""
- if not all(isinstance(e, int) for e in t.shape):
- raise DynamicShapeError()
-
memory_format = suggest_memory_format(t)
if not t.is_contiguous(memory_format=memory_format):
memory_format = None
@@ -491,6 +479,16 @@
return (_ident, (metadata,))
+def _reduce_symint(s):
+ """
+ See FxGraphCachePickler. Custom reducer to pickle SymInts.
+ """
+ # For hashing purposes, we only care about the name of the symbol and
+ # not the backed value. We evaluate guards stored with a cached graph
+ # to ensure a cached entity with SymInt args is safe to reuse.
+ return (_ident, (str(s),))
+
+
class FxGraphCachePickler(pickle.Pickler):
"""
Custom pickler to customize the pickling of some objects (Tensors), only for the
@@ -502,6 +500,7 @@
dispatch_table = copyreg.dispatch_table.copy()
dispatch_table[torch._subclasses.fake_tensor.FakeTensor] = _reduce_fake_tensor
dispatch_table[torch.Tensor] = _reduce_tensor
+ dispatch_table[torch.SymInt] = _reduce_symint
@staticmethod
def dumps(obj) -> bytes:
@@ -513,6 +512,15 @@
pickler.dump(obj)
return stream.getvalue()
+ @staticmethod
+ def get_hash(obj: Any) -> str:
+ """
+ Serialize an object using the FxGraphCachePickler and return a hash
+ of the pickled object.
+ """
+ serialized_data = FxGraphCachePickler.dumps(obj)
+ return sha256_hash(serialized_data)
+
@functools.lru_cache(None)
def get_inductor_code_hash() -> bytes:
@@ -553,8 +561,14 @@
# Excluded kwargs param that are not stable between runs
EXCLUDED_KWARGS = ["graph_id"]
- def __init__(self, fx_args: List[Any], fx_kwargs: Dict[str, Any]) -> None:
- self.fx_args = fx_args
+ def __init__(
+ self,
+ gm: torch.fx.GraphModule,
+ example_inputs: List[torch.Tensor],
+ fx_kwargs: Dict[str, Any],
+ ):
+ self.gm = gm
+ self.example_inputs = example_inputs
# Order kwargs so hashing is stable to changes in kwarg order.
self.fx_kwargs = {}
@@ -575,79 +589,243 @@
self.inductor_config = config.save_config() # type: ignore[attr-defined]
self.inductor_code_hash = get_inductor_code_hash()
+ def debug_str(self) -> str:
+ """
+ Get a printable string describing in more detail all the attributes
+ comprising this object. Useful for debugging when one graph hashes
+ to a different value than another.
+ """
-def compiled_fx_graph_hash(fx_args: List[Any], fx_kwargs: Dict[str, Any]) -> str:
+ def get_str(obj) -> str:
+ if isinstance(obj, torch.Tensor):
+ return str(extract_tensor_metadata(obj))
+ elif isinstance(obj, bytes):
+ return "<bytes>"
+ else:
+ return str(obj)
+
+ lines = []
+ for attr, obj in vars(self).items():
+ if isinstance(obj, list):
+ for ii in range(len(obj)):
+ h = FxGraphCachePickler.get_hash(obj[ii])
+ lines.append(f"[{h}] {attr}[{ii}]: {get_str(obj[ii])}")
+ elif isinstance(obj, dict):
+ for k, v in obj.items():
+ h = FxGraphCachePickler.get_hash(v)
+ lines.append(f"[{h}] {attr}[{k}]: {get_str(v)}")
+ else:
+ h = FxGraphCachePickler.get_hash(obj)
+ lines.append(f"[{h}] {attr}: {get_str(obj)}")
+ return "\n".join(lines)
+
+
+def compiled_fx_graph_hash(
+ gm: torch.fx.GraphModule,
+ example_inputs: List[torch.Tensor],
+ fx_kwargs: Dict[str, Any],
+) -> str:
"""
Generate a unique hash of the FX graph for caching.
"""
- details = FxGraphHashDetails(fx_args, fx_kwargs)
- serialized_data = FxGraphCachePickler.dumps(details)
- return (
- "f"
- + base64.b32encode(hashlib.sha256(serialized_data).digest())[:51]
- .decode("utf-8")
- .lower()
- )
+ details = FxGraphHashDetails(gm, example_inputs, fx_kwargs)
+ # The prefix distinguishes among the other kinds of objects we
+ # cache in this module.
+ key = "f" + FxGraphCachePickler.get_hash(details)
+ log.debug("FX graph cache hash details for key %s:\n%s", key, details.debug_str())
+ return key
class FxGraphCache:
"""
Supports caching and reusing compiled Fx graphs.
+
+ The overall strategy is as follows:
+ - This cache stores entries on disk. When saving an entry, we can't
+ serialize callables (that could be C++, Triton, etc.), so we serialize
+ their own disk cache location. We then recreate the compiled artifact
+ after fetching from disk.
+ - For indexing the cache, we gather the fields relevant to identifying an
+ FxGraph (the graph module, graph inputs, system settings etc.) into an
+ FxGraphCacheDetails object, pickle it, and compute a hash for the key.
+ See FxGraphCachePickler.
+ - Among the metadata we store, we also include a guards expression that's
+ appropriate for validating any symbols for Tensor arguments that have
+ symbolic bounds. On cache lookup then, we evaluate those guards in the
+ current context to validate that a cached entry can be served.
+ - A given graph could have multiple compiled versions, corresponding to
+ different sets of guards. Therefore, we store cache entries in the form:
+ <temp dir>/<fx graph hash>/<serialized metatdata>
+ - On lookup, we compute the key from the graph details, iterate over all
+ leaf files in the corresponding subdirectory, deserialize the entry, and
+ evaluate its guards expression. If the evaluation succeeds, we have a
+ cache hit. If it fails, we compile the graph and store a new entry.
+ - Finally, on a cache hit, we need to make sure any guards that would
+ have been created during compilation are added to the current context.
"""
# TODO(masnesral): Investigate whether it's beneficial to store compiled graphs
# in an in-memory cache after loading from disk.
- @classmethod
- def save_graph(cls, key: str, compiled_graph: CompiledFxGraph):
+ @staticmethod
+ def _get_tmp_dir() -> str:
+ """
+ Get the toplevel temporary directory for storing compiled graphs.
+ """
+ return os.path.join(cache_dir(), "fxgraph")
+
+ @staticmethod
+ def _get_tmp_dir_for_key(key: str) -> str:
+ """
+ Return the disk location for a given cache key.
+ """
+ return os.path.join(FxGraphCache._get_tmp_dir(), key[1:3], key)
+
+ @staticmethod
+ def _filter_symints(inputs: List[Any]) -> List[torch.SymInt]:
+ """
+ Get the SymInt objects from the input list.
+ """
+ return [s for s in inputs if isinstance(s, torch.SymInt)]
+
+ @staticmethod
+ def _get_shape_env() -> ShapeEnv:
+ """
+ Helper to get the shape env from the tracing context.
+ """
+ tracing_context = torch._guards.TracingContext.get()
+ assert tracing_context is not None
+ return tracing_context.fake_mode.shape_env
+
+ @staticmethod
+ def _lookup_graph(
+ key: str,
+ example_inputs: List[torch.Tensor],
+ ) -> Optional[CompiledFxGraph]:
+ """
+ Lookup a compiled graph in the cache by key. On a hit, return the
+ deserialized CompiledFxGraph object. On a miss, return None.
+ """
+ subdir = FxGraphCache._get_tmp_dir_for_key(key)
+ if not os.path.exists(subdir):
+ return None
+
+ # Iterate over any entries in the subdir for this key and evaluate
+ # their guards to determine whether there's a hit.
+ for path in sorted(os.listdir(subdir)):
+ with open(os.path.join(subdir, path), "rb") as f:
+ graph: CompiledFxGraph = pickle.load(f)
+
+ guards_expr = graph.guards_expr
+ if not guards_expr:
+ # No guards to evaluate
+ return graph
+
+ # Evaluate the guard expression in the current context.
+ shape_env = FxGraphCache._get_shape_env()
+ symints = FxGraphCache._filter_symints(example_inputs)
+
+ # If there's not a cache hit, we don't want the evaluation to
+ # affect the current env, e.g., cause the creation of new guards,
+ # so we evaluate with the hints instead of the symbols.
+ assert all(has_hint(s) for s in symints)
+ hints = [hint_int(s) for s in symints]
+ hit = bool(shape_env.evaluate_guards_expression(guards_expr, hints))
+ log.debug(
+ "fx graph cache key %s evaluating guards for %s with values %s => %s",
+ key,
+ guards_expr,
+ hints,
+ hit,
+ )
+ if hit:
+ # Now re-evaluate with the symints to add any guards to the current env.
+ check = bool(shape_env.evaluate_guards_expression(guards_expr, symints))
+ assert check is True
+ log.debug(
+ "fx graph cache key %s post-load guards: %s",
+ key,
+ shape_env.guards,
+ )
+ return graph
+
+ return None
+
+ @staticmethod
+ def _save_graph(
+ key: str, compiled_graph: CompiledFxGraph, example_inputs: List[torch.Tensor]
+ ):
+ """
+ Store a serialized CompiledFxGraph on disk.
+ """
disk_compiled_graph = copy(compiled_graph)
- # Important as compiled models are not pickleable
- # TODO: Check status of PR #101651 as that might change the above statement
+ # Important as compiled models are not pickleable:
disk_compiled_graph.compiled_artifact = None
- write(pickle.dumps(disk_compiled_graph), "cg", extra=key, hash_type="cg")
- @classmethod
- def load_graph(cls, cg_path: str) -> CompiledFxGraph:
- with open(cg_path, "rb") as f:
- return pickle.load(f)
+ # Before serializing, compute the guard expression that will be used to
+ # ensure that a CompiledFxGraph is valid when loaded from the cache. It's
+ # sufficient to consider only the SymInt args to the fx graph since the
+ # Tensor shapes are already captured in the hash for the cache key. Any
+ # Tensor arg with a symbolic shape will have a SymInt arg for the graph.
+ shape_env = FxGraphCache._get_shape_env()
+ symints = FxGraphCache._filter_symints(example_inputs)
+ disk_compiled_graph.guards_expr = shape_env.produce_guards_expression(symints)
- @classmethod
+ content = pickle.dumps(disk_compiled_graph)
+
+ subdir = FxGraphCache._get_tmp_dir_for_key(key)
+ if not os.path.exists(subdir):
+ os.makedirs(subdir, exist_ok=True)
+
+ # Use a hash of the serialized CompiledFxGraph to get a unique file
+ # name. The specific name doesn't matter since a lookup involves
+ # iterating over all entries in the parent subdir.
+ path = os.path.join(subdir, sha256_hash(content))
+ write_atomic(path, content)
+
+ @staticmethod
def load(
- cls,
compile_fx_fn: Callable[..., Any],
- fx_args: List[Any],
+ gm: torch.fx.GraphModule,
+ example_inputs: List[torch.Tensor],
fx_kwargs: Dict[str, Any],
):
+ """
+ Load a compiled graph from the cache. If a cached entry does not exist,
+ compile the graph and save it to the cache.
+ """
from filelock import FileLock
- try:
- key = compiled_fx_graph_hash(fx_args, fx_kwargs)
- except DynamicShapeError:
- # TODO(masnresral): support dynamic shapes
- log.debug("fx graph cache skip due to dynamic shapes")
- counters["inductor"]["fxgraph_cache_skip"] += 1
- return compile_fx_fn(*fx_args, **fx_kwargs)
+ key = compiled_fx_graph_hash(gm, example_inputs, fx_kwargs)
lock_dir = get_lock_dir()
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
with lock:
- _, _, cg_path = get_path(key, "cg")
- if not os.path.exists(cg_path):
+ compiled_graph = FxGraphCache._lookup_graph(key, example_inputs)
+ if compiled_graph is None:
log.debug("fx graph cache miss for key %s", key)
counters["inductor"]["fxgraph_cache_miss"] += 1
- compiled_graph: CompiledFxGraph = compile_fx_fn(*fx_args, **fx_kwargs)
- cls.save_graph(key, compiled_graph)
+ compiled_graph = compile_fx_fn(gm, example_inputs, **fx_kwargs)
+ FxGraphCache._save_graph(key, compiled_graph, example_inputs)
else:
- # Load required info from disk; recreation of compiled model will be on first run
log.debug("fx graph cache hit for key %s", key)
counters["inductor"]["fxgraph_cache_hit"] += 1
- compiled_graph = cls.load_graph(cg_path)
return compiled_graph
+ @staticmethod
+ def clear():
+ """
+ Clear out the on-disk cache.
+ """
+ shutil.rmtree(FxGraphCache._get_tmp_dir())
+
@dataclasses.dataclass
class CompiledFxGraph:
- """Class holding a compiled FX graph"""
+ """
+ Class holding a compiled FX graph. This is the object serialized on disk
+ to support FxGraph caching.
+ """
compiled_artifact: Optional[Callable[..., Any]] = None
current_callable: Optional[Callable[..., Any]] = None
@@ -660,9 +838,33 @@
mutated_input_idxs: Set[int] = field(default_factory=set)
constants: Dict[str, torch.Tensor] = field(default_factory=dict)
output_strides: Optional[List[Optional[Tuple[int, ...]]]] = None
+ # This is a string representation of an expression we serialize
+ # with the object so the guards can be evaluated in a different
+ # context in order to verify the validity of serving a cached
+ # fx graph. The expression must be generated by:
+ # ShapeEnv.produce_guards_expression()
+ guards_expr: Optional[str] = None
_boxed_call: Optional[bool] = None
+ def __init__(
+ self,
+ compiled_artifact: Optional[Callable[..., Any]],
+ graph: GraphLowering,
+ output_strides: List[Optional[Tuple[int, ...]]],
+ ):
+ self.compiled_artifact = compiled_artifact
+ self.cache_key = graph.cache_key
+ self.artifact_path = graph.cache_path
+ self.cache_linemap = graph.cache_linemap
+ self.device_types = graph.device_types
+ self.device_idxs = graph.device_idxs
+ self.mutated_inputs = graph.mutated_inputs
+ self.mutated_input_idxs = set(graph.mutated_input_idxs)
+ self.constants = graph.constants
+ self.output_strides = output_strides
+ self.guards_expr = None
+
def __call__(self, inputs: List[Any]) -> Any:
return self.get_current_callable()(inputs)
diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py
index 32e3b60..b6aa7f1 100644
--- a/torch/_inductor/compile_fx.py
+++ b/torch/_inductor/compile_fx.py
@@ -366,8 +366,7 @@
# Inputs to fx_codegen_and_compile
# Anything that affects codegen should go here, so if the signature
- # of fx_codegen_and_compile changes, the list and dict should be updated accordingly
- graph_args = [gm, example_inputs]
+ # of fx_codegen_and_compile changes, the dict should be updated accordingly
graph_kwargs = {
"cudagraphs": cudagraphs,
"num_fixed": num_fixed,
@@ -385,11 +384,11 @@
if config.fx_graph_cache and not aot_mode:
compiled_graph = FxGraphCache.load(
- fx_codegen_and_compile, graph_args, graph_kwargs
+ fx_codegen_and_compile, gm, example_inputs, graph_kwargs
)
else:
compiled_graph = fx_codegen_and_compile(
- *graph_args, **graph_kwargs # type: ignore[arg-type]
+ gm, example_inputs, **graph_kwargs # type: ignore[arg-type]
)
log.debug("FX codegen and compilation took %.3fs", time.time() - start)
@@ -620,18 +619,8 @@
if graph.disable_cudagraphs:
BoxedBool.disable(cudagraphs)
- compiled_graph = CompiledFxGraph(
- compiled_artifact=compiled_fn,
- cache_key=graph.cache_key,
- artifact_path=graph.cache_path,
- cache_linemap=graph.cache_linemap,
- device_types=graph.device_types,
- device_idxs=graph.device_idxs,
- mutated_inputs=graph.mutated_inputs,
- mutated_input_idxs=set(graph.mutated_input_idxs),
- constants=graph.constants,
- output_strides=output_strides,
- )
+ compiled_graph = CompiledFxGraph(compiled_fn, graph, output_strides)
+
return compiled_graph
diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py
index 7008b15..94b8297 100644
--- a/torch/fx/experimental/symbolic_shapes.py
+++ b/torch/fx/experimental/symbolic_shapes.py
@@ -2735,13 +2735,31 @@
self._check_translation_validate()
return exprs
- def evaluate_guards_for_args(self, placeholders, args, *, ignore_static=True):
+ def produce_guards_expression(self, placeholders, ignore_static=True):
+ """
+ Expected to be used with evaluate_guards_expression(). Produces the guards
+ for the given placeholders and returns a string expression to be evaluated
+ by evaluate_guards_expression given concrete values for the placeholders.
+ """
from torch._dynamo.source import LocalSource
- arg_names = [f"t{i}" for i in range(len(args))]
+ arg_names = [f"t{i}" for i in range(len(placeholders))]
guards = self.produce_guards(placeholders, [LocalSource(a) for a in arg_names], ignore_static=ignore_static)
if guards:
- code = " and ".join(guards)
- return eval(code, SYMPY_INTERP, {"L": dict(zip(arg_names, args))})
+ return " and ".join(guards)
+ return None
+
+ def evaluate_guards_expression(self, code, args):
+ """
+ Expected to be used with produce_guards_expression(). Evaluates an expression
+ generated by produce_guards_expression for the given concrete args.
+ """
+ arg_names = [f"t{i}" for i in range(len(args))]
+ return eval(code, SYMPY_INTERP, {"L": dict(zip(arg_names, args))})
+
+ def evaluate_guards_for_args(self, placeholders, args, *, ignore_static=True):
+ code = self.produce_guards_expression(placeholders, ignore_static=ignore_static)
+ if code:
+ return self.evaluate_guards_expression(code, args)
return True
def bind_symbols(self, placeholders, args):