[TensorExpr] Make Load and Store multi-dimensional. (#35800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35800
This PR includes the following changes:
* Introduce a new `Expr` type `Buf`: it plays a similar to `Var` role, but also has dimensions.
* Use the new `Buf` class in `Store` and `Load` instead of `Var` for specifying where to store to or load from. `Buf` contains the dimensions info of the buffer we're loading/storing to and hence we are able to keep N-d indexes without flattening them into a 1-d index ([x,y] vs [x+y*W]).
* Flattening of the indexes is now a separate pass that is executed in `LoopNest::prepareForCodegen` - backends still expect indexes to be flattened, and this PR preserves that.
* `Tensor` now contains a `Buf` instead of `Var`, and thus Tensor now has the dimensions info (previously it was a property of a `Function`, not a `Tensor`). This brings us closer to Tensor being a combination of Buffer + Function, where Buffer specifies iteration domain and the Function defines a computation.
TODOs:
* Consider merging `Buffer` with `Buf` or `BufHandle`. It seems that we don't need all of them.
* Harden the logic of how we create buffers in fuser pass. Currently it seems that sometimes we don't set dimensions.
* Use `Buf` in `Allocate` and `Free`.
* Make it clearer that `Function` doesn't "own" dimensions info and that dimensions are a property of a Tensor, not a Function.
Differential Revision: D20789005
Test Plan: Imported from OSS
Reviewed By: zheng-xq
Pulled By: ZolotukhinM
fbshipit-source-id: e04188d1d297f195f1c46669c614557d6bb6cde4
diff --git a/test/cpp/tensorexpr/test_aten.cpp b/test/cpp/tensorexpr/test_aten.cpp
index 9d098d8..1b84666 100644
--- a/test/cpp/tensorexpr/test_aten.cpp
+++ b/test/cpp/tensorexpr/test_aten.cpp
@@ -15,13 +15,13 @@
void testATen_cast_Float() {
KernelScope kernel_scope;
const int kTotalSize = 128;
- Buffer a_buf(VarHandle("A", kHandle), kInt, {ExprHandle(kTotalSize)});
- Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kInt);
+ Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
VarHandle index = VarHandle("index", kInt);
- ExprHandle load_a = Load::make(a_buf, index, 1);
+ ExprHandle load_a = Load::make(a_buf, {index}, 1);
ExprHandle to_float = Cast::make(kFloat, load_a);
- Stmt* store_b = Store::make(b_buf, index, to_float, 1);
+ Stmt* store_b = Store::make(b_buf, {index}, to_float, 1);
Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
PaddedBuffer<int> a_v(kTotalSize);
@@ -43,13 +43,13 @@
void testATennegInt() {
KernelScope kernel_scope;
const int kTotalSize = 128;
- Buffer a_buf(VarHandle("A", kHandle), kInt, {ExprHandle(kTotalSize)});
- Buffer b_buf(VarHandle("B", kHandle), kInt, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kInt);
+ Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kInt);
VarHandle index = VarHandle("index", kInt);
- ExprHandle load_a = Load::make(a_buf, index, 1);
+ ExprHandle load_a = Load::make(a_buf, {index}, 1);
ExprHandle to_float = Sub::make(0, load_a);
- Stmt* store_b = Store::make(b_buf, index, to_float, 1);
+ Stmt* store_b = Store::make(b_buf, {index}, to_float, 1);
Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
PaddedBuffer<int> a_v(kTotalSize);
@@ -71,13 +71,13 @@
void testATennegFloat() {
KernelScope kernel_scope;
const int kTotalSize = 128;
- Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
VarHandle index = VarHandle("index", kInt);
- ExprHandle load_a = Load::make(a_buf, index, 1);
+ ExprHandle load_a = Load::make(a_buf, {index}, 1);
ExprHandle to_float = Sub::make(0, load_a);
- Stmt* store_b = Store::make(b_buf, index, to_float, 1);
+ Stmt* store_b = Store::make(b_buf, {index}, to_float, 1);
Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
PaddedBuffer<float> a_v(kTotalSize);
@@ -99,16 +99,16 @@
void testATenaddInt() {
KernelScope kernel_scope;
const int kTotalSize = 128;
- Buffer a_buf(VarHandle("A", kHandle), kInt, {ExprHandle(kTotalSize)});
- Buffer b_buf(VarHandle("B", kHandle), kInt, {ExprHandle(kTotalSize)});
- Buffer c_buf(VarHandle("C", kHandle), kInt, {ExprHandle(kTotalSize)});
- Buffer d_buf(VarHandle("D", kHandle), kInt, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kInt);
+ Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kInt);
+ Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kInt);
+ Buffer d_buf(BufHandle("D", {ExprHandle(kTotalSize)}), kInt);
VarHandle index = VarHandle("index", kInt);
- ExprHandle load_a = Load::make(a_buf, index, 1);
- ExprHandle load_b = Load::make(b_buf, index, 1);
- ExprHandle load_c = Load::make(c_buf, index, 1);
- Stmt* store_d = Store::make(d_buf, index, load_a + load_b * load_c, 1);
+ ExprHandle load_a = Load::make(a_buf, {index}, 1);
+ ExprHandle load_b = Load::make(b_buf, {index}, 1);
+ ExprHandle load_c = Load::make(c_buf, {index}, 1);
+ Stmt* store_d = Store::make(d_buf, {index}, load_a + load_b * load_c, 1);
Stmt* stmt = For::make(index, 0, kTotalSize, store_d);
PaddedBuffer<int> a_v(kTotalSize);
@@ -136,16 +136,16 @@
void testATenaddFloat() {
KernelScope kernel_scope;
const int kTotalSize = 128;
- Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer d_buf(VarHandle("D", kHandle), kFloat, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer d_buf(BufHandle("D", {ExprHandle(kTotalSize)}), kFloat);
VarHandle index = VarHandle("index", kInt);
- ExprHandle load_a = Load::make(a_buf, index, 1);
- ExprHandle load_b = Load::make(b_buf, index, 1);
- ExprHandle load_c = Load::make(c_buf, index, 1);
- Stmt* store_d = Store::make(d_buf, index, load_a + load_b * load_c, 1);
+ ExprHandle load_a = Load::make(a_buf, {index}, 1);
+ ExprHandle load_b = Load::make(b_buf, {index}, 1);
+ ExprHandle load_c = Load::make(c_buf, {index}, 1);
+ Stmt* store_d = Store::make(d_buf, {index}, load_a + load_b * load_c, 1);
Stmt* stmt = For::make(index, 0, kTotalSize, store_d);
PaddedBuffer<float> a_v(kTotalSize);
@@ -173,16 +173,16 @@
void testATensubInt() {
KernelScope kernel_scope;
const int kTotalSize = 128;
- Buffer a_buf(VarHandle("A", kHandle), kInt, {ExprHandle(kTotalSize)});
- Buffer b_buf(VarHandle("B", kHandle), kInt, {ExprHandle(kTotalSize)});
- Buffer c_buf(VarHandle("C", kHandle), kInt, {ExprHandle(kTotalSize)});
- Buffer d_buf(VarHandle("D", kHandle), kInt, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kInt);
+ Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kInt);
+ Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kInt);
+ Buffer d_buf(BufHandle("D", {ExprHandle(kTotalSize)}), kInt);
VarHandle index = VarHandle("index", kInt);
- ExprHandle load_a = Load::make(a_buf, index, 1);
- ExprHandle load_b = Load::make(b_buf, index, 1);
- ExprHandle load_c = Load::make(c_buf, index, 1);
- Stmt* store_d = Store::make(d_buf, index, load_a - load_b * load_c, 1);
+ ExprHandle load_a = Load::make(a_buf, {index}, 1);
+ ExprHandle load_b = Load::make(b_buf, {index}, 1);
+ ExprHandle load_c = Load::make(c_buf, {index}, 1);
+ Stmt* store_d = Store::make(d_buf, {index}, load_a - load_b * load_c, 1);
Stmt* stmt = For::make(index, 0, kTotalSize, store_d);
PaddedBuffer<int> a_v(kTotalSize);
@@ -210,16 +210,16 @@
void testATensubFloat() {
KernelScope kernel_scope;
const int kTotalSize = 128;
- Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer d_buf(VarHandle("D", kHandle), kFloat, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer d_buf(BufHandle("D", {ExprHandle(kTotalSize)}), kFloat);
VarHandle index = VarHandle("index", kInt);
- ExprHandle load_a = Load::make(a_buf, index, 1);
- ExprHandle load_b = Load::make(b_buf, index, 1);
- ExprHandle load_c = Load::make(c_buf, index, 1);
- Stmt* store_d = Store::make(d_buf, index, load_a - load_b * load_c, 1);
+ ExprHandle load_a = Load::make(a_buf, {index}, 1);
+ ExprHandle load_b = Load::make(b_buf, {index}, 1);
+ ExprHandle load_c = Load::make(c_buf, {index}, 1);
+ Stmt* store_d = Store::make(d_buf, {index}, load_a - load_b * load_c, 1);
Stmt* stmt = For::make(index, 0, kTotalSize, store_d);
PaddedBuffer<float> a_v(kTotalSize);
@@ -247,17 +247,17 @@
void testATenlerp() {
KernelScope kernel_scope;
const int kTotalSize = 128;
- Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer d_buf(VarHandle("D", kHandle), kFloat, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer d_buf(BufHandle("D", {ExprHandle(kTotalSize)}), kFloat);
VarHandle index = VarHandle("index", kInt);
- ExprHandle load_a = Load::make(a_buf, index, 1);
- ExprHandle load_b = Load::make(b_buf, index, 1);
- ExprHandle load_c = Load::make(c_buf, index, 1);
+ ExprHandle load_a = Load::make(a_buf, {index}, 1);
+ ExprHandle load_b = Load::make(b_buf, {index}, 1);
+ ExprHandle load_c = Load::make(c_buf, {index}, 1);
Stmt* store_d =
- Store::make(d_buf, index, load_a + load_c * (load_b - load_a), 1);
+ Store::make(d_buf, {index}, load_a + load_c * (load_b - load_a), 1);
Stmt* stmt = For::make(index, 0, kTotalSize, store_d);
PaddedBuffer<float> a_v(kTotalSize);
@@ -285,19 +285,19 @@
void testATenaddcmulInt() {
KernelScope kernel_scope;
const int kTotalSize = 128;
- Buffer a_buf(VarHandle("A", kHandle), kInt, {ExprHandle(kTotalSize)});
- Buffer b_buf(VarHandle("B", kHandle), kInt, {ExprHandle(kTotalSize)});
- Buffer c_buf(VarHandle("C", kHandle), kInt, {ExprHandle(kTotalSize)});
- Buffer d_buf(VarHandle("D", kHandle), kInt, {ExprHandle(kTotalSize)});
- Buffer e_buf(VarHandle("E", kHandle), kInt, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kInt);
+ Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kInt);
+ Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kInt);
+ Buffer d_buf(BufHandle("D", {ExprHandle(kTotalSize)}), kInt);
+ Buffer e_buf(BufHandle("E", {ExprHandle(kTotalSize)}), kInt);
VarHandle index = VarHandle("index", kInt);
- ExprHandle load_a = Load::make(a_buf, index, 1);
- ExprHandle load_b = Load::make(b_buf, index, 1);
- ExprHandle load_c = Load::make(c_buf, index, 1);
- ExprHandle load_d = Load::make(d_buf, index, 1);
+ ExprHandle load_a = Load::make(a_buf, {index}, 1);
+ ExprHandle load_b = Load::make(b_buf, {index}, 1);
+ ExprHandle load_c = Load::make(c_buf, {index}, 1);
+ ExprHandle load_d = Load::make(d_buf, {index}, 1);
Stmt* store_e =
- Store::make(e_buf, index, load_a + load_b * load_c * load_d, 1);
+ Store::make(e_buf, {index}, load_a + load_b * load_c * load_d, 1);
Stmt* stmt = For::make(index, 0, kTotalSize, store_e);
PaddedBuffer<int> a_v(kTotalSize);
@@ -328,19 +328,19 @@
void testATenaddcmulFloat() {
KernelScope kernel_scope;
const int kTotalSize = 128;
- Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer d_buf(VarHandle("D", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer e_buf(VarHandle("E", kHandle), kFloat, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer d_buf(BufHandle("D", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer e_buf(BufHandle("E", {ExprHandle(kTotalSize)}), kFloat);
VarHandle index = VarHandle("index", kInt);
- ExprHandle load_a = Load::make(a_buf, index, 1);
- ExprHandle load_b = Load::make(b_buf, index, 1);
- ExprHandle load_c = Load::make(c_buf, index, 1);
- ExprHandle load_d = Load::make(d_buf, index, 1);
+ ExprHandle load_a = Load::make(a_buf, {index}, 1);
+ ExprHandle load_b = Load::make(b_buf, {index}, 1);
+ ExprHandle load_c = Load::make(c_buf, {index}, 1);
+ ExprHandle load_d = Load::make(d_buf, {index}, 1);
Stmt* store_e =
- Store::make(e_buf, index, load_a + load_b * load_c * load_d, 1);
+ Store::make(e_buf, {index}, load_a + load_b * load_c * load_d, 1);
Stmt* stmt = For::make(index, 0, kTotalSize, store_e);
PaddedBuffer<float> a_v(kTotalSize);
@@ -371,14 +371,14 @@
void testATenmulInt() {
KernelScope kernel_scope;
const int kTotalSize = 128;
- Buffer a_buf(VarHandle("A", kHandle), kInt, {ExprHandle(kTotalSize)});
- Buffer b_buf(VarHandle("B", kHandle), kInt, {ExprHandle(kTotalSize)});
- Buffer c_buf(VarHandle("C", kHandle), kInt, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kInt);
+ Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kInt);
+ Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kInt);
VarHandle index = VarHandle("index", kInt);
- ExprHandle load_a = Load::make(a_buf, index, 1);
- ExprHandle load_b = Load::make(b_buf, index, 1);
- Stmt* store_c = Store::make(c_buf, index, load_a * load_b, 1);
+ ExprHandle load_a = Load::make(a_buf, {index}, 1);
+ ExprHandle load_b = Load::make(b_buf, {index}, 1);
+ Stmt* store_c = Store::make(c_buf, {index}, load_a * load_b, 1);
Stmt* stmt = For::make(index, 0, kTotalSize, store_c);
PaddedBuffer<int> a_v(kTotalSize);
@@ -403,14 +403,14 @@
void testATenmulFloat() {
KernelScope kernel_scope;
const int kTotalSize = 128;
- Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kFloat);
VarHandle index = VarHandle("index", kInt);
- ExprHandle load_a = Load::make(a_buf, index, 1);
- ExprHandle load_b = Load::make(b_buf, index, 1);
- Stmt* store_c = Store::make(c_buf, index, load_a * load_b, 1);
+ ExprHandle load_a = Load::make(a_buf, {index}, 1);
+ ExprHandle load_b = Load::make(b_buf, {index}, 1);
+ Stmt* store_c = Store::make(c_buf, {index}, load_a * load_b, 1);
Stmt* stmt = For::make(index, 0, kTotalSize, store_c);
PaddedBuffer<float> a_v(kTotalSize);
@@ -435,14 +435,14 @@
void testATendivInt() {
KernelScope kernel_scope;
const int kTotalSize = 128;
- Buffer a_buf(VarHandle("A", kHandle), kInt, {ExprHandle(kTotalSize)});
- Buffer b_buf(VarHandle("B", kHandle), kInt, {ExprHandle(kTotalSize)});
- Buffer c_buf(VarHandle("C", kHandle), kInt, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kInt);
+ Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kInt);
+ Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kInt);
VarHandle index = VarHandle("index", kInt);
- ExprHandle load_a = Load::make(a_buf, index, 1);
- ExprHandle load_b = Load::make(b_buf, index, 1);
- Stmt* store_c = Store::make(c_buf, index, load_a / load_b, 1);
+ ExprHandle load_a = Load::make(a_buf, {index}, 1);
+ ExprHandle load_b = Load::make(b_buf, {index}, 1);
+ Stmt* store_c = Store::make(c_buf, {index}, load_a / load_b, 1);
Stmt* stmt = For::make(index, 0, kTotalSize, store_c);
PaddedBuffer<int> a_v(kTotalSize);
@@ -467,14 +467,14 @@
void testATendivFloat() {
KernelScope kernel_scope;
const int kTotalSize = 128;
- Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kFloat);
VarHandle index = VarHandle("index", kInt);
- ExprHandle load_a = Load::make(a_buf, index, 1);
- ExprHandle load_b = Load::make(b_buf, index, 1);
- Stmt* store_c = Store::make(c_buf, index, load_a / load_b, 1);
+ ExprHandle load_a = Load::make(a_buf, {index}, 1);
+ ExprHandle load_b = Load::make(b_buf, {index}, 1);
+ Stmt* store_c = Store::make(c_buf, {index}, load_a / load_b, 1);
Stmt* stmt = For::make(index, 0, kTotalSize, store_c);
PaddedBuffer<float> a_v(kTotalSize);
@@ -499,14 +499,14 @@
void testATenmaxInt() {
KernelScope kernel_scope;
const int kTotalSize = 128;
- Buffer a_buf(VarHandle("A", kHandle), kInt, {ExprHandle(kTotalSize)});
- Buffer b_buf(VarHandle("B", kHandle), kInt, {ExprHandle(kTotalSize)});
- Buffer c_buf(VarHandle("C", kHandle), kInt, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kInt);
+ Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kInt);
+ Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kInt);
VarHandle index = VarHandle("index", kInt);
- ExprHandle load_a = Load::make(a_buf, index, 1);
- ExprHandle load_b = Load::make(b_buf, index, 1);
- Stmt* store_c = Store::make(c_buf, index, Max::make(load_a, load_b, true), 1);
+ ExprHandle load_a = Load::make(a_buf, {index}, 1);
+ ExprHandle load_b = Load::make(b_buf, {index}, 1);
+ Stmt* store_c = Store::make(c_buf, {index}, Max::make(load_a, load_b, true), 1);
Stmt* stmt = For::make(index, 0, kTotalSize, store_c);
PaddedBuffer<int> a_v(kTotalSize);
@@ -531,14 +531,14 @@
void testATenmaxFloat() {
KernelScope kernel_scope;
const int kTotalSize = 128;
- Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kFloat);
VarHandle index = VarHandle("index", kInt);
- ExprHandle load_a = Load::make(a_buf, index, 1);
- ExprHandle load_b = Load::make(b_buf, index, 1);
- Stmt* store_c = Store::make(c_buf, index, Max::make(load_a, load_b, true), 1);
+ ExprHandle load_a = Load::make(a_buf, {index}, 1);
+ ExprHandle load_b = Load::make(b_buf, {index}, 1);
+ Stmt* store_c = Store::make(c_buf, {index}, Max::make(load_a, load_b, true), 1);
Stmt* stmt = For::make(index, 0, kTotalSize, store_c);
PaddedBuffer<float> a_v(kTotalSize);
@@ -563,14 +563,14 @@
void testATenminInt() {
KernelScope kernel_scope;
const int kTotalSize = 128;
- Buffer a_buf(VarHandle("A", kHandle), kInt, {ExprHandle(kTotalSize)});
- Buffer b_buf(VarHandle("B", kHandle), kInt, {ExprHandle(kTotalSize)});
- Buffer c_buf(VarHandle("C", kHandle), kInt, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kInt);
+ Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kInt);
+ Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kInt);
VarHandle index = VarHandle("index", kInt);
- ExprHandle load_a = Load::make(a_buf, index, 1);
- ExprHandle load_b = Load::make(b_buf, index, 1);
- Stmt* store_c = Store::make(c_buf, index, Min::make(load_a, load_b, true), 1);
+ ExprHandle load_a = Load::make(a_buf, {index}, 1);
+ ExprHandle load_b = Load::make(b_buf, {index}, 1);
+ Stmt* store_c = Store::make(c_buf, {index}, Min::make(load_a, load_b, true), 1);
Stmt* stmt = For::make(index, 0, kTotalSize, store_c);
PaddedBuffer<int> a_v(kTotalSize);
@@ -595,14 +595,14 @@
void testATenminFloat() {
KernelScope kernel_scope;
const int kTotalSize = 128;
- Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kFloat);
VarHandle index = VarHandle("index", kInt);
- ExprHandle load_a = Load::make(a_buf, index, 1);
- ExprHandle load_b = Load::make(b_buf, index, 1);
- Stmt* store_c = Store::make(c_buf, index, Min::make(load_a, load_b, true), 1);
+ ExprHandle load_a = Load::make(a_buf, {index}, 1);
+ ExprHandle load_b = Load::make(b_buf, {index}, 1);
+ Stmt* store_c = Store::make(c_buf, {index}, Min::make(load_a, load_b, true), 1);
Stmt* stmt = For::make(index, 0, kTotalSize, store_c);
PaddedBuffer<float> a_v(kTotalSize);
@@ -627,15 +627,15 @@
void testATen_sigmoid_backward() {
KernelScope kernel_scope;
const int kTotalSize = 128;
- Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kFloat);
VarHandle index = VarHandle("index", kInt);
- ExprHandle load_a = Load::make(a_buf, index, 1);
- ExprHandle load_b = Load::make(b_buf, index, 1);
+ ExprHandle load_a = Load::make(a_buf, {index}, 1);
+ ExprHandle load_b = Load::make(b_buf, {index}, 1);
Stmt* store_c = Store::make(
- c_buf, index, load_a * load_b * (FloatImm::make(1.0f) - load_b), 1);
+ c_buf, {index}, load_a * load_b * (FloatImm::make(1.0f) - load_b), 1);
Stmt* stmt = For::make(index, 0, kTotalSize, store_c);
PaddedBuffer<float> a_v(kTotalSize);
@@ -660,15 +660,15 @@
void testATen_tanh_backward() {
KernelScope kernel_scope;
const int kTotalSize = 128;
- Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kFloat);
VarHandle index = VarHandle("index", kInt);
- ExprHandle load_a = Load::make(a_buf, index, 1);
- ExprHandle load_b = Load::make(b_buf, index, 1);
+ ExprHandle load_a = Load::make(a_buf, {index}, 1);
+ ExprHandle load_b = Load::make(b_buf, {index}, 1);
Stmt* store_c = Store::make(
- c_buf, index, load_a * (FloatImm::make(1.0f) - (load_b * load_b)), 1);
+ c_buf, {index}, load_a * (FloatImm::make(1.0f) - (load_b * load_b)), 1);
Stmt* stmt = For::make(index, 0, kTotalSize, store_c);
PaddedBuffer<float> a_v(kTotalSize);
@@ -693,12 +693,12 @@
void __ubsan_ignore_float_divide_by_zero__ testATenreciprocal() {
KernelScope kernel_scope;
const int kTotalSize = 128;
- Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
VarHandle index = VarHandle("index", kInt);
- ExprHandle load_a = Load::make(a_buf, index, 1);
- Stmt* store_b = Store::make(b_buf, index, FloatImm::make(1.0f) / load_a, 1);
+ ExprHandle load_a = Load::make(a_buf, {index}, 1);
+ Stmt* store_b = Store::make(b_buf, {index}, FloatImm::make(1.0f) / load_a, 1);
Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
PaddedBuffer<float> a_v(kTotalSize);
@@ -720,12 +720,12 @@
void testATenreluInt() {
KernelScope kernel_scope;
const int kTotalSize = 128;
- Buffer a_buf(VarHandle("A", kHandle), kInt, {ExprHandle(kTotalSize)});
- Buffer b_buf(VarHandle("B", kHandle), kInt, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kInt);
+ Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kInt);
VarHandle index = VarHandle("index", kInt);
- ExprHandle load_a = Load::make(a_buf, index, 1);
- Stmt* store_b = Store::make(b_buf, index, Max::make(load_a, 0, false), 1);
+ ExprHandle load_a = Load::make(a_buf, {index}, 1);
+ Stmt* store_b = Store::make(b_buf, {index}, Max::make(load_a, 0, false), 1);
Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
PaddedBuffer<int> a_v(kTotalSize);
@@ -747,14 +747,14 @@
void testATenreluFloat() {
KernelScope kernel_scope;
const int kTotalSize = 128;
- Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
VarHandle index = VarHandle("index", kInt);
- ExprHandle load_a = Load::make(a_buf, index, 1);
+ ExprHandle load_a = Load::make(a_buf, {index}, 1);
Stmt* store_b = Store::make(
b_buf,
- index,
+ {index},
Max::make(load_a, 0, false), // relu does not propagate nans
1);
Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
@@ -778,12 +778,12 @@
void testATenlogFloat() {
KernelScope kernel_scope;
const int kTotalSize = 128;
- Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
VarHandle index = VarHandle("index", kInt);
- ExprHandle load_a = Load::make(a_buf, index, 1);
- Stmt* store_b = Store::make(b_buf, index, log(load_a), 1);
+ ExprHandle load_a = Load::make(a_buf, {index}, 1);
+ Stmt* store_b = Store::make(b_buf, {index}, log(load_a), 1);
Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
PaddedBuffer<float> a_v(kTotalSize);
@@ -805,12 +805,12 @@
void testATenlog10Float() {
KernelScope kernel_scope;
const int kTotalSize = 128;
- Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
VarHandle index = VarHandle("index", kInt);
- ExprHandle load_a = Load::make(a_buf, index, 1);
- Stmt* store_b = Store::make(b_buf, index, log10(load_a), 1);
+ ExprHandle load_a = Load::make(a_buf, {index}, 1);
+ Stmt* store_b = Store::make(b_buf, {index}, log10(load_a), 1);
Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
PaddedBuffer<float> a_v(kTotalSize);
@@ -832,12 +832,12 @@
void testATenlog2Float() {
KernelScope kernel_scope;
const int kTotalSize = 128;
- Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
VarHandle index = VarHandle("index", kInt);
- ExprHandle load_a = Load::make(a_buf, index, 1);
- Stmt* store_b = Store::make(b_buf, index, log2(load_a), 1);
+ ExprHandle load_a = Load::make(a_buf, {index}, 1);
+ Stmt* store_b = Store::make(b_buf, {index}, log2(load_a), 1);
Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
PaddedBuffer<float> a_v(kTotalSize);
@@ -859,12 +859,12 @@
void testATenexpFloat() {
KernelScope kernel_scope;
const int kTotalSize = 128;
- Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
VarHandle index = VarHandle("index", kInt);
- ExprHandle load_a = Load::make(a_buf, index, 1);
- Stmt* store_b = Store::make(b_buf, index, exp(load_a), 1);
+ ExprHandle load_a = Load::make(a_buf, {index}, 1);
+ Stmt* store_b = Store::make(b_buf, {index}, exp(load_a), 1);
Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
PaddedBuffer<float> a_v(kTotalSize);
@@ -886,12 +886,12 @@
void testATenerfFloat() {
KernelScope kernel_scope;
const int kTotalSize = 128;
- Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
VarHandle index = VarHandle("index", kInt);
- ExprHandle load_a = Load::make(a_buf, index, 1);
- Stmt* store_b = Store::make(b_buf, index, erf(load_a), 1);
+ ExprHandle load_a = Load::make(a_buf, {index}, 1);
+ Stmt* store_b = Store::make(b_buf, {index}, erf(load_a), 1);
Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
PaddedBuffer<float> a_v(kTotalSize);
@@ -913,12 +913,12 @@
void testATencosFloat() {
KernelScope kernel_scope;
const int kTotalSize = 128;
- Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
VarHandle index = VarHandle("index", kInt);
- ExprHandle load_a = Load::make(a_buf, index, 1);
- Stmt* store_b = Store::make(b_buf, index, cos(load_a), 1);
+ ExprHandle load_a = Load::make(a_buf, {index}, 1);
+ Stmt* store_b = Store::make(b_buf, {index}, cos(load_a), 1);
Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
PaddedBuffer<float> a_v(kTotalSize);
@@ -940,9 +940,9 @@
void testATeneqInt() {
KernelScope kernel_scope;
constexpr int N = 128;
- Buffer a(VarHandle("A", kHandle), kInt, {N});
- Buffer b(VarHandle("B", kHandle), kInt, {N});
- Buffer c(VarHandle("C", kHandle), kInt, {N});
+ Buffer a(BufHandle("A", {N}), kInt);
+ Buffer b(BufHandle("B", {N}), kInt);
+ Buffer c(BufHandle("C", {N}), kInt);
std::vector<int> a_buffer(N, 1);
std::vector<int> b_buffer(N, 1);
std::vector<int> c_buffer(N, 0);
@@ -955,10 +955,10 @@
N,
Store::make(
c,
- i,
+ {i},
CompareSelect::make(
- Load::make(a, i, mask),
- Load::make(b, i, mask),
+ Load::make(a, {i}, mask),
+ Load::make(b, {i}, mask),
CompareSelectOperation::kEQ),
mask));
@@ -971,9 +971,9 @@
void testATengeInt() {
KernelScope kernel_scope;
constexpr int N = 128;
- Buffer a(VarHandle("A", kHandle), kInt, {N});
- Buffer b(VarHandle("B", kHandle), kInt, {N});
- Buffer c(VarHandle("C", kHandle), kInt, {N});
+ Buffer a(BufHandle("A", {N}), kInt);
+ Buffer b(BufHandle("B", {N}), kInt);
+ Buffer c(BufHandle("C", {N}), kInt);
std::vector<int> a_buffer(N, 5);
std::vector<int> b_buffer(N, 5);
std::vector<int> c_buffer(N, 0);
@@ -986,10 +986,10 @@
N,
Store::make(
c,
- i,
+ {i},
CompareSelect::make(
- Load::make(a, i, mask),
- Load::make(b, i, mask),
+ Load::make(a, {i}, mask),
+ Load::make(b, {i}, mask),
CompareSelectOperation::kGE),
mask));
@@ -1002,9 +1002,9 @@
void testATengtInt() {
KernelScope kernel_scope;
constexpr int N = 128;
- Buffer a(VarHandle("A", kHandle), kInt, {N});
- Buffer b(VarHandle("B", kHandle), kInt, {N});
- Buffer c(VarHandle("C", kHandle), kInt, {N});
+ Buffer a(BufHandle("A", {N}), kInt);
+ Buffer b(BufHandle("B", {N}), kInt);
+ Buffer c(BufHandle("C", {N}), kInt);
std::vector<int> a_buffer(N, 6);
std::vector<int> b_buffer(N, 3);
std::vector<int> c_buffer(N, 0);
@@ -1017,10 +1017,10 @@
N,
Store::make(
c,
- i,
+ {i},
CompareSelect::make(
- Load::make(a, i, mask),
- Load::make(b, i, mask),
+ Load::make(a, {i}, mask),
+ Load::make(b, {i}, mask),
CompareSelectOperation::kGT),
mask));
@@ -1033,9 +1033,9 @@
void testATenleInt() {
KernelScope kernel_scope;
constexpr int N = 128;
- Buffer a(VarHandle("A", kHandle), kInt, {N});
- Buffer b(VarHandle("B", kHandle), kInt, {N});
- Buffer c(VarHandle("C", kHandle), kInt, {N});
+ Buffer a(BufHandle("A", {N}), kInt);
+ Buffer b(BufHandle("B", {N}), kInt);
+ Buffer c(BufHandle("C", {N}), kInt);
std::vector<int> a_buffer(N, 5);
std::vector<int> b_buffer(N, 5);
std::vector<int> c_buffer(N, 0);
@@ -1048,10 +1048,10 @@
N,
Store::make(
c,
- i,
+ {i},
CompareSelect::make(
- Load::make(a, i, mask),
- Load::make(b, i, mask),
+ Load::make(a, {i}, mask),
+ Load::make(b, {i}, mask),
CompareSelectOperation::kLE),
mask));
@@ -1064,9 +1064,9 @@
void testATenltInt() {
KernelScope kernel_scope;
constexpr int N = 128;
- Buffer a(VarHandle("A", kHandle), kInt, {N});
- Buffer b(VarHandle("B", kHandle), kInt, {N});
- Buffer c(VarHandle("C", kHandle), kInt, {N});
+ Buffer a(BufHandle("A", {N}), kInt);
+ Buffer b(BufHandle("B", {N}), kInt);
+ Buffer c(BufHandle("C", {N}), kInt);
std::vector<int> a_buffer(N, 5);
std::vector<int> b_buffer(N, 5);
std::vector<int> c_buffer(N, 1);
@@ -1079,10 +1079,10 @@
N,
Store::make(
c,
- i,
+ {i},
CompareSelect::make(
- Load::make(a, i, mask),
- Load::make(b, i, mask),
+ Load::make(a, {i}, mask),
+ Load::make(b, {i}, mask),
CompareSelectOperation::kLT),
mask));
diff --git a/test/cpp/tensorexpr/test_cuda.cpp b/test/cpp/tensorexpr/test_cuda.cpp
index 6072945..033a18f 100644
--- a/test/cpp/tensorexpr/test_cuda.cpp
+++ b/test/cpp/tensorexpr/test_cuda.cpp
@@ -43,6 +43,7 @@
std::vector<For*> loops = l.getLoopStmtsFor(c);
l.setGPUBlockIndex(loops[1], 0);
l.setGPUThreadIndex(loops[2], 0);
+ l.prepareForCodegen();
Stmt* stmt = l.root_stmt();
CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf);
const int N = block_count * block_size * num_iter;
@@ -113,6 +114,7 @@
l.splitWithMask(loops[0], block_size, &n_outer, &n_inner);
l.setGPUBlockIndex(n_outer, 0);
l.setGPUThreadIndex(n_inner, 0);
+ l.prepareForCodegen();
Stmt* stmt = l.root_stmt();
CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf);
PaddedBuffer<float> a_v(N);
@@ -161,13 +163,14 @@
auto testWithSize = [](int32_t M, int32_t N) {
VarHandle m("m", kInt);
VarHandle n("n", kInt);
- Buffer a(VarHandle("a", kHandle), kFloat, {m, n});
- Buffer b(VarHandle("b", kHandle), kFloat, {m, n});
+ Buffer a(BufHandle("a", {m, n}), kFloat);
+ Buffer b(BufHandle("b", {m, n}), kFloat);
Tensor* c = Compute(
"c", {{m, "m"}, {n, "n"}}, [&](const VarHandle& i, const VarHandle& j) {
return a(i, j) + b(i, j);
});
LoopNest l({c});
+ l.prepareForCodegen();
Stmt* s = l.root_stmt();
CudaCodeGen cg(s, {a, b, c, m, n});
@@ -237,6 +240,7 @@
std::vector<For*> loops = l.getLoopStmtsFor(c);
l.setGPUBlockIndex(loops[1], 0);
l.setGPUThreadIndex(loops[2], 0);
+ l.prepareForCodegen();
Stmt* stmt = l.root_stmt();
CudaCodeGen cuda_cg(stmt, c);
const int N = block_count * block_size * num_iter;
@@ -280,7 +284,7 @@
KernelScope ks;
constexpr int N = 4096;
VarHandle n("n", kInt);
- Buffer a(VarHandle("a", kHandle), kFloat, {n});
+ Buffer a(BufHandle("a", {n}), kFloat);
Tensor* b =
Compute("b", {{n, "n"}}, [&](const VarHandle& i) { return a(i) * 2.0f; });
LoopNest l({b});
diff --git a/test/cpp/tensorexpr/test_expr.cpp b/test/cpp/tensorexpr/test_expr.cpp
index d4badc5..1f9e2c6 100644
--- a/test/cpp/tensorexpr/test_expr.cpp
+++ b/test/cpp/tensorexpr/test_expr.cpp
@@ -69,9 +69,9 @@
Buffer a_buf("a", kFloat, {1});
Buffer b_buf("b", kFloat, {1});
- ExprHandle load_a = Load::make(a_buf, 0, 1);
+ ExprHandle load_a = Load::make(a_buf, {0}, 1);
VarHandle var = VarHandle("v", kFloat);
- Stmt* store_b = Store::make(b_buf, 0, var, 1);
+ Stmt* store_b = Store::make(b_buf, {0}, var, 1);
Stmt* let_store = LetStmt::make(var, load_a, store_b);
SimpleIREvaluator eval(let_store, a_buf, b_buf);
@@ -182,9 +182,9 @@
const int kVectorCount = 128;
const int kTotalSize = kVectorSize * kVectorCount;
- Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kFloat);
/*
Build the following:
@@ -197,16 +197,16 @@
VarHandle index = VarHandle("index", kInt);
ExprHandle load_a = Load::make(
a_buf,
- Ramp::make(index * kVectorSize, 1, kVectorSize),
+ {Ramp::make(index * kVectorSize, 1, kVectorSize)},
Broadcast::make(1, kVectorSize));
ExprHandle load_b = Load::make(
b_buf,
- Ramp::make(index * kVectorSize, 1, kVectorSize),
+ {Ramp::make(index * kVectorSize, 1, kVectorSize)},
Broadcast::make(1, kVectorSize));
ExprHandle value = load_a + load_b;
Stmt* store_c = Store::make(
c_buf,
- Ramp::make(index * kVectorSize, 1, kVectorSize),
+ {Ramp::make(index * kVectorSize, 1, kVectorSize)},
value,
Broadcast::make(1, kVectorSize));
Stmt* stmt = For::make(index, 0, kVectorCount, store_c);
@@ -232,9 +232,9 @@
void testExprCompareSelectEQ() {
KernelScope kernel_scope;
constexpr int N = 1024;
- Buffer a(VarHandle("A", kHandle), kInt, {N});
- Buffer b(VarHandle("B", kHandle), kInt, {N});
- Buffer c(VarHandle("C", kHandle), kInt, {N});
+ Buffer a(BufHandle("A", {N}), kInt);
+ Buffer b(BufHandle("B", {N}), kInt);
+ Buffer c(BufHandle("C", {N}), kInt);
std::vector<int> a_buffer(N, 1);
std::vector<int> b_buffer(N, 1);
std::vector<int> c_buffer(N, 0);
@@ -248,10 +248,10 @@
N,
Store::make(
c,
- i,
+ {i},
CompareSelect::make(
- Load::make(a, i, mask),
- Load::make(b, i, mask),
+ Load::make(a, {i}, mask),
+ Load::make(b, {i}, mask),
CompareSelectOperation::kEQ),
mask));
@@ -403,11 +403,11 @@
KernelScope kernel_scope;
auto testWithSize = [](int32_t size) {
VarHandle n("n", kInt);
- Buffer a(VarHandle("a", kHandle), kFloat, {n});
- Buffer b(VarHandle("b", kHandle), kFloat, {n});
- Buffer c(VarHandle("c", kHandle), kFloat, {n});
+ Buffer a(BufHandle("a", {n}), kFloat);
+ Buffer b(BufHandle("b", {n}), kFloat);
+ Buffer c(BufHandle("c", {n}), kFloat);
VarHandle i("i", kInt);
- Stmt* s = For::make(i, 0, n, Store::make(c, i, a(i) + b(i), 1));
+ Stmt* s = For::make(i, 0, n, Store::make(c, {i}, a(i) + b(i), 1));
std::vector<float> aData(size, 1.0f);
std::vector<float> bData(size, 2.0f);
std::vector<float> cData(size, 0.0f);
@@ -426,9 +426,9 @@
Buffer a_buf("a", kFloat, {N});
VarHandle index = VarHandle("index", kInt);
Stmt* assign_x2 =
- Store::make(VarHandle(a_buf.data()), index, cast<float>(index) * 2, 1);
+ Store::make(BufHandle(a_buf.data()), {index}, cast<float>(index) * 2, 1);
Stmt* assign_x3 =
- Store::make(VarHandle(a_buf.data()), index, cast<float>(index) * 3, 1);
+ Store::make(BufHandle(a_buf.data()), {index}, cast<float>(index) * 3, 1);
ExprHandle even_cond = CompareSelect::make(Mod::make(index, 2), 0, kEQ);
Stmt* assign = Cond::make(even_cond, assign_x2, assign_x3);
Stmt* for_stmt = For::make(index, 0, N, assign);
@@ -476,7 +476,7 @@
Buffer a_buf("a", kInt, {N});
VarHandle index = VarHandle("index", kInt);
Stmt* body =
- Store::make(VarHandle(a_buf.data()), index, 5, 1);
+ Store::make(BufHandle(a_buf.data()), {index}, 5, 1);
Stmt* loop = For::make(index, 0, N, body);
Stmt* cloned_loop = Stmt::clone(loop);
@@ -490,7 +490,7 @@
// Let's add another assign to the body in the cloned loop and verify that the
// original statement hasn't changed while the cloned one has.
- Stmt* body_addition = Store::make(VarHandle(a_buf.data()), index, 33, 1);
+ Stmt* body_addition = Store::make(BufHandle(a_buf.data()), {index}, 33, 1);
Block* cloned_body =
static_cast<Block*>(static_cast<const For*>(cloned_loop)->body());
cloned_body->append_stmt(body_addition);
diff --git a/test/cpp/tensorexpr/test_llvm.cpp b/test/cpp/tensorexpr/test_llvm.cpp
index c1522b8..9953bf2 100644
--- a/test/cpp/tensorexpr/test_llvm.cpp
+++ b/test/cpp/tensorexpr/test_llvm.cpp
@@ -196,7 +196,7 @@
void testLLVMBufferTest() {
KernelScope kernel_scope;
- Buffer a(VarHandle("A", kHandle), kFloat, {32});
+ Buffer a(BufHandle("A", {32}), kFloat);
std::vector<int32_t> v(5);
std::vector<void*> args({v.data()});
auto rv = IntImm::make(0);
@@ -206,14 +206,14 @@
void testLLVMBlockTest() {
KernelScope kernel_scope;
- Buffer a(VarHandle("A", kHandle), kInt, {32});
+ Buffer a(BufHandle("A", {32}), kInt);
std::vector<int32_t> v = {1, 2};
std::vector<void*> args({v.data()});
auto block = Block::make({
- Store::make(a, IntImm::make(0), IntImm::make(3), IntImm::make(1)),
- Store::make(a, IntImm::make(1), IntImm::make(4), IntImm::make(1)),
- Store::make(a, IntImm::make(0), IntImm::make(4), IntImm::make(1)),
+ Store::make(a, {IntImm::make(0)}, IntImm::make(3), IntImm::make(1)),
+ Store::make(a, {IntImm::make(1)}, IntImm::make(4), IntImm::make(1)),
+ Store::make(a, {IntImm::make(0)}, IntImm::make(4), IntImm::make(1)),
});
LLVMCodeGen cg(block, {a});
@@ -224,15 +224,15 @@
void testLLVMLoadStoreTest() {
KernelScope kernel_scope;
- Buffer a(VarHandle("A", kHandle), kInt, {1});
- Buffer b(VarHandle("B", kHandle), kInt, {1});
+ Buffer a(BufHandle("A", {1}), kInt);
+ Buffer b(BufHandle("B", {1}), kInt);
std::vector<int32_t> a_buffer = {42};
std::vector<int32_t> b_buffer = {-11};
auto store = Store::make(
b,
- IntImm::make(0),
- Load::make(a, IntImm::make(0), IntImm::make(1)),
+ {IntImm::make(0)},
+ Load::make(a, {IntImm::make(0)}, IntImm::make(1)),
IntImm::make(1));
LLVMCodeGen cg(store, {a, b});
std::vector<void*> args({a_buffer.data(), b_buffer.data()});
@@ -243,19 +243,19 @@
void testLLVMIfThenElseTest() {
KernelScope kernel_scope;
- Buffer a(VarHandle("A", kHandle), kInt, {1});
- Buffer b(VarHandle("B", kHandle), kInt, {1});
- Buffer c(VarHandle("C", kHandle), kInt, {1});
+ Buffer a(BufHandle("A", {1}), kInt);
+ Buffer b(BufHandle("B", {1}), kInt);
+ Buffer c(BufHandle("C", {1}), kInt);
std::vector<int32_t> a_buffer = {42};
std::vector<int32_t> b_buffer = {-11};
std::vector<int32_t> c_buffer = {1};
auto store = Store::make(
b,
- IntImm::make(0),
+ {IntImm::make(0)},
IfThenElse::make(
- Load::make(c, IntImm::make(0), IntImm::make(1)), // cond
- Load::make(a, IntImm::make(0), IntImm::make(1)), // then
+ Load::make(c, {IntImm::make(0)}, IntImm::make(1)), // cond
+ Load::make(a, {IntImm::make(0)}, IntImm::make(1)), // then
IntImm::make(0)), // else
IntImm::make(1));
LLVMCodeGen cg(store, {a, b, c});
@@ -267,15 +267,15 @@
void testLLVMVecLoadStoreTest() {
KernelScope kernel_scope;
- Buffer a(VarHandle("A", kHandle), kInt, {1});
- Buffer b(VarHandle("B", kHandle), kInt, {1});
+ Buffer a(BufHandle("A", {1}), kInt);
+ Buffer b(BufHandle("B", {1}), kInt);
std::vector<int32_t> a_buffer = {1, 1, 1, 1};
std::vector<int32_t> b_buffer = {2, 2, 2, 2};
auto store = Store::make(
b,
- Ramp::make(0, 1, 4),
- Load::make(a, Ramp::make(0, 1, 4), Broadcast::make(IntImm::make(1), 4)),
+ {Ramp::make(0, 1, 4)},
+ Load::make(a, {Ramp::make(0, 1, 4)}, Broadcast::make(IntImm::make(1), 4)),
Broadcast::make(IntImm::make(1), 4));
LLVMCodeGen cg(store, {a, b});
std::vector<void*> args({a_buffer.data(), b_buffer.data()});
@@ -293,17 +293,17 @@
#define FLOAT_INTRINSICS_TEST(Name, Lanes) \
void testLLVMVecFloat_##Name##Lane##Lanes##Test() { \
KernelScope kernel_scope; \
- Buffer a(VarHandle("A", kHandle), kFloat, {1}); \
- Buffer b(VarHandle("B", kHandle), kFloat, {1}); \
+ Buffer a(BufHandle("A", {1}), kFloat); \
+ Buffer b(BufHandle("B", {1}), kFloat); \
float val = 0.5f; \
std::vector<float> a_buffer(Lanes, val); \
std::vector<float> b_buffer(Lanes, val); \
auto store = Store::make( \
b, \
- Ramp::make(0, 1, Lanes), \
+ {Ramp::make(0, 1, Lanes)}, \
Name(Load::make( \
a, \
- Ramp::make(0, 1, Lanes), \
+ {Ramp::make(0, 1, Lanes)}, \
Broadcast::make(IntImm::make(1), Lanes))), \
Broadcast::make(IntImm::make(1), Lanes)); \
LLVMCodeGen cg(store, {a, b}); \
@@ -339,17 +339,17 @@
#define DOUBLE_INTRINSICS_TEST(Name, Lanes) \
void testLLVMVecDouble_##Name##Lane##Lanes##Test() { \
KernelScope kernel_scope; \
- Buffer a(VarHandle("A", kHandle), kDouble, {1}); \
- Buffer b(VarHandle("B", kHandle), kDouble, {1}); \
+ Buffer a(BufHandle("A", {1}), kDouble); \
+ Buffer b(BufHandle("B", {1}), kDouble); \
float val = 0.5f; \
std::vector<double> a_buffer(Lanes, val); \
std::vector<double> b_buffer(Lanes, val); \
auto store = Store::make( \
b, \
- Ramp::make(0, 1, Lanes), \
+ {Ramp::make(0, 1, Lanes)}, \
Name(Load::make( \
a, \
- Ramp::make(0, 1, Lanes), \
+ {Ramp::make(0, 1, Lanes)}, \
Broadcast::make(IntImm::make(1), Lanes))), \
Broadcast::make(IntImm::make(1), Lanes)); \
LLVMCodeGen cg(store, {a, b}); \
@@ -384,13 +384,13 @@
void testLLVMVectorizerLoadStoreTest() {
KernelScope kernel_scope;
- Buffer a(VarHandle("A", kHandle), kInt, {1});
+ Buffer a(BufHandle("A", {1}), kInt);
Tensor* c = Compute("c", {{4, "i"}}, [&](const VarHandle& i) {
- return Load::make(a, i, 1);
+ return Load::make(a, {i}, 1);
});
- Buffer c_buf(VarHandle(c->func_var()), kInt, {4});
+ Buffer c_buf(BufHandle(c->func_var()), kInt);
LoopNest l({c});
Stmt* s = l.root_stmt();
l.vectorize(*dynamic_cast<Block*>(s)->stmts().begin());
@@ -410,15 +410,15 @@
void testLLVMMemcpyTest() {
KernelScope kernel_scope;
constexpr int N = 32;
- Buffer a(VarHandle("A", kHandle), kInt, {N});
- Buffer b(VarHandle("B", kHandle), kInt, {N});
+ Buffer a(BufHandle("A", {N}), kInt);
+ Buffer b(BufHandle("B", {N}), kInt);
std::vector<int32_t> a_buffer(N, 42);
std::vector<int32_t> b_buffer(N, 0);
auto mask = IntImm::make(1);
VarHandle i("i", kInt);
auto expr =
- For::make(i, 0, N, Store::make(b, i, Load::make(a, i, mask), mask));
+ For::make(i, 0, N, Store::make(b, {i}, Load::make(a, {i}, mask), mask));
LLVMCodeGen cg(expr, {a, b});
@@ -434,12 +434,12 @@
void testLLVMBzeroTest() {
KernelScope kernel_scope;
constexpr int N = 32;
- Buffer b(VarHandle("B", kHandle), kInt, {N});
+ Buffer b(BufHandle("B", {N}), kInt);
std::vector<int32_t> b_buffer(N, 11);
auto mask = IntImm::make(1);
VarHandle i("i", kInt);
- auto expr = For::make(i, 0, N, Store::make(b, i, IntImm::make(0), mask));
+ auto expr = For::make(i, 0, N, Store::make(b, {i}, IntImm::make(0), mask));
LLVMCodeGen cg(expr, {b});
@@ -453,9 +453,9 @@
void testLLVMElemwiseAdd() {
KernelScope kernel_scope;
constexpr int N = 1024;
- Buffer a(VarHandle("A", kHandle), kInt, {N});
- Buffer b(VarHandle("B", kHandle), kInt, {N});
- Buffer c(VarHandle("C", kHandle), kInt, {N});
+ Buffer a(BufHandle("A", {N}), kInt);
+ Buffer b(BufHandle("B", {N}), kInt);
+ Buffer c(BufHandle("C", {N}), kInt);
std::vector<int32_t> a_buffer(N, 41);
std::vector<int32_t> b_buffer(N, 1);
std::vector<int32_t> c_buffer(N, 1);
@@ -468,8 +468,8 @@
N,
Store::make(
c,
- i,
- Add::make(Load::make(a, i, mask), Load::make(b, i, mask)),
+ {i},
+ Add::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask)),
mask));
LLVMCodeGen cg(expr, {a, b, c});
@@ -488,9 +488,9 @@
void testLLVMElemwiseAddFloat() {
KernelScope kernel_scope;
constexpr int N = 1024;
- Buffer a(VarHandle("A", kHandle), kFloat, {N});
- Buffer b(VarHandle("B", kHandle), kFloat, {N});
- Buffer c(VarHandle("C", kHandle), kFloat, {N});
+ Buffer a(BufHandle("A", {N}), kFloat);
+ Buffer b(BufHandle("B", {N}), kFloat);
+ Buffer c(BufHandle("C", {N}), kFloat);
std::vector<float> a_buffer(N, 41);
std::vector<float> b_buffer(N, 1);
std::vector<float> c_buffer(N, 1);
@@ -501,7 +501,7 @@
i,
0,
N,
- Store::make(c, i, Load::make(a, i, mask) + Load::make(b, i, mask), mask));
+ Store::make(c, {i}, Load::make(a, {i}, mask) + Load::make(b, {i}, mask), mask));
LLVMCodeGen cg(expr, {a, b, c});
@@ -519,8 +519,8 @@
void testLLVMElemwiseLog10Float() {
KernelScope kernel_scope;
constexpr int N = 1024;
- Buffer a(VarHandle("A", kHandle), kFloat, {N});
- Buffer b(VarHandle("B", kHandle), kFloat, {N});
+ Buffer a(BufHandle("A", {N}), kFloat);
+ Buffer b(BufHandle("B", {N}), kFloat);
std::vector<float> a_buffer(N, 10.0f);
std::vector<float> b_buffer(N, 2.0f);
@@ -532,8 +532,8 @@
N / 4,
Store::make(
b,
- Ramp::make(i * 4, 1, 4),
- log10(Load::make(a, Ramp::make(i * 4, 1, 4), mask)),
+ {Ramp::make(i * 4, 1, 4)},
+ log10(Load::make(a, {Ramp::make(i * 4, 1, 4)}, mask)),
mask));
LLVMCodeGen cg(expr, {a, b});
@@ -550,9 +550,9 @@
void testLLVMElemwiseMaxInt() {
KernelScope kernel_scope;
constexpr int N = 1024;
- Buffer a(VarHandle("A", kHandle), kInt, {N});
- Buffer b(VarHandle("B", kHandle), kInt, {N});
- Buffer c(VarHandle("C", kHandle), kInt, {N});
+ Buffer a(BufHandle("A", {N}), kInt);
+ Buffer b(BufHandle("B", {N}), kInt);
+ Buffer c(BufHandle("C", {N}), kInt);
std::vector<int> a_buffer(N, 41);
std::vector<int> b_buffer(N, 1);
std::vector<int> c_buffer(N, 1);
@@ -565,8 +565,8 @@
N,
Store::make(
c,
- i,
- Max::make(Load::make(a, i, mask), Load::make(b, i, mask), false),
+ {i},
+ Max::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), false),
mask));
LLVMCodeGen cg(expr, {a, b, c});
@@ -585,9 +585,9 @@
void testLLVMElemwiseMinInt() {
KernelScope kernel_scope;
constexpr int N = 1024;
- Buffer a(VarHandle("A", kHandle), kInt, {N});
- Buffer b(VarHandle("B", kHandle), kInt, {N});
- Buffer c(VarHandle("C", kHandle), kInt, {N});
+ Buffer a(BufHandle("A", {N}), kInt);
+ Buffer b(BufHandle("B", {N}), kInt);
+ Buffer c(BufHandle("C", {N}), kInt);
std::vector<int> a_buffer(N, 41);
std::vector<int> b_buffer(N, 1);
std::vector<int> c_buffer(N, 1);
@@ -600,8 +600,8 @@
N,
Store::make(
c,
- i,
- Min::make(Load::make(a, i, mask), Load::make(b, i, mask), false),
+ {i},
+ Min::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), false),
mask));
LLVMCodeGen cg(expr, {a, b, c});
@@ -620,9 +620,9 @@
void testLLVMElemwiseMaxNumFloat() {
KernelScope kernel_scope;
constexpr int N = 1024;
- Buffer a(VarHandle("A", kHandle), kFloat, {N});
- Buffer b(VarHandle("B", kHandle), kFloat, {N});
- Buffer c(VarHandle("C", kHandle), kFloat, {N});
+ Buffer a(BufHandle("A", {N}), kFloat);
+ Buffer b(BufHandle("B", {N}), kFloat);
+ Buffer c(BufHandle("C", {N}), kFloat);
std::vector<float> a_buffer(N, 41);
std::vector<float> b_buffer(N, 1);
std::vector<float> c_buffer(N, 1);
@@ -635,8 +635,8 @@
N,
Store::make(
c,
- i,
- Max::make(Load::make(a, i, mask), Load::make(b, i, mask), false),
+ {i},
+ Max::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), false),
mask));
LLVMCodeGen cg(expr, {a, b, c});
@@ -655,9 +655,9 @@
void testLLVMElemwiseMaxNumNaNFloat() {
KernelScope kernel_scope;
constexpr int N = 1024;
- Buffer a(VarHandle("A", kHandle), kFloat, {N});
- Buffer b(VarHandle("B", kHandle), kFloat, {N});
- Buffer c(VarHandle("C", kHandle), kFloat, {N});
+ Buffer a(BufHandle("A", {N}), kFloat);
+ Buffer b(BufHandle("B", {N}), kFloat);
+ Buffer c(BufHandle("C", {N}), kFloat);
std::vector<float> a_buffer(N, NAN);
std::vector<float> b_buffer(N, 1);
std::vector<float> c_buffer(N, 1);
@@ -670,8 +670,8 @@
N,
Store::make(
c,
- i,
- Max::make(Load::make(a, i, mask), Load::make(b, i, mask), false),
+ {i},
+ Max::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), false),
mask));
LLVMCodeGen cg(expr, {a, b, c});
@@ -689,9 +689,9 @@
void testLLVMElemwiseMinNumFloat() {
KernelScope kernel_scope;
constexpr int N = 1024;
- Buffer a(VarHandle("A", kHandle), kFloat, {N});
- Buffer b(VarHandle("B", kHandle), kFloat, {N});
- Buffer c(VarHandle("C", kHandle), kFloat, {N});
+ Buffer a(BufHandle("A", {N}), kFloat);
+ Buffer b(BufHandle("B", {N}), kFloat);
+ Buffer c(BufHandle("C", {N}), kFloat);
std::vector<float> a_buffer(N, 41);
std::vector<float> b_buffer(N, 1);
std::vector<float> c_buffer(N, 1);
@@ -704,8 +704,8 @@
N,
Store::make(
c,
- i,
- Min::make(Load::make(a, i, mask), Load::make(b, i, mask), false),
+ {i},
+ Min::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), false),
mask));
LLVMCodeGen cg(expr, {a, b, c});
@@ -724,9 +724,9 @@
void testLLVMElemwiseMinNumNaNFloat() {
KernelScope kernel_scope;
constexpr int N = 1024;
- Buffer a(VarHandle("A", kHandle), kFloat, {N});
- Buffer b(VarHandle("B", kHandle), kFloat, {N});
- Buffer c(VarHandle("C", kHandle), kFloat, {N});
+ Buffer a(BufHandle("A", {N}), kFloat);
+ Buffer b(BufHandle("B", {N}), kFloat);
+ Buffer c(BufHandle("C", {N}), kFloat);
std::vector<float> a_buffer(N, NAN);
std::vector<float> b_buffer(N, 1);
std::vector<float> c_buffer(N, 1);
@@ -739,8 +739,8 @@
N,
Store::make(
c,
- i,
- Min::make(Load::make(a, i, mask), Load::make(b, i, mask), false),
+ {i},
+ Min::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), false),
mask));
LLVMCodeGen cg(expr, {a, b, c});
@@ -759,9 +759,9 @@
void testLLVMElemwiseMaximumFloat() {
KernelScope kernel_scope;
constexpr int N = 1024;
- Buffer a(VarHandle("A", kHandle), kFloat, {N});
- Buffer b(VarHandle("B", kHandle), kFloat, {N});
- Buffer c(VarHandle("C", kHandle), kFloat, {N});
+ Buffer a(BufHandle("A", {N}), kFloat);
+ Buffer b(BufHandle("B", {N}), kFloat);
+ Buffer c(BufHandle("C", {N}), kFloat);
std::vector<float> a_buffer(N, 41);
std::vector<float> b_buffer(N, 1);
std::vector<float> c_buffer(N, 1);
@@ -774,8 +774,8 @@
N,
Store::make(
c,
- i,
- Max::make(Load::make(a, i, mask), Load::make(b, i, mask), true),
+ {i},
+ Max::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), true),
mask));
LLVMCodeGen cg(expr, {a, b, c});
@@ -794,9 +794,9 @@
void testLLVMElemwiseMaximumNaNFloat() {
KernelScope kernel_scope;
constexpr int N = 1024;
- Buffer a(VarHandle("A", kHandle), kFloat, {N});
- Buffer b(VarHandle("B", kHandle), kFloat, {N});
- Buffer c(VarHandle("C", kHandle), kFloat, {N});
+ Buffer a(BufHandle("A", {N}), kFloat);
+ Buffer b(BufHandle("B", {N}), kFloat);
+ Buffer c(BufHandle("C", {N}), kFloat);
std::vector<float> a_buffer(N, NAN);
std::vector<float> b_buffer(N, 1);
std::vector<float> c_buffer(N, 1);
@@ -809,8 +809,8 @@
N,
Store::make(
c,
- i,
- Max::make(Load::make(a, i, mask), Load::make(b, i, mask), true),
+ {i},
+ Max::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), true),
mask));
LLVMCodeGen cg(expr, {a, b, c});
@@ -830,9 +830,9 @@
void testLLVMElemwiseMinimumFloat() {
KernelScope kernel_scope;
constexpr int N = 1024;
- Buffer a(VarHandle("A", kHandle), kFloat, {N});
- Buffer b(VarHandle("B", kHandle), kFloat, {N});
- Buffer c(VarHandle("C", kHandle), kFloat, {N});
+ Buffer a(BufHandle("A", {N}), kFloat);
+ Buffer b(BufHandle("B", {N}), kFloat);
+ Buffer c(BufHandle("C", {N}), kFloat);
std::vector<float> a_buffer(N, 41);
std::vector<float> b_buffer(N, 1);
std::vector<float> c_buffer(N, 1);
@@ -845,8 +845,8 @@
N,
Store::make(
c,
- i,
- Min::make(Load::make(a, i, mask), Load::make(b, i, mask), true),
+ {i},
+ Min::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), true),
mask));
LLVMCodeGen cg(expr, {a, b, c});
@@ -865,9 +865,9 @@
void testLLVMElemwiseMinimumNaNFloat() {
KernelScope kernel_scope;
constexpr int N = 1024;
- Buffer a(VarHandle("A", kHandle), kFloat, {N});
- Buffer b(VarHandle("B", kHandle), kFloat, {N});
- Buffer c(VarHandle("C", kHandle), kFloat, {N});
+ Buffer a(BufHandle("A", {N}), kFloat);
+ Buffer b(BufHandle("B", {N}), kFloat);
+ Buffer c(BufHandle("C", {N}), kFloat);
std::vector<float> a_buffer(N, NAN);
std::vector<float> b_buffer(N, 1);
std::vector<float> c_buffer(N, 1);
@@ -880,8 +880,8 @@
N,
Store::make(
c,
- i,
- Min::make(Load::make(a, i, mask), Load::make(b, i, mask), true),
+ {i},
+ Min::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), true),
mask));
LLVMCodeGen cg(expr, {a, b, c});
@@ -902,9 +902,9 @@
void testLLVMCompareSelectIntEQ() {
KernelScope kernel_scope;
constexpr int N = 1024;
- Buffer a(VarHandle("A", kHandle), kInt, {N});
- Buffer b(VarHandle("B", kHandle), kInt, {N});
- Buffer c(VarHandle("C", kHandle), kInt, {N});
+ Buffer a(BufHandle("A", {N}), kInt);
+ Buffer b(BufHandle("B", {N}), kInt);
+ Buffer c(BufHandle("C", {N}), kInt);
std::vector<int> a_buffer(N, 1);
std::vector<int> b_buffer(N, 1);
std::vector<int> c_buffer(N, 0);
@@ -923,10 +923,10 @@
N,
Store::make(
c,
- i,
+ {i},
CompareSelect::make(
- Load::make(a, i, mask),
- Load::make(b, i, mask),
+ Load::make(a, {i}, mask),
+ Load::make(b, {i}, mask),
CompareSelectOperation::kEQ),
mask));
@@ -948,9 +948,9 @@
void testLLVMCompareSelectFloatEQ() {
KernelScope kernel_scope;
constexpr int N = 1024;
- Buffer a(VarHandle("A", kHandle), kFloat, {N});
- Buffer b(VarHandle("B", kHandle), kFloat, {N});
- Buffer c(VarHandle("C", kHandle), kInt, {N});
+ Buffer a(BufHandle("A", {N}), kFloat);
+ Buffer b(BufHandle("B", {N}), kFloat);
+ Buffer c(BufHandle("C", {N}), kInt);
std::vector<float> a_buffer(N, 1.0f);
std::vector<float> b_buffer(N, 1.0f);
std::vector<int> c_buffer(N, 0);
@@ -963,10 +963,10 @@
N,
Store::make(
c,
- i,
+ {i},
CompareSelect::make(
- Load::make(a, i, mask),
- Load::make(b, i, mask),
+ Load::make(a, {i}, mask),
+ Load::make(b, {i}, mask),
CompareSelectOperation::kEQ),
mask));
@@ -986,10 +986,10 @@
void testLLVMStoreFloat() {
KernelScope kernel_scope;
- Buffer result(VarHandle("result", kHandle), kFloat, {1});
+ Buffer result(BufHandle("result", {1}), kFloat);
std::vector<float> result_buffer = {0.0f};
auto expr = Store::make(
- result, IntImm::make(0), FloatImm::make(3.14f), IntImm::make(1));
+ result, {IntImm::make(0)}, FloatImm::make(3.14f), IntImm::make(1));
LLVMCodeGen cg(expr, {result});
std::vector<void*> args({result_buffer.data()});
ASSERT_EQ(cg.value<int>(args), 0);
@@ -1004,7 +1004,7 @@
});
LoopNest l({tensor});
Stmt* stmt = l.root_stmt();
- Buffer f_buf(VarHandle(tensor->func_var()), kFloat, {N});
+ Buffer f_buf(BufHandle(tensor->func_var()), kFloat);
LLVMCodeGen cg(stmt, {f_buf});
PaddedBuffer<float> f_v(N, "f_v");
@@ -1021,13 +1021,13 @@
void testLLVMComputeMul() {
KernelScope kernel_scope;
const int N = 1024;
- Buffer a(VarHandle("a", kHandle), kFloat, {N});
- Buffer b(VarHandle("b", kHandle), kFloat, {N});
+ Buffer a(BufHandle("a", {N}), kFloat);
+ Buffer b(BufHandle("b", {N}), kFloat);
Tensor* c = Compute("c", {{N, "i"}}, [&](const VarHandle& i) {
- return Load::make(a, i, 1) * Load::make(b, i, 1);
+ return Load::make(a, {i}, 1) * Load::make(b, {i}, 1);
});
- Buffer c_buf(VarHandle(c->func_var()), kFloat, {N});
+ Buffer c_buf(BufHandle(c->func_var()), kFloat);
LoopNest l({c});
Stmt* s = l.root_stmt();
@@ -1045,16 +1045,17 @@
KernelScope kernel_scope;
const int M = 32;
const int N = 1024;
- Buffer a(VarHandle("a", kHandle), kFloat, {M, N});
- Buffer b(VarHandle("b", kHandle), kFloat, {N});
+ Buffer a(BufHandle("a", {M, N}), kFloat);
+ Buffer b(BufHandle("b", {N}), kFloat);
Tensor* c = Compute(
"c", {{M, "i"}, {N, "j"}}, [&](const VarHandle& i, const VarHandle& j) {
ExprHandle mask(1);
- return Load::make(a, i * N + j, mask) + Load::make(b, j, mask);
+ return Load::make(a, {i, j}, mask) + Load::make(b, {j}, mask);
});
- Buffer c_buf(VarHandle(c->func_var()), kFloat, {M, N});
+ Buffer c_buf(BufHandle(c->func_var()), kFloat);
LoopNest l({c});
+ l.prepareForCodegen();
Stmt* s = l.root_stmt();
LLVMCodeGen cg(s, {a, b, c_buf});
@@ -1091,11 +1092,11 @@
KernelScope kernel_scope;
auto testWithSize = [](int32_t size) {
VarHandle n("n", kInt);
- Buffer a(VarHandle("a", kHandle), kFloat, {n});
- Buffer b(VarHandle("b", kHandle), kFloat, {n});
- Buffer c(VarHandle("c", kHandle), kFloat, {n});
+ Buffer a(BufHandle("a", {n}), kFloat);
+ Buffer b(BufHandle("b", {n}), kFloat);
+ Buffer c(BufHandle("c", {n}), kFloat);
VarHandle i("i", kInt);
- Stmt* s = For::make(i, 0, n, Store::make(c, i, a(i) + b(i), 1));
+ Stmt* s = For::make(i, 0, n, Store::make(c, {i}, a(i) + b(i), 1));
std::vector<float> aData(size, 1.0f);
std::vector<float> bData(size, 2.0f);
std::vector<float> cData(size, 0.0f);
@@ -1113,11 +1114,11 @@
KernelScope kernel_scope;
auto testWithSize = [](int32_t size) {
VarHandle n("n", kInt);
- Buffer a(VarHandle("a", kHandle), kFloat, {n});
- Buffer b(VarHandle("b", kHandle), kFloat, {n});
- Buffer c(VarHandle("c", kHandle), kFloat, {n});
+ Buffer a(BufHandle("a", {n}), kFloat);
+ Buffer b(BufHandle("b", {n}), kFloat);
+ Buffer c(BufHandle("c", {n}), kFloat);
VarHandle i("i", kInt);
- Stmt* s = For::make(i, 0, n, Store::make(c, i, a(i) + b(i), 1));
+ Stmt* s = For::make(i, 0, n, Store::make(c, {i}, a(i) + b(i), 1));
std::vector<float> aData(size, 1.0f);
std::vector<float> bData(size, 2.0f);
std::vector<float> cData(size, 0.0f);
@@ -1134,8 +1135,8 @@
KernelScope kernel_scope;
auto testWithSize = [](int32_t size) {
VarHandle n("n", kInt);
- Buffer a(VarHandle("a", kHandle), kFloat, {n});
- Buffer b(VarHandle("b", kHandle), kFloat, {n});
+ Buffer a(BufHandle("a", {n}), kFloat);
+ Buffer b(BufHandle("b", {n}), kFloat);
Tensor* c = Compute(
"c", {{n, "n"}}, [&](const VarHandle& i) { return a(i) + b(i); });
LoopNest l({c});
@@ -1157,13 +1158,14 @@
auto testWithSize = [](int32_t M, int32_t N) {
VarHandle m("m", kInt);
VarHandle n("n", kInt);
- Buffer a(VarHandle("a", kHandle), kFloat, {m, n});
- Buffer b(VarHandle("b", kHandle), kFloat, {m, n});
+ Buffer a(BufHandle("a", {m, n}), kFloat);
+ Buffer b(BufHandle("b", {m, n}), kFloat);
Tensor* c = Compute(
"c", {{m, "m"}, {n, "n"}}, [&](const VarHandle& i, const VarHandle& j) {
return a(i, j) + b(i, j);
});
LoopNest l({c});
+ l.prepareForCodegen();
Stmt* s = l.root_stmt();
LLVMCodeGen cg(s, {a, b, c, m, n});
std::vector<float> aData(M * N, 1.0f);
diff --git a/test/cpp/tensorexpr/test_loopnest.cpp b/test/cpp/tensorexpr/test_loopnest.cpp
index 387c71e..76c61e4 100644
--- a/test/cpp/tensorexpr/test_loopnest.cpp
+++ b/test/cpp/tensorexpr/test_loopnest.cpp
@@ -77,7 +77,7 @@
VarHandle x_inner("x_inner", kInt);
VarHandle y("y", kInt);
VarHandle x_tail("x_tail", kInt);
- VarHandle f("f", kHandle);
+ BufHandle f("f", {26, 5});
ExprHandle x_1 = x_outer * 4 + x_inner;
ExprHandle x_outer_end = (ExprHandle(26) - 0) / 4;
For* stmt1 = For::make(
@@ -89,13 +89,13 @@
0,
4,
For::make(
- y, 0, 5, Store::make(f, x_1 * 5 + y * 1, func(x_1, y), 1))));
+ y, 0, 5, Store::make(f, {x_1, y}, func(x_1, y), 1))));
ExprHandle x_2 = x_tail + x_outer_end * 4;
For* stmt2 = For::make(
x_tail,
0,
(ExprHandle(26) - 0) % 4,
- For::make(y, 0, 5, Store::make(f, x_2 * 5 + y * 1, func(x_2, y), 1)));
+ For::make(y, 0, 5, Store::make(f, {x_2, y}, func(x_2, y), 1)));
Stmt* stmt = Block::make({stmt1, stmt2});
std::ostringstream oss_ref;
@@ -107,6 +107,7 @@
PaddedBuffer<float> f_v(26, 5, "f_v");
PaddedBuffer<float> f_ref(26, 5, "f_res");
+ stmt = FlattenIndexes(stmt);
SimpleIREvaluator ir_eval(stmt, tensor);
ir_eval(f_v);
@@ -145,7 +146,7 @@
VarHandle x_inner("x_inner", kInt);
VarHandle y("y", kInt);
VarHandle x_tail("x_tail", kInt);
- VarHandle f("f", kHandle);
+ BufHandle f("f", {24, 5});
ExprHandle x_1 = x_outer * 4 + x_inner;
ExprHandle x_outer_end = (ExprHandle(24) - 0) / 4;
For* stmt = For::make(
@@ -157,8 +158,7 @@
0,
4,
For::make(
- y, 0, 5, Store::make(f, x_1 * 5 + y * 1, func(x_1, y), 1))));
- // Stmt stmt = Block::make({stmt1, stmt2});
+ y, 0, 5, Store::make(f, {x_1, y}, func(x_1, y), 1))));
std::ostringstream oss_ref;
oss_ref << *stmt;
@@ -429,6 +429,7 @@
(c_buf(m, n) * d_buf(m, k) + a_buf(m, n) * b_buf(n, k));
});
LoopNest l2({z2});
+ l2.prepareForCodegen();
Stmt* stmt2 = l2.root_stmt();
std::ostringstream oss2;
@@ -454,7 +455,7 @@
const int kVectorCount = 128;
const int kTotalSize = kVectorSize * kVectorCount;
- Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
+ Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
Tensor* b = Compute(
"f", {{kTotalSize, "i"}}, [&](const std::vector<VarHandle>& axes) {
@@ -487,10 +488,10 @@
const int kVectorCount = 128;
const int kTotalSize = kVectorSize * kVectorCount;
- Buffer a(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer b(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer c(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)});
- Buffer d(VarHandle("D", kHandle), kFloat, {ExprHandle(kTotalSize)});
+ Buffer a(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer b(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer c(BufHandle("C", {ExprHandle(kTotalSize)}), kFloat);
+ Buffer d(BufHandle("D", {ExprHandle(kTotalSize)}), kFloat);
Tensor* e = Compute("e", {{kTotalSize, "i"}}, [&](const VarHandle& i) {
return a(i) + b(i);
@@ -525,8 +526,8 @@
auto testWithSize = [](int32_t M, int32_t N) {
VarHandle m("m", kInt);
VarHandle n("n", kInt);
- Buffer a(VarHandle("a", kHandle), kFloat, {m, n});
- Buffer b(VarHandle("b", kHandle), kFloat, {m, n});
+ Buffer a(BufHandle("a", {m, n}), kFloat);
+ Buffer b(BufHandle("b", {m, n}), kFloat);
Tensor* c = Compute(
"c", {{m, "m"}, {n, "n"}}, [&](const VarHandle& i, const VarHandle& j) {
return a(i, j) + b(i, j);
diff --git a/test/cpp/tensorexpr/test_simplify.cpp b/test/cpp/tensorexpr/test_simplify.cpp
index 8a03e0e..26356be 100644
--- a/test/cpp/tensorexpr/test_simplify.cpp
+++ b/test/cpp/tensorexpr/test_simplify.cpp
@@ -346,9 +346,9 @@
void testHashLargeExpression() {
KernelScope kernel_scope;
constexpr int N = 1024;
- Buffer a(VarHandle("A", kHandle), kInt, {N});
- Buffer b(VarHandle("B", kHandle), kInt, {N});
- Buffer c(VarHandle("C", kHandle), kInt, {N});
+ Buffer a(BufHandle("A", {N}), kInt);
+ Buffer b(BufHandle("B", {N}), kInt);
+ Buffer c(BufHandle("C", {N}), kInt);
auto mask = IntImm::make(1);
VarHandle i("i", kInt);
auto memcpy_stmt = For::make(
@@ -357,25 +357,25 @@
N,
Store::make(
c,
- i,
+ {i},
CompareSelect::make(
- Load::make(a, i, mask),
- Load::make(b, i, mask),
+ Load::make(a, {i}, mask),
+ Load::make(b, {i}, mask),
CompareSelectOperation::kEQ),
mask));
- Buffer d(VarHandle("D", kHandle), kInt, {1});
- Buffer e(VarHandle("E", kHandle), kInt, {1});
+ Buffer d(BufHandle("D", {1}), kInt);
+ Buffer e(BufHandle("E", {1}), kInt);
auto store_ramp_stmt = Store::make(
e,
- Ramp::make(0, 1, 4),
- Load::make(d, Ramp::make(0, 1, 4), Broadcast::make(IntImm::make(1), 4)),
+ {Ramp::make(0, 1, 4)},
+ Load::make(d, {Ramp::make(0, 1, 4)}, Broadcast::make(IntImm::make(1), 4)),
Broadcast::make(Cast::make(kInt, DoubleImm::make(1)), 4));
auto if_stmt = Cond::make(
CompareSelect::make(
- Load::make(a, i, mask),
- Load::make(b, i, mask),
+ Load::make(a, {i}, mask),
+ Load::make(b, {i}, mask),
CompareSelectOperation::kGE),
memcpy_stmt,
store_ramp_stmt);
diff --git a/torch/csrc/jit/tensorexpr/buffer.h b/torch/csrc/jit/tensorexpr/buffer.h
index 6cd1861..026c4c2 100644
--- a/torch/csrc/jit/tensorexpr/buffer.h
+++ b/torch/csrc/jit/tensorexpr/buffer.h
@@ -8,18 +8,13 @@
class Buffer {
public:
- Buffer(
- const VarHandle& data,
- const Dtype& dtype,
- const std::vector<ExprHandle>& dims)
- : data_(data.node()),
- dtype_(dtype),
- dims_(ExprHandleVectorToExprVector(dims)) {
+ Buffer(const BufHandle& data, const Dtype& dtype)
+ : data_(data.node()), dtype_(dtype) {
if (data.dtype() != kHandle) {
throw malformed_input();
}
- std::vector<ExprHandle> stride_handles(dims.size());
+ std::vector<ExprHandle> stride_handles(ndim());
for (int i = ndim() - 1; i >= 0; i--) {
if (i == ndim() - 1) {
stride_handles[i] = 1;
@@ -33,100 +28,60 @@
const std::string& name,
const Dtype& dtype,
const std::vector<ExprHandle>& dims)
- : Buffer(VarHandle(name, kHandle), dtype, dims) {}
+ : Buffer(BufHandle(name, dims), dtype) {}
- const Var* data() const {
+ const Buf* data() const {
return data_;
}
const Dtype& dtype() const {
return dtype_;
}
int ndim() const {
- return dims_.size();
+ return data_->ndim();
}
const Expr* dim(int index) const {
- return dims_[index];
+ return data_->dim(index);
+ }
+ std::vector<const Expr*> dims() const {
+ return data_->dims();
}
// TODO: consider defer the storage flatten to a later stage.
template <typename... Args>
ExprHandle operator()(Args... args) const {
- ExprHandle index = Index(std::forward<Args>(args)...);
- return LoadValue(index);
+ return LoadValue(std::forward<Args>(args)...);
+ }
+ ExprHandle LoadValue(
+ const ExprHandle& x,
+ const ExprHandle& y,
+ const ExprHandle& z) const {
+ return Load::make(*this, {x, y, z}, ExprHandle(1));
+ }
+ ExprHandle LoadValue(const ExprHandle& x, const ExprHandle& y) const {
+ return Load::make(*this, {x, y}, ExprHandle(1));
+ }
+ ExprHandle LoadValue(const ExprHandle& x) const {
+ return Load::make(*this, {x}, ExprHandle(1));
}
template <typename T>
ExprHandle call(const std::vector<T>& args) const {
std::vector<ExprHandle> params(args.begin(), args.end());
- ExprHandle index = Index(params);
- return LoadValue(index);
+ return LoadValue(params);
}
private:
- ExprHandle Index(const ExprHandle& x) const {
- if (ndim() != 1) {
- throw malformed_input();
- }
- return x;
- }
- ExprHandle Index(const ExprHandle& x, const ExprHandle& y) const {
- if (ndim() != 2) {
- throw malformed_input();
- }
- return x * ExprHandle(strides_[0]) + y;
- }
- ExprHandle Index(
- const ExprHandle& x,
- const ExprHandle& y,
- const ExprHandle& z) const {
- if (ndim() != 3) {
- throw malformed_input();
- }
- return x * ExprHandle(strides_[0]) + y * ExprHandle(strides_[1]) + z;
- }
- ExprHandle Index(
- const ExprHandle& x,
- const ExprHandle& y,
- const ExprHandle& z,
- const ExprHandle& w) const {
- if (ndim() != 4) {
- throw malformed_input();
- }
- return x * ExprHandle(strides_[0]) + y * ExprHandle(strides_[1]) +
- z * ExprHandle(strides_[2]) + w;
- }
- ExprHandle Index(const std::vector<ExprHandle>& indices) const {
- if (ndim() != (int)indices.size()) {
- throw malformed_input();
- }
- ExprHandle total_index;
- for (size_t i = 0; i < indices.size(); i++) {
- ExprHandle index;
- if (i == indices.size() - 1) {
- index = indices[i];
- } else {
- index = indices[i] * ExprHandle(strides_[i]);
- }
- if (i == 0) {
- total_index = index;
- } else {
- total_index = total_index + index;
- }
- }
- return total_index;
- }
+ ExprHandle LoadValue(const std::vector<ExprHandle>& indices) const;
- ExprHandle LoadValue(const ExprHandle& index) const;
-
- const Var* data_;
+ const Buf* data_;
Dtype dtype_;
- std::vector<const Expr*> dims_;
std::vector<const Expr*> strides_;
// TODO: add strides
};
-inline ExprHandle Buffer::LoadValue(const ExprHandle& index) const {
- return Load::make(*this, index, ExprHandle(1));
+inline ExprHandle Buffer::LoadValue(
+ const std::vector<ExprHandle>& indices) const {
+ return Load::make(*this, indices, ExprHandle(1));
}
} // namespace tensorexpr
diff --git a/torch/csrc/jit/tensorexpr/codegen.h b/torch/csrc/jit/tensorexpr/codegen.h
index e232e5b..c1f63b2 100644
--- a/torch/csrc/jit/tensorexpr/codegen.h
+++ b/torch/csrc/jit/tensorexpr/codegen.h
@@ -56,12 +56,14 @@
class CodeGen::BufferArg {
public:
BufferArg(const Buffer& buffer)
- : var_(buffer.data()), dtype_(buffer.dtype()) {}
+ : var_(buffer.data()->base_handle()), dtype_(buffer.dtype()) {}
BufferArg(Tensor* tensor)
- : var_(tensor->function()->func_var(tensor->output_index())),
+ : var_(tensor->function()
+ ->func_var(tensor->output_index())
+ ->base_handle()),
dtype_(tensor->function()->body(tensor->output_index())->dtype()) {}
BufferArg(const Function& func)
- : var_(func.func_var(0)), dtype_(func.body(0)->dtype()) {
+ : var_(func.func_var(0)->base_handle()), dtype_(func.body(0)->dtype()) {
// TODO: Support multiple-output functions
if (func.func_vars().size() != 1) {
throw unimplemented_lowering();
diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp
index 0d5e21b..2119b1e 100644
--- a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp
+++ b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp
@@ -161,14 +161,15 @@
void CudaPrinter::visit(const Load* v) {
// TODO: find a better metric in using ldg or not. Support different dtypes.
if (v->dtype().scalar_type() == ScalarType::Half) {
- os() << "__half2float(" << *v->base_handle() << "[" << *v->index() << "])";
+ os() << "__half2float(" << *v->base_handle() << "[" << *v->flat_index()
+ << "])";
} else {
- os() << "__ldg(" << *v->base_handle() << " + " << *v->index() << ")";
+ os() << "__ldg(" << *v->base_handle() << " + " << *v->flat_index() << ")";
}
}
void CudaPrinter::visit(const Store* v) {
- os() << *v->base_handle() << "[" << *v->index() << "] = ";
+ os() << *v->base_handle() << "[" << *v->flat_index() << "] = ";
if (v->value()->dtype().scalar_type() == ScalarType::Half) {
os() << "__float2half(" << *v->value() << ");";
} else {
diff --git a/torch/csrc/jit/tensorexpr/eval.h b/torch/csrc/jit/tensorexpr/eval.h
index 23dc185..66404ce 100644
--- a/torch/csrc/jit/tensorexpr/eval.h
+++ b/torch/csrc/jit/tensorexpr/eval.h
@@ -582,7 +582,8 @@
}
void* ptr = iter->second;
- v->index()->accept(this);
+ const Expr* flat_idx = flatten_index(v->buf()->dims(), v->indices());
+ flat_idx->accept(this);
std::vector<int> index = value().as_vec<int>();
v->mask()->accept(this);
std::vector<int> mask = value().as_vec<int>();
@@ -615,7 +616,8 @@
void* ptr = iter->second;
- v->index()->accept(this);
+ const Expr* flat_idx = flatten_index(v->buf()->dims(), v->indices());
+ flat_idx->accept(this);
std::vector<int> index = value().as_vec<int>();
v->mask()->accept(this);
std::vector<int> mask = value().as_vec<int>();
@@ -857,7 +859,13 @@
: dtype_(expr.dtype()) {
std::vector<BufferArg> buffer_args_extended = buffer_args;
Buffer ret_buf("ret_val", dtype_, {1});
- Stmt* store_stmt = Store::make(VarHandle(ret_buf.data()), 0, expr);
+ std::vector<const Expr*> indices;
+ const Expr* zero = new IntImm(0);
+ for (size_t i = 0; i < ret_buf.data()->ndim(); i++) {
+ indices.push_back(zero);
+ }
+ Stmt* store_stmt =
+ new Store(ret_buf.data(), indices, expr.node(), new IntImm(1));
buffer_args_extended.push_back(ret_buf);
codegen_.reset(new CodeGenType(store_stmt, buffer_args_extended));
}
diff --git a/torch/csrc/jit/tensorexpr/expr.cpp b/torch/csrc/jit/tensorexpr/expr.cpp
index 9f7ed26..82b09f2 100644
--- a/torch/csrc/jit/tensorexpr/expr.cpp
+++ b/torch/csrc/jit/tensorexpr/expr.cpp
@@ -203,6 +203,17 @@
return IfThenElse::make(c, t, f);
}
+ExprHandle Buf::make(
+ const std::string& name_hint,
+ const std::vector<ExprHandle>& dims) {
+ return ExprHandle(
+ new Buf(new Var(name_hint, kHandle), ExprHandleVectorToExprVector(dims)));
+}
+
+ExprHandle Buf::make(const std::vector<ExprHandle>& dims) {
+ return Buf::make("", dims);
+}
+
} // namespace tensorexpr
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/tensorexpr/expr.h b/torch/csrc/jit/tensorexpr/expr.h
index 14eb4bb..2af4830 100644
--- a/torch/csrc/jit/tensorexpr/expr.h
+++ b/torch/csrc/jit/tensorexpr/expr.h
@@ -161,6 +161,66 @@
std::string name_hint_;
};
+class TORCH_API Buf : public ExprNode<Buf> {
+ public:
+ static ExprHandle make(
+ const std::string& name_hint,
+ const std::vector<ExprHandle>& dims);
+ static ExprHandle make(const std::vector<ExprHandle>& dims);
+
+ // TODO: unique_name
+ const Var* base_handle() const {
+ return base_handle_;
+ }
+ const std::string& name_hint() const {
+ return base_handle_->name_hint();
+ }
+
+ Buf(const Var* var, const std::vector<const Expr*>& dims)
+ : ExprNodeBase(kHandle, kPrimitive), base_handle_(var), dims_(dims) {
+ TORCH_CHECK(var);
+ }
+
+ int ndim() const {
+ return dims_.size();
+ }
+ const Expr* dim(int index) const {
+ return dims_[index];
+ }
+ std::vector<const Expr*> dims() const {
+ return dims_;
+ }
+
+ private:
+ const Var* base_handle_;
+ std::vector<const Expr*> dims_;
+};
+
+class TORCH_API BufHandle : public ExprHandle {
+ public:
+ BufHandle() : ExprHandle(nullptr) {}
+ // explicit BufHandle(Dtype dtype) : ExprHandle(Buf::make(dtype)) {}
+ BufHandle(const std::string& name_hint, const std::vector<ExprHandle>& dims)
+ : ExprHandle(Buf::make(name_hint, dims)) {}
+ explicit BufHandle(const Buf* node) : ExprHandle(node) {}
+ const Buf* node() const {
+ return static_cast<const Buf*>(ExprHandle::node());
+ }
+ bool operator==(const BufHandle& other) const {
+ return this->node() == other.node();
+ }
+ bool operator!=(const BufHandle& other) const {
+ return !(*this == other);
+ }
+
+ const std::string& name_hint() const {
+ return this->node()->name_hint();
+ }
+ bool empty() const {
+ return (this->node() == nullptr);
+ }
+};
+
// An expression to construct the underlying variable node.
// Note: do not store any info here, since it is often possible to slice this
// object. For example: VarHandle x('x'); ExprHandle x2 = x;
diff --git a/torch/csrc/jit/tensorexpr/function.cpp b/torch/csrc/jit/tensorexpr/function.cpp
index 97996f7..4a9a9a6 100644
--- a/torch/csrc/jit/tensorexpr/function.cpp
+++ b/torch/csrc/jit/tensorexpr/function.cpp
@@ -32,7 +32,8 @@
unpack_dim_args(dim_args, &dims, &args);
const Expr* body = body_func(VarVectorToVarHandleVector(args)).node();
Function* func = new Function(func_name, dims, args, body);
- return new Tensor(func, 0);
+ const Buf* buf = func->func_var(0);
+ return new Tensor(buf, func, 0);
}
Tensor* Compute(
@@ -48,7 +49,8 @@
unpack_dim_args(dim_args, &dims, &args);
const Expr* body = body_func(VarHandle(args[0])).node();
Function* func = new Function(func_name, dims, args, body);
- return new Tensor(func, 0);
+ const Buf* buf = func->func_var(0);
+ return new Tensor(buf, func, 0);
}
Tensor* Compute(
@@ -64,7 +66,8 @@
unpack_dim_args(dim_args, &dims, &args);
const Expr* body = body_func(VarHandle(args[0]), VarHandle(args[1])).node();
Function* func = new Function(func_name, dims, args, body);
- return new Tensor(func, 0);
+ const Buf* buf = func->func_var(0);
+ return new Tensor(buf, func, 0);
}
Tensor* Compute(
@@ -83,7 +86,8 @@
body_func(VarHandle(args[0]), VarHandle(args[1]), VarHandle(args[2]))
.node();
Function* func = new Function(func_name, dims, args, body);
- return new Tensor(func, 0);
+ const Buf* buf = func->func_var(0);
+ return new Tensor(buf, func, 0);
}
Tensor* Compute(
@@ -103,37 +107,20 @@
auto args = VarVectorToVarHandleVector(args_nodes);
const Expr* body = body_func(args[0], args[1], args[2], args[3]).node();
Function* func = new Function(func_name, dims, args_nodes, body);
- return new Tensor(func, 0);
+ const Buf* buf = func->func_var(0);
+ return new Tensor(buf, func, 0);
}
Stmt* Function::ElementStmt(size_t index) {
- std::vector<ExprHandle> strides(dims_.size());
- for (size_t i = 0; i < strides.size(); i++) {
- if (i == strides.size() - 1) {
- strides[i] = ExprHandle(1);
- continue;
- }
- ExprHandle stride = ExprHandle(dims_[i + 1]);
- for (size_t j = i + 2; j < dims_.size(); j++) {
- stride = stride * ExprHandle(dims_[j]);
- }
- strides[i] = stride;
- }
-
- ExprHandle total_index = int32_t{0};
- for (size_t i = 0; i < dims_.size(); i++) {
- ExprHandle index = VarHandle(this->args_[i]) * ExprHandle(strides[i]);
- if (i == 0) {
- total_index = index;
- } else {
- total_index = total_index + index;
- }
+ const Buf* buf = func_var(index);
+ std::vector<const Expr*> indices;
+ for (size_t i = 0; i < buf->ndim(); i++) {
+ indices.push_back(this->args_[i]);
}
const Expr* mask = new IntImm(1);
- Stmt* update_stmt =
- new Store(func_var(index), total_index.node(), body(index), mask);
+ Stmt* update_stmt = new Store(buf, indices, body(index), mask);
return update_stmt;
}
diff --git a/torch/csrc/jit/tensorexpr/function.h b/torch/csrc/jit/tensorexpr/function.h
index eba3ebe..b526543 100644
--- a/torch/csrc/jit/tensorexpr/function.h
+++ b/torch/csrc/jit/tensorexpr/function.h
@@ -17,8 +17,9 @@
const std::vector<const Expr*>& dims,
const std::vector<const Var*>& args,
const Expr* body)
- : func_vars_({VarHandle(func_name, kHandle).node()}),
- dims_(dims),
+ // TODO: Function should not create buffers, they should be created
+ // manually before constructing a function.
+ : func_vars_({new Buf(new Var(func_name, kHandle), dims)}),
args_(args),
bodies_({body}) {}
Function(
@@ -26,30 +27,14 @@
const std::vector<const Expr*>& dims,
const std::vector<const Var*>& args,
const std::vector<const Expr*>& bodies)
- : func_vars_(func_names.size()),
- dims_(dims),
- args_(args),
- bodies_(bodies) {
+ : func_vars_(func_names.size()), args_(args), bodies_(bodies) {
for (size_t i = 0; i < func_names.size(); i++) {
- func_vars_[i] = new Var(func_names[i], kHandle);
+ func_vars_[i] = new Buf(new Var(func_names[i], kHandle), dims);
}
}
- int ndim() const {
- return dims_.size();
- }
- const Expr* dim(int index) const {
- if (index < 0 || index >= ndim()) {
- throw out_of_range_index();
- }
-
- return dims_[index];
- }
- const std::vector<const Expr*>& dims() const {
- return dims_;
- }
const Var* arg(int index) const {
- if (index < 0 || index >= ndim()) {
+ if (index < 0 || index >= args_.size()) {
throw out_of_range_index();
}
@@ -70,10 +55,10 @@
return bodies_[index];
}
- std::vector<const Var*> func_vars() const {
+ std::vector<const Buf*> func_vars() const {
return func_vars_;
}
- const Var* func_var(size_t index) const {
+ const Buf* func_var(size_t index) const {
if (index >= func_vars_.size()) {
throw out_of_range_index();
}
@@ -83,8 +68,7 @@
Stmt* ElementStmt(size_t index);
private:
- std::vector<const Var*> func_vars_;
- std::vector<const Expr*> dims_;
+ std::vector<const Buf*> func_vars_;
std::vector<const Var*> args_;
std::vector<const Expr*> bodies_;
};
diff --git a/torch/csrc/jit/tensorexpr/hash_provider.cpp b/torch/csrc/jit/tensorexpr/hash_provider.cpp
index f77ec3a..546a99c 100644
--- a/torch/csrc/jit/tensorexpr/hash_provider.cpp
+++ b/torch/csrc/jit/tensorexpr/hash_provider.cpp
@@ -172,21 +172,26 @@
void HashProvider::visit(const Load* v) {
CACHE_GUARD();
v->base_handle()->accept(this);
- v->index()->accept(this);
+ SimplifierHashType indices_hash;
+ for (const Expr* ind : v->indices()) {
+ ind->accept(this);
+ indices_hash = hash_combine(indices_hash, hashOf(ind));
+ }
v->mask()->accept(this);
putHash(
v,
hash_combine(
- "load",
- hashOf(v->base_handle()),
- hashOf(v->index()),
- hashOf(v->mask())));
+ "load", hashOf(v->base_handle()), indices_hash, hashOf(v->mask())));
}
void HashProvider::visit(const Store* v) {
CACHE_GUARD();
v->base_handle()->accept(this);
- v->index()->accept(this);
+ SimplifierHashType indices_hash;
+ for (const Expr* ind : v->indices()) {
+ ind->accept(this);
+ indices_hash = hash_combine(indices_hash, hashOf(ind));
+ }
v->value()->accept(this);
v->mask()->accept(this);
putHash(
@@ -194,7 +199,7 @@
hash_combine(
"store",
hashOf(v->base_handle()),
- hashOf(v->index()),
+ indices_hash,
hashOf(v->value()),
hashOf(v->mask())));
}
diff --git a/torch/csrc/jit/tensorexpr/ir.cpp b/torch/csrc/jit/tensorexpr/ir.cpp
index 68e34ab..651c4a1 100644
--- a/torch/csrc/jit/tensorexpr/ir.cpp
+++ b/torch/csrc/jit/tensorexpr/ir.cpp
@@ -10,46 +10,183 @@
return Dtype(buffer_dtype, index_dtype.lanes());
}
-Load::Load(const Buffer& buffer, const Expr* index, const Expr* mask)
+static Dtype dtypeOfIndices(const std::vector<const Expr*>& indices) {
+ if (!indices.size()) {
+ throw malformed_input();
+ }
+ Dtype dt = indices.at(0)->dtype();
+ for (size_t i = 1; i < indices.size(); ++i) {
+ if (indices.at(i)->dtype() != dt) {
+ throw malformed_input();
+ }
+ }
+ return dt;
+}
+
+static bool indicesValid(const std::vector<const Expr*>& indices) {
+ if (indices.size() == 0) {
+ return false;
+ }
+ Dtype index_dtype = dtypeOfIndices(indices);
+ if (indices.size() > 1 && index_dtype.lanes() > 1) {
+ // Multilane is only allowed in a flattened (i.e. 1D) index
+ return false;
+ }
+ if (index_dtype.scalar_type() != ScalarType::Int) {
+ return false;
+ }
+ return true;
+}
+
+Load::Load(
+ const Buffer& buffer,
+ const std::vector<const Expr*>& indices,
+ const Expr* mask)
: Load(
- ChooseDtype(buffer.dtype(), index->dtype()),
+ ChooseDtype(buffer.dtype(), dtypeOfIndices(indices)),
buffer.data(),
- index,
+ indices,
mask) {}
Load::Load(
Dtype dtype,
- const Var* base_handle,
- const Expr* index,
+ const Buf* buf,
+ const std::vector<const Expr*>& indices,
const Expr* mask)
- : ExprNodeBase(dtype),
- base_handle_(base_handle),
- index_(index),
- mask_(mask) {
- if (base_handle->dtype() != kHandle) {
+ : ExprNodeBase(dtype), buf_(buf), indices_(indices), mask_(mask) {
+ if (buf->base_handle()->dtype() != kHandle) {
throw malformed_input();
}
-
- if (index->dtype().lanes() != mask->dtype().lanes()) {
+ if (!indicesValid(indices)) {
throw malformed_input();
}
-
- if (index->dtype().scalar_type() != ScalarType::Int) {
- throw unsupported_dtype();
+ Dtype index_dtype = dtypeOfIndices(indices);
+ if (index_dtype.lanes() != mask->dtype().lanes()) {
+ throw malformed_input();
}
}
+ExprHandle Load::make(
+ const Buffer& buffer,
+ const std::vector<ExprHandle>& indices,
+ const ExprHandle& mask) {
+ return ExprHandle(
+ new Load(buffer, ExprHandleVectorToExprVector(indices), mask.node()));
+}
+ExprHandle Load::make(
+ Dtype dtype,
+ const BufHandle& buf,
+ const std::vector<ExprHandle>& indices,
+ const ExprHandle& mask) {
+ return ExprHandle(new Load(
+ dtype, buf.node(), ExprHandleVectorToExprVector(indices), mask.node()));
+}
+
Store::Store(
const Buffer& buffer,
- const Expr* index,
+ const std::vector<const Expr*>& indices,
const Expr* value,
const Expr* mask)
- : Store(buffer.data(), index, value, mask) {
+ : Store(buffer.data(), indices, value, mask) {
if (buffer.dtype().scalar_type() != value->dtype().scalar_type()) {
throw malformed_input();
}
}
+Store::Store(
+ const Buf* buf,
+ std::vector<const Expr*> indices,
+ const Expr* value,
+ const Expr* mask)
+ : buf_(buf), indices_(std::move(indices)), value_(value), mask_(mask) {
+ if (buf->dtype() != kHandle) {
+ throw malformed_input();
+ }
+ /*
+ TODO: Reenable the checks.
+ The reason they are disabled is that kernel.cpp is using Buffers somewhat
+ loosely: we don't set dimensions properly and just construct index expressions
+ directly. We should harden that part and then we'd be able to turn on these
+ checks.
+
+ if (!indicesValid(indices)) {
+ throw malformed_input();
+ }
+ if (!mask || !value) {
+ throw malformed_input();
+ }
+ Dtype index_dtype = dtypeOfIndices(indices);
+ if (index_dtype.lanes() != mask->dtype().lanes()) {
+ throw malformed_input();
+ }
+ if (index_dtype.lanes() != value->dtype().lanes()) {
+ throw malformed_input();
+ }
+ */
+}
+
+Store* Store::make(
+ const Buffer& buffer,
+ const std::vector<ExprHandle>& indices,
+ const ExprHandle& value,
+ const ExprHandle& mask) {
+ return new Store(
+ buffer, ExprHandleVectorToExprVector(indices), value.node(), mask.node());
+}
+
+Store* Store::make(
+ const BufHandle& buf,
+ const std::vector<ExprHandle>& indices,
+ const ExprHandle& value,
+ const ExprHandle& mask) {
+ return new Store(
+ buf.node(),
+ ExprHandleVectorToExprVector(indices),
+ value.node(),
+ mask.node());
+}
+
+Store* Store::make(
+ const BufHandle& buf,
+ const std::vector<ExprHandle>& indices,
+ const ExprHandle& value) {
+ return new Store(
+ buf.node(),
+ ExprHandleVectorToExprVector(indices),
+ value.node(),
+ ExprHandle(1).node());
+}
+
+const Expr* flatten_index(
+ const std::vector<const Expr*>& dims,
+ const std::vector<const Expr*>& indices) {
+ // Handle already flattened indices first
+ if (indices.size() == 1) {
+ return indices[0];
+ }
+
+ size_t ndim = dims.size();
+ if (ndim != indices.size()) {
+ throw malformed_input();
+ }
+ if (ndim == 0) {
+ return new IntImm(0);
+ }
+ std::vector<const Expr*> strides(ndim);
+ // stride[i] = stride[i+1]*dims[i+1], i < ndim-1
+ // stride[i] = 1, i = ndim-1
+ strides[ndim - 1] = new IntImm(1);
+ for (size_t i = 1; i < ndim; i++) {
+ strides[ndim - 1 - i] = new Mul(strides[ndim - i], dims[ndim - i]);
+ }
+
+ const Expr* total_index = new IntImm(0);
+ for (size_t i = 0; i < ndim; i++) {
+ total_index = new Add(total_index, new Mul(indices[i], strides[i]));
+ }
+ return total_index;
+}
+
Dtype Intrinsics::IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1) {
// TODO: check the op_type and make a real decision
return dt1;
diff --git a/torch/csrc/jit/tensorexpr/ir.h b/torch/csrc/jit/tensorexpr/ir.h
index 136339f..2b7b878 100644
--- a/torch/csrc/jit/tensorexpr/ir.h
+++ b/torch/csrc/jit/tensorexpr/ir.h
@@ -442,39 +442,44 @@
class TORCH_API Load : public ExprNode<Load> {
public:
const Var* base_handle() const {
- return base_handle_;
+ return buf_->base_handle();
}
- const Expr* index() const {
- return index_;
+ std::vector<const Expr*> indices() const {
+ return indices_;
+ }
+ const Expr* flat_index() const {
+ TORCH_CHECK(indices_.size() == 1, "Indices haven't been flattened.");
+ return indices_[0];
}
const Expr* mask() const {
return mask_;
}
+ const Buf* buf() const {
+ return buf_;
+ }
static ExprHandle make(
const Buffer& buffer,
- const ExprHandle& index,
- const ExprHandle& mask) {
- return ExprHandle(new Load(buffer, index.node(), mask.node()));
- }
+ const std::vector<ExprHandle>& indices,
+ const ExprHandle& mask);
static ExprHandle make(
Dtype dtype,
- const VarHandle& base_handle,
- const ExprHandle& index,
- const ExprHandle& mask) {
- return ExprHandle(
- new Load(dtype, base_handle.node(), index.node(), mask.node()));
- }
+ const BufHandle& buf,
+ const std::vector<ExprHandle>& indices,
+ const ExprHandle& mask);
- Load(const Buffer& buffer, const Expr* index, const Expr* mask);
+ Load(
+ const Buffer& buffer,
+ const std::vector<const Expr*>& indices,
+ const Expr* mask);
Load(
Dtype dtype,
- const Var* base_handle,
- const Expr* index,
+ const Buf* base_handle,
+ const std::vector<const Expr*>& indices,
const Expr* mask);
private:
- const Var* base_handle_;
- const Expr* index_;
+ const Buf* buf_;
+ std::vector<const Expr*> indices_;
const Expr* mask_;
};
@@ -872,6 +877,9 @@
const std::vector<VarHandle>&);
TORCH_API std::vector<VarHandle> VarVectorToVarHandleVector(
const std::vector<const Var*>&);
+TORCH_API const Expr* flatten_index(
+ const std::vector<const Expr*>& dims,
+ const std::vector<const Expr*>& indices);
} // namespace tensorexpr
} // namespace jit
diff --git a/torch/csrc/jit/tensorexpr/ir_mutator.cpp b/torch/csrc/jit/tensorexpr/ir_mutator.cpp
index d2a145b..71e68f1 100644
--- a/torch/csrc/jit/tensorexpr/ir_mutator.cpp
+++ b/torch/csrc/jit/tensorexpr/ir_mutator.cpp
@@ -184,18 +184,28 @@
const Expr* IRMutator::mutate(const Load* v) {
Dtype dtype = v->dtype();
- const Var* base_handle = v->base_handle();
- const Expr* index = v->index();
+ const Buf* buf = v->buf();
+
+ bool any_index_changed = false;
+ std::vector<const Expr*> indices_new;
+ for (const Expr* ind : v->indices()) {
+ const Expr* new_ind = ind->accept_mutator(this);
+ if (new_ind != ind) {
+ any_index_changed = true;
+ }
+ indices_new.push_back(new_ind);
+ }
const Expr* mask = v->mask();
- const Expr* base_handle_expr = base_handle->accept_mutator(this);
- const Var* base_handle_new = dynamic_cast<const Var*>(base_handle_expr);
- const Expr* index_new = index->accept_mutator(this);
+ const Buf* buf_new = dynamic_cast<const Buf*>(buf->accept_mutator(this));
const Expr* mask_new = mask->accept_mutator(this);
- if (base_handle == base_handle_new && index == index_new &&
- mask == mask_new) {
+ if (buf == buf_new && !any_index_changed && mask == mask_new) {
return v;
}
- return new Load(dtype, base_handle_new, index_new, mask_new);
+ return new Load(dtype, buf_new, indices_new, mask_new);
+}
+
+const Expr* IRMutator::mutate(const Buf* v) {
+ return v;
}
const Expr* IRMutator::mutate(const Broadcast* v) {
@@ -321,20 +331,27 @@
}
Stmt* IRMutator::mutate(const Store* v) {
- const Var* base_handle = v->base_handle();
- const Expr* index = v->index();
+ const Buf* buf = v->buf();
+
+ bool any_index_changed = false;
+ std::vector<const Expr*> indices_new;
+ for (const Expr* ind : v->indices()) {
+ const Expr* new_ind = ind->accept_mutator(this);
+ if (new_ind != ind) {
+ any_index_changed = true;
+ }
+ indices_new.push_back(new_ind);
+ }
const Expr* value = v->value();
const Expr* mask = v->mask();
- const Expr* base_handle_expr = base_handle->accept_mutator(this);
- const Var* base_handle_new = dynamic_cast<const Var*>(base_handle_expr);
- const Expr* index_new = index->accept_mutator(this);
+ const Buf* buf_new = dynamic_cast<const Buf*>(buf->accept_mutator(this));
const Expr* value_new = value->accept_mutator(this);
const Expr* mask_new = mask->accept_mutator(this);
- if (base_handle == base_handle_new && index == index_new &&
- value == value_new && mask == mask_new) {
+ if (buf == buf_new && !any_index_changed && value == value_new &&
+ mask == mask_new) {
return (Stmt*)v;
}
- return new Store(base_handle_new, index_new, value_new, mask_new);
+ return new Store(buf_new, indices_new, value_new, mask_new);
}
Stmt* IRMutator::mutate(const Allocate* v) {
@@ -436,7 +453,7 @@
}
Stmt* StmtClone::mutate(const Store* v) {
- return new Store(v->base_handle(), v->index(), v->value(), v->mask());
+ return new Store(v->buf(), v->indices(), v->value(), v->mask());
}
Stmt* StmtClone::mutate(const Allocate* v) {
diff --git a/torch/csrc/jit/tensorexpr/ir_mutator.h b/torch/csrc/jit/tensorexpr/ir_mutator.h
index 42c0fca..254a5bb 100644
--- a/torch/csrc/jit/tensorexpr/ir_mutator.h
+++ b/torch/csrc/jit/tensorexpr/ir_mutator.h
@@ -27,6 +27,7 @@
class Cast;
class Var;
+class Buf;
class Let;
class LetStmt;
class Ramp;
@@ -71,6 +72,7 @@
#undef IMM_MUTATE_DECLARE
virtual const Expr* mutate(const Cast* v);
virtual const Expr* mutate(const Var* v);
+ virtual const Expr* mutate(const Buf* v);
virtual const Expr* mutate(const Let* v);
virtual Stmt* mutate(const LetStmt* v);
virtual const Expr* mutate(const Ramp* v);
diff --git a/torch/csrc/jit/tensorexpr/ir_printer.cpp b/torch/csrc/jit/tensorexpr/ir_printer.cpp
index 096a98f..5f8c619 100644
--- a/torch/csrc/jit/tensorexpr/ir_printer.cpp
+++ b/torch/csrc/jit/tensorexpr/ir_printer.cpp
@@ -263,7 +263,15 @@
void IRPrinter::visit(const Load* v) {
// TODO: support the mask case
- os() << *v->base_handle() << "[" << *v->index() << "]";
+ os() << *v->base_handle() << "[";
+ size_t i = 0;
+ for (const Expr* ind : v->indices()) {
+ if (i++) {
+ os() << ", ";
+ }
+ ind->accept(this);
+ }
+ os() << "]";
}
void IRPrinter::visit(const For* v) {
@@ -296,8 +304,15 @@
void IRPrinter::visit(const Store* v) {
// TODO: handle the mask
emitIndent();
- os() << *v->base_handle() << "[" << *v->index() << "] = " << *v->value()
- << ";";
+ os() << *v->base_handle() << "[";
+ size_t i = 0;
+ for (const Expr* ind : v->indices()) {
+ if (i++) {
+ os() << ", ";
+ }
+ ind->accept(this);
+ }
+ os() << "] = " << *v->value() << ";";
}
void IRPrinter::visit(const Broadcast* v) {
diff --git a/torch/csrc/jit/tensorexpr/ir_visitor.cpp b/torch/csrc/jit/tensorexpr/ir_visitor.cpp
index ee97b3b..9506a48 100644
--- a/torch/csrc/jit/tensorexpr/ir_visitor.cpp
+++ b/torch/csrc/jit/tensorexpr/ir_visitor.cpp
@@ -97,14 +97,22 @@
}
void IRVisitor::visit(const Load* v) {
- v->base_handle()->accept(this);
- v->index()->accept(this);
+ v->buf()->accept(this);
+ for (const Expr* ind : v->indices()) {
+ ind->accept(this);
+ }
v->mask()->accept(this);
}
-void IRVisitor::visit(const Store* v) {
+void IRVisitor::visit(const Buf* v) {
v->base_handle()->accept(this);
- v->index()->accept(this);
+}
+
+void IRVisitor::visit(const Store* v) {
+ v->buf()->accept(this);
+ for (const Expr* ind : v->indices()) {
+ ind->accept(this);
+ }
v->value()->accept(this);
v->mask()->accept(this);
}
@@ -151,8 +159,7 @@
}
void IRVisitor::visit(const Allocate* v) {
- const Var* buffer_var = v->buffer_var();
- buffer_var->accept(this);
+ v->buffer_var()->accept(this);
std::vector<const Expr*> dims = v->dims();
for (const Expr* dim : dims) {
dim->accept(this);
@@ -160,8 +167,7 @@
}
void IRVisitor::visit(const Free* v) {
- const Var* buffer_var = v->buffer_var();
- buffer_var->accept(this);
+ v->buffer_var()->accept(this);
}
void IRVisitor::visit(const Cond* v) {
diff --git a/torch/csrc/jit/tensorexpr/ir_visitor.h b/torch/csrc/jit/tensorexpr/ir_visitor.h
index 32819c4..1e1cdf3f 100644
--- a/torch/csrc/jit/tensorexpr/ir_visitor.h
+++ b/torch/csrc/jit/tensorexpr/ir_visitor.h
@@ -27,6 +27,7 @@
class Cast;
class Var;
+class Buf;
class Let;
class LetStmt;
class Ramp;
@@ -70,6 +71,7 @@
virtual void visit(const Cast* v);
virtual void visit(const Var* v);
+ virtual void visit(const Buf* v);
virtual void visit(const Let* v);
virtual void visit(const LetStmt* v);
virtual void visit(const Ramp* v);
diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h
index 4f2f0e2..d1977a9 100644
--- a/torch/csrc/jit/tensorexpr/kernel.h
+++ b/torch/csrc/jit/tensorexpr/kernel.h
@@ -12,9 +12,8 @@
template <typename T>
inline std::vector<int64_t> bufferSizes(const T& t) {
std::vector<int64_t> sizes;
- for (int i = 0; i < t->function()->ndim(); i++) {
- sizes.push_back(
- dynamic_cast<const IntImm*>(t->function()->dim(i))->value());
+ for (int i = 0; i < t->buf()->ndim(); i++) {
+ sizes.push_back(dynamic_cast<const IntImm*>(t->buf()->dim(i))->value());
}
return sizes;
}
@@ -70,7 +69,7 @@
template <typename T, typename T1>
ExprHandle broadcast(const T& t, const std::vector<T1>& axes) {
return t->call(computeIndicesToBroadcast(
- axes, ExprVectorToExprHandleVector(t->function()->dims())));
+ axes, ExprVectorToExprHandleVector(t->buf()->dims())));
}
template <typename T, typename T1>
diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp
index f3b9802..4e16359 100644
--- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp
+++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp
@@ -818,7 +818,7 @@
if (v->dtype().lanes() == 1) {
v->base_handle()->accept(this);
auto base = this->value_;
- v->index()->accept(this);
+ v->flat_index()->accept(this);
auto idx = this->value_;
auto* maskimm = dynamic_cast<const IntImm*>(v->mask());
@@ -856,7 +856,7 @@
}
// Handle the case where the load is contiguous and unmasked efficiently
- auto* idx_ramp = dynamic_cast<const Ramp*>(v->index());
+ auto* idx_ramp = dynamic_cast<const Ramp*>(v->flat_index());
if (unmasked_load && idx_ramp) {
auto* stride_imm = dynamic_cast<const IntImm*>(idx_ramp->stride());
if (stride_imm && stride_imm->value() == 1) {
@@ -876,7 +876,7 @@
// Fallback to a scalar implementation
v->base_handle()->accept(this);
auto base = this->value_;
- v->index()->accept(this);
+ v->flat_index()->accept(this);
auto idx = this->value_;
v->mask()->accept(this);
auto mask = this->value_;
@@ -983,7 +983,7 @@
if (v->value()->dtype().lanes() == 1) {
v->base_handle()->accept(this);
auto base = this->value_;
- v->index()->accept(this);
+ v->flat_index()->accept(this);
auto idx = this->value_;
v->value()->accept(this);
auto val = this->value_;
@@ -1018,7 +1018,7 @@
auto val = this->value_;
// Handle the case where the store is contiguous and unmasked efficiently
- auto* idx_ramp = dynamic_cast<const Ramp*>(v->index());
+ auto* idx_ramp = dynamic_cast<const Ramp*>(v->flat_index());
if (unmasked_store && idx_ramp) {
auto* stride_imm = dynamic_cast<const IntImm*>(idx_ramp->stride());
if (stride_imm && stride_imm->value() == 1) {
@@ -1035,7 +1035,7 @@
}
}
- v->index()->accept(this);
+ v->flat_index()->accept(this);
auto idx = this->value_;
v->mask()->accept(this);
auto mask = this->value_;
diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp
index 4940b04..2ce5ebf 100644
--- a/torch/csrc/jit/tensorexpr/loopnest.cpp
+++ b/torch/csrc/jit/tensorexpr/loopnest.cpp
@@ -29,6 +29,35 @@
} // namespace
+class IndexFlattener : public IRMutator {
+ public:
+ Stmt* flatten(Stmt* s) {
+ return s->accept_mutator(this);
+ }
+ const Expr* mutate(const Load* v) override {
+ if (v->indices().size() == 1) {
+ return v;
+ }
+ return new Load(
+ v->dtype(),
+ v->buf(),
+ {flatten_index(v->buf()->dims(), v->indices())},
+ v->mask());
+ }
+ Stmt* mutate(const Store* v) override {
+ const Expr* value = v->value();
+ const Expr* new_value = value->accept_mutator(this);
+ if (v->indices().size() == 1 && value == new_value) {
+ return (Stmt*)v;
+ }
+ return new Store(
+ v->buf(),
+ {flatten_index(v->buf()->dims(), v->indices())},
+ new_value,
+ v->mask());
+ }
+};
+
class Vectorizer : public IRMutator {
public:
Stmt* vectorize(const For* v) {
@@ -162,13 +191,13 @@
const Expr* mutate(const Load* v) override {
Dtype dtype(v->dtype().scalar_type(), lanes_);
- const Var* base_handle = v->base_handle();
- std::vector<const Expr*> inputs = {v->index(), v->mask()};
+ const Buf* buf = v->buf();
+ std::vector<const Expr*> inputs = {v->flat_index(), v->mask()};
return try_vectorize(v, inputs, [&]() {
return Load::make(
dtype,
- VarHandle(base_handle),
- ExprHandle(inputs[0]),
+ BufHandle(buf),
+ {ExprHandle(inputs[0])},
ExprHandle(inputs[1]));
});
}
@@ -204,12 +233,12 @@
}
Stmt* mutate(const Store* v) override {
- const Var* base_handle = v->base_handle();
- std::vector<const Expr*> inputs = {v->index(), v->value(), v->mask()};
+ const Buf* buf = v->buf();
+ std::vector<const Expr*> inputs = {v->flat_index(), v->value(), v->mask()};
return try_vectorize(v, inputs, [&]() {
return Store::make(
- VarHandle(base_handle),
- ExprHandle(inputs[0]),
+ BufHandle(buf),
+ {ExprHandle(inputs[0])},
ExprHandle(inputs[1]),
ExprHandle(inputs[2]));
});
@@ -299,7 +328,8 @@
Stmt* old_f = Stmt::clone(f);
Stmt* new_f = nullptr;
try {
- new_f = v.vectorize(f);
+ new_f = FlattenIndexes(f);
+ new_f = v.vectorize(dynamic_cast<For*>(new_f));
} catch (std::runtime_error& e) {
// Partial vectorization may have corrupted f
new_f = old_f;
@@ -312,10 +342,8 @@
private:
Expr* mutate(const FunctionCall* v) override {
const Tensor* t = v->tensor();
- Buffer buffer(
- VarHandle(t->func_var()),
- t->body()->dtype(),
- ExprVectorToExprHandleVector(t->dims()));
+ const Buf* b = t->buf();
+ Buffer buffer(BufHandle(b), t->body()->dtype());
const std::vector<const Expr*>& params = v->params();
std::vector<ExprHandle> params_expr(params.size());
for (size_t i = 0; i < params.size(); i++) {
@@ -333,19 +361,20 @@
if (func->func_vars().size() != 1) {
throw unimplemented_lowering();
}
- func_var_set_.insert(func->func_var(0));
+ func_var_set_.insert(func->func_var(0)->base_handle());
}
}
protected:
bool should_inline(Function* func) const {
- return func_var_set_.count(func->func_var(0)) > 0;
+ return func_var_set_.count(func->func_var(0)->base_handle()) > 0;
}
// For the target function, insert the caller/callee pair into the replacement
// mapping.
const Expr* mutate(const FunctionCall* v) override {
Function* func = v->tensor()->function();
+ const Buf* buf = v->tensor()->buf();
// TODO: Support multiple-output functions
if (func->func_vars().size() != 1) {
throw unimplemented_lowering();
@@ -353,7 +382,7 @@
if (should_inline(func)) {
// Insert the caller/callee pair into the mapping.
- for (int i = 0; i < func->ndim(); i++) {
+ for (int i = 0; i < buf->ndim(); i++) {
const Var* func_callee_arg = dynamic_cast<const Var*>(func->arg(i));
const Expr* func_caller_param = v->param(i);
auto iter = inline_mapping_.find(func_callee_arg);
@@ -369,7 +398,7 @@
const Expr* result = body->accept_mutator(this);
// Remove the caller/callee relationship.
- for (int i = 0; i < func->ndim(); i++) {
+ for (int i = 0; i < buf->ndim(); i++) {
const Var* func_callee_arg = dynamic_cast<const Var*>(func->arg(i));
auto iter = inline_mapping_.find(func_callee_arg);
if (iter == inline_mapping_.end()) {
@@ -652,18 +681,18 @@
stmt_to_tensor_[body] = t;
tensor_to_stmt_[t] = body;
- if (f->ndim() == 0) {
+ if (t->buf()->ndim() == 0) {
return body;
}
- if (f->ndim() == 0) {
+ if (t->buf()->ndim() == 0) {
throw malformed_input();
}
- for (size_t i = 0; i < f->ndim(); i++) {
+ for (size_t i = 0; i < t->buf()->ndim(); i++) {
// Going in reverse order: from innermost loop to the outermost
- size_t dim_index = f->ndim() - i - 1;
- Range r(new IntImm(0), f->dim(dim_index));
+ size_t dim_index = t->buf()->ndim() - i - 1;
+ Range r(new IntImm(0), t->buf()->dim(dim_index));
body = new For(f->arg(dim_index), r.start(), r.stop(), body);
}
return body;
@@ -702,8 +731,8 @@
continue;
}
Stmt* alloc = new Allocate(
- tensor->func_var(), tensor->body()->dtype(), tensor->dims());
- Stmt* free = new Free(tensor->func_var());
+ tensor->buf()->base_handle(), tensor->body()->dtype(), tensor->dims());
+ Stmt* free = new Free(tensor->buf()->base_handle());
b->prepend_stmt(alloc);
b->append_stmt(free);
}
@@ -722,6 +751,8 @@
Flattener flattener;
root_stmt_ = root_stmt_->accept_mutator(&flattener);
+ root_stmt_ = FlattenIndexes(root_stmt_);
+
// Add allocs and frees for intermediate buffers at the global level.
root_stmt_ = insertAllocFree(root_stmt_);
}
@@ -879,6 +910,11 @@
return tensor_to_stmt_.count(t) > 0;
}
+Stmt* FlattenIndexes(Stmt* s) {
+ IndexFlattener idx_flattener;
+ return idx_flattener.flatten(s);
+}
+
} // namespace tensorexpr
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/tensorexpr/loopnest.h b/torch/csrc/jit/tensorexpr/loopnest.h
index 37c56e4..b5d1264 100644
--- a/torch/csrc/jit/tensorexpr/loopnest.h
+++ b/torch/csrc/jit/tensorexpr/loopnest.h
@@ -57,6 +57,8 @@
std::unordered_set<Tensor*> intermediate_tensors_;
};
+TORCH_API Stmt* FlattenIndexes(Stmt* s);
+
// represent a range [start, stop)
class Range {
public:
diff --git a/torch/csrc/jit/tensorexpr/stmt.h b/torch/csrc/jit/tensorexpr/stmt.h
index c5cf547..944d327 100644
--- a/torch/csrc/jit/tensorexpr/stmt.h
+++ b/torch/csrc/jit/tensorexpr/stmt.h
@@ -165,10 +165,14 @@
class TORCH_API Store : public StmtNode<Store> {
public:
const Var* base_handle() const {
- return base_handle_;
+ return buf_->base_handle();
}
- const Expr* index() const {
- return index_;
+ std::vector<const Expr*> indices() const {
+ return indices_;
+ }
+ const Expr* flat_index() const {
+ TORCH_CHECK(indices_.size() == 1, "Indices haven't been flattened.");
+ return indices_[0];
}
const Expr* value() const {
return value_;
@@ -176,62 +180,43 @@
const Expr* mask() const {
return mask_;
}
+ const Buf* buf() const {
+ return buf_;
+ }
static Store* make(
const Buffer& buffer,
- const ExprHandle& index,
+ const std::vector<ExprHandle>& indices,
const ExprHandle& value,
- const ExprHandle& mask) {
- return new Store(buffer, index.node(), value.node(), mask.node());
- }
+ const ExprHandle& mask);
static Store* make(
- const VarHandle& base_handle,
- const ExprHandle& index,
+ const BufHandle& buf,
+ const std::vector<ExprHandle>& indices,
const ExprHandle& value,
- const ExprHandle& mask) {
- return new Store(
- base_handle.node(), index.node(), value.node(), mask.node());
- }
+ const ExprHandle& mask);
static Store* make(
- const VarHandle& base_handle,
- const ExprHandle& index,
- const ExprHandle& value) {
- return new Store(
- base_handle.node(), index.node(), value.node(), ExprHandle(1).node());
- }
+ const BufHandle& buf,
+ const std::vector<ExprHandle>& indices,
+ const ExprHandle& value);
// TODO: merge this with Load.
Store(
const Buffer& buffer,
- const Expr* index,
+ const std::vector<const Expr*>& indices,
const Expr* value,
const Expr* mask);
Store(
- const Var* base_handle,
- const Expr* index,
+ const Buf* buf,
+ std::vector<const Expr*> indices,
const Expr* value,
- const Expr* mask)
- : base_handle_(base_handle), index_(index), value_(value), mask_(mask) {
- if (base_handle_->dtype() != kHandle) {
- throw malformed_input(base_handle);
- }
-
- if (index->dtype().lanes() != mask->dtype().lanes() ||
- index->dtype().lanes() != value->dtype().lanes()) {
- throw malformed_input();
- }
-
- if (index->dtype().scalar_type() != ScalarType::Int) {
- throw unsupported_dtype();
- }
- }
+ const Expr* mask);
private:
- const Var* base_handle_;
- const Expr* index_;
+ const Buf* buf_;
+ std::vector<const Expr*> indices_;
const Expr* value_;
const Expr* mask_;
};
diff --git a/torch/csrc/jit/tensorexpr/tensor.h b/torch/csrc/jit/tensorexpr/tensor.h
index a896ed5..f1c9dff 100644
--- a/torch/csrc/jit/tensorexpr/tensor.h
+++ b/torch/csrc/jit/tensorexpr/tensor.h
@@ -23,17 +23,17 @@
const Expr* body() const {
return function()->body(output_index());
}
- const Var* func_var() const {
+ const Buf* func_var() const {
return function()->func_var(output_index());
}
int ndim() const {
- return function()->dims().size();
+ return buf_->dims().size();
}
const Expr* dim(int index) const {
- return function()->dim(index);
+ return buf_->dim(index);
}
- const std::vector<const Expr*>& dims() const {
- return function()->dims();
+ std::vector<const Expr*> dims() const {
+ return buf_->dims();
}
const Var* arg(int index) const {
return function()->arg(index);
@@ -42,8 +42,12 @@
return function()->args();
}
- Tensor(Function* function, int output_index)
- : function_(function), output_index_(output_index) {}
+ const Buf* buf() const {
+ return buf_;
+ }
+
+ Tensor(const Buf* buf, Function* function, int output_index)
+ : buf_(buf), function_(function), output_index_(output_index) {}
template <typename... Ts>
inline ExprHandle operator()(const Ts&... ts);
template <typename T>
@@ -52,6 +56,7 @@
inline ExprHandle call(const Ts&... ts);
private:
+ const Buf* buf_;
Function* function_;
int output_index_;
};