teach inductor to handle floor (#94341)
Per title, happen when there's upsampling with non-integer scale.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94341
Approved by: https://github.com/ezyang
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index effe9b6..c0d8657 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -61,6 +61,7 @@
from torch._inductor import codecache, config, metrics, test_operators
from torch._inductor.codegen.cpp import cexpr, CppOverrides, CppVecOverrides
from torch._inductor.codegen.triton import texpr
+from torch._inductor.codegen.wrapper import pexpr
from torch._inductor.compile_fx import (
compile_fx,
@@ -506,6 +507,8 @@
example_inputs = list(map(downcast_fn, example_inputs))
if hasattr(model, "to"):
model = model.to(torch.half)
+ if rtol is not None:
+ rtol = 2e-3
check_model(
self,
model,
@@ -3655,7 +3658,7 @@
aten.upsample_bilinear2d(a, None, True, [2.0, 2.0]),
)
- self.common(fn, (torch.randn([2, 4, 37, 38]),))
+ self.common(fn, (torch.randn([2, 4, 37, 38]),), atol=2.5e-5, rtol=1.3e-6)
def test_upsample_bilinear2d_b(self):
def fn(a):
@@ -3666,6 +3669,8 @@
[
torch.randn([1, 2, 40, 59]),
],
+ atol=2.5e-5,
+ rtol=1.3e-6,
)
def test_reflection_pad2d(self):
@@ -5517,16 +5522,16 @@
"test_roi_align_dynamic_shapes": ("cpu", "cuda"),
"test_sizehint_issue1_dynamic_shapes": ("cpu", "cuda"),
"test_unroll_small_reduction_dynamic_shapes": ("cpu", "cuda"),
- "test_upsample_bilinear2d_a_dynamic_shapes": ("cpu", "cuda"),
- "test_upsample_bilinear2d_b_dynamic_shapes": ("cpu", "cuda"),
+ "test_upsample_bilinear2d_a_dynamic_shapes": ("cpu"),
+ "test_upsample_bilinear2d_b_dynamic_shapes": ("cpu"),
"test_upsample_cat_conv_dynamic_shapes": (
"cpu",
"cuda",
), # upsample does not support dynamic shapes yet (#92667)
- "test_upsample_nearest1d_dynamic_shapes": ("cpu", "cuda"),
+ "test_upsample_nearest1d_dynamic_shapes": ("cpu"),
"test_upsample_nearest2d_backward_dynamic_shapes": ("cpu", "cuda"),
- "test_upsample_nearest2d_dynamic_shapes": ("cpu", "cuda"),
- "test_upsample_nearest3d_dynamic_shapes": ("cpu", "cuda"),
+ "test_upsample_nearest2d_dynamic_shapes": ("cpu"),
+ "test_upsample_nearest3d_dynamic_shapes": ("cpu"),
}
@@ -7082,6 +7087,12 @@
self.assertEqual(cexpr(expr), result)
self.assertEqual(texpr(expr), result)
+ def test_print_floor(self):
+ s1 = sympy.Symbol("s1", integer=False)
+ expr = sympy.floor(s1)
+ self.assertEqual(texpr(expr), "tl.libdevice.floor(s1)")
+ self.assertEqual(pexpr(expr), "math.floor(s1)")
+
if HAS_CUDA and not TEST_WITH_ASAN:
diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py
index d60aba0..601995e 100644
--- a/torch/_inductor/codegen/common.py
+++ b/torch/_inductor/codegen/common.py
@@ -72,6 +72,27 @@
return self._print_FloorDiv(expr)
+class PythonPrinter(ExprPrinter):
+ def _print_ModularIndexing(self, expr):
+ x, div, mod = expr.args
+ x = self.paren(self.doprint(x))
+ div = self.paren(self.doprint(div))
+ mod = self.paren(self.doprint(mod))
+ if div != "1":
+ x = f"({x} // {div})"
+ return f"{x} % {mod}"
+
+ def _print_FloorDiv(self, expr):
+ x, div = expr.args
+ x = self.paren(self.doprint(x))
+ div = self.paren(self.doprint(div))
+ return f"({x} // {div})"
+
+ def _print_floor(self, expr):
+ assert len(expr.args) == 1
+ return f"math.floor({self.paren(self._print(expr.args[0]))})"
+
+
class OpOverrides:
def __init__(self, parent):
super().__init__()
diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py
index 7d94abe..8ff5767 100644
--- a/torch/_inductor/codegen/triton.py
+++ b/torch/_inductor/codegen/triton.py
@@ -28,12 +28,12 @@
from .common import (
CSEVariable,
DeferredLine,
- ExprPrinter,
free_symbol_startswith,
IndentedBuffer,
index_prevent_reordering,
Kernel,
OpOverrides,
+ PythonPrinter,
SizeArg,
TensorArg,
)
@@ -74,24 +74,14 @@
return instance_descriptor(tuple(divisible_by_16), ())
-class TritonPrinter(ExprPrinter):
- def _print_ModularIndexing(self, expr):
- x, div, mod = expr.args
- x = self.paren(self.doprint(x))
- div = self.paren(self.doprint(div))
- mod = self.paren(self.doprint(mod))
- if div != "1":
- x = f"({x} // {div})"
- return f"{x} % {mod}"
-
- def _print_FloorDiv(self, expr):
- x, div = expr.args
- x = self.paren(self.doprint(x))
- div = self.paren(self.doprint(div))
- return f"({x} // {div})"
+class TritonPrinter(PythonPrinter):
+ def _print_floor(self, expr):
+ assert len(expr.args) == 1
+ return f"tl.libdevice.floor({self.paren(self._print(expr.args[0]))})"
texpr = TritonPrinter().doprint
+pexpr = PythonPrinter().doprint
def triton_compute_type(dtype):
@@ -552,7 +542,7 @@
class TritonKernel(Kernel):
overrides = TritonOverrides
- sexpr = texpr
+ sexpr = pexpr
def __init__(
self,
@@ -1228,10 +1218,10 @@
# TODO(jansel): if there are constants, we shouldn't bother passing them as args
for tree in self.range_trees:
if isinstance(tree.numel, (sympy.Integer, sympy.Symbol)):
- expr = texpr(tree.numel)
+ expr = pexpr(tree.numel)
else:
expr = f"{name}_{tree.prefix}numel"
- code.writeline(f"{expr} = {texpr(tree.numel)}")
+ code.writeline(f"{expr} = {pexpr(tree.numel)}")
if tree.prefix != "r" or self.inside_reduction:
call_args.append(expr)
if tree.prefix != "r":
diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py
index 1e019d5..d69d19c 100644
--- a/torch/_inductor/codegen/wrapper.py
+++ b/torch/_inductor/codegen/wrapper.py
@@ -12,10 +12,9 @@
from ..codecache import cpp_compile_command, get_code_path
from ..utils import cache_on_self, has_triton, sympy_dot, sympy_product
from ..virtualized import V
-from .common import CodeGen, DeferredLine, IndentedBuffer, Kernel
-from .triton import texpr
+from .common import CodeGen, DeferredLine, IndentedBuffer, Kernel, PythonPrinter
-pexpr = texpr
+pexpr = PythonPrinter().doprint
def buffer_reuse_key(node: ir.Buffer):
@@ -272,6 +271,7 @@
f"""
from ctypes import c_void_p, c_long
import torch
+ import math
import random
from torch import empty_strided, as_strided, device
from {codecache.__name__} import AsyncCompile
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index a43fc31..38dd659 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -5,7 +5,7 @@
import torch
from torch import _VF
-from torch import sym_float as _sym_float, sym_int as _sym_int
+from torch import sym_int as _sym_int
from torch._C import _infer_size, _add_docstr
from torch._torch_docs import reproducibility_notes, tf32_notes, sparse_support_notes
# A workaround to support both TorchScript and MyPy:
@@ -3917,7 +3917,7 @@
for i in range(dim)]
else:
output_size = [
- _sym_int(math.floor(_sym_float(input.size(i + 2)) * scale_factors[i]))
+ _sym_int(input.size(i + 2) * scale_factors[i])
for i in range(dim)
]
scale_factors = None