blob: c257143450055bb6db2e4e894fe206d40008c5c6 [file] [log] [blame]
/*
* Copyright 2000-2013 JetBrains s.r.o.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.jetbrains.python.inspections.quickfix;
import com.intellij.codeInspection.LocalQuickFix;
import com.intellij.codeInspection.ProblemDescriptor;
import com.intellij.openapi.project.Project;
import com.intellij.psi.util.PsiTreeUtil;
import com.jetbrains.python.PyBundle;
import com.jetbrains.python.PyNames;
import com.jetbrains.python.psi.*;
import org.jetbrains.annotations.NonNls;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.ArrayList;
import java.util.List;
/**
* For:
* class B(A):
* def __init__(self):
* A.__init__(self) # inserted
* print "Constructor B was called"
*
* User: catherine
*/
public class AddCallSuperQuickFix implements LocalQuickFix {
@NotNull
public String getName() {
return PyBundle.message("QFIX.add.super");
}
@NonNls
@NotNull
public String getFamilyName() {
return getName();
}
public void applyFix(@NotNull final Project project, @NotNull final ProblemDescriptor descriptor) {
PyFunction problemFunction = PsiTreeUtil.getParentOfType(descriptor.getPsiElement(), PyFunction.class);
if (problemFunction == null) return;
final StringBuilder superCall = new StringBuilder();
final PyClass klass = problemFunction.getContainingClass();
if (klass == null) return;
final PyClass[] superClasses = klass.getSuperClasses();
if (superClasses.length == 0) return;
final PyClass superClass = superClasses[0];
final PyFunction superInit = superClass.findMethodByName(PyNames.INIT, true);
if (superInit == null) return;
boolean addComma = true;
if (klass.isNewStyleClass()) {
addComma = false;
if (LanguageLevel.forElement(klass).isPy3K())
superCall.append("super().__init__(");
else
superCall.append("super(").append(klass.getName()).append(", self).__init__(");
}
else {
superCall.append(superClass.getName());
superCall.append(".__init__(self");
}
StringBuilder newFunction = new StringBuilder("def __init__(self");
buildParameterList(problemFunction, superInit, superCall, newFunction, addComma);
superCall.append(")");
final PyStatementList statementList = problemFunction.getStatementList();
PyExpression docstring = null;
final PyStatement[] statements = statementList.getStatements();
if (statements.length != 0 && statements[0] instanceof PyExpressionStatement) {
PyExpressionStatement st = (PyExpressionStatement)statements[0];
if (st.getExpression() instanceof PyStringLiteralExpression)
docstring = st.getExpression();
}
newFunction.append("):\n\t");
if (docstring != null)
newFunction.append(docstring.getText()).append("\n\t");
newFunction.append(superCall).append("\n\t");
boolean first = true;
for (PyStatement statement : statements) {
if (first && docstring != null || statement instanceof PyPassStatement) {
first = false;
continue;
}
newFunction.append(statement.getText()).append("\n\t");
}
problemFunction.replace(
PyElementGenerator.getInstance(project).createFromText(LanguageLevel.forElement(problemFunction), PyFunction.class,
newFunction.toString()));
}
private static void buildParameterList(@NotNull final PyFunction problemFunction,
@NotNull final PyFunction superInit,
@NotNull final StringBuilder superCall,
@NotNull final StringBuilder newFunction, boolean addComma) {
final PyParameter[] parameters = problemFunction.getParameterList().getParameters();
final List<String> problemParams = new ArrayList<String>();
final List<String> functionParams = new ArrayList<String>();
String starName = null;
String doubleStarName = null;
for (int i = 1; i != parameters.length; i++) {
PyParameter p = parameters[i];
functionParams.add(p.getName());
if (p.getText().startsWith("**")) {
doubleStarName = p.getText();
continue;
}
if (p.getText().startsWith("*")) {
starName = p.getText();
continue;
}
if (p.getDefaultValue() != null) {
problemParams.add(p.getText());
continue;
}
newFunction.append(",").append(p.getText());
}
addParametersFromSuper(superInit, superCall, newFunction, addComma, problemParams, functionParams, starName, doubleStarName);
}
private static void addParametersFromSuper(@NotNull final PyFunction superInit, @NotNull final StringBuilder superCall,
@NotNull final StringBuilder newFunction, boolean addComma,
@NotNull final List<String> problemParams, @NotNull final List<String> functionParams,
@Nullable String starName, @Nullable String doubleStarName) {
final PyParameterList paramList = superInit.getParameterList();
PyParameter[] parameters = paramList.getParameters();
boolean addDouble = false;
boolean addStar = false;
for (int i = 1; i != parameters.length; i++) {
PyParameter p = parameters[i];
if (p.getDefaultValue() != null) continue;
final String param = p.getName();
String paramText = p.getText();
if (paramText.startsWith("**")) {
addDouble = true;
if (doubleStarName == null)
doubleStarName = p.getText();
continue;
}
if (paramText.startsWith("*")) {
addStar = true;
if (starName == null)
starName = p.getText();
continue;
}
if (addComma)
superCall.append(",");
superCall.append(param);
if (!functionParams.contains(param))
newFunction.append(",").append(param);
addComma = true;
}
for(String p : problemParams)
newFunction.append(",").append(p);
if (starName != null) {
newFunction.append(",").append(starName);
if (addStar) {
if (addComma) superCall.append(",");
superCall.append(starName);
addComma = true;
}
}
if (doubleStarName != null) {
newFunction.append(",").append(doubleStarName);
if (addDouble) {
if (addComma) superCall.append(",");
superCall.append(doubleStarName);
}
}
}
}