Revert "XPUHooksInterface inherits from AcceleratorHooksInterface (#129463)"
This reverts commit 6353a12e6a80f06217645b10fb69cffeac08a8d0.
Reverted https://github.com/pytorch/pytorch/pull/129463 on behalf of https://github.com/kit1980 due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/129463#issuecomment-2207529072))
diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h
index e91927c..e5bacba 100644
--- a/aten/src/ATen/Context.h
+++ b/aten/src/ATen/Context.h
@@ -65,8 +65,6 @@
: at::getAccelerator(true).value();
if (device_type == at::kCUDA) {
return at::detail::getCUDAHooks();
- } else if (device_type == at::kXPU) {
- return at::detail::getXPUHooks();
} else if (device_type == at::kMPS) {
return at::detail::getMPSHooks();
} else if (device_type == at::kPrivateUse1) {
diff --git a/aten/src/ATen/detail/XPUHooksInterface.h b/aten/src/ATen/detail/XPUHooksInterface.h
index 3537a98..89128a4 100644
--- a/aten/src/ATen/detail/XPUHooksInterface.h
+++ b/aten/src/ATen/detail/XPUHooksInterface.h
@@ -2,21 +2,26 @@
#include <c10/core/Device.h>
#include <c10/util/Exception.h>
-#include <c10/util/Registry.h>
-
#include <ATen/core/Generator.h>
-#include <ATen/detail/AcceleratorHooksInterface.h>
-
+#include <c10/util/Registry.h>
namespace at {
-struct TORCH_API XPUHooksInterface : AcceleratorHooksInterface{
+constexpr const char* XPU_HELP =
+ "The XPU backend requires Intel Extension for Pytorch;"
+ "this error has occurred because you are trying "
+ "to use some XPU's functionality, but the Intel Extension for Pytorch has not been "
+ "loaded for some reason. The Intel Extension for Pytorch MUST "
+ "be loaded, EVEN IF you don't directly use any symbols from that!";
+
+struct TORCH_API XPUHooksInterface {
virtual ~XPUHooksInterface() = default;
virtual void initXPU() const {
TORCH_CHECK(
false,
- "Cannot initialize XPU without ATen_xpu library.");
+ "Cannot initialize XPU without Intel Extension for Pytorch.",
+ XPU_HELP);
}
virtual bool hasXPU() const {
@@ -26,7 +31,8 @@
virtual std::string showConfig() const {
TORCH_CHECK(
false,
- "Cannot query detailed XPU version without ATen_xpu library.");
+ "Cannot query detailed XPU version without Intel Extension for Pytorch. ",
+ XPU_HELP);
}
virtual int32_t getGlobalIdxFromDevice(const Device& device) const {
@@ -34,11 +40,11 @@
}
virtual Generator getXPUGenerator(C10_UNUSED DeviceIndex device_index = -1) const {
- TORCH_CHECK(false, "Cannot get XPU generator without ATen_xpu library.");
+ TORCH_CHECK(false, "Cannot get XPU generator without Intel Extension for Pytorch. ", XPU_HELP);
}
virtual const Generator& getDefaultXPUGenerator(C10_UNUSED DeviceIndex device_index = -1) const {
- TORCH_CHECK(false, "Cannot get default XPU generator without ATen_xpu library.");
+ TORCH_CHECK(false, "Cannot get default XPU generator without Intel Extension for Pytorch. ", XPU_HELP);
}
virtual DeviceIndex getNumGPUs() const {
@@ -64,10 +70,6 @@
virtual bool isPinnedPtr(const void* /*data*/) const {
return false;
}
-
- virtual bool hasPrimaryContext(DeviceIndex /*device_index*/) const override{
- TORCH_CHECK(false, "Cannot query primary context without ATen_xpu library.");
- }
};
struct TORCH_API XPUHooksArgs {};
diff --git a/aten/src/ATen/xpu/detail/XPUHooks.cpp b/aten/src/ATen/xpu/detail/XPUHooks.cpp
index 589e792..61bc19f 100644
--- a/aten/src/ATen/xpu/detail/XPUHooks.cpp
+++ b/aten/src/ATen/xpu/detail/XPUHooks.cpp
@@ -80,11 +80,6 @@
sycl::get_pointer_type(data, c10::xpu::get_device_context());
}
-bool XPUHooks::hasPrimaryContext(DeviceIndex device_index) const {
- // The default context is utilized for each device. So it always returns true.
- return true;
-}
-
REGISTER_XPU_HOOKS(XPUHooks);
} // namespace at::xpu::detail
diff --git a/aten/src/ATen/xpu/detail/XPUHooks.h b/aten/src/ATen/xpu/detail/XPUHooks.h
index b417f50..3027958 100644
--- a/aten/src/ATen/xpu/detail/XPUHooks.h
+++ b/aten/src/ATen/xpu/detail/XPUHooks.h
@@ -20,7 +20,6 @@
void deviceSynchronize(DeviceIndex device_index) const override;
Allocator* getPinnedMemoryAllocator() const override;
bool isPinnedPtr(const void* data) const override;
- bool hasPrimaryContext(DeviceIndex device_index) const override;
};
} // namespace at::xpu::detail