[jit] Support properties on `Device` (#32953)
Summary:
Stacked PRs
* #32955 - [jit] Fix flipped PackedSequence outputs in script
* **#32953 - [jit] Support properties on `Device`**
PyTorch devices have a `index` and `type` property. This PR adds support for both to TorchScript
](https://our.intern.facebook.com/intern/diff/19849320/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32953
Pulled By: driazati
Differential Revision: D19849320
fbshipit-source-id: ce845258c6110058dd9ea1f759ef74b7ed2e786e
diff --git a/test/test_jit.py b/test/test_jit.py
index 19816ee..8d09919 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -4343,6 +4343,21 @@
cu.define(dedent(inspect.getsource(ok)))
check(cu.ok)
+ def _test_device_type(self, dest):
+ def fn(x):
+ # type: (Device) -> Tuple[str, Optional[int]]
+ return x.type, x.index
+
+ device = torch.ones(2).to(dest).device
+ self.checkScript(fn, [device])
+
+ def test_device_type(self):
+ self._test_device_type('cpu')
+
+ @unittest.skipIf(not RUN_CUDA, "Requires CUDA")
+ def test_device_type_cuda(self):
+ self._test_device_type('cuda')
+
def test_eval_python(self):
def _test(m):
self.assertTrue(m(torch.ones(2, 2)))
diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp
index 68b687b..3c0ae7c 100644
--- a/torch/csrc/jit/register_prim_ops.cpp
+++ b/torch/csrc/jit/register_prim_ops.cpp
@@ -607,6 +607,28 @@
},
aliasAnalysisFromSchema()),
Operator(
+ "prim::type(Device self) -> str",
+ [](Stack& stack) {
+ auto d = pop(stack);
+ push(
+ stack,
+ DeviceTypeName(d.toDevice().type(), /* lower_case=*/true));
+ return 0;
+ },
+ aliasAnalysisFromSchema()),
+ Operator(
+ "prim::index(Device self) -> int?",
+ [](Stack& stack) {
+ auto d = pop(stack).toDevice();
+ if (d.has_index()) {
+ push(stack, d.index());
+ } else {
+ push(stack, IValue());
+ }
+ return 0;
+ },
+ aliasAnalysisFromSchema()),
+ Operator(
// TODO return generator object when torchscript supports RNG
// first-class
"aten::manual_seed(int seed) -> ()",
diff --git a/torch/csrc/jit/script/sugared_value.cpp b/torch/csrc/jit/script/sugared_value.cpp
index 9fd00a6..8fd2f7e1 100644
--- a/torch/csrc/jit/script/sugared_value.cpp
+++ b/torch/csrc/jit/script/sugared_value.cpp
@@ -54,6 +54,16 @@
emitBuiltinCall(loc, *m.graph(), symbol, inputs, attributes, self));
}
+// older versions of gcc/clang have a bug where enums can't be used as keys
+// in a map by default
+// https://stackoverflow.com/questions/18837857/cant-use-enum-class-as-unordered-map-key
+struct EnumClassHash {
+ template <typename T>
+ std::size_t operator()(T t) const {
+ return static_cast<std::size_t>(t);
+ }
+};
+
// support syntax sugar for x.foo(y, z) by allowing x.foo to return a
// callable value that will resolve to foo(x, y, z) when called.
std::shared_ptr<SugaredValue> SimpleValue::attr(
@@ -67,22 +77,31 @@
Symbol::aten(builtin_cast_methods().at(field)),
NamedValue(loc, "self", value_));
}
- // functions that are just direct property lookups on tensor
- // must be registered as prim::<name>(Tensor t) -> <return_type>
- static const std::unordered_set<std::string> fields = {
- "dtype",
- "device",
- "grad",
- "data",
- "shape",
- "is_cuda",
- "is_sparse",
- "is_mkldnn",
- "is_quantized",
- "requires_grad",
- "layout",
- };
- if (fields.count(field)) {
+ }
+ // accessing properties of Tensor and Device that are implemented as
+ // prim:: operators
+ using PropertiesLookup = std::
+ unordered_map<TypeKind, std::unordered_set<std::string>, EnumClassHash>;
+ static const PropertiesLookup builtin_properties = {
+ {TypeKind::TensorType,
+ {
+ "dtype",
+ "device",
+ "grad",
+ "data",
+ "shape",
+ "is_cuda",
+ "is_sparse",
+ "is_mkldnn",
+ "is_quantized",
+ "requires_grad",
+ "layout",
+ }},
+ {TypeKind::DeviceObjType, {"type", "index"}}};
+ auto kind = value_->type()->kind();
+ auto builtin_entry = builtin_properties.find(kind);
+ if (builtin_entry != builtin_properties.end()) {
+ if (builtin_entry->second.count(field) > 0) {
auto r =
m.graph()->insert(Symbol::fromQualString("prim::" + field), {value_});
return std::make_shared<SimpleValue>(r);