| /* |
| * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved. |
| * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. |
| * |
| * This code is free software; you can redistribute it and/or modify it |
| * under the terms of the GNU General Public License version 2 only, as |
| * published by the Free Software Foundation. |
| * |
| * This code is distributed in the hope that it will be useful, but WITHOUT |
| * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or |
| * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License |
| * version 2 for more details (a copy is included in the LICENSE file that |
| * accompanied this code). |
| * |
| * You should have received a copy of the GNU General Public License version |
| * 2 along with this work; if not, write to the Free Software Foundation, |
| * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. |
| * |
| * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA |
| * or visit www.oracle.com if you need additional information or have any |
| * questions. |
| */ |
| |
| |
| package org.graalvm.compiler.replacements.test; |
| |
| import static org.junit.Assert.assertNotNull; |
| |
| import java.util.ArrayList; |
| import java.util.Collection; |
| import java.util.List; |
| |
| import org.graalvm.compiler.core.common.type.IntegerStamp; |
| import org.graalvm.compiler.core.common.type.StampFactory; |
| import org.graalvm.compiler.core.test.GraalCompilerTest; |
| import org.graalvm.compiler.graph.Node; |
| import org.graalvm.compiler.nodes.NodeView; |
| import org.graalvm.compiler.nodes.ParameterNode; |
| import org.graalvm.compiler.nodes.PiNode; |
| import org.graalvm.compiler.nodes.ReturnNode; |
| import org.graalvm.compiler.nodes.StructuredGraph; |
| import org.graalvm.compiler.nodes.StructuredGraph.AllowAssumptions; |
| import org.graalvm.compiler.nodes.StructuredGraph.GuardsStage; |
| import org.graalvm.compiler.nodes.ValueNode; |
| import org.graalvm.compiler.nodes.spi.LoweringTool; |
| import org.graalvm.compiler.phases.common.CanonicalizerPhase; |
| import org.graalvm.compiler.phases.common.LoweringPhase; |
| import org.graalvm.compiler.phases.tiers.HighTierContext; |
| import org.graalvm.compiler.phases.tiers.PhaseContext; |
| import org.graalvm.compiler.replacements.nodes.arithmetic.IntegerExactArithmeticNode; |
| import org.graalvm.compiler.replacements.nodes.arithmetic.IntegerExactArithmeticSplitNode; |
| import org.junit.Assert; |
| import org.junit.Test; |
| import org.junit.runner.RunWith; |
| import org.junit.runners.Parameterized; |
| import org.junit.runners.Parameterized.Parameters; |
| |
| @RunWith(Parameterized.class) |
| public class IntegerExactFoldTest extends GraalCompilerTest { |
| private final long lowerBoundA; |
| private final long upperBoundA; |
| private final long lowerBoundB; |
| private final long upperBoundB; |
| private final int bits; |
| private final Operation operation; |
| |
| public IntegerExactFoldTest(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, Operation operation) { |
| this.lowerBoundA = lowerBoundA; |
| this.upperBoundA = upperBoundA; |
| this.lowerBoundB = lowerBoundB; |
| this.upperBoundB = upperBoundB; |
| this.bits = bits; |
| this.operation = operation; |
| |
| assert bits == 32 || bits == 64; |
| assert lowerBoundA <= upperBoundA; |
| assert lowerBoundB <= upperBoundB; |
| assert bits == 64 || isInteger(lowerBoundA); |
| assert bits == 64 || isInteger(upperBoundA); |
| assert bits == 64 || isInteger(lowerBoundB); |
| assert bits == 64 || isInteger(upperBoundB); |
| } |
| |
| @Test |
| public void testFolding() { |
| StructuredGraph graph = prepareGraph(); |
| IntegerStamp a = StampFactory.forInteger(bits, lowerBoundA, upperBoundA); |
| IntegerStamp b = StampFactory.forInteger(bits, lowerBoundB, upperBoundB); |
| |
| List<ParameterNode> params = graph.getNodes(ParameterNode.TYPE).snapshot(); |
| params.get(0).replaceAtMatchingUsages(graph.addOrUnique(new PiNode(params.get(0), a)), x -> x instanceof IntegerExactArithmeticNode); |
| params.get(1).replaceAtMatchingUsages(graph.addOrUnique(new PiNode(params.get(1), b)), x -> x instanceof IntegerExactArithmeticNode); |
| |
| Node originalNode = graph.getNodes().filter(x -> x instanceof IntegerExactArithmeticNode).first(); |
| assertNotNull("original node must be in the graph", originalNode); |
| |
| new CanonicalizerPhase().apply(graph, getDefaultHighTierContext()); |
| ValueNode node = findNode(graph); |
| boolean overflowExpected = node instanceof IntegerExactArithmeticNode; |
| |
| IntegerStamp resultStamp = (IntegerStamp) node.stamp(NodeView.DEFAULT); |
| operation.verifyOverflow(lowerBoundA, upperBoundA, lowerBoundB, upperBoundB, bits, overflowExpected, resultStamp); |
| } |
| |
| @Test |
| public void testFoldingAfterLowering() { |
| StructuredGraph graph = prepareGraph(); |
| |
| Node originalNode = graph.getNodes().filter(x -> x instanceof IntegerExactArithmeticNode).first(); |
| assertNotNull("original node must be in the graph", originalNode); |
| |
| graph.setGuardsStage(GuardsStage.FIXED_DEOPTS); |
| CanonicalizerPhase canonicalizer = new CanonicalizerPhase(); |
| PhaseContext context = new PhaseContext(getProviders()); |
| new LoweringPhase(canonicalizer, LoweringTool.StandardLoweringStage.HIGH_TIER).apply(graph, context); |
| IntegerExactArithmeticSplitNode loweredNode = graph.getNodes().filter(IntegerExactArithmeticSplitNode.class).first(); |
| assertNotNull("the lowered node must be in the graph", loweredNode); |
| |
| loweredNode.getX().setStamp(StampFactory.forInteger(bits, lowerBoundA, upperBoundA)); |
| loweredNode.getY().setStamp(StampFactory.forInteger(bits, lowerBoundB, upperBoundB)); |
| new CanonicalizerPhase().apply(graph, context); |
| |
| ValueNode node = findNode(graph); |
| boolean overflowExpected = node instanceof IntegerExactArithmeticSplitNode; |
| |
| IntegerStamp resultStamp = (IntegerStamp) node.stamp(NodeView.DEFAULT); |
| operation.verifyOverflow(lowerBoundA, upperBoundA, lowerBoundB, upperBoundB, bits, overflowExpected, resultStamp); |
| } |
| |
| private static boolean isInteger(long value) { |
| return value >= Integer.MIN_VALUE && value <= Integer.MAX_VALUE; |
| } |
| |
| private static ValueNode findNode(StructuredGraph graph) { |
| ValueNode resultNode = graph.getNodes().filter(ReturnNode.class).first().result(); |
| assertNotNull("some node must be the returned value", resultNode); |
| return resultNode; |
| } |
| |
| protected StructuredGraph prepareGraph() { |
| String snippet = "snippetInt" + bits; |
| StructuredGraph graph = parseEager(getResolvedJavaMethod(operation.getClass(), snippet), AllowAssumptions.NO); |
| HighTierContext context = getDefaultHighTierContext(); |
| new CanonicalizerPhase().apply(graph, context); |
| return graph; |
| } |
| |
| private static void addTest(ArrayList<Object[]> tests, long lowerBound1, long upperBound1, long lowerBound2, long upperBound2, int bits, Operation operation) { |
| tests.add(new Object[]{lowerBound1, upperBound1, lowerBound2, upperBound2, bits, operation}); |
| } |
| |
| @Parameters(name = "a[{0} / {1}], b[{2} / {3}], bits={4}, operation={5}") |
| public static Collection<Object[]> data() { |
| ArrayList<Object[]> tests = new ArrayList<>(); |
| |
| Operation[] operations = new Operation[]{new AddOperation(), new SubOperation(), new MulOperation()}; |
| for (Operation operation : operations) { |
| for (int bits : new int[]{32, 64}) { |
| // zero related |
| addTest(tests, 0, 0, 1, 1, bits, operation); |
| addTest(tests, 1, 1, 0, 0, bits, operation); |
| addTest(tests, -1, 1, 0, 1, bits, operation); |
| |
| // bounds |
| addTest(tests, -2, 2, 3, 3, bits, operation); |
| addTest(tests, -1, 1, 1, 1, bits, operation); |
| addTest(tests, -1, 1, -1, 1, bits, operation); |
| |
| addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, Integer.MAX_VALUE - 0xF, Integer.MAX_VALUE, bits, operation); |
| addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, -1, -1, bits, operation); |
| addTest(tests, Integer.MAX_VALUE, Integer.MAX_VALUE, -1, -1, bits, operation); |
| addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE, -1, -1, bits, operation); |
| addTest(tests, Integer.MAX_VALUE, Integer.MAX_VALUE, 1, 1, bits, operation); |
| addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE, 1, 1, bits, operation); |
| } |
| |
| // bit-specific test cases |
| addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, Integer.MAX_VALUE - 0xF, Integer.MAX_VALUE, 64, operation); |
| addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, -1, -1, 64, operation); |
| } |
| |
| return tests; |
| } |
| |
| private abstract static class Operation { |
| abstract void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp); |
| } |
| |
| private static final class AddOperation extends Operation { |
| @Override |
| public void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp) { |
| try { |
| long res = addExact(lowerBoundA, lowerBoundB, bits); |
| resultStamp.contains(res); |
| res = addExact(upperBoundA, upperBoundB, bits); |
| resultStamp.contains(res); |
| Assert.assertFalse(overflowExpected); |
| } catch (ArithmeticException e) { |
| Assert.assertTrue(overflowExpected); |
| } |
| } |
| |
| private static long addExact(long x, long y, int bits) { |
| if (bits == 32) { |
| return Math.addExact((int) x, (int) y); |
| } else { |
| return Math.addExact(x, y); |
| } |
| } |
| |
| @SuppressWarnings("unused") |
| public static int snippetInt32(int a, int b) { |
| return Math.addExact(a, b); |
| } |
| |
| @SuppressWarnings("unused") |
| public static long snippetInt64(long a, long b) { |
| return Math.addExact(a, b); |
| } |
| } |
| |
| private static final class SubOperation extends Operation { |
| @Override |
| public void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp) { |
| try { |
| long res = subExact(lowerBoundA, upperBoundB, bits); |
| Assert.assertTrue(resultStamp.contains(res)); |
| res = subExact(upperBoundA, lowerBoundB, bits); |
| Assert.assertTrue(resultStamp.contains(res)); |
| Assert.assertFalse(overflowExpected); |
| } catch (ArithmeticException e) { |
| Assert.assertTrue(overflowExpected); |
| } |
| } |
| |
| private static long subExact(long x, long y, int bits) { |
| if (bits == 32) { |
| return Math.subtractExact((int) x, (int) y); |
| } else { |
| return Math.subtractExact(x, y); |
| } |
| } |
| |
| @SuppressWarnings("unused") |
| public static int snippetInt32(int a, int b) { |
| return Math.subtractExact(a, b); |
| } |
| |
| @SuppressWarnings("unused") |
| public static long snippetInt64(long a, long b) { |
| return Math.subtractExact(a, b); |
| } |
| } |
| |
| private static final class MulOperation extends Operation { |
| @Override |
| public void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp) { |
| // now check for all values in the stamp whether their products overflow overflow |
| boolean overflowOccurred = false; |
| |
| for (long l1 = lowerBoundA; l1 <= upperBoundA; l1++) { |
| for (long l2 = lowerBoundB; l2 <= upperBoundB; l2++) { |
| try { |
| long res = mulExact(l1, l2, bits); |
| Assert.assertTrue(resultStamp.contains(res)); |
| } catch (ArithmeticException e) { |
| overflowOccurred = true; |
| } |
| if (l2 == Long.MAX_VALUE) { |
| // do not want to overflow the check loop |
| break; |
| } |
| } |
| if (l1 == Long.MAX_VALUE) { |
| // do not want to overflow the check loop |
| break; |
| } |
| } |
| |
| Assert.assertEquals(overflowExpected, overflowOccurred); |
| } |
| |
| private static long mulExact(long x, long y, int bits) { |
| if (bits == 32) { |
| return Math.multiplyExact((int) x, (int) y); |
| } else { |
| return Math.multiplyExact(x, y); |
| } |
| } |
| |
| @SuppressWarnings("unused") |
| public static int snippetInt32(int a, int b) { |
| return Math.multiplyExact(a, b); |
| } |
| |
| @SuppressWarnings("unused") |
| public static long snippetInt64(long a, long b) { |
| return Math.multiplyExact(a, b); |
| } |
| } |
| } |