[JIT] Support device as Dict key (#65079)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65079

This is required to use RPC DeviceMap aka Dict[torch.device, torch.device] in torchscript

Test Plan: Imported from OSS

Reviewed By: malfet

Differential Revision: D31072626

Pulled By: pbelevich

fbshipit-source-id: 51cfa5653db86de73b624e9157d68d1b319bfc64
diff --git a/aten/src/ATen/core/Dict_inl.h b/aten/src/ATen/core/Dict_inl.h
index 9e74355..a215d94 100644
--- a/aten/src/ATen/core/Dict_inl.h
+++ b/aten/src/ATen/core/Dict_inl.h
@@ -50,6 +50,8 @@
     return std::hash<bool>()(ivalue.toBool());
   } else if (ivalue.isTensor()) {
     return std::hash<TensorImpl*>()(ivalue.toTensor().unsafeGetTensorImpl());
+  } else if (ivalue.isDevice()) {
+    return std::hash<Device>()(ivalue.toDevice());
   } else {
     throw std::runtime_error(
         "Can't hash IValues with tag '" + ivalue.tagKind() + "'");
diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h
index 4284e29..ad33a33 100644
--- a/aten/src/ATen/core/jit_type.h
+++ b/aten/src/ATen/core/jit_type.h
@@ -824,12 +824,13 @@
       case TypeKind::ComplexType:
       case TypeKind::StringType:
       case TypeKind::TensorType:
+      case TypeKind::DeviceObjType:
         return DictTypePtr(new DictType(key, value));
       default:
         AT_ERROR(
             "Cannot create dict for key type '",
             key->str(),
-            "', only int, float, complex, Tensor and string keys are supported");
+            "', only int, float, complex, Tensor, device and string keys are supported");
     }
   }
 
diff --git a/test/jit/test_union.py b/test/jit/test_union.py
index fb53d53..bf1894a 100644
--- a/test/jit/test_union.py
+++ b/test/jit/test_union.py
@@ -370,7 +370,7 @@
             return x[1]
 
         with self.assertRaisesRegex(RuntimeError, "only int, float, "
-                                    "complex, Tensor and string keys "
+                                    "complex, Tensor, device and string keys "
                                     "are supported"):
             torch.jit.script(fn)