[dynamo][benchmarks] use fresh inductor cache and raise batch size wherever possible (#88044)

cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88044
Approved by: https://github.com/ngimel
diff --git a/benchmarks/dynamo/Makefile_dashboard b/benchmarks/dynamo/Makefile_dashboard
index 1c75d60..559f9fe 100644
--- a/benchmarks/dynamo/Makefile_dashboard
+++ b/benchmarks/dynamo/Makefile_dashboard
@@ -7,6 +7,7 @@
 		&& (test -e torchvision || git clone --recursive https://github.com/pytorch/vision torchvision) \
 		&& (test -e torchdata || git clone --recursive https://github.com/pytorch/data.git torchdata) \
 		&& (test -e torchtext || git clone --recursive https://github.com/pytorch/text torchtext) \
+		&& (test -e torchaudio || git clone --recursive https://github.com/pytorch/text torchaudio) \
 		&& (test -e detectron2 || git clone --recursive https://github.com/facebookresearch/detectron2) \
 		&& (test -e torchbenchmark || git clone --recursive https://github.com/pytorch/benchmark torchbenchmark) \
 		&& (test -e triton || git clone --recursive https://github.com/openai/triton.git) \
@@ -17,6 +18,7 @@
 	(cd ../../../torchvision    && git pull && git submodule update --init --recursive)
 	(cd ../../../torchdata      && git pull && git submodule update --init --recursive)
 	(cd ../../../torchtext      && git pull && git submodule update --init --recursive)
+	(cd ../../../torchaudio      && git pull && git submodule update --init --recursive)
 	(cd ../../../detectron2     && git pull && git submodule update --init --recursive)
 	(cd ../../../torchbenchmark && git pull && git submodule update --init --recursive)
 	(cd ../../../triton         && git checkout master && git pull && git checkout $(TRITON_VERSION) && git submodule update --init --recursive)
@@ -32,6 +34,7 @@
 	(cd ../../../torchvision && python setup.py clean && python setup.py develop)
 	(cd ../../../torchdata && python setup.py install)
 	(cd ../../../torchtext   && python setup.py clean && python setup.py develop)
+	(cd ../../../torchaudio   && python setup.py clean && python setup.py develop)
 	(cd ../../../detectron2  && python setup.py clean && python setup.py develop)
 	(cd ../../../torchbenchmark && python install.py --continue_on_fail)
 	(cd ../../../triton/python && python setup.py clean && python setup.py develop)
diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py
index d1fd51f..dcbdfa6 100644
--- a/benchmarks/dynamo/common.py
+++ b/benchmarks/dynamo/common.py
@@ -28,7 +28,7 @@
 from torch._dynamo.testing import dummy_fx_compile, format_speedup, same
 from torch._dynamo.utils import clone_inputs
 from torch._inductor import config as inductor_config
-from torch._inductor.utils import fresh_triton_cache
+from torch._inductor.utils import fresh_inductor_cache
 from torch._subclasses.fake_tensor import FakeTensorMode
 from torch.utils._pytree import tree_map
 
@@ -765,7 +765,7 @@
         cache_minder = NullContext()
         if is_cold_start:
             cache_entries = {}
-            cache_minder = fresh_triton_cache(cache_entries)
+            cache_minder = fresh_inductor_cache(cache_entries)
 
         try:
             with cache_minder:
diff --git a/benchmarks/dynamo/huggingface.py b/benchmarks/dynamo/huggingface.py
index b563c22..c7ecd5f 100755
--- a/benchmarks/dynamo/huggingface.py
+++ b/benchmarks/dynamo/huggingface.py
@@ -68,9 +68,6 @@
         exec(f"from transformers import {cls}")
 
 
-USE_HALF_BATCH_SIZE = True
-
-
 # These models contain the models present in huggingface_models_list. It is a
 # combination of models supported by HF Fx parser and some manually supplied
 # models. For these models, we already know the largest batch size that can fit
@@ -107,31 +104,27 @@
 }
 
 # TODO - Fails even after fake tensors
-USE_SMALL_BATCH_SIZE = {
+BATCH_SIZE_DIVISORS = {
     "AlbertForMaskedLM": 2,
-    "AlbertForPreTraining": 4,
     "AlbertForQuestionAnswering": 2,
-    "BartForCausalLM": 2,
-    "BartForConditionalGeneration": 1,
-    "BlenderbotSmallForConditionalGeneration": 32,
-    "DebertaForMaskedLM": 4,
+    "AllenaiLongformerBase": 2,
+    "BartForConditionalGeneration": 2,
+    "BertForMaskedLM": 2,
+    "BlenderbotSmallForCausalLM": 2,
+    "BlenderbotSmallForConditionalGeneration": 2,
+    "ElectraForCausalLM": 2,
+    "ElectraForQuestionAnswering": 2,
+    "GPT2ForSequenceClassification": 2,
+    "LayoutLMForMaskedLM": 2,
+    "LayoutLMForSequenceClassification": 2,
+    "RobertaForCausalLM": 2,
+    "T5ForConditionalGeneration": 2,
+    # Large footprint
+    "BartForCausalLM": 4,
     "DebertaForQuestionAnswering": 4,
-    "DebertaV2ForMaskedLM": 1,
-    "DebertaV2ForQuestionAnswering": 1,
-    "DistilBertForMaskedLM": 16,
-    "ElectraForCausalLM": 1,
-    "GPTNeoForCausalLM": 1,
-    "GPTNeoForSequenceClassification": 1,
-    "M2M100ForConditionalGeneration": 2,
-    "MT5ForConditionalGeneration": 2,
-    "MegatronBertForCausalLM": 2,
-    "OPTForCausalLM": 4,
-    "PegasusForCausalLM": 8,
-    "PegasusForConditionalGeneration": 4,
-    "RobertaForCausalLM": 4,
-    "TrOCRForCausalLM": 8,
-    "XGLMForCausalLM": 1,
     "XLNetLMHeadModel": 4,
+    # Very large footprint
+    "DebertaForMaskedLM": 8,
 }
 
 
@@ -369,13 +362,8 @@
 
         if batch_size is None:
             batch_size = batch_size_default
-            if model_name in USE_SMALL_BATCH_SIZE:
-                batch_size = USE_SMALL_BATCH_SIZE[model_name]
-                log.warning(
-                    f"Running smaller batch size={batch_size} for {model_name}, orig batch_size={batch_size_default}"
-                )
-            elif USE_HALF_BATCH_SIZE and batch_size >= 2:
-                batch_size = int(batch_size / 2)
+            if model_name in BATCH_SIZE_DIVISORS:
+                batch_size = max(int(batch_size / BATCH_SIZE_DIVISORS[model_name]), 1)
                 log.warning(
                     f"Running smaller batch size={batch_size} for {model_name}, orig batch_size={batch_size_default}"
                 )
diff --git a/benchmarks/dynamo/timm_models.py b/benchmarks/dynamo/timm_models.py
index f7ff255..70d06ab 100755
--- a/benchmarks/dynamo/timm_models.py
+++ b/benchmarks/dynamo/timm_models.py
@@ -40,44 +40,30 @@
 
 
 # TODO - Figure out the reason of cold start memory spike
+
 BATCH_SIZE_DIVISORS = {
     "beit_base_patch16_224": 2,
-    "cait_m36_384": 4,
-    "convit_base": 4,
+    "cait_m36_384": 2,
+    "convit_base": 2,
     "convmixer_768_32": 2,
-    "convnext_base": 4,
-    "crossvit_9_240": 2,
+    "convnext_base": 2,
     "cspdarknet53": 2,
     "deit_base_distilled_patch16_224": 2,
-    "dla102": 2,
     "dpn107": 2,
-    "eca_botnext26ts_256": 2,
-    "eca_halonext26ts": 2,
-    "gluon_senet154": 2,
     "gluon_xception65": 2,
-    "gmixer_24_224": 2,
-    "gmlp_s16_224": 2,
-    "hrnet_w18": 64,
-    "jx_nest_base": 4,
-    "mixer_b16_224": 2,
-    "mixnet_l": 2,
-    "mobilevit_s": 4,
-    "nfnet_l0": 2,
+    "mobilevit_s": 2,
     "pit_b_224": 2,
     "pnasnet5large": 2,
     "poolformer_m36": 2,
     "res2net101_26w_4s": 2,
-    "res2net50_14w_8s": 64,
-    "res2next50": 64,
-    "resnest101e": 4,
+    "resnest101e": 2,
     "sebotnet33ts_256": 2,
     "swin_base_patch4_window7_224": 2,
     "swsl_resnext101_32x16d": 2,
-    "tf_mixnet_l": 2,
-    "tnt_s_patch16_224": 2,
-    "twins_pcpvt_base": 4,
+    "twins_pcpvt_base": 2,
     "vit_base_patch16_224": 2,
     "volo_d1_224": 2,
+    "jx_nest_base": 4,
     "xcit_large_24_p8_224": 4,
 }
 
@@ -230,9 +216,11 @@
         )
         input_size = data_config["input_size"]
         recorded_batch_size = TIMM_MODELS[model_name]
-        recorded_batch_size = max(
-            int(recorded_batch_size / BATCH_SIZE_DIVISORS.get(model_name, 1)), 1
-        )
+
+        if model_name in BATCH_SIZE_DIVISORS:
+            recorded_batch_size = max(
+                int(recorded_batch_size / BATCH_SIZE_DIVISORS[model_name]), 1
+            )
         batch_size = batch_size or recorded_batch_size
 
         # example_inputs = torch.randn(
diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py
index 1c97c26..7f1e8bc 100644
--- a/torch/_inductor/codecache.py
+++ b/torch/_inductor/codecache.py
@@ -50,8 +50,11 @@
 logging.getLogger("filelock").setLevel(logging.DEBUG if config.debug else logging.INFO)
 
 
+@functools.lru_cache(None)
 def cache_dir():
-    return f"/tmp/torchinductor_{getpass.getuser()}"
+    return os.environ.get(
+        "TORCHINDUCTOR_CACHE_DIR", f"/tmp/torchinductor_{getpass.getuser()}"
+    )
 
 
 def get_lock_dir():
diff --git a/torch/_inductor/triton_ops/autotune.py b/torch/_inductor/triton_ops/autotune.py
index 59ee762..b6f1c5c 100644
--- a/torch/_inductor/triton_ops/autotune.py
+++ b/torch/_inductor/triton_ops/autotune.py
@@ -13,7 +13,7 @@
 from .. import config
 from ..ir import ReductionHint
 from ..triton_ops.mm_perf_model import estimate_matmul_time
-from ..utils import conditional_product, has_triton
+from ..utils import conditional_product, dynamo_utils, has_triton
 from .conv_perf_model import (
     early_config_prune as conv_early_config_prune,
     estimate_conv_time,
@@ -136,6 +136,7 @@
 
         return do_bench(kernel_call)
 
+    @dynamo_utils.dynamo_timed
     def autotune_to_one_config(self, *args, **kwargs):
         """Do the actual autotuning"""
         from ..compile_fx import clone_preserve_strides
diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py
index e970f6a..8f08c01 100644
--- a/torch/_inductor/utils.py
+++ b/torch/_inductor/utils.py
@@ -242,23 +242,27 @@
 
 
 @contextlib.contextmanager
-def fresh_triton_cache(cache_entries=None):
+def fresh_inductor_cache(cache_entries=None):
     """
-    Contextmanager that provides a clean tmp cachedir for triton.
+    Contextmanager that provides a clean tmp cachedir for inductor.
 
     Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes
     generated with this cache instance.
     """
-    with tempfile.TemporaryDirectory() as tmpdirname:
-        with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": tmpdirname}):
-            yield
-            if isinstance(cache_entries, dict):
-                assert len(cache_entries) == 0, "expected empty cache_entries dict"
-                files = os.listdir(tmpdirname)
-                cache_entries.update(
-                    {
-                        f: os.path.getsize(os.path.join(tmpdirname, f))
-                        for f in files
-                        if ".lock" not in f
-                    }
-                )
+    with tempfile.TemporaryDirectory() as inductor_cache_dir:
+        with mock.patch.dict(
+            os.environ, {"TORCHINDUCTOR_CACHE_DIR": inductor_cache_dir}
+        ):
+            triton_cache_dir = os.path.join(inductor_cache_dir, "triton")
+            with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": triton_cache_dir}):
+                yield
+                if isinstance(cache_entries, dict):
+                    assert len(cache_entries) == 0, "expected empty cache_entries dict"
+                    files = os.listdir(triton_cache_dir)
+                    cache_entries.update(
+                        {
+                            f: os.path.getsize(os.path.join(triton_cache_dir, f))
+                            for f in files
+                            if ".lock" not in f
+                        }
+                    )