blob: e11f3e01e916939ce8393dcad1999e69facf8d87 [file] [log] [blame]
/*
* Copyright 2000-2013 JetBrains s.r.o.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.jetbrains.python.codeInsight.override;
import com.google.common.collect.Lists;
import com.intellij.codeInsight.CodeInsightUtilCore;
import com.intellij.featureStatistics.FeatureUsageTracker;
import com.intellij.featureStatistics.ProductivityFeatureNames;
import com.intellij.ide.util.MemberChooser;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.application.Result;
import com.intellij.openapi.command.WriteCommandAction;
import com.intellij.openapi.diagnostic.Logger;
import com.intellij.openapi.editor.Editor;
import com.intellij.openapi.editor.ScrollType;
import com.intellij.openapi.project.Project;
import com.intellij.openapi.ui.DialogWrapper;
import com.intellij.openapi.util.TextRange;
import com.intellij.openapi.util.text.StringUtil;
import com.intellij.psi.PsiDocumentManager;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiFile;
import com.intellij.psi.PsiWhiteSpace;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.ui.SpeedSearchComparator;
import com.jetbrains.python.PyNames;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.impl.PyFunctionBuilder;
import com.jetbrains.python.psi.impl.PyPsiUtils;
import com.jetbrains.python.psi.types.PyNoneType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.*;
/**
* @author Alexey.Ivanov
*/
public class PyOverrideImplementUtil {
private static final Logger LOG = Logger.getInstance("#com.jetbrains.python.codeInsight.override.PyOverrideImplementUtil");
private PyOverrideImplementUtil() {
}
@Nullable
public static PyClass getContextClass(@NotNull final Project project, @NotNull final Editor editor, @NotNull final PsiFile file) {
PsiDocumentManager.getInstance(project).commitAllDocuments();
int offset = editor.getCaretModel().getOffset();
PsiElement element = file.findElementAt(offset);
if (element == null) {
// are we in whitespace after last class? PY-440
final PsiElement lastChild = file.getLastChild();
if (lastChild != null &&
offset >= lastChild.getTextRange().getStartOffset() &&
offset <= lastChild.getTextRange().getEndOffset()) {
element = lastChild;
}
}
final PyClass pyClass = PsiTreeUtil.getParentOfType(element, PyClass.class, false);
if (pyClass == null && element instanceof PsiWhiteSpace && element.getPrevSibling() instanceof PyClass) {
return (PyClass) element.getPrevSibling();
}
return pyClass;
}
public static void chooseAndOverrideMethods(final Project project, @NotNull final Editor editor, @NotNull final PyClass pyClass) {
FeatureUsageTracker.getInstance().triggerFeatureUsed(ProductivityFeatureNames.CODEASSISTS_OVERRIDE_IMPLEMENT);
chooseAndOverrideOrImplementMethods(project, editor, pyClass);
}
private static void chooseAndOverrideOrImplementMethods(final Project project,
@NotNull final Editor editor,
@NotNull final PyClass pyClass) {
LOG.assertTrue(pyClass.isValid());
ApplicationManager.getApplication().assertReadAccessAllowed();
final Collection<PyFunction> superFunctions = getAllSuperFunctions(pyClass);
chooseAndOverrideOrImplementMethods(project, editor, pyClass, superFunctions, "Select Methods to Override", false);
}
public static void chooseAndOverrideOrImplementMethods(@NotNull final Project project,
@NotNull final Editor editor,
@NotNull final PyClass pyClass,
@NotNull final Collection<PyFunction> superFunctions,
@NotNull final String title, final boolean implement) {
List<PyMethodMember> elements = new ArrayList<PyMethodMember>();
for (PyFunction function : superFunctions) {
final String name = function.getName();
if (name == null || PyUtil.isClassPrivateName(name)) {
continue;
}
if (pyClass.findMethodByName(name, false) == null) {
final PyMethodMember member = new PyMethodMember(function);
elements.add(member);
}
}
if (elements.size() == 0) {
return;
}
final MemberChooser<PyMethodMember> chooser =
new MemberChooser<PyMethodMember>(elements.toArray(new PyMethodMember[elements.size()]), false, true, project) {
@Override
protected SpeedSearchComparator getSpeedSearchComparator() {
return new SpeedSearchComparator(false) {
@Nullable
@Override
public Iterable<TextRange> matchingFragments(String pattern, String text) {
return super.matchingFragments(PyMethodMember.trimUnderscores(pattern), text);
}
};
}
};
chooser.setTitle(title);
chooser.setCopyJavadocVisible(false);
chooser.show();
if (chooser.getExitCode() != DialogWrapper.OK_EXIT_CODE) {
return;
}
List<PyMethodMember> membersToOverride = chooser.getSelectedElements();
overrideMethods(editor, pyClass, membersToOverride, implement);
}
public static void overrideMethods(final Editor editor, final PyClass pyClass, final List<PyMethodMember> membersToOverride,
final boolean implement) {
if (membersToOverride == null) {
return;
}
new WriteCommandAction(pyClass.getProject(), pyClass.getContainingFile()) {
protected void run(@NotNull final Result result) throws Throwable {
write(pyClass, membersToOverride, editor, implement);
}
}.execute();
}
private static void write(@NotNull final PyClass pyClass,
@NotNull final List<PyMethodMember> newMembers,
@NotNull final Editor editor, boolean implement) {
final PyStatementList statementList = pyClass.getStatementList();
final int offset = editor.getCaretModel().getOffset();
PsiElement anchor = null;
for (PyStatement statement: statementList.getStatements()) {
if (statement.getTextRange().getStartOffset() < offset ||
(statement instanceof PyExpressionStatement && ((PyExpressionStatement)statement).getExpression() instanceof PyStringLiteralExpression)) {
anchor = statement;
}
}
PyFunction element = null;
for (PyMethodMember newMember : newMembers) {
PyFunction baseFunction = (PyFunction) newMember.getPsiElement();
final PyFunctionBuilder builder = buildOverriddenFunction(pyClass, baseFunction, implement);
PyFunction function = builder.addFunctionAfter(statementList, anchor, LanguageLevel.forElement(statementList));
element = CodeInsightUtilCore.forcePsiPostprocessAndRestoreElement(function);
}
PyPsiUtils.removeRedundantPass(statementList);
if (element != null) {
final PyStatementList targetStatementList = element.getStatementList();
final int start = targetStatementList.getTextRange().getStartOffset();
editor.getCaretModel().moveToOffset(start);
editor.getScrollingModel().scrollToCaret(ScrollType.RELATIVE);
editor.getSelectionModel().setSelection(start, element.getTextRange().getEndOffset());
}
}
private static PyFunctionBuilder buildOverriddenFunction(PyClass pyClass, PyFunction baseFunction, boolean implement) {
PyFunctionBuilder pyFunctionBuilder = new PyFunctionBuilder(baseFunction.getName());
final PyDecoratorList decorators = baseFunction.getDecoratorList();
if (decorators != null && decorators.findDecorator(PyNames.CLASSMETHOD) != null) {
pyFunctionBuilder.decorate(PyNames.CLASSMETHOD);
}
PyAnnotation anno = baseFunction.getAnnotation();
if (anno != null) {
pyFunctionBuilder.annotation(anno.getText());
}
final TypeEvalContext context = TypeEvalContext.userInitiated(baseFunction.getContainingFile());
final List<PyParameter> baseParams = PyUtil.getParameters(baseFunction, context);
for (PyParameter parameter : baseParams) {
pyFunctionBuilder.parameter(parameter.getText());
}
PyClass baseClass = baseFunction.getContainingClass();
assert baseClass != null;
StringBuilder statementBody = new StringBuilder();
boolean hadStar = false;
List<String> parameters = new ArrayList<String>();
for (PyParameter parameter: baseParams) {
final PyNamedParameter pyNamedParameter = parameter.getAsNamed();
if (pyNamedParameter != null) {
String repr = pyNamedParameter.getRepr(false);
parameters.add(hadStar && !pyNamedParameter.isKeywordContainer() ? pyNamedParameter.getName() + "=" + repr : repr);
if (pyNamedParameter.isPositionalContainer()) {
hadStar = true;
}
}
else if (parameter instanceof PySingleStarParameter) {
hadStar = true;
}
else {
parameters.add(parameter.getText());
}
}
if (PyNames.FAKE_OLD_BASE.equals(baseClass.getName()) || implement) {
statementBody.append(PyNames.PASS);
}
else {
if (!PyNames.INIT.equals(baseFunction.getName()) && context.getReturnType(baseFunction) != PyNoneType.INSTANCE) {
statementBody.append("return ");
}
if (baseClass.isNewStyleClass()) {
statementBody.append(PyNames.SUPER);
statementBody.append("(");
final LanguageLevel langLevel = ((PyFile)pyClass.getContainingFile()).getLanguageLevel();
if (!langLevel.isPy3K()) {
final String baseFirstName = !baseParams.isEmpty() ? baseParams.get(0).getName() : null;
final String firstName = baseFirstName != null ? baseFirstName : PyNames.CANONICAL_SELF;
PsiElement outerClass = PsiTreeUtil.getParentOfType(pyClass, PyClass.class, true, PyFunction.class);
String className = pyClass.getName();
final List<String> nameResult = Lists.newArrayList(className);
while(outerClass != null) {
nameResult.add(0, ((PyClass)outerClass).getName());
outerClass = PsiTreeUtil.getParentOfType(outerClass, PyClass.class, true, PyFunction.class);
}
className = StringUtil.join(nameResult, ".");
statementBody.append(className).append(", ").append(firstName);
}
statementBody.append(").").append(baseFunction.getName()).append("(");
if (parameters.size() > 0) {
parameters.remove(0);
}
}
else {
statementBody.append(getReferenceText(pyClass, baseClass)).append(".").append(baseFunction.getName()).append("(");
}
statementBody.append(StringUtil.join(parameters, ", "));
statementBody.append(")");
}
pyFunctionBuilder.statement(statementBody.toString());
return pyFunctionBuilder;
}
// TODO find a better place for this logic
private static String getReferenceText(PyClass fromClass, PyClass toClass) {
final PyExpression[] superClassExpressions = fromClass.getSuperClassExpressions();
for (PyExpression expression : superClassExpressions) {
if (expression instanceof PyReferenceExpression) {
PsiElement target = ((PyReferenceExpression) expression).getReference().resolve();
if (target == toClass) {
return expression.getText();
}
}
}
return toClass.getName();
}
@NotNull
public static Collection<PyFunction> getAllSuperFunctions(@NotNull final PyClass pyClass) {
final Map<String, PyFunction> superFunctions = new HashMap<String, PyFunction>();
for (PyClass aClass : pyClass.getAncestorClasses()) {
for (PyFunction function : aClass.getMethods()) {
if (!superFunctions.containsKey(function.getName())) {
superFunctions.put(function.getName(), function);
}
}
}
return superFunctions.values();
}
}