blob: c991266f452284fa9df749a7654eb32bd539b76f [file] [log] [blame]
/*
* Copyright 2000-2014 JetBrains s.r.o.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.jetbrains.python.psi.impl;
import com.google.common.collect.Lists;
import com.intellij.lang.ASTNode;
import com.intellij.openapi.util.Pair;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiErrorElement;
import com.intellij.psi.tree.TokenSet;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.util.IncorrectOperationException;
import com.intellij.util.SmartList;
import com.jetbrains.python.PyTokenTypes;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.toolbox.FP;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* @author yole
*/
public class PyAssignmentStatementImpl extends PyElementImpl implements PyAssignmentStatement {
private PyExpression[] myTargets;
public PyAssignmentStatementImpl(ASTNode astNode) {
super(astNode);
}
@Override
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
pyVisitor.visitPyAssignmentStatement(this);
}
public PyExpression[] getTargets() {
if (myTargets == null) {
myTargets = calcTargets(false);
}
return myTargets;
}
@NotNull
@Override
public PyExpression[] getRawTargets() {
return calcTargets(true);
}
private PyExpression[] calcTargets(boolean raw) {
final ASTNode[] eqSigns = getNode().getChildren(TokenSet.create(PyTokenTypes.EQ));
if (eqSigns.length == 0) {
return PyExpression.EMPTY_ARRAY;
}
final ASTNode lastEq = eqSigns[eqSigns.length - 1];
List<PyExpression> candidates = new ArrayList<PyExpression>();
ASTNode node = getNode().getFirstChildNode();
while (node != null && node != lastEq) {
final PsiElement psi = node.getPsi();
if (psi instanceof PyExpression) {
if (raw) {
candidates.add((PyExpression) psi);
}
else {
addCandidate(candidates, (PyExpression)psi);
}
}
node = node.getTreeNext();
}
List<PyExpression> targets = new ArrayList<PyExpression>();
for (PyExpression expr : candidates) { // only filter out targets
if (raw ||
expr instanceof PyTargetExpression ||
expr instanceof PyReferenceExpression ||
expr instanceof PySubscriptionExpression ||
expr instanceof PySliceExpression) {
targets.add(expr);
}
}
return targets.toArray(new PyExpression[targets.size()]);
}
private static void addCandidate(List<PyExpression> candidates, PyExpression psi) {
if (psi instanceof PyParenthesizedExpression) {
addCandidate(candidates, ((PyParenthesizedExpression)psi).getContainedExpression());
}
else if (psi instanceof PySequenceExpression) {
final PyExpression[] pyExpressions = ((PySequenceExpression)psi).getElements();
for (PyExpression pyExpression : pyExpressions) {
addCandidate(candidates, pyExpression);
}
}
else if (psi instanceof PyStarExpression) {
final PyExpression expression = ((PyStarExpression)psi).getExpression();
if (expression != null) {
addCandidate(candidates, expression);
}
}
else {
candidates.add(psi);
}
}
/**
* @return rightmost expression in statement, which is supposedly the assigned value, or null.
*/
@Nullable
public PyExpression getAssignedValue() {
PsiElement child = getLastChild();
while (child != null && !(child instanceof PyExpression)) {
if (child instanceof PsiErrorElement) return null; // incomplete assignment operator can't be analyzed properly, bail out.
child = child.getPrevSibling();
}
return (PyExpression)child;
}
@NotNull
public List<Pair<PyExpression, PyExpression>> getTargetsToValuesMapping() {
List<Pair<PyExpression, PyExpression>> ret = new SmartList<Pair<PyExpression, PyExpression>>();
if (!PsiTreeUtil.hasErrorElements(this)) { // no parse errors
PyExpression[] constituents = PsiTreeUtil.getChildrenOfType(this, PyExpression.class); // "a = b = c" -> [a, b, c]
if (constituents != null && constituents.length > 1) {
PyExpression rhs = constituents[constituents.length - 1]; // last
List<PyExpression> lhses = Lists.newArrayList(constituents);
if (lhses.size()>0) lhses.remove(lhses.size()-1); // copy all but last; most often it's one element.
for (PyExpression lhs : lhses) mapToValues(lhs, rhs, ret);
}
}
return ret;
}
@Nullable
public PyExpression getLeftHandSideExpression() {
PsiElement child = getFirstChild();
while (child != null && !(child instanceof PyExpression)) {
if (child instanceof PsiErrorElement) return null; // incomplete assignment operator can't be analyzed properly, bail out.
child = child.getPrevSibling();
}
return (PyExpression)child;
}
@Override
public boolean isAssignmentTo(@NotNull String name) {
PyExpression lhs = getLeftHandSideExpression();
return lhs instanceof PyTargetExpression && name.equals(lhs.getName());
}
private static void mapToValues(PyExpression lhs, PyExpression rhs, List<Pair<PyExpression, PyExpression>> map) {
// cast for convenience
PySequenceExpression lhs_tuple = null;
PyExpression lhs_one = null;
if (lhs instanceof PySequenceExpression) lhs_tuple = (PySequenceExpression)lhs;
else if (lhs != null) lhs_one = lhs;
PySequenceExpression rhs_tuple = null;
PyExpression rhs_one = null;
if (rhs instanceof PyParenthesizedExpression) {
PyExpression exp = ((PyParenthesizedExpression)rhs).getContainedExpression();
if (exp instanceof PyTupleExpression)
rhs_tuple = (PySequenceExpression)exp;
else
rhs_one = rhs;
}
else if (rhs instanceof PySequenceExpression) rhs_tuple = (PySequenceExpression)rhs;
else if (rhs != null) rhs_one = rhs;
//
if (lhs_one != null) { // single LHS, single RHS (direct mapping) or multiple RHS (packing)
map.add(Pair.create(lhs_one, rhs));
}
else if (lhs_tuple != null && rhs_one != null) { // multiple LHS, single RHS: unpacking
// PY-2648, PY-2649
PyElementGenerator elementGenerator = PyElementGenerator.getInstance(rhs_one.getProject());
final LanguageLevel languageLevel = LanguageLevel.forElement(lhs);
int counter = 0;
for (PyExpression tuple_elt : lhs_tuple.getElements()) {
try {
final PyExpression expression = elementGenerator.createExpressionFromText(languageLevel, rhs_one.getText() + "[" + counter + "]");
map.add(Pair.create(tuple_elt, expression));
}
catch (IncorrectOperationException e) {
// not parsed, no problem
}
++counter;
}
// map.addAll(FP.zipList(Arrays.asList(lhs_tuple.getElements()), new RepeatIterable<PyExpression>(rhs_one)));
}
else if (lhs_tuple != null && rhs_tuple != null) { // multiple both sides: piecewise mapping
map.addAll(FP.zipList(Arrays.asList(lhs_tuple.getElements()), Arrays.asList(rhs_tuple.getElements()), null, null));
}
}
@NotNull
public Iterable<PyElement> iterateNames() {
final List<PyExpression> expressions = PyUtil.flattenedParensAndStars(getTargets());
List<PyElement> result = new ArrayList<PyElement>();
for (PyExpression expression : expressions) {
if (expression instanceof PyQualifiedExpression && ((PyQualifiedExpression)expression).isQualified()) {
continue;
}
result.add(expression);
}
return result;
}
public PyElement getElementNamed(final String the_name) {
// performance: check simple case first
PyExpression[] targets = getTargets();
if (targets.length == 1 && targets[0] instanceof PyTargetExpression) {
PyTargetExpression target = (PyTargetExpression)targets[0];
return !target.isQualified() && the_name.equals(target.getName()) ? target : null;
}
return IterHelper.findName(iterateNames(), the_name);
}
public boolean mustResolveOutside() {
return true; // a = a+1 resolves 'a' outside itself.
}
@Override
public void subtreeChanged() {
super.subtreeChanged();
myTargets = null;
}
}