blob: c28bc29455c90b243ba8b3d32b42303998fdf9d0 [file] [log] [blame]
/*
* Copyright 2000-2013 JetBrains s.r.o.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.jetbrains.python.codeInsight.stdlib;
import com.google.common.collect.ImmutableSet;
import com.intellij.openapi.extensions.Extensions;
import com.intellij.openapi.vfs.VirtualFile;
import com.intellij.psi.PsiElement;
import com.intellij.psi.util.QualifiedName;
import com.jetbrains.python.PyNames;
import com.jetbrains.python.codeInsight.controlflow.ScopeOwner;
import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.impl.PyBuiltinCache;
import com.jetbrains.python.psi.impl.PyTypeProvider;
import com.jetbrains.python.psi.resolve.PyResolveContext;
import com.jetbrains.python.psi.resolve.QualifiedNameFinder;
import com.jetbrains.python.psi.types.*;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static com.jetbrains.python.psi.PyUtil.as;
/**
* @author yole
*/
public class PyStdlibTypeProvider extends PyTypeProviderBase {
private static final Set<String> OPEN_FUNCTIONS = ImmutableSet.of("__builtin__.open", "io.open", "os.fdopen",
"pathlib.Path.open");
private static final String BINARY_FILE_TYPE = "io.FileIO[bytes]";
private static final String TEXT_FILE_TYPE = "io.TextIOWrapper[unicode]";
@Nullable
public static PyStdlibTypeProvider getInstance() {
for (PyTypeProvider typeProvider : Extensions.getExtensions(PyTypeProvider.EP_NAME)) {
if (typeProvider instanceof PyStdlibTypeProvider) {
return (PyStdlibTypeProvider)typeProvider;
}
}
return null;
}
@Override
public PyType getReferenceType(@NotNull PsiElement referenceTarget, @NotNull TypeEvalContext context, @Nullable PsiElement anchor) {
PyType type = getNamedTupleType(referenceTarget, anchor);
if (type != null) {
return type;
}
type = getEnumType(referenceTarget, context, anchor);
if (type != null) {
return type;
}
return null;
}
@Nullable
private static PyType getEnumType(@NotNull PsiElement referenceTarget, @NotNull TypeEvalContext context,
@Nullable PsiElement anchor) {
if (referenceTarget instanceof PyTargetExpression) {
final PyTargetExpression target = (PyTargetExpression)referenceTarget;
final ScopeOwner owner = ScopeUtil.getScopeOwner(target);
if (owner instanceof PyClass) {
final PyClass cls = (PyClass)owner;
final List<PyClassLikeType> types = cls.getAncestorTypes(context);
for (PyClassLikeType type : types) {
if (type != null && "enum.Enum".equals(type.getClassQName())) {
final PyType classType = context.getType(cls);
if (classType instanceof PyClassType) {
return ((PyClassType)classType).toInstance();
}
}
}
}
}
if (referenceTarget instanceof PyQualifiedNameOwner) {
final PyQualifiedNameOwner qualifiedNameOwner = (PyQualifiedNameOwner)referenceTarget;
final String name = qualifiedNameOwner.getQualifiedName();
if ("enum.Enum.name".equals(name)) {
return PyBuiltinCache.getInstance(referenceTarget).getStrType();
}
else if ("enum.Enum.value".equals(name) && anchor instanceof PyReferenceExpression && context.maySwitchToAST(anchor)) {
final PyReferenceExpression anchorExpr = (PyReferenceExpression)anchor;
final PyExpression qualifier = anchorExpr.getQualifier();
if (qualifier instanceof PyReferenceExpression) {
final PyReferenceExpression qualifierExpr = (PyReferenceExpression)qualifier;
final PsiElement resolvedQualifier = qualifierExpr.getReference().resolve();
if (resolvedQualifier instanceof PyTargetExpression) {
final PyTargetExpression qualifierTarget = (PyTargetExpression)resolvedQualifier;
// Requires switching to AST, we cannot use getType(qualifierTarget) here, because its type is overridden by this type provider
if (context.maySwitchToAST(qualifierTarget)) {
final PyExpression value = qualifierTarget.findAssignedValue();
if (value != null) {
return context.getType(value);
}
}
}
}
}
else if ("enum.EnumMeta.__members__".equals(name)) {
return PyTypeParser.getTypeByName(referenceTarget, "dict[str, unknown]");
}
}
return null;
}
@Nullable
@Override
public PyType getCallType(@NotNull PyFunction function, @Nullable PyCallSiteExpression callSite, @NotNull TypeEvalContext context) {
final String qname = getQualifiedName(function, callSite);
if (qname != null) {
if (OPEN_FUNCTIONS.contains(qname) && callSite != null) {
final PyTypeChecker.AnalyzeCallResults results = PyTypeChecker.analyzeCallSite(callSite, context);
if (results != null) {
final PyType type = getOpenFunctionType(qname, results.getArguments(), callSite);
if (type != null) {
return type;
}
}
}
else if ("__builtin__.tuple.__add__".equals(qname) && callSite instanceof PyBinaryExpression) {
final PyBinaryExpression expression = (PyBinaryExpression)callSite;
final PyTupleType leftTupleType = as(context.getType(expression.getLeftExpression()), PyTupleType.class);
if (expression.getRightExpression() != null) {
final PyTupleType rightTupleType = as(context.getType(expression.getRightExpression()), PyTupleType.class);
if (leftTupleType != null && rightTupleType != null) {
final PyType[] elementTypes = new PyType[leftTupleType.getElementCount() + rightTupleType.getElementCount()];
for (int i = 0; i < leftTupleType.getElementCount(); i++) {
elementTypes[i] = leftTupleType.getElementType(i);
}
for (int i = 0; i < rightTupleType.getElementCount(); i++) {
elementTypes[i + leftTupleType.getElementCount()] = rightTupleType.getElementType(i);
}
return PyTupleType.create(function, elementTypes);
}
}
}
}
return null;
}
@Nullable
@Override
public PyType getContextManagerVariableType(@NotNull PyClass contextManager, @NotNull PyExpression withExpression, @NotNull TypeEvalContext context) {
if ("contextlib.closing".equals(contextManager.getQualifiedName()) && withExpression instanceof PyCallExpression) {
PyExpression closee = ((PyCallExpression)withExpression).getArgument(0, PyExpression.class);
if (closee != null) {
return context.getType(closee);
}
}
final String name = contextManager.getName();
if ("FileIO".equals(name) || "TextIOWrapper".equals(name) || "IOBase".equals(name) || "_IOBase".equals(name)) {
return context.getType(withExpression);
}
return null;
}
@Nullable
private static PyType getNamedTupleType(@NotNull PsiElement referenceTarget, @Nullable PsiElement anchor) {
if (referenceTarget instanceof PyTargetExpression) {
final PyTargetExpression target = (PyTargetExpression)referenceTarget;
final QualifiedName calleeName = target.getCalleeName();
if (calleeName != null && PyNames.NAMEDTUPLE.equals(calleeName.toString())) {
// TODO: Create stubs for namedtuple for preventing switch from stub to AST
final PyExpression value = target.findAssignedValue();
if (value instanceof PyCallExpression) {
final PyCallExpression call = (PyCallExpression)value;
final PyCallExpression.PyMarkedCallee callee = call.resolveCallee(PyResolveContext.noImplicits());
if (callee != null) {
final Callable callable = callee.getCallable();
if (PyNames.COLLECTIONS_NAMEDTUPLE.equals(callable.getQualifiedName())) {
return PyNamedTupleType.fromCall(call, 1);
}
}
}
}
}
else if (referenceTarget instanceof PyFunction && anchor instanceof PyCallExpression) {
final PyFunction function = (PyFunction)referenceTarget;
if (PyNames.NAMEDTUPLE.equals(function.getName()) && PyNames.COLLECTIONS_NAMEDTUPLE.equals(function.getQualifiedName())) {
return PyNamedTupleType.fromCall((PyCallExpression)anchor, 2);
}
}
return null;
}
@Nullable
private static PyType getOpenFunctionType(@NotNull String callQName,
@NotNull Map<PyExpression, PyNamedParameter> arguments,
@NotNull PsiElement anchor) {
String mode = "r";
for (Map.Entry<PyExpression, PyNamedParameter> entry : arguments.entrySet()) {
final PyNamedParameter parameter = entry.getValue();
if ("mode".equals(parameter.getName())) {
PyExpression argument = entry.getKey();
if (argument instanceof PyKeywordArgument) {
argument = ((PyKeywordArgument)argument).getValueExpression();
}
if (argument instanceof PyStringLiteralExpression) {
mode = ((PyStringLiteralExpression)argument).getStringValue();
break;
}
}
}
final LanguageLevel level = LanguageLevel.forElement(anchor);
// Binary mode
if (mode.contains("b")) {
return PyTypeParser.getTypeByName(anchor, BINARY_FILE_TYPE);
}
// Text mode
else {
if (level.isPy3K() || "io.open".equals(callQName)) {
return PyTypeParser.getTypeByName(anchor, TEXT_FILE_TYPE);
}
else {
return PyTypeParser.getTypeByName(anchor, BINARY_FILE_TYPE);
}
}
}
@Nullable
private static String getQualifiedName(@NotNull PyFunction f, @Nullable PsiElement callSite) {
if (!f.isValid()) {
return null;
}
String result = f.getName();
final PyClass c = f.getContainingClass();
final VirtualFile vfile = f.getContainingFile().getVirtualFile();
if (vfile != null) {
String module = QualifiedNameFinder.findShortestImportableName(callSite != null ? callSite : f, vfile);
if ("builtins".equals(module)) {
module = "__builtin__";
}
result = String.format("%s.%s%s",
module,
c != null ? c.getName() + "." : "",
result);
final QualifiedName qname = PyStdlibCanonicalPathProvider.restoreStdlibCanonicalPath(QualifiedName.fromDottedString(result));
if (qname != null) {
return qname.toString();
}
}
return result;
}
}