blob: af71cb1964f114aa4128b5fb64d3421cd05d107c [file] [log] [blame]
/*
* Copyright (c) 2014, 2015, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package org.graalvm.compiler.core.common.type;
import static jdk.vm.ci.meta.MetaUtil.getSimpleName;
import java.util.Arrays;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import jdk.vm.ci.meta.Constant;
import jdk.vm.ci.meta.JavaKind;
import org.graalvm.compiler.core.common.calc.FloatConvert;
import org.graalvm.compiler.core.common.type.ArithmeticOpTable.BinaryOp.Add;
import org.graalvm.compiler.core.common.type.ArithmeticOpTable.BinaryOp.And;
import org.graalvm.compiler.core.common.type.ArithmeticOpTable.BinaryOp.Div;
import org.graalvm.compiler.core.common.type.ArithmeticOpTable.BinaryOp.Mul;
import org.graalvm.compiler.core.common.type.ArithmeticOpTable.BinaryOp.Or;
import org.graalvm.compiler.core.common.type.ArithmeticOpTable.BinaryOp.Rem;
import org.graalvm.compiler.core.common.type.ArithmeticOpTable.BinaryOp.Sub;
import org.graalvm.compiler.core.common.type.ArithmeticOpTable.BinaryOp.Xor;
import org.graalvm.compiler.core.common.type.ArithmeticOpTable.IntegerConvertOp.Narrow;
import org.graalvm.compiler.core.common.type.ArithmeticOpTable.IntegerConvertOp.SignExtend;
import org.graalvm.compiler.core.common.type.ArithmeticOpTable.IntegerConvertOp.ZeroExtend;
import org.graalvm.compiler.core.common.type.ArithmeticOpTable.ShiftOp.Shl;
import org.graalvm.compiler.core.common.type.ArithmeticOpTable.ShiftOp.Shr;
import org.graalvm.compiler.core.common.type.ArithmeticOpTable.ShiftOp.UShr;
import org.graalvm.compiler.core.common.type.ArithmeticOpTable.UnaryOp.Abs;
import org.graalvm.compiler.core.common.type.ArithmeticOpTable.UnaryOp.Neg;
import org.graalvm.compiler.core.common.type.ArithmeticOpTable.UnaryOp.Not;
import org.graalvm.compiler.core.common.type.ArithmeticOpTable.UnaryOp.Sqrt;
/**
* Information about arithmetic operations.
*/
public final class ArithmeticOpTable {
private final UnaryOp<Neg> neg;
private final BinaryOp<Add> add;
private final BinaryOp<Sub> sub;
private final BinaryOp<Mul> mul;
private final BinaryOp<Div> div;
private final BinaryOp<Rem> rem;
private final UnaryOp<Not> not;
private final BinaryOp<And> and;
private final BinaryOp<Or> or;
private final BinaryOp<Xor> xor;
private final ShiftOp<Shl> shl;
private final ShiftOp<Shr> shr;
private final ShiftOp<UShr> ushr;
private final UnaryOp<Abs> abs;
private final UnaryOp<Sqrt> sqrt;
private final IntegerConvertOp<ZeroExtend> zeroExtend;
private final IntegerConvertOp<SignExtend> signExtend;
private final IntegerConvertOp<Narrow> narrow;
private final FloatConvertOp[] floatConvert;
private final int hash;
public static ArithmeticOpTable forStamp(Stamp s) {
if (s instanceof ArithmeticStamp) {
return ((ArithmeticStamp) s).getOps();
} else {
return EMPTY;
}
}
public static final ArithmeticOpTable EMPTY = new ArithmeticOpTable(null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null);
public ArithmeticOpTable(UnaryOp<Neg> neg, BinaryOp<Add> add, BinaryOp<Sub> sub, BinaryOp<Mul> mul, BinaryOp<Div> div, BinaryOp<Rem> rem, UnaryOp<Not> not, BinaryOp<And> and, BinaryOp<Or> or,
BinaryOp<Xor> xor, ShiftOp<Shl> shl, ShiftOp<Shr> shr, ShiftOp<UShr> ushr, UnaryOp<Abs> abs, UnaryOp<Sqrt> sqrt, IntegerConvertOp<ZeroExtend> zeroExtend,
IntegerConvertOp<SignExtend> signExtend, IntegerConvertOp<Narrow> narrow, FloatConvertOp... floatConvert) {
this(neg, add, sub, mul, div, rem, not, and, or, xor, shl, shr, ushr, abs, sqrt, zeroExtend, signExtend, narrow, Stream.of(floatConvert));
}
public interface ArithmeticOpWrapper {
<OP> UnaryOp<OP> wrapUnaryOp(UnaryOp<OP> op);
<OP> BinaryOp<OP> wrapBinaryOp(BinaryOp<OP> op);
<OP> ShiftOp<OP> wrapShiftOp(ShiftOp<OP> op);
<OP> IntegerConvertOp<OP> wrapIntegerConvertOp(IntegerConvertOp<OP> op);
FloatConvertOp wrapFloatConvertOp(FloatConvertOp op);
}
private static <T> T wrapIfNonNull(Function<T, T> wrapper, T obj) {
if (obj == null) {
return null;
} else {
return wrapper.apply(obj);
}
}
public static ArithmeticOpTable wrap(ArithmeticOpWrapper wrapper, ArithmeticOpTable inner) {
UnaryOp<Neg> neg = wrapIfNonNull(wrapper::wrapUnaryOp, inner.getNeg());
BinaryOp<Add> add = wrapIfNonNull(wrapper::wrapBinaryOp, inner.getAdd());
BinaryOp<Sub> sub = wrapIfNonNull(wrapper::wrapBinaryOp, inner.getSub());
BinaryOp<Mul> mul = wrapIfNonNull(wrapper::wrapBinaryOp, inner.getMul());
BinaryOp<Div> div = wrapIfNonNull(wrapper::wrapBinaryOp, inner.getDiv());
BinaryOp<Rem> rem = wrapIfNonNull(wrapper::wrapBinaryOp, inner.getRem());
UnaryOp<Not> not = wrapIfNonNull(wrapper::wrapUnaryOp, inner.getNot());
BinaryOp<And> and = wrapIfNonNull(wrapper::wrapBinaryOp, inner.getAnd());
BinaryOp<Or> or = wrapIfNonNull(wrapper::wrapBinaryOp, inner.getOr());
BinaryOp<Xor> xor = wrapIfNonNull(wrapper::wrapBinaryOp, inner.getXor());
ShiftOp<Shl> shl = wrapIfNonNull(wrapper::wrapShiftOp, inner.getShl());
ShiftOp<Shr> shr = wrapIfNonNull(wrapper::wrapShiftOp, inner.getShr());
ShiftOp<UShr> ushr = wrapIfNonNull(wrapper::wrapShiftOp, inner.getUShr());
UnaryOp<Abs> abs = wrapIfNonNull(wrapper::wrapUnaryOp, inner.getAbs());
UnaryOp<Sqrt> sqrt = wrapIfNonNull(wrapper::wrapUnaryOp, inner.getSqrt());
IntegerConvertOp<ZeroExtend> zeroExtend = wrapIfNonNull(wrapper::wrapIntegerConvertOp, inner.getZeroExtend());
IntegerConvertOp<SignExtend> signExtend = wrapIfNonNull(wrapper::wrapIntegerConvertOp, inner.getSignExtend());
IntegerConvertOp<Narrow> narrow = wrapIfNonNull(wrapper::wrapIntegerConvertOp, inner.getNarrow());
Stream<FloatConvertOp> floatConvert = Stream.of(inner.floatConvert).filter(Objects::nonNull).map(wrapper::wrapFloatConvertOp);
return new ArithmeticOpTable(neg, add, sub, mul, div, rem, not, and, or, xor, shl, shr, ushr, abs, sqrt, zeroExtend, signExtend, narrow, floatConvert);
}
private ArithmeticOpTable(UnaryOp<Neg> neg, BinaryOp<Add> add, BinaryOp<Sub> sub, BinaryOp<Mul> mul, BinaryOp<Div> div, BinaryOp<Rem> rem, UnaryOp<Not> not, BinaryOp<And> and, BinaryOp<Or> or,
BinaryOp<Xor> xor, ShiftOp<Shl> shl, ShiftOp<Shr> shr, ShiftOp<UShr> ushr, UnaryOp<Abs> abs, UnaryOp<Sqrt> sqrt, IntegerConvertOp<ZeroExtend> zeroExtend,
IntegerConvertOp<SignExtend> signExtend, IntegerConvertOp<Narrow> narrow, Stream<FloatConvertOp> floatConvert) {
this.neg = neg;
this.add = add;
this.sub = sub;
this.mul = mul;
this.div = div;
this.rem = rem;
this.not = not;
this.and = and;
this.or = or;
this.xor = xor;
this.shl = shl;
this.shr = shr;
this.ushr = ushr;
this.abs = abs;
this.sqrt = sqrt;
this.zeroExtend = zeroExtend;
this.signExtend = signExtend;
this.narrow = narrow;
this.floatConvert = new FloatConvertOp[FloatConvert.values().length];
floatConvert.forEach(op -> this.floatConvert[op.getFloatConvert().ordinal()] = op);
this.hash = Objects.hash(neg, add, sub, mul, div, rem, not, and, or, xor, shl, shr, ushr, abs, sqrt, zeroExtend, signExtend, narrow);
}
@Override
public int hashCode() {
return hash;
}
/**
* Describes the unary negation operation.
*/
public UnaryOp<Neg> getNeg() {
return neg;
}
/**
* Describes the addition operation.
*/
public BinaryOp<Add> getAdd() {
return add;
}
/**
* Describes the subtraction operation.
*/
public BinaryOp<Sub> getSub() {
return sub;
}
/**
* Describes the multiplication operation.
*/
public BinaryOp<Mul> getMul() {
return mul;
}
/**
* Describes the division operation.
*/
public BinaryOp<Div> getDiv() {
return div;
}
/**
* Describes the remainder operation.
*/
public BinaryOp<Rem> getRem() {
return rem;
}
/**
* Describes the bitwise not operation.
*/
public UnaryOp<Not> getNot() {
return not;
}
/**
* Describes the bitwise and operation.
*/
public BinaryOp<And> getAnd() {
return and;
}
/**
* Describes the bitwise or operation.
*/
public BinaryOp<Or> getOr() {
return or;
}
/**
* Describes the bitwise xor operation.
*/
public BinaryOp<Xor> getXor() {
return xor;
}
/**
* Describes the shift left operation.
*/
public ShiftOp<Shl> getShl() {
return shl;
}
/**
* Describes the signed shift right operation.
*/
public ShiftOp<Shr> getShr() {
return shr;
}
/**
* Describes the unsigned shift right operation.
*/
public ShiftOp<UShr> getUShr() {
return ushr;
}
/**
* Describes the absolute value operation.
*/
public UnaryOp<Abs> getAbs() {
return abs;
}
/**
* Describes the square root operation.
*/
public UnaryOp<Sqrt> getSqrt() {
return sqrt;
}
/**
* Describes the zero extend conversion.
*/
public IntegerConvertOp<ZeroExtend> getZeroExtend() {
return zeroExtend;
}
/**
* Describes the sign extend conversion.
*/
public IntegerConvertOp<SignExtend> getSignExtend() {
return signExtend;
}
/**
* Describes the narrowing conversion.
*/
public IntegerConvertOp<Narrow> getNarrow() {
return narrow;
}
/**
* Describes integer/float/double conversions.
*/
public FloatConvertOp getFloatConvert(FloatConvert op) {
return floatConvert[op.ordinal()];
}
public static String toString(Op... ops) {
return Arrays.asList(ops).stream().map(o -> o == null ? "null" : o.operator + "{" + getSimpleName(o.getClass(), false) + "}").collect(Collectors.joining(","));
}
private boolean opsEquals(ArithmeticOpTable that) {
// @formatter:off
return Objects.equals(neg, that.neg) &&
Objects.equals(add, that.add) &&
Objects.equals(sub, that.sub) &&
Objects.equals(mul, that.mul) &&
Objects.equals(div, that.div) &&
Objects.equals(rem, that.rem) &&
Objects.equals(not, that.not) &&
Objects.equals(and, that.and) &&
Objects.equals(or, that.or) &&
Objects.equals(xor, that.xor) &&
Objects.equals(shl, that.shl) &&
Objects.equals(shr, that.shr) &&
Objects.equals(ushr, that.ushr) &&
Objects.equals(abs, that.abs) &&
Objects.equals(sqrt, that.sqrt) &&
Objects.equals(zeroExtend, that.zeroExtend) &&
Objects.equals(signExtend, that.signExtend) &&
Objects.equals(narrow, that.narrow);
// @formatter:on
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
ArithmeticOpTable that = (ArithmeticOpTable) obj;
if (opsEquals(that)) {
if (Arrays.equals(this.floatConvert, that.floatConvert)) {
return true;
}
}
return false;
}
@Override
public String toString() {
return getClass().getSimpleName() + "[" + toString(neg, add, sub, mul, div, rem, not, and, or, xor, shl, shr, ushr, abs, sqrt, zeroExtend, signExtend, narrow) + ",floatConvert[" +
toString(floatConvert) + "]]";
}
public abstract static class Op {
private final String operator;
protected Op(String operator) {
this.operator = operator;
}
@Override
public String toString() {
return operator;
}
@Override
public int hashCode() {
return operator.hashCode();
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
Op that = (Op) obj;
if (operator.equals(that.operator)) {
return true;
}
return true;
}
}
/**
* Describes a unary arithmetic operation.
*/
public abstract static class UnaryOp<T> extends Op {
public abstract static class Neg extends UnaryOp<Neg> {
protected Neg() {
super("-");
}
}
public abstract static class Not extends UnaryOp<Not> {
protected Not() {
super("~");
}
}
public abstract static class Abs extends UnaryOp<Abs> {
protected Abs() {
super("ABS");
}
}
public abstract static class Sqrt extends UnaryOp<Sqrt> {
protected Sqrt() {
super("SQRT");
}
}
protected UnaryOp(String operation) {
super(operation);
}
/**
* Apply the operation to a {@link Constant}.
*/
public abstract Constant foldConstant(Constant value);
/**
* Apply the operation to a {@link Stamp}.
*/
public abstract Stamp foldStamp(Stamp stamp);
public UnaryOp<T> unwrap() {
return this;
}
}
/**
* Describes a binary arithmetic operation.
*/
public abstract static class BinaryOp<T> extends Op {
public abstract static class Add extends BinaryOp<Add> {
protected Add(boolean associative, boolean commutative) {
super("+", associative, commutative);
}
}
public abstract static class Sub extends BinaryOp<Sub> {
protected Sub(boolean associative, boolean commutative) {
super("-", associative, commutative);
}
}
public abstract static class Mul extends BinaryOp<Mul> {
protected Mul(boolean associative, boolean commutative) {
super("*", associative, commutative);
}
}
public abstract static class Div extends BinaryOp<Div> {
protected Div(boolean associative, boolean commutative) {
super("/", associative, commutative);
}
}
public abstract static class Rem extends BinaryOp<Rem> {
protected Rem(boolean associative, boolean commutative) {
super("%", associative, commutative);
}
}
public abstract static class And extends BinaryOp<And> {
protected And(boolean associative, boolean commutative) {
super("&", associative, commutative);
}
}
public abstract static class Or extends BinaryOp<Or> {
protected Or(boolean associative, boolean commutative) {
super("|", associative, commutative);
}
}
public abstract static class Xor extends BinaryOp<Xor> {
protected Xor(boolean associative, boolean commutative) {
super("^", associative, commutative);
}
}
private final boolean associative;
private final boolean commutative;
protected BinaryOp(String operation, boolean associative, boolean commutative) {
super(operation);
this.associative = associative;
this.commutative = commutative;
}
/**
* Apply the operation to two {@linkplain Constant Constants}.
*/
public abstract Constant foldConstant(Constant a, Constant b);
/**
* Apply the operation to two {@linkplain Stamp Stamps}.
*/
public abstract Stamp foldStamp(Stamp a, Stamp b);
/**
* Checks whether this operation is associative. An operation is associative when
* {@code (a . b) . c == a . (b . c)} for all a, b, c. Note that you still have to be
* careful with inverses. For example the integer subtraction operation will report
* {@code true} here, since you can still reassociate as long as the correct negations are
* inserted.
*/
public final boolean isAssociative() {
return associative;
}
/**
* Checks whether this operation is commutative. An operation is commutative when
* {@code a . b == b . a} for all a, b.
*/
public final boolean isCommutative() {
return commutative;
}
/**
* Check whether a {@link Constant} is a neutral element for this operation. A neutral
* element is any element {@code n} where {@code a . n == a} for all a.
*
* @param n the {@link Constant} that should be tested
* @return true iff for all {@code a}: {@code a . n == a}
*/
public boolean isNeutral(Constant n) {
return false;
}
/**
* Check whether this operation has a zero {@code z == a . a} for each a. Examples of
* operations having such an element are subtraction and exclusive-or. Note that this may be
* different from the numbers tested by {@link #isNeutral}.
*
* @param stamp a {@link Stamp}
* @return a unique {@code z} such that {@code z == a . a} for each {@code a} in
* {@code stamp} if it exists, otherwise {@code null}
*/
public Constant getZero(Stamp stamp) {
return null;
}
public BinaryOp<T> unwrap() {
return this;
}
@Override
public int hashCode() {
final int prime = 31;
int result = super.hashCode();
result = prime * result + (associative ? 1231 : 1237);
result = prime * result + (commutative ? 1231 : 1237);
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (!super.equals(obj)) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
BinaryOp<?> that = (BinaryOp<?>) obj;
if (associative != that.associative) {
return false;
}
if (commutative != that.commutative) {
return false;
}
return true;
}
@Override
public String toString() {
if (associative) {
if (commutative) {
return super.toString() + "[AC]";
} else {
return super.toString() + "[A]";
}
} else if (commutative) {
return super.toString() + "[C]";
}
return super.toString();
}
}
/**
* Describes a shift operation. The right argument of a shift operation always has kind
* {@link JavaKind#Int}.
*/
public abstract static class ShiftOp<OP> extends Op {
public abstract static class Shl extends ShiftOp<Shl> {
public Shl() {
super("<<");
}
}
public abstract static class Shr extends ShiftOp<Shr> {
public Shr() {
super(">>");
}
}
public abstract static class UShr extends ShiftOp<UShr> {
public UShr() {
super(">>>");
}
}
protected ShiftOp(String operation) {
super(operation);
}
/**
* Apply the shift to a constant.
*/
public abstract Constant foldConstant(Constant c, int amount);
/**
* Apply the shift to a stamp.
*/
public abstract Stamp foldStamp(Stamp s, IntegerStamp amount);
/**
* Get the shift amount mask for a given result stamp.
*/
public abstract int getShiftAmountMask(Stamp s);
}
public abstract static class FloatConvertOp extends UnaryOp<FloatConvertOp> {
private final FloatConvert op;
protected FloatConvertOp(FloatConvert op) {
super(op.name());
this.op = op;
}
public FloatConvert getFloatConvert() {
return op;
}
@Override
public FloatConvertOp unwrap() {
return this;
}
@Override
public int hashCode() {
final int prime = 31;
return prime * super.hashCode() + op.hashCode();
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (!super.equals(obj)) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
FloatConvertOp that = (FloatConvertOp) obj;
if (op != that.op) {
return false;
}
return true;
}
}
public abstract static class IntegerConvertOp<T> extends Op {
public abstract static class ZeroExtend extends IntegerConvertOp<ZeroExtend> {
protected ZeroExtend() {
super("ZeroExtend");
}
}
public abstract static class SignExtend extends IntegerConvertOp<SignExtend> {
protected SignExtend() {
super("SignExtend");
}
}
public abstract static class Narrow extends IntegerConvertOp<Narrow> {
protected Narrow() {
super("Narrow");
}
}
protected IntegerConvertOp(String op) {
super(op);
}
public abstract Constant foldConstant(int inputBits, int resultBits, Constant value);
public abstract Stamp foldStamp(int inputBits, int resultBits, Stamp stamp);
public IntegerConvertOp<T> unwrap() {
return this;
}
}
}