[JIT] Support device as Dict key (#65079)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65079
This is required to use RPC DeviceMap aka Dict[torch.device, torch.device] in torchscript
Test Plan: Imported from OSS
Reviewed By: malfet
Differential Revision: D31072626
Pulled By: pbelevich
fbshipit-source-id: 51cfa5653db86de73b624e9157d68d1b319bfc64
diff --git a/aten/src/ATen/core/Dict_inl.h b/aten/src/ATen/core/Dict_inl.h
index 9e74355..a215d94 100644
--- a/aten/src/ATen/core/Dict_inl.h
+++ b/aten/src/ATen/core/Dict_inl.h
@@ -50,6 +50,8 @@
return std::hash<bool>()(ivalue.toBool());
} else if (ivalue.isTensor()) {
return std::hash<TensorImpl*>()(ivalue.toTensor().unsafeGetTensorImpl());
+ } else if (ivalue.isDevice()) {
+ return std::hash<Device>()(ivalue.toDevice());
} else {
throw std::runtime_error(
"Can't hash IValues with tag '" + ivalue.tagKind() + "'");
diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h
index 4284e29..ad33a33 100644
--- a/aten/src/ATen/core/jit_type.h
+++ b/aten/src/ATen/core/jit_type.h
@@ -824,12 +824,13 @@
case TypeKind::ComplexType:
case TypeKind::StringType:
case TypeKind::TensorType:
+ case TypeKind::DeviceObjType:
return DictTypePtr(new DictType(key, value));
default:
AT_ERROR(
"Cannot create dict for key type '",
key->str(),
- "', only int, float, complex, Tensor and string keys are supported");
+ "', only int, float, complex, Tensor, device and string keys are supported");
}
}
diff --git a/test/jit/test_union.py b/test/jit/test_union.py
index fb53d53..bf1894a 100644
--- a/test/jit/test_union.py
+++ b/test/jit/test_union.py
@@ -370,7 +370,7 @@
return x[1]
with self.assertRaisesRegex(RuntimeError, "only int, float, "
- "complex, Tensor and string keys "
+ "complex, Tensor, device and string keys "
"are supported"):
torch.jit.script(fn)