Better error message for mismatched dict key type (#22231)
Summary:
](https://our.intern.facebook.com/intern/diff/15993936/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22231
Pulled By: driazati
Differential Revision: D15993936
fbshipit-source-id: 6822ef01477a3b32beb8c037a621fa71abd022c8
diff --git a/test/test_jit.py b/test/test_jit.py
index 14ce7ff..95886ac 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -15994,6 +15994,13 @@
self.assertEqual(fn(), {'ok': 10})
+ def test_key_type(self):
+ with self.assertRaisesRegex(RuntimeError, "Expected key type 'None' to subtype"):
+ @torch.jit.script
+ def fn(a):
+ # type: (Dict[str, int]) -> int
+ return a[None]
+
def test_loop(self):
@torch.jit.script
def fn(x):
diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp
index d9e613e..d56ed48 100644
--- a/torch/csrc/jit/script/compiler.cpp
+++ b/torch/csrc/jit/script/compiler.cpp
@@ -1339,7 +1339,7 @@
throw ErrorReport(stmt)
<< "List of iterables is not supported currently.";
}
- // Emit loop information for builtinFunction values like range(), zip(),
+ // Emit loop information for builtinFunction values like range(), zip(),
// enumerate() or SimpleValue like List, Tensor, Dict, etc.
SugaredValuePtr sv = emitSugaredExpr(itrs[0], 1);
@@ -1357,7 +1357,7 @@
sv = std::make_shared<SimpleValue>(
graph->insert(aten::keys, {siv->getValue()}, {}, stmt.range()));
}
- emitLoopCommon(stmt.range(), body, sv, targets, {});
+ emitLoopCommon(stmt.range(), body, sv, targets, {});
return;
}
@@ -2260,7 +2260,7 @@
if (input_size == 2) {
start_index = emitSugaredExpr(inputs[1], 1)->asValue(loc, method);
}
-
+
if (input_size > 2) {
throw ErrorReport(loc)
<< "enumerate expected at most 2 arguments, got " << input_size;
@@ -2805,7 +2805,15 @@
Value* dict_val,
Value* key_val) {
auto dict_type = dict_val->type()->cast<DictType>();
- AT_ASSERT(key_val->type()->isSubtypeOf(dict_type->getKeyType()));
+
+ if (!key_val->type()->isSubtypeOf(dict_type->getKeyType())) {
+ throw ErrorReport(loc)
+ << "Expected key type '" << key_val->type()->python_str()
+ << "' to subtype the key type '"
+ << dict_type->getKeyType()->python_str() << "' of the dict '"
+ << dict_type->python_str() << "'";
+ }
+
return graph->insertNode(graph->createDictIndex(dict_val, key_val))
->output();
}