Add unsupported types to schema type parser (#28181)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28181
These types are needed to parse the schemas from native_functions.yaml.
Note: This doesn't actually add the functionality to JIT, it only makes the parser pass.
ghstack-source-id: 92436989
Test Plan: waitforsandcastle
Differential Revision: D17969014
fbshipit-source-id: 41ebe256baec81ed8fb165e7b7cffa5160d285c3
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index db341d9..beee630 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -66,11 +66,11 @@
variants: method
supports_named_tensor: True
-- func: align_to(Tensor(a) self, DimnameList names) -> Tensor(a)
+- func: align_to(Tensor(a) self, Dimname[] names) -> Tensor(a)
variants: method
supports_named_tensor: True
-- func: align_to(Tensor(a) self, DimnameList order, int ellipsis_idx) -> Tensor(a)
+- func: align_to(Tensor(a) self, Dimname[] order, int ellipsis_idx) -> Tensor(a)
variants: method
supports_named_tensor: True
@@ -83,15 +83,15 @@
use_c10_dispatcher: unboxed_only
supports_named_tensor: True
-- func: refine_names(Tensor(a) self, DimnameList names) -> Tensor(a)
+- func: refine_names(Tensor(a) self, Dimname[] names) -> Tensor(a)
variants: method
supports_named_tensor: True
-- func: unflatten(Tensor self, Dimname dim, int[] sizes, DimnameList names) -> Tensor
+- func: unflatten(Tensor self, Dimname dim, int[] sizes, Dimname[] names) -> Tensor
variants: method
supports_named_tensor: True
-- func: unflatten(Tensor self, int dim, int[] sizes, DimnameList names) -> Tensor
+- func: unflatten(Tensor self, int dim, int[] sizes, Dimname[] names) -> Tensor
variants: method
supports_named_tensor: True
@@ -1241,7 +1241,7 @@
variants: function, method
supports_named_tensor: True
-- func: flatten.DimnameList(Tensor self, DimnameList dims, Dimname out_dim) -> Tensor
+- func: flatten.DimnameList(Tensor self, Dimname[] dims, Dimname out_dim) -> Tensor
variants: function, method
supports_named_tensor: True
diff --git a/aten/src/ATen/native_parse.py b/aten/src/ATen/native_parse.py
index b493b5c..be2b69d 100644
--- a/aten/src/ATen/native_parse.py
+++ b/aten/src/ATen/native_parse.py
@@ -69,6 +69,9 @@
elif t == 'int64_t?':
raise RuntimeError("Please use int? and not int64_t?. "
"See [temp translations] for details.")
+ # Enables Dimname[] by translating to legacy DimnameList.
+ elif t == 'Dimname[]':
+ t = 'DimnameList'
elif t == 'Dimname[]?':
t = 'DimnameList?'
# Enables float by translating to legacy double.
diff --git a/torch/csrc/jit/script/schema_type_parser.cpp b/torch/csrc/jit/script/schema_type_parser.cpp
index 718f555..1473aaa 100644
--- a/torch/csrc/jit/script/schema_type_parser.cpp
+++ b/torch/csrc/jit/script/schema_type_parser.cpp
@@ -33,10 +33,13 @@
TypeAndAlias SchemaTypeParser::parseBaseType() {
static std::unordered_map<std::string, TypePtr> type_map = {
{"Generator", GeneratorType::get()},
+ {"Dimname", StringType::get()},
{"ScalarType", IntType::get()},
{"Layout", IntType::get()},
{"MemoryFormat", IntType::get()},
+ {"Storage", IntType::get()},
{"QScheme", IntType::get()},
+ {"ConstQuantizerPtr", IntType::get()}, // TODO This type should be removed from the schema parser, it should use the custom class mechanism instead. @jerryzh
{"Device", DeviceObjType::get()},
{"Scalar", NumberType::get()},
{"str", StringType::get()},