blob: 512205abd77aa7af4b60d839353d2471a04d6f1c [file] [log] [blame]
/*
* Copyright (C) 2015 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.jack.optimizations.tailrecursion;
import com.android.jack.Jack;
import com.android.jack.Options;
import com.android.jack.annotations.DisableTailRecursionOptimization;
import com.android.jack.ir.ast.JAnnotationType;
import com.android.jack.ir.ast.JAsgOperation;
import com.android.jack.ir.ast.JBlock;
import com.android.jack.ir.ast.JExpression;
import com.android.jack.ir.ast.JExpressionStatement;
import com.android.jack.ir.ast.JGoto;
import com.android.jack.ir.ast.JLabel;
import com.android.jack.ir.ast.JLabeledStatement;
import com.android.jack.ir.ast.JLocal;
import com.android.jack.ir.ast.JLocalRef;
import com.android.jack.ir.ast.JMethod;
import com.android.jack.ir.ast.JMethodBody;
import com.android.jack.ir.ast.JMethodCall;
import com.android.jack.ir.ast.JModifier;
import com.android.jack.ir.ast.JParameter;
import com.android.jack.ir.ast.JParameterRef;
import com.android.jack.ir.ast.JReturnStatement;
import com.android.jack.ir.ast.JStatement;
import com.android.jack.ir.ast.JStatementList;
import com.android.jack.ir.ast.JThisRef;
import com.android.jack.ir.ast.JTryStatement;
import com.android.jack.ir.ast.JVisitor;
import com.android.jack.ir.sourceinfo.SourceInfo;
import com.android.jack.scheduling.filter.SourceTypeFilter;
import com.android.jack.transformations.request.AddJLocalInMethodBody;
import com.android.jack.transformations.request.AppendStatement;
import com.android.jack.transformations.request.PrependStatement;
import com.android.jack.transformations.request.Remove;
import com.android.jack.transformations.request.TransformationRequest;
import com.android.jack.transformations.threeaddresscode.ThreeAddressCodeForm;
import com.android.jack.util.NamingTools;
import com.android.sched.item.Description;
import com.android.sched.schedulable.Constraint;
import com.android.sched.schedulable.Filter;
import com.android.sched.schedulable.RunnableSchedulable;
import com.android.sched.schedulable.Transform;
import com.android.sched.util.config.ThreadConfig;
import com.android.sched.util.log.Tracer;
import com.android.sched.util.log.TracerFactory;
import com.android.sched.util.log.stats.Counter;
import com.android.sched.util.log.stats.CounterImpl;
import com.android.sched.util.log.stats.StatisticId;
import java.util.ArrayList;
import java.util.Iterator;
import javax.annotation.CheckForNull;
import javax.annotation.Nonnull;
/**
* This visitor optimizes tail recursive calls
*/
@Description("Optimizes tail recursive calls")
@Constraint(need = {JReturnStatement.class, JMethodCall.class},
no = {ThreeAddressCodeForm.class})
@Transform(add = {JLabeledStatement.class,
JLabel.class,
JBlock.class,
JLocal.class,
JAsgOperation.class,
JLocalRef.class,
JParameterRef.class,
JExpressionStatement.class,
JGoto.class})
@Filter(SourceTypeFilter.class)
public class TailRecursionOptimizer implements RunnableSchedulable<JMethod> {
/*
This schedulable transforms detected tail recursive calls to a series of
statements that compute arguments into temporary variables, store them
into proper argument variables and perform goto to the beginning of method.
Example:
int factTail(int k, int ret) {
if (k > 0) {
java.lang.System.out.println(k);
return test1.Test.factTail(k - 1, ret * k);
}
return ret;
}
transforms to:
int factTail(int k, int ret) {
method.start : {
}
if (k > 0) {
java.lang.System.out.println(k);
tmp.k = k - 1;
tmp.ret = ret * k;
k = tmp.k;
ret = tmp.ret;
goto method.start;
}
return ret;
}
Current implementation doesn't handle void type tail recursion, because
it's difficult to check if a recursive call is the last instruction.
*/
@Nonnull
private final com.android.jack.util.filter.Filter<JMethod> filter =
ThreadConfig.get(Options.METHOD_FILTER);
@Nonnull
private final Tracer tracer = TracerFactory.getTracer();
@Nonnull
private static final StatisticId<Counter> TAIL_RECURSION_OPTS = new StatisticId< Counter >(
"jack.optimization.tail-recursion", "Tail recursion optimizations",
CounterImpl.class, Counter.class);
@Nonnull
private final JAnnotationType annotationType =
Jack.getSession().getPhantomLookup().getAnnotationType(
NamingTools.getTypeSignatureName(DisableTailRecursionOptimization.class.getName()));
private class TailRecursionVisitor extends JVisitor {
@Nonnull
private final JMethod enclosingMethod;
@Nonnull
private final TransformationRequest tr;
@CheckForNull
private JLabeledStatement labeledFirstStatement = null;
private TailRecursionVisitor(@Nonnull JMethod method,
@Nonnull TransformationRequest tr) {
this.enclosingMethod = method;
this.tr = tr;
}
private void labelFirstStatement() {
JMethodBody body = (JMethodBody) enclosingMethod.getBody();
assert body != null;
JBlock block = body.getBlock();
JStatement firstStatement = block.getStatements().get(0);
assert firstStatement != null;
if (firstStatement instanceof JLabeledStatement) {
labeledFirstStatement = (JLabeledStatement) firstStatement;
} else {
SourceInfo srcInfo = firstStatement.getSourceInfo();
labeledFirstStatement = new JLabeledStatement(srcInfo, new JLabel(srcInfo, "method.start"),
new JBlock(srcInfo));
tr.append(new PrependStatement(block, labeledFirstStatement));
}
}
@Override
public boolean visit(@Nonnull JTryStatement tryStatement) {
return false;
}
@Override
public boolean visit(@Nonnull JReturnStatement returnStatement) {
JExpression retExpr = returnStatement.getExpr();
if (retExpr instanceof JMethodCall) {
JMethodCall methodCall = (JMethodCall) retExpr;
JExpression instance = methodCall.getInstance();
if (methodCall.getMethodIdWide().equals(enclosingMethod.getMethodIdWide())
&& (instance == null || (instance.getType().isSameType(methodCall.getReceiverType())
&& instance instanceof JThisRef))) {
tracer.getStatistic(TAIL_RECURSION_OPTS).incValue();
if (labeledFirstStatement == null) {
labelFirstStatement();
}
assert labeledFirstStatement != null;
SourceInfo srcInfo = returnStatement.getSourceInfo();
JMethodBody body = (JMethodBody) enclosingMethod.getBody();
assert body != null;
Iterator<JParameter> paramIt = enclosingMethod.getParams().iterator();
Iterator<JExpression> exprIt = methodCall.getArgs().iterator();
ArrayList<JStatement> tmpAssignments =
new ArrayList<JStatement>();
ArrayList<JStatement> argAssignments =
new ArrayList<JStatement>();
while (paramIt.hasNext() && exprIt.hasNext()) {
JParameter param = paramIt.next();
JExpression expr = exprIt.next();
JLocal tempVar = new JLocal(srcInfo, "tmp." + param.getName(),
param.getType(), JModifier.FINAL, body);
tr.append(new AddJLocalInMethodBody(tempVar, body));
JAsgOperation asgToTemp = new JAsgOperation(srcInfo, tempVar.makeRef(srcInfo), expr);
JExpressionStatement asgToTempStmt =
new JExpressionStatement(srcInfo, asgToTemp);
tmpAssignments.add(asgToTempStmt);
JAsgOperation tempToArg =
new JAsgOperation(srcInfo, param.makeRef(srcInfo), tempVar.makeRef(srcInfo));
JStatement tempToArgStmt = new JExpressionStatement(srcInfo, tempToArg);
argAssignments.add(tempToArgStmt);
}
final JStatementList returnStmtParent = (JStatementList) returnStatement.getParent();
assert returnStmtParent != null;
for (JStatement asgStmt : tmpAssignments) {
tr.append(new AppendStatement(returnStmtParent, asgStmt));
}
for (JStatement asgStmt : argAssignments) {
tr.append(new AppendStatement(returnStmtParent, asgStmt));
}
JGoto tailCall = new JGoto(returnStatement.getSourceInfo(), labeledFirstStatement);
tailCall.setCatchBlocks(returnStatement.getJCatchBlocks());
tr.append(new AppendStatement(returnStmtParent, tailCall));
tr.append(new Remove(returnStatement));
}
}
return false;
}
}
@Override
public void run(@Nonnull JMethod method) {
if (method.isNative()
|| method.isAbstract()
// method should not be overridable
|| !(method.isFinal() || method.isPrivate() || method.isStatic())
|| method.isSynthetic()
|| !filter.accept(this.getClass(), method)
|| !method.getAnnotations(annotationType).isEmpty()
|| !method.getEnclosingType().getAnnotations(annotationType).isEmpty()) {
return;
}
TransformationRequest request = new TransformationRequest(method);
TailRecursionVisitor visitor = new TailRecursionVisitor(method, request);
visitor.accept(method);
request.commit();
}
}