blob: 1a417b85eacafc940791c3a43ecd747dd28dd5be [file] [log] [blame]
# Copyright (c) 2017-2021 hippo91 <guillaume.peillex@gmail.com>
# Copyright (c) 2017-2018, 2020 Claudiu Popa <pcmanticore@gmail.com>
# Copyright (c) 2018 Bryce Guinta <bryce.paul.guinta@gmail.com>
# Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk>
# Copyright (c) 2021 Pierre Sassoulas <pierre.sassoulas@gmail.com>
# Copyright (c) 2021 Daniƫl van Noord <13665637+DanielNoord@users.noreply.github.com>
# Copyright (c) 2021 Marc Mueller <30130371+cdce8p@users.noreply.github.com>
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/PyCQA/astroid/blob/main/LICENSE
import unittest
try:
import numpy # pylint: disable=unused-import
HAS_NUMPY = True
except ImportError:
HAS_NUMPY = False
from astroid import builder, nodes
from astroid.brain.brain_numpy_utils import (
NUMPY_VERSION_TYPE_HINTS_SUPPORT,
numpy_supports_type_hints,
)
@unittest.skipUnless(HAS_NUMPY, "This test requires the numpy library.")
class NumpyBrainNdarrayTest(unittest.TestCase):
"""
Test that calls to numpy functions returning arrays are correctly inferred
"""
ndarray_returning_ndarray_methods = (
"__abs__",
"__add__",
"__and__",
"__array__",
"__array_wrap__",
"__copy__",
"__deepcopy__",
"__eq__",
"__floordiv__",
"__ge__",
"__gt__",
"__iadd__",
"__iand__",
"__ifloordiv__",
"__ilshift__",
"__imod__",
"__imul__",
"__invert__",
"__ior__",
"__ipow__",
"__irshift__",
"__isub__",
"__itruediv__",
"__ixor__",
"__le__",
"__lshift__",
"__lt__",
"__matmul__",
"__mod__",
"__mul__",
"__ne__",
"__neg__",
"__or__",
"__pos__",
"__pow__",
"__rshift__",
"__sub__",
"__truediv__",
"__xor__",
"all",
"any",
"argmax",
"argmin",
"argpartition",
"argsort",
"astype",
"byteswap",
"choose",
"clip",
"compress",
"conj",
"conjugate",
"copy",
"cumprod",
"cumsum",
"diagonal",
"dot",
"flatten",
"getfield",
"max",
"mean",
"min",
"newbyteorder",
"prod",
"ptp",
"ravel",
"repeat",
"reshape",
"round",
"searchsorted",
"squeeze",
"std",
"sum",
"swapaxes",
"take",
"trace",
"transpose",
"var",
"view",
)
def _inferred_ndarray_method_call(self, func_name):
node = builder.extract_node(
f"""
import numpy as np
test_array = np.ndarray((2, 2))
test_array.{func_name:s}()
"""
)
return node.infer()
def _inferred_ndarray_attribute(self, attr_name):
node = builder.extract_node(
f"""
import numpy as np
test_array = np.ndarray((2, 2))
test_array.{attr_name:s}
"""
)
return node.infer()
def test_numpy_function_calls_inferred_as_ndarray(self):
"""
Test that some calls to numpy functions are inferred as numpy.ndarray
"""
licit_array_types = ".ndarray"
for func_ in self.ndarray_returning_ndarray_methods:
with self.subTest(typ=func_):
inferred_values = list(self._inferred_ndarray_method_call(func_))
self.assertTrue(
len(inferred_values) == 1,
msg=f"Too much inferred value for {func_:s}",
)
self.assertTrue(
inferred_values[-1].pytype() in licit_array_types,
msg=f"Illicit type for {func_:s} ({inferred_values[-1].pytype()})",
)
def test_numpy_ndarray_attribute_inferred_as_ndarray(self):
"""
Test that some numpy ndarray attributes are inferred as numpy.ndarray
"""
licit_array_types = ".ndarray"
for attr_ in ("real", "imag", "shape", "T"):
with self.subTest(typ=attr_):
inferred_values = list(self._inferred_ndarray_attribute(attr_))
self.assertTrue(
len(inferred_values) == 1,
msg=f"Too much inferred value for {attr_:s}",
)
self.assertTrue(
inferred_values[-1].pytype() in licit_array_types,
msg=f"Illicit type for {attr_:s} ({inferred_values[-1].pytype()})",
)
@unittest.skipUnless(
HAS_NUMPY and numpy_supports_type_hints(),
f"This test requires the numpy library with a version above {NUMPY_VERSION_TYPE_HINTS_SUPPORT}",
)
def test_numpy_ndarray_class_support_type_indexing(self):
"""
Test that numpy ndarray class can be subscripted (type hints)
"""
src = """
import numpy as np
np.ndarray[int]
"""
node = builder.extract_node(src)
cls_node = node.inferred()[0]
self.assertIsInstance(cls_node, nodes.ClassDef)
self.assertEqual(cls_node.name, "ndarray")
if __name__ == "__main__":
unittest.main()