[dynamo][benchmarks] use fresh inductor cache and raise batch size wherever possible (#88044)
cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88044
Approved by: https://github.com/ngimel
diff --git a/benchmarks/dynamo/Makefile_dashboard b/benchmarks/dynamo/Makefile_dashboard
index 1c75d60..559f9fe 100644
--- a/benchmarks/dynamo/Makefile_dashboard
+++ b/benchmarks/dynamo/Makefile_dashboard
@@ -7,6 +7,7 @@
&& (test -e torchvision || git clone --recursive https://github.com/pytorch/vision torchvision) \
&& (test -e torchdata || git clone --recursive https://github.com/pytorch/data.git torchdata) \
&& (test -e torchtext || git clone --recursive https://github.com/pytorch/text torchtext) \
+ && (test -e torchaudio || git clone --recursive https://github.com/pytorch/text torchaudio) \
&& (test -e detectron2 || git clone --recursive https://github.com/facebookresearch/detectron2) \
&& (test -e torchbenchmark || git clone --recursive https://github.com/pytorch/benchmark torchbenchmark) \
&& (test -e triton || git clone --recursive https://github.com/openai/triton.git) \
@@ -17,6 +18,7 @@
(cd ../../../torchvision && git pull && git submodule update --init --recursive)
(cd ../../../torchdata && git pull && git submodule update --init --recursive)
(cd ../../../torchtext && git pull && git submodule update --init --recursive)
+ (cd ../../../torchaudio && git pull && git submodule update --init --recursive)
(cd ../../../detectron2 && git pull && git submodule update --init --recursive)
(cd ../../../torchbenchmark && git pull && git submodule update --init --recursive)
(cd ../../../triton && git checkout master && git pull && git checkout $(TRITON_VERSION) && git submodule update --init --recursive)
@@ -32,6 +34,7 @@
(cd ../../../torchvision && python setup.py clean && python setup.py develop)
(cd ../../../torchdata && python setup.py install)
(cd ../../../torchtext && python setup.py clean && python setup.py develop)
+ (cd ../../../torchaudio && python setup.py clean && python setup.py develop)
(cd ../../../detectron2 && python setup.py clean && python setup.py develop)
(cd ../../../torchbenchmark && python install.py --continue_on_fail)
(cd ../../../triton/python && python setup.py clean && python setup.py develop)
diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py
index d1fd51f..dcbdfa6 100644
--- a/benchmarks/dynamo/common.py
+++ b/benchmarks/dynamo/common.py
@@ -28,7 +28,7 @@
from torch._dynamo.testing import dummy_fx_compile, format_speedup, same
from torch._dynamo.utils import clone_inputs
from torch._inductor import config as inductor_config
-from torch._inductor.utils import fresh_triton_cache
+from torch._inductor.utils import fresh_inductor_cache
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.utils._pytree import tree_map
@@ -765,7 +765,7 @@
cache_minder = NullContext()
if is_cold_start:
cache_entries = {}
- cache_minder = fresh_triton_cache(cache_entries)
+ cache_minder = fresh_inductor_cache(cache_entries)
try:
with cache_minder:
diff --git a/benchmarks/dynamo/huggingface.py b/benchmarks/dynamo/huggingface.py
index b563c22..c7ecd5f 100755
--- a/benchmarks/dynamo/huggingface.py
+++ b/benchmarks/dynamo/huggingface.py
@@ -68,9 +68,6 @@
exec(f"from transformers import {cls}")
-USE_HALF_BATCH_SIZE = True
-
-
# These models contain the models present in huggingface_models_list. It is a
# combination of models supported by HF Fx parser and some manually supplied
# models. For these models, we already know the largest batch size that can fit
@@ -107,31 +104,27 @@
}
# TODO - Fails even after fake tensors
-USE_SMALL_BATCH_SIZE = {
+BATCH_SIZE_DIVISORS = {
"AlbertForMaskedLM": 2,
- "AlbertForPreTraining": 4,
"AlbertForQuestionAnswering": 2,
- "BartForCausalLM": 2,
- "BartForConditionalGeneration": 1,
- "BlenderbotSmallForConditionalGeneration": 32,
- "DebertaForMaskedLM": 4,
+ "AllenaiLongformerBase": 2,
+ "BartForConditionalGeneration": 2,
+ "BertForMaskedLM": 2,
+ "BlenderbotSmallForCausalLM": 2,
+ "BlenderbotSmallForConditionalGeneration": 2,
+ "ElectraForCausalLM": 2,
+ "ElectraForQuestionAnswering": 2,
+ "GPT2ForSequenceClassification": 2,
+ "LayoutLMForMaskedLM": 2,
+ "LayoutLMForSequenceClassification": 2,
+ "RobertaForCausalLM": 2,
+ "T5ForConditionalGeneration": 2,
+ # Large footprint
+ "BartForCausalLM": 4,
"DebertaForQuestionAnswering": 4,
- "DebertaV2ForMaskedLM": 1,
- "DebertaV2ForQuestionAnswering": 1,
- "DistilBertForMaskedLM": 16,
- "ElectraForCausalLM": 1,
- "GPTNeoForCausalLM": 1,
- "GPTNeoForSequenceClassification": 1,
- "M2M100ForConditionalGeneration": 2,
- "MT5ForConditionalGeneration": 2,
- "MegatronBertForCausalLM": 2,
- "OPTForCausalLM": 4,
- "PegasusForCausalLM": 8,
- "PegasusForConditionalGeneration": 4,
- "RobertaForCausalLM": 4,
- "TrOCRForCausalLM": 8,
- "XGLMForCausalLM": 1,
"XLNetLMHeadModel": 4,
+ # Very large footprint
+ "DebertaForMaskedLM": 8,
}
@@ -369,13 +362,8 @@
if batch_size is None:
batch_size = batch_size_default
- if model_name in USE_SMALL_BATCH_SIZE:
- batch_size = USE_SMALL_BATCH_SIZE[model_name]
- log.warning(
- f"Running smaller batch size={batch_size} for {model_name}, orig batch_size={batch_size_default}"
- )
- elif USE_HALF_BATCH_SIZE and batch_size >= 2:
- batch_size = int(batch_size / 2)
+ if model_name in BATCH_SIZE_DIVISORS:
+ batch_size = max(int(batch_size / BATCH_SIZE_DIVISORS[model_name]), 1)
log.warning(
f"Running smaller batch size={batch_size} for {model_name}, orig batch_size={batch_size_default}"
)
diff --git a/benchmarks/dynamo/timm_models.py b/benchmarks/dynamo/timm_models.py
index f7ff255..70d06ab 100755
--- a/benchmarks/dynamo/timm_models.py
+++ b/benchmarks/dynamo/timm_models.py
@@ -40,44 +40,30 @@
# TODO - Figure out the reason of cold start memory spike
+
BATCH_SIZE_DIVISORS = {
"beit_base_patch16_224": 2,
- "cait_m36_384": 4,
- "convit_base": 4,
+ "cait_m36_384": 2,
+ "convit_base": 2,
"convmixer_768_32": 2,
- "convnext_base": 4,
- "crossvit_9_240": 2,
+ "convnext_base": 2,
"cspdarknet53": 2,
"deit_base_distilled_patch16_224": 2,
- "dla102": 2,
"dpn107": 2,
- "eca_botnext26ts_256": 2,
- "eca_halonext26ts": 2,
- "gluon_senet154": 2,
"gluon_xception65": 2,
- "gmixer_24_224": 2,
- "gmlp_s16_224": 2,
- "hrnet_w18": 64,
- "jx_nest_base": 4,
- "mixer_b16_224": 2,
- "mixnet_l": 2,
- "mobilevit_s": 4,
- "nfnet_l0": 2,
+ "mobilevit_s": 2,
"pit_b_224": 2,
"pnasnet5large": 2,
"poolformer_m36": 2,
"res2net101_26w_4s": 2,
- "res2net50_14w_8s": 64,
- "res2next50": 64,
- "resnest101e": 4,
+ "resnest101e": 2,
"sebotnet33ts_256": 2,
"swin_base_patch4_window7_224": 2,
"swsl_resnext101_32x16d": 2,
- "tf_mixnet_l": 2,
- "tnt_s_patch16_224": 2,
- "twins_pcpvt_base": 4,
+ "twins_pcpvt_base": 2,
"vit_base_patch16_224": 2,
"volo_d1_224": 2,
+ "jx_nest_base": 4,
"xcit_large_24_p8_224": 4,
}
@@ -230,9 +216,11 @@
)
input_size = data_config["input_size"]
recorded_batch_size = TIMM_MODELS[model_name]
- recorded_batch_size = max(
- int(recorded_batch_size / BATCH_SIZE_DIVISORS.get(model_name, 1)), 1
- )
+
+ if model_name in BATCH_SIZE_DIVISORS:
+ recorded_batch_size = max(
+ int(recorded_batch_size / BATCH_SIZE_DIVISORS[model_name]), 1
+ )
batch_size = batch_size or recorded_batch_size
# example_inputs = torch.randn(
diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py
index 1c97c26..7f1e8bc 100644
--- a/torch/_inductor/codecache.py
+++ b/torch/_inductor/codecache.py
@@ -50,8 +50,11 @@
logging.getLogger("filelock").setLevel(logging.DEBUG if config.debug else logging.INFO)
+@functools.lru_cache(None)
def cache_dir():
- return f"/tmp/torchinductor_{getpass.getuser()}"
+ return os.environ.get(
+ "TORCHINDUCTOR_CACHE_DIR", f"/tmp/torchinductor_{getpass.getuser()}"
+ )
def get_lock_dir():
diff --git a/torch/_inductor/triton_ops/autotune.py b/torch/_inductor/triton_ops/autotune.py
index 59ee762..b6f1c5c 100644
--- a/torch/_inductor/triton_ops/autotune.py
+++ b/torch/_inductor/triton_ops/autotune.py
@@ -13,7 +13,7 @@
from .. import config
from ..ir import ReductionHint
from ..triton_ops.mm_perf_model import estimate_matmul_time
-from ..utils import conditional_product, has_triton
+from ..utils import conditional_product, dynamo_utils, has_triton
from .conv_perf_model import (
early_config_prune as conv_early_config_prune,
estimate_conv_time,
@@ -136,6 +136,7 @@
return do_bench(kernel_call)
+ @dynamo_utils.dynamo_timed
def autotune_to_one_config(self, *args, **kwargs):
"""Do the actual autotuning"""
from ..compile_fx import clone_preserve_strides
diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py
index e970f6a..8f08c01 100644
--- a/torch/_inductor/utils.py
+++ b/torch/_inductor/utils.py
@@ -242,23 +242,27 @@
@contextlib.contextmanager
-def fresh_triton_cache(cache_entries=None):
+def fresh_inductor_cache(cache_entries=None):
"""
- Contextmanager that provides a clean tmp cachedir for triton.
+ Contextmanager that provides a clean tmp cachedir for inductor.
Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes
generated with this cache instance.
"""
- with tempfile.TemporaryDirectory() as tmpdirname:
- with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": tmpdirname}):
- yield
- if isinstance(cache_entries, dict):
- assert len(cache_entries) == 0, "expected empty cache_entries dict"
- files = os.listdir(tmpdirname)
- cache_entries.update(
- {
- f: os.path.getsize(os.path.join(tmpdirname, f))
- for f in files
- if ".lock" not in f
- }
- )
+ with tempfile.TemporaryDirectory() as inductor_cache_dir:
+ with mock.patch.dict(
+ os.environ, {"TORCHINDUCTOR_CACHE_DIR": inductor_cache_dir}
+ ):
+ triton_cache_dir = os.path.join(inductor_cache_dir, "triton")
+ with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": triton_cache_dir}):
+ yield
+ if isinstance(cache_entries, dict):
+ assert len(cache_entries) == 0, "expected empty cache_entries dict"
+ files = os.listdir(triton_cache_dir)
+ cache_entries.update(
+ {
+ f: os.path.getsize(os.path.join(triton_cache_dir, f))
+ for f in files
+ if ".lock" not in f
+ }
+ )