| /* |
| * Copyright 2000-2013 JetBrains s.r.o. |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package com.jetbrains.python.refactoring.introduce.field; |
| |
| import com.intellij.lang.ASTNode; |
| import com.intellij.openapi.actionSystem.DataContext; |
| import com.intellij.openapi.application.AccessToken; |
| import com.intellij.openapi.application.ApplicationManager; |
| import com.intellij.openapi.editor.CaretModel; |
| import com.intellij.openapi.editor.Document; |
| import com.intellij.openapi.editor.Editor; |
| import com.intellij.openapi.editor.SelectionModel; |
| import com.intellij.openapi.project.Project; |
| import com.intellij.psi.PsiElement; |
| import com.intellij.psi.PsiFile; |
| import com.intellij.psi.PsiReference; |
| import com.intellij.psi.search.LocalSearchScope; |
| import com.intellij.psi.search.searches.ReferencesSearch; |
| import com.intellij.psi.util.PsiTreeUtil; |
| import com.intellij.refactoring.RefactoringBundle; |
| import com.intellij.refactoring.introduce.inplace.InplaceVariableIntroducer; |
| import com.intellij.refactoring.util.CommonRefactoringUtil; |
| import com.intellij.util.Function; |
| import com.intellij.util.FunctionUtil; |
| import com.jetbrains.python.PyNames; |
| import com.jetbrains.python.inspections.quickfix.AddFieldQuickFix; |
| import com.jetbrains.python.codeInsight.controlflow.ScopeOwner; |
| import com.jetbrains.python.psi.*; |
| import com.jetbrains.python.psi.impl.PyFunctionBuilder; |
| import com.jetbrains.python.refactoring.PyReplaceExpressionUtil; |
| import com.jetbrains.python.refactoring.introduce.IntroduceHandler; |
| import com.jetbrains.python.refactoring.introduce.IntroduceOperation; |
| import com.jetbrains.python.refactoring.introduce.variable.PyIntroduceVariableHandler; |
| import com.jetbrains.python.testing.PythonUnitTestUtil; |
| import org.jetbrains.annotations.NotNull; |
| import org.jetbrains.annotations.Nullable; |
| |
| import javax.swing.*; |
| import java.util.*; |
| |
| /** |
| * @author Dennis.Ushakov |
| */ |
| public class PyIntroduceFieldHandler extends IntroduceHandler { |
| |
| public PyIntroduceFieldHandler() { |
| super(new IntroduceFieldValidator(), RefactoringBundle.message("introduce.field.title")); |
| } |
| |
| public void invoke(@NotNull Project project, Editor editor, PsiFile file, DataContext dataContext) { |
| final IntroduceOperation operation = new IntroduceOperation(project, editor, file, null); |
| operation.addAvailableInitPlace(InitPlace.CONSTRUCTOR); |
| if (isTestClass(file, editor)) { |
| operation.addAvailableInitPlace(InitPlace.SET_UP); |
| } |
| performAction(operation); |
| } |
| |
| private static boolean isTestClass(PsiFile file, Editor editor) { |
| PsiElement element1 = null; |
| final SelectionModel selectionModel = editor.getSelectionModel(); |
| if (selectionModel.hasSelection()) { |
| element1 = file.findElementAt(selectionModel.getSelectionStart()); |
| } |
| else { |
| final CaretModel caretModel = editor.getCaretModel(); |
| final Document document = editor.getDocument(); |
| int lineNumber = document.getLineNumber(caretModel.getOffset()); |
| if ((lineNumber >= 0) && (lineNumber < document.getLineCount())) { |
| element1 = file.findElementAt(document.getLineStartOffset(lineNumber)); |
| } |
| } |
| if (element1 != null) { |
| final PyClass clazz = PyUtil.getContainingClassOrSelf(element1); |
| if (clazz != null && PythonUnitTestUtil.isTestCaseClass(clazz)) return true; |
| } |
| return false; |
| } |
| |
| @Override |
| protected PsiElement replaceExpression(PsiElement expression, PyExpression newExpression, IntroduceOperation operation) { |
| if (operation.getInitPlace() != InitPlace.SAME_METHOD) { |
| return PyReplaceExpressionUtil.replaceExpression(expression, newExpression); |
| } |
| return super.replaceExpression(expression, newExpression, operation); |
| } |
| |
| @Override |
| protected boolean checkEnabled(IntroduceOperation operation) { |
| if (PyUtil.getContainingClassOrSelf(operation.getElement()) == null) { |
| CommonRefactoringUtil.showErrorHint(operation.getProject(), operation.getEditor(), "Cannot introduce field: not in class", myDialogTitle, |
| getHelpId()); |
| return false; |
| } |
| if (dependsOnLocalScopeValues(operation.getElement())) { |
| operation.removeAvailableInitPlace(InitPlace.CONSTRUCTOR); |
| operation.removeAvailableInitPlace(InitPlace.SET_UP); |
| } |
| return true; |
| } |
| |
| private static boolean dependsOnLocalScopeValues(PsiElement initializer) { |
| ScopeOwner scope = PsiTreeUtil.getParentOfType(initializer, ScopeOwner.class); |
| ResolvingVisitor visitor = new ResolvingVisitor(scope); |
| initializer.accept(visitor); |
| return visitor.hasLocalScopeDependencies; |
| |
| } |
| |
| private static class ResolvingVisitor extends PyRecursiveElementVisitor { |
| private boolean hasLocalScopeDependencies = false; |
| private final ScopeOwner myScope; |
| |
| public ResolvingVisitor(ScopeOwner scope) { |
| myScope = scope; |
| } |
| |
| @Override |
| public void visitPyReferenceExpression(PyReferenceExpression node) { |
| super.visitPyReferenceExpression(node); |
| final PsiElement result = node.getReference().resolve(); |
| if (result != null && PsiTreeUtil.getParentOfType(result, ScopeOwner.class) == myScope) { |
| if (result instanceof PyParameter && myScope instanceof PyFunction) { |
| final PyFunction function = (PyFunction)myScope; |
| final PyParameter[] parameters = function.getParameterList().getParameters(); |
| if (parameters.length > 0 && result == parameters[0]) { |
| final PyFunction.Modifier modifier = function.getModifier(); |
| if (modifier != PyFunction.Modifier.STATICMETHOD) { |
| // 'self' is not a local scope dependency |
| return; |
| } |
| } |
| } |
| hasLocalScopeDependencies = true; |
| } |
| } |
| } |
| |
| @Nullable |
| @Override |
| protected PsiElement addDeclaration(@NotNull PsiElement expression, @NotNull PsiElement declaration, @NotNull IntroduceOperation operation) { |
| final PsiElement expr = expression instanceof PyClass ? expression : expression.getParent(); |
| PsiElement anchor = PyUtil.getContainingClassOrSelf(expr); |
| assert anchor instanceof PyClass; |
| final PyClass clazz = (PyClass)anchor; |
| final Project project = anchor.getProject(); |
| if (operation.getInitPlace() == InitPlace.CONSTRUCTOR && !inConstructor(expression)) { |
| return AddFieldQuickFix.addFieldToInit(project, clazz, "", new AddFieldDeclaration(declaration)); |
| } else if (operation.getInitPlace() == InitPlace.SET_UP) { |
| return addFieldToSetUp(clazz, new AddFieldDeclaration(declaration)); |
| } |
| return PyIntroduceVariableHandler.doIntroduceVariable(expression, declaration, operation.getOccurrences(), operation.isReplaceAll()); |
| } |
| |
| private static boolean inConstructor(@NotNull PsiElement expression) { |
| final PsiElement expr = expression instanceof PyClass ? expression : expression.getParent(); |
| PyClass clazz = PyUtil.getContainingClassOrSelf(expr); |
| PsiElement current = PyUtil.getConcealingParent(expression); |
| if (clazz != null && current != null && current instanceof PyFunction) { |
| PyFunction init = clazz.findMethodByName(PyNames.INIT, false); |
| if (current == init) { |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| @Nullable |
| private static PsiElement addFieldToSetUp(PyClass clazz, final Function<String, PyStatement> callback) { |
| final PyFunction init = clazz.findMethodByName(PythonUnitTestUtil.TESTCASE_SETUP_NAME, false); |
| if (init != null) { |
| return AddFieldQuickFix.appendToMethod(init, callback); |
| } |
| final PyFunctionBuilder builder = new PyFunctionBuilder(PythonUnitTestUtil.TESTCASE_SETUP_NAME); |
| builder.parameter(PyNames.CANONICAL_SELF); |
| PyFunction setUp = builder.buildFunction(clazz.getProject(), LanguageLevel.getDefault()); |
| final PyStatementList statements = clazz.getStatementList(); |
| final PsiElement anchor = statements.getFirstChild(); |
| setUp = (PyFunction)statements.addBefore(setUp, anchor); |
| return AddFieldQuickFix.appendToMethod(setUp, callback); |
| } |
| |
| @Override |
| protected List<PsiElement> getOccurrences(PsiElement element, @NotNull PyExpression expression) { |
| if (isAssignedLocalVariable(element)) { |
| PyFunction function = PsiTreeUtil.getParentOfType(element, PyFunction.class); |
| Collection<PsiReference> references = ReferencesSearch.search(element, new LocalSearchScope(function)).findAll(); |
| ArrayList<PsiElement> result = new ArrayList<PsiElement>(); |
| for (PsiReference reference : references) { |
| PsiElement refElement = reference.getElement(); |
| if (refElement != element) { |
| result.add(refElement); |
| } |
| } |
| return result; |
| } |
| return super.getOccurrences(element, expression); |
| } |
| |
| @Override |
| protected PyExpression createExpression(Project project, String name, PsiElement declaration) { |
| final String text = declaration.getText(); |
| final String self_name = text.substring(0, text.indexOf('.')); |
| return PyElementGenerator.getInstance(project).createExpressionFromText(self_name + "." + name); |
| } |
| |
| @Override |
| protected PyAssignmentStatement createDeclaration(Project project, String assignmentText, PsiElement anchor) { |
| final PyFunction container = PsiTreeUtil.getParentOfType(anchor, PyFunction.class); |
| String selfName = PyUtil.getFirstParameterName(container); |
| final LanguageLevel langLevel = LanguageLevel.forElement(anchor); |
| return PyElementGenerator.getInstance(project).createFromText(langLevel, PyAssignmentStatement.class, selfName + "." + assignmentText); |
| } |
| |
| @Override |
| protected void postRefactoring(PsiElement element) { |
| if (isAssignedLocalVariable(element)) { |
| element.getParent().delete(); |
| } |
| } |
| |
| private static boolean isAssignedLocalVariable(PsiElement element) { |
| if (element instanceof PyTargetExpression && element.getParent() instanceof PyAssignmentStatement && |
| PsiTreeUtil.getParentOfType(element, PyFunction.class) != null) { |
| PyAssignmentStatement stmt = (PyAssignmentStatement) element.getParent(); |
| if (stmt.getTargets().length == 1) { |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| @Override |
| protected String getHelpId() { |
| return "python.reference.introduceField"; |
| } |
| |
| @Override |
| protected boolean checkIntroduceContext(PsiFile file, Editor editor, PsiElement element) { |
| if (element != null && isInStaticMethod(element)) { |
| CommonRefactoringUtil.showErrorHint(file.getProject(), editor, "Introduce Field refactoring cannot be used in static methods", |
| RefactoringBundle.message("introduce.field.title"), |
| "refactoring.extractMethod"); |
| return false; |
| } |
| return super.checkIntroduceContext(file, editor, element); |
| } |
| |
| private static boolean isInStaticMethod(PsiElement element) { |
| PyFunction containingMethod = PsiTreeUtil.getParentOfType(element, PyFunction.class, false, PyClass.class); |
| if (containingMethod != null) { |
| final PyFunction.Modifier modifier = containingMethod.getModifier(); |
| return modifier == PyFunction.Modifier.STATICMETHOD; |
| } |
| return false; |
| } |
| |
| @Override |
| protected boolean isValidIntroduceContext(PsiElement element) { |
| return super.isValidIntroduceContext(element) && |
| PsiTreeUtil.getParentOfType(element, PyFunction.class, false, PyClass.class) != null && |
| PsiTreeUtil.getParentOfType(element, PyDecoratorList.class) == null && |
| !isInStaticMethod(element); |
| } |
| |
| private static class AddFieldDeclaration implements Function<String, PyStatement> { |
| private final PsiElement myDeclaration; |
| |
| private AddFieldDeclaration(PsiElement declaration) { |
| myDeclaration = declaration; |
| } |
| |
| public PyStatement fun(String self_name) { |
| if (PyNames.CANONICAL_SELF.equals(self_name)) { |
| return (PyStatement)myDeclaration; |
| } |
| final String text = myDeclaration.getText(); |
| final Project project = myDeclaration.getProject(); |
| return PyElementGenerator.getInstance(project).createFromText(LanguageLevel.getDefault(), PyStatement.class, |
| text.replaceFirst(PyNames.CANONICAL_SELF + "\\.", self_name + ".")); |
| } |
| } |
| |
| @Override |
| protected void performInplaceIntroduce(IntroduceOperation operation) { |
| final PsiElement statement = performRefactoring(operation); |
| // put caret on identifier after "self." |
| if (statement instanceof PyAssignmentStatement) { |
| final List<PsiElement> occurrences = operation.getOccurrences(); |
| final PsiElement occurrence = findOccurrenceUnderCaret(occurrences, operation.getEditor()); |
| PyTargetExpression target = (PyTargetExpression) ((PyAssignmentStatement)statement).getTargets() [0]; |
| putCaretOnFieldName(operation.getEditor(), occurrence != null ? occurrence : target); |
| final InplaceVariableIntroducer<PsiElement> introducer = new PyInplaceFieldIntroducer(target, operation, occurrences); |
| introducer.performInplaceRefactoring(new LinkedHashSet<String>(operation.getSuggestedNames())); |
| } |
| } |
| |
| private static void putCaretOnFieldName(Editor editor, PsiElement occurrence) { |
| PyQualifiedExpression qExpr = PsiTreeUtil.getParentOfType(occurrence, PyQualifiedExpression.class, false); |
| if (qExpr != null && !qExpr.isQualified()) { |
| qExpr = PsiTreeUtil.getParentOfType(qExpr, PyQualifiedExpression.class); |
| } |
| if (qExpr != null) { |
| final ASTNode nameElement = qExpr.getNameElement(); |
| if (nameElement != null) { |
| final int offset = nameElement.getTextRange().getStartOffset(); |
| editor.getCaretModel().moveToOffset(offset); |
| } |
| } |
| } |
| |
| private static class PyInplaceFieldIntroducer extends InplaceVariableIntroducer<PsiElement> { |
| private final PyTargetExpression myTarget; |
| private final IntroduceOperation myOperation; |
| private final PyIntroduceFieldPanel myPanel; |
| |
| public PyInplaceFieldIntroducer(PyTargetExpression target, |
| IntroduceOperation operation, |
| List<PsiElement> occurrences) { |
| super(target, operation.getEditor(), operation.getProject(), "Introduce Field", |
| occurrences.toArray(new PsiElement[occurrences.size()]), null); |
| myTarget = target; |
| myOperation = operation; |
| if (operation.getAvailableInitPlaces().size() > 1) { |
| myPanel = new PyIntroduceFieldPanel(myProject, operation.getAvailableInitPlaces()); |
| } |
| else { |
| myPanel = null; |
| } |
| } |
| |
| @Override |
| protected PsiElement checkLocalScope() { |
| return myTarget.getContainingFile(); |
| } |
| |
| @Override |
| protected JComponent getComponent() { |
| return myPanel == null ? null : myPanel.getRootPanel(); |
| } |
| |
| @Override |
| protected void moveOffsetAfter(boolean success) { |
| if (success && (myPanel != null && myPanel.getInitPlace() != InitPlace.SAME_METHOD) || myOperation.getInplaceInitPlace() != InitPlace.SAME_METHOD) { |
| final AccessToken accessToken = ApplicationManager.getApplication().acquireWriteActionLock(getClass()); |
| try { |
| final PyAssignmentStatement initializer = PsiTreeUtil.getParentOfType(myTarget, PyAssignmentStatement.class); |
| assert initializer != null; |
| final Function<String, PyStatement> callback = FunctionUtil.<String, PyStatement>constant(initializer); |
| final PyClass pyClass = PyUtil.getContainingClassOrSelf(initializer); |
| InitPlace initPlace = myPanel != null ? myPanel.getInitPlace() : myOperation.getInplaceInitPlace(); |
| if (initPlace == InitPlace.CONSTRUCTOR) { |
| AddFieldQuickFix.addFieldToInit(myProject, pyClass, "", callback); |
| } |
| else if (initPlace == InitPlace.SET_UP) { |
| addFieldToSetUp(pyClass, callback); |
| } |
| if (myOperation.getOccurrences().size() > 0) { |
| initializer.delete(); |
| } |
| else { |
| final PyExpression copy = |
| PyElementGenerator.getInstance(myProject).createExpressionFromText(LanguageLevel.forElement(myTarget), myTarget.getText()); |
| initializer.replace(copy); |
| } |
| initializer.delete(); |
| } |
| finally { |
| accessToken.finish(); |
| } |
| } |
| } |
| } |
| } |