blob: 7e47ca3fb61ca307f20ac2e6239a13de8fb13a27 [file] [log] [blame]
#include <test/cpp/tensorexpr/test_base.h>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <unordered_map>
#include <test/cpp/tensorexpr/padded_buffer.h>
#include <torch/csrc/jit/tensorexpr/analysis.h>
#include <torch/csrc/jit/tensorexpr/bounds_inference.h>
#include <torch/csrc/jit/tensorexpr/buffer.h>
#include <torch/csrc/jit/tensorexpr/eval.h>
#include <torch/csrc/jit/tensorexpr/function.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
#include <torch/csrc/jit/tensorexpr/loopnest.h>
#include <torch/csrc/jit/tensorexpr/tensor.h>
namespace torch {
namespace jit {
using namespace torch::jit::tensorexpr;
static void verifyConstBounds(
const TensorAccessBoundsInfo& access_info,
const std::vector<std::pair<int, int>>& ref) {
size_t ndim = ref.size();
ASSERT_EQ(access_info.start.size(), ndim);
ASSERT_EQ(access_info.stop.size(), ndim);
for (size_t i = 0; i < ndim; i++) {
if (ref[i].first >= 0) { // Negative values are used to skip the check
ASSERT_TRUE(access_info.start[i]->isConstant());
int start_i = immediateAs<int>(access_info.start[i]);
ASSERT_EQ(start_i, ref[i].first);
}
if (ref[i].second >= 0) {
ASSERT_TRUE(access_info.stop[i]->isConstant());
int stop_i = immediateAs<int>(access_info.stop[i]);
ASSERT_EQ(stop_i, ref[i].second);
}
}
}
void testBoundsInference_1() {
// Verify that bounds inference works for the following example:
// for i in 0..100:
// b[i] = a[i]
// For this loop bounds inference should yield the following:
// {{b, kStore, 0, 99}, {a, kLoad, 0, 99}}
KernelScope kernel_scope;
ExprHandle n(100);
Buffer a(BufHandle("a", {n}, kFloat));
Tensor* b =
Compute("b", {{n, "i"}}, [&](const VarHandle& i) { return a(i); });
LoopNest l({b});
auto bounds_info = inferBounds(l.root_stmt());
// We should have two entries: one for 'b' and one for 'a'.
ASSERT_EQ(bounds_info.size(), 2);
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
verifyConstBounds(bounds_info.at(a.data())[0], {{0, 99}});
ASSERT_EQ(bounds_info.at(b->buf()).size(), 1);
ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kStore);
verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 99}});
}
void testBoundsInference_2() {
// Verify that bounds inference works for the following example:
// for i in 0..n:
// b[i] = a[i]
// For this loop bounds inference should yield the following:
// {{b, kStore, 0, n-1}, {a, kLoad, 0, n-1}}
KernelScope kernel_scope;
VarHandle n("n", kInt);
Buffer a(BufHandle("a", {n}, kFloat));
Tensor* b =
Compute("b", {{n, "i"}}, [&](const VarHandle& i) { return a(i); });
LoopNest l({b});
auto bounds_info = inferBounds(l.root_stmt());
// We should have two entries: one for 'b' and one for 'a'.
ASSERT_EQ(bounds_info.size(), 2);
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
verifyConstBounds(bounds_info.at(a.data())[0], {{0, -1}});
ASSERT_EQ(bounds_info.at(b->buf()).size(), 1);
ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kStore);
verifyConstBounds(bounds_info.at(b->buf())[0], {{0, -1}});
}
void testBoundsInference_3() {
// Verify that bounds inference works for the following example:
// for i in 0..100:
// b[i] = a[i] * a[i+10]
// For this loop bounds inference should yield the following:
// {{b, kStore, 0, 99}, {a, kLoad, 0, 109}}
KernelScope kernel_scope;
ExprHandle n(100);
Buffer a(BufHandle("a", {n + 10}, kFloat));
Tensor* b = Compute(
"b", {{n, "i"}}, [&](const VarHandle& i) { return a(i) * a(i + 10); });
LoopNest l({b});
auto bounds_info = inferBounds(l.root_stmt());
// We should have two entries: one for 'b' and one for 'a'.
ASSERT_EQ(bounds_info.size(), 2);
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
verifyConstBounds(bounds_info.at(a.data())[0], {{0, 109}});
ASSERT_EQ(bounds_info.at(b->buf()).size(), 1);
ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kStore);
verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 99}});
}
void testBoundsInference_4() {
// Verify that bounds inference works for the following example:
//
// for y in 0..200:
// for x in 0..320:
// b[y,x] = x*y
// for y in 0..200:
// for x in 0..320:
// c[y,x] = a[y,x] * b[y,x]
KernelScope kernel_scope;
ExprHandle W(320);
ExprHandle H(200);
Buffer a(BufHandle("a", {H, W}, kFloat));
Tensor* b = Compute(
"b", {{H, "y"}, {W, "x"}}, [&](const VarHandle& y, const VarHandle& x) {
return x * y;
});
Tensor* c = Compute(
"c", {{H, "y"}, {W, "x"}}, [&](const VarHandle& y, const VarHandle& x) {
return a(y, x) * b->call(y, x);
});
LoopNest l({c});
std::vector<For*> loops = l.getLoopStmtsFor(c);
Stmt* body = l.getLoopBodyFor(c);
{
// Infer bounds on the top-level loop scope
auto bounds_info = inferBounds(loops[0]);
ASSERT_EQ(bounds_info.size(), 3);
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
verifyConstBounds(bounds_info.at(a.data())[0], {{0, 199}, {0, 319}});
ASSERT_EQ(bounds_info.at(b->buf()).size(), 1);
ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kLoad);
verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 199}, {0, 319}});
ASSERT_EQ(bounds_info.at(c->buf()).size(), 1);
ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore);
verifyConstBounds(bounds_info.at(c->buf())[0], {{0, 199}, {0, 319}});
}
{
// Infer bounds on the inner loop scope
auto bounds_info = inferBounds(loops[1]);
ASSERT_EQ(bounds_info.size(), 3);
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
verifyConstBounds(bounds_info.at(a.data())[0], {{-1, -1}, {0, 319}});
ASSERT_EQ(bounds_info.at(b->buf()).size(), 1);
ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kLoad);
verifyConstBounds(bounds_info.at(b->buf())[0], {{-1, -1}, {0, 319}});
ASSERT_EQ(bounds_info.at(c->buf()).size(), 1);
ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore);
verifyConstBounds(bounds_info.at(c->buf())[0], {{-1, -1}, {0, 319}});
}
{
// Infer bounds on the inner loop body's scope
auto bounds_info = inferBounds(body);
ASSERT_EQ(bounds_info.size(), 3);
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
verifyConstBounds(bounds_info.at(a.data())[0], {{-1, -1}, {-1, -1}});
ASSERT_EQ(bounds_info.at(b->buf()).size(), 1);
ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kLoad);
verifyConstBounds(bounds_info.at(b->buf())[0], {{-1, -1}, {-1, -1}});
ASSERT_EQ(bounds_info.at(c->buf()).size(), 1);
ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore);
verifyConstBounds(bounds_info.at(c->buf())[0], {{-1, -1}, {-1, -1}});
}
}
void testBoundsInference_5() {
// Verify that bounds inference works for the following example:
// for i in 0..100:
// b[i] = a[i]
//
// ==> split ==>
//
// for i_outer in 0..100/16:
// for i_inner in 0..16:
// b[i_outer * 16 + i_inner] = a[i_outer * 16 + i_inner]
// for i_tail in 0..100%16:
// b[i_tail + (100/16)*16] = a[i_tail + (100/16)*16];
KernelScope kernel_scope;
ExprHandle n(100);
Buffer a(BufHandle("a", {n}, kFloat));
Tensor* b =
Compute("b", {{n, "i"}}, [&](const VarHandle& i) { return a(i); });
LoopNest l({b});
For* outer;
For* inner;
For* tail;
std::vector<For*> loops = l.getLoopStmtsFor(b);
l.splitWithTail(loops[0], 16, &outer, &inner, &tail);
{
// Verify inferred bounds for the outer loop
auto bounds_info = inferBounds(outer);
ASSERT_EQ(bounds_info.size(), 2);
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
verifyConstBounds(bounds_info.at(a.data())[0], {{0, 95}});
ASSERT_EQ(bounds_info.at(b->buf()).size(), 1);
ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kStore);
verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 95}});
}
{
// Verify inferred bounds for the tail loop
auto bounds_info = inferBounds(tail);
ASSERT_EQ(bounds_info.size(), 2);
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
verifyConstBounds(bounds_info.at(a.data())[0], {{96, 99}});
ASSERT_EQ(bounds_info.at(b->buf()).size(), 1);
ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kStore);
verifyConstBounds(bounds_info.at(b->buf())[0], {{96, 99}});
}
}
void testBoundsInference_6() {
// Verify that bounds inference works for the following example:
//
// for y in 0..200:
// for x in 0..320:
// b[y,x] = x*y
// for y in 0..20:
// for x in 0..32:
// c[y,x] = a[y+100,x+100] * b[y*2,x*5]
KernelScope kernel_scope;
ExprHandle W(320);
ExprHandle H(200);
ExprHandle CW(32);
ExprHandle CH(20);
Buffer a(BufHandle("a", {H, W}, kFloat));
Tensor* b = Compute(
"b", {{H, "y"}, {W, "x"}}, [&](const VarHandle& y, const VarHandle& x) {
return x * y;
});
Tensor* c = Compute(
"c", {{CH, "y"}, {CW, "x"}}, [&](const VarHandle& y, const VarHandle& x) {
return a(y + 100, x + 100) * b->call(y * 2, x * 5);
});
LoopNest l({c});
std::vector<For*> loops = l.getLoopStmtsFor(c);
Stmt* body = l.getLoopBodyFor(c);
{
// Infer bounds on the top-level loop scope
auto bounds_info = inferBounds(loops[0]);
ASSERT_EQ(bounds_info.size(), 3);
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
verifyConstBounds(bounds_info.at(a.data())[0], {{100, 119}, {100, 131}});
ASSERT_EQ(bounds_info.at(b->buf()).size(), 1);
ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kLoad);
verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 38}, {0, 155}});
ASSERT_EQ(bounds_info.at(c->buf()).size(), 1);
ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore);
verifyConstBounds(bounds_info.at(c->buf())[0], {{0, 19}, {0, 31}});
}
{
// Infer bounds on the inner loop scope
auto bounds_info = inferBounds(loops[1]);
ASSERT_EQ(bounds_info.size(), 3);
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
verifyConstBounds(bounds_info.at(a.data())[0], {{-1, -1}, {100, 131}});
ASSERT_EQ(bounds_info.at(b->buf()).size(), 1);
ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kLoad);
verifyConstBounds(bounds_info.at(b->buf())[0], {{-1, -1}, {0, 155}});
ASSERT_EQ(bounds_info.at(c->buf()).size(), 1);
ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore);
verifyConstBounds(bounds_info.at(c->buf())[0], {{-1, -1}, {0, 31}});
}
{
// Infer bounds on the inner loop body's scope
auto bounds_info = inferBounds(body);
ASSERT_EQ(bounds_info.size(), 3);
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
verifyConstBounds(bounds_info.at(a.data())[0], {{-1, -1}, {-1, -1}});
ASSERT_EQ(bounds_info.at(b->buf()).size(), 1);
ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kLoad);
verifyConstBounds(bounds_info.at(b->buf())[0], {{-1, -1}, {-1, -1}});
ASSERT_EQ(bounds_info.at(c->buf()).size(), 1);
ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore);
verifyConstBounds(bounds_info.at(c->buf())[0], {{-1, -1}, {-1, -1}});
}
}
void testBoundsInferenceNonOverlapping() {
KernelScope kernel_scope;
ExprHandle H(3);
Buffer a(BufHandle("a", {10}, kFloat));
Tensor* b =
Compute("b", {{H, "x"}}, [&](const VarHandle& x) { return a(x); });
Tensor* c = Compute(
"c", {{H, "x"}}, [&](const VarHandle& x) { return a(x + H + 1); });
LoopNest l({b, c});
std::vector<For*> loops = NodeFinder<For>::find(l.root_stmt());
{
// Infer bounds on the top-level loop scope
auto bounds_info = inferBounds(loops[0]);
ASSERT_EQ(bounds_info.size(), 2);
// reads from a[0:2], writes to b[0:2]
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
verifyConstBounds(bounds_info.at(a.data())[0], {{0, 2}});
ASSERT_EQ(bounds_info.at(b->buf()).size(), 1);
ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kStore);
verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 2}});
}
{
// Infer bounds on the inner loop scope
auto bounds_info = inferBounds(loops[1]);
ASSERT_EQ(bounds_info.size(), 2);
// reads from a[0+4:2+4], writes to c[0:2]
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
verifyConstBounds(bounds_info.at(a.data())[0], {{4, 6}});
ASSERT_EQ(bounds_info.at(c->buf()).size(), 1);
ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore);
verifyConstBounds(bounds_info.at(c->buf())[0], {{0, 2}});
}
{
// Infer bounds on the high level program.
auto bounds_info = inferBounds(l.root_stmt());
ASSERT_EQ(bounds_info.size(), 3);
// Should be union of above 2 bounds.
ASSERT_EQ(bounds_info.at(a.data()).size(), 2);
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
verifyConstBounds(bounds_info.at(a.data())[0], {{0, 2}});
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
verifyConstBounds(bounds_info.at(a.data())[1], {{4, 6}});
ASSERT_EQ(bounds_info.at(b->buf()).size(), 1);
ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kStore);
verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 2}});
ASSERT_EQ(bounds_info.at(c->buf()).size(), 1);
ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore);
verifyConstBounds(bounds_info.at(c->buf())[0], {{0, 2}});
}
}
void testBoundsInferenceAdjacent() {
KernelScope kernel_scope;
ExprHandle H(6);
Buffer a(BufHandle("a", {20}, kFloat));
Tensor* b =
Compute("b", {{H, "x"}}, [&](const VarHandle& x) { return a(x); });
Tensor* c =
Compute("c", {{H, "x"}}, [&](const VarHandle& x) { return a(x + H); });
LoopNest l({b, c});
std::vector<For*> loops = NodeFinder<For>::find(l.root_stmt());
{
// Infer bounds on the top-level loop scope
auto bounds_info = inferBounds(loops[0]);
ASSERT_EQ(bounds_info.size(), 2);
// reads from a[0:5], writes to b[0:5]
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
verifyConstBounds(bounds_info.at(a.data())[0], {{0, 5}});
ASSERT_EQ(bounds_info.at(b->buf()).size(), 1);
ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kStore);
verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 5}});
}
{
// Infer bounds on the inner loop scope
auto bounds_info = inferBounds(loops[1]);
ASSERT_EQ(bounds_info.size(), 2);
// reads from a[0+6:5+6], writes to c[0:5]
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
verifyConstBounds(bounds_info.at(a.data())[0], {{6, 11}});
ASSERT_EQ(bounds_info.at(c->buf()).size(), 1);
ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore);
verifyConstBounds(bounds_info.at(c->buf())[0], {{0, 5}});
}
{
// Infer bounds on the high level program.
auto bounds_info = inferBounds(l.root_stmt());
ASSERT_EQ(bounds_info.size(), 3);
// Should be union of above 2 bounds, but this time the bounds of A can be
// merged.
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
verifyConstBounds(bounds_info.at(a.data())[0], {{0, 11}});
ASSERT_EQ(bounds_info.at(b->buf()).size(), 1);
ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kStore);
verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 5}});
ASSERT_EQ(bounds_info.at(c->buf()).size(), 1);
ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore);
verifyConstBounds(bounds_info.at(c->buf())[0], {{0, 5}});
}
}
void testMergeInferredBounds() {
KernelScope kernel_scope;
Buffer a(BufHandle("a", {10}, kFloat));
// There are seven cases to consider in mergeTensorAccesses(A, B)
// * A is lower than B and does not overlap.
// * A is higher than B and does not overlap.
// * A overlaps B on both ends.
// * B overlaps A on both ends.
// * A overlaps B on the lower end. (equiv to B overlaps A on upper end).
// * A overlaps B on the upper end. (likewise covers reverse)
// * A and B are the same range.
BoundsInfo info;
// Test no overlap, both ways.
info[a.data()].push_back({kLoad, {new IntImm(1)}, {new IntImm(3)}});
info[a.data()].push_back({kLoad, {new IntImm(5)}, {new IntImm(7)}});
info[a.data()].push_back({kLoad, {new IntImm(9)}, {new IntImm(9)}});
BoundsInfo res = mergeTensorAccesses(info);
ASSERT_EQ(res.size(), 1);
ASSERT_EQ(res[a.data()].size(), 3);
ASSERT_EQ(res.at(a.data())[0].kind, kLoad);
ASSERT_EQ(res.at(a.data())[1].kind, kLoad);
ASSERT_EQ(res.at(a.data())[2].kind, kLoad);
verifyConstBounds(res.at(a.data())[0], {{1, 3}});
verifyConstBounds(res.at(a.data())[1], {{5, 7}});
verifyConstBounds(res.at(a.data())[2], {{9, 9}});
// Test full overlap, A over B.
info.clear();
info[a.data()].push_back({kLoad, {new IntImm(1)}, {new IntImm(7)}});
info[a.data()].push_back({kLoad, {new IntImm(3)}, {new IntImm(6)}});
res = mergeTensorAccesses(info);
ASSERT_EQ(res[a.data()].size(), 1);
verifyConstBounds(res.at(a.data())[0], {{1, 7}});
// B over A.
info.clear();
info[a.data()].push_back({kLoad, {new IntImm(3)}, {new IntImm(6)}});
info[a.data()].push_back({kLoad, {new IntImm(1)}, {new IntImm(7)}});
res = mergeTensorAccesses(info);
ASSERT_EQ(res[a.data()].size(), 1);
verifyConstBounds(res.at(a.data())[0], {{1, 7}});
// Test partial overlap on the low end, A over B.
info.clear();
info[a.data()].push_back({kLoad, {new IntImm(5)}, {new IntImm(7)}});
info[a.data()].push_back({kLoad, {new IntImm(3)}, {new IntImm(6)}});
res = mergeTensorAccesses(info);
ASSERT_EQ(res[a.data()].size(), 1);
verifyConstBounds(res.at(a.data())[0], {{3, 7}});
// Test partial overlap on the high end.
info.clear();
info[a.data()].push_back({kLoad, {new IntImm(2)}, {new IntImm(5)}});
info[a.data()].push_back({kLoad, {new IntImm(4)}, {new IntImm(6)}});
res = mergeTensorAccesses(info);
ASSERT_EQ(res[a.data()].size(), 1);
verifyConstBounds(res.at(a.data())[0], {{2, 6}});
// Test equality is deduped.
info.clear();
info[a.data()].push_back({kLoad, {new IntImm(4)}, {new IntImm(6)}});
info[a.data()].push_back({kLoad, {new IntImm(4)}, {new IntImm(6)}});
res = mergeTensorAccesses(info);
ASSERT_EQ(res[a.data()].size(), 1);
verifyConstBounds(res.at(a.data())[0], {{4, 6}});
}
void testMergeInferredLoadStoreDiff() {
KernelScope kernel_scope;
Buffer a(BufHandle("a", {10}, kFloat));
// Loads and Stores do not merge:
BoundsInfo info;
info[a.data()].push_back({kLoad, {new IntImm(1)}, {new IntImm(7)}});
info[a.data()].push_back({kStore, {new IntImm(3)}, {new IntImm(9)}});
BoundsInfo res = mergeTensorAccesses(info);
ASSERT_EQ(res.size(), 1);
ASSERT_EQ(res[a.data()].size(), 2);
ASSERT_EQ(res.at(a.data())[0].kind, kLoad);
ASSERT_EQ(res.at(a.data())[1].kind, kStore);
verifyConstBounds(res.at(a.data())[0], {{1, 7}});
verifyConstBounds(res.at(a.data())[1], {{3, 9}});
// Do merge around the other kind of access:
info.clear();
info[a.data()].push_back({kLoad, {new IntImm(1)}, {new IntImm(3)}});
info[a.data()].push_back({kStore, {new IntImm(3)}, {new IntImm(4)}});
info[a.data()].push_back({kLoad, {new IntImm(3)}, {new IntImm(5)}});
info[a.data()].push_back({kStore, {new IntImm(4)}, {new IntImm(8)}});
info[a.data()].push_back({kLoad, {new IntImm(5)}, {new IntImm(7)}});
res = mergeTensorAccesses(info);
ASSERT_EQ(res[a.data()].size(), 2);
verifyConstBounds(res.at(a.data())[0], {{1, 7}});
verifyConstBounds(res.at(a.data())[1], {{3, 8}});
}
void testMergeInferred2DBounds() {
KernelScope kernel_scope;
Buffer a(BufHandle("a", {10, 10}, kFloat));
// Non overlapping in both dimensions:
BoundsInfo info;
info[a.data()].push_back(
{kLoad, {new IntImm(1), new IntImm(1)}, {new IntImm(3), new IntImm(3)}});
info[a.data()].push_back(
{kLoad, {new IntImm(5), new IntImm(5)}, {new IntImm(9), new IntImm(9)}});
BoundsInfo res = mergeTensorAccesses(info);
ASSERT_EQ(res.size(), 1);
ASSERT_EQ(res[a.data()].size(), 2);
ASSERT_EQ(res.at(a.data())[0].kind, kLoad);
ASSERT_EQ(res.at(a.data())[1].kind, kLoad);
verifyConstBounds(res.at(a.data())[0], {{1, 3}, {1, 3}});
verifyConstBounds(res.at(a.data())[1], {{5, 9}, {5, 9}});
// Overlapping in a single dimension should mean we cannot merge.
// First dimension:
info.clear();
info[a.data()].push_back(
{kLoad, {new IntImm(1), new IntImm(1)}, {new IntImm(3), new IntImm(3)}});
info[a.data()].push_back(
{kLoad, {new IntImm(2), new IntImm(5)}, {new IntImm(9), new IntImm(9)}});
res = mergeTensorAccesses(info);
ASSERT_EQ(res[a.data()].size(), 2);
verifyConstBounds(res.at(a.data())[0], {{1, 3}, {1, 3}});
verifyConstBounds(res.at(a.data())[1], {{2, 9}, {5, 9}});
// Second dimension:
info.clear();
info[a.data()].push_back(
{kLoad, {new IntImm(1), new IntImm(1)}, {new IntImm(3), new IntImm(3)}});
info[a.data()].push_back(
{kLoad, {new IntImm(5), new IntImm(2)}, {new IntImm(9), new IntImm(9)}});
res = mergeTensorAccesses(info);
ASSERT_EQ(res[a.data()].size(), 2);
verifyConstBounds(res.at(a.data())[0], {{1, 3}, {1, 3}});
verifyConstBounds(res.at(a.data())[1], {{5, 9}, {2, 9}});
// Overlapping in both dimensions:
// {1-6, 1-3) | {4-9, 2,7} => {1,9, 1,7}
// TODO: this will overestimate and we should fix it.
info.clear();
info[a.data()].push_back(
{kLoad, {new IntImm(1), new IntImm(1)}, {new IntImm(6), new IntImm(3)}});
info[a.data()].push_back(
{kLoad, {new IntImm(4), new IntImm(2)}, {new IntImm(9), new IntImm(7)}});
res = mergeTensorAccesses(info);
ASSERT_EQ(res[a.data()].size(), 1);
verifyConstBounds(res.at(a.data())[0], {{1, 9}, {1, 7}});
}
void testMergeAdjacentBounds() {
KernelScope kernel_scope;
Buffer a(BufHandle("a", {10}, kFloat));
// Adjacent but not overlapping bounds can be merged.
// e.g. {1-4} | {5-9} => {1-9}
BoundsInfo info;
info[a.data()].push_back({kLoad, {new IntImm(1)}, {new IntImm(4)}});
info[a.data()].push_back({kLoad, {new IntImm(5)}, {new IntImm(9)}});
BoundsInfo res = mergeTensorAccesses(info);
ASSERT_EQ(res[a.data()].size(), 1);
verifyConstBounds(res.at(a.data())[0], {{1, 9}});
// And on the other side:
info.clear();
info[a.data()].push_back({kLoad, {new IntImm(5)}, {new IntImm(9)}});
info[a.data()].push_back({kLoad, {new IntImm(1)}, {new IntImm(4)}});
res = mergeTensorAccesses(info);
ASSERT_EQ(res[a.data()].size(), 1);
verifyConstBounds(res.at(a.data())[0], {{1, 9}});
// One space gap is enough to prevent merging:
info.clear();
info[a.data()].push_back({kLoad, {new IntImm(1)}, {new IntImm(4)}});
info[a.data()].push_back({kLoad, {new IntImm(6)}, {new IntImm(9)}});
res = mergeTensorAccesses(info);
ASSERT_EQ(res[a.data()].size(), 2);
verifyConstBounds(res.at(a.data())[0], {{1, 4}});
verifyConstBounds(res.at(a.data())[1], {{6, 9}});
}
std::pair<std::string, std::string> boundAsStringPair(
TensorAccessBoundsInfo& info,
size_t idx = 0) {
std::ostringstream start, stop;
start << *info.start[idx];
stop << *info.stop[idx];
return {start.str(), stop.str()};
}
void testMergeSymbolicBounds() {
KernelScope kernel_scope;
Buffer a(BufHandle("a", {10}, kFloat));
VarHandle W("W", kInt);
VarHandle X("X", kInt);
VarHandle Y("Y", kInt);
VarHandle Z("Z", kInt);
// Can do nothing with fully symbolic bounds:
BoundsInfo info;
info[a.data()].push_back({kLoad, {W.node()}, {Z.node()}});
info[a.data()].push_back({kLoad, {X.node()}, {Y.node()}});
BoundsInfo res = mergeTensorAccesses(info);
ASSERT_EQ(res[a.data()].size(), 2);
// Can merge if the difference between bounds is constant and enclosing.
// {X-Y} | {X-5 - Y+10} => {X-5 - Y+10}
info.clear();
info[a.data()].push_back({kLoad, {X.node()}, {Y.node()}});
info[a.data()].push_back({kLoad,
{new Sub(X.node(), new IntImm(5))},
{new Add(Y.node(), new IntImm(10))}});
res = mergeTensorAccesses(info);
ASSERT_EQ(res[a.data()].size(), 1);
// Cannot merge otherwise.
// {X-Y} | {X+5 - Y+10} => could be 2 groups if Y < X+5.
info.clear();
info[a.data()].push_back({kLoad, {X.node()}, {Y.node()}});
info[a.data()].push_back({kLoad,
{new Add(X.node(), new IntImm(5))},
{new Add(Y.node(), new IntImm(10))}});
res = mergeTensorAccesses(info);
ASSERT_EQ(res[a.data()].size(), 2);
// Can't merge if there's a gap of at least one element:
info.clear();
info[a.data()].push_back({kLoad, {X.node()}, {new IntImm(4)}});
info[a.data()].push_back({kLoad, {new IntImm(6)}, {Y.node()}});
res = mergeTensorAccesses(info);
ASSERT_EQ(res[a.data()].size(), 2);
// Can't even though the high of the first bound is above the low of the
// second, X can == 6 and Y can == 4 so this can't merge in all cases.
info.clear();
info[a.data()].push_back({kLoad, {X.node()}, {new IntImm(6)}});
info[a.data()].push_back({kLoad, {new IntImm(4)}, {Y.node()}});
res = mergeTensorAccesses(info);
ASSERT_EQ(res[a.data()].size(), 2);
// If either side is equal, they must be overlapping.
info.clear();
info[a.data()].push_back({kLoad, {X.node()}, {Z.node()}});
info[a.data()].push_back({kLoad, {X.node()}, {Y.node()}});
res = mergeTensorAccesses(info);
ASSERT_EQ(res[a.data()].size(), 1);
auto pair = boundAsStringPair(res[a.data()][0]);
ASSERT_EQ(pair.first, "X");
ASSERT_EQ(pair.second, "Max(Y, Z, 1)");
info.clear();
info[a.data()].push_back({kLoad, {X.node()}, {Y.node()}});
info[a.data()].push_back({kLoad, {Z.node()}, {Y.node()}});
res = mergeTensorAccesses(info);
ASSERT_EQ(res[a.data()].size(), 1);
pair = boundAsStringPair(res[a.data()][0]);
ASSERT_EQ(pair.first, "Min(Z, X, 1)");
ASSERT_EQ(pair.second, "Y");
// If either side is only one apart, they must be adjacent.
info.clear();
info[a.data()].push_back(
{kLoad, {new Add(X.node(), new IntImm(1))}, {Z.node()}});
info[a.data()].push_back({kLoad, {X.node()}, {Y.node()}});
res = mergeTensorAccesses(info);
ASSERT_EQ(res[a.data()].size(), 1);
pair = boundAsStringPair(res[a.data()][0]);
ASSERT_EQ(pair.first, "X");
ASSERT_EQ(pair.second, "Max(Y, Z, 1)");
info.clear();
info[a.data()].push_back({kLoad, {X.node()}, {Y.node()}});
info[a.data()].push_back(
{kLoad, {Z.node()}, {new Sub(Y.node(), new IntImm(1))}});
res = mergeTensorAccesses(info);
ASSERT_EQ(res[a.data()].size(), 1);
pair = boundAsStringPair(res[a.data()][0]);
ASSERT_EQ(pair.first, "Min(Z, X, 1)");
ASSERT_EQ(pair.second, "Y");
// If either side is 2 apart, they may not be overlapping.
// in this case if Y == X+1 they don't overlap.
info.clear();
info[a.data()].push_back(
{kLoad, {new Add(X.node(), new IntImm(2))}, {Z.node()}});
info[a.data()].push_back(
{kLoad, {X.node()}, {new Sub(Y.node(), new IntImm(1))}});
res = mergeTensorAccesses(info);
ASSERT_EQ(res[a.data()].size(), 2);
// In this case they may not overlap if X == Y.
info.clear();
info[a.data()].push_back({kLoad, {X.node()}, {Y.node()}});
info[a.data()].push_back(
{kLoad, {Z.node()}, {new Sub(Y.node(), new IntImm(2))}});
res = mergeTensorAccesses(info);
ASSERT_EQ(res[a.data()].size(), 2);
}
void testMergeSymbolicAdjacent() {
KernelScope kernel_scope;
Buffer a(BufHandle("a", {10}, kFloat));
VarHandle X("X", kInt);
VarHandle Y("Y", kInt);
BoundsInfo info;
// Can merge if a range is adjacent:
// {X-5} | {6-Y} => {X-Y}
info[a.data()].push_back({kLoad, {X.node()}, {new IntImm(5)}});
info[a.data()].push_back({kLoad, {new IntImm(6)}, {Y.node()}});
BoundsInfo res = mergeTensorAccesses(info);
ASSERT_EQ(res[a.data()].size(), 1);
auto pair = boundAsStringPair(res[a.data()][0]);
ASSERT_EQ(pair.first, "X");
ASSERT_EQ(pair.second, "Y");
info.clear();
info[a.data()].push_back({kLoad, {new IntImm(6)}, {Y.node()}});
info[a.data()].push_back({kLoad, {X.node()}, {new IntImm(5)}});
res = mergeTensorAccesses(info);
ASSERT_EQ(res[a.data()].size(), 1);
pair = boundAsStringPair(res[a.data()][0]);
ASSERT_EQ(pair.first, "X");
ASSERT_EQ(pair.second, "Y");
info.clear();
info[a.data()].push_back({kLoad, {new IntImm(5)}, {Y.node()}});
info[a.data()].push_back({kLoad, {X.node()}, {new IntImm(6)}});
res = mergeTensorAccesses(info);
ASSERT_EQ(res[a.data()].size(), 1);
pair = boundAsStringPair(res[a.data()][0]);
ASSERT_EQ(pair.first, "X");
ASSERT_EQ(pair.second, "Y");
info.clear();
info[a.data()].push_back({kLoad, {X.node()}, {new IntImm(6)}});
info[a.data()].push_back({kLoad, {new IntImm(5)}, {Y.node()}});
res = mergeTensorAccesses(info);
ASSERT_EQ(res[a.data()].size(), 1);
pair = boundAsStringPair(res[a.data()][0]);
ASSERT_EQ(pair.first, "X");
ASSERT_EQ(pair.second, "Y");
// If either the lower or upper bound is adjacent the range then they must
// overlap, even if we don't know the extent.
info.clear();
info[a.data()].push_back({kLoad, {new IntImm(6)}, {X.node()}});
info[a.data()].push_back({kLoad, {new IntImm(5)}, {Y.node()}});
res = mergeTensorAccesses(info);
ASSERT_EQ(res[a.data()].size(), 1);
pair = boundAsStringPair(res[a.data()][0]);
ASSERT_EQ(pair.first, "5");
ASSERT_EQ(pair.second, "Max(Y, X, 1)");
info.clear();
info[a.data()].push_back({kLoad, {X.node()}, {new IntImm(6)}});
info[a.data()].push_back({kLoad, {Y.node()}, {new IntImm(5)}});
res = mergeTensorAccesses(info);
ASSERT_EQ(res[a.data()].size(), 1);
pair = boundAsStringPair(res[a.data()][0]);
ASSERT_EQ(pair.first, "Min(Y, X, 1)");
ASSERT_EQ(pair.second, "6");
}
} // namespace jit
} // namespace torch