Revert "[inductor] Implement Fx graph caching to improve warm compilation time. (#103453)"
This reverts commit fc1105b2827ee2febc85a3c353470edfd70a66ed.
Reverted https://github.com/pytorch/pytorch/pull/103453 on behalf of https://github.com/kit1980 due to Same issue unfortunately, the newly added test fails on internal builds ([comment](https://github.com/pytorch/pytorch/pull/103453#issuecomment-1760202365))
diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py
index 8403d0d..e293afb 100644
--- a/test/inductor/test_codecache.py
+++ b/test/inductor/test_codecache.py
@@ -1,32 +1,12 @@
# Owner(s): ["module: inductor"]
import functools
-import pickle
-import tempfile
import unittest
-from unittest.mock import patch
import torch
-from torch._dynamo.test_case import run_tests, TestCase
-from torch._dynamo.utils import counters
-from torch._inductor import config
-from torch._inductor.codecache import (
- AsyncCompile,
- FxGraphCachePickler,
- FxGraphHashDetails,
- TensorMetadata,
- TensorMetadataAndValues,
-)
-from torch.testing._internal.common_utils import (
- instantiate_parametrized_tests,
- parametrize,
-)
+from torch._inductor.codecache import AsyncCompile
from torch.testing._internal.inductor_utils import HAS_CUDA
-from torch.utils._triton import has_triton
-
-HAS_TRITON = has_triton()
requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")
-requires_triton = functools.partial(unittest.skipIf, not HAS_TRITON, "requires triton")
class MyModel(torch.nn.Module):
@@ -57,274 +37,3 @@
@requires_cuda()
def test_codecache_fork():
_run_codecache_test("fork")
-
-
-@instantiate_parametrized_tests
-class TestFxGraphCache(TestCase):
- @classmethod
- def setUpClass(cls):
- # Reroute all cache disk activity to a clean temporary directory to
- # ensure isolation (and initial cache misses). Deliberately create the
- # temp dir in setUpClass, however, so that individual test runs reuse
- # the same location. We don't expect different tests to reuse cache
- # entries, so preserving the temp dir provides that additional testing.
- cls.tmpdir = tempfile.TemporaryDirectory()
- cls.cache_dir_patch = patch("torch._inductor.codecache.cache_dir")
- cls.cache_dir_patch.start().return_value = cls.tmpdir.name
-
- @classmethod
- def tearDownClass(cls):
- cls.cache_dir_patch.stop()
- cls.tmpdir.cleanup()
-
- def setUp(self):
- counters.clear()
-
- @requires_triton()
- @config.patch({"fx_graph_cache": True})
- @parametrize("device", ("cuda", "cpu"))
- @parametrize("dtype", (torch.float, torch.bfloat16))
- def test_cache_load_function(self, device, dtype):
- """
- Verify that we can populate and load functions from the cache.
- """
- if device == "cuda" and not HAS_CUDA:
- raise unittest.SkipTest("requires CUDA")
-
- def fn(x, y):
- return (x * 2, y @ y)
-
- a = torch.rand(25, dtype=dtype, device=device)
- b = torch.rand(5, 5, dtype=dtype, device=device)
- c = a.view(5, 5)
-
- compiled_fn = torch.compile(fn, dynamic=False)
-
- # A first call shold miss in the cache.
- self.assertEqual(fn(a, b), compiled_fn(a, b))
- self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
- self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
-
- # A second call should hit. (First reset so in-memory guards
- # don't prevent compilation).
- torch._dynamo.reset()
- self.assertEqual(fn(a, b), compiled_fn(a, b))
- self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
- self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
-
- # But we expect different code if the tensors are aliased.
- torch._dynamo.reset()
- self.assertEqual(fn(a, c), compiled_fn(a, c))
- self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
- self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
-
- @requires_triton()
- @config.patch({"fx_graph_cache": True})
- @parametrize("device", ("cuda", "cpu"))
- @parametrize("dtype", (torch.float, torch.float16))
- def test_cache_load_model(self, device, dtype):
- """
- Verify that we can populate and load models from the cache.
- """
- if device == "cuda" and not HAS_CUDA:
- raise unittest.SkipTest("requires CUDA")
-
- model = MyModel().to(dtype=dtype, device=device)
-
- a = torch.rand(10, 10, dtype=dtype, device=device)
-
- compiled_model = torch.compile(model, dynamic=False)
-
- # A first call shold miss in the cache.
- self.assertEqual(model(a), compiled_model(a))
- self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
- self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
-
- # A second call should hit. (First reset so in-memory guards
- # don't prevent compilation).
- torch._dynamo.reset()
- self.assertEqual(model(a), compiled_model(a))
- self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
- self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
-
-
-class TestFxGraphCacheHashing(TestCase):
- def test_tensor_constants(self):
- """
- Test the handling of small vs. large tensor constants.
- """
- data = FxGraphCachePickler.dumps(torch.tensor(list(range(9))))
- self.assertIsInstance(pickle.loads(data), TensorMetadata)
-
- data = FxGraphCachePickler.dumps(torch.tensor(list(range(8))))
- self.assertIsInstance(pickle.loads(data), TensorMetadataAndValues)
-
- def test_hash_fake_tensors(self):
- """
- Test hashing (pickling) FakeTensors with various characteristics.
- """
- with torch._subclasses.FakeTensorMode():
- # Verify that FakeTensors get pickled into a TensorMetadata:
- data = FxGraphCachePickler.dumps(torch.randn(1))
- self.assertIsInstance(pickle.loads(data), TensorMetadata)
-
- # Different shapes:
- self.assertEqual(
- FxGraphCachePickler.dumps(torch.randn(3)),
- FxGraphCachePickler.dumps(torch.randn(3)),
- )
- self.assertNotEqual(
- FxGraphCachePickler.dumps(torch.randn(3)),
- FxGraphCachePickler.dumps(torch.randn(4)),
- )
- self.assertNotEqual(
- FxGraphCachePickler.dumps(torch.randn(3)),
- FxGraphCachePickler.dumps(torch.randn(3, 3)),
- )
-
- self.assertEqual(
- FxGraphCachePickler.dumps(torch.randn(3, 3)),
- FxGraphCachePickler.dumps(torch.randn(3, 3)),
- )
- self.assertNotEqual(
- FxGraphCachePickler.dumps(torch.randn(3, 3)),
- FxGraphCachePickler.dumps(torch.randn(3, 4)),
- )
- self.assertNotEqual(
- FxGraphCachePickler.dumps(torch.randn(3, 3)),
- FxGraphCachePickler.dumps(torch.randn(4, 3)),
- )
-
- # Different strides:
- self.assertEqual(
- FxGraphCachePickler.dumps(torch.randn(3, 3)),
- FxGraphCachePickler.dumps(
- torch.randn(3, 3).transpose(0, 1).transpose(0, 1)
- ),
- )
- self.assertNotEqual(
- FxGraphCachePickler.dumps(torch.randn(3, 3)),
- FxGraphCachePickler.dumps(torch.randn(3, 3).transpose(0, 1)),
- )
-
- # Different storage offsets:
- self.assertEqual(
- FxGraphCachePickler.dumps(torch.randn(3)[1:]),
- FxGraphCachePickler.dumps(torch.randn(3)[1:]),
- )
- self.assertNotEqual(
- FxGraphCachePickler.dumps(torch.randn(3)[1:]),
- FxGraphCachePickler.dumps(torch.randn(2)),
- )
-
- # Different dtypes:
- self.assertEqual(
- FxGraphCachePickler.dumps(torch.randn(3, dtype=torch.float32)),
- FxGraphCachePickler.dumps(torch.randn(3, dtype=torch.float32)),
- )
- self.assertNotEqual(
- FxGraphCachePickler.dumps(torch.randn(3, dtype=torch.float32)),
- FxGraphCachePickler.dumps(torch.randn(3, dtype=torch.float64)),
- )
-
- # Different 'requires_grad':
- self.assertEqual(
- FxGraphCachePickler.dumps(torch.randn(3, requires_grad=True)),
- FxGraphCachePickler.dumps(torch.randn(3, requires_grad=True)),
- )
- self.assertNotEqual(
- FxGraphCachePickler.dumps(torch.randn(3, requires_grad=True)),
- FxGraphCachePickler.dumps(torch.randn(3, requires_grad=False)),
- )
-
- # Different memory formats:
- self.assertNotEqual(
- FxGraphCachePickler.dumps(torch.randn(1, 2, 3, 4)),
- FxGraphCachePickler.dumps(
- torch.randn(1, 2, 3, 4).to(memory_format=torch.channels_last)
- ),
- )
-
- # Different devices:
- self.assertEqual(
- FxGraphCachePickler.dumps(torch.randn(3, device="meta")),
- FxGraphCachePickler.dumps(torch.randn(3, device="meta")),
- )
- self.assertNotEqual(
- FxGraphCachePickler.dumps(torch.randn(3, device="meta")),
- FxGraphCachePickler.dumps(torch.randn(3, device="cpu")),
- )
-
- if HAS_CUDA and torch.cuda.device_count() >= 2:
- self.assertEqual(
- FxGraphCachePickler.dumps(torch.randn(3, device="cuda:1")),
- FxGraphCachePickler.dumps(torch.randn(3, device="cuda:1")),
- )
- self.assertNotEqual(
- FxGraphCachePickler.dumps(torch.randn(3, device="cuda:0")),
- FxGraphCachePickler.dumps(torch.randn(3, device="cuda:1")),
- )
-
- def test_hash_kwargs(self):
- """
- Test the special handling of the kwargs when hashing, i.e.,
- ordering of the kwargs dict and any set arguments.
- """
- # Dict order of the kwargs should not affect hashes.
- details1 = FxGraphHashDetails([], {"a": 0, "z": 1})
- details2 = FxGraphHashDetails([], {"z": 1, "a": 0})
- self.assertEqual(
- FxGraphCachePickler.dumps(details1),
- FxGraphCachePickler.dumps(details2),
- )
-
- # Different kwarg values should affect hashes.
- details1 = FxGraphHashDetails([], {"a": 0})
- details2 = FxGraphHashDetails([], {"a": 1})
- self.assertNotEqual(
- FxGraphCachePickler.dumps(details1),
- FxGraphCachePickler.dumps(details2),
- )
-
- # Set order should not affect hashes. Sets are unordered, but
- # sorting and creating a new set seems to change the order.
- set1 = {"a", "b", "c", "d", "e", "f", "g"}
- set2 = set(sorted(set1)) # noqa: C414
- details1 = FxGraphHashDetails([], {"a": set1})
- details2 = FxGraphHashDetails([], {"a": set2})
- self.assertEqual(
- FxGraphCachePickler.dumps(details1),
- FxGraphCachePickler.dumps(details2),
- )
-
- # But different set contents should affect hashes.
- details1 = FxGraphHashDetails([], {"a": {1, 2, 3}})
- details2 = FxGraphHashDetails([], {"a": {1, 2}})
- self.assertNotEqual(
- FxGraphCachePickler.dumps(details1),
- FxGraphCachePickler.dumps(details2),
- )
-
- def test_hash_config_changes(self):
- """
- Test that different config settings affect hashes.
- """
- with config.patch({"max_autotune": False}):
- details1 = FxGraphHashDetails([], {})
- details2 = FxGraphHashDetails([], {})
-
- with config.patch({"max_autotune": True}):
- details3 = FxGraphHashDetails([], {})
-
- self.assertEqual(
- FxGraphCachePickler.dumps(details1),
- FxGraphCachePickler.dumps(details2),
- )
- self.assertNotEqual(
- FxGraphCachePickler.dumps(details1),
- FxGraphCachePickler.dumps(details3),
- )
-
-
-if __name__ == "__main__":
- run_tests()
diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py
index 926e45f..4dde4c4 100644
--- a/torch/_inductor/codecache.py
+++ b/torch/_inductor/codecache.py
@@ -1,20 +1,16 @@
from __future__ import annotations
import base64
-import copyreg
import dataclasses
import functools
import getpass
import hashlib
import importlib
-import io
import json
import logging
import multiprocessing
import os
import pathlib
-import pickle
-import pkgutil
import platform
import re
import shlex
@@ -29,7 +25,6 @@
import weakref
from bisect import bisect_right
from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
-from copy import copy
from ctypes import c_void_p, cdll, CDLL
from dataclasses import field
from functools import partial
@@ -46,11 +41,9 @@
get_interface_for_device,
get_registered_device_interfaces,
)
-from torch._dynamo.utils import counters
from torch._inductor import config, exc
from torch._inductor.codegen.cuda import cuda_env
from torch._inductor.utils import developer_warning, is_linux
-from torch._prims_common import suggest_memory_format
if TYPE_CHECKING:
from torch._inductor.graph import GraphLowering
@@ -357,8 +350,6 @@
return code_hash(content, extra)
if hash_type == "cubin":
return code_hash(repr(content))
- if hash_type == "cg":
- return extra
raise AssertionError(f"Unknown hash type {hash_type}")
@@ -393,259 +384,6 @@
@dataclasses.dataclass
-class TensorMetadata:
- """
- The Tensor metadata relevant when hashing FxGraph cache keys.
- """
-
- dtype: torch.dtype
- shape: torch.Size
- stride: Tuple[Any, ...]
- device: torch.device
- layout: torch.layout
- memory_format: Optional[torch.memory_format]
- storage_offset: int
- requires_grad: bool
- is_quantized: bool
- is_conj: bool
- is_neg: bool
- is_coalesced: bool
- dense_dim: int
- sparse_dim: int
-
-
-@dataclasses.dataclass
-class TensorMetadataAndValues:
- """
- TensorMetadata plus the elements as a list of raw values.
- Used for hashing inlined constants.
- """
-
- tensor_metadata: TensorMetadata
- values: List[Any]
-
-
-class DynamicShapeError(BaseException):
- """
- TODO(masnesral): Support dynamic shapes and remove this.
- """
-
- pass
-
-
-def extract_tensor_metadata(t: torch.Tensor) -> TensorMetadata:
- """
- Extract the TensorMetadata of a tensor.
- """
- if not all(isinstance(e, int) for e in t.shape):
- raise DynamicShapeError()
-
- memory_format = suggest_memory_format(t)
- if not t.is_contiguous(memory_format=memory_format):
- memory_format = None
-
- return TensorMetadata(
- dtype=t.dtype,
- shape=t.shape,
- stride=t.stride() if t.layout == torch.strided else (),
- device=t.device,
- layout=t.layout,
- memory_format=memory_format,
- storage_offset=t.storage_offset(),
- requires_grad=t.requires_grad,
- is_quantized=t.is_quantized,
- is_conj=t.is_conj(),
- is_neg=t.is_neg(),
- is_coalesced=t.is_coalesced() if t.is_sparse else False,
- dense_dim=t.dense_dim() if t.is_sparse else False,
- sparse_dim=t.sparse_dim() if t.is_sparse else False,
- )
-
-
-def _ident(x: Any) -> Any:
- return x
-
-
-def _reduce_fake_tensor(t):
- """
- See FxGraphCachePickler. Custom reducer to pickle FakeTensors.
- """
- metadata = extract_tensor_metadata(t)
- return (_ident, (metadata,))
-
-
-def _reduce_tensor(t):
- """
- See FxGraphCachePickler. Custom reducer to pickle Tensors.
- """
- # If we see tensors, we know they're contstants stored as attributes on
- # the GraphModule. See tensor lowering; small constants are inlined. If
- # we see a small tensor, therefore, no reference will ultimately remain
- # in the generated code. So we need to include its value in the cache key.
- # Large constannts are effectively treated as inputs and we consider only
- # their metadata.
- metadata = extract_tensor_metadata(t)
- if len(t.shape) == 0 or torch._inductor.graph.GraphLowering.can_inline_constant(t):
- return (_ident, (TensorMetadataAndValues(metadata, t.tolist()),))
- else:
- return (_ident, (metadata,))
-
-
-class FxGraphCachePickler(pickle.Pickler):
- """
- Custom pickler to customize the pickling of some objects (Tensors), only for the
- purpose of computing a hash for keying into the FxGraphCache. Tensors contain
- objects that don't pickle and/or vary between runs, and we want to capture the
- data that allow us to compute a stable, but safe hash.
- """
-
- dispatch_table = copyreg.dispatch_table.copy()
- dispatch_table[torch._subclasses.fake_tensor.FakeTensor] = _reduce_fake_tensor
- dispatch_table[torch.Tensor] = _reduce_tensor
-
- @staticmethod
- def dumps(obj) -> bytes:
- """
- Pickle an object using the FxGraphCachePickler.
- """
- with io.BytesIO() as stream:
- pickler = FxGraphCachePickler(stream)
- pickler.dump(obj)
- return stream.getvalue()
-
-
-@functools.lru_cache(None)
-def get_inductor_code_hash() -> bytes:
- """
- Compute a hash of all inductor code modules. Used by the FxGraph cache
- so any inductor code changes would result in new cache keys.
- """
- inductor_root = os.path.dirname(__file__)
-
- contents: Dict[str, bytes] = {}
- for lib in pkgutil.iter_modules([inductor_root]):
- spec = lib.module_finder.find_spec(lib.name, None)
- assert spec is not None
- module = spec.origin
- assert module is not None
- with open(module, "rb") as f:
- contents[module] = f.read()
-
- return hashlib.sha256(pickle.dumps(contents)).digest()
-
-
-@dataclasses.dataclass
-class OrderedSetHolder:
- """
- See FxGraphHashDetails. Holds a sorted list to support stable hashing
- of set kwargs.
- """
-
- items: List[Any]
-
-
-class FxGraphHashDetails:
- """
- Object to capture all the details for a compiled FX graph relevant to computing
- a safe and stable cache key.
- """
-
- # Excluded kwargs param that are not stable between runs
- EXCLUDED_KWARGS = ["graph_id"]
-
- def __init__(self, fx_args: List[Any], fx_kwargs: Dict[str, Any]) -> None:
- self.fx_args = fx_args
-
- # Order kwargs so hashing is stable to changes in kwarg order.
- self.fx_kwargs = {}
- for k in sorted(fx_kwargs):
- if k not in self.EXCLUDED_KWARGS:
- if type(fx_kwargs[k]) is set:
- # Special case to handle set params. Python sets can't be
- # ordered, so sort the elements and store them in a proxy.
- self.fx_kwargs[k] = OrderedSetHolder(sorted(fx_kwargs[k]))
- else:
- self.fx_kwargs[k] = fx_kwargs[k]
-
- # Also hash on various system info (including the triton compiler version), as
- # well as the inductor configuration and code.
- self.torch_version = torch.__version__
- self.system_info = CacheBase.get_system()
-
- self.inductor_config = config.save_config() # type: ignore[attr-defined]
- self.inductor_code_hash = get_inductor_code_hash()
-
-
-def compiled_fx_graph_hash(fx_args: List[Any], fx_kwargs: Dict[str, Any]) -> str:
- """
- Generate a unique hash of the FX graph for caching.
- """
- details = FxGraphHashDetails(fx_args, fx_kwargs)
- serialized_data = FxGraphCachePickler.dumps(details)
- return (
- "f"
- + base64.b32encode(hashlib.sha256(serialized_data).digest())[:51]
- .decode("utf-8")
- .lower()
- )
-
-
-class FxGraphCache:
- """
- Supports caching and reusing compiled Fx graphs.
- """
-
- # TODO(masnesral): Investigate whether it's beneficial to store compiled graphs
- # in an in-memory cache after loading from disk.
- @classmethod
- def save_graph(cls, key: str, compiled_graph: CompiledFxGraph):
- disk_compiled_graph = copy(compiled_graph)
- # Important as compiled models are not pickleable
- # TODO: Check status of PR #101651 as that might change the above statement
- disk_compiled_graph.compiled_artifact = None
- write(pickle.dumps(disk_compiled_graph), "cg", extra=key, hash_type="cg")
-
- @classmethod
- def load_graph(cls, cg_path: str) -> CompiledFxGraph:
- with open(cg_path, "rb") as f:
- return pickle.load(f)
-
- @classmethod
- def load(
- cls,
- compile_fx_fn: Callable[..., Any],
- fx_args: List[Any],
- fx_kwargs: Dict[str, Any],
- ):
- from filelock import FileLock
-
- try:
- key = compiled_fx_graph_hash(fx_args, fx_kwargs)
- except DynamicShapeError:
- # TODO(masnresral): support dynamic shapes
- log.debug("fx graph cache skip due to dynamic shapes")
- counters["inductor"]["fxgraph_cache_skip"] += 1
- return compile_fx_fn(*fx_args, **fx_kwargs)
-
- lock_dir = get_lock_dir()
- lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
- with lock:
- _, _, cg_path = get_path(key, "cg")
- if not os.path.exists(cg_path):
- log.debug("fx graph cache miss for key %s", key)
- counters["inductor"]["fxgraph_cache_miss"] += 1
- compiled_graph: CompiledFxGraph = compile_fx_fn(*fx_args, **fx_kwargs)
- cls.save_graph(key, compiled_graph)
- else:
- # Load required info from disk; recreation of compiled model will be on first run
- log.debug("fx graph cache hit for key %s", key)
- counters["inductor"]["fxgraph_cache_hit"] += 1
- compiled_graph = cls.load_graph(cg_path)
-
- return compiled_graph
-
-
-@dataclasses.dataclass
class CompiledFxGraph:
"""Class holding a compiled FX graph"""
@@ -658,7 +396,6 @@
device_idxs: Set[int] = field(default_factory=set)
mutated_inputs: Set[str] = field(default_factory=set)
mutated_input_idxs: Set[int] = field(default_factory=set)
- constants: Dict[str, torch.Tensor] = field(default_factory=dict)
_boxed_call: Optional[bool] = None
@@ -688,7 +425,6 @@
compiled_graph.cache_key,
compiled_graph.artifact_path,
compiled_graph.cache_linemap,
- compiled_graph.constants,
).call
return compiled_graph.compiled_artifact(inputs)
@@ -1589,10 +1325,9 @@
source_code: str,
extra: str = "",
linemap: Optional[List[Tuple[int, str]]] = None,
- attrs: Optional[Dict[str, Any]] = None,
) -> ModuleType:
key, path = write(source_code, "py", extra=extra)
- return cls.load_by_key_path(key, path, linemap, attrs)
+ return cls.load_by_key_path(key, path, linemap)
@classmethod
def load_by_key_path(
@@ -1600,7 +1335,6 @@
key: str,
path: str,
linemap: Optional[List[Tuple[int, str]]] = None,
- attrs: Optional[Dict[str, Any]] = None,
) -> ModuleType:
if linemap is None:
linemap = []
@@ -1622,10 +1356,6 @@
# unzip into separate lines/nodes lists
cls.linemaps[path] = list(zip(*linemap))
- if attrs is not None:
- for k, v in attrs.items():
- setattr(mod, k, v)
-
return cls.cache[key]
@classmethod
diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py
index f25b9f5..8429b5a 100644
--- a/torch/_inductor/compile_fx.py
+++ b/torch/_inductor/compile_fx.py
@@ -4,7 +4,6 @@
import itertools
import logging
import sys
-import time
import warnings
from typing import Any, Callable, Dict, FrozenSet, List, Optional, Sequence, Union
@@ -23,7 +22,7 @@
)
from torch._dynamo.utils import detect_fake_mode
from torch._functorch.aot_autograd import make_boxed_func
-from torch._inductor.codecache import code_hash, CompiledFxGraph, FxGraphCache
+from torch._inductor.codecache import code_hash, CompiledFxGraph
from torch._inductor.debug import save_args_for_compile_fx_inner
from torch._ops import OpOverload
@@ -324,8 +323,6 @@
cudagraphs = BoxedBool(config.triton.cudagraphs)
# Inputs to fx_codegen_and_compile
- # Anything that affects codegen should go here, so if the signature
- # of fx_codegen_and_compile changes, the list and dict should be updated accordingly
graph_args = [gm, example_inputs]
graph_kwargs = {
"cudagraphs": cudagraphs,
@@ -340,18 +337,9 @@
"extern_node_serializer": extern_node_serializer,
}
- start = time.time()
-
- if config.fx_graph_cache:
- compiled_graph: CompiledFxGraph = FxGraphCache.load(
- fx_codegen_and_compile, graph_args, graph_kwargs
- )
- else:
- compiled_graph = fx_codegen_and_compile(
- *graph_args, **graph_kwargs # type: ignore[arg-type]
- )
-
- log.debug("FX codegen and compilation took %.3fs", time.time() - start)
+ compiled_graph: CompiledFxGraph = fx_codegen_and_compile(
+ *graph_args, **graph_kwargs # type: ignore[arg-type]
+ )
if aot_mode:
return compiled_graph
@@ -559,7 +547,6 @@
)
else:
context.output_strides.append(None)
-
compiled_fn = graph.compile_to_fn()
if V.aot_compilation is True:
@@ -577,7 +564,6 @@
device_idxs=graph.device_idxs,
mutated_inputs=graph.mutated_inputs,
mutated_input_idxs=set(graph.mutated_input_idxs),
- constants=graph.constants,
)
return compiled_graph
diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py
index eba3029..b354009 100644
--- a/torch/_inductor/config.py
+++ b/torch/_inductor/config.py
@@ -15,9 +15,6 @@
# Whether to enable printing the source code for each future
verbose_progress = False
-# use fx aot graph codegen cache
-fx_graph_cache = os.environ.get("TORCHINDUCTOR_FX_GRAPH_CACHE") == "1"
-
# use cpp wrapper instead of python wrapper
cpp_wrapper = False
diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py
index 0d207ca..5e09e5c 100644
--- a/torch/_inductor/graph.py
+++ b/torch/_inductor/graph.py
@@ -639,13 +639,6 @@
e.__traceback__
) from None
- @staticmethod
- def can_inline_constant(t: torch.Tensor) -> bool:
- """
- True if this is a small constant attr that will be inlined.
- """
- return len(t.shape) == 1 and t.shape[0] <= 8
-
def get_attr(self, target, args, kwargs):
# this is a constant
value = getattr(self.module, target)
@@ -656,7 +649,7 @@
with no_dispatch():
if value.shape == ():
return Constant(value.item(), value.dtype, value.device)
- if self.can_inline_constant(value):
+ if len(value.shape) == 1 and value.shape[0] <= 8:
# tensor lowering has constant inlining logic
from .lowering import tensor
@@ -974,13 +967,14 @@
code, linemap = self.codegen()
linemap = [(line_no, node.stack_trace) for line_no, node in linemap]
key, path = PyCodeCache.write(code)
- mod = PyCodeCache.load_by_key_path(
- key, path, linemap=linemap, attrs=self.constants
- )
+ mod = PyCodeCache.load_by_key_path(key, path, linemap=linemap)
self.cache_key = key
self.cache_path = path
self.cache_linemap = linemap
+ for name, value in self.constants.items():
+ setattr(mod, name, value)
+
# Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029
# TODO. Revisit this once the logging API is more mature
assert mod.__file__ is not None
diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py
index 3bfffb7..7aa09a7 100644
--- a/torch/fx/graph_module.py
+++ b/torch/fx/graph_module.py
@@ -106,9 +106,7 @@
import_strs: Set[str] = set()
for name, obj in globals.items():
import_strs.add(_format_import_statement(name, obj, importer))
- # Sort the imports so we have a stable import block that allows us to
- # hash the graph module and get a consistent key for use in a cache.
- return "\n".join(sorted(import_strs))
+ return "\n".join(import_strs)
@compatibility(is_backward_compatible=True)