blob: 8ecd8a5d9b0382ddb4fdf87e0fd41e69a40c68c4 [file] [log] [blame]
/*
* Copyright 2000-2013 JetBrains s.r.o.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.jetbrains.python.inspections.quickfix;
import com.intellij.codeInsight.CodeInsightUtilCore;
import com.intellij.codeInsight.FileModificationService;
import com.intellij.codeInsight.template.TemplateBuilder;
import com.intellij.codeInsight.template.TemplateBuilderFactory;
import com.intellij.codeInspection.LocalQuickFix;
import com.intellij.codeInspection.ProblemDescriptor;
import com.intellij.openapi.editor.Editor;
import com.intellij.openapi.fileEditor.FileEditorManager;
import com.intellij.openapi.fileEditor.OpenFileDescriptor;
import com.intellij.openapi.project.Project;
import com.intellij.openapi.ui.MessageType;
import com.intellij.openapi.vfs.VirtualFile;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiFile;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.util.Function;
import com.jetbrains.python.PyBundle;
import com.jetbrains.python.PyNames;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.impl.PyBuiltinCache;
import com.jetbrains.python.psi.impl.PyPsiUtils;
import com.jetbrains.python.psi.types.PyClassType;
import com.jetbrains.python.psi.types.PyClassTypeImpl;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import static com.jetbrains.python.PyNames.FAKE_OLD_BASE;
/**
* Available on self.my_something when my_something is unresolved.
* User: dcheryasov
* Date: Apr 4, 2009 1:53:46 PM
*/
public class AddFieldQuickFix implements LocalQuickFix {
private final String myInitializer;
private final String myClassName;
private String myIdentifier;
private boolean replaceInitializer = false;
public AddFieldQuickFix(@NotNull final String identifier, @NotNull final String initializer, final String className, boolean replace) {
myIdentifier = identifier;
myInitializer = initializer;
myClassName = className;
replaceInitializer = replace;
}
@NotNull
public String getName() {
return PyBundle.message("QFIX.NAME.add.field.$0.to.class.$1", myIdentifier, myClassName);
}
@NotNull
public String getFamilyName() {
return "Add field to class";
}
@Nullable
public static PsiElement appendToMethod(PyFunction init, Function<String, PyStatement> callback) {
// add this field as the last stmt of the constructor
final PyStatementList statementList = init.getStatementList();
// name of 'self' may be different for fancier styles
PyParameter[] params = init.getParameterList().getParameters();
String selfName = PyNames.CANONICAL_SELF;
if (params.length > 0) {
selfName = params[0].getName();
}
PyStatement newStmt = callback.fun(selfName);
if (!FileModificationService.getInstance().preparePsiElementForWrite(statementList)) return null;
final PsiElement result = PyUtil.addElementToStatementList(newStmt, statementList, true);
PyPsiUtils.removeRedundantPass(statementList);
return result;
}
public void applyFix(@NotNull Project project, @NotNull ProblemDescriptor descriptor) {
// expect the descriptor to point to the unresolved identifier.
final PsiElement element = descriptor.getPsiElement();
final PyClassType type = getClassType(element);
if (type == null) return;
final PyClass cls = type.getPyClass();
PsiElement initStatement;
if (!type.isDefinition()) {
initStatement = addFieldToInit(project, cls, myIdentifier, new CreateFieldCallback(project, myIdentifier, myInitializer));
}
else {
PyStatement field = PyElementGenerator.getInstance(project)
.createFromText(LanguageLevel.getDefault(), PyStatement.class, myIdentifier + " = " + myInitializer);
initStatement = PyUtil.addElementToStatementList(field, cls.getStatementList(), true);
}
if (initStatement != null) {
showTemplateBuilder(initStatement, cls.getContainingFile());
return;
}
// somehow we failed. tell about this
PyUtil.showBalloon(project, PyBundle.message("QFIX.failed.to.add.field"), MessageType.ERROR);
}
private static PyClassType getClassType(@NotNull final PsiElement element) {
if (element instanceof PyQualifiedExpression) {
final PyExpression qualifier = ((PyQualifiedExpression)element).getQualifier();
if (qualifier == null) return null;
final PyType type = TypeEvalContext.userInitiated(element.getContainingFile()).getType(qualifier);
return type instanceof PyClassType ? (PyClassType)type : null;
}
final PyClass aClass = PsiTreeUtil.getParentOfType(element, PyClass.class);
return aClass != null ? new PyClassTypeImpl(aClass, false) : null;
}
private void showTemplateBuilder(PsiElement initStatement, @NotNull final PsiFile file) {
initStatement = CodeInsightUtilCore.forcePsiPostprocessAndRestoreElement(initStatement);
if (initStatement instanceof PyAssignmentStatement) {
final TemplateBuilder builder = TemplateBuilderFactory.getInstance().createTemplateBuilder(initStatement);
final PyExpression assignedValue = ((PyAssignmentStatement)initStatement).getAssignedValue();
final PyExpression leftExpression = ((PyAssignmentStatement)initStatement).getLeftHandSideExpression();
if (assignedValue != null && leftExpression != null) {
if (replaceInitializer)
builder.replaceElement(assignedValue, myInitializer);
else
builder.replaceElement(leftExpression.getLastChild(), myIdentifier);
final VirtualFile virtualFile = file.getVirtualFile();
if (virtualFile == null) return;
final Editor editor = FileEditorManager.getInstance(file.getProject()).openTextEditor(
new OpenFileDescriptor(file.getProject(), virtualFile), true);
if (editor == null) return;
builder.run(editor, false);
}
}
}
@Nullable
public static PsiElement addFieldToInit(Project project, PyClass cls, String itemName, Function<String, PyStatement> callback) {
if (cls != null && itemName != null) {
PyFunction init = cls.findMethodByName(PyNames.INIT, false);
if (init != null) {
return appendToMethod(init, callback);
}
else { // no init! boldly copy ancestor's.
for (PyClass ancestor : cls.getAncestorClasses()) {
init = ancestor.findMethodByName(PyNames.INIT, false);
if (init != null) break;
}
PyFunction newInit = createInitMethod(project, cls, init);
if (newInit == null) {
return null;
}
appendToMethod(newInit, callback);
PsiElement addAnchor = null;
PyFunction[] meths = cls.getMethods();
if (meths.length > 0) addAnchor = meths[0].getPrevSibling();
PyStatementList clsContent = cls.getStatementList();
newInit = (PyFunction) clsContent.addAfter(newInit, addAnchor);
PyUtil.showBalloon(project, PyBundle.message("QFIX.added.constructor.$0.for.field.$1", cls.getName(), itemName), MessageType.INFO);
final PyStatementList statementList = newInit.getStatementList();
final PyStatement[] statements = statementList.getStatements();
return statements.length != 0 ? statements[0] : null;
}
}
return null;
}
@Nullable
private static PyFunction createInitMethod(Project project, PyClass cls, @Nullable PyFunction ancestorInit) {
// found it; copy its param list and make a call to it.
if (!FileModificationService.getInstance().preparePsiElementForWrite(cls)) {
return null;
}
String paramList = ancestorInit != null ? ancestorInit.getParameterList().getText() : "(self)";
String functionText = "def " + PyNames.INIT + paramList + ":\n";
if (ancestorInit == null) functionText += " pass";
else {
final PyClass ancestorClass = ancestorInit.getContainingClass();
if (ancestorClass != null && ancestorClass != PyBuiltinCache.getInstance(ancestorInit).getClass("object") && !FAKE_OLD_BASE.equals(ancestorClass.getName())) {
StringBuilder sb = new StringBuilder();
PyParameter[] params = ancestorInit.getParameterList().getParameters();
boolean seen = false;
if (cls.isNewStyleClass()) {
// form the super() call
sb.append("super(");
if (!LanguageLevel.forElement(cls).isPy3K()) {
sb.append(cls.getName());
// NOTE: assume that we have at least the first param
String self_name = params[0].getName();
sb.append(", ").append(self_name);
}
sb.append(").").append(PyNames.INIT).append("(");
}
else {
sb.append(ancestorClass.getName());
sb.append(".__init__(self");
seen = true;
}
for (int i = 1; i < params.length; i += 1) {
if (seen) sb.append(", ");
else seen = true;
sb.append(params[i].getText());
}
sb.append(")");
functionText += " " + sb.toString();
}
else {
functionText += " pass";
}
}
return PyElementGenerator.getInstance(project).createFromText(
LanguageLevel.getDefault(), PyFunction.class, functionText,
new int[]{0}
);
}
private static class CreateFieldCallback implements Function<String, PyStatement> {
private Project myProject;
private String myItemName;
private String myInitializer;
private CreateFieldCallback(Project project, String itemName, String initializer) {
myProject = project;
myItemName = itemName;
myInitializer = initializer;
}
public PyStatement fun(String self_name) {
return PyElementGenerator.getInstance(myProject).createFromText(LanguageLevel.getDefault(), PyStatement.class, self_name + "." + myItemName + " = " + myInitializer);
}
}
}