Type torch._inductor.codegen.wrapper (#100657)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100657
Approved by: https://github.com/voznesenskym
diff --git a/.lintrunner.toml b/.lintrunner.toml
index a3f6433..05a089a 100644
--- a/.lintrunner.toml
+++ b/.lintrunner.toml
@@ -178,6 +178,7 @@
'torch/_dynamo/debug_utils.py',
'torch/_dynamo/repro/**/*.py',
'torch/_inductor/graph.py',
+ 'torch/_inductor/codegen/wrapper.py',
'torch/_C/_dynamo/**/*.py',
'test/test_utils.py', # used to by in MYPY but after importing op_db it took 10+ minutes
]
diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py
index 7f76a6b..f8a31d1 100644
--- a/torch/_inductor/codegen/wrapper.py
+++ b/torch/_inductor/codegen/wrapper.py
@@ -225,7 +225,7 @@
# maps from reusing buffer to reused buffer
self.reuses = dict()
- self.write_get_cuda_stream = functools.lru_cache(None)(
+ self.write_get_cuda_stream = functools.lru_cache(None)( # type: ignore[assignment]
self.write_get_cuda_stream
)
@@ -389,7 +389,8 @@
while (
self.lines
and isinstance(self.lines[-1], MemoryPlanningLine)
- and self.lines[-1].node.name not in out_names
+ # TODO: this seems legit, NullLine has no node
+ and self.lines[-1].node.name not in out_names # type: ignore[attr-defined]
):
# these lines will be pointless
self.lines.pop()
@@ -998,7 +999,7 @@
return "true" if s else "false"
elif isinstance(s, str):
return f'"{s}"'
- elif isinstance(s, (List, Tuple)):
+ elif isinstance(s, (list, tuple)):
vals = ", ".join(list(map(self.val_to_str, s)))
return f"{{{vals}}}"
else:
diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py
index 4de7815..de1796f 100644
--- a/torch/_inductor/graph.py
+++ b/torch/_inductor/graph.py
@@ -165,7 +165,7 @@
self.constants: Dict[str, torch.Tensor] = {}
self.removed_buffers: Set[str] = set()
self.inplaced_to_remove: Set[str] = set()
- self.wrapper_code = None
+ self.wrapper_code: Optional[WrapperCodeGen] = None
self.num_static_inputs = num_static_inputs
self.mutated_inputs: Set[str] = set()
self.unaligned_buffers: Set[str] = set()