Introduce `mlc` device (ML Compute device) to PyTorch's device list (#50634)

Summary:
Apple recently announced ML Compute, a new framework available in macOS Big Sur, which enables users to accelerate the training of neural networks on Mac hardware. This PR is the first on a series of PRs that will enable the integration with ML Compute. Most of the integration code will live on a separate subrepo named `mlc`.
The integration with `mlc` (ML Compute) will be very similar to that of xla. We rely on registering our ops through:

TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
 m.impl_UNBOXED(<op_schema_name>, &customized_op_kernel)
 ...
}

Pull Request resolved: https://github.com/pytorch/pytorch/pull/50634

Reviewed By: malfet

Differential Revision: D26614213

Pulled By: smessmer

fbshipit-source-id: 3b492b346c61cc3950ac880ac01a82fbdddbc07b
diff --git a/.circleci/cimodel/data/pytorch_build_data.py b/.circleci/cimodel/data/pytorch_build_data.py
index 683965e..bd50e43 100644
--- a/.circleci/cimodel/data/pytorch_build_data.py
+++ b/.circleci/cimodel/data/pytorch_build_data.py
@@ -157,6 +157,7 @@
         next_nodes = {
             "asan": AsanConfigNode,
             "xla": XlaConfigNode,
+            "mlc": MLCConfigNode,
             "vulkan": VulkanConfigNode,
             "parallel_tbb": ParallelTBBConfigNode,
             "parallel_native": ParallelNativeConfigNode,
@@ -193,6 +194,16 @@
     def child_constructor(self):
         return ImportantConfigNode
 
+class MLCConfigNode(TreeConfigNode):
+    def modify_label(self, label):
+        return "MLC=" + str(label)
+
+    def init2(self, node_name):
+        self.props["is_mlc"] = node_name
+
+    def child_constructor(self):
+        return ImportantConfigNode
+
 
 class AsanConfigNode(TreeConfigNode):
     def modify_label(self, label):
diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h
index 4dfb316..ecb58c5 100644
--- a/aten/src/ATen/Context.h
+++ b/aten/src/ATen/Context.h
@@ -73,6 +73,9 @@
   bool hasXLA() const {
     return c10::impl::hasDeviceGuardImpl(at::DeviceType::XLA);
   }
+  bool hasMLC() const {
+    return c10::impl::hasDeviceGuardImpl(at::DeviceType::MLC);
+  }
   // defined in header so that getNonVariableType has ability to inline
   // call_once check. getNonVariableType is called fairly frequently
   THCState* lazyInitCUDA() {
@@ -276,6 +279,10 @@
   return globalContext().hasXLA();
 }
 
+static inline bool hasMLC() {
+  return globalContext().hasMLC();
+}
+
 // Despite its name, this function returns the number of *CUDA* GPUs.
 static inline size_t getNumGPUs() {
   // WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS
diff --git a/aten/src/ATen/Version.cpp b/aten/src/ATen/Version.cpp
index ecc5070..fe03506 100644
--- a/aten/src/ATen/Version.cpp
+++ b/aten/src/ATen/Version.cpp
@@ -194,6 +194,7 @@
 
   // TODO: do HIP
   // TODO: do XLA
+  // TODO: do MLC
 
   return ss.str();
 }
diff --git a/aten/src/ATen/core/VariableFallbackKernel.cpp b/aten/src/ATen/core/VariableFallbackKernel.cpp
index baf49ae..c69d5db 100644
--- a/aten/src/ATen/core/VariableFallbackKernel.cpp
+++ b/aten/src/ATen/core/VariableFallbackKernel.cpp
@@ -48,4 +48,8 @@
   m.fallback(torch::CppFunction::makeFallthrough());
 }
 
+TORCH_LIBRARY_IMPL(_, AutogradMLC, m) {
+  m.fallback(torch::CppFunction::makeFallthrough());
+}
+
 }
diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h
index 7298cbc..4b1105c 100644
--- a/aten/src/ATen/core/aten_interned_strings.h
+++ b/aten/src/ATen/core/aten_interned_strings.h
@@ -397,6 +397,7 @@
 _(aten, is_complex) \
 _(aten, is_contiguous) \
 _(aten, is_cuda) \
+_(aten, is_mlc) \
 _(aten, is_distributed) \
 _(aten, is_floating_point) \
 _(aten, is_nonzero) \
diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h
index b3ec12d..96867b0 100644
--- a/aten/src/ATen/templates/TensorBody.h
+++ b/aten/src/ATen/templates/TensorBody.h
@@ -364,6 +364,9 @@
   /// Returns if a `Tensor` is mkldnn tensor.
   bool is_mkldnn() const;
 
+  /// Returns if a `Tensor` is mlc tensor.
+  bool is_mlc() const;
+
   /// Returns if a `Tensor` is vulkan tensor.
   bool is_vulkan() const;
 
diff --git a/aten/src/ATen/templates/TensorMethods.cpp b/aten/src/ATen/templates/TensorMethods.cpp
index 37b9fdd..29998e8 100644
--- a/aten/src/ATen/templates/TensorMethods.cpp
+++ b/aten/src/ATen/templates/TensorMethods.cpp
@@ -145,6 +145,15 @@
   return self.is_mkldnn();
 }
 
+bool Tensor::is_mlc() const {
+  // NB: this is not a native function to avoid dispatching overhead.
+  return impl_->is_mlc();
+}
+
+bool is_mlc(Tensor self) {
+  return self.is_mlc();
+}
+
 bool Tensor::is_vulkan() const {
   // NB: this is not a native function to avoid dispatching overhead.
   return impl_->is_vulkan();
diff --git a/c10/core/Backend.h b/c10/core/Backend.h
index 2746360..072a665 100644
--- a/c10/core/Backend.h
+++ b/c10/core/Backend.h
@@ -45,6 +45,7 @@
   QuantizedXPU,
   Undefined,
   MkldnnCPU,
+  MLC,
   NumOptions
 };
 
@@ -99,6 +100,8 @@
       return Backend::QuantizedCUDA;
     case Backend::QuantizedXPU:
       return Backend::QuantizedXPU;
+    case Backend::MLC:
+      return Backend::MLC;
     default:
       throw std::runtime_error("Unknown backend");
   }
@@ -117,6 +120,8 @@
     return Backend::MSNPU;
   } else if (t == DispatchKey::XLA || t == DispatchKey::AutogradXLA) {
     return Backend::XLA;
+  } else if (t == DispatchKey::MLC || t == DispatchKey::AutogradMLC) {
+    return Backend::MLC;
   } else if (t == DispatchKey::Vulkan) {
     return Backend::Vulkan;
   } else if (t == DispatchKey::Metal) {
@@ -182,6 +187,8 @@
       return DispatchKey::QuantizedCUDA;
     case Backend::Undefined:
       return DispatchKey::Undefined;
+    case Backend::MLC:
+      return DispatchKey::MLC;
     default:
       throw std::runtime_error("Unknown backend");
   }
@@ -220,6 +227,8 @@
       return DeviceType::Vulkan;
     case Backend::Metal:
       return DeviceType::Metal;
+    case Backend::MLC:
+      return DeviceType::MLC;
     case Backend::Undefined:
       AT_ERROR("Undefined backend is not a valid device type");
     default:
@@ -250,6 +259,8 @@
     case Backend::MSNPU:
     case Backend::XLA:
       return Backend::CPU;
+    case Backend::MLC:
+      return Backend::CPU;
     case Backend::MkldnnCPU:
       return Backend::MkldnnCPU;
     case Backend::QuantizedCPU:
@@ -302,6 +313,7 @@
     case Backend::FPGA:
     case Backend::MSNPU:
     case Backend::XLA:
+    case Backend::MLC:
       return Backend::CUDA;
     case Backend::SparseXPU:
     case Backend::SparseCPU:
@@ -324,6 +336,7 @@
     case Backend::FPGA:
     case Backend::MSNPU:
     case Backend::XLA:
+    case Backend::MLC:
       return Backend::HIP;
     case Backend::SparseXPU:
     case Backend::SparseCPU:
@@ -354,6 +367,8 @@
       return "MSNPU";
     case Backend::XLA:
       return "XLA";
+    case Backend::MLC:
+      return "MLC";
     case Backend::SparseCPU:
       return "SparseCPU";
     case Backend::SparseCUDA:
diff --git a/c10/core/Device.cpp b/c10/core/Device.cpp
index 4be5701..a0182ff 100644
--- a/c10/core/Device.cpp
+++ b/c10/core/Device.cpp
@@ -46,6 +46,7 @@
           {"msnpu", DeviceType::MSNPU},
           {"xla", DeviceType::XLA},
           {"vulkan", DeviceType::Vulkan},
+          {"mlc", DeviceType::MLC},
       }};
   auto device = std::find_if(
       types.begin(),
@@ -57,7 +58,7 @@
     return device->second;
   }
   AT_ERROR(
-      "Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, xla, vulkan device type at start of device string: ",
+      "Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, mlc, xla, vulkan device type at start of device string: ",
       device_string);
 }
 } // namespace
diff --git a/c10/core/DeviceType.cpp b/c10/core/DeviceType.cpp
index 008d39d..611d0fe 100644
--- a/c10/core/DeviceType.cpp
+++ b/c10/core/DeviceType.cpp
@@ -27,6 +27,8 @@
       return lower_case ? "msnpu" : "MSNPU";
     case DeviceType::XLA:
       return lower_case ? "xla" : "XLA";
+    case DeviceType::MLC:
+      return lower_case ? "mlc" : "MLC";
     case DeviceType::Vulkan:
       return lower_case ? "vulkan" : "VULKAN";
     case DeviceType::Metal:
@@ -65,6 +67,7 @@
     case DeviceType::FPGA:
     case DeviceType::MSNPU:
     case DeviceType::XLA:
+    case DeviceType::MLC:
     case DeviceType::Vulkan:
     case DeviceType::Metal:
     case DeviceType::XPU:
diff --git a/c10/core/DeviceType.h b/c10/core/DeviceType.h
index a7ea4ae..5bdb5ea 100644
--- a/c10/core/DeviceType.h
+++ b/c10/core/DeviceType.h
@@ -26,11 +26,12 @@
   Vulkan = 10, // Vulkan
   Metal = 11, // Metal
   XPU = 12, // XPU
+  MLC = 13, //ML Compute / Apple
   // NB: If you add more devices:
   //  - Change the implementations of DeviceTypeName and isValidDeviceType
   //    in DeviceType.cpp
   //  - Change the number below
-  COMPILE_TIME_MAX_DEVICE_TYPES = 13,
+  COMPILE_TIME_MAX_DEVICE_TYPES = 14,
 };
 
 constexpr DeviceType kCPU = DeviceType::CPU;
@@ -39,6 +40,7 @@
 constexpr DeviceType kFPGA = DeviceType::FPGA;
 constexpr DeviceType kMSNPU = DeviceType::MSNPU;
 constexpr DeviceType kXLA = DeviceType::XLA;
+constexpr DeviceType kMLC = DeviceType::MLC;
 constexpr DeviceType kVulkan = DeviceType::Vulkan;
 constexpr DeviceType kMetal = DeviceType::Metal;
 constexpr DeviceType kXPU = DeviceType::XPU;
diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp
index 9a2150f..2ae5a87 100644
--- a/c10/core/DispatchKey.cpp
+++ b/c10/core/DispatchKey.cpp
@@ -21,6 +21,8 @@
       return "MSNPU";
     case DispatchKey::XLA:
       return "XLA";
+    case DispatchKey::MLC:
+      return "MLC";
     case DispatchKey::Vulkan:
       return "Vulkan";
     case DispatchKey::Metal:
@@ -80,6 +82,8 @@
       return "AutogradCUDA";
     case DispatchKey::AutogradXLA:
       return "AutogradXLA";
+    case DispatchKey::AutogradMLC:
+      return "AutogradMLC";
     case DispatchKey::AutogradNestedTensor:
       return "AutogradNestedTensor";
     case DispatchKey::AutogradPrivateUse1:
@@ -143,6 +147,8 @@
       return DispatchKey::AutogradCUDA;
     case DispatchKey::XLA:
       return DispatchKey::AutogradXLA;
+    case DispatchKey::MLC:
+      return DispatchKey::AutogradMLC;
     case DispatchKey::NestedTensor:
       return DispatchKey::AutogradNestedTensor;
     case DispatchKey::PrivateUse1:
diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h
index b353c3d..40d952c 100644
--- a/c10/core/DispatchKey.h
+++ b/c10/core/DispatchKey.h
@@ -62,6 +62,7 @@
   MSNPU, // unused externally, but tested at
   // test/cpp_extensions/msnpu_extension.cpp
   XLA, // lives out of tree at https://github.com/pytorch/xla
+  MLC, // lives out of tree at https://github.com/pytorch/MLCompute
   Vulkan,
   Metal,
   XPU, // For out of tree Intel's heterogeneous computing plug-in
@@ -224,9 +225,9 @@
   AutogradCPU,
   AutogradCUDA,
   AutogradXLA,
-  AutogradNestedTensor, // lives out of tree at
-                        // https://github.com/pytorch/nestedtensor
   AutogradXPU,
+  AutogradMLC,
+  AutogradNestedTensor, // lives out of tree at https://github.com/pytorch/nestedtensor
   // Here are some reserved pre-autograd keys for user-defined backends, see
   // Note [Private use DispatchKey]
   AutogradPrivateUse1,
diff --git a/c10/core/DispatchKeySet.cpp b/c10/core/DispatchKeySet.cpp
index 60f3ed3..9695545 100644
--- a/c10/core/DispatchKeySet.cpp
+++ b/c10/core/DispatchKeySet.cpp
@@ -14,6 +14,7 @@
         DispatchKey::PrivateUse1,
         DispatchKey::PrivateUse2,
         DispatchKey::PrivateUse3,
+        DispatchKey::MLC,
     });
 
 bool isBackendDispatchKey(DispatchKey t) {
@@ -48,6 +49,8 @@
       return DispatchKeySet(DispatchKey::CUDA);
     case DispatchKey::AutogradXLA:
       return DispatchKeySet(DispatchKey::XLA);
+    case DispatchKey::AutogradMLC:
+      return DispatchKeySet(DispatchKey::MLC);
     case DispatchKey::AutogradNestedTensor:
       return DispatchKeySet(DispatchKey::NestedTensor);
     case DispatchKey::AutogradXPU:
diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h
index 4de1292..803b76f 100644
--- a/c10/core/DispatchKeySet.h
+++ b/c10/core/DispatchKeySet.h
@@ -195,6 +195,7 @@
     DispatchKey::AutogradCUDA,
     DispatchKey::AutogradXLA,
     DispatchKey::AutogradNestedTensor,
+    DispatchKey::AutogradMLC,
     DispatchKey::AutogradXPU,
     DispatchKey::AutogradPrivateUse1,
     DispatchKey::AutogradPrivateUse2,
diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h
index c2cb4d5..06b3f4f 100644
--- a/c10/core/TensorImpl.h
+++ b/c10/core/TensorImpl.h
@@ -543,6 +543,10 @@
     return key_set_.has(DispatchKey::Metal);
   }
 
+  bool is_mlc() const {
+    return key_set_.has(DispatchKey::MLC);
+  }
+
   // TODO: remove this once we don't automatically enabled Autograd dispatch keys
   //       in TensorImpl constructor.
   // DON'T USE THIS API!! It's only created for testing purpose in
diff --git a/c10/core/TensorOptions.h b/c10/core/TensorOptions.h
index 1cb0508..41c5e90 100644
--- a/c10/core/TensorOptions.h
+++ b/c10/core/TensorOptions.h
@@ -629,6 +629,8 @@
             return DispatchKey::MSNPU;
           case DeviceType::XLA:
             return DispatchKey::XLA;
+          case DeviceType::MLC:
+            return DispatchKey::MLC;
           case DeviceType::Vulkan:
             return DispatchKey::Vulkan;
           case DeviceType::Metal:
@@ -687,6 +689,8 @@
     return DeviceType::MSNPU;
   } else if (tid == DispatchKey::XLA) {
     return DeviceType::XLA;
+  } else if (tid == DispatchKey::MLC) {
+    return DeviceType::MLC;
   } else if (tid == DispatchKey::SparseCPU) {
     return DeviceType::CPU;
   } else if (tid == DispatchKey::SparseCUDA) {
diff --git a/caffe2/proto/caffe2.proto b/caffe2/proto/caffe2.proto
index 0b50b0a..9143758 100644
--- a/caffe2/proto/caffe2.proto
+++ b/caffe2/proto/caffe2.proto
@@ -198,8 +198,9 @@
   PROTO_FPGA = 7;                   // FPGA
   PROTO_MSNPU = 8;                  // MSNPU
   PROTO_XLA = 9;                    // XLA / TPU
+  PROTO_MLC = 10;                   // ML Compute
   // Change the following number if you add more devices in the code.
-  PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = 10;
+  PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = 11;
 }
 
 // Device-specific options. We do not distinguish DeviceOption protos for
diff --git a/torch/_utils.py b/torch/_utils.py
index e09f2d9..fff8d8a 100644
--- a/torch/_utils.py
+++ b/torch/_utils.py
@@ -179,6 +179,12 @@
     return tensor
 
 
+def _rebuild_mlc_tensor(data, dtype, device, requires_grad):
+    tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
+    tensor.requires_grad = requires_grad
+    return tensor
+
+
 def _rebuild_qtensor(storage, storage_offset, size, stride, quantizer_params, requires_grad, backward_hooks):
     qscheme = quantizer_params[0]
     if qscheme == torch.per_tensor_affine:
diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp
index f3b23c6..e7e3f34 100644
--- a/torch/csrc/autograd/init.cpp
+++ b/torch/csrc/autograd/init.cpp
@@ -95,6 +95,7 @@
       .value("FPGA", c10::DeviceType::FPGA)
       .value("MSNPU", c10::DeviceType::MSNPU)
       .value("XLA", c10::DeviceType::XLA)
+      .value("MLC", c10::DeviceType::MLC)
       .value("Vulkan", c10::DeviceType::Vulkan)
       .value("Metal", c10::DeviceType::Metal);
 
diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp
index 27729be..9377be8 100644
--- a/torch/csrc/autograd/python_variable.cpp
+++ b/torch/csrc/autograd/python_variable.cpp
@@ -610,6 +610,17 @@
   END_HANDLE_TH_ERRORS
 }
 
+PyObject *THPVariable_is_mlc(THPVariable *self, void *unused)
+{
+  HANDLE_TH_ERRORS
+  if (check_has_torch_function((PyObject *)self)) {
+    return handle_torch_function_getter(self, "is_mlc");
+  }
+  auto& self_ = self->cdata;
+  return torch::autograd::utils::wrap(self_.is_mlc());
+  END_HANDLE_TH_ERRORS
+}
+
 PyObject *THPVariable_is_vulkan(THPVariable *self, void *unused)
 {
   HANDLE_TH_ERRORS
@@ -751,6 +762,7 @@
   {"is_xpu", (getter)THPVariable_is_xpu, nullptr, nullptr, nullptr},
   {"is_sparse", (getter)THPVariable_is_sparse, nullptr, nullptr, nullptr},
   {"is_mkldnn", (getter)THPVariable_is_mkldnn, nullptr, nullptr, nullptr},
+  {"is_mlc", (getter)THPVariable_is_mlc, nullptr, nullptr, nullptr},
   {"is_vulkan", (getter)THPVariable_is_vulkan, nullptr, nullptr, nullptr},
   {"is_complex", (getter)THPVariable_is_complex, nullptr, nullptr, nullptr},
   {"is_quantized", (getter)THPVariable_is_quantized, nullptr, nullptr, nullptr},
diff --git a/torch/csrc/jit/frontend/sugared_value.cpp b/torch/csrc/jit/frontend/sugared_value.cpp
index 0164f2f..1652b4a 100644
--- a/torch/csrc/jit/frontend/sugared_value.cpp
+++ b/torch/csrc/jit/frontend/sugared_value.cpp
@@ -110,6 +110,7 @@
            {"is_xpu", "prim"},
            {"is_sparse", "prim"},
            {"is_mkldnn", "prim"},
+           {"is_mlc", "prim"},
            {"is_quantized", "prim"},
            {"is_vulkan", "prim"},
            {"is_meta", "prim"},
diff --git a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp
index 882cb48..dbcdb33 100644
--- a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp
+++ b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp
@@ -290,6 +290,14 @@
          },
          aliasAnalysisFromSchema()),
      Operator(
+         "prim::is_mlc(Tensor a) -> bool",
+         [](Stack* stack) {
+           at::Tensor a;
+           pop(stack, a);
+           push(stack, a.is_mlc());
+         },
+         aliasAnalysisFromSchema()),
+     Operator(
          "prim::is_vulkan(Tensor a) -> bool",
          [](Stack* stack) {
            at::Tensor a;
diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp
index 93f14bd..3151224 100644
--- a/torch/csrc/utils/tensor_new.cpp
+++ b/torch/csrc/utils/tensor_new.cpp
@@ -61,6 +61,9 @@
       return Backend::XLA;
     case DeviceType::XPU:
       return backendToXPU(b);
+    case DeviceType::MLC:
+      TORCH_CHECK(!isSparse(b), "Sparse not implemented for MLC");
+      return Backend::MLC;
     default:
       AT_ERROR("Unknown device type");
   }
diff --git a/torch/library.h b/torch/library.h
index 93236a9..8aa13ea 100644
--- a/torch/library.h
+++ b/torch/library.h
@@ -292,6 +292,8 @@
         return c10::DispatchKey::CUDA;
       case c10::DeviceType::XLA:
         return c10::DispatchKey::XLA;
+      case c10::DeviceType::MLC:
+        return c10::DispatchKey::MLC;
       case c10::DeviceType::HIP:
         return c10::DispatchKey::HIP;
       case c10::DeviceType::MSNPU:
diff --git a/torch/overrides.py b/torch/overrides.py
index 2f95c04..baec083 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -910,6 +910,7 @@
         Tensor.is_xpu.__get__: lambda self: -1,
         Tensor.is_leaf.__get__: lambda self: -1,
         Tensor.is_meta.__get__: lambda self: -1,
+        Tensor.is_mlc.__get__: lambda self: -1,
         Tensor.is_mkldnn.__get__: lambda self: -1,
         Tensor.is_quantized.__get__: lambda self: -1,
         Tensor.is_sparse.__get__: lambda self: -1,
diff --git a/torch/tensor.py b/torch/tensor.py
index 55dae4e..e2c9e0d 100644
--- a/torch/tensor.py
+++ b/torch/tensor.py
@@ -57,7 +57,7 @@
         if id(self) in memo:
             return memo[id(self)]
         with torch.no_grad():
-            if self.is_sparse or self.device.type == 'xla':
+            if self.is_sparse or self.device.type == 'xla' or self.device.type == 'mlc':
                 new_tensor = self.clone()
             else:
                 new_storage = self.storage().__deepcopy__(memo)
@@ -123,6 +123,12 @@
                        str(self.device),
                        self.requires_grad)
             return (torch._utils._rebuild_xla_tensor, arg_xla)
+        if self.device.type == 'mlc':
+            arg_mlc = (self.cpu().numpy(),
+                       self.dtype,
+                       str(self.device),
+                       self.requires_grad)
+            return (torch._utils._rebuild_mlc_tensor, arg_mlc)
         if self.is_quantized:
             # quantizer_params can be different type based on torch attribute
             quantizer_params: Union[Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int]]
diff --git a/torch/testing/_internal/distributed/nn/api/remote_module_test.py b/torch/testing/_internal/distributed/nn/api/remote_module_test.py
index c4fc3a6..849382e 100644
--- a/torch/testing/_internal/distributed/nn/api/remote_module_test.py
+++ b/torch/testing/_internal/distributed/nn/api/remote_module_test.py
@@ -257,7 +257,7 @@
 
         with self.assertRaisesRegex(
             RuntimeError,
-            r"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, xla, vulkan"
+            r"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, mlc, xla, vulkan"
             " device type at start of device string",
         ):
             list(