A real fix for issue #250 (failure with mock) (#295)
Fixes #250
The main idea here is optimizing generics for cases where a type information is added to existing code.
Now:
* ``Node[int]`` and ``Node`` have identical ``__bases__`` and identical ``__mro__[1:]`` (except for the first item, since it is the class itself).
* After addition of typing information (i.e. making some classes generic), ``__mro__`` is changed very little, at most one bare ``Generic`` appears in ``__mro__``.
* Consequently, only non-parameterized generics appear in ``__bases__`` and ``__mro__[1:]``.
Interestingly, this could be achieved in few lines of code and no existing test break.
On the positive side of this approach, there is very little chance that existing code (even with sophisticated "magic") will break after addition of typing information.
On the negative side, it will be more difficult for _runtime_ type-checkers to perform decorator-based type checks (e.g. enforce method overriding only by consistent methods). Essentially, now type erasure happens partially at the class creation time (all bases are reduced to origin).
(We have __orig_class__ and __orig_bases__ to help runtime checkers.)
diff --git a/python2/test_typing.py b/python2/test_typing.py
index 866f1b5..4b0da3a 100644
--- a/python2/test_typing.py
+++ b/python2/test_typing.py
@@ -627,6 +627,62 @@
class MM2(collections_abc.MutableMapping, MutableMapping[str, str]):
pass
+ def test_orig_bases(self):
+ T = TypeVar('T')
+ class C(typing.Dict[str, T]): pass
+ self.assertEqual(C.__orig_bases__, (typing.Dict[str, T],))
+
+ def test_naive_runtime_checks(self):
+ def naive_dict_check(obj, tp):
+ # Check if a dictionary conforms to Dict type
+ if len(tp.__parameters__) > 0:
+ raise NotImplementedError
+ if tp.__args__:
+ KT, VT = tp.__args__
+ return all(isinstance(k, KT) and isinstance(v, VT)
+ for k, v in obj.items())
+ self.assertTrue(naive_dict_check({'x': 1}, typing.Dict[typing.Text, int]))
+ self.assertFalse(naive_dict_check({1: 'x'}, typing.Dict[typing.Text, int]))
+ with self.assertRaises(NotImplementedError):
+ naive_dict_check({1: 'x'}, typing.Dict[typing.Text, T])
+
+ def naive_generic_check(obj, tp):
+ # Check if an instance conforms to the generic class
+ if not hasattr(obj, '__orig_class__'):
+ raise NotImplementedError
+ return obj.__orig_class__ == tp
+ class Node(Generic[T]): pass
+ self.assertTrue(naive_generic_check(Node[int](), Node[int]))
+ self.assertFalse(naive_generic_check(Node[str](), Node[int]))
+ self.assertFalse(naive_generic_check(Node[str](), List))
+ with self.assertRaises(NotImplementedError):
+ naive_generic_check([1,2,3], Node[int])
+
+ def naive_list_base_check(obj, tp):
+ # Check if list conforms to a List subclass
+ return all(isinstance(x, tp.__orig_bases__[0].__args__[0])
+ for x in obj)
+ class C(List[int]): pass
+ self.assertTrue(naive_list_base_check([1, 2, 3], C))
+ self.assertFalse(naive_list_base_check(['a', 'b'], C))
+
+ def test_multi_subscr_base(self):
+ T = TypeVar('T')
+ U = TypeVar('U')
+ V = TypeVar('V')
+ class C(List[T][U][V]): pass
+ class D(C, List[T][U][V]): pass
+ self.assertEqual(C.__parameters__, (V,))
+ self.assertEqual(D.__parameters__, (V,))
+ self.assertEqual(C[int].__parameters__, ())
+ self.assertEqual(D[int].__parameters__, ())
+ self.assertEqual(C[int].__args__, (int,))
+ self.assertEqual(D[int].__args__, (int,))
+ self.assertEqual(C.__bases__, (List,))
+ self.assertEqual(D.__bases__, (C, List))
+ self.assertEqual(C.__orig_bases__, (List[T][U][V],))
+ self.assertEqual(D.__orig_bases__, (C, List[T][U][V]))
+
def test_pickle(self):
global C # pickle wants to reference the class by name
T = TypeVar('T')
diff --git a/python2/typing.py b/python2/typing.py
index 1cd7cb8..6bd4c3d 100644
--- a/python2/typing.py
+++ b/python2/typing.py
@@ -1069,13 +1069,7 @@
"""Metaclass for generic types."""
def __new__(cls, name, bases, namespace,
- tvars=None, args=None, origin=None, extra=None):
- if extra is None:
- extra = namespace.get('__extra__')
- if extra is not None and type(extra) is abc.ABCMeta and extra not in bases:
- bases = (extra,) + bases
- self = super(GenericMeta, cls).__new__(cls, name, bases, namespace)
-
+ tvars=None, args=None, origin=None, extra=None, orig_bases=None):
if tvars is not None:
# Called from __getitem__() below.
assert origin is not None
@@ -1116,12 +1110,27 @@
", ".join(str(g) for g in gvars)))
tvars = gvars
+ initial_bases = bases
+ if extra is None:
+ extra = namespace.get('__extra__')
+ if extra is not None and type(extra) is abc.ABCMeta and extra not in bases:
+ bases = (extra,) + bases
+ bases = tuple(_gorg(b) if isinstance(b, GenericMeta) else b for b in bases)
+
+ # remove bare Generic from bases if there are other generic bases
+ if any(isinstance(b, GenericMeta) and b is not Generic for b in bases):
+ bases = tuple(b for b in bases if b is not Generic)
+ self = super(GenericMeta, cls).__new__(cls, name, bases, namespace)
+
self.__parameters__ = tvars
self.__args__ = args
self.__origin__ = origin
self.__extra__ = extra
# Speed hack (https://github.com/python/typing/issues/196).
self.__next_in_mro__ = _next_in_mro(self)
+ # Preserve base classes on subclassing (__bases__ are type erased now).
+ if orig_bases is None:
+ self.__orig_bases__ = initial_bases
# This allows unparameterized generic collections to be used
# with issubclass() and isinstance() in the same way as their
@@ -1216,12 +1225,13 @@
tvars = _type_vars(params)
args = params
return self.__class__(self.__name__,
- (self,) + self.__bases__,
+ self.__bases__,
dict(self.__dict__),
tvars=tvars,
args=args,
origin=self,
- extra=self.__extra__)
+ extra=self.__extra__,
+ orig_bases=self.__orig_bases__)
def __instancecheck__(self, instance):
# Since we extend ABC.__subclasscheck__ and
@@ -1268,6 +1278,10 @@
else:
origin = _gorg(cls)
obj = cls.__next_in_mro__.__new__(origin)
+ try:
+ obj.__orig_class__ = cls
+ except AttributeError:
+ pass
obj.__init__(*args, **kwds)
return obj
@@ -1438,6 +1452,7 @@
attr != '__next_in_mro__' and
attr != '__parameters__' and
attr != '__origin__' and
+ attr != '__orig_bases__' and
attr != '__extra__' and
attr != '__module__'):
attrs.add(attr)
diff --git a/src/test_typing.py b/src/test_typing.py
index 052e8bc..9159149 100644
--- a/src/test_typing.py
+++ b/src/test_typing.py
@@ -654,6 +654,63 @@
class MM2(collections_abc.MutableMapping, MutableMapping[str, str]):
pass
+ def test_orig_bases(self):
+ T = TypeVar('T')
+ class C(typing.Dict[str, T]): ...
+ self.assertEqual(C.__orig_bases__, (typing.Dict[str, T],))
+
+ def test_naive_runtime_checks(self):
+ def naive_dict_check(obj, tp):
+ # Check if a dictionary conforms to Dict type
+ if len(tp.__parameters__) > 0:
+ raise NotImplementedError
+ if tp.__args__:
+ KT, VT = tp.__args__
+ return all(isinstance(k, KT) and isinstance(v, VT)
+ for k, v in obj.items())
+ self.assertTrue(naive_dict_check({'x': 1}, typing.Dict[str, int]))
+ self.assertFalse(naive_dict_check({1: 'x'}, typing.Dict[str, int]))
+ with self.assertRaises(NotImplementedError):
+ naive_dict_check({1: 'x'}, typing.Dict[str, T])
+
+ def naive_generic_check(obj, tp):
+ # Check if an instance conforms to the generic class
+ if not hasattr(obj, '__orig_class__'):
+ raise NotImplementedError
+ return obj.__orig_class__ == tp
+ class Node(Generic[T]): ...
+ self.assertTrue(naive_generic_check(Node[int](), Node[int]))
+ self.assertFalse(naive_generic_check(Node[str](), Node[int]))
+ self.assertFalse(naive_generic_check(Node[str](), List))
+ with self.assertRaises(NotImplementedError):
+ naive_generic_check([1,2,3], Node[int])
+
+ def naive_list_base_check(obj, tp):
+ # Check if list conforms to a List subclass
+ return all(isinstance(x, tp.__orig_bases__[0].__args__[0])
+ for x in obj)
+ class C(List[int]): ...
+ self.assertTrue(naive_list_base_check([1, 2, 3], C))
+ self.assertFalse(naive_list_base_check(['a', 'b'], C))
+
+ def test_multi_subscr_base(self):
+ T = TypeVar('T')
+ U = TypeVar('U')
+ V = TypeVar('V')
+ class C(List[T][U][V]): ...
+ class D(C, List[T][U][V]): ...
+ self.assertEqual(C.__parameters__, (V,))
+ self.assertEqual(D.__parameters__, (V,))
+ self.assertEqual(C[int].__parameters__, ())
+ self.assertEqual(D[int].__parameters__, ())
+ self.assertEqual(C[int].__args__, (int,))
+ self.assertEqual(D[int].__args__, (int,))
+ self.assertEqual(C.__bases__, (List,))
+ self.assertEqual(D.__bases__, (C, List))
+ self.assertEqual(C.__orig_bases__, (List[T][U][V],))
+ self.assertEqual(D.__orig_bases__, (C, List[T][U][V]))
+
+
def test_pickle(self):
global C # pickle wants to reference the class by name
T = TypeVar('T')
diff --git a/src/typing.py b/src/typing.py
index 1f95a5d..188c94e 100644
--- a/src/typing.py
+++ b/src/typing.py
@@ -959,11 +959,7 @@
"""Metaclass for generic types."""
def __new__(cls, name, bases, namespace,
- tvars=None, args=None, origin=None, extra=None):
- if extra is not None and type(extra) is abc.ABCMeta and extra not in bases:
- bases = (extra,) + bases
- self = super().__new__(cls, name, bases, namespace, _root=True)
-
+ tvars=None, args=None, origin=None, extra=None, orig_bases=None):
if tvars is not None:
# Called from __getitem__() below.
assert origin is not None
@@ -1004,12 +1000,25 @@
", ".join(str(g) for g in gvars)))
tvars = gvars
+ initial_bases = bases
+ if extra is not None and type(extra) is abc.ABCMeta and extra not in bases:
+ bases = (extra,) + bases
+ bases = tuple(_gorg(b) if isinstance(b, GenericMeta) else b for b in bases)
+
+ # remove bare Generic from bases if there are other generic bases
+ if any(isinstance(b, GenericMeta) and b is not Generic for b in bases):
+ bases = tuple(b for b in bases if b is not Generic)
+ self = super().__new__(cls, name, bases, namespace, _root=True)
+
self.__parameters__ = tvars
self.__args__ = args
self.__origin__ = origin
self.__extra__ = extra
# Speed hack (https://github.com/python/typing/issues/196).
self.__next_in_mro__ = _next_in_mro(self)
+ # Preserve base classes on subclassing (__bases__ are type erased now).
+ if orig_bases is None:
+ self.__orig_bases__ = initial_bases
# This allows unparameterized generic collections to be used
# with issubclass() and isinstance() in the same way as their
@@ -1104,12 +1113,13 @@
tvars = _type_vars(params)
args = params
return self.__class__(self.__name__,
- (self,) + self.__bases__,
+ self.__bases__,
dict(self.__dict__),
tvars=tvars,
args=args,
origin=self,
- extra=self.__extra__)
+ extra=self.__extra__,
+ orig_bases=self.__orig_bases__)
def __instancecheck__(self, instance):
# Since we extend ABC.__subclasscheck__ and
@@ -1153,6 +1163,10 @@
else:
origin = _gorg(cls)
obj = cls.__next_in_mro__.__new__(origin)
+ try:
+ obj.__orig_class__ = cls
+ except AttributeError:
+ pass
obj.__init__(*args, **kwds)
return obj
@@ -1521,6 +1535,7 @@
attr != '__next_in_mro__' and
attr != '__parameters__' and
attr != '__origin__' and
+ attr != '__orig_bases__' and
attr != '__extra__' and
attr != '__module__'):
attrs.add(attr)