Better error message for mismatched dict key type (#22231)

Summary:
](https://our.intern.facebook.com/intern/diff/15993936/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22231

Pulled By: driazati

Differential Revision: D15993936

fbshipit-source-id: 6822ef01477a3b32beb8c037a621fa71abd022c8
diff --git a/test/test_jit.py b/test/test_jit.py
index 14ce7ff..95886ac 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -15994,6 +15994,13 @@
 
         self.assertEqual(fn(), {'ok': 10})
 
+    def test_key_type(self):
+        with self.assertRaisesRegex(RuntimeError, "Expected key type 'None' to subtype"):
+            @torch.jit.script
+            def fn(a):
+                # type: (Dict[str, int]) -> int
+                return a[None]
+
     def test_loop(self):
         @torch.jit.script
         def fn(x):
diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp
index d9e613e..d56ed48 100644
--- a/torch/csrc/jit/script/compiler.cpp
+++ b/torch/csrc/jit/script/compiler.cpp
@@ -1339,7 +1339,7 @@
       throw ErrorReport(stmt)
           << "List of iterables is not supported currently.";
     }
-    // Emit loop information for builtinFunction values like range(), zip(), 
+    // Emit loop information for builtinFunction values like range(), zip(),
     // enumerate() or SimpleValue like List, Tensor, Dict, etc.
     SugaredValuePtr sv = emitSugaredExpr(itrs[0], 1);
 
@@ -1357,7 +1357,7 @@
         sv = std::make_shared<SimpleValue>(
           graph->insert(aten::keys, {siv->getValue()}, {}, stmt.range()));
       }
-      emitLoopCommon(stmt.range(), body, sv, targets, {}); 
+      emitLoopCommon(stmt.range(), body, sv, targets, {});
       return;
     }
 
@@ -2260,7 +2260,7 @@
         if (input_size == 2) {
           start_index = emitSugaredExpr(inputs[1], 1)->asValue(loc, method);
         }
-        
+
         if (input_size > 2) {
           throw ErrorReport(loc)
             << "enumerate expected at most 2 arguments, got " << input_size;
@@ -2805,7 +2805,15 @@
       Value* dict_val,
       Value* key_val) {
     auto dict_type = dict_val->type()->cast<DictType>();
-    AT_ASSERT(key_val->type()->isSubtypeOf(dict_type->getKeyType()));
+
+    if (!key_val->type()->isSubtypeOf(dict_type->getKeyType())) {
+      throw ErrorReport(loc)
+          << "Expected key type '" << key_val->type()->python_str()
+          << "' to subtype the key type '"
+          << dict_type->getKeyType()->python_str() << "' of the dict '"
+          << dict_type->python_str() << "'";
+    }
+
     return graph->insertNode(graph->createDictIndex(dict_val, key_val))
         ->output();
   }