| /* |
| * Copyright 2000-2013 JetBrains s.r.o. |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package com.jetbrains.python.psi.types; |
| |
| import com.intellij.psi.PsiElement; |
| import com.intellij.psi.PsiPolyVariantReference; |
| import com.intellij.psi.PsiReference; |
| import com.intellij.psi.ResolveResult; |
| import com.jetbrains.python.PyNames; |
| import com.jetbrains.python.psi.*; |
| import com.jetbrains.python.psi.resolve.PyResolveContext; |
| import com.jetbrains.python.psi.resolve.RatedResolveResult; |
| import org.jetbrains.annotations.NotNull; |
| import org.jetbrains.annotations.Nullable; |
| |
| import java.util.*; |
| |
| /** |
| * @author vlan |
| */ |
| public class PyTypeChecker { |
| private PyTypeChecker() { |
| } |
| |
| public static boolean match(@Nullable PyType expected, @Nullable PyType actual, @NotNull TypeEvalContext context) { |
| return match(expected, actual, context, null, true); |
| } |
| |
| /** |
| * Checks whether a type *actual* can be placed where *expected* is expected. |
| * For example int matches object, while str doesn't match int. |
| * Work for builtin types, classes, tuples etc. |
| * |
| * @param expected expected type |
| * @param actual type to be matched against expected |
| * @param context |
| * @param substitutions |
| * @return |
| */ |
| public static boolean match(@Nullable PyType expected, @Nullable PyType actual, @NotNull TypeEvalContext context, |
| @Nullable Map<PyGenericType, PyType> substitutions) { |
| return match(expected, actual, context, substitutions, true); |
| } |
| |
| private static boolean match(@Nullable PyType expected, @Nullable PyType actual, @NotNull TypeEvalContext context, |
| @Nullable Map<PyGenericType, PyType> substitutions, boolean recursive) { |
| // TODO: subscriptable types?, module types?, etc. |
| if (expected instanceof PyGenericType && substitutions != null) { |
| final PyGenericType generic = (PyGenericType)expected; |
| final PyType subst = substitutions.get(generic); |
| final PyType bound = generic.getBound(); |
| if (!match(bound, actual, context, substitutions, recursive)) { |
| return false; |
| } |
| else if (subst != null) { |
| if (expected.equals(actual)) { |
| return true; |
| } |
| else if (recursive) { |
| return match(subst, actual, context, substitutions, false); |
| } |
| else { |
| return false; |
| } |
| } |
| else if (actual != null) { |
| substitutions.put(generic, actual); |
| } |
| else if (bound != null) { |
| substitutions.put(generic, bound); |
| } |
| return true; |
| } |
| if (expected == null || actual == null) { |
| return true; |
| } |
| if (expected instanceof PyClassType) { |
| final PyClass c = ((PyClassType)expected).getPyClass(); |
| if ("object".equals(c.getName())) { |
| return true; |
| } |
| } |
| if (isUnknown(actual)) { |
| return true; |
| } |
| if (actual instanceof PyUnionType) { |
| for (PyType m : ((PyUnionType)actual).getMembers()) { |
| if (match(expected, m, context, substitutions, recursive)) { |
| return true; |
| } |
| } |
| return false; |
| } |
| if (expected instanceof PyUnionType) { |
| for (PyType t : ((PyUnionType)expected).getMembers()) { |
| if (match(t, actual, context, substitutions, recursive)) { |
| return true; |
| } |
| } |
| return false; |
| } |
| if (expected instanceof PyClassType && actual instanceof PyClassType) { |
| final PyClass superClass = ((PyClassType)expected).getPyClass(); |
| final PyClass subClass = ((PyClassType)actual).getPyClass(); |
| if (expected instanceof PyCollectionType && actual instanceof PyCollectionType) { |
| if (!matchClasses(superClass, subClass, context)) { |
| return false; |
| } |
| final PyType superElementType = ((PyCollectionType)expected).getElementType(context); |
| final PyType subElementType = ((PyCollectionType)actual).getElementType(context); |
| return match(superElementType, subElementType, context, substitutions, recursive); |
| } |
| else if (expected instanceof PyTupleType && actual instanceof PyTupleType) { |
| final PyTupleType superTupleType = (PyTupleType)expected; |
| final PyTupleType subTupleType = (PyTupleType)actual; |
| if (superTupleType.getElementCount() != subTupleType.getElementCount()) { |
| return false; |
| } |
| else { |
| for (int i = 0; i < superTupleType.getElementCount(); i++) { |
| if (!match(superTupleType.getElementType(i), subTupleType.getElementType(i), context, substitutions, recursive)) { |
| return false; |
| } |
| } |
| return true; |
| } |
| } |
| else if (matchClasses(superClass, subClass, context)) { |
| return true; |
| } |
| else if (((PyClassType)actual).isDefinition() && PyNames.CALLABLE.equals(expected.getName())) { |
| return true; |
| } |
| if (expected.equals(actual)) { |
| return true; |
| } |
| } |
| if (actual instanceof PyFunctionType && expected instanceof PyClassType) { |
| final PyClass superClass = ((PyClassType)expected).getPyClass(); |
| if (PyNames.CALLABLE.equals(superClass.getName())) { |
| return true; |
| } |
| } |
| if (actual instanceof PyCallableType && expected instanceof PyCallableType) { |
| final PyCallableType expectedCallable = (PyCallableType)expected; |
| final PyCallableType actualCallable = (PyCallableType)actual; |
| if (expectedCallable.isCallable() && actualCallable.isCallable()) { |
| final List<PyCallableParameter> expectedParameters = expectedCallable.getParameters(context); |
| final List<PyCallableParameter> actualParameters = actualCallable.getParameters(context); |
| if (expectedParameters != null && actualParameters != null) { |
| final int size = Math.min(expectedParameters.size(), actualParameters.size()); |
| for (int i = 0; i < size; i++) { |
| final PyCallableParameter expectedParam = expectedParameters.get(i); |
| final PyCallableParameter actualParam = actualParameters.get(i); |
| // TODO: Check named and star params, not only positional ones |
| if (!match(expectedParam.getType(context), actualParam.getType(context), context, substitutions, recursive)) { |
| return false; |
| } |
| } |
| } |
| if (!match(expectedCallable.getReturnType(context), actualCallable.getReturnType(context), context, substitutions, recursive)) { |
| return false; |
| } |
| return true; |
| } |
| } |
| return matchNumericTypes(expected, actual); |
| } |
| |
| private static boolean matchNumericTypes(PyType expected, PyType actual) { |
| final String superName = expected.getName(); |
| final String subName = actual.getName(); |
| final boolean subIsBool = "bool".equals(subName); |
| final boolean subIsInt = "int".equals(subName); |
| final boolean subIsLong = "long".equals(subName); |
| final boolean subIsFloat = "float".equals(subName); |
| final boolean subIsComplex = "complex".equals(subName); |
| if (superName == null || subName == null || |
| superName.equals(subName) || |
| ("int".equals(superName) && subIsBool) || |
| (("long".equals(superName) || PyNames.ABC_INTEGRAL.equals(superName)) && (subIsBool || subIsInt)) || |
| (("float".equals(superName) || PyNames.ABC_REAL.equals(superName)) && (subIsBool || subIsInt || subIsLong)) || |
| (("complex".equals(superName) || PyNames.ABC_COMPLEX.equals(superName)) && (subIsBool || subIsInt || subIsLong || subIsFloat)) || |
| (PyNames.ABC_NUMBER.equals(superName) && (subIsBool || subIsInt || subIsLong || subIsFloat || subIsComplex))) { |
| return true; |
| } |
| return false; |
| } |
| |
| public static boolean isUnknown(@Nullable PyType type) { |
| if (type == null || type instanceof PyGenericType) { |
| return true; |
| } |
| if (type instanceof PyUnionType) { |
| final PyUnionType union = (PyUnionType)type; |
| for (PyType t : union.getMembers()) { |
| if (isUnknown(t)) { |
| return true; |
| } |
| } |
| } |
| return false; |
| } |
| |
| @Nullable |
| public static PyType toNonWeakType(@Nullable PyType type, @NotNull TypeEvalContext context) { |
| if (type instanceof PyUnionType) { |
| final PyUnionType unionType = (PyUnionType)type; |
| if (unionType.isWeak()) { |
| return unionType.excludeNull(context); |
| } |
| } |
| return type; |
| } |
| |
| public static boolean hasGenerics(@Nullable PyType type, @NotNull TypeEvalContext context) { |
| final Set<PyGenericType> collected = new HashSet<PyGenericType>(); |
| collectGenerics(type, context, collected, new HashSet<PyType>()); |
| return !collected.isEmpty(); |
| } |
| |
| private static void collectGenerics(@Nullable PyType type, @NotNull TypeEvalContext context, @NotNull Set<PyGenericType> collected, |
| @NotNull Set<PyType> visited) { |
| if (visited.contains(type)) { |
| return; |
| } |
| visited.add(type); |
| if (type instanceof PyGenericType) { |
| collected.add((PyGenericType)type); |
| } |
| else if (type instanceof PyUnionType) { |
| final PyUnionType union = (PyUnionType)type; |
| for (PyType t : union.getMembers()) { |
| collectGenerics(t, context, collected, visited); |
| } |
| } |
| else if (type instanceof PyCollectionType) { |
| final PyCollectionType collection = (PyCollectionType)type; |
| collectGenerics(collection.getElementType(context), context, collected, visited); |
| } |
| else if (type instanceof PyTupleType) { |
| final PyTupleType tuple = (PyTupleType)type; |
| final int n = tuple.getElementCount(); |
| for (int i = 0; i < n; i++) { |
| collectGenerics(tuple.getElementType(i), context, collected, visited); |
| } |
| } |
| else if (type instanceof PyCallableType) { |
| final PyCallableType callable = (PyCallableType)type; |
| final List<PyCallableParameter> parameters = callable.getParameters(context); |
| if (parameters != null) { |
| for (PyCallableParameter parameter : parameters) { |
| if (parameter != null) { |
| collectGenerics(parameter.getType(context), context, collected, visited); |
| } |
| } |
| } |
| collectGenerics(callable.getReturnType(context), context, collected, visited); |
| } |
| } |
| |
| @Nullable |
| public static PyType substitute(@Nullable PyType type, @NotNull Map<PyGenericType, PyType> substitutions, |
| @NotNull TypeEvalContext context) { |
| if (hasGenerics(type, context)) { |
| if (type instanceof PyGenericType) { |
| return substitutions.get((PyGenericType)type); |
| } |
| else if (type instanceof PyUnionType) { |
| final PyUnionType union = (PyUnionType)type; |
| final List<PyType> results = new ArrayList<PyType>(); |
| for (PyType t : union.getMembers()) { |
| final PyType subst = substitute(t, substitutions, context); |
| results.add(subst); |
| } |
| return PyUnionType.union(results); |
| } |
| else if (type instanceof PyCollectionTypeImpl) { |
| final PyCollectionTypeImpl collection = (PyCollectionTypeImpl)type; |
| final PyType elem = collection.getElementType(context); |
| final PyType subst = substitute(elem, substitutions, context); |
| return new PyCollectionTypeImpl(collection.getPyClass(), collection.isDefinition(), subst); |
| } |
| else if (type instanceof PyTupleType) { |
| final PyTupleType tuple = (PyTupleType)type; |
| final int n = tuple.getElementCount(); |
| final List<PyType> results = new ArrayList<PyType>(); |
| for (int i = 0; i < n; i++) { |
| final PyType subst = substitute(tuple.getElementType(i), substitutions, context); |
| results.add(subst); |
| } |
| return new PyTupleType((PyTupleType)type, results.toArray(new PyType[results.size()])); |
| } |
| else if (type instanceof PyCallableType) { |
| final PyCallableType callable = (PyCallableType)type; |
| List<PyCallableParameter> substParams = null; |
| final List<PyCallableParameter> parameters = callable.getParameters(context); |
| if (parameters != null) { |
| substParams = new ArrayList<PyCallableParameter>(); |
| for (PyCallableParameter parameter : parameters) { |
| final PyType substType = substitute(parameter.getType(context), substitutions, context); |
| final PyCallableParameter subst = parameter.getParameter() != null ? |
| new PyCallableParameterImpl(parameter.getParameter()) : |
| new PyCallableParameterImpl(parameter.getName(), substType); |
| substParams.add(subst); |
| } |
| } |
| final PyType substResult = substitute(callable.getReturnType(context), substitutions, context); |
| return new PyCallableTypeImpl(substParams, substResult); |
| } |
| } |
| return type; |
| } |
| |
| @Nullable |
| public static Map<PyGenericType, PyType> unifyGenericCall(@Nullable PyExpression receiver, |
| @NotNull Map<PyExpression, PyNamedParameter> arguments, |
| @NotNull TypeEvalContext context) { |
| final Map<PyGenericType, PyType> substitutions = unifyReceiver(receiver, context); |
| for (Map.Entry<PyExpression, PyNamedParameter> entry : arguments.entrySet()) { |
| final PyNamedParameter p = entry.getValue(); |
| if (p.isPositionalContainer() || p.isKeywordContainer()) { |
| continue; |
| } |
| final PyType argType = context.getType(entry.getKey()); |
| final PyType paramType = context.getType(p); |
| if (!match(paramType, argType, context, substitutions)) { |
| return null; |
| } |
| } |
| return substitutions; |
| } |
| |
| @NotNull |
| public static Map<PyGenericType, PyType> unifyReceiver(@Nullable PyExpression receiver, @NotNull TypeEvalContext context) { |
| final Map<PyGenericType, PyType> substitutions = new LinkedHashMap<PyGenericType, PyType>(); |
| // Collect generic params of object type |
| final Set<PyGenericType> generics = new LinkedHashSet<PyGenericType>(); |
| final PyType qualifierType = receiver != null ? context.getType(receiver) : null; |
| collectGenerics(qualifierType, context, generics, new HashSet<PyType>()); |
| for (PyGenericType t : generics) { |
| substitutions.put(t, t); |
| } |
| // Unify generics in constructor |
| if (qualifierType != null) { |
| final PyResolveContext resolveContext = PyResolveContext.noImplicits().withTypeEvalContext(context); |
| // TODO: Resolve to __new__ as well |
| final List<? extends RatedResolveResult> results = qualifierType.resolveMember(PyNames.INIT, null, AccessDirection.READ, |
| resolveContext); |
| if (results != null && !results.isEmpty()) { |
| final PsiElement init = results.get(0).getElement(); |
| if (init instanceof PyTypedElement) { |
| final PyType initType = context.getType((PyTypedElement)init); |
| if (initType instanceof PyCallableType) { |
| final PyType initReturnType = ((PyCallableType)initType).getReturnType(context); |
| if (initReturnType != null) { |
| match(initReturnType, qualifierType, context, substitutions); |
| } |
| } |
| } |
| } |
| } |
| return substitutions; |
| } |
| |
| private static boolean matchClasses(@Nullable PyClass superClass, @Nullable PyClass subClass, @NotNull TypeEvalContext context) { |
| if (superClass == null || subClass == null || subClass.isSubclass(superClass) || PyABCUtil.isSubclass(subClass, superClass)) { |
| return true; |
| } |
| else if (PyUtil.hasUnresolvedAncestors(subClass, context)) { |
| return true; |
| } |
| else { |
| final String superName = superClass.getName(); |
| return superName != null && superName.equals(subClass.getName()); |
| } |
| } |
| |
| @Nullable |
| public static AnalyzeCallResults analyzeCall(@NotNull PyCallExpression call, @NotNull TypeEvalContext context) { |
| final PyExpression callee = call.getCallee(); |
| if (callee instanceof PyQualifiedExpression) { |
| final PyQualifiedExpression qualified = (PyQualifiedExpression)callee; |
| if (isResolvedToSeveralMethods(qualified, context)) { |
| return null; |
| } |
| } |
| final PyArgumentList args = call.getArgumentList(); |
| if (args != null) { |
| final CallArgumentsMapping mapping = args.analyzeCall(PyResolveContext.noImplicits().withTypeEvalContext(context)); |
| final Map<PyExpression, PyNamedParameter> arguments = mapping.getPlainMappedParams(); |
| final PyCallExpression.PyMarkedCallee markedCallee = mapping.getMarkedCallee(); |
| if (markedCallee != null) { |
| final Callable callable = markedCallee.getCallable(); |
| if (callable instanceof PyFunction) { |
| final PyFunction function = (PyFunction)callable; |
| final PyExpression receiver; |
| if (function.getModifier() == PyFunction.Modifier.STATICMETHOD) { |
| receiver = null; |
| } |
| else if (callee instanceof PyQualifiedExpression) { |
| receiver = ((PyQualifiedExpression)callee).getQualifier(); |
| } |
| else { |
| receiver = null; |
| } |
| return new AnalyzeCallResults(callable, receiver, arguments); |
| } |
| } |
| } |
| return null; |
| } |
| |
| @Nullable |
| public static AnalyzeCallResults analyzeCall(@NotNull PyBinaryExpression expr, @NotNull TypeEvalContext context) { |
| final PsiPolyVariantReference ref = expr.getReference(PyResolveContext.noImplicits().withTypeEvalContext(context)); |
| final ResolveResult[] resolveResult; |
| resolveResult = ref.multiResolve(false); |
| AnalyzeCallResults firstResults = null; |
| for (ResolveResult result : resolveResult) { |
| final PsiElement resolved = result.getElement(); |
| if (resolved instanceof PyTypedElement) { |
| final PyTypedElement typedElement = (PyTypedElement)resolved; |
| final PyType type = context.getType(typedElement); |
| if (!(type instanceof PyFunctionType)) { |
| return null; |
| } |
| final Callable callable = ((PyFunctionType)type).getCallable(); |
| final boolean isRight = PyNames.isRightOperatorName(typedElement.getName()); |
| final PyExpression arg = isRight ? expr.getLeftExpression() : expr.getRightExpression(); |
| final PyExpression receiver = isRight ? expr.getRightExpression() : expr.getLeftExpression(); |
| final PyParameter[] parameters = callable.getParameterList().getParameters(); |
| if (parameters.length >= 2) { |
| final PyNamedParameter param = parameters[1].getAsNamed(); |
| if (arg != null && param != null) { |
| final Map<PyExpression, PyNamedParameter> arguments = new LinkedHashMap<PyExpression, PyNamedParameter>(); |
| arguments.put(arg, param); |
| final AnalyzeCallResults results = new AnalyzeCallResults(callable, receiver, arguments); |
| if (firstResults == null) { |
| firstResults = results; |
| } |
| if (match(context.getType(param), context.getType(arg), context)) { |
| return results; |
| } |
| } |
| } |
| } |
| } |
| if (firstResults != null) { |
| return firstResults; |
| } |
| return null; |
| } |
| |
| @Nullable |
| public static AnalyzeCallResults analyzeCall(@NotNull PySubscriptionExpression expr, @NotNull TypeEvalContext context) { |
| final PsiReference ref = expr.getReference(PyResolveContext.noImplicits().withTypeEvalContext(context)); |
| final PsiElement resolved; |
| resolved = ref.resolve(); |
| if (resolved instanceof PyTypedElement) { |
| final PyType type = context.getType((PyTypedElement)resolved); |
| if (type instanceof PyFunctionType) { |
| final Callable callable = ((PyFunctionType)type).getCallable(); |
| final PyParameter[] parameters = callable.getParameterList().getParameters(); |
| if (parameters.length == 2) { |
| final PyNamedParameter param = parameters[1].getAsNamed(); |
| if (param != null) { |
| final Map<PyExpression, PyNamedParameter> arguments = new LinkedHashMap<PyExpression, PyNamedParameter>(); |
| final PyExpression arg = expr.getIndexExpression(); |
| if (arg != null) { |
| arguments.put(arg, param); |
| return new AnalyzeCallResults(callable, expr.getOperand(), arguments); |
| } |
| } |
| } |
| } |
| } |
| return null; |
| } |
| |
| @Nullable |
| public static AnalyzeCallResults analyzeCallSite(@Nullable PyCallSiteExpression callSite, @NotNull TypeEvalContext context) { |
| if (callSite instanceof PyCallExpression) { |
| return analyzeCall((PyCallExpression)callSite, context); |
| } |
| else if (callSite instanceof PyBinaryExpression) { |
| return analyzeCall((PyBinaryExpression)callSite, context); |
| } |
| else if (callSite instanceof PySubscriptionExpression) { |
| return analyzeCall((PySubscriptionExpression)callSite, context); |
| } |
| return null; |
| } |
| |
| @Nullable |
| public static Boolean isCallable(@Nullable PyType type) { |
| if (type == null) { |
| return null; |
| } |
| else if (type instanceof PyUnionType) { |
| Boolean result = true; |
| for (PyType member : ((PyUnionType)type).getMembers()) { |
| final Boolean callable = isCallable(member); |
| if (callable == null) { |
| return null; |
| } |
| else if (!callable) { |
| result = false; |
| } |
| } |
| return result; |
| } |
| else if (type instanceof PyCallableType) { |
| return ((PyCallableType) type).isCallable(); |
| } |
| return false; |
| } |
| |
| /** |
| * Hack for skipping type checking for method calls of union members if there are several call alternatives. |
| * |
| * TODO: Multi-resolve callees when analysing calls. This requires multi-resolving in followAssignmentsChain. |
| */ |
| public static boolean isResolvedToSeveralMethods(@NotNull PyQualifiedExpression callee, @NotNull TypeEvalContext context) { |
| final PyExpression qualifier = callee.getQualifier(); |
| if (qualifier != null) { |
| final PyType qualifierType = context.getType(qualifier); |
| if (qualifierType instanceof PyUnionType) { |
| final PyUnionType unionType = (PyUnionType)qualifierType; |
| final String name = callee.getName(); |
| if (name == null) { |
| return false; |
| } |
| int sameNameCount = 0; |
| for (PyType member : unionType.getMembers()) { |
| if (member != null) { |
| final PyResolveContext resolveContext = PyResolveContext.noImplicits().withTypeEvalContext(context); |
| final List<? extends RatedResolveResult> results = member.resolveMember(name, callee, AccessDirection.READ, resolveContext |
| ); |
| if (results != null && !results.isEmpty()) { |
| sameNameCount++; |
| } |
| } |
| } |
| if (sameNameCount > 1) { |
| return true; |
| } |
| } |
| final PyExpression qualifierExpr = qualifier instanceof PyCallExpression ? ((PyCallExpression)qualifier).getCallee() : qualifier; |
| if (qualifierExpr instanceof PyQualifiedExpression) { |
| return isResolvedToSeveralMethods((PyQualifiedExpression)qualifierExpr, context); |
| } |
| } |
| return false; |
| } |
| |
| public static class AnalyzeCallResults { |
| @NotNull private final Callable myCallable; |
| @Nullable private final PyExpression myReceiver; |
| @NotNull private final Map<PyExpression, PyNamedParameter> myArguments; |
| |
| public AnalyzeCallResults(@NotNull Callable callable, @Nullable PyExpression receiver, |
| @NotNull Map<PyExpression, PyNamedParameter> arguments) { |
| myCallable = callable; |
| myReceiver = receiver; |
| myArguments = arguments; |
| } |
| |
| @NotNull |
| public Callable getCallable() { |
| return myCallable; |
| } |
| |
| @Nullable |
| public PyExpression getReceiver() { |
| return myReceiver; |
| } |
| |
| @NotNull |
| public Map<PyExpression, PyNamedParameter> getArguments() { |
| return myArguments; |
| } |
| } |
| } |