Fall back to eager mode when viewing with differing bitwidths (#120998) (#121786)
The inductor lowering code for viewing a tensor as a type with a different bitwidth currently doesn't generate valid triton code. This change looks for a source and destination dtype and, if different sizes, falls back to the eager mode aten implementation. Prior to this change, this condition would throw an exception.
Fixes #120998.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121786
Approved by: https://github.com/peterbell10, https://github.com/bertmaher
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index 34c9627..955bbd6 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -9190,6 +9190,24 @@
self.common(fn, args, check_lowp=check_lowp)
+ # codegen test fails with no dynamic for loop in dynamic shape tests
+ @expectedFailureCodegenDynamic
+ def test_view_uint8_through_differing_bitwidths(self):
+ # https://github.com/pytorch/pytorch/issues/120998
+ def fn(x, view_dtype):
+ return x.view(view_dtype).view(torch.uint8)
+
+ view_dtypes = [torch.int16, torch.int32, torch.int64]
+ for dtype in view_dtypes:
+ x = torch.randint(0, 2**4, [4096, 4096], dtype=torch.uint8)
+ self.common(
+ fn,
+ (
+ x,
+ dtype,
+ ),
+ )
+
@dataclasses.dataclass
class TestFailure:
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index 0b602a9..6264035 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -566,9 +566,8 @@
src_bits = _get_primitive_bitwidth(x_dtype)
dst_bits = _get_primitive_bitwidth(dtype)
if src_bits != dst_bits:
- raise NotImplementedError(
- f"bitcast {x_dtype} to different bitwidth type {dtype} is not supported yet."
- )
+ # fallback to aten eager implementation for differing bitwidths
+ return fallback_handler(aten.view.dtype)(x, dtype)
def _to_dtype_bitcast(x):
# Because we may promote tensor type from float16 or bfloat16