blob: c748731559b0078e85e7a566f8bf42a0ddbbd9fb [file] [log] [blame]
/*
* Copyright 2000-2014 JetBrains s.r.o.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.jetbrains.python.refactoring.classes;
import com.google.common.collect.Collections2;
import com.intellij.lang.ASTNode;
import com.intellij.lang.injection.InjectedLanguageManager;
import com.intellij.openapi.diagnostic.Logger;
import com.intellij.openapi.project.Project;
import com.intellij.openapi.util.Key;
import com.intellij.openapi.util.Pair;
import com.intellij.openapi.vfs.VirtualFile;
import com.intellij.psi.*;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.psi.util.PsiUtilBase;
import com.intellij.psi.util.QualifiedName;
import com.intellij.util.ArrayUtil;
import com.jetbrains.NotNullPredicate;
import com.jetbrains.python.PyNames;
import com.jetbrains.python.codeInsight.PyCodeInsightSettings;
import com.jetbrains.python.codeInsight.imports.AddImportHelper;
import com.jetbrains.python.codeInsight.imports.PyImportOptimizer;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.impl.PyBuiltinCache;
import com.jetbrains.python.psi.impl.PyImportedModule;
import com.jetbrains.python.psi.impl.PyPsiUtils;
import com.jetbrains.python.psi.resolve.QualifiedNameFinder;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.*;
/**
* @author Dennis.Ushakov
*/
public final class PyClassRefactoringUtil {
private static final Logger LOG = Logger.getInstance(PyClassRefactoringUtil.class.getName());
private static final Key<PsiNamedElement> ENCODED_IMPORT = Key.create("PyEncodedImport");
private static final Key<Boolean> ENCODED_USE_FROM_IMPORT = Key.create("PyEncodedUseFromImport");
private static final Key<String> ENCODED_IMPORT_AS = Key.create("PyEncodedImportAs");
private PyClassRefactoringUtil() {
}
/**
* Copies class field declarations to some other place
*
* @param assignmentStatements list of class fields
* @param dequalifyIfDeclaredInClass If not null method will check if field declared in this class.
* If declared -- qualifier will be removed.
* For example: MyClass.Foo will become Foo it this param is MyClass.
* @return new (copied) fields
*/
@NotNull
public static List<PyAssignmentStatement> copyFieldDeclarationToStatement(@NotNull final Collection<PyAssignmentStatement> assignmentStatements,
@NotNull final PyStatementList superClassStatement,
@Nullable final PyClass dequalifyIfDeclaredInClass) {
final List<PyAssignmentStatement> declarations = new ArrayList<PyAssignmentStatement>(assignmentStatements.size());
Collections.sort(declarations, PyDependenciesComparator.INSTANCE);
for (final PyAssignmentStatement pyAssignmentStatement : assignmentStatements) {
final PyElement value = pyAssignmentStatement.getAssignedValue();
final PyAssignmentStatement newDeclaration = (PyAssignmentStatement)pyAssignmentStatement.copy();
if (value instanceof PyReferenceExpression && dequalifyIfDeclaredInClass != null) {
final String newValue = getNewValueToAssign((PyReferenceExpression)value, dequalifyIfDeclaredInClass);
setNewAssigneeValue(newDeclaration, newValue);
}
declarations.add(PyUtil.addElementToStatementList(newDeclaration, superClassStatement));
PyPsiUtils.removeRedundantPass(superClassStatement);
}
return declarations;
}
/**
* Sets new value to assignment statement.
* @param assignmentStatement statement to change
* @param newValue new value
*/
private static void setNewAssigneeValue(@NotNull final PyAssignmentStatement assignmentStatement, @NotNull final String newValue) {
final PyExpression oldValue = assignmentStatement.getAssignedValue();
final PyExpression newExpression =
PyElementGenerator.getInstance(assignmentStatement.getProject()).createExpressionFromText(LanguageLevel.forElement(assignmentStatement), newValue);
if (oldValue != null) {
oldValue.replace(newExpression);
} else {
assignmentStatement.add(newExpression);
}
}
/**
* Checks if current value declared in provided class and removes class qualifier if true
* @param currentValue current value
* @param dequalifyIfDeclaredInClass class to check
* @return value as string
*/
@NotNull
private static String getNewValueToAssign(@NotNull final PyReferenceExpression currentValue, @NotNull final PyClass dequalifyIfDeclaredInClass) {
final PyExpression qualifier = currentValue.getQualifier();
if ((qualifier instanceof PyReferenceExpression) &&
((PyReferenceExpression)qualifier).getReference().isReferenceTo(dequalifyIfDeclaredInClass)) {
final String name = currentValue.getName();
return ((name != null) ? name : currentValue.getText());
}
return currentValue.getText();
}
@NotNull
public static List<PyFunction> copyMethods(Collection<PyFunction> methods, PyClass superClass, boolean skipIfExist ) {
if (methods.isEmpty()) {
return Collections.emptyList();
}
for (final PsiElement e : methods) {
rememberNamedReferences(e);
}
final PyFunction[] elements = methods.toArray(new PyFunction[methods.size()]);
return addMethods(superClass, skipIfExist, elements);
}
/**
* Adds methods to class.
*
* @param destination where to add methods
* @param methods methods
* @param skipIfExist do not add anything if method already exists
* @return newly added methods or existing one (if skipIfExists is true and method already exists)
*/
@NotNull
public static List<PyFunction> addMethods(@NotNull final PyClass destination, final boolean skipIfExist, @NotNull final PyFunction... methods) {
final PyStatementList destStatementList = destination.getStatementList();
final List<PyFunction> result = new ArrayList<PyFunction>(methods.length);
for (final PyFunction method : methods) {
final PyFunction existingMethod = destination.findMethodByName(method.getName(), false);
if ((existingMethod != null) && skipIfExist) {
result.add(existingMethod);
continue; //We skip adding if class already has this method.
}
final PyFunction newMethod = insertMethodInProperPlace(destStatementList, method);
result.add(newMethod);
restoreNamedReferences(newMethod);
}
PyPsiUtils.removeRedundantPass(destStatementList);
return result;
}
/**
* Adds init methods before all other methods (but after class vars and docs).
* Adds all other methods to the bottom
*
* @param destStatementList where to add methods
* @param method method to add
* @return newlty added method
*/
@NotNull
private static PyFunction insertMethodInProperPlace(
@NotNull final PyStatementList destStatementList,
@NotNull final PyFunction method) {
boolean methodIsInit = PyUtil.isInit(method);
if (!methodIsInit) {
//Not init method could be inserted in the bottom
return (PyFunction)destStatementList.add(method);
}
//We should find appropriate place for init
for (final PsiElement element : destStatementList.getChildren()) {
final boolean elementComment = element instanceof PyExpressionStatement;
final boolean elementClassField = element instanceof PyAssignmentStatement;
if (!(elementComment || elementClassField)) {
return (PyFunction)destStatementList.addBefore(method, element);
}
}
return (PyFunction)destStatementList.add(method);
}
public static <T extends PyElement & PyStatementListContainer> void insertPassIfNeeded(@NotNull T element) {
final PyStatementList statements = element.getStatementList();
if (statements.getStatements().length == 0) {
statements.add(
PyElementGenerator.getInstance(element.getProject())
.createFromText(LanguageLevel.getDefault(), PyPassStatement.class, PyNames.PASS)
);
}
}
/**
* Restores references saved by {@link #rememberNamedReferences(com.intellij.psi.PsiElement, String...)}.
*
* @param element newly created element to restore references
* @see #rememberNamedReferences(com.intellij.psi.PsiElement, String...)
*/
public static void restoreNamedReferences(@NotNull final PsiElement element) {
restoreNamedReferences(element, null);
}
public static void restoreNamedReferences(@NotNull final PsiElement newElement, @Nullable final PsiElement oldElement) {
newElement.acceptChildren(new PyRecursiveElementVisitor() {
@Override
public void visitPyReferenceExpression(PyReferenceExpression node) {
super.visitPyReferenceExpression(node);
restoreReference(node);
}
@Override
public void visitPyStringLiteralExpression(PyStringLiteralExpression node) {
super.visitPyStringLiteralExpression(node);
for (PsiReference ref : node.getReferences()) {
if (ref.isReferenceTo(oldElement)) {
ref.bindToElement(newElement);
}
}
}
});
}
private static void restoreReference(final PyReferenceExpression node) {
PsiNamedElement target = node.getCopyableUserData(ENCODED_IMPORT);
final String asName = node.getCopyableUserData(ENCODED_IMPORT_AS);
final Boolean useFromImport = node.getCopyableUserData(ENCODED_USE_FROM_IMPORT);
if (target instanceof PsiDirectory) {
target = (PsiNamedElement)PyUtil.turnDirIntoPackageElement((PsiDirectory)target, node);
}
if (target instanceof PyFunction) {
final PyFunction f = (PyFunction)target;
final PyClass c = f.getContainingClass();
if (c != null && c.findInitOrNew(false) == f) {
target = c;
}
}
if (target == null) return;
if (PsiTreeUtil.isAncestor(node.getContainingFile(), target, false)) return;
if (target instanceof PyFile || target instanceof PsiDirectory) {
insertImport(node, target, asName, useFromImport != null ? useFromImport : true);
}
else {
insertImport(node, target, asName, true);
}
node.putCopyableUserData(ENCODED_IMPORT, null);
node.putCopyableUserData(ENCODED_IMPORT_AS, null);
node.putCopyableUserData(ENCODED_USE_FROM_IMPORT, null);
}
public static void insertImport(PsiElement anchor, Collection<PsiNamedElement> elements) {
for (PsiNamedElement newClass : elements) {
insertImport(anchor, newClass);
}
}
public static boolean isValidQualifiedName(QualifiedName name) {
if (name == null) {
return false;
}
final Collection<String> components = name.getComponents();
if (components.isEmpty()) {
return false;
}
for (String s : components) {
if (!PyNames.isIdentifier(s) || PyNames.isReserved(s)) {
return false;
}
}
return true;
}
public static boolean insertImport(@NotNull PsiElement anchor, @NotNull PsiNamedElement element) {
return insertImport(anchor, element, null);
}
public static boolean insertImport(@NotNull PsiElement anchor, @NotNull PsiNamedElement element, @Nullable String asName) {
return insertImport(anchor, element, asName, PyCodeInsightSettings.getInstance().PREFER_FROM_IMPORT);
}
public static boolean insertImport(@NotNull PsiElement anchor,
@NotNull PsiNamedElement element,
@Nullable String asName,
boolean preferFromImport) {
if (PyBuiltinCache.getInstance(element).isBuiltin(element)) return false;
final PsiFileSystemItem elementSource = element instanceof PsiDirectory? (PsiFileSystemItem)element : element.getContainingFile();
final PsiFile file = anchor.getContainingFile();
if (elementSource == file) return false;
final QualifiedName qname = QualifiedNameFinder.findCanonicalImportPath(element, anchor);
if (qname == null || !isValidQualifiedName(qname)) {
return false;
}
final QualifiedName containingQName;
final String importedName;
if (element instanceof PyFile || element instanceof PsiDirectory) {
containingQName = qname.removeLastComponent();
importedName = qname.getLastComponent();
}
else {
containingQName = qname;
importedName = getOriginalName(element);
}
final AddImportHelper.ImportPriority priority = AddImportHelper.getImportPriority(anchor, elementSource);
if (preferFromImport && !containingQName.getComponents().isEmpty()) {
return AddImportHelper.addOrUpdateFromImportStatement(file, containingQName.toString(), importedName, asName, priority, anchor);
}
else {
return AddImportHelper.addImportStatement(file, containingQName.append(importedName).toString(), asName, priority, anchor);
}
}
/**
* Searches for references inside some element (like {@link com.jetbrains.python.psi.PyAssignmentStatement}, {@link com.jetbrains.python.psi.PyFunction} etc
* and stored them.
* After that you can add element to some new parent. Newly created element then should be processed via {@link #restoreNamedReferences(com.intellij.psi.PsiElement)}
* and all references would be restored.
*
* @param element element to store references for
* @param namesToSkip if reference inside of element has one of this names, it will not be saved.
*/
public static void rememberNamedReferences(@NotNull final PsiElement element, @NotNull final String... namesToSkip) {
element.acceptChildren(new PyRecursiveElementVisitor() {
@Override
public void visitPyReferenceExpression(PyReferenceExpression node) {
super.visitPyReferenceExpression(node);
if (PsiTreeUtil.getParentOfType(node, PyImportStatementBase.class) != null) {
return;
}
final NameDefiner importElement = getImportElement(node);
if (importElement != null && PsiTreeUtil.isAncestor(element, importElement, false)) {
return;
}
if (!ArrayUtil.contains(node.getText(), namesToSkip)) { //Do not remember name if it should be skipped
rememberReference(node, element);
}
}
});
}
private static void rememberReference(@NotNull PyReferenceExpression node, @NotNull PsiElement element) {
// We will remember reference in deepest node (except for references to PyImportedModules, as we need references to modules, not to
// their packages)
final PyExpression qualifier = node.getQualifier();
if (qualifier != null && !(resolveExpression(qualifier) instanceof PyImportedModule)) {
return;
}
final PsiElement target = resolveExpression(node);
if (target instanceof PsiNamedElement && !PsiTreeUtil.isAncestor(element, target, false)) {
final NameDefiner importElement = getImportElement(node);
if (!PyUtil.inSameFile(element, target) && importElement == null && !(target instanceof PsiFileSystemItem)) {
return;
}
node.putCopyableUserData(ENCODED_IMPORT, (PsiNamedElement)target);
if (importElement instanceof PyImportElement) {
node.putCopyableUserData(ENCODED_IMPORT_AS, ((PyImportElement)importElement).getAsName());
}
node.putCopyableUserData(ENCODED_USE_FROM_IMPORT, qualifier == null);
}
}
@Nullable
private static NameDefiner getImportElement(PyReferenceExpression expr) {
for (ResolveResult result : expr.getReference().multiResolve(false)) {
final PsiElement e = result.getElement();
if (e instanceof PyImportElement) {
return (PyImportElement)e;
}
if (e instanceof PyStarImportElement) {
return (PyStarImportElement)e;
}
}
return null;
}
@Nullable
private static PsiElement resolveExpression(@NotNull PyExpression expr) {
if (expr instanceof PyReferenceExpression) {
return ((PyReferenceExpression)expr).getReference().resolve();
}
return null;
}
public static void updateImportOfElement(@NotNull PyImportStatementBase importStatement, @NotNull PsiNamedElement element) {
final String name = getOriginalName(element);
if (name != null) {
PyImportElement importElement = null;
for (PyImportElement e : importStatement.getImportElements()) {
if (name.equals(getOriginalName(e))) {
importElement = e;
}
}
if (importElement != null) {
final PsiFile file = importStatement.getContainingFile();
final PsiFile newFile = element.getContainingFile();
boolean deleteImportElement = false;
if (newFile == file) {
deleteImportElement = true;
}
else if (insertImport(importStatement, element, importElement.getAsName(), true)) {
deleteImportElement = true;
}
if (deleteImportElement) {
if (importStatement.getImportElements().length == 1) {
final boolean isInjected =
InjectedLanguageManager.getInstance(importElement.getProject()).isInjectedFragment(importElement.getContainingFile());
if (!isInjected) {
importStatement.delete();
}
else {
deleteImportStatementFromInjected(importStatement);
}
}
else {
importElement.delete();
}
}
}
}
}
private static void deleteImportStatementFromInjected(@NotNull final PyImportStatementBase importStatement) {
final PsiElement sibling = importStatement.getPrevSibling();
importStatement.delete();
if (sibling instanceof PsiWhiteSpace) sibling.delete();
}
@Nullable
public static String getOriginalName(@NotNull PsiNamedElement element) {
if (element instanceof PyFile) {
VirtualFile virtualFile = PsiUtilBase.asVirtualFile(PyUtil.turnInitIntoDir(element));
if (virtualFile != null) {
return virtualFile.getNameWithoutExtension();
}
return null;
}
return element.getName();
}
@Nullable
private static String getOriginalName(PyImportElement element) {
final QualifiedName qname = element.getImportedQName();
if (qname != null && qname.getComponentCount() > 0) {
return qname.getComponents().get(0);
}
return null;
}
/**
* Adds super classes to certain class.
*
* @param project project where refactoring takes place
* @param clazz destination
* @param superClasses classes to add
*/
public static void addSuperclasses(@NotNull final Project project,
@NotNull final PyClass clazz,
@NotNull final PyClass... superClasses) {
final Collection<String> superClassNames = new ArrayList<String>();
for (final PyClass superClass : Collections2.filter(Arrays.asList(superClasses), NotNullPredicate.INSTANCE)) {
if (superClass.getName() != null) {
superClassNames.add(superClass.getName());
insertImport(clazz, superClass);
}
}
addSuperClassExpressions(project, clazz, superClassNames, null);
}
/**
* Adds expressions to superclass list
*
* @param project project
* @param clazz class to add expressions to superclass list
* @param paramExpressions param expressions. Like "object" or "MySuperClass". Will not add any param exp. if null.
* @param keywordArguments keyword args like "metaclass=ABCMeta". key-value pairs. Will not add any keyword arg. if null.
*/
public static void addSuperClassExpressions(@NotNull final Project project,
@NotNull final PyClass clazz,
@Nullable final Collection<String> paramExpressions,
@Nullable final Collection<Pair<String, String>> keywordArguments) {
final PyElementGenerator generator = PyElementGenerator.getInstance(project);
final LanguageLevel languageLevel = LanguageLevel.forElement(clazz);
PyArgumentList superClassExpressionList = clazz.getSuperClassExpressionList();
boolean addExpression = false;
if (superClassExpressionList == null) {
superClassExpressionList = generator.createFromText(languageLevel, PyClass.class, "class foo():pass").getSuperClassExpressionList();
assert superClassExpressionList != null : "expression not created";
addExpression = true;
}
generator.createFromText(LanguageLevel.PYTHON34, PyClass.class, "class foo(object, metaclass=Foo): pass").getSuperClassExpressionList();
if (paramExpressions != null) {
for (final String paramExpression : paramExpressions) {
superClassExpressionList.addArgument(generator.createParameter(paramExpression));
}
}
if (keywordArguments != null) {
for (final Pair<String, String> keywordArgument : keywordArguments) {
superClassExpressionList.addArgument(generator.createKeywordArgument(languageLevel, keywordArgument.first, keywordArgument.second));
}
}
// If class has no expression list, then we need to add it manually.
if (addExpression) {
final ASTNode classNameNode = clazz.getNameNode(); // For nameless classes we simply add expression list directly to them
final PsiElement elementToAddAfter = (classNameNode == null) ? clazz.getFirstChild() : classNameNode.getPsi();
clazz.addAfter(superClassExpressionList, elementToAddAfter);
}
}
/**
* Optimizes imports resorting them and removing unneeded
*
* @param file file to optimize imports
*/
public static void optimizeImports(@NotNull final PsiFile file) {
new PyImportOptimizer().processFile(file).run();
}
/**
* Adds class attributeName (field) if it does not exist. like __metaclass__ = ABCMeta. Or CLASS_FIELD = 42.
*
* @param aClass where to add
* @param attributeName attribute's name. Like __metaclass__ or CLASS_FIELD
* @param value it's value. Like ABCMeta or 42.
* @return newly inserted attribute
*/
@Nullable
public static PsiElement addClassAttributeIfNotExist(
@NotNull final PyClass aClass,
@NotNull final String attributeName,
@NotNull final String value) {
if (aClass.findClassAttribute(attributeName, false) != null) {
return null; //Do not add any if exist already
}
final PyElementGenerator generator = PyElementGenerator.getInstance(aClass.getProject());
final String text = String.format("%s = %s", attributeName, value);
final LanguageLevel level = LanguageLevel.forElement(aClass);
final PyAssignmentStatement assignmentStatement = generator.createFromText(level, PyAssignmentStatement.class, text);
//TODO: Add metaclass to the top. Add others between last attributeName and first method
return PyUtil.addElementToStatementList(assignmentStatement, aClass.getStatementList(), true);
}
}