blob: ef96aa2fe0baa55cb53dd5d43934a13875b0c079 [file] [log] [blame]
# Copyright (c) 2019-2021 hippo91 <guillaume.peillex@gmail.com>
# Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk>
# Copyright (c) 2020 Claudiu Popa <pcmanticore@gmail.com>
# Copyright (c) 2021 Pierre Sassoulas <pierre.sassoulas@gmail.com>
# Copyright (c) 2021 Daniƫl van Noord <13665637+DanielNoord@users.noreply.github.com>
# Copyright (c) 2021 Marc Mueller <30130371+cdce8p@users.noreply.github.com>
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/PyCQA/astroid/blob/main/LICENSE
import unittest
try:
import numpy # pylint: disable=unused-import
HAS_NUMPY = True
except ImportError:
HAS_NUMPY = False
from astroid import builder
@unittest.skipUnless(HAS_NUMPY, "This test requires the numpy library.")
class BrainNumpyCoreMultiarrayTest(unittest.TestCase):
"""
Test the numpy core multiarray brain module
"""
numpy_functions_returning_array = (
("array", "[1, 2]"),
("bincount", "[1, 2]"),
("busday_count", "('2011-01', '2011-02')"),
("busday_offset", "'2012-03', -1, roll='forward'"),
("concatenate", "([1, 2], [1, 2])"),
("datetime_as_string", "['2012-02', '2012-03']"),
("dot", "[1, 2]", "[1, 2]"),
("empty_like", "[1, 2]"),
("inner", "[1, 2]", "[1, 2]"),
("is_busday", "['2011-07-01', '2011-07-02', '2011-07-18']"),
("lexsort", "(('toto', 'tutu'), ('riri', 'fifi'))"),
("packbits", "np.array([1, 2])"),
("unpackbits", "np.array([[1], [2], [3]], dtype=np.uint8)"),
("vdot", "[1, 2]", "[1, 2]"),
("where", "[True, False]", "[1, 2]", "[2, 1]"),
("empty", "[1, 2]"),
("zeros", "[1, 2]"),
)
numpy_functions_returning_bool = (
("can_cast", "np.int32, np.int64"),
("may_share_memory", "np.array([1, 2])", "np.array([3, 4])"),
("shares_memory", "np.array([1, 2])", "np.array([3, 4])"),
)
numpy_functions_returning_dtype = (
# ("min_scalar_type", "10"), # Not yet tested as it returns np.dtype
# ("result_type", "'i4'", "'c8'"), # Not yet tested as it returns np.dtype
)
numpy_functions_returning_none = (("copyto", "([1, 2], [1, 3])"),)
numpy_functions_returning_tuple = (
(
"unravel_index",
"[22, 33, 44]",
"(6, 7)",
), # Not yet tested as is returns a tuple
)
def _inferred_numpy_func_call(self, func_name, *func_args):
node = builder.extract_node(
f"""
import numpy as np
func = np.{func_name:s}
func({','.join(func_args):s})
"""
)
return node.infer()
def _inferred_numpy_no_alias_func_call(self, func_name, *func_args):
node = builder.extract_node(
f"""
import numpy
func = numpy.{func_name:s}
func({','.join(func_args):s})
"""
)
return node.infer()
def test_numpy_function_calls_inferred_as_ndarray(self):
"""
Test that calls to numpy functions are inferred as numpy.ndarray
"""
for infer_wrapper in (
self._inferred_numpy_func_call,
self._inferred_numpy_no_alias_func_call,
):
for func_ in self.numpy_functions_returning_array:
with self.subTest(typ=func_):
inferred_values = list(infer_wrapper(*func_))
self.assertTrue(
len(inferred_values) == 1,
msg="Too much inferred values ({}) for {:s}".format(
inferred_values, func_[0]
),
)
self.assertTrue(
inferred_values[-1].pytype() == ".ndarray",
msg="Illicit type for {:s} ({})".format(
func_[0], inferred_values[-1].pytype()
),
)
def test_numpy_function_calls_inferred_as_bool(self):
"""
Test that calls to numpy functions are inferred as bool
"""
for infer_wrapper in (
self._inferred_numpy_func_call,
self._inferred_numpy_no_alias_func_call,
):
for func_ in self.numpy_functions_returning_bool:
with self.subTest(typ=func_):
inferred_values = list(infer_wrapper(*func_))
self.assertTrue(
len(inferred_values) == 1,
msg="Too much inferred values ({}) for {:s}".format(
inferred_values, func_[0]
),
)
self.assertTrue(
inferred_values[-1].pytype() == "builtins.bool",
msg="Illicit type for {:s} ({})".format(
func_[0], inferred_values[-1].pytype()
),
)
def test_numpy_function_calls_inferred_as_dtype(self):
"""
Test that calls to numpy functions are inferred as numpy.dtype
"""
for infer_wrapper in (
self._inferred_numpy_func_call,
self._inferred_numpy_no_alias_func_call,
):
for func_ in self.numpy_functions_returning_dtype:
with self.subTest(typ=func_):
inferred_values = list(infer_wrapper(*func_))
self.assertTrue(
len(inferred_values) == 1,
msg="Too much inferred values ({}) for {:s}".format(
inferred_values, func_[0]
),
)
self.assertTrue(
inferred_values[-1].pytype() == "numpy.dtype",
msg="Illicit type for {:s} ({})".format(
func_[0], inferred_values[-1].pytype()
),
)
def test_numpy_function_calls_inferred_as_none(self):
"""
Test that calls to numpy functions are inferred as None
"""
for infer_wrapper in (
self._inferred_numpy_func_call,
self._inferred_numpy_no_alias_func_call,
):
for func_ in self.numpy_functions_returning_none:
with self.subTest(typ=func_):
inferred_values = list(infer_wrapper(*func_))
self.assertTrue(
len(inferred_values) == 1,
msg="Too much inferred values ({}) for {:s}".format(
inferred_values, func_[0]
),
)
self.assertTrue(
inferred_values[-1].pytype() == "builtins.NoneType",
msg="Illicit type for {:s} ({})".format(
func_[0], inferred_values[-1].pytype()
),
)
def test_numpy_function_calls_inferred_as_tuple(self):
"""
Test that calls to numpy functions are inferred as tuple
"""
for infer_wrapper in (
self._inferred_numpy_func_call,
self._inferred_numpy_no_alias_func_call,
):
for func_ in self.numpy_functions_returning_tuple:
with self.subTest(typ=func_):
inferred_values = list(infer_wrapper(*func_))
self.assertTrue(
len(inferred_values) == 1,
msg="Too much inferred values ({}) for {:s}".format(
inferred_values, func_[0]
),
)
self.assertTrue(
inferred_values[-1].pytype() == "builtins.tuple",
msg="Illicit type for {:s} ({})".format(
func_[0], inferred_values[-1].pytype()
),
)
if __name__ == "__main__":
unittest.main()