| /* |
| * Copyright 2000-2014 JetBrains s.r.o. |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package com.jetbrains.python.refactoring.classes; |
| |
| import com.google.common.collect.Collections2; |
| import com.intellij.lang.ASTNode; |
| import com.intellij.lang.injection.InjectedLanguageManager; |
| import com.intellij.openapi.diagnostic.Logger; |
| import com.intellij.openapi.project.Project; |
| import com.intellij.openapi.util.Key; |
| import com.intellij.openapi.util.Pair; |
| import com.intellij.openapi.vfs.VirtualFile; |
| import com.intellij.psi.*; |
| import com.intellij.psi.util.PsiTreeUtil; |
| import com.intellij.psi.util.PsiUtilBase; |
| import com.intellij.psi.util.QualifiedName; |
| import com.intellij.util.ArrayUtil; |
| import com.jetbrains.NotNullPredicate; |
| import com.jetbrains.python.PyNames; |
| import com.jetbrains.python.codeInsight.PyCodeInsightSettings; |
| import com.jetbrains.python.codeInsight.imports.AddImportHelper; |
| import com.jetbrains.python.codeInsight.imports.PyImportOptimizer; |
| import com.jetbrains.python.psi.*; |
| import com.jetbrains.python.psi.impl.PyBuiltinCache; |
| import com.jetbrains.python.psi.impl.PyImportedModule; |
| import com.jetbrains.python.psi.impl.PyPsiUtils; |
| import com.jetbrains.python.psi.resolve.QualifiedNameFinder; |
| import org.jetbrains.annotations.NotNull; |
| import org.jetbrains.annotations.Nullable; |
| |
| import java.util.*; |
| |
| /** |
| * @author Dennis.Ushakov |
| */ |
| public final class PyClassRefactoringUtil { |
| private static final Logger LOG = Logger.getInstance(PyClassRefactoringUtil.class.getName()); |
| private static final Key<PsiNamedElement> ENCODED_IMPORT = Key.create("PyEncodedImport"); |
| private static final Key<Boolean> ENCODED_USE_FROM_IMPORT = Key.create("PyEncodedUseFromImport"); |
| private static final Key<String> ENCODED_IMPORT_AS = Key.create("PyEncodedImportAs"); |
| |
| |
| private PyClassRefactoringUtil() { |
| } |
| |
| |
| /** |
| * Copies class field declarations to some other place |
| * |
| * @param assignmentStatements list of class fields |
| * @param dequalifyIfDeclaredInClass If not null method will check if field declared in this class. |
| * If declared -- qualifier will be removed. |
| * For example: MyClass.Foo will become Foo it this param is MyClass. |
| * @return new (copied) fields |
| */ |
| @NotNull |
| public static List<PyAssignmentStatement> copyFieldDeclarationToStatement(@NotNull final Collection<PyAssignmentStatement> assignmentStatements, |
| @NotNull final PyStatementList superClassStatement, |
| @Nullable final PyClass dequalifyIfDeclaredInClass) { |
| final List<PyAssignmentStatement> declarations = new ArrayList<PyAssignmentStatement>(assignmentStatements.size()); |
| Collections.sort(declarations, PyDependenciesComparator.INSTANCE); |
| |
| |
| for (final PyAssignmentStatement pyAssignmentStatement : assignmentStatements) { |
| final PyElement value = pyAssignmentStatement.getAssignedValue(); |
| final PyAssignmentStatement newDeclaration = (PyAssignmentStatement)pyAssignmentStatement.copy(); |
| |
| if (value instanceof PyReferenceExpression && dequalifyIfDeclaredInClass != null) { |
| final String newValue = getNewValueToAssign((PyReferenceExpression)value, dequalifyIfDeclaredInClass); |
| |
| setNewAssigneeValue(newDeclaration, newValue); |
| |
| } |
| |
| declarations.add(PyUtil.addElementToStatementList(newDeclaration, superClassStatement)); |
| PyPsiUtils.removeRedundantPass(superClassStatement); |
| } |
| return declarations; |
| } |
| |
| /** |
| * Sets new value to assignment statement. |
| * @param assignmentStatement statement to change |
| * @param newValue new value |
| */ |
| private static void setNewAssigneeValue(@NotNull final PyAssignmentStatement assignmentStatement, @NotNull final String newValue) { |
| final PyExpression oldValue = assignmentStatement.getAssignedValue(); |
| final PyExpression newExpression = |
| PyElementGenerator.getInstance(assignmentStatement.getProject()).createExpressionFromText(LanguageLevel.forElement(assignmentStatement), newValue); |
| if (oldValue != null) { |
| oldValue.replace(newExpression); |
| } else { |
| assignmentStatement.add(newExpression); |
| } |
| } |
| |
| /** |
| * Checks if current value declared in provided class and removes class qualifier if true |
| * @param currentValue current value |
| * @param dequalifyIfDeclaredInClass class to check |
| * @return value as string |
| */ |
| @NotNull |
| private static String getNewValueToAssign(@NotNull final PyReferenceExpression currentValue, @NotNull final PyClass dequalifyIfDeclaredInClass) { |
| final PyExpression qualifier = currentValue.getQualifier(); |
| if ((qualifier instanceof PyReferenceExpression) && |
| ((PyReferenceExpression)qualifier).getReference().isReferenceTo(dequalifyIfDeclaredInClass)) { |
| final String name = currentValue.getName(); |
| return ((name != null) ? name : currentValue.getText()); |
| } |
| return currentValue.getText(); |
| } |
| |
| @NotNull |
| public static List<PyFunction> copyMethods(Collection<PyFunction> methods, PyClass superClass, boolean skipIfExist ) { |
| if (methods.isEmpty()) { |
| return Collections.emptyList(); |
| } |
| for (final PsiElement e : methods) { |
| rememberNamedReferences(e); |
| } |
| final PyFunction[] elements = methods.toArray(new PyFunction[methods.size()]); |
| return addMethods(superClass, skipIfExist, elements); |
| } |
| |
| /** |
| * Adds methods to class. |
| * |
| * @param destination where to add methods |
| * @param methods methods |
| * @param skipIfExist do not add anything if method already exists |
| * @return newly added methods or existing one (if skipIfExists is true and method already exists) |
| */ |
| @NotNull |
| public static List<PyFunction> addMethods(@NotNull final PyClass destination, final boolean skipIfExist, @NotNull final PyFunction... methods) { |
| |
| final PyStatementList destStatementList = destination.getStatementList(); |
| final List<PyFunction> result = new ArrayList<PyFunction>(methods.length); |
| |
| for (final PyFunction method : methods) { |
| |
| final PyFunction existingMethod = destination.findMethodByName(method.getName(), false); |
| if ((existingMethod != null) && skipIfExist) { |
| result.add(existingMethod); |
| continue; //We skip adding if class already has this method. |
| } |
| |
| |
| final PyFunction newMethod = insertMethodInProperPlace(destStatementList, method); |
| result.add(newMethod); |
| restoreNamedReferences(newMethod); |
| } |
| |
| PyPsiUtils.removeRedundantPass(destStatementList); |
| return result; |
| } |
| |
| /** |
| * Adds init methods before all other methods (but after class vars and docs). |
| * Adds all other methods to the bottom |
| * |
| * @param destStatementList where to add methods |
| * @param method method to add |
| * @return newlty added method |
| */ |
| @NotNull |
| private static PyFunction insertMethodInProperPlace( |
| @NotNull final PyStatementList destStatementList, |
| @NotNull final PyFunction method) { |
| boolean methodIsInit = PyUtil.isInit(method); |
| if (!methodIsInit) { |
| //Not init method could be inserted in the bottom |
| return (PyFunction)destStatementList.add(method); |
| } |
| |
| //We should find appropriate place for init |
| for (final PsiElement element : destStatementList.getChildren()) { |
| final boolean elementComment = element instanceof PyExpressionStatement; |
| final boolean elementClassField = element instanceof PyAssignmentStatement; |
| |
| if (!(elementComment || elementClassField)) { |
| return (PyFunction)destStatementList.addBefore(method, element); |
| } |
| } |
| return (PyFunction)destStatementList.add(method); |
| } |
| |
| |
| public static <T extends PyElement & PyStatementListContainer> void insertPassIfNeeded(@NotNull T element) { |
| final PyStatementList statements = element.getStatementList(); |
| if (statements.getStatements().length == 0) { |
| statements.add( |
| PyElementGenerator.getInstance(element.getProject()) |
| .createFromText(LanguageLevel.getDefault(), PyPassStatement.class, PyNames.PASS) |
| ); |
| } |
| } |
| |
| /** |
| * Restores references saved by {@link #rememberNamedReferences(com.intellij.psi.PsiElement, String...)}. |
| * |
| * @param element newly created element to restore references |
| * @see #rememberNamedReferences(com.intellij.psi.PsiElement, String...) |
| */ |
| public static void restoreNamedReferences(@NotNull final PsiElement element) { |
| restoreNamedReferences(element, null); |
| } |
| |
| public static void restoreNamedReferences(@NotNull final PsiElement newElement, @Nullable final PsiElement oldElement) { |
| newElement.acceptChildren(new PyRecursiveElementVisitor() { |
| @Override |
| public void visitPyReferenceExpression(PyReferenceExpression node) { |
| super.visitPyReferenceExpression(node); |
| restoreReference(node); |
| } |
| |
| @Override |
| public void visitPyStringLiteralExpression(PyStringLiteralExpression node) { |
| super.visitPyStringLiteralExpression(node); |
| for (PsiReference ref : node.getReferences()) { |
| if (ref.isReferenceTo(oldElement)) { |
| ref.bindToElement(newElement); |
| } |
| } |
| } |
| }); |
| } |
| |
| |
| private static void restoreReference(final PyReferenceExpression node) { |
| PsiNamedElement target = node.getCopyableUserData(ENCODED_IMPORT); |
| final String asName = node.getCopyableUserData(ENCODED_IMPORT_AS); |
| final Boolean useFromImport = node.getCopyableUserData(ENCODED_USE_FROM_IMPORT); |
| if (target instanceof PsiDirectory) { |
| target = (PsiNamedElement)PyUtil.turnDirIntoPackageElement((PsiDirectory)target, node); |
| } |
| if (target instanceof PyFunction) { |
| final PyFunction f = (PyFunction)target; |
| final PyClass c = f.getContainingClass(); |
| if (c != null && c.findInitOrNew(false) == f) { |
| target = c; |
| } |
| } |
| if (target == null) return; |
| if (PsiTreeUtil.isAncestor(node.getContainingFile(), target, false)) return; |
| if (target instanceof PyFile || target instanceof PsiDirectory) { |
| insertImport(node, target, asName, useFromImport != null ? useFromImport : true); |
| } |
| else { |
| insertImport(node, target, asName, true); |
| } |
| node.putCopyableUserData(ENCODED_IMPORT, null); |
| node.putCopyableUserData(ENCODED_IMPORT_AS, null); |
| node.putCopyableUserData(ENCODED_USE_FROM_IMPORT, null); |
| } |
| |
| public static void insertImport(PsiElement anchor, Collection<PsiNamedElement> elements) { |
| for (PsiNamedElement newClass : elements) { |
| insertImport(anchor, newClass); |
| } |
| } |
| |
| public static boolean isValidQualifiedName(QualifiedName name) { |
| if (name == null) { |
| return false; |
| } |
| final Collection<String> components = name.getComponents(); |
| if (components.isEmpty()) { |
| return false; |
| } |
| for (String s : components) { |
| if (!PyNames.isIdentifier(s) || PyNames.isReserved(s)) { |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| public static boolean insertImport(@NotNull PsiElement anchor, @NotNull PsiNamedElement element) { |
| return insertImport(anchor, element, null); |
| } |
| |
| public static boolean insertImport(@NotNull PsiElement anchor, @NotNull PsiNamedElement element, @Nullable String asName) { |
| return insertImport(anchor, element, asName, PyCodeInsightSettings.getInstance().PREFER_FROM_IMPORT); |
| } |
| |
| public static boolean insertImport(@NotNull PsiElement anchor, |
| @NotNull PsiNamedElement element, |
| @Nullable String asName, |
| boolean preferFromImport) { |
| if (PyBuiltinCache.getInstance(element).isBuiltin(element)) return false; |
| final PsiFileSystemItem elementSource = element instanceof PsiDirectory? (PsiFileSystemItem)element : element.getContainingFile(); |
| final PsiFile file = anchor.getContainingFile(); |
| if (elementSource == file) return false; |
| final QualifiedName qname = QualifiedNameFinder.findCanonicalImportPath(element, anchor); |
| if (qname == null || !isValidQualifiedName(qname)) { |
| return false; |
| } |
| final QualifiedName containingQName; |
| final String importedName; |
| if (element instanceof PyFile || element instanceof PsiDirectory) { |
| containingQName = qname.removeLastComponent(); |
| importedName = qname.getLastComponent(); |
| } |
| else { |
| containingQName = qname; |
| importedName = getOriginalName(element); |
| } |
| final AddImportHelper.ImportPriority priority = AddImportHelper.getImportPriority(anchor, elementSource); |
| if (preferFromImport && !containingQName.getComponents().isEmpty()) { |
| return AddImportHelper.addOrUpdateFromImportStatement(file, containingQName.toString(), importedName, asName, priority, anchor); |
| } |
| else { |
| return AddImportHelper.addImportStatement(file, containingQName.append(importedName).toString(), asName, priority, anchor); |
| } |
| } |
| |
| /** |
| * Searches for references inside some element (like {@link com.jetbrains.python.psi.PyAssignmentStatement}, {@link com.jetbrains.python.psi.PyFunction} etc |
| * and stored them. |
| * After that you can add element to some new parent. Newly created element then should be processed via {@link #restoreNamedReferences(com.intellij.psi.PsiElement)} |
| * and all references would be restored. |
| * |
| * @param element element to store references for |
| * @param namesToSkip if reference inside of element has one of this names, it will not be saved. |
| */ |
| public static void rememberNamedReferences(@NotNull final PsiElement element, @NotNull final String... namesToSkip) { |
| element.acceptChildren(new PyRecursiveElementVisitor() { |
| @Override |
| public void visitPyReferenceExpression(PyReferenceExpression node) { |
| super.visitPyReferenceExpression(node); |
| if (PsiTreeUtil.getParentOfType(node, PyImportStatementBase.class) != null) { |
| return; |
| } |
| final NameDefiner importElement = getImportElement(node); |
| if (importElement != null && PsiTreeUtil.isAncestor(element, importElement, false)) { |
| return; |
| } |
| if (!ArrayUtil.contains(node.getText(), namesToSkip)) { //Do not remember name if it should be skipped |
| rememberReference(node, element); |
| } |
| } |
| }); |
| } |
| |
| private static void rememberReference(@NotNull PyReferenceExpression node, @NotNull PsiElement element) { |
| // We will remember reference in deepest node (except for references to PyImportedModules, as we need references to modules, not to |
| // their packages) |
| final PyExpression qualifier = node.getQualifier(); |
| if (qualifier != null && !(resolveExpression(qualifier) instanceof PyImportedModule)) { |
| return; |
| } |
| final PsiElement target = resolveExpression(node); |
| if (target instanceof PsiNamedElement && !PsiTreeUtil.isAncestor(element, target, false)) { |
| final NameDefiner importElement = getImportElement(node); |
| if (!PyUtil.inSameFile(element, target) && importElement == null && !(target instanceof PsiFileSystemItem)) { |
| return; |
| } |
| node.putCopyableUserData(ENCODED_IMPORT, (PsiNamedElement)target); |
| if (importElement instanceof PyImportElement) { |
| node.putCopyableUserData(ENCODED_IMPORT_AS, ((PyImportElement)importElement).getAsName()); |
| } |
| node.putCopyableUserData(ENCODED_USE_FROM_IMPORT, qualifier == null); |
| } |
| } |
| |
| @Nullable |
| private static NameDefiner getImportElement(PyReferenceExpression expr) { |
| for (ResolveResult result : expr.getReference().multiResolve(false)) { |
| final PsiElement e = result.getElement(); |
| if (e instanceof PyImportElement) { |
| return (PyImportElement)e; |
| } |
| if (e instanceof PyStarImportElement) { |
| return (PyStarImportElement)e; |
| } |
| } |
| return null; |
| } |
| |
| @Nullable |
| private static PsiElement resolveExpression(@NotNull PyExpression expr) { |
| if (expr instanceof PyReferenceExpression) { |
| return ((PyReferenceExpression)expr).getReference().resolve(); |
| } |
| return null; |
| } |
| |
| public static void updateImportOfElement(@NotNull PyImportStatementBase importStatement, @NotNull PsiNamedElement element) { |
| final String name = getOriginalName(element); |
| if (name != null) { |
| PyImportElement importElement = null; |
| for (PyImportElement e : importStatement.getImportElements()) { |
| if (name.equals(getOriginalName(e))) { |
| importElement = e; |
| } |
| } |
| if (importElement != null) { |
| final PsiFile file = importStatement.getContainingFile(); |
| final PsiFile newFile = element.getContainingFile(); |
| boolean deleteImportElement = false; |
| if (newFile == file) { |
| deleteImportElement = true; |
| } |
| else if (insertImport(importStatement, element, importElement.getAsName(), true)) { |
| deleteImportElement = true; |
| } |
| if (deleteImportElement) { |
| if (importStatement.getImportElements().length == 1) { |
| final boolean isInjected = |
| InjectedLanguageManager.getInstance(importElement.getProject()).isInjectedFragment(importElement.getContainingFile()); |
| if (!isInjected) { |
| importStatement.delete(); |
| } |
| else { |
| deleteImportStatementFromInjected(importStatement); |
| } |
| } |
| else { |
| importElement.delete(); |
| } |
| } |
| } |
| } |
| } |
| |
| private static void deleteImportStatementFromInjected(@NotNull final PyImportStatementBase importStatement) { |
| final PsiElement sibling = importStatement.getPrevSibling(); |
| importStatement.delete(); |
| if (sibling instanceof PsiWhiteSpace) sibling.delete(); |
| } |
| |
| @Nullable |
| public static String getOriginalName(@NotNull PsiNamedElement element) { |
| if (element instanceof PyFile) { |
| VirtualFile virtualFile = PsiUtilBase.asVirtualFile(PyUtil.turnInitIntoDir(element)); |
| if (virtualFile != null) { |
| return virtualFile.getNameWithoutExtension(); |
| } |
| return null; |
| } |
| return element.getName(); |
| } |
| |
| @Nullable |
| private static String getOriginalName(PyImportElement element) { |
| final QualifiedName qname = element.getImportedQName(); |
| if (qname != null && qname.getComponentCount() > 0) { |
| return qname.getComponents().get(0); |
| } |
| return null; |
| } |
| |
| /** |
| * Adds super classes to certain class. |
| * |
| * @param project project where refactoring takes place |
| * @param clazz destination |
| * @param superClasses classes to add |
| */ |
| public static void addSuperclasses(@NotNull final Project project, |
| @NotNull final PyClass clazz, |
| @NotNull final PyClass... superClasses) { |
| |
| final Collection<String> superClassNames = new ArrayList<String>(); |
| |
| |
| for (final PyClass superClass : Collections2.filter(Arrays.asList(superClasses), NotNullPredicate.INSTANCE)) { |
| if (superClass.getName() != null) { |
| superClassNames.add(superClass.getName()); |
| insertImport(clazz, superClass); |
| } |
| } |
| |
| addSuperClassExpressions(project, clazz, superClassNames, null); |
| } |
| |
| |
| /** |
| * Adds expressions to superclass list |
| * |
| * @param project project |
| * @param clazz class to add expressions to superclass list |
| * @param paramExpressions param expressions. Like "object" or "MySuperClass". Will not add any param exp. if null. |
| * @param keywordArguments keyword args like "metaclass=ABCMeta". key-value pairs. Will not add any keyword arg. if null. |
| */ |
| public static void addSuperClassExpressions(@NotNull final Project project, |
| @NotNull final PyClass clazz, |
| @Nullable final Collection<String> paramExpressions, |
| @Nullable final Collection<Pair<String, String>> keywordArguments) { |
| final PyElementGenerator generator = PyElementGenerator.getInstance(project); |
| final LanguageLevel languageLevel = LanguageLevel.forElement(clazz); |
| |
| PyArgumentList superClassExpressionList = clazz.getSuperClassExpressionList(); |
| boolean addExpression = false; |
| if (superClassExpressionList == null) { |
| superClassExpressionList = generator.createFromText(languageLevel, PyClass.class, "class foo():pass").getSuperClassExpressionList(); |
| assert superClassExpressionList != null : "expression not created"; |
| addExpression = true; |
| } |
| |
| |
| generator.createFromText(LanguageLevel.PYTHON34, PyClass.class, "class foo(object, metaclass=Foo): pass").getSuperClassExpressionList(); |
| if (paramExpressions != null) { |
| for (final String paramExpression : paramExpressions) { |
| superClassExpressionList.addArgument(generator.createParameter(paramExpression)); |
| } |
| } |
| |
| if (keywordArguments != null) { |
| for (final Pair<String, String> keywordArgument : keywordArguments) { |
| superClassExpressionList.addArgument(generator.createKeywordArgument(languageLevel, keywordArgument.first, keywordArgument.second)); |
| } |
| } |
| |
| // If class has no expression list, then we need to add it manually. |
| if (addExpression) { |
| final ASTNode classNameNode = clazz.getNameNode(); // For nameless classes we simply add expression list directly to them |
| final PsiElement elementToAddAfter = (classNameNode == null) ? clazz.getFirstChild() : classNameNode.getPsi(); |
| clazz.addAfter(superClassExpressionList, elementToAddAfter); |
| } |
| } |
| |
| /** |
| * Optimizes imports resorting them and removing unneeded |
| * |
| * @param file file to optimize imports |
| */ |
| public static void optimizeImports(@NotNull final PsiFile file) { |
| new PyImportOptimizer().processFile(file).run(); |
| } |
| |
| /** |
| * Adds class attributeName (field) if it does not exist. like __metaclass__ = ABCMeta. Or CLASS_FIELD = 42. |
| * |
| * @param aClass where to add |
| * @param attributeName attribute's name. Like __metaclass__ or CLASS_FIELD |
| * @param value it's value. Like ABCMeta or 42. |
| * @return newly inserted attribute |
| */ |
| @Nullable |
| public static PsiElement addClassAttributeIfNotExist( |
| @NotNull final PyClass aClass, |
| @NotNull final String attributeName, |
| @NotNull final String value) { |
| if (aClass.findClassAttribute(attributeName, false) != null) { |
| return null; //Do not add any if exist already |
| } |
| final PyElementGenerator generator = PyElementGenerator.getInstance(aClass.getProject()); |
| final String text = String.format("%s = %s", attributeName, value); |
| final LanguageLevel level = LanguageLevel.forElement(aClass); |
| |
| final PyAssignmentStatement assignmentStatement = generator.createFromText(level, PyAssignmentStatement.class, text); |
| //TODO: Add metaclass to the top. Add others between last attributeName and first method |
| return PyUtil.addElementToStatementList(assignmentStatement, aClass.getStatementList(), true); |
| } |
| } |