[BE] Enable ruff's UP rules and autoformat utils/ (#105424)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105424
Approved by: https://github.com/ezyang, https://github.com/malfet
diff --git a/torch/utils/_freeze.py b/torch/utils/_freeze.py
index 5245ac0..6590ff4 100644
--- a/torch/utils/_freeze.py
+++ b/torch/utils/_freeze.py
@@ -237,7 +237,7 @@
module_mangled_name = "__".join(module_qualname)
c_name = "M_" + module_mangled_name
- with open(path, "r") as src_file:
+ with open(path) as src_file:
co = self.compile_string(src_file.read())
bytecode = marshal.dumps(co)
diff --git a/torch/utils/benchmark/examples/end_to_end.py b/torch/utils/benchmark/examples/end_to_end.py
index 5e0f4271..a6d05a9 100644
--- a/torch/utils/benchmark/examples/end_to_end.py
+++ b/torch/utils/benchmark/examples/end_to_end.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""End-to-end example to test a PR for regressions:
$ python -m examples.end_to_end --pr 39850
@@ -111,7 +110,7 @@
def construct_stmt_and_label(pr, params):
if pr == "39850":
- k0, k1, k2, dim = [params[i] for i in ["k0", "k1", "k2", "dim"]]
+ k0, k1, k2, dim = (params[i] for i in ["k0", "k1", "k2", "dim"])
state = np.random.RandomState(params["random_value"])
topk_dim = state.randint(low=0, high=dim)
dim_size = [k0, k1, k2][topk_dim]
@@ -291,7 +290,7 @@
)
_, result_log_file = tempfile.mkstemp(suffix=".log")
- with open(result_log_file, "wt") as f:
+ with open(result_log_file, "w") as f:
f.write(f"{device_str}\n\n{column_labels}\n")
print(f"\n{column_labels}\n[First twenty omitted (these tend to be noisy) ]")
for key, (r_ref, r_pr), rel_diff in results:
diff --git a/torch/utils/benchmark/examples/op_benchmark.py b/torch/utils/benchmark/examples/op_benchmark.py
index 65b69d8..b7536b9 100644
--- a/torch/utils/benchmark/examples/op_benchmark.py
+++ b/torch/utils/benchmark/examples/op_benchmark.py
@@ -37,13 +37,13 @@
assert_dicts_equal(float_params, int_params)
assert_dicts_equal(float_tensor_params["x"], int_tensor_params["x"])
- float_measurement, int_measurement = [
+ float_measurement, int_measurement = (
Timer(
stmt,
globals=tensors,
).blocked_autorange(min_run_time=_MEASURE_TIME)
for tensors in (float_tensors, int_tensors)
- ]
+ )
descriptions = []
for name in float_tensors:
diff --git a/torch/utils/benchmark/examples/sparse/op_benchmark.py b/torch/utils/benchmark/examples/sparse/op_benchmark.py
index f9ee17d..d7e97d3 100644
--- a/torch/utils/benchmark/examples/sparse/op_benchmark.py
+++ b/torch/utils/benchmark/examples/sparse/op_benchmark.py
@@ -32,13 +32,13 @@
assert_dicts_equal(float_params, int_params)
assert_dicts_equal(float_tensor_params["x"], int_tensor_params["x"])
- float_measurement, int_measurement = [
+ float_measurement, int_measurement = (
Timer(
stmt,
globals=tensors,
).blocked_autorange(min_run_time=_MEASURE_TIME)
for tensors in (float_tensors, int_tensors)
- ]
+ )
descriptions = []
for name in float_tensors:
diff --git a/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py b/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py
index d8284ee..c703955 100644
--- a/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py
+++ b/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py
@@ -27,7 +27,7 @@
results = []
for tensors, tensor_params, params in spectral_fuzzer.take(samples):
shape = [params['k0'], params['k1'], params['k2']][:params['ndim']]
- str_shape = ' x '.join(["{:<4}".format(s) for s in shape])
+ str_shape = ' x '.join([f"{s:<4}" for s in shape])
sub_label = f"{str_shape} {'' if tensor_params['x']['is_contiguous'] else '(discontiguous)'}"
for dim in _dim_options(params['ndim']):
for nthreads in (1, 4, 16) if not cuda else (1,):
diff --git a/torch/utils/benchmark/utils/common.py b/torch/utils/benchmark/utils/common.py
index a8bbef3..c1636dd 100644
--- a/torch/utils/benchmark/utils/common.py
+++ b/torch/utils/benchmark/utils/common.py
@@ -325,7 +325,7 @@
if not os.path.exists(owner_file):
continue
- with open(owner_file, "rt") as f:
+ with open(owner_file) as f:
owner_pid = int(f.read())
if owner_pid == os.getpid():
@@ -349,7 +349,7 @@
os.makedirs(path, exist_ok=False)
if use_dev_shm:
- with open(os.path.join(path, "owner.pid"), "wt") as f:
+ with open(os.path.join(path, "owner.pid"), "w") as f:
f.write(str(os.getpid()))
return path
diff --git a/torch/utils/benchmark/utils/cpp_jit.py b/torch/utils/benchmark/utils/cpp_jit.py
index 65b8c70..a09f1a0 100644
--- a/torch/utils/benchmark/utils/cpp_jit.py
+++ b/torch/utils/benchmark/utils/cpp_jit.py
@@ -137,7 +137,7 @@
os.makedirs(build_dir, exist_ok=True)
src_path = os.path.join(build_dir, "timer_src.cpp")
- with open(src_path, "wt") as f:
+ with open(src_path, "w") as f:
f.write(src)
# `cpp_extension` has its own locking scheme, so we don't need our lock.
@@ -154,7 +154,7 @@
def compile_timeit_template(*, stmt: str, setup: str, global_setup: str) -> TimeitModuleType:
template_path: str = os.path.join(SOURCE_ROOT, "timeit_template.cpp")
- with open(template_path, "rt") as f:
+ with open(template_path) as f:
src: str = f.read()
module = _compile_template(stmt=stmt, setup=setup, global_setup=global_setup, src=src, is_standalone=False)
@@ -164,7 +164,7 @@
def compile_callgrind_template(*, stmt: str, setup: str, global_setup: str) -> str:
template_path: str = os.path.join(SOURCE_ROOT, "valgrind_wrapper", "timer_callgrind_template.cpp")
- with open(template_path, "rt") as f:
+ with open(template_path) as f:
src: str = f.read()
target = _compile_template(stmt=stmt, setup=setup, global_setup=global_setup, src=src, is_standalone=True)
diff --git a/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py b/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py
index 71753bd..61e4348 100644
--- a/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py
+++ b/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py
@@ -28,7 +28,10 @@
CompletedProcessType = subprocess.CompletedProcess
-FunctionCount = NamedTuple("FunctionCount", [("count", int), ("function", str)])
+class FunctionCount(NamedTuple):
+ # TODO(#105471): Rename the count field
+ count: int # type: ignore[assignment]
+ function: str
@dataclasses.dataclass(repr=False, eq=False, frozen=True)
@@ -598,7 +601,7 @@
stderr=subprocess.STDOUT,
**kwargs,
)
- with open(stdout_stderr_log, "rt") as f:
+ with open(stdout_stderr_log) as f:
return invocation, f.read()
finally:
f_stdout_stderr.close()
@@ -612,7 +615,7 @@
)
script_file = os.path.join(working_dir, "timer_callgrind.py")
- with open(script_file, "wt") as f:
+ with open(script_file, "w") as f:
f.write(self._construct_script(
task_spec,
globals=GlobalsBridge(globals, data_dir),
@@ -652,7 +655,7 @@
if valgrind_invocation.returncode:
error_report = ""
if os.path.exists(error_log):
- with open(error_log, "rt") as f:
+ with open(error_log) as f:
error_report = f.read()
if not error_report:
error_report = "Unknown error.\n" + valgrind_invocation_output
@@ -724,7 +727,7 @@
fpath = f"{callgrind_out}.{i + 1}" # Callgrind one-indexes files.
callgrind_out_contents: Optional[str] = None
if retain_out_file:
- with open(fpath, "rt") as f:
+ with open(fpath) as f:
callgrind_out_contents = f.read()
return (
diff --git a/torch/utils/bottleneck/__main__.py b/torch/utils/bottleneck/__main__.py
index 86c1af0..f7fd209 100644
--- a/torch/utils/bottleneck/__main__.py
+++ b/torch/utils/bottleneck/__main__.py
@@ -16,7 +16,7 @@
def compiled_with_cuda(sysinfo):
if sysinfo.cuda_compiled_version:
- return 'compiled w/ CUDA {}'.format(sysinfo.cuda_compiled_version)
+ return f'compiled w/ CUDA {sysinfo.cuda_compiled_version}'
return 'not compiled w/ CUDA'
@@ -59,7 +59,7 @@
'debug_str': debug_str,
'pytorch_version': info.torch_version,
'cuda_compiled': compiled_with_cuda(info),
- 'py_version': '{}.{}'.format(sys.version_info[0], sys.version_info[1]),
+ 'py_version': f'{sys.version_info[0]}.{sys.version_info[1]}',
'cuda_runtime': cuda_avail,
'pip_version': pip_version,
'pip_list_output': pip_list_output,
@@ -138,7 +138,7 @@
result = {
'mode': mode,
- 'description': 'top {} events sorted by {}'.format(topk, sortby),
+ 'description': f'top {topk} events sorted by {sortby}',
'output': torch.autograd.profiler_util._build_table(topk_events),
'cuda_warning': cuda_warning
}
diff --git a/torch/utils/bundled_inputs.py b/torch/utils/bundled_inputs.py
index 4ae3973..ad34e15 100644
--- a/torch/utils/bundled_inputs.py
+++ b/torch/utils/bundled_inputs.py
@@ -261,11 +261,11 @@
if input_list is not None and not isinstance(input_list, Sequence):
- raise TypeError("Error inputs for function {0} is not a Sequence".format(function_name))
+ raise TypeError(f"Error inputs for function {function_name} is not a Sequence")
function_arg_types = [arg.type for arg in function.schema.arguments[1:]] # type: ignore[attr-defined]
deflated_inputs_type: ListType = ListType(TupleType(function_arg_types))
- model._c._register_attribute("_bundled_inputs_deflated_{name}".format(name=function_name), deflated_inputs_type, [])
+ model._c._register_attribute(f"_bundled_inputs_deflated_{function_name}", deflated_inputs_type, [])
if hasattr(model, "_generate_bundled_inputs_for_" + function_name):
if input_list is not None:
@@ -290,7 +290,7 @@
for inp_idx, args in enumerate(input_list):
if not isinstance(args, Tuple) and not isinstance(args, List): # type: ignore[arg-type]
raise TypeError(
- "Error bundled input for function {0} idx: {1} is not a Tuple or a List".format(function_name, inp_idx)
+ f"Error bundled input for function {function_name} idx: {inp_idx} is not a Tuple or a List"
)
deflated_args = []
parts.append("(")
@@ -314,7 +314,7 @@
# Back-channel return this expr for debugging.
if _receive_inflate_expr is not None:
_receive_inflate_expr.append(expr)
- setattr(model, "_bundled_inputs_deflated_{name}".format(name=function_name), deflated_inputs)
+ setattr(model, f"_bundled_inputs_deflated_{function_name}", deflated_inputs)
definition = textwrap.dedent("""
def _generate_bundled_inputs_for_{name}(self):
deflated = self._bundled_inputs_deflated_{name}
diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py
index 8c023c7..4da281d 100644
--- a/torch/utils/checkpoint.py
+++ b/torch/utils/checkpoint.py
@@ -66,7 +66,7 @@
return device_module
-class DefaultDeviceType(object):
+class DefaultDeviceType:
r"""
A class that manages the default device type for checkpointing.
If no non-CPU tensors are present, the default device type will
diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py
index ee3c61c..323a391 100644
--- a/torch/utils/cpp_extension.py
+++ b/torch/utils/cpp_extension.py
@@ -150,11 +150,11 @@
only once we need to get any ROCm-specific path.
'''
if ROCM_HOME is None:
- raise EnvironmentError('ROCM_HOME environment variable is not set. '
- 'Please set it to your ROCm install root.')
+ raise OSError('ROCM_HOME environment variable is not set. '
+ 'Please set it to your ROCm install root.')
elif IS_WINDOWS:
- raise EnvironmentError('Building PyTorch extensions using '
- 'ROCm and Windows is not supported.')
+ raise OSError('Building PyTorch extensions using '
+ 'ROCm and Windows is not supported.')
return os.path.join(ROCM_HOME, *paths)
@@ -264,7 +264,7 @@
if it already had the right content (to avoid triggering recompile).
'''
if os.path.exists(filename):
- with open(filename, 'r') as f:
+ with open(filename) as f:
content = f.read()
if content == new_content:
@@ -2247,8 +2247,8 @@
only once we need to get any CUDA-specific path.
'''
if CUDA_HOME is None:
- raise EnvironmentError('CUDA_HOME environment variable is not set. '
- 'Please set it to your CUDA install root.')
+ raise OSError('CUDA_HOME environment variable is not set. '
+ 'Please set it to your CUDA install root.')
return os.path.join(CUDA_HOME, *paths)
diff --git a/torch/utils/data/_utils/pin_memory.py b/torch/utils/data/_utils/pin_memory.py
index 074b89b..cdd53c2 100644
--- a/torch/utils/data/_utils/pin_memory.py
+++ b/torch/utils/data/_utils/pin_memory.py
@@ -37,7 +37,7 @@
data = pin_memory(data, device)
except Exception:
data = ExceptionWrapper(
- where="in pin memory thread for device {}".format(device_id))
+ where=f"in pin memory thread for device {device_id}")
r = (idx, data)
while not done_event.is_set():
try:
diff --git a/torch/utils/data/_utils/worker.py b/torch/utils/data/_utils/worker.py
index b4fc8e0..0d43f63 100644
--- a/torch/utils/data/_utils/worker.py
+++ b/torch/utils/data/_utils/worker.py
@@ -76,13 +76,13 @@
def __setattr__(self, key, val):
if self.__initialized:
- raise RuntimeError("Cannot assign attributes to {} objects".format(self.__class__.__name__))
+ raise RuntimeError(f"Cannot assign attributes to {self.__class__.__name__} objects")
return super().__setattr__(key, val)
def __repr__(self):
items = []
for k in self.__keys:
- items.append('{}={}'.format(k, getattr(self, k)))
+ items.append(f'{k}={getattr(self, k)}')
return '{}({})'.format(self.__class__.__name__, ', '.join(items))
@@ -252,7 +252,7 @@
fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
except Exception:
init_exception = ExceptionWrapper(
- where="in DataLoader worker process {}".format(worker_id))
+ where=f"in DataLoader worker process {worker_id}")
# When using Iterable mode, some worker can exit earlier than others due
# to the IterableDataset behaving differently for different workers.
@@ -318,7 +318,7 @@
# `ExceptionWrapper` does the correct thing.
# See NOTE [ Python Traceback Reference Cycle Problem ]
data = ExceptionWrapper(
- where="in DataLoader worker process {}".format(worker_id))
+ where=f"in DataLoader worker process {worker_id}")
data_queue.put((idx, data))
del data, idx, index, r # save memory
except KeyboardInterrupt:
diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py
index ec86f77..1c33592 100644
--- a/torch/utils/data/dataloader.py
+++ b/torch/utils/data/dataloader.py
@@ -604,7 +604,7 @@
self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
self._persistent_workers = loader.persistent_workers
self._num_yielded = 0
- self._profile_name = "enumerate(DataLoader)#{}.__next__".format(self.__class__.__name__)
+ self._profile_name = f"enumerate(DataLoader)#{self.__class__.__name__}.__next__"
def __iter__(self) -> '_BaseDataLoaderIter':
return self
@@ -1145,7 +1145,7 @@
self._mark_worker_as_unavailable(worker_id)
if len(failed_workers) > 0:
pids_str = ', '.join(str(w.pid) for w in failed_workers)
- raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
+ raise RuntimeError(f'DataLoader worker (pid(s) {pids_str}) exited unexpectedly') from e
if isinstance(e, queue.Empty):
return (False, None)
import tempfile
@@ -1281,7 +1281,7 @@
if success:
return data
else:
- raise RuntimeError('DataLoader timed out after {} seconds'.format(self._timeout))
+ raise RuntimeError(f'DataLoader timed out after {self._timeout} seconds')
elif self._pin_memory:
while self._pin_memory_thread.is_alive():
success, data = self._try_get_data()
diff --git a/torch/utils/data/datapipes/_decorator.py b/torch/utils/data/datapipes/_decorator.py
index e4cc9e4..96b7e00 100644
--- a/torch/utils/data/datapipes/_decorator.py
+++ b/torch/utils/data/datapipes/_decorator.py
@@ -80,7 +80,7 @@
elif isinstance(arg, Callable): # type:ignore[arg-type]
self.deterministic_fn = arg # type: ignore[assignment, misc]
else:
- raise TypeError("{} can not be decorated by non_deterministic".format(arg))
+ raise TypeError(f"{arg} can not be decorated by non_deterministic")
def __call__(self, *args, **kwargs):
global _determinism
diff --git a/torch/utils/data/datapipes/_typing.py b/torch/utils/data/datapipes/_typing.py
index 6377a2e..68049ba 100644
--- a/torch/utils/data/datapipes/_typing.py
+++ b/torch/utils/data/datapipes/_typing.py
@@ -234,7 +234,7 @@
return issubtype(self.param, other.param)
if isinstance(other, type):
return issubtype(self.param, other)
- raise TypeError("Expected '_DataPipeType' or 'type', but found {}".format(type(other)))
+ raise TypeError(f"Expected '_DataPipeType' or 'type', but found {type(other)}")
def issubtype_of_instance(self, other):
return issubinstance(other, self.param)
@@ -279,13 +279,13 @@
@_tp_cache
def _getitem_(self, params):
if params is None:
- raise TypeError('{}[t]: t can not be None'.format(self.__name__))
+ raise TypeError(f'{self.__name__}[t]: t can not be None')
if isinstance(params, str):
params = ForwardRef(params)
if not isinstance(params, tuple):
params = (params, )
- msg = "{}[t]: t must be a type".format(self.__name__)
+ msg = f"{self.__name__}[t]: t must be a type"
params = tuple(_type_check(p, msg) for p in params)
if isinstance(self.type.param, _GenericAlias):
@@ -303,7 +303,7 @@
'__type_class__': True})
if len(params) > 1:
- raise TypeError('Too many parameters for {} actual {}, expected 1'.format(self, len(params)))
+ raise TypeError(f'Too many parameters for {self} actual {len(params)}, expected 1')
t = _DataPipeType(params[0])
diff --git a/torch/utils/data/datapipes/dataframe/dataframes.py b/torch/utils/data/datapipes/dataframe/dataframes.py
index 06029e0..72d93cd 100644
--- a/torch/utils/data/datapipes/dataframe/dataframes.py
+++ b/torch/utils/data/datapipes/dataframe/dataframes.py
@@ -36,7 +36,7 @@
CaptureControl.disabled = True
-class CaptureControl():
+class CaptureControl:
disabled = False
@@ -184,7 +184,7 @@
return value
-class CaptureLikeMock():
+class CaptureLikeMock:
def __init__(self, name):
import unittest.mock as mock
# TODO(VitalyFedyunin): Do not use provate function here, copy own implementation instead.
@@ -232,7 +232,7 @@
def __str__(self):
variable = self.kwargs['variable']
value = self.kwargs['value']
- return "{variable} = {value}".format(variable=variable, value=value)
+ return f"{variable} = {value}"
def execute(self):
self.kwargs['variable'].calculated_value = self.kwargs['value'].execute()
@@ -272,7 +272,7 @@
self.key = key
def __str__(self):
- return "%s[%s]" % (self.left, get_val(self.key))
+ return f"{self.left}[{get_val(self.key)}]"
def execute(self):
left = self.left.execute()
@@ -287,7 +287,7 @@
self.value = value
def __str__(self):
- return "%s[%s] = %s" % (self.left, get_val(self.key), self.value)
+ return f"{self.left}[{get_val(self.key)}] = {self.value}"
def execute(self):
left = self.left.execute()
@@ -302,7 +302,7 @@
self.right = right
def __str__(self):
- return "%s + %s" % (self.left, self.right)
+ return f"{self.left} + {self.right}"
def execute(self):
return get_val(self.left) + get_val(self.right)
@@ -315,7 +315,7 @@
self.right = right
def __str__(self):
- return "%s * %s" % (self.left, self.right)
+ return f"{self.left} * {self.right}"
def execute(self):
return get_val(self.left) * get_val(self.right)
@@ -328,7 +328,7 @@
self.right = right
def __str__(self):
- return "%s - %s" % (self.left, self.right)
+ return f"{self.left} - {self.right}"
def execute(self):
return get_val(self.left) - get_val(self.right)
@@ -341,7 +341,7 @@
self.name = name
def __str__(self):
- return "%s.%s" % (self.src, self.name)
+ return f"{self.src}.{self.name}"
def execute(self):
val = get_val(self.src)
diff --git a/torch/utils/data/datapipes/datapipe.py b/torch/utils/data/datapipes/datapipe.py
index 445400e..1017b52 100644
--- a/torch/utils/data/datapipes/datapipe.py
+++ b/torch/utils/data/datapipes/datapipe.py
@@ -126,7 +126,7 @@
functools.update_wrapper(wrapper=function, wrapped=f, assigned=("__doc__",))
return function
else:
- raise AttributeError("'{0}' object has no attribute '{1}".format(self.__class__.__name__, attribute_name))
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{attribute_name}")
@classmethod
def register_function(cls, function_name, function):
@@ -135,7 +135,7 @@
@classmethod
def register_datapipe_as_function(cls, function_name, cls_to_register, enable_df_api_tracing=False):
if function_name in cls.functions:
- raise Exception("Unable to add DataPipe function name {} as it is already taken".format(function_name))
+ raise Exception(f"Unable to add DataPipe function name {function_name} as it is already taken")
def class_function(cls, enable_df_api_tracing, source_dp, *args, **kwargs):
result_pipe = cls(source_dp, *args, **kwargs)
@@ -265,7 +265,7 @@
functools.update_wrapper(wrapper=function, wrapped=f, assigned=("__doc__",))
return function
else:
- raise AttributeError("'{0}' object has no attribute '{1}".format(self.__class__.__name__, attribute_name))
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{attribute_name}")
@classmethod
def register_function(cls, function_name, function):
@@ -274,7 +274,7 @@
@classmethod
def register_datapipe_as_function(cls, function_name, cls_to_register):
if function_name in cls.functions:
- raise Exception("Unable to add DataPipe function name {} as it is already taken".format(function_name))
+ raise Exception(f"Unable to add DataPipe function name {function_name} as it is already taken")
def class_function(cls, source_dp, *args, **kwargs):
result_pipe = cls(source_dp, *args, **kwargs)
@@ -363,7 +363,7 @@
return len(self._datapipe)
except Exception as e:
raise TypeError(
- "{} instance doesn't have valid length".format(type(self).__name__)
+ f"{type(self).__name__} instance doesn't have valid length"
) from e
diff --git a/torch/utils/data/datapipes/gen_pyi.py b/torch/utils/data/datapipes/gen_pyi.py
index 1b77fbf..ed3e75b 100644
--- a/torch/utils/data/datapipes/gen_pyi.py
+++ b/torch/utils/data/datapipes/gen_pyi.py
@@ -19,7 +19,7 @@
template_path = os.path.join(dir, template_name)
output_path = os.path.join(dir, output_name)
- with open(template_path, "r") as f:
+ with open(template_path) as f:
content = f.read()
for placeholder, lines, indentation in replacements:
with open(output_path, "w") as f:
diff --git a/torch/utils/data/datapipes/iter/callable.py b/torch/utils/data/datapipes/iter/callable.py
index 4e3dce4..9916b09 100644
--- a/torch/utils/data/datapipes/iter/callable.py
+++ b/torch/utils/data/datapipes/iter/callable.py
@@ -126,7 +126,7 @@
if isinstance(self.datapipe, Sized):
return len(self.datapipe)
raise TypeError(
- "{} instance doesn't have valid length".format(type(self).__name__)
+ f"{type(self).__name__} instance doesn't have valid length"
)
diff --git a/torch/utils/data/datapipes/iter/combinatorics.py b/torch/utils/data/datapipes/iter/combinatorics.py
index 30b569e..4d2973b 100644
--- a/torch/utils/data/datapipes/iter/combinatorics.py
+++ b/torch/utils/data/datapipes/iter/combinatorics.py
@@ -48,7 +48,7 @@
# Dataset has been tested as `Sized`
if isinstance(self.sampler, Sized):
return len(self.sampler)
- raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
+ raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
@functional_datapipe('shuffle')
@@ -137,7 +137,7 @@
def __len__(self) -> int:
if isinstance(self.datapipe, Sized):
return len(self.datapipe)
- raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
+ raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
def reset(self) -> None:
self._buffer = []
diff --git a/torch/utils/data/datapipes/iter/combining.py b/torch/utils/data/datapipes/iter/combining.py
index 7c76e98..4fe05ea 100644
--- a/torch/utils/data/datapipes/iter/combining.py
+++ b/torch/utils/data/datapipes/iter/combining.py
@@ -56,7 +56,7 @@
if all(isinstance(dp, Sized) for dp in self.datapipes):
return sum(len(dp) for dp in self.datapipes)
else:
- raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
+ raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
@functional_datapipe('fork')
@@ -567,7 +567,7 @@
if all(isinstance(dp, Sized) for dp in self.datapipes):
return min(len(dp) for dp in self.datapipes) * len(self.datapipes)
else:
- raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
+ raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
def reset(self) -> None:
self.buffer = []
@@ -627,4 +627,4 @@
if all(isinstance(dp, Sized) for dp in self.datapipes):
return min(len(dp) for dp in self.datapipes)
else:
- raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
+ raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
diff --git a/torch/utils/data/datapipes/iter/filelister.py b/torch/utils/data/datapipes/iter/filelister.py
index b2ecd71..22e2cd4 100644
--- a/torch/utils/data/datapipes/iter/filelister.py
+++ b/torch/utils/data/datapipes/iter/filelister.py
@@ -61,5 +61,5 @@
def __len__(self):
if self.length == -1:
- raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
+ raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
return self.length
diff --git a/torch/utils/data/datapipes/iter/fileopener.py b/torch/utils/data/datapipes/iter/fileopener.py
index 03d5761..50737d9 100644
--- a/torch/utils/data/datapipes/iter/fileopener.py
+++ b/torch/utils/data/datapipes/iter/fileopener.py
@@ -51,7 +51,7 @@
self.encoding: Optional[str] = encoding
if self.mode not in ('b', 't', 'rb', 'rt', 'r'):
- raise ValueError("Invalid mode {}".format(mode))
+ raise ValueError(f"Invalid mode {mode}")
# TODO: enforce typing for each instance based on mode, otherwise
# `argument_validation` with this DataPipe may be potentially broken
@@ -68,5 +68,5 @@
def __len__(self):
if self.length == -1:
- raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
+ raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
return self.length
diff --git a/torch/utils/data/datapipes/iter/grouping.py b/torch/utils/data/datapipes/iter/grouping.py
index c83bd27..b26847d 100644
--- a/torch/utils/data/datapipes/iter/grouping.py
+++ b/torch/utils/data/datapipes/iter/grouping.py
@@ -83,7 +83,7 @@
else:
return (len(self.datapipe) + self.batch_size - 1) // self.batch_size
else:
- raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
+ raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
@functional_datapipe('unbatch')
diff --git a/torch/utils/data/datapipes/iter/routeddecoder.py b/torch/utils/data/datapipes/iter/routeddecoder.py
index 8bfbe14..5e68ae1 100644
--- a/torch/utils/data/datapipes/iter/routeddecoder.py
+++ b/torch/utils/data/datapipes/iter/routeddecoder.py
@@ -62,4 +62,4 @@
def __len__(self) -> int:
if isinstance(self.datapipe, Sized):
return len(self.datapipe)
- raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
+ raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
diff --git a/torch/utils/data/datapipes/iter/sharding.py b/torch/utils/data/datapipes/iter/sharding.py
index 730caea..1f4a3a2 100644
--- a/torch/utils/data/datapipes/iter/sharding.py
+++ b/torch/utils/data/datapipes/iter/sharding.py
@@ -80,4 +80,4 @@
if isinstance(self.source_datapipe, Sized):
return len(self.source_datapipe) // self.num_of_instances +\
(1 if (self.instance_id < len(self.source_datapipe) % self.num_of_instances) else 0)
- raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
+ raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
diff --git a/torch/utils/data/datapipes/map/combining.py b/torch/utils/data/datapipes/map/combining.py
index 85146f8..4a4a785 100644
--- a/torch/utils/data/datapipes/map/combining.py
+++ b/torch/utils/data/datapipes/map/combining.py
@@ -47,7 +47,7 @@
return dp[index - offset]
else:
offset += len(dp)
- raise IndexError("Index {} is out of range.".format(index))
+ raise IndexError(f"Index {index} is out of range.")
def __len__(self) -> int:
return sum(len(dp) for dp in self.datapipes)
diff --git a/torch/utils/data/datapipes/map/grouping.py b/torch/utils/data/datapipes/map/grouping.py
index da3cf56..65b30d8 100644
--- a/torch/utils/data/datapipes/map/grouping.py
+++ b/torch/utils/data/datapipes/map/grouping.py
@@ -64,4 +64,4 @@
else:
return (len(self.datapipe) + self.batch_size - 1) // self.batch_size
else:
- raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
+ raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
diff --git a/torch/utils/data/datapipes/utils/common.py b/torch/utils/data/datapipes/utils/common.py
index e39d67e..99ae0cb 100644
--- a/torch/utils/data/datapipes/utils/common.py
+++ b/torch/utils/data/datapipes/utils/common.py
@@ -305,7 +305,7 @@
self.closed = False
if parent_stream is not None:
if not isinstance(parent_stream, StreamWrapper):
- raise RuntimeError('Parent stream should be StreamWrapper, {} was given'.format(type(parent_stream)))
+ raise RuntimeError(f'Parent stream should be StreamWrapper, {type(parent_stream)} was given')
parent_stream.child_counter += 1
self.parent_stream = parent_stream
if StreamWrapper.debug_unclosed_streams:
diff --git a/torch/utils/data/datapipes/utils/decoder.py b/torch/utils/data/datapipes/utils/decoder.py
index 4da810c..8a7cb71 100644
--- a/torch/utils/data/datapipes/utils/decoder.py
+++ b/torch/utils/data/datapipes/utils/decoder.py
@@ -137,7 +137,7 @@
- pilrgba: pil None rgba
"""
def __init__(self, imagespec):
- assert imagespec in list(imagespecs.keys()), "unknown image specification: {}".format(imagespec)
+ assert imagespec in list(imagespecs.keys()), f"unknown image specification: {imagespec}"
self.imagespec = imagespec.lower()
def __call__(self, extension, data):
@@ -167,14 +167,14 @@
return img
elif atype == "numpy":
result = np.asarray(img)
- assert result.dtype == np.uint8, "numpy image array should be type uint8, but got {}".format(result.dtype)
+ assert result.dtype == np.uint8, f"numpy image array should be type uint8, but got {result.dtype}"
if etype == "uint8":
return result
else:
return result.astype("f") / 255.0
elif atype == "torch":
result = np.asarray(img)
- assert result.dtype == np.uint8, "numpy image array should be type uint8, but got {}".format(result.dtype)
+ assert result.dtype == np.uint8, f"numpy image array should be type uint8, but got {result.dtype}"
if etype == "uint8":
result = np.array(result.transpose(2, 0, 1))
diff --git a/torch/utils/data/graph.py b/torch/utils/data/graph.py
index 2769e32..7fc95d5 100644
--- a/torch/utils/data/graph.py
+++ b/torch/utils/data/graph.py
@@ -130,7 +130,7 @@
# Add cache here to prevent infinite recursion on DataPipe
def _traverse_helper(datapipe: DataPipe, only_datapipe: bool, cache: Set[int]) -> DataPipeGraph:
if not isinstance(datapipe, (IterDataPipe, MapDataPipe)):
- raise RuntimeError("Expected `IterDataPipe` or `MapDataPipe`, but {} is found".format(type(datapipe)))
+ raise RuntimeError(f"Expected `IterDataPipe` or `MapDataPipe`, but {type(datapipe)} is found")
dp_id = id(datapipe)
if dp_id in cache:
diff --git a/torch/utils/dlpack.py b/torch/utils/dlpack.py
index f903de9..a987bca 100644
--- a/torch/utils/dlpack.py
+++ b/torch/utils/dlpack.py
@@ -102,7 +102,7 @@
# device is either CUDA or ROCm, we need to pass the current
# stream
if device[0] in (DLDeviceType.kDLGPU, DLDeviceType.kDLROCM):
- stream = torch.cuda.current_stream('cuda:{}'.format(device[1]))
+ stream = torch.cuda.current_stream(f'cuda:{device[1]}')
# cuda_stream is the pointer to the stream and it is a public
# attribute, but it is not documented
# The array API specify that the default legacy stream must be passed
diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py
index 3b583db..163f364 100644
--- a/torch/utils/hipify/cuda_to_hip_mappings.py
+++ b/torch/utils/hipify/cuda_to_hip_mappings.py
@@ -46,7 +46,7 @@
RE_MINOR = re.compile(r"#define\s+ROCM_VERSION_MINOR\s+(\d+)")
RE_PATCH = re.compile(r"#define\s+ROCM_VERSION_PATCH\s+(\d+)")
major, minor, patch = 0, 0, 0
- for line in open(rocm_version_h, "r"):
+ for line in open(rocm_version_h):
match = RE_MAJOR.search(line)
if match:
major = int(match.group(1))
diff --git a/torch/utils/hipify/hipify_python.py b/torch/utils/hipify/hipify_python.py
index 34a0667..fa80065 100755
--- a/torch/utils/hipify/hipify_python.py
+++ b/torch/utils/hipify/hipify_python.py
@@ -219,13 +219,13 @@
unsupported_calls = {cuda_call for (cuda_call, _filepath) in stats["unsupported_calls"]}
# Print the number of unsupported calls
- print("Total number of unsupported CUDA function calls: {0:d}".format(len(unsupported_calls)))
+ print(f"Total number of unsupported CUDA function calls: {len(unsupported_calls):d}")
# Print the list of unsupported calls
print(", ".join(unsupported_calls))
# Print the number of kernel launches
- print("\nTotal number of replaced kernel launches: {0:d}".format(len(stats["kernel_launches"])))
+ print("\nTotal number of replaced kernel launches: {:d}".format(len(stats["kernel_launches"])))
def add_dim3(kernel_string, cuda_kernel):
@@ -254,8 +254,8 @@
first_arg_clean = kernel_string[arg_locs[0]['start']:arg_locs[0]['end']].replace("\n", "").strip(" ")
second_arg_clean = kernel_string[arg_locs[1]['start']:arg_locs[1]['end']].replace("\n", "").strip(" ")
- first_arg_dim3 = "dim3({})".format(first_arg_clean)
- second_arg_dim3 = "dim3({})".format(second_arg_clean)
+ first_arg_dim3 = f"dim3({first_arg_clean})"
+ second_arg_dim3 = f"dim3({second_arg_clean})"
first_arg_raw_dim3 = first_arg_raw.replace(first_arg_clean, first_arg_dim3)
second_arg_raw_dim3 = second_arg_raw.replace(second_arg_clean, second_arg_dim3)
@@ -269,7 +269,7 @@
def processKernelLaunches(string, stats):
""" Replace the CUDA style Kernel launches with the HIP style kernel launches."""
# Concat the namespace with the kernel names. (Find cleaner way of doing this later).
- string = RE_KERNEL_LAUNCH.sub(lambda inp: "{0}{1}::".format(inp.group(1), inp.group(2)), string)
+ string = RE_KERNEL_LAUNCH.sub(lambda inp: f"{inp.group(1)}{inp.group(2)}::", string)
def grab_method_and_template(in_kernel):
# The positions for relevant kernel components.
@@ -482,7 +482,7 @@
"""
output_string = input_string
for func in MATH_TRANSPILATIONS:
- output_string = output_string.replace(r'{}('.format(func), '{}('.format(MATH_TRANSPILATIONS[func]))
+ output_string = output_string.replace(fr'{func}(', f'{MATH_TRANSPILATIONS[func]}(')
return output_string
@@ -531,7 +531,7 @@
"""
output_string = input_string
output_string = RE_EXTERN_SHARED.sub(
- lambda inp: "HIP_DYNAMIC_SHARED({0} {1}, {2})".format(
+ lambda inp: "HIP_DYNAMIC_SHARED({} {}, {})".format(
inp.group(1) or "", inp.group(2), inp.group(3)), output_string)
return output_string
@@ -657,7 +657,7 @@
# Cribbed from https://stackoverflow.com/questions/42742810/speed-up-millions-of-regex-replacements-in-python-3/42789508#42789508
-class Trie():
+class Trie:
"""Regex::Trie in Python. Creates a Trie out of a list of words. The trie can be exported to a Regex pattern.
The corresponding Regex should match much faster than a simple Regex union."""
@@ -750,7 +750,7 @@
CAFFE2_TRIE.add(src)
CAFFE2_MAP[src] = dst
RE_CAFFE2_PREPROCESSOR = re.compile(CAFFE2_TRIE.pattern())
-RE_PYTORCH_PREPROCESSOR = re.compile(r'(?<=\W)({0})(?=\W)'.format(PYTORCH_TRIE.pattern()))
+RE_PYTORCH_PREPROCESSOR = re.compile(fr'(?<=\W)({PYTORCH_TRIE.pattern()})(?=\W)')
RE_QUOTE_HEADER = re.compile(r'#include "([^"]+)"')
RE_ANGLE_HEADER = re.compile(r'#include <([^>]+)>')
@@ -789,7 +789,7 @@
rel_filepath = os.path.relpath(filepath, output_directory)
- with open(fin_path, 'r', encoding='utf-8') as fin:
+ with open(fin_path, encoding='utf-8') as fin:
if fin.readline() == HIPIFY_C_BREADCRUMB:
hipify_result.hipified_path = None
hipify_result.status = "[ignored, input is hipified output]"
@@ -929,7 +929,7 @@
do_write = True
if os.path.exists(fout_path):
- with open(fout_path, 'r', encoding='utf-8') as fout_old:
+ with open(fout_path, encoding='utf-8') as fout_old:
do_write = fout_old.read() != output_source
if do_write:
try:
@@ -956,7 +956,7 @@
with openf(filepath, "r+") as f:
contents = f.read()
if strict:
- contents = re.sub(r'\b({0})\b'.format(re.escape(search_string)), lambda x: replace_string, contents)
+ contents = re.sub(fr'\b({re.escape(search_string)})\b', lambda x: replace_string, contents)
else:
contents = contents.replace(search_string, replace_string)
f.seek(0)
@@ -968,8 +968,8 @@
with openf(filepath, "r+") as f:
contents = f.read()
if header[0] != "<" and header[-1] != ">":
- header = '"{0}"'.format(header)
- contents = ('#include {0} \n'.format(header)) + contents
+ header = f'"{header}"'
+ contents = (f'#include {header} \n') + contents
f.seek(0)
f.write(contents)
f.truncate()
diff --git a/torch/utils/jit/log_extract.py b/torch/utils/jit/log_extract.py
index d9d0e44..2e89a76 100644
--- a/torch/utils/jit/log_extract.py
+++ b/torch/utils/jit/log_extract.py
@@ -11,7 +11,7 @@
pfx = None
current = ""
graphs = []
- with open(filename, "r") as f:
+ with open(filename) as f:
split_strs = f.read().split(BEGIN)
for i, split_str in enumerate(split_strs):
if i == 0:
diff --git a/torch/utils/mobile_optimizer.py b/torch/utils/mobile_optimizer.py
index ec20042..66d57a2 100644
--- a/torch/utils/mobile_optimizer.py
+++ b/torch/utils/mobile_optimizer.py
@@ -31,7 +31,7 @@
"""
if not isinstance(script_module, torch.jit.ScriptModule):
raise TypeError(
- 'Got {}, but ScriptModule is expected.'.format(type(script_module)))
+ f'Got {type(script_module)}, but ScriptModule is expected.')
if optimization_blocklist is None:
optimization_blocklist = set()
@@ -86,7 +86,7 @@
"""
if not isinstance(script_module, torch.jit.ScriptModule):
raise TypeError(
- 'Got {}, but ScriptModule is expected.'.format(type(script_module)))
+ f'Got {type(script_module)}, but ScriptModule is expected.')
lint_list = []
diff --git a/torch/utils/tensorboard/_caffe2_graph.py b/torch/utils/tensorboard/_caffe2_graph.py
index 8bba2ae..2aa162a 100644
--- a/torch/utils/tensorboard/_caffe2_graph.py
+++ b/torch/utils/tensorboard/_caffe2_graph.py
@@ -232,7 +232,7 @@
def f(name):
if "_grad" in name:
- return "GRADIENTS/{}".format(name)
+ return f"GRADIENTS/{name}"
else:
return name
@@ -317,7 +317,7 @@
):
return "/cpu:*"
if device_option.device_type == caffe2_pb2.CUDA:
- return "/gpu:{}".format(device_option.device_id)
+ return f"/gpu:{device_option.device_id}"
raise Exception("Unhandled device", device_option)
diff --git a/torch/utils/tensorboard/_embedding.py b/torch/utils/tensorboard/_embedding.py
index f172e09..afbe681 100644
--- a/torch/utils/tensorboard/_embedding.py
+++ b/torch/utils/tensorboard/_embedding.py
@@ -62,7 +62,7 @@
def get_embedding_info(metadata, label_img, subdir, global_step, tag):
info = EmbeddingInfo()
- info.tensor_name = "{}:{}".format(tag, str(global_step).zfill(5))
+ info.tensor_name = f"{tag}:{str(global_step).zfill(5)}"
info.tensor_path = _gfile_join(subdir, "tensors.tsv")
if metadata is not None:
info.metadata_path = _gfile_join(subdir, "metadata.tsv")
diff --git a/torch/utils/tensorboard/_pytorch_graph.py b/torch/utils/tensorboard/_pytorch_graph.py
index f03812b..280b503 100644
--- a/torch/utils/tensorboard/_pytorch_graph.py
+++ b/torch/utils/tensorboard/_pytorch_graph.py
@@ -275,7 +275,7 @@
parent_scope, attr_scope, attr_name
)
else:
- attr_to_scope[attr_key] = "__module.{}".format(attr_name)
+ attr_to_scope[attr_key] = f"__module.{attr_name}"
# We don't need classtype nodes; scope will provide this information
if node.output().type().kind() != CLASSTYPE_KIND:
node_py = NodePyOP(node)
@@ -286,7 +286,7 @@
for i, node in enumerate(graph.outputs()): # Create sink nodes for output ops
node_pyio = NodePyIO(node, "output")
- node_pyio.debugName = "output.{}".format(i + 1)
+ node_pyio.debugName = f"output.{i + 1}"
node_pyio.inputs = [node.debugName()]
nodes_py.append(node_pyio)
@@ -302,7 +302,7 @@
for name, module in trace.named_modules(prefix="__module"):
mod_name = parse_traced_name(module)
attr_name = name.split(".")[-1]
- alias_to_name[name] = "{}[{}]".format(mod_name, attr_name)
+ alias_to_name[name] = f"{mod_name}[{attr_name}]"
for node in nodes_py.nodes_op:
module_aliases = node.scopeName.split("/")
diff --git a/torch/utils/tensorboard/writer.py b/torch/utils/tensorboard/writer.py
index 2fa34d0..b592707 100644
--- a/torch/utils/tensorboard/writer.py
+++ b/torch/utils/tensorboard/writer.py
@@ -953,7 +953,7 @@
# Maybe we should encode the tag so slashes don't trip us up?
# I don't think this will mess us up, but better safe than sorry.
- subdir = "%s/%s" % (str(global_step).zfill(5), self._encode(tag))
+ subdir = f"{str(global_step).zfill(5)}/{self._encode(tag)}"
save_path = os.path.join(self._get_file_writer().get_logdir(), subdir)
fs = tf.io.gfile
diff --git a/torch/utils/throughput_benchmark.py b/torch/utils/throughput_benchmark.py
index 8b2fd1a..2dc3ce8 100644
--- a/torch/utils/throughput_benchmark.py
+++ b/torch/utils/throughput_benchmark.py
@@ -18,10 +18,10 @@
raise AssertionError("Shouldn't reach here :)")
if time_us >= US_IN_SECOND:
- return '{:.3f}s'.format(time_us / US_IN_SECOND)
+ return f'{time_us / US_IN_SECOND:.3f}s'
if time_us >= US_IN_MS:
- return '{:.3f}ms'.format(time_us / US_IN_MS)
- return '{:.3f}us'.format(time_us)
+ return f'{time_us / US_IN_MS:.3f}ms'
+ return f'{time_us:.3f}us'
class ExecutionStats:
@@ -52,8 +52,8 @@
def __str__(self):
return '\n'.join([
"Average latency per example: " + format_time(time_ms=self.latency_avg_ms),
- "Total number of iterations: {}".format(self.num_iters),
- "Total number of iterations per second (across all threads): {:.2f}".format(self.iters_per_second),
+ f"Total number of iterations: {self.num_iters}",
+ f"Total number of iterations per second (across all threads): {self.iters_per_second:.2f}",
"Total time: " + format_time(time_s=self.total_time_seconds)
])
diff --git a/torch/utils/viz/_cycles.py b/torch/utils/viz/_cycles.py
index a64d5e9..13a425c 100644
--- a/torch/utils/viz/_cycles.py
+++ b/torch/utils/viz/_cycles.py
@@ -220,29 +220,29 @@
if isinstance(obj, BASE_TYPES):
return repr(obj)
if type(obj).__name__ == 'function':
- return "function\n{}".format(obj.__name__)
+ return f"function\n{obj.__name__}"
elif isinstance(obj, types.MethodType):
try:
func_name = obj.__func__.__qualname__
except AttributeError:
func_name = "<anonymous>"
- return "instancemethod\n{}".format(func_name)
+ return f"instancemethod\n{func_name}"
elif isinstance(obj, list):
return f"[{format_sequence(obj)}]"
elif isinstance(obj, tuple):
return f"({format_sequence(obj)})"
elif isinstance(obj, dict):
- return "dict[{}]".format(len(obj))
+ return f"dict[{len(obj)}]"
elif isinstance(obj, types.ModuleType):
- return "module\n{}".format(obj.__name__)
+ return f"module\n{obj.__name__}"
elif isinstance(obj, type):
- return "type\n{}".format(obj.__name__)
+ return f"type\n{obj.__name__}"
elif isinstance(obj, weakref.ref):
referent = obj()
if referent is None:
return "weakref (dead referent)"
else:
- return "weakref to id 0x{:x}".format(id(referent))
+ return f"weakref to id 0x{id(referent):x}"
elif isinstance(obj, types.FrameType):
filename = obj.f_code.co_filename
if len(filename) > FRAME_FILENAME_LIMIT:
diff --git a/torch/utils/weak.py b/torch/utils/weak.py
index 2a7d597..bcd3025 100644
--- a/torch/utils/weak.py
+++ b/torch/utils/weak.py
@@ -4,7 +4,6 @@
from weakref import ref
from _weakrefset import _IterationGuard # type: ignore[attr-defined]
from collections.abc import MutableMapping, Mapping
-from typing import Dict
from torch import Tensor
import collections.abc as _collections_abc
@@ -83,7 +82,7 @@
# This is directly adapted from cpython/Lib/weakref.py
class WeakIdKeyDictionary(MutableMapping):
- data: Dict[WeakIdRef, object]
+ data: dict[WeakIdRef, object]
def __init__(self, dict=None):
self.data = {}
@@ -144,7 +143,7 @@
return len(self.data) - len(self._pending_removals)
def __repr__(self):
- return "<%s at %#x>" % (self.__class__.__name__, id(self))
+ return f"<{self.__class__.__name__} at {id(self):#x}>"
def __setitem__(self, key, value):
self.data[WeakIdRef(key, self._remove)] = value # CHANGED