[BE][Easy] use `pathlib.Path` instead of `dirname` / `".."` / `pardir` (#129374)
Changes by apply order:
1. Replace all `".."` and `os.pardir` usage with `os.path.dirname(...)`.
2. Replace nested `os.path.dirname(os.path.dirname(...))` call with `str(Path(...).parent.parent)`.
3. Reorder `.absolute()` ~/ `.resolve()`~ and `.parent`: always resolve the path first.
`.parent{...}.absolute()` -> `.absolute().parent{...}`
4. Replace chained `.parent x N` with `.parents[${N - 1}]`: the code is easier to read (see 5.)
`.parent.parent.parent.parent` -> `.parents[3]`
5. ~Replace `.parents[${N - 1}]` with `.parents[${N} - 1]`: the code is easier to read and does not introduce any runtime overhead.~
~`.parents[3]` -> `.parents[4 - 1]`~
6. ~Replace `.parents[2 - 1]` with `.parent.parent`: because the code is shorter and easier to read.~
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129374
Approved by: https://github.com/justinchuby, https://github.com/malfet
diff --git a/.circleci/codegen_validation/normalize_yaml_fragment.py b/.circleci/codegen_validation/normalize_yaml_fragment.py
index 32899cc..232eaa8 100755
--- a/.circleci/codegen_validation/normalize_yaml_fragment.py
+++ b/.circleci/codegen_validation/normalize_yaml_fragment.py
@@ -5,8 +5,9 @@
import yaml
+
# Need to import modules that lie on an upward-relative path
-sys.path.append(os.path.join(sys.path[0], ".."))
+sys.path.append(os.path.dirname(sys.path[0]))
import cimodel.lib.miniyaml as miniyaml
diff --git a/.github/scripts/delete_old_branches.py b/.github/scripts/delete_old_branches.py
index c2676ae..e28d33c 100644
--- a/.github/scripts/delete_old_branches.py
+++ b/.github/scripts/delete_old_branches.py
@@ -9,6 +9,7 @@
from github_utils import gh_fetch_json_dict, gh_graphql
from gitutils import GitRepo
+
SEC_IN_DAY = 24 * 60 * 60
CLOSED_PR_RETENTION = 30 * SEC_IN_DAY
NO_PR_RETENTION = 1.5 * 365 * SEC_IN_DAY
@@ -21,7 +22,7 @@
if not TOKEN:
raise Exception("GITHUB_TOKEN is not set") # noqa: TRY002
-REPO_ROOT = Path(__file__).parent.parent.parent
+REPO_ROOT = Path(__file__).parents[2]
# Query for all PRs instead of just closed/merged because it's faster
GRAPHQL_ALL_PRS_BY_UPDATED_AT = """
diff --git a/.github/scripts/ensure_actions_will_cancel.py b/.github/scripts/ensure_actions_will_cancel.py
index 7e0cbc1..2c76f09 100755
--- a/.github/scripts/ensure_actions_will_cancel.py
+++ b/.github/scripts/ensure_actions_will_cancel.py
@@ -1,13 +1,12 @@
#!/usr/bin/env python3
import sys
-
from pathlib import Path
import yaml
-REPO_ROOT = Path(__file__).resolve().parent.parent.parent
+REPO_ROOT = Path(__file__).resolve().parents[2]
WORKFLOWS = REPO_ROOT / ".github" / "workflows"
EXPECTED_GROUP_PREFIX = (
"${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}"
diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py
index cf1135b..a797a22 100644
--- a/.github/scripts/generate_binary_build_matrix.py
+++ b/.github/scripts/generate_binary_build_matrix.py
@@ -13,6 +13,7 @@
import os
from typing import Dict, List, Optional, Tuple
+
CUDA_ARCHES = ["11.8", "12.1", "12.4"]
@@ -85,7 +86,7 @@
from pathlib import Path
nccl_version_mk = (
- Path(__file__).absolute().parent.parent.parent
+ Path(__file__).absolute().parents[2]
/ "third_party"
/ "nccl"
/ "nccl"
diff --git a/.github/scripts/gitutils.py b/.github/scripts/gitutils.py
index 1640e43..ecc678f 100644
--- a/.github/scripts/gitutils.py
+++ b/.github/scripts/gitutils.py
@@ -19,6 +19,7 @@
Union,
)
+
T = TypeVar("T")
RE_GITHUB_URL_MATCH = re.compile("^https://.*@?github.com/(.+)/(.+)$")
@@ -31,7 +32,7 @@
def get_git_repo_dir() -> str:
from pathlib import Path
- return os.getenv("GIT_REPO_DIR", str(Path(__file__).resolve().parent.parent.parent))
+ return os.getenv("GIT_REPO_DIR", str(Path(__file__).resolve().parents[2]))
def fuzzy_list_to_dict(items: List[Tuple[str, str]]) -> Dict[str, List[str]]:
diff --git a/.github/scripts/lint_native_functions.py b/.github/scripts/lint_native_functions.py
index 4dfe9fd..07504d7 100755
--- a/.github/scripts/lint_native_functions.py
+++ b/.github/scripts/lint_native_functions.py
@@ -26,7 +26,7 @@
return str(base / Path("aten/src/ATen/native/native_functions.yaml"))
-with open(Path(__file__).parent.parent.parent / fn(".")) as f:
+with open(Path(__file__).parents[2] / fn(".")) as f:
contents = f.read()
yaml = ruamel.yaml.YAML() # type: ignore[attr-defined]
diff --git a/.github/scripts/test_gitutils.py b/.github/scripts/test_gitutils.py
index c4137ba..b269cac 100644
--- a/.github/scripts/test_gitutils.py
+++ b/.github/scripts/test_gitutils.py
@@ -68,7 +68,7 @@
class TestGitRepo(TestCase):
def setUp(self) -> None:
- repo_dir = BASE_DIR.parent.parent.absolute()
+ repo_dir = BASE_DIR.absolute().parent.parent
if not (repo_dir / ".git").is_dir():
raise SkipTest(
"Can't find git directory, make sure to run this test on real repo checkout"
diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/update_expected.py b/benchmarks/dynamo/ci_expected_accuracy/cu124/update_expected.py
index 5d73cf6..daecbb4 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/cu124/update_expected.py
+++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/update_expected.py
@@ -18,23 +18,24 @@
import argparse
import json
import os
-import pathlib
import subprocess
import sys
import urllib
from io import BytesIO
from itertools import product
+from pathlib import Path
from urllib.request import urlopen
from zipfile import ZipFile
import pandas as pd
import requests
+
# Note: the public query url targets this rockset lambda:
# https://console.rockset.com/lambdas/details/commons.artifacts
ARTIFACTS_QUERY_URL = "https://api.usw2a1.rockset.com/v1/public/shared_lambdas/4ca0033e-0117-41f5-b043-59cde19eff35"
CSV_LINTER = str(
- pathlib.Path(__file__).absolute().parent.parent.parent.parent
+ Path(__file__).absolute().parents[3]
/ "tools/linter/adapters/no_merge_conflict_csv_linter.py"
)
diff --git a/benchmarks/dynamo/ci_expected_accuracy/update_expected.py b/benchmarks/dynamo/ci_expected_accuracy/update_expected.py
index 5d73cf6..daecbb4 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/update_expected.py
+++ b/benchmarks/dynamo/ci_expected_accuracy/update_expected.py
@@ -18,23 +18,24 @@
import argparse
import json
import os
-import pathlib
import subprocess
import sys
import urllib
from io import BytesIO
from itertools import product
+from pathlib import Path
from urllib.request import urlopen
from zipfile import ZipFile
import pandas as pd
import requests
+
# Note: the public query url targets this rockset lambda:
# https://console.rockset.com/lambdas/details/commons.artifacts
ARTIFACTS_QUERY_URL = "https://api.usw2a1.rockset.com/v1/public/shared_lambdas/4ca0033e-0117-41f5-b043-59cde19eff35"
CSV_LINTER = str(
- pathlib.Path(__file__).absolute().parent.parent.parent.parent
+ Path(__file__).absolute().parents[3]
/ "tools/linter/adapters/no_merge_conflict_csv_linter.py"
)
diff --git a/docs/source/scripts/build_opsets.py b/docs/source/scripts/build_opsets.py
index f0c4d3b..c752ade 100644
--- a/docs/source/scripts/build_opsets.py
+++ b/docs/source/scripts/build_opsets.py
@@ -4,12 +4,12 @@
import torch
import torch._prims as prims
-
from torchgen.gen import parse_native_yaml
-ROOT = Path(__file__).absolute().parent.parent.parent.parent
-NATIVE_FUNCTION_YAML_PATH = ROOT / Path("aten/src/ATen/native/native_functions.yaml")
-TAGS_YAML_PATH = ROOT / Path("aten/src/ATen/native/tags.yaml")
+
+ROOT = Path(__file__).absolute().parents[3]
+NATIVE_FUNCTION_YAML_PATH = ROOT / "aten/src/ATen/native/native_functions.yaml"
+TAGS_YAML_PATH = ROOT / "aten/src/ATen/native/tags.yaml"
BUILD_DIR = "build/ir"
ATEN_OPS_CSV_FILE = "aten_ops.csv"
diff --git a/docs/source/scripts/build_quantization_configs.py b/docs/source/scripts/build_quantization_configs.py
index bf40566..5d1f445 100644
--- a/docs/source/scripts/build_quantization_configs.py
+++ b/docs/source/scripts/build_quantization_configs.py
@@ -15,7 +15,7 @@
# Create a directory for the images, if it doesn't exist
QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH = os.path.join(
- os.path.realpath(os.path.join(__file__, "..")), "quantization_backend_configs"
+ os.path.realpath(os.path.dirname(__file__)), "quantization_backend_configs"
)
if not os.path.exists(QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH):
diff --git a/docs/source/scripts/exportdb/generate_example_rst.py b/docs/source/scripts/exportdb/generate_example_rst.py
index b9f68ad..58f8b14 100644
--- a/docs/source/scripts/exportdb/generate_example_rst.py
+++ b/docs/source/scripts/exportdb/generate_example_rst.py
@@ -5,16 +5,15 @@
import torch
import torch._dynamo as torchdynamo
-
from torch._export.db.case import ExportCase, normalize_inputs
from torch._export.db.examples import all_examples
from torch.export import export
PWD = Path(__file__).absolute().parent
-ROOT = Path(__file__).absolute().parent.parent.parent.parent
-SOURCE = ROOT / Path("source")
-EXPORTDB_SOURCE = SOURCE / Path("generated") / Path("exportdb")
+ROOT = Path(__file__).absolute().parents[3]
+SOURCE = ROOT / "source"
+EXPORTDB_SOURCE = SOURCE / "generated" / "exportdb"
def generate_example_rst(example_case: ExportCase):
diff --git a/scripts/compile_tests/update_failures.py b/scripts/compile_tests/update_failures.py
index 929ed9f..c786362 100755
--- a/scripts/compile_tests/update_failures.py
+++ b/scripts/compile_tests/update_failures.py
@@ -1,8 +1,8 @@
#!/usr/bin/env python3
import argparse
import os
-import pathlib
import subprocess
+from pathlib import Path
from common import (
get_testcases,
@@ -12,9 +12,9 @@
key,
open_test_results,
)
-
from download_reports import download_reports
+
"""
Usage: update_failures.py /path/to/dynamo_test_failures.py /path/to/test commit_sha
@@ -194,7 +194,7 @@
"filename",
nargs="?",
default=str(
- pathlib.Path(__file__).absolute().parent.parent.parent
+ Path(__file__).absolute().parents[2]
/ "torch/testing/_internal/dynamo_test_failures.py"
),
help="Optional path to dynamo_test_failures.py",
@@ -203,7 +203,7 @@
parser.add_argument(
"test_dir",
nargs="?",
- default=str(pathlib.Path(__file__).absolute().parent.parent.parent / "test"),
+ default=str(Path(__file__).absolute().parents[2] / "test"),
help="Optional path to test folder",
)
parser.add_argument(
@@ -219,7 +219,7 @@
action="store_true",
)
args = parser.parse_args()
- assert pathlib.Path(args.filename).exists(), args.filename
- assert pathlib.Path(args.test_dir).exists(), args.test_dir
+ assert Path(args.filename).exists(), args.filename
+ assert Path(args.test_dir).exists(), args.test_dir
dynamo38, dynamo311 = download_reports(args.commit, ("dynamo38", "dynamo311"))
update(args.filename, args.test_dir, dynamo38, dynamo311, args.also_remove_skips)
diff --git a/test/jit/fixtures_srcs/generate_models.py b/test/jit/fixtures_srcs/generate_models.py
index 973a31e..1de9bb3 100644
--- a/test/jit/fixtures_srcs/generate_models.py
+++ b/test/jit/fixtures_srcs/generate_models.py
@@ -5,12 +5,13 @@
from pathlib import Path
from typing import Set
-import torch
-
# Use asterisk symbol so developer doesn't need to import here when they add tests for upgraders.
from test.jit.fixtures_srcs.fixtures_src import * # noqa: F403
+
+import torch
from torch.jit.mobile import _export_operator_list, _load_for_lite_interpreter
+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
diff --git a/test/jit/test_backend_nnapi.py b/test/jit/test_backend_nnapi.py
index 289827e..9f47716 100644
--- a/test/jit/test_backend_nnapi.py
+++ b/test/jit/test_backend_nnapi.py
@@ -9,6 +9,7 @@
import torch._C
from torch.testing._internal.common_utils import IS_FBCODE, skipIfTorchDynamo
+
# hacky way to skip these tests in fbcode:
# during test execution in fbcode, test_nnapi is available during test discovery,
# but not during test execution. So we can't try-catch here, otherwise it'll think
@@ -40,7 +41,7 @@
without the delegate API.
"""
# First skip is needed for IS_WINDOWS or IS_MACOS to skip the tests.
-torch_root = Path(__file__).resolve().parent.parent.parent
+torch_root = Path(__file__).resolve().parents[2]
lib_path = torch_root / "build" / "lib" / "libnnapi_backend.so"
diff --git a/test/mobile/test_bytecode.py b/test/mobile/test_bytecode.py
index 14faa13..307921d 100644
--- a/test/mobile/test_bytecode.py
+++ b/test/mobile/test_bytecode.py
@@ -20,6 +20,7 @@
)
from torch.testing._internal.common_utils import run_tests, TestCase
+
pytorch_test_dir = Path(__file__).resolve().parents[1]
# script_module_v4.ptl and script_module_v5.ptl source code
diff --git a/test/mobile/test_upgrader_codegen.py b/test/mobile/test_upgrader_codegen.py
index 51acebc..033cb26 100644
--- a/test/mobile/test_upgrader_codegen.py
+++ b/test/mobile/test_upgrader_codegen.py
@@ -6,9 +6,9 @@
from torch.jit.generate_bytecode import generate_upgraders_bytecode
from torch.testing._internal.common_utils import run_tests, TestCase
-
from torchgen.operator_versions.gen_mobile_upgraders import sort_upgrader, write_cpp
+
pytorch_caffe2_dir = Path(__file__).resolve().parents[2]
diff --git a/test/mobile/test_upgraders.py b/test/mobile/test_upgraders.py
index 910c2bf..5ebf9a2 100644
--- a/test/mobile/test_upgraders.py
+++ b/test/mobile/test_upgraders.py
@@ -6,10 +6,10 @@
import torch
import torch.utils.bundled_inputs
-
from torch.jit.mobile import _load_for_lite_interpreter
from torch.testing._internal.common_utils import run_tests, TestCase
+
pytorch_test_dir = Path(__file__).resolve().parents[1]
diff --git a/test/onnx/onnx_test_common.py b/test/onnx/onnx_test_common.py
index 0cb97f4..4d9d807 100644
--- a/test/onnx/onnx_test_common.py
+++ b/test/onnx/onnx_test_common.py
@@ -3,7 +3,6 @@
from __future__ import annotations
import contextlib
-
import copy
import dataclasses
import io
@@ -26,7 +25,6 @@
)
import numpy as np
-
import onnxruntime
import pytest
import pytorch_test_common
@@ -40,6 +38,7 @@
from torch.testing._internal.opinfo import core as opinfo_core
from torch.types import Number
+
_NumericType = Union[Number, torch.Tensor, np.ndarray]
_ModelType = Union[torch.nn.Module, Callable, torch_export.ExportedProgram]
_InputArgsType = Optional[
@@ -48,8 +47,7 @@
_OutputsType = Sequence[_NumericType]
onnx_model_dir = os.path.join(
- os.path.dirname(os.path.realpath(__file__)),
- os.pardir,
+ os.path.dirname(os.path.dirname(os.path.realpath(__file__))),
"repos",
"onnx",
"onnx",
@@ -57,11 +55,7 @@
"test",
"data",
)
-
-
pytorch_converted_dir = os.path.join(onnx_model_dir, "pytorch-converted")
-
-
pytorch_operator_dir = os.path.join(onnx_model_dir, "pytorch-operator")
diff --git a/test/quantization/core/test_docs.py b/test/quantization/core/test_docs.py
index 6e5a7cc..4244fae 100644
--- a/test/quantization/core/test_docs.py
+++ b/test/quantization/core/test_docs.py
@@ -50,7 +50,7 @@
"been updated to have the correct relative path between "
"test_docs.py and the docs."
)
- pytorch_root = core_dir.parent.parent.parent
+ pytorch_root = core_dir.parents[2]
return pytorch_root / path_from_pytorch
path_to_file = get_correct_path(path_from_pytorch)
diff --git a/test/test_typing.py b/test/test_typing.py
index 3793700..67b96a5 100644
--- a/test/test_typing.py
+++ b/test/test_typing.py
@@ -5,7 +5,6 @@
import os
import re
import shutil
-
import unittest
from collections import defaultdict
from threading import Lock
@@ -18,6 +17,7 @@
TestCase,
)
+
try:
from mypy import api
except ImportError:
@@ -30,7 +30,7 @@
REVEAL_DIR = os.path.join(DATA_DIR, "reveal")
PASS_DIR = os.path.join(DATA_DIR, "pass")
FAIL_DIR = os.path.join(DATA_DIR, "fail")
-MYPY_INI = os.path.join(DATA_DIR, os.pardir, os.pardir, "mypy.ini")
+MYPY_INI = os.path.join(os.path.dirname(os.path.dirname(DATA_DIR)), "mypy.ini")
CACHE_DIR = os.path.join(DATA_DIR, ".mypy_cache")
diff --git a/tools/amd_build/build_amd.py b/tools/amd_build/build_amd.py
index 96047e6..e05d834 100755
--- a/tools/amd_build/build_amd.py
+++ b/tools/amd_build/build_amd.py
@@ -4,17 +4,15 @@
import argparse
import os
import sys
+from pathlib import Path
-sys.path.append(
- os.path.realpath(
- os.path.join(
- __file__, os.path.pardir, os.path.pardir, os.path.pardir, "torch", "utils"
- )
- )
-)
+
+REPO_ROOT = Path(__file__).parents[2].resolve()
+sys.path.append(str(REPO_ROOT / "torch" / "utils"))
from hipify import hipify_python # type: ignore[import]
+
parser = argparse.ArgumentParser(
description="Top-level script for HIPifying, filling in most common parameters"
)
@@ -52,7 +50,7 @@
args = parser.parse_args()
amd_build_dir = os.path.dirname(os.path.realpath(__file__))
-proj_dir = os.path.join(os.path.dirname(os.path.dirname(amd_build_dir)))
+proj_dir = os.path.dirname(os.path.dirname(amd_build_dir))
if args.project_directory:
proj_dir = args.project_directory
diff --git a/tools/build_libtorch.py b/tools/build_libtorch.py
index 3b85d41..07acdd1 100644
--- a/tools/build_libtorch.py
+++ b/tools/build_libtorch.py
@@ -1,16 +1,18 @@
import argparse
import sys
-from os.path import abspath, dirname
+from pathlib import Path
-# By appending pytorch_root to sys.path, this module can import other torch
+
+# By appending REPO_ROOT to sys.path, this module can import other torch
# modules even when run as a standalone script. i.e., it's okay either you
# do `python build_libtorch.py` or `python -m tools.build_libtorch`.
-pytorch_root = dirname(dirname(abspath(__file__)))
-sys.path.append(pytorch_root)
+REPO_ROOT = Path(__file__).absolute().parent.parent
+sys.path.append(str(REPO_ROOT))
from tools.build_pytorch_libs import build_caffe2
from tools.setup_helpers.cmake import CMake
+
if __name__ == "__main__":
# Placeholder for future interface. For now just gives a nice -h.
parser = argparse.ArgumentParser(description="Build libtorch")
diff --git a/tools/code_coverage/package/oss/utils.py b/tools/code_coverage/package/oss/utils.py
index 1cb67dc..6028283 100644
--- a/tools/code_coverage/package/oss/utils.py
+++ b/tools/code_coverage/package/oss/utils.py
@@ -42,9 +42,7 @@
def get_pytorch_folder() -> str:
# TOOLS_FOLDER in oss: pytorch/tools/code_coverage
return os.path.abspath(
- os.environ.get(
- "PYTORCH_FOLDER", os.path.join(TOOLS_FOLDER, os.path.pardir, os.path.pardir)
- )
+ os.environ.get("PYTORCH_FOLDER", os.path.dirname(os.path.dirname(TOOLS_FOLDER)))
)
diff --git a/tools/code_coverage/package/util/setting.py b/tools/code_coverage/package/util/setting.py
index ed5efc3..bb8029f 100644
--- a/tools/code_coverage/package/util/setting.py
+++ b/tools/code_coverage/package/util/setting.py
@@ -1,13 +1,12 @@
import os
from enum import Enum
+from pathlib import Path
from typing import Dict, List, Set
# <project folder>
HOME_DIR = os.environ["HOME"]
-TOOLS_FOLDER = os.path.join(
- os.path.dirname(os.path.realpath(__file__)), os.path.pardir, os.path.pardir
-)
+TOOLS_FOLDER = str(Path(__file__).resolve().parents[2])
# <profile folder>
diff --git a/tools/gen_vulkan_spv.py b/tools/gen_vulkan_spv.py
index 653a18a..334fb20 100644
--- a/tools/gen_vulkan_spv.py
+++ b/tools/gen_vulkan_spv.py
@@ -8,24 +8,28 @@
import io
import os
import re
-import sys
-from itertools import product
-
-sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import subprocess
+import sys
import textwrap
from dataclasses import dataclass
+from itertools import product
+from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import yaml
from yaml.constructor import ConstructorError
from yaml.nodes import MappingNode
+
try:
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader # type: ignore[assignment, misc]
+
+REPO_ROOT = Path(__file__).absolute().parent.parent
+sys.path.append(str(REPO_ROOT))
+
CPP_H_NAME = "spv.h"
CPP_SRC_NAME = "spv.cpp"
diff --git a/tools/linter/adapters/s3_init.py b/tools/linter/adapters/s3_init.py
index 260a3d4..80e61ef 100644
--- a/tools/linter/adapters/s3_init.py
+++ b/tools/linter/adapters/s3_init.py
@@ -11,6 +11,7 @@
import urllib.request
from pathlib import Path
+
# String representing the host platform (e.g. Linux, Darwin).
HOST_PLATFORM = platform.system()
HOST_PLATFORM_ARCH = platform.system() + "-" + platform.processor()
@@ -25,10 +26,7 @@
PYTORCH_ROOT = result.stdout.decode("utf-8").strip()
except subprocess.CalledProcessError:
# If git is not installed, compute repo root as 3 folders up from this file
- path_ = os.path.abspath(__file__)
- for _ in range(4):
- path_ = os.path.dirname(path_)
- PYTORCH_ROOT = path_
+ PYTORCH_ROOT = str(Path(__file__).absolute().parents[3])
DRY_RUN = False
diff --git a/tools/onnx/update_default_opset_version.py b/tools/onnx/update_default_opset_version.py
index a6446b1..88a98e5 100755
--- a/tools/onnx/update_default_opset_version.py
+++ b/tools/onnx/update_default_opset_version.py
@@ -12,10 +12,10 @@
import argparse
import datetime
import os
-import pathlib
import re
import subprocess
import sys
+from pathlib import Path
from subprocess import DEVNULL
from typing import Any
@@ -30,7 +30,7 @@
def main(args: Any) -> None:
- pytorch_dir = pathlib.Path(__file__).parent.parent.parent.resolve()
+ pytorch_dir = Path(__file__).parents[2].resolve()
onnx_dir = pytorch_dir / "third_party" / "onnx"
os.chdir(onnx_dir)
diff --git a/tools/setup_helpers/cmake.py b/tools/setup_helpers/cmake.py
index 4d10b3d..b9677e0 100644
--- a/tools/setup_helpers/cmake.py
+++ b/tools/setup_helpers/cmake.py
@@ -7,6 +7,7 @@
import sys
import sysconfig
from distutils.version import LooseVersion
+from pathlib import Path
from subprocess import CalledProcessError, check_call, check_output
from typing import Any, cast, Dict, List, Optional
@@ -172,9 +173,7 @@
toolset_expr = ",".join([f"{k}={v}" for k, v in toolset_dict.items()])
args.append("-T" + toolset_expr)
- base_dir = os.path.dirname(
- os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
- )
+ base_dir = str(Path(__file__).absolute().parents[2])
install_dir = os.path.join(base_dir, "torch")
_mkdir_p(install_dir)
diff --git a/tools/setup_helpers/gen.py b/tools/setup_helpers/gen.py
index 3ca9a87..fb3b21f 100644
--- a/tools/setup_helpers/gen.py
+++ b/tools/setup_helpers/gen.py
@@ -1,11 +1,13 @@
# Little stub file to get BUILD.bazel to play along
-import os.path
import sys
+from pathlib import Path
-root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
-sys.path.insert(0, root)
+
+REPO_ROOT = Path(__file__).absolute().parents[2]
+sys.path.insert(0, str(REPO_ROOT))
import torchgen.gen
+
torchgen.gen.main()
diff --git a/tools/setup_helpers/gen_unboxing.py b/tools/setup_helpers/gen_unboxing.py
index b876b72..6e733d7 100644
--- a/tools/setup_helpers/gen_unboxing.py
+++ b/tools/setup_helpers/gen_unboxing.py
@@ -1,11 +1,13 @@
# Little stub file to get BUILD.bazel to play along
-import os.path
import sys
+from pathlib import Path
-root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
-sys.path.insert(0, root)
+
+REPO_ROOT = Path(__file__).absolute().parents[2]
+sys.path.insert(0, str(REPO_ROOT))
import tools.jit.gen_unboxing
+
tools.jit.gen_unboxing.main(sys.argv[1:])
diff --git a/tools/setup_helpers/generate_code.py b/tools/setup_helpers/generate_code.py
index 6c939fe..86e98b2 100644
--- a/tools/setup_helpers/generate_code.py
+++ b/tools/setup_helpers/generate_code.py
@@ -1,23 +1,25 @@
import argparse
import os
-import pathlib
import sys
+from pathlib import Path
from typing import Any, cast, Optional
import yaml
+
try:
# use faster C loader if available
from yaml import CSafeLoader as YamlLoader
except ImportError:
from yaml import SafeLoader as YamlLoader # type: ignore[assignment, misc]
+
NATIVE_FUNCTIONS_PATH = "aten/src/ATen/native/native_functions.yaml"
TAGS_PATH = "aten/src/ATen/native/tags.yaml"
def generate_code(
- gen_dir: pathlib.Path,
+ gen_dir: Path,
native_functions_path: Optional[str] = None,
tags_path: Optional[str] = None,
install_dir: Optional[str] = None,
@@ -28,6 +30,7 @@
) -> None:
from tools.autograd.gen_annotated_fn_args import gen_annotated
from tools.autograd.gen_autograd import gen_autograd, gen_autograd_python
+
from torchgen.selective_build.selector import SelectiveBuilder
# Build ATen based Variable classes
@@ -39,7 +42,7 @@
autograd_gen_dir = os.path.join(install_dir, "autograd", "generated")
for d in (autograd_gen_dir, python_install_dir):
os.makedirs(d, exist_ok=True)
- autograd_dir = os.fspath(pathlib.Path(__file__).parent.parent / "autograd")
+ autograd_dir = os.fspath(Path(__file__).parent.parent / "autograd")
if subset == "pybindings" or not subset:
gen_autograd_python(
@@ -106,8 +109,9 @@
operators_yaml_path: Optional[str],
) -> Any:
# cwrap depends on pyyaml, so we can't import it earlier
- root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
- sys.path.insert(0, root)
+ REPO_ROOT = Path(__file__).absolute().parents[2]
+ sys.path.insert(0, str(REPO_ROOT))
+
from torchgen.selective_build.selector import SelectiveBuilder
assert not (
@@ -131,8 +135,8 @@
parser.add_argument("--tags-path")
parser.add_argument(
"--gen-dir",
- type=pathlib.Path,
- default=pathlib.Path("."),
+ type=Path,
+ default=Path("."),
help="Root directory where to install files. Defaults to the current working directory.",
)
parser.add_argument(
diff --git a/tools/stats/export_test_times.py b/tools/stats/export_test_times.py
index 2b9c0c4..ae87718 100644
--- a/tools/stats/export_test_times.py
+++ b/tools/stats/export_test_times.py
@@ -1,8 +1,10 @@
-import pathlib
import sys
+from pathlib import Path
-REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
+
+REPO_ROOT = Path(__file__).resolve().parents[2]
sys.path.append(str(REPO_ROOT))
+
from tools.stats.import_test_stats import get_test_class_times, get_test_times
diff --git a/tools/stats/import_test_stats.py b/tools/stats/import_test_stats.py
index 513edb1..bf61dfa 100644
--- a/tools/stats/import_test_stats.py
+++ b/tools/stats/import_test_stats.py
@@ -3,12 +3,13 @@
import datetime
import json
import os
-import pathlib
import shutil
+from pathlib import Path
from typing import Any, Callable, cast, Dict, List, Optional, Union
from urllib.request import urlopen
-REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
+
+REPO_ROOT = Path(__file__).resolve().parents[2]
def get_disabled_issues() -> List[str]:
@@ -20,7 +21,7 @@
SLOW_TESTS_FILE = ".pytorch-slow-tests.json"
DISABLED_TESTS_FILE = ".pytorch-disabled-tests.json"
-ADDITIONAL_CI_FILES_FOLDER = pathlib.Path(".additional_ci_files")
+ADDITIONAL_CI_FILES_FOLDER = Path(".additional_ci_files")
TEST_TIMES_FILE = "test-times.json"
TEST_CLASS_TIMES_FILE = "test-class-times.json"
TEST_FILE_RATINGS_FILE = "test-file-ratings.json"
@@ -34,7 +35,7 @@
def fetch_and_cache(
- dirpath: Union[str, pathlib.Path],
+ dirpath: Union[str, Path],
name: str,
url: str,
process_fn: Callable[[Dict[str, Any]], Dict[str, Any]],
@@ -42,7 +43,7 @@
"""
This fetch and cache utils allows sharing between different process.
"""
- pathlib.Path(dirpath).mkdir(exist_ok=True)
+ Path(dirpath).mkdir(exist_ok=True)
path = os.path.join(dirpath, name)
print(f"Downloading {url} to {path}")
@@ -50,7 +51,7 @@
def is_cached_file_valid() -> bool:
# Check if the file is new enough (see: FILE_CACHE_LIFESPAN_SECONDS). A real check
# could make a HEAD request and check/store the file's ETag
- fname = pathlib.Path(path)
+ fname = Path(path)
now = datetime.datetime.now()
mtime = datetime.datetime.fromtimestamp(fname.stat().st_mtime)
diff = now - mtime
diff --git a/tools/test/heuristics/test_heuristics.py b/tools/test/heuristics/test_heuristics.py
index a1d1534..8de72d1 100644
--- a/tools/test/heuristics/test_heuristics.py
+++ b/tools/test/heuristics/test_heuristics.py
@@ -1,14 +1,16 @@
# For testing specific heuristics
import io
import json
-import pathlib
import sys
import unittest
+from pathlib import Path
from typing import Any, Dict, List, Set
from unittest import mock
-REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
+
+REPO_ROOT = Path(__file__).resolve().parents[3]
sys.path.append(str(REPO_ROOT))
+
from tools.test.heuristics.test_interface import TestTD
from tools.testing.target_determination.determinator import TestPrioritizations
from tools.testing.target_determination.heuristics.filepath import (
@@ -23,6 +25,7 @@
)
from tools.testing.test_run import TestRun
+
sys.path.remove(str(REPO_ROOT))
HEURISTIC_CLASS = "tools.testing.target_determination.heuristics.historical_class_failure_correlation."
diff --git a/tools/test/heuristics/test_interface.py b/tools/test/heuristics/test_interface.py
index df122ab..5fe78c1 100644
--- a/tools/test/heuristics/test_interface.py
+++ b/tools/test/heuristics/test_interface.py
@@ -1,13 +1,16 @@
-import pathlib
import sys
import unittest
+from pathlib import Path
from typing import Any, Dict, List
-REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
+
+REPO_ROOT = Path(__file__).resolve().parents[3]
sys.path.append(str(REPO_ROOT))
+
import tools.testing.target_determination.heuristics.interface as interface
from tools.testing.test_run import TestRun
+
sys.path.remove(str(REPO_ROOT))
diff --git a/tools/test/heuristics/test_utils.py b/tools/test/heuristics/test_utils.py
index 9348972..e270336 100644
--- a/tools/test/heuristics/test_utils.py
+++ b/tools/test/heuristics/test_utils.py
@@ -1,14 +1,16 @@
-import pathlib
import sys
import unittest
+from pathlib import Path
from typing import Any, Dict
-REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
+REPO_ROOT = Path(__file__).resolve().parents[3]
sys.path.append(str(REPO_ROOT))
+
import tools.testing.target_determination.heuristics.utils as utils
from tools.testing.test_run import TestRun
+
sys.path.remove(str(REPO_ROOT))
diff --git a/tools/test/test_gen_backend_stubs.py b/tools/test/test_gen_backend_stubs.py
index 34bd2fe..bf1050f 100644
--- a/tools/test/test_gen_backend_stubs.py
+++ b/tools/test/test_gen_backend_stubs.py
@@ -1,6 +1,5 @@
# Owner(s): ["module: codegen"]
-import os
import tempfile
import unittest
from typing import Optional
@@ -8,12 +7,8 @@
import expecttest
from torchgen.gen import _GLOBAL_PARSE_NATIVE_YAML_CACHE # noqa: F401
-
from torchgen.gen_backend_stubs import run
-path = os.path.dirname(os.path.realpath(__file__))
-gen_backend_stubs_path = os.path.join(path, "../torchgen/gen_backend_stubs.py")
-
# gen_backend_stubs.py is an integration point that is called directly by external backends.
# The tests here are to confirm that badly formed inputs result in reasonable error messages.
diff --git a/tools/test/test_test_run.py b/tools/test/test_test_run.py
index 25fa810..c3fc273 100644
--- a/tools/test/test_test_run.py
+++ b/tools/test/test_test_run.py
@@ -1,8 +1,9 @@
-import pathlib
import sys
import unittest
+from pathlib import Path
-REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
+
+REPO_ROOT = Path(__file__).resolve().parents[2]
try:
# using tools/ to optimize test run.
sys.path.append(str(REPO_ROOT))
diff --git a/tools/test/test_test_selections.py b/tools/test/test_test_selections.py
index cc9bf5f..5de7a7a 100644
--- a/tools/test/test_test_selections.py
+++ b/tools/test/test_test_selections.py
@@ -1,12 +1,13 @@
import functools
-import pathlib
import random
import sys
import unittest
from collections import defaultdict
+from pathlib import Path
from typing import Dict, List, Tuple
-REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
+
+REPO_ROOT = Path(__file__).resolve().parents[2]
try:
# using tools/ to optimize test run.
sys.path.append(str(REPO_ROOT))
diff --git a/tools/test/test_upload_stats_lib.py b/tools/test/test_upload_stats_lib.py
index 0baf323..c5447dc 100644
--- a/tools/test/test_upload_stats_lib.py
+++ b/tools/test/test_upload_stats_lib.py
@@ -1,17 +1,19 @@
import decimal
import inspect
-import pathlib
import sys
import unittest
+from pathlib import Path
from typing import Any, Dict
from unittest import mock
-REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
-sys.path.insert(0, str(REPO_ROOT))
-from tools.stats.upload_metrics import add_global_metric, emit_metric
+REPO_ROOT = Path(__file__).resolve().parents[2]
+sys.path.insert(0, str(REPO_ROOT))
+
+from tools.stats.upload_metrics import add_global_metric, emit_metric
from tools.stats.upload_stats_lib import BATCH_SIZE, upload_to_rockset
+
sys.path.remove(str(REPO_ROOT))
# default values
diff --git a/tools/testing/discover_tests.py b/tools/testing/discover_tests.py
index 5c37fdd..0ee20ff 100644
--- a/tools/testing/discover_tests.py
+++ b/tools/testing/discover_tests.py
@@ -4,10 +4,11 @@
from pathlib import Path
from typing import List, Optional, Union
+
CPP_TEST_PREFIX = "cpp"
CPP_TEST_PATH = "build/bin"
CPP_TESTS_DIR = os.path.abspath(os.getenv("CPP_TESTS_DIR", default=CPP_TEST_PATH))
-REPO_ROOT = Path(__file__).resolve().parent.parent.parent
+REPO_ROOT = Path(__file__).resolve().parents[2]
def parse_test_module(test: str) -> str:
diff --git a/tools/testing/do_target_determination_for_s3.py b/tools/testing/do_target_determination_for_s3.py
index 32ea85b..27a0fbb 100644
--- a/tools/testing/do_target_determination_for_s3.py
+++ b/tools/testing/do_target_determination_for_s3.py
@@ -1,10 +1,10 @@
import json
import os
-import pathlib
import sys
+from pathlib import Path
-REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
+REPO_ROOT = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(REPO_ROOT))
from tools.stats.import_test_stats import (
@@ -18,7 +18,6 @@
get_test_times,
)
from tools.stats.upload_metrics import emit_metric
-
from tools.testing.discover_tests import TESTS
from tools.testing.target_determination.determinator import (
AggregatedHeuristics,
@@ -26,6 +25,7 @@
TestPrioritizations,
)
+
sys.path.remove(str(REPO_ROOT))
diff --git a/tools/testing/explicit_ci_jobs.py b/tools/testing/explicit_ci_jobs.py
index 594e00d..4a3acd6 100755
--- a/tools/testing/explicit_ci_jobs.py
+++ b/tools/testing/explicit_ci_jobs.py
@@ -2,16 +2,15 @@
import argparse
import fnmatch
-import pathlib
import subprocess
import textwrap
-
+from pathlib import Path
from typing import Any, Dict, List
import yaml
-REPO_ROOT = pathlib.Path(__file__).parent.parent.parent
+REPO_ROOT = Path(__file__).parents[2]
CONFIG_YML = REPO_ROOT / ".circleci" / "config.yml"
WORKFLOWS_DIR = REPO_ROOT / ".github" / "workflows"
diff --git a/tools/testing/modulefinder_determinator.py b/tools/testing/modulefinder_determinator.py
index ba58d75..856111a 100644
--- a/tools/testing/modulefinder_determinator.py
+++ b/tools/testing/modulefinder_determinator.py
@@ -1,11 +1,12 @@
import modulefinder
import os
-import pathlib
import sys
import warnings
+from pathlib import Path
from typing import Any, Dict, List, Set
-REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
+
+REPO_ROOT = Path(__file__).resolve().parents[2]
# These tests are slow enough that it's worth calculating whether the patch
# touched any related files first. This list was manually generated, but for every
diff --git a/tools/testing/target_determination/gen_artifact.py b/tools/testing/target_determination/gen_artifact.py
index c5165cb..794e572 100644
--- a/tools/testing/target_determination/gen_artifact.py
+++ b/tools/testing/target_determination/gen_artifact.py
@@ -1,9 +1,10 @@
import json
import os
-import pathlib
+from pathlib import Path
from typing import Any, List
-REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
+
+REPO_ROOT = Path(__file__).resolve().parents[3]
def gen_ci_artifact(included: List[Any], excluded: List[Any]) -> None:
diff --git a/tools/testing/target_determination/heuristics/filepath.py b/tools/testing/target_determination/heuristics/filepath.py
index 066d13f..cd51296 100644
--- a/tools/testing/target_determination/heuristics/filepath.py
+++ b/tools/testing/target_determination/heuristics/filepath.py
@@ -8,14 +8,14 @@
HeuristicInterface,
TestPrioritizations,
)
-
from tools.testing.target_determination.heuristics.utils import (
normalize_ratings,
query_changed_files,
)
from tools.testing.test_run import TestRun
-REPO_ROOT = Path(__file__).parent.parent.parent.parent
+
+REPO_ROOT = Path(__file__).parents[3]
keyword_synonyms: Dict[str, List[str]] = {
"amp": ["mixed_precision"],
diff --git a/tools/testing/target_determination/heuristics/llm.py b/tools/testing/target_determination/heuristics/llm.py
index d3021d9..3a94ceb 100644
--- a/tools/testing/target_determination/heuristics/llm.py
+++ b/tools/testing/target_determination/heuristics/llm.py
@@ -14,7 +14,7 @@
from tools.testing.test_run import TestRun
-REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent.parent
+REPO_ROOT = Path(__file__).resolve().parents[4]
class LLM(HeuristicInterface):
diff --git a/tools/testing/target_determination/heuristics/previously_failed_in_pr.py b/tools/testing/target_determination/heuristics/previously_failed_in_pr.py
index 26439f2..b479de2 100644
--- a/tools/testing/target_determination/heuristics/previously_failed_in_pr.py
+++ b/tools/testing/target_determination/heuristics/previously_failed_in_pr.py
@@ -8,7 +8,6 @@
TD_HEURISTIC_PREVIOUSLY_FAILED,
TD_HEURISTIC_PREVIOUSLY_FAILED_ADDITIONAL,
)
-
from tools.testing.target_determination.heuristics.interface import (
HeuristicInterface,
TestPrioritizations,
@@ -18,7 +17,8 @@
)
from tools.testing.test_run import TestRun
-REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent.parent
+
+REPO_ROOT = Path(__file__).resolve().parents[4]
class PreviouslyFailedInPR(HeuristicInterface):
diff --git a/tools/testing/target_determination/heuristics/utils.py b/tools/testing/target_determination/heuristics/utils.py
index 7d8297d..f5ad74f 100644
--- a/tools/testing/target_determination/heuristics/utils.py
+++ b/tools/testing/target_determination/heuristics/utils.py
@@ -11,7 +11,8 @@
from tools.testing.test_run import TestRun
-REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent.parent
+
+REPO_ROOT = Path(__file__).resolve().parents[4]
def python_test_file_to_test_name(tests: Set[str]) -> Set[str]:
diff --git a/tools/testing/test_selections.py b/tools/testing/test_selections.py
index 3e43edd..b7f2c8d 100644
--- a/tools/testing/test_selections.py
+++ b/tools/testing/test_selections.py
@@ -2,13 +2,13 @@
import os
import subprocess
from pathlib import Path
-
from typing import Callable, Dict, FrozenSet, List, Optional, Sequence, Tuple
from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests
from tools.testing.test_run import ShardedTest, TestRun
-REPO_ROOT = Path(__file__).resolve().parent.parent.parent
+
+REPO_ROOT = Path(__file__).resolve().parents[2]
IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1"
BUILD_ENVIRONMENT = os.getenv("BUILD_ENVIRONMENT", "")
diff --git a/tools/vscode_settings.py b/tools/vscode_settings.py
index 21fddf6..edfaec3 100755
--- a/tools/vscode_settings.py
+++ b/tools/vscode_settings.py
@@ -2,6 +2,7 @@
from pathlib import Path
+
try:
# VS Code settings allow comments and trailing commas, which are not valid JSON.
import json5 as json # type: ignore[import]
diff --git a/torch/_inductor/runtime/compile_tasks.py b/torch/_inductor/runtime/compile_tasks.py
index 17788ab..3a96268 100644
--- a/torch/_inductor/runtime/compile_tasks.py
+++ b/torch/_inductor/runtime/compile_tasks.py
@@ -5,6 +5,7 @@
import os
import sys
import warnings
+from pathlib import Path
from types import ModuleType
from typing import Any, Callable, Dict
@@ -51,15 +52,13 @@
def _set_triton_ptxas_path() -> None:
if os.environ.get("TRITON_PTXAS_PATH") is not None:
return
- ptxas_path = os.path.abspath(
- os.path.join(os.path.dirname(__file__), "..", "bin", "ptxas")
- )
- if not os.path.exists(ptxas_path):
+ ptxas = Path(__file__).absolute().parents[1] / "bin" / "ptxas"
+ if not ptxas.exists():
return
- if os.path.isfile(ptxas_path) and os.access(ptxas_path, os.X_OK):
- os.environ["TRITON_PTXAS_PATH"] = ptxas_path
+ if ptxas.is_file() and os.access(ptxas, os.X_OK):
+ os.environ["TRITON_PTXAS_PATH"] = str(ptxas)
else:
- warnings.warn(f"{ptxas_path} exists but is not an executable")
+ warnings.warn(f"{ptxas} exists but is not an executable")
def _worker_compile_triton(load_kernel: Callable[[], Any], extra_env: Dict[str, str]):
diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
index 8daeefd..a463471 100644
--- a/torch/testing/_internal/common_utils.py
+++ b/torch/testing/_internal/common_utils.py
@@ -4579,7 +4579,7 @@
path = torch_root / 'lib' / lib_name
if os.path.exists(path):
return path
- torch_root = Path(__file__).resolve().parent.parent.parent
+ torch_root = Path(__file__).resolve().parents[2]
return torch_root / 'build' / 'lib' / lib_name
def skip_but_pass_in_sandcastle(reason):
diff --git a/torchgen/gen_backend_stubs.py b/torchgen/gen_backend_stubs.py
index 8f2b07e..208e153 100644
--- a/torchgen/gen_backend_stubs.py
+++ b/torchgen/gen_backend_stubs.py
@@ -1,8 +1,8 @@
import argparse
import os
-import pathlib
import re
from collections import Counter, defaultdict, namedtuple
+from pathlib import Path
from typing import Dict, List, Optional, Sequence, Set, Union
import yaml
@@ -527,7 +527,7 @@
source_yaml: str, output_dir: str, dry_run: bool, impl_path: Optional[str] = None
) -> None:
# Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
- pytorch_root = pathlib.Path(__file__).parent.parent.absolute()
+ pytorch_root = Path(__file__).absolute().parent.parent
template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates")
def make_file_manager(install_dir: str) -> FileManager:
diff --git a/torchgen/gen_lazy_tensor.py b/torchgen/gen_lazy_tensor.py
index 6dd6f45..52b4209 100644
--- a/torchgen/gen_lazy_tensor.py
+++ b/torchgen/gen_lazy_tensor.py
@@ -1,7 +1,7 @@
import argparse
import os
-import pathlib
from collections import namedtuple
+from pathlib import Path
from typing import (
Any,
Callable,
@@ -261,7 +261,7 @@
options = parser.parse_args()
# Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
- torch_root = pathlib.Path(__file__).parent.parent.parent.absolute()
+ torch_root = Path(__file__).absolute().parents[2]
aten_path = str(torch_root / "aten" / "src" / "ATen")
lazy_ir_generator: Type[GenLazyIR] = default_args.lazy_ir_generator
if options.gen_ts_lowerings: