enable scalar reduction with dim=-1 (#88628)
Tested with all samples for `sum`, but also fixes all samples errors on other reductions (amin, amax, any, all etc)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88628
Approved by: https://github.com/desertfire
diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py
index b06d372..03b5138 100644
--- a/test/inductor/test_torchinductor_opinfo.py
+++ b/test/inductor/test_torchinductor_opinfo.py
@@ -141,11 +141,7 @@
"linalg.pinv.singular": {f32, f64},
"linalg.householder_product": {f32},
# These might be passing now?
- "T": {b8, f16, f32, f64, i32, i64},
- "H": {b8, f16, f32, f64, i32, i64},
"__getitem__": {b8, f16, f32, f64, i32, i64},
- "acos": {b8, f16, f32, f64, i32, i64},
- "acosh": {b8, f16, f32, f64, i32, i64},
"nn.functional.conv_transpose3d": {f16},
"max.reduction_with_dim": {i32, i64},
"min.reduction_with_dim": {i32, i64},
@@ -447,6 +443,7 @@
"select_scatter",
"squeeze",
"unsqueeze",
+ "sum",
}
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index 71f038b..d83fbba 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -3002,7 +3002,7 @@
axis = list(axis)
for i in range(len(axis)):
if axis[i] < 0:
- axis[i] += len(size)
+ axis[i] += len(size) if len(size) else 1
assert 0 <= axis[i] < len(size) or (len(size) == 0 and axis[i] == 0)
assert len(set(axis)) == len(axis), "reduction axis not unique"
return axis