blob: 96dddd286c7c2c803dc59623361a6b5a5dc8888e [file] [log] [blame]
# Copyright (c) 2021 hippo91 <guillaume.peillex@gmail.com>
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/PyCQA/astroid/blob/main/LICENSE
import pytest
try:
import numpy # pylint: disable=unused-import
HAS_NUMPY = True
except ImportError:
HAS_NUMPY = False
from astroid import builder
@pytest.mark.skipif(HAS_NUMPY is False, reason="This test requires the numpy library.")
class TestBrainNumpyMa:
"""
Test the numpy ma brain module
"""
@staticmethod
def test_numpy_ma_masked_where_returns_maskedarray():
"""
Test that calls to numpy ma masked_where returns a MaskedArray object.
The "masked_where" node is an Attribute
"""
src = """
import numpy as np
data = np.ndarray((1,2))
np.ma.masked_where([1, 0, 0], data)
"""
node = builder.extract_node(src)
cls_node = node.inferred()[0]
assert cls_node.pytype() == "numpy.ma.core.MaskedArray"
@staticmethod
def test_numpy_ma_masked_where_returns_maskedarray_bis():
"""
Test that calls to numpy ma masked_where returns a MaskedArray object
The "masked_where" node is a Name
"""
src = """
from numpy.ma import masked_where
data = np.ndarray((1,2))
masked_where([1, 0, 0], data)
"""
node = builder.extract_node(src)
cls_node = node.inferred()[0]
assert cls_node.pytype() == "numpy.ma.core.MaskedArray"