Introduce `mlc` device (ML Compute device) to PyTorch's device list (#50634)
Summary:
Apple recently announced ML Compute, a new framework available in macOS Big Sur, which enables users to accelerate the training of neural networks on Mac hardware. This PR is the first on a series of PRs that will enable the integration with ML Compute. Most of the integration code will live on a separate subrepo named `mlc`.
The integration with `mlc` (ML Compute) will be very similar to that of xla. We rely on registering our ops through:
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl_UNBOXED(<op_schema_name>, &customized_op_kernel)
...
}
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50634
Reviewed By: malfet
Differential Revision: D26614213
Pulled By: smessmer
fbshipit-source-id: 3b492b346c61cc3950ac880ac01a82fbdddbc07b
diff --git a/.circleci/cimodel/data/pytorch_build_data.py b/.circleci/cimodel/data/pytorch_build_data.py
index 683965e..bd50e43 100644
--- a/.circleci/cimodel/data/pytorch_build_data.py
+++ b/.circleci/cimodel/data/pytorch_build_data.py
@@ -157,6 +157,7 @@
next_nodes = {
"asan": AsanConfigNode,
"xla": XlaConfigNode,
+ "mlc": MLCConfigNode,
"vulkan": VulkanConfigNode,
"parallel_tbb": ParallelTBBConfigNode,
"parallel_native": ParallelNativeConfigNode,
@@ -193,6 +194,16 @@
def child_constructor(self):
return ImportantConfigNode
+class MLCConfigNode(TreeConfigNode):
+ def modify_label(self, label):
+ return "MLC=" + str(label)
+
+ def init2(self, node_name):
+ self.props["is_mlc"] = node_name
+
+ def child_constructor(self):
+ return ImportantConfigNode
+
class AsanConfigNode(TreeConfigNode):
def modify_label(self, label):
diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h
index 4dfb316..ecb58c5 100644
--- a/aten/src/ATen/Context.h
+++ b/aten/src/ATen/Context.h
@@ -73,6 +73,9 @@
bool hasXLA() const {
return c10::impl::hasDeviceGuardImpl(at::DeviceType::XLA);
}
+ bool hasMLC() const {
+ return c10::impl::hasDeviceGuardImpl(at::DeviceType::MLC);
+ }
// defined in header so that getNonVariableType has ability to inline
// call_once check. getNonVariableType is called fairly frequently
THCState* lazyInitCUDA() {
@@ -276,6 +279,10 @@
return globalContext().hasXLA();
}
+static inline bool hasMLC() {
+ return globalContext().hasMLC();
+}
+
// Despite its name, this function returns the number of *CUDA* GPUs.
static inline size_t getNumGPUs() {
// WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS
diff --git a/aten/src/ATen/Version.cpp b/aten/src/ATen/Version.cpp
index ecc5070..fe03506 100644
--- a/aten/src/ATen/Version.cpp
+++ b/aten/src/ATen/Version.cpp
@@ -194,6 +194,7 @@
// TODO: do HIP
// TODO: do XLA
+ // TODO: do MLC
return ss.str();
}
diff --git a/aten/src/ATen/core/VariableFallbackKernel.cpp b/aten/src/ATen/core/VariableFallbackKernel.cpp
index baf49ae..c69d5db 100644
--- a/aten/src/ATen/core/VariableFallbackKernel.cpp
+++ b/aten/src/ATen/core/VariableFallbackKernel.cpp
@@ -48,4 +48,8 @@
m.fallback(torch::CppFunction::makeFallthrough());
}
+TORCH_LIBRARY_IMPL(_, AutogradMLC, m) {
+ m.fallback(torch::CppFunction::makeFallthrough());
+}
+
}
diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h
index 7298cbc..4b1105c 100644
--- a/aten/src/ATen/core/aten_interned_strings.h
+++ b/aten/src/ATen/core/aten_interned_strings.h
@@ -397,6 +397,7 @@
_(aten, is_complex) \
_(aten, is_contiguous) \
_(aten, is_cuda) \
+_(aten, is_mlc) \
_(aten, is_distributed) \
_(aten, is_floating_point) \
_(aten, is_nonzero) \
diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h
index b3ec12d..96867b0 100644
--- a/aten/src/ATen/templates/TensorBody.h
+++ b/aten/src/ATen/templates/TensorBody.h
@@ -364,6 +364,9 @@
/// Returns if a `Tensor` is mkldnn tensor.
bool is_mkldnn() const;
+ /// Returns if a `Tensor` is mlc tensor.
+ bool is_mlc() const;
+
/// Returns if a `Tensor` is vulkan tensor.
bool is_vulkan() const;
diff --git a/aten/src/ATen/templates/TensorMethods.cpp b/aten/src/ATen/templates/TensorMethods.cpp
index 37b9fdd..29998e8 100644
--- a/aten/src/ATen/templates/TensorMethods.cpp
+++ b/aten/src/ATen/templates/TensorMethods.cpp
@@ -145,6 +145,15 @@
return self.is_mkldnn();
}
+bool Tensor::is_mlc() const {
+ // NB: this is not a native function to avoid dispatching overhead.
+ return impl_->is_mlc();
+}
+
+bool is_mlc(Tensor self) {
+ return self.is_mlc();
+}
+
bool Tensor::is_vulkan() const {
// NB: this is not a native function to avoid dispatching overhead.
return impl_->is_vulkan();
diff --git a/c10/core/Backend.h b/c10/core/Backend.h
index 2746360..072a665 100644
--- a/c10/core/Backend.h
+++ b/c10/core/Backend.h
@@ -45,6 +45,7 @@
QuantizedXPU,
Undefined,
MkldnnCPU,
+ MLC,
NumOptions
};
@@ -99,6 +100,8 @@
return Backend::QuantizedCUDA;
case Backend::QuantizedXPU:
return Backend::QuantizedXPU;
+ case Backend::MLC:
+ return Backend::MLC;
default:
throw std::runtime_error("Unknown backend");
}
@@ -117,6 +120,8 @@
return Backend::MSNPU;
} else if (t == DispatchKey::XLA || t == DispatchKey::AutogradXLA) {
return Backend::XLA;
+ } else if (t == DispatchKey::MLC || t == DispatchKey::AutogradMLC) {
+ return Backend::MLC;
} else if (t == DispatchKey::Vulkan) {
return Backend::Vulkan;
} else if (t == DispatchKey::Metal) {
@@ -182,6 +187,8 @@
return DispatchKey::QuantizedCUDA;
case Backend::Undefined:
return DispatchKey::Undefined;
+ case Backend::MLC:
+ return DispatchKey::MLC;
default:
throw std::runtime_error("Unknown backend");
}
@@ -220,6 +227,8 @@
return DeviceType::Vulkan;
case Backend::Metal:
return DeviceType::Metal;
+ case Backend::MLC:
+ return DeviceType::MLC;
case Backend::Undefined:
AT_ERROR("Undefined backend is not a valid device type");
default:
@@ -250,6 +259,8 @@
case Backend::MSNPU:
case Backend::XLA:
return Backend::CPU;
+ case Backend::MLC:
+ return Backend::CPU;
case Backend::MkldnnCPU:
return Backend::MkldnnCPU;
case Backend::QuantizedCPU:
@@ -302,6 +313,7 @@
case Backend::FPGA:
case Backend::MSNPU:
case Backend::XLA:
+ case Backend::MLC:
return Backend::CUDA;
case Backend::SparseXPU:
case Backend::SparseCPU:
@@ -324,6 +336,7 @@
case Backend::FPGA:
case Backend::MSNPU:
case Backend::XLA:
+ case Backend::MLC:
return Backend::HIP;
case Backend::SparseXPU:
case Backend::SparseCPU:
@@ -354,6 +367,8 @@
return "MSNPU";
case Backend::XLA:
return "XLA";
+ case Backend::MLC:
+ return "MLC";
case Backend::SparseCPU:
return "SparseCPU";
case Backend::SparseCUDA:
diff --git a/c10/core/Device.cpp b/c10/core/Device.cpp
index 4be5701..a0182ff 100644
--- a/c10/core/Device.cpp
+++ b/c10/core/Device.cpp
@@ -46,6 +46,7 @@
{"msnpu", DeviceType::MSNPU},
{"xla", DeviceType::XLA},
{"vulkan", DeviceType::Vulkan},
+ {"mlc", DeviceType::MLC},
}};
auto device = std::find_if(
types.begin(),
@@ -57,7 +58,7 @@
return device->second;
}
AT_ERROR(
- "Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, xla, vulkan device type at start of device string: ",
+ "Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, mlc, xla, vulkan device type at start of device string: ",
device_string);
}
} // namespace
diff --git a/c10/core/DeviceType.cpp b/c10/core/DeviceType.cpp
index 008d39d..611d0fe 100644
--- a/c10/core/DeviceType.cpp
+++ b/c10/core/DeviceType.cpp
@@ -27,6 +27,8 @@
return lower_case ? "msnpu" : "MSNPU";
case DeviceType::XLA:
return lower_case ? "xla" : "XLA";
+ case DeviceType::MLC:
+ return lower_case ? "mlc" : "MLC";
case DeviceType::Vulkan:
return lower_case ? "vulkan" : "VULKAN";
case DeviceType::Metal:
@@ -65,6 +67,7 @@
case DeviceType::FPGA:
case DeviceType::MSNPU:
case DeviceType::XLA:
+ case DeviceType::MLC:
case DeviceType::Vulkan:
case DeviceType::Metal:
case DeviceType::XPU:
diff --git a/c10/core/DeviceType.h b/c10/core/DeviceType.h
index a7ea4ae..5bdb5ea 100644
--- a/c10/core/DeviceType.h
+++ b/c10/core/DeviceType.h
@@ -26,11 +26,12 @@
Vulkan = 10, // Vulkan
Metal = 11, // Metal
XPU = 12, // XPU
+ MLC = 13, //ML Compute / Apple
// NB: If you add more devices:
// - Change the implementations of DeviceTypeName and isValidDeviceType
// in DeviceType.cpp
// - Change the number below
- COMPILE_TIME_MAX_DEVICE_TYPES = 13,
+ COMPILE_TIME_MAX_DEVICE_TYPES = 14,
};
constexpr DeviceType kCPU = DeviceType::CPU;
@@ -39,6 +40,7 @@
constexpr DeviceType kFPGA = DeviceType::FPGA;
constexpr DeviceType kMSNPU = DeviceType::MSNPU;
constexpr DeviceType kXLA = DeviceType::XLA;
+constexpr DeviceType kMLC = DeviceType::MLC;
constexpr DeviceType kVulkan = DeviceType::Vulkan;
constexpr DeviceType kMetal = DeviceType::Metal;
constexpr DeviceType kXPU = DeviceType::XPU;
diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp
index 9a2150f..2ae5a87 100644
--- a/c10/core/DispatchKey.cpp
+++ b/c10/core/DispatchKey.cpp
@@ -21,6 +21,8 @@
return "MSNPU";
case DispatchKey::XLA:
return "XLA";
+ case DispatchKey::MLC:
+ return "MLC";
case DispatchKey::Vulkan:
return "Vulkan";
case DispatchKey::Metal:
@@ -80,6 +82,8 @@
return "AutogradCUDA";
case DispatchKey::AutogradXLA:
return "AutogradXLA";
+ case DispatchKey::AutogradMLC:
+ return "AutogradMLC";
case DispatchKey::AutogradNestedTensor:
return "AutogradNestedTensor";
case DispatchKey::AutogradPrivateUse1:
@@ -143,6 +147,8 @@
return DispatchKey::AutogradCUDA;
case DispatchKey::XLA:
return DispatchKey::AutogradXLA;
+ case DispatchKey::MLC:
+ return DispatchKey::AutogradMLC;
case DispatchKey::NestedTensor:
return DispatchKey::AutogradNestedTensor;
case DispatchKey::PrivateUse1:
diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h
index b353c3d..40d952c 100644
--- a/c10/core/DispatchKey.h
+++ b/c10/core/DispatchKey.h
@@ -62,6 +62,7 @@
MSNPU, // unused externally, but tested at
// test/cpp_extensions/msnpu_extension.cpp
XLA, // lives out of tree at https://github.com/pytorch/xla
+ MLC, // lives out of tree at https://github.com/pytorch/MLCompute
Vulkan,
Metal,
XPU, // For out of tree Intel's heterogeneous computing plug-in
@@ -224,9 +225,9 @@
AutogradCPU,
AutogradCUDA,
AutogradXLA,
- AutogradNestedTensor, // lives out of tree at
- // https://github.com/pytorch/nestedtensor
AutogradXPU,
+ AutogradMLC,
+ AutogradNestedTensor, // lives out of tree at https://github.com/pytorch/nestedtensor
// Here are some reserved pre-autograd keys for user-defined backends, see
// Note [Private use DispatchKey]
AutogradPrivateUse1,
diff --git a/c10/core/DispatchKeySet.cpp b/c10/core/DispatchKeySet.cpp
index 60f3ed3..9695545 100644
--- a/c10/core/DispatchKeySet.cpp
+++ b/c10/core/DispatchKeySet.cpp
@@ -14,6 +14,7 @@
DispatchKey::PrivateUse1,
DispatchKey::PrivateUse2,
DispatchKey::PrivateUse3,
+ DispatchKey::MLC,
});
bool isBackendDispatchKey(DispatchKey t) {
@@ -48,6 +49,8 @@
return DispatchKeySet(DispatchKey::CUDA);
case DispatchKey::AutogradXLA:
return DispatchKeySet(DispatchKey::XLA);
+ case DispatchKey::AutogradMLC:
+ return DispatchKeySet(DispatchKey::MLC);
case DispatchKey::AutogradNestedTensor:
return DispatchKeySet(DispatchKey::NestedTensor);
case DispatchKey::AutogradXPU:
diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h
index 4de1292..803b76f 100644
--- a/c10/core/DispatchKeySet.h
+++ b/c10/core/DispatchKeySet.h
@@ -195,6 +195,7 @@
DispatchKey::AutogradCUDA,
DispatchKey::AutogradXLA,
DispatchKey::AutogradNestedTensor,
+ DispatchKey::AutogradMLC,
DispatchKey::AutogradXPU,
DispatchKey::AutogradPrivateUse1,
DispatchKey::AutogradPrivateUse2,
diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h
index c2cb4d5..06b3f4f 100644
--- a/c10/core/TensorImpl.h
+++ b/c10/core/TensorImpl.h
@@ -543,6 +543,10 @@
return key_set_.has(DispatchKey::Metal);
}
+ bool is_mlc() const {
+ return key_set_.has(DispatchKey::MLC);
+ }
+
// TODO: remove this once we don't automatically enabled Autograd dispatch keys
// in TensorImpl constructor.
// DON'T USE THIS API!! It's only created for testing purpose in
diff --git a/c10/core/TensorOptions.h b/c10/core/TensorOptions.h
index 1cb0508..41c5e90 100644
--- a/c10/core/TensorOptions.h
+++ b/c10/core/TensorOptions.h
@@ -629,6 +629,8 @@
return DispatchKey::MSNPU;
case DeviceType::XLA:
return DispatchKey::XLA;
+ case DeviceType::MLC:
+ return DispatchKey::MLC;
case DeviceType::Vulkan:
return DispatchKey::Vulkan;
case DeviceType::Metal:
@@ -687,6 +689,8 @@
return DeviceType::MSNPU;
} else if (tid == DispatchKey::XLA) {
return DeviceType::XLA;
+ } else if (tid == DispatchKey::MLC) {
+ return DeviceType::MLC;
} else if (tid == DispatchKey::SparseCPU) {
return DeviceType::CPU;
} else if (tid == DispatchKey::SparseCUDA) {
diff --git a/caffe2/proto/caffe2.proto b/caffe2/proto/caffe2.proto
index 0b50b0a..9143758 100644
--- a/caffe2/proto/caffe2.proto
+++ b/caffe2/proto/caffe2.proto
@@ -198,8 +198,9 @@
PROTO_FPGA = 7; // FPGA
PROTO_MSNPU = 8; // MSNPU
PROTO_XLA = 9; // XLA / TPU
+ PROTO_MLC = 10; // ML Compute
// Change the following number if you add more devices in the code.
- PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = 10;
+ PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = 11;
}
// Device-specific options. We do not distinguish DeviceOption protos for
diff --git a/torch/_utils.py b/torch/_utils.py
index e09f2d9..fff8d8a 100644
--- a/torch/_utils.py
+++ b/torch/_utils.py
@@ -179,6 +179,12 @@
return tensor
+def _rebuild_mlc_tensor(data, dtype, device, requires_grad):
+ tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
+ tensor.requires_grad = requires_grad
+ return tensor
+
+
def _rebuild_qtensor(storage, storage_offset, size, stride, quantizer_params, requires_grad, backward_hooks):
qscheme = quantizer_params[0]
if qscheme == torch.per_tensor_affine:
diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp
index f3b23c6..e7e3f34 100644
--- a/torch/csrc/autograd/init.cpp
+++ b/torch/csrc/autograd/init.cpp
@@ -95,6 +95,7 @@
.value("FPGA", c10::DeviceType::FPGA)
.value("MSNPU", c10::DeviceType::MSNPU)
.value("XLA", c10::DeviceType::XLA)
+ .value("MLC", c10::DeviceType::MLC)
.value("Vulkan", c10::DeviceType::Vulkan)
.value("Metal", c10::DeviceType::Metal);
diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp
index 27729be..9377be8 100644
--- a/torch/csrc/autograd/python_variable.cpp
+++ b/torch/csrc/autograd/python_variable.cpp
@@ -610,6 +610,17 @@
END_HANDLE_TH_ERRORS
}
+PyObject *THPVariable_is_mlc(THPVariable *self, void *unused)
+{
+ HANDLE_TH_ERRORS
+ if (check_has_torch_function((PyObject *)self)) {
+ return handle_torch_function_getter(self, "is_mlc");
+ }
+ auto& self_ = self->cdata;
+ return torch::autograd::utils::wrap(self_.is_mlc());
+ END_HANDLE_TH_ERRORS
+}
+
PyObject *THPVariable_is_vulkan(THPVariable *self, void *unused)
{
HANDLE_TH_ERRORS
@@ -751,6 +762,7 @@
{"is_xpu", (getter)THPVariable_is_xpu, nullptr, nullptr, nullptr},
{"is_sparse", (getter)THPVariable_is_sparse, nullptr, nullptr, nullptr},
{"is_mkldnn", (getter)THPVariable_is_mkldnn, nullptr, nullptr, nullptr},
+ {"is_mlc", (getter)THPVariable_is_mlc, nullptr, nullptr, nullptr},
{"is_vulkan", (getter)THPVariable_is_vulkan, nullptr, nullptr, nullptr},
{"is_complex", (getter)THPVariable_is_complex, nullptr, nullptr, nullptr},
{"is_quantized", (getter)THPVariable_is_quantized, nullptr, nullptr, nullptr},
diff --git a/torch/csrc/jit/frontend/sugared_value.cpp b/torch/csrc/jit/frontend/sugared_value.cpp
index 0164f2f..1652b4a 100644
--- a/torch/csrc/jit/frontend/sugared_value.cpp
+++ b/torch/csrc/jit/frontend/sugared_value.cpp
@@ -110,6 +110,7 @@
{"is_xpu", "prim"},
{"is_sparse", "prim"},
{"is_mkldnn", "prim"},
+ {"is_mlc", "prim"},
{"is_quantized", "prim"},
{"is_vulkan", "prim"},
{"is_meta", "prim"},
diff --git a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp
index 882cb48..dbcdb33 100644
--- a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp
+++ b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp
@@ -290,6 +290,14 @@
},
aliasAnalysisFromSchema()),
Operator(
+ "prim::is_mlc(Tensor a) -> bool",
+ [](Stack* stack) {
+ at::Tensor a;
+ pop(stack, a);
+ push(stack, a.is_mlc());
+ },
+ aliasAnalysisFromSchema()),
+ Operator(
"prim::is_vulkan(Tensor a) -> bool",
[](Stack* stack) {
at::Tensor a;
diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp
index 93f14bd..3151224 100644
--- a/torch/csrc/utils/tensor_new.cpp
+++ b/torch/csrc/utils/tensor_new.cpp
@@ -61,6 +61,9 @@
return Backend::XLA;
case DeviceType::XPU:
return backendToXPU(b);
+ case DeviceType::MLC:
+ TORCH_CHECK(!isSparse(b), "Sparse not implemented for MLC");
+ return Backend::MLC;
default:
AT_ERROR("Unknown device type");
}
diff --git a/torch/library.h b/torch/library.h
index 93236a9..8aa13ea 100644
--- a/torch/library.h
+++ b/torch/library.h
@@ -292,6 +292,8 @@
return c10::DispatchKey::CUDA;
case c10::DeviceType::XLA:
return c10::DispatchKey::XLA;
+ case c10::DeviceType::MLC:
+ return c10::DispatchKey::MLC;
case c10::DeviceType::HIP:
return c10::DispatchKey::HIP;
case c10::DeviceType::MSNPU:
diff --git a/torch/overrides.py b/torch/overrides.py
index 2f95c04..baec083 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -910,6 +910,7 @@
Tensor.is_xpu.__get__: lambda self: -1,
Tensor.is_leaf.__get__: lambda self: -1,
Tensor.is_meta.__get__: lambda self: -1,
+ Tensor.is_mlc.__get__: lambda self: -1,
Tensor.is_mkldnn.__get__: lambda self: -1,
Tensor.is_quantized.__get__: lambda self: -1,
Tensor.is_sparse.__get__: lambda self: -1,
diff --git a/torch/tensor.py b/torch/tensor.py
index 55dae4e..e2c9e0d 100644
--- a/torch/tensor.py
+++ b/torch/tensor.py
@@ -57,7 +57,7 @@
if id(self) in memo:
return memo[id(self)]
with torch.no_grad():
- if self.is_sparse or self.device.type == 'xla':
+ if self.is_sparse or self.device.type == 'xla' or self.device.type == 'mlc':
new_tensor = self.clone()
else:
new_storage = self.storage().__deepcopy__(memo)
@@ -123,6 +123,12 @@
str(self.device),
self.requires_grad)
return (torch._utils._rebuild_xla_tensor, arg_xla)
+ if self.device.type == 'mlc':
+ arg_mlc = (self.cpu().numpy(),
+ self.dtype,
+ str(self.device),
+ self.requires_grad)
+ return (torch._utils._rebuild_mlc_tensor, arg_mlc)
if self.is_quantized:
# quantizer_params can be different type based on torch attribute
quantizer_params: Union[Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int]]
diff --git a/torch/testing/_internal/distributed/nn/api/remote_module_test.py b/torch/testing/_internal/distributed/nn/api/remote_module_test.py
index c4fc3a6..849382e 100644
--- a/torch/testing/_internal/distributed/nn/api/remote_module_test.py
+++ b/torch/testing/_internal/distributed/nn/api/remote_module_test.py
@@ -257,7 +257,7 @@
with self.assertRaisesRegex(
RuntimeError,
- r"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, xla, vulkan"
+ r"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, mlc, xla, vulkan"
" device type at start of device string",
):
list(