Add support for XPU accumulate type (#128579)
Provide an accumulate type interface specifically for XPU, similar to what was done for MPS.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128579
Approved by: https://github.com/EikanWang, https://github.com/albanD
diff --git a/aten/src/ATen/AccumulateType.h b/aten/src/ATen/AccumulateType.h
index 0275ef0..b1f120e 100644
--- a/aten/src/ATen/AccumulateType.h
+++ b/aten/src/ATen/AccumulateType.h
@@ -82,6 +82,7 @@
using type = acc_t; \
};
#define MPS_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::MPS)
+#define XPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::XPU)
#define CUDA_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CUDA)
#define CPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CPU)
@@ -104,6 +105,25 @@
MPS_ACC_TYPE(c10::complex<float>, c10::complex<float>);
MPS_ACC_TYPE(c10::complex<double>, c10::complex<float>);
+XPU_ACC_TYPE(BFloat16, float);
+XPU_ACC_TYPE(Half, float);
+XPU_ACC_TYPE(Float8_e5m2, float);
+XPU_ACC_TYPE(Float8_e4m3fn, float);
+XPU_ACC_TYPE(Float8_e5m2fnuz, float);
+XPU_ACC_TYPE(Float8_e4m3fnuz, float);
+XPU_ACC_TYPE(float, float);
+XPU_ACC_TYPE(double, double);
+XPU_ACC_TYPE(int8_t, int64_t);
+XPU_ACC_TYPE(uint8_t, int64_t);
+XPU_ACC_TYPE(char, int64_t);
+XPU_ACC_TYPE(int16_t, int64_t);
+XPU_ACC_TYPE(int32_t, int64_t);
+XPU_ACC_TYPE(int64_t, int64_t);
+XPU_ACC_TYPE(bool, bool);
+XPU_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
+XPU_ACC_TYPE(c10::complex<float>, c10::complex<float>);
+XPU_ACC_TYPE(c10::complex<double>, c10::complex<double>);
+
#if defined(__CUDACC__) || defined(__HIPCC__)
CUDA_ACC_TYPE(half, float);
#endif