Fix triton import time cycles (#122059)
Summary: `has_triton` causes some import time cycles. Lets use `has_triton_package` which is enough.
Test Plan:
```
buck2 test 'fbcode//mode/opt' fbcode//fblearner/flow/projects/model_processing/pytorch_model_export_utils/logical_transformations/tests:filter_inference_feature_metadata_test -- --exact 'fblearner/flow/projects/model_processing/pytorch_model_export_utils/logical_transformations/tests:filter_inference_feature_metadata_test - test_collect_features_from_graph_module_nodes (fblearner.flow.projects.model_processing.pytorch_model_export_utils.logical_transformations.tests.filter_inference_feature_metadata_test.FilterInferenceFromFeatureMetadataTest)'
```
now passes
Differential Revision: D55001430
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122059
Approved by: https://github.com/aakhundov
diff --git a/torch/fx/graph.py b/torch/fx/graph.py
index 083c0d5..ac99fd2 100644
--- a/torch/fx/graph.py
+++ b/torch/fx/graph.py
@@ -384,14 +384,14 @@
def _gen_python_code(
self, nodes, root_module: str, namespace: _Namespace, *, verbose: bool = False,
) -> PythonCode:
- from torch.utils._triton import has_triton
+ from torch.utils._triton import has_triton_package
free_vars: List[str] = []
body: List[str] = []
globals_: Dict[str, Any] = {}
wrapped_fns: Dict[str, None] = {}
- if has_triton():
+ if has_triton_package():
import triton
globals_[triton.__name__] = triton
from torch.utils._triton import patch_triton_dtype_repr
diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py
index e0c80c8..7140a00 100644
--- a/torch/fx/proxy.py
+++ b/torch/fx/proxy.py
@@ -241,9 +241,9 @@
Can be override to support more trace-specific types.
"""
- from torch.utils._triton import has_triton
+ from torch.utils._triton import has_triton_package
- if has_triton():
+ if has_triton_package():
import triton
if not isinstance(a, Proxy) and hasattr(a, '__fx_create_arg__'):
@@ -281,7 +281,7 @@
elif isinstance(a, torch._ops.OpOverload):
return a
- elif has_triton() and isinstance(a, triton.language.dtype):
+ elif has_triton_package() and isinstance(a, triton.language.dtype):
return a
if isinstance(a, Proxy):