blob: 63ce0c00e5cd8974f49ff65708efa1622ad13f96 [file] [log] [blame]
/*
* Copyright 2000-2014 JetBrains s.r.o.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.jetbrains.python.codeInsight.imports;
import com.intellij.codeInspection.LocalInspectionToolSession;
import com.intellij.lang.ImportOptimizer;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiFile;
import com.intellij.psi.PsiFileSystemItem;
import com.intellij.util.ArrayUtil;
import com.jetbrains.python.formatter.PyBlock;
import com.jetbrains.python.inspections.unresolvedReference.PyUnresolvedReferencesInspection;
import com.jetbrains.python.psi.*;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/**
* @author yole
*/
public class PyImportOptimizer implements ImportOptimizer {
@Override
public boolean supports(PsiFile file) {
return true;
}
@Override
@NotNull
public Runnable processFile(@NotNull final PsiFile file) {
final LocalInspectionToolSession session = new LocalInspectionToolSession(file, 0, file.getTextLength());
final PyUnresolvedReferencesInspection.Visitor visitor = new PyUnresolvedReferencesInspection.Visitor(null,
session,
Collections.<String>emptyList());
file.accept(new PyRecursiveElementVisitor() {
@Override
public void visitPyElement(PyElement node) {
super.visitPyElement(node);
node.accept(visitor);
}
});
return new Runnable() {
@Override
public void run() {
visitor.optimizeImports();
if (file instanceof PyFile) {
new ImportSorter((PyFile) file).run();
}
}
};
}
private static class ImportSorter {
private final PyFile myFile;
private final List<PyImportStatementBase> myBuiltinImports = new ArrayList<PyImportStatementBase>();
private final List<PyImportStatementBase> myThirdPartyImports = new ArrayList<PyImportStatementBase>();
private final List<PyImportStatementBase> myProjectImports = new ArrayList<PyImportStatementBase>();
private final List<PyImportStatementBase> myImportBlock;
private final PyElementGenerator myGenerator;
private boolean myMissorted = false;
private ImportSorter(PyFile file) {
myFile = file;
myImportBlock = myFile.getImportBlock();
myGenerator = PyElementGenerator.getInstance(myFile.getProject());
}
public void run() {
if (myImportBlock.isEmpty()) {
return;
}
LanguageLevel langLevel = LanguageLevel.forElement(myFile);
for (PyImportStatementBase importStatement : myImportBlock) {
if (importStatement instanceof PyFromImportStatement && ((PyFromImportStatement)importStatement).isFromFuture()) {
continue;
}
if (importStatement instanceof PyImportStatement && importStatement.getImportElements().length > 1) {
for (PyImportElement importElement : importStatement.getImportElements()) {
myMissorted = true;
PsiElement toImport = importElement.resolve();
final PyImportStatement splitImport = myGenerator.createImportStatement(langLevel, importElement.getText(), null);
prioritize(splitImport, toImport);
}
}
else {
PsiElement toImport;
if (importStatement instanceof PyFromImportStatement) {
toImport = ((PyFromImportStatement)importStatement).resolveImportSource();
}
else {
final PyImportElement firstImportElement = ArrayUtil.getFirstElement(importStatement.getImportElements());
toImport = firstImportElement != null? firstImportElement.resolve() : null;
}
prioritize(importStatement, toImport);
}
}
if (myMissorted || needBlankLinesBetweenGroups()) {
applyResults();
}
}
private boolean needBlankLinesBetweenGroups() {
int nonEmptyGroups = 0;
if (myBuiltinImports.size() > 0) nonEmptyGroups++;
if (myThirdPartyImports.size() > 0) nonEmptyGroups++;
if (myProjectImports.size() > 0) nonEmptyGroups++;
return nonEmptyGroups > 1;
}
private void prioritize(PyImportStatementBase importStatement, @Nullable PsiElement toImport) {
if (toImport != null && !(toImport instanceof PsiFileSystemItem)) {
toImport = toImport.getContainingFile();
}
final AddImportHelper.ImportPriority priority = toImport == null
? AddImportHelper.ImportPriority.PROJECT
: AddImportHelper.getImportPriority(myFile, (PsiFileSystemItem)toImport);
if (priority == AddImportHelper.ImportPriority.BUILTIN) {
myBuiltinImports.add(importStatement);
if (!myThirdPartyImports.isEmpty() || !myProjectImports.isEmpty()) {
myMissorted = true;
}
}
else if (priority == AddImportHelper.ImportPriority.THIRD_PARTY) {
myThirdPartyImports.add(importStatement);
if (!myProjectImports.isEmpty()) {
myMissorted = true;
}
}
else {
myProjectImports.add(importStatement);
}
}
private void applyResults() {
markGroupBegin(myThirdPartyImports);
markGroupBegin(myProjectImports);
addImports(myBuiltinImports);
addImports(myThirdPartyImports);
addImports(myProjectImports);
PsiElement lastElement = myImportBlock.get(myImportBlock.size()-1);
PyImportStatementBase firstNonFutureImport = findFirstNonFutureImport();
if (firstNonFutureImport != null) {
myFile.deleteChildRange(firstNonFutureImport, lastElement);
}
for (PyImportStatementBase anImport : myBuiltinImports) {
anImport.putCopyableUserData(PyBlock.IMPORT_GROUP_BEGIN, null);
}
}
private PyImportStatementBase findFirstNonFutureImport() {
for (PyImportStatementBase importStatement: myImportBlock) {
if (!(importStatement instanceof PyFromImportStatement && ((PyFromImportStatement)importStatement).isFromFuture())) {
return importStatement;
}
}
return null;
}
private static void markGroupBegin(List<PyImportStatementBase> imports) {
if (imports.size() > 0) {
imports.get(0).putCopyableUserData(PyBlock.IMPORT_GROUP_BEGIN, true);
}
}
private void addImports(final List<PyImportStatementBase> imports) {
for (PyImportStatementBase newImport: imports) {
myFile.addBefore(newImport, findFirstNonFutureImport());
}
}
}
}