[TensorExpr] Make Load and Store multi-dimensional. (#35800)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35800

This PR includes the following changes:
* Introduce a new `Expr` type `Buf`: it plays a similar to `Var` role, but also has dimensions.
* Use the new `Buf` class in `Store` and `Load` instead of `Var` for specifying where to store to or load from. `Buf` contains the dimensions info of the buffer we're loading/storing to and hence we are able to keep N-d indexes without flattening them into a 1-d index ([x,y] vs [x+y*W]).
* Flattening of the indexes is now a separate pass that is executed in `LoopNest::prepareForCodegen` - backends still expect indexes to be flattened, and this PR preserves that.
* `Tensor` now contains a `Buf` instead of `Var`, and thus Tensor now has the dimensions info (previously it was a property of a `Function`, not a `Tensor`). This brings us closer to Tensor being a combination of Buffer + Function, where Buffer specifies iteration domain and the Function defines a computation.

TODOs:
* Consider merging `Buffer` with `Buf` or `BufHandle`. It seems that we don't need all of them.
* Harden the logic of how we create buffers in fuser pass. Currently it seems that sometimes we don't set dimensions.
* Use `Buf` in `Allocate` and `Free`.
* Make it clearer that `Function` doesn't "own" dimensions info and that dimensions are a property of a Tensor, not a Function.

Differential Revision: D20789005

Test Plan: Imported from OSS

Reviewed By: zheng-xq

Pulled By: ZolotukhinM

fbshipit-source-id: e04188d1d297f195f1c46669c614557d6bb6cde4
diff --git a/test/cpp/tensorexpr/test_aten.cpp b/test/cpp/tensorexpr/test_aten.cpp
index 9d098d8..1b84666 100644
--- a/test/cpp/tensorexpr/test_aten.cpp
+++ b/test/cpp/tensorexpr/test_aten.cpp
@@ -15,13 +15,13 @@
 void testATen_cast_Float() {
   KernelScope kernel_scope;
   const int kTotalSize = 128;
-  Buffer a_buf(VarHandle("A", kHandle), kInt, {ExprHandle(kTotalSize)});
-  Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kInt);
+  Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
 
   VarHandle index = VarHandle("index", kInt);
-  ExprHandle load_a = Load::make(a_buf, index, 1);
+  ExprHandle load_a = Load::make(a_buf, {index}, 1);
   ExprHandle to_float = Cast::make(kFloat, load_a);
-  Stmt* store_b = Store::make(b_buf, index, to_float, 1);
+  Stmt* store_b = Store::make(b_buf, {index}, to_float, 1);
   Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
 
   PaddedBuffer<int> a_v(kTotalSize);
@@ -43,13 +43,13 @@
 void testATennegInt() {
   KernelScope kernel_scope;
   const int kTotalSize = 128;
-  Buffer a_buf(VarHandle("A", kHandle), kInt, {ExprHandle(kTotalSize)});
-  Buffer b_buf(VarHandle("B", kHandle), kInt, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kInt);
+  Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kInt);
 
   VarHandle index = VarHandle("index", kInt);
-  ExprHandle load_a = Load::make(a_buf, index, 1);
+  ExprHandle load_a = Load::make(a_buf, {index}, 1);
   ExprHandle to_float = Sub::make(0, load_a);
-  Stmt* store_b = Store::make(b_buf, index, to_float, 1);
+  Stmt* store_b = Store::make(b_buf, {index}, to_float, 1);
   Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
 
   PaddedBuffer<int> a_v(kTotalSize);
@@ -71,13 +71,13 @@
 void testATennegFloat() {
   KernelScope kernel_scope;
   const int kTotalSize = 128;
-  Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
 
   VarHandle index = VarHandle("index", kInt);
-  ExprHandle load_a = Load::make(a_buf, index, 1);
+  ExprHandle load_a = Load::make(a_buf, {index}, 1);
   ExprHandle to_float = Sub::make(0, load_a);
-  Stmt* store_b = Store::make(b_buf, index, to_float, 1);
+  Stmt* store_b = Store::make(b_buf, {index}, to_float, 1);
   Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
 
   PaddedBuffer<float> a_v(kTotalSize);
@@ -99,16 +99,16 @@
 void testATenaddInt() {
   KernelScope kernel_scope;
   const int kTotalSize = 128;
-  Buffer a_buf(VarHandle("A", kHandle), kInt, {ExprHandle(kTotalSize)});
-  Buffer b_buf(VarHandle("B", kHandle), kInt, {ExprHandle(kTotalSize)});
-  Buffer c_buf(VarHandle("C", kHandle), kInt, {ExprHandle(kTotalSize)});
-  Buffer d_buf(VarHandle("D", kHandle), kInt, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kInt);
+  Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kInt);
+  Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kInt);
+  Buffer d_buf(BufHandle("D", {ExprHandle(kTotalSize)}), kInt);
 
   VarHandle index = VarHandle("index", kInt);
-  ExprHandle load_a = Load::make(a_buf, index, 1);
-  ExprHandle load_b = Load::make(b_buf, index, 1);
-  ExprHandle load_c = Load::make(c_buf, index, 1);
-  Stmt* store_d = Store::make(d_buf, index, load_a + load_b * load_c, 1);
+  ExprHandle load_a = Load::make(a_buf, {index}, 1);
+  ExprHandle load_b = Load::make(b_buf, {index}, 1);
+  ExprHandle load_c = Load::make(c_buf, {index}, 1);
+  Stmt* store_d = Store::make(d_buf, {index}, load_a + load_b * load_c, 1);
   Stmt* stmt = For::make(index, 0, kTotalSize, store_d);
 
   PaddedBuffer<int> a_v(kTotalSize);
@@ -136,16 +136,16 @@
 void testATenaddFloat() {
   KernelScope kernel_scope;
   const int kTotalSize = 128;
-  Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer d_buf(VarHandle("D", kHandle), kFloat, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer d_buf(BufHandle("D", {ExprHandle(kTotalSize)}), kFloat);
 
   VarHandle index = VarHandle("index", kInt);
-  ExprHandle load_a = Load::make(a_buf, index, 1);
-  ExprHandle load_b = Load::make(b_buf, index, 1);
-  ExprHandle load_c = Load::make(c_buf, index, 1);
-  Stmt* store_d = Store::make(d_buf, index, load_a + load_b * load_c, 1);
+  ExprHandle load_a = Load::make(a_buf, {index}, 1);
+  ExprHandle load_b = Load::make(b_buf, {index}, 1);
+  ExprHandle load_c = Load::make(c_buf, {index}, 1);
+  Stmt* store_d = Store::make(d_buf, {index}, load_a + load_b * load_c, 1);
   Stmt* stmt = For::make(index, 0, kTotalSize, store_d);
 
   PaddedBuffer<float> a_v(kTotalSize);
@@ -173,16 +173,16 @@
 void testATensubInt() {
   KernelScope kernel_scope;
   const int kTotalSize = 128;
-  Buffer a_buf(VarHandle("A", kHandle), kInt, {ExprHandle(kTotalSize)});
-  Buffer b_buf(VarHandle("B", kHandle), kInt, {ExprHandle(kTotalSize)});
-  Buffer c_buf(VarHandle("C", kHandle), kInt, {ExprHandle(kTotalSize)});
-  Buffer d_buf(VarHandle("D", kHandle), kInt, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kInt);
+  Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kInt);
+  Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kInt);
+  Buffer d_buf(BufHandle("D", {ExprHandle(kTotalSize)}), kInt);
 
   VarHandle index = VarHandle("index", kInt);
-  ExprHandle load_a = Load::make(a_buf, index, 1);
-  ExprHandle load_b = Load::make(b_buf, index, 1);
-  ExprHandle load_c = Load::make(c_buf, index, 1);
-  Stmt* store_d = Store::make(d_buf, index, load_a - load_b * load_c, 1);
+  ExprHandle load_a = Load::make(a_buf, {index}, 1);
+  ExprHandle load_b = Load::make(b_buf, {index}, 1);
+  ExprHandle load_c = Load::make(c_buf, {index}, 1);
+  Stmt* store_d = Store::make(d_buf, {index}, load_a - load_b * load_c, 1);
   Stmt* stmt = For::make(index, 0, kTotalSize, store_d);
 
   PaddedBuffer<int> a_v(kTotalSize);
@@ -210,16 +210,16 @@
 void testATensubFloat() {
   KernelScope kernel_scope;
   const int kTotalSize = 128;
-  Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer d_buf(VarHandle("D", kHandle), kFloat, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer d_buf(BufHandle("D", {ExprHandle(kTotalSize)}), kFloat);
 
   VarHandle index = VarHandle("index", kInt);
-  ExprHandle load_a = Load::make(a_buf, index, 1);
-  ExprHandle load_b = Load::make(b_buf, index, 1);
-  ExprHandle load_c = Load::make(c_buf, index, 1);
-  Stmt* store_d = Store::make(d_buf, index, load_a - load_b * load_c, 1);
+  ExprHandle load_a = Load::make(a_buf, {index}, 1);
+  ExprHandle load_b = Load::make(b_buf, {index}, 1);
+  ExprHandle load_c = Load::make(c_buf, {index}, 1);
+  Stmt* store_d = Store::make(d_buf, {index}, load_a - load_b * load_c, 1);
   Stmt* stmt = For::make(index, 0, kTotalSize, store_d);
 
   PaddedBuffer<float> a_v(kTotalSize);
@@ -247,17 +247,17 @@
 void testATenlerp() {
   KernelScope kernel_scope;
   const int kTotalSize = 128;
-  Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer d_buf(VarHandle("D", kHandle), kFloat, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer d_buf(BufHandle("D", {ExprHandle(kTotalSize)}), kFloat);
 
   VarHandle index = VarHandle("index", kInt);
-  ExprHandle load_a = Load::make(a_buf, index, 1);
-  ExprHandle load_b = Load::make(b_buf, index, 1);
-  ExprHandle load_c = Load::make(c_buf, index, 1);
+  ExprHandle load_a = Load::make(a_buf, {index}, 1);
+  ExprHandle load_b = Load::make(b_buf, {index}, 1);
+  ExprHandle load_c = Load::make(c_buf, {index}, 1);
   Stmt* store_d =
-      Store::make(d_buf, index, load_a + load_c * (load_b - load_a), 1);
+      Store::make(d_buf, {index}, load_a + load_c * (load_b - load_a), 1);
   Stmt* stmt = For::make(index, 0, kTotalSize, store_d);
 
   PaddedBuffer<float> a_v(kTotalSize);
@@ -285,19 +285,19 @@
 void testATenaddcmulInt() {
   KernelScope kernel_scope;
   const int kTotalSize = 128;
-  Buffer a_buf(VarHandle("A", kHandle), kInt, {ExprHandle(kTotalSize)});
-  Buffer b_buf(VarHandle("B", kHandle), kInt, {ExprHandle(kTotalSize)});
-  Buffer c_buf(VarHandle("C", kHandle), kInt, {ExprHandle(kTotalSize)});
-  Buffer d_buf(VarHandle("D", kHandle), kInt, {ExprHandle(kTotalSize)});
-  Buffer e_buf(VarHandle("E", kHandle), kInt, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kInt);
+  Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kInt);
+  Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kInt);
+  Buffer d_buf(BufHandle("D", {ExprHandle(kTotalSize)}), kInt);
+  Buffer e_buf(BufHandle("E", {ExprHandle(kTotalSize)}), kInt);
 
   VarHandle index = VarHandle("index", kInt);
-  ExprHandle load_a = Load::make(a_buf, index, 1);
-  ExprHandle load_b = Load::make(b_buf, index, 1);
-  ExprHandle load_c = Load::make(c_buf, index, 1);
-  ExprHandle load_d = Load::make(d_buf, index, 1);
+  ExprHandle load_a = Load::make(a_buf, {index}, 1);
+  ExprHandle load_b = Load::make(b_buf, {index}, 1);
+  ExprHandle load_c = Load::make(c_buf, {index}, 1);
+  ExprHandle load_d = Load::make(d_buf, {index}, 1);
   Stmt* store_e =
-      Store::make(e_buf, index, load_a + load_b * load_c * load_d, 1);
+      Store::make(e_buf, {index}, load_a + load_b * load_c * load_d, 1);
   Stmt* stmt = For::make(index, 0, kTotalSize, store_e);
 
   PaddedBuffer<int> a_v(kTotalSize);
@@ -328,19 +328,19 @@
 void testATenaddcmulFloat() {
   KernelScope kernel_scope;
   const int kTotalSize = 128;
-  Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer d_buf(VarHandle("D", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer e_buf(VarHandle("E", kHandle), kFloat, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer d_buf(BufHandle("D", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer e_buf(BufHandle("E", {ExprHandle(kTotalSize)}), kFloat);
 
   VarHandle index = VarHandle("index", kInt);
-  ExprHandle load_a = Load::make(a_buf, index, 1);
-  ExprHandle load_b = Load::make(b_buf, index, 1);
-  ExprHandle load_c = Load::make(c_buf, index, 1);
-  ExprHandle load_d = Load::make(d_buf, index, 1);
+  ExprHandle load_a = Load::make(a_buf, {index}, 1);
+  ExprHandle load_b = Load::make(b_buf, {index}, 1);
+  ExprHandle load_c = Load::make(c_buf, {index}, 1);
+  ExprHandle load_d = Load::make(d_buf, {index}, 1);
   Stmt* store_e =
-      Store::make(e_buf, index, load_a + load_b * load_c * load_d, 1);
+      Store::make(e_buf, {index}, load_a + load_b * load_c * load_d, 1);
   Stmt* stmt = For::make(index, 0, kTotalSize, store_e);
 
   PaddedBuffer<float> a_v(kTotalSize);
@@ -371,14 +371,14 @@
 void testATenmulInt() {
   KernelScope kernel_scope;
   const int kTotalSize = 128;
-  Buffer a_buf(VarHandle("A", kHandle), kInt, {ExprHandle(kTotalSize)});
-  Buffer b_buf(VarHandle("B", kHandle), kInt, {ExprHandle(kTotalSize)});
-  Buffer c_buf(VarHandle("C", kHandle), kInt, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kInt);
+  Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kInt);
+  Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kInt);
 
   VarHandle index = VarHandle("index", kInt);
-  ExprHandle load_a = Load::make(a_buf, index, 1);
-  ExprHandle load_b = Load::make(b_buf, index, 1);
-  Stmt* store_c = Store::make(c_buf, index, load_a * load_b, 1);
+  ExprHandle load_a = Load::make(a_buf, {index}, 1);
+  ExprHandle load_b = Load::make(b_buf, {index}, 1);
+  Stmt* store_c = Store::make(c_buf, {index}, load_a * load_b, 1);
   Stmt* stmt = For::make(index, 0, kTotalSize, store_c);
 
   PaddedBuffer<int> a_v(kTotalSize);
@@ -403,14 +403,14 @@
 void testATenmulFloat() {
   KernelScope kernel_scope;
   const int kTotalSize = 128;
-  Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kFloat);
 
   VarHandle index = VarHandle("index", kInt);
-  ExprHandle load_a = Load::make(a_buf, index, 1);
-  ExprHandle load_b = Load::make(b_buf, index, 1);
-  Stmt* store_c = Store::make(c_buf, index, load_a * load_b, 1);
+  ExprHandle load_a = Load::make(a_buf, {index}, 1);
+  ExprHandle load_b = Load::make(b_buf, {index}, 1);
+  Stmt* store_c = Store::make(c_buf, {index}, load_a * load_b, 1);
   Stmt* stmt = For::make(index, 0, kTotalSize, store_c);
 
   PaddedBuffer<float> a_v(kTotalSize);
@@ -435,14 +435,14 @@
 void testATendivInt() {
   KernelScope kernel_scope;
   const int kTotalSize = 128;
-  Buffer a_buf(VarHandle("A", kHandle), kInt, {ExprHandle(kTotalSize)});
-  Buffer b_buf(VarHandle("B", kHandle), kInt, {ExprHandle(kTotalSize)});
-  Buffer c_buf(VarHandle("C", kHandle), kInt, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kInt);
+  Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kInt);
+  Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kInt);
 
   VarHandle index = VarHandle("index", kInt);
-  ExprHandle load_a = Load::make(a_buf, index, 1);
-  ExprHandle load_b = Load::make(b_buf, index, 1);
-  Stmt* store_c = Store::make(c_buf, index, load_a / load_b, 1);
+  ExprHandle load_a = Load::make(a_buf, {index}, 1);
+  ExprHandle load_b = Load::make(b_buf, {index}, 1);
+  Stmt* store_c = Store::make(c_buf, {index}, load_a / load_b, 1);
   Stmt* stmt = For::make(index, 0, kTotalSize, store_c);
 
   PaddedBuffer<int> a_v(kTotalSize);
@@ -467,14 +467,14 @@
 void testATendivFloat() {
   KernelScope kernel_scope;
   const int kTotalSize = 128;
-  Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kFloat);
 
   VarHandle index = VarHandle("index", kInt);
-  ExprHandle load_a = Load::make(a_buf, index, 1);
-  ExprHandle load_b = Load::make(b_buf, index, 1);
-  Stmt* store_c = Store::make(c_buf, index, load_a / load_b, 1);
+  ExprHandle load_a = Load::make(a_buf, {index}, 1);
+  ExprHandle load_b = Load::make(b_buf, {index}, 1);
+  Stmt* store_c = Store::make(c_buf, {index}, load_a / load_b, 1);
   Stmt* stmt = For::make(index, 0, kTotalSize, store_c);
 
   PaddedBuffer<float> a_v(kTotalSize);
@@ -499,14 +499,14 @@
 void testATenmaxInt() {
   KernelScope kernel_scope;
   const int kTotalSize = 128;
-  Buffer a_buf(VarHandle("A", kHandle), kInt, {ExprHandle(kTotalSize)});
-  Buffer b_buf(VarHandle("B", kHandle), kInt, {ExprHandle(kTotalSize)});
-  Buffer c_buf(VarHandle("C", kHandle), kInt, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kInt);
+  Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kInt);
+  Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kInt);
 
   VarHandle index = VarHandle("index", kInt);
-  ExprHandle load_a = Load::make(a_buf, index, 1);
-  ExprHandle load_b = Load::make(b_buf, index, 1);
-  Stmt* store_c = Store::make(c_buf, index, Max::make(load_a, load_b, true), 1);
+  ExprHandle load_a = Load::make(a_buf, {index}, 1);
+  ExprHandle load_b = Load::make(b_buf, {index}, 1);
+  Stmt* store_c = Store::make(c_buf, {index}, Max::make(load_a, load_b, true), 1);
   Stmt* stmt = For::make(index, 0, kTotalSize, store_c);
 
   PaddedBuffer<int> a_v(kTotalSize);
@@ -531,14 +531,14 @@
 void testATenmaxFloat() {
   KernelScope kernel_scope;
   const int kTotalSize = 128;
-  Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kFloat);
 
   VarHandle index = VarHandle("index", kInt);
-  ExprHandle load_a = Load::make(a_buf, index, 1);
-  ExprHandle load_b = Load::make(b_buf, index, 1);
-  Stmt* store_c = Store::make(c_buf, index, Max::make(load_a, load_b, true), 1);
+  ExprHandle load_a = Load::make(a_buf, {index}, 1);
+  ExprHandle load_b = Load::make(b_buf, {index}, 1);
+  Stmt* store_c = Store::make(c_buf, {index}, Max::make(load_a, load_b, true), 1);
   Stmt* stmt = For::make(index, 0, kTotalSize, store_c);
 
   PaddedBuffer<float> a_v(kTotalSize);
@@ -563,14 +563,14 @@
 void testATenminInt() {
   KernelScope kernel_scope;
   const int kTotalSize = 128;
-  Buffer a_buf(VarHandle("A", kHandle), kInt, {ExprHandle(kTotalSize)});
-  Buffer b_buf(VarHandle("B", kHandle), kInt, {ExprHandle(kTotalSize)});
-  Buffer c_buf(VarHandle("C", kHandle), kInt, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kInt);
+  Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kInt);
+  Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kInt);
 
   VarHandle index = VarHandle("index", kInt);
-  ExprHandle load_a = Load::make(a_buf, index, 1);
-  ExprHandle load_b = Load::make(b_buf, index, 1);
-  Stmt* store_c = Store::make(c_buf, index, Min::make(load_a, load_b, true), 1);
+  ExprHandle load_a = Load::make(a_buf, {index}, 1);
+  ExprHandle load_b = Load::make(b_buf, {index}, 1);
+  Stmt* store_c = Store::make(c_buf, {index}, Min::make(load_a, load_b, true), 1);
   Stmt* stmt = For::make(index, 0, kTotalSize, store_c);
 
   PaddedBuffer<int> a_v(kTotalSize);
@@ -595,14 +595,14 @@
 void testATenminFloat() {
   KernelScope kernel_scope;
   const int kTotalSize = 128;
-  Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kFloat);
 
   VarHandle index = VarHandle("index", kInt);
-  ExprHandle load_a = Load::make(a_buf, index, 1);
-  ExprHandle load_b = Load::make(b_buf, index, 1);
-  Stmt* store_c = Store::make(c_buf, index, Min::make(load_a, load_b, true), 1);
+  ExprHandle load_a = Load::make(a_buf, {index}, 1);
+  ExprHandle load_b = Load::make(b_buf, {index}, 1);
+  Stmt* store_c = Store::make(c_buf, {index}, Min::make(load_a, load_b, true), 1);
   Stmt* stmt = For::make(index, 0, kTotalSize, store_c);
 
   PaddedBuffer<float> a_v(kTotalSize);
@@ -627,15 +627,15 @@
 void testATen_sigmoid_backward() {
   KernelScope kernel_scope;
   const int kTotalSize = 128;
-  Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kFloat);
 
   VarHandle index = VarHandle("index", kInt);
-  ExprHandle load_a = Load::make(a_buf, index, 1);
-  ExprHandle load_b = Load::make(b_buf, index, 1);
+  ExprHandle load_a = Load::make(a_buf, {index}, 1);
+  ExprHandle load_b = Load::make(b_buf, {index}, 1);
   Stmt* store_c = Store::make(
-      c_buf, index, load_a * load_b * (FloatImm::make(1.0f) - load_b), 1);
+      c_buf, {index}, load_a * load_b * (FloatImm::make(1.0f) - load_b), 1);
   Stmt* stmt = For::make(index, 0, kTotalSize, store_c);
 
   PaddedBuffer<float> a_v(kTotalSize);
@@ -660,15 +660,15 @@
 void testATen_tanh_backward() {
   KernelScope kernel_scope;
   const int kTotalSize = 128;
-  Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kFloat);
 
   VarHandle index = VarHandle("index", kInt);
-  ExprHandle load_a = Load::make(a_buf, index, 1);
-  ExprHandle load_b = Load::make(b_buf, index, 1);
+  ExprHandle load_a = Load::make(a_buf, {index}, 1);
+  ExprHandle load_b = Load::make(b_buf, {index}, 1);
   Stmt* store_c = Store::make(
-      c_buf, index, load_a * (FloatImm::make(1.0f) - (load_b * load_b)), 1);
+      c_buf, {index}, load_a * (FloatImm::make(1.0f) - (load_b * load_b)), 1);
   Stmt* stmt = For::make(index, 0, kTotalSize, store_c);
 
   PaddedBuffer<float> a_v(kTotalSize);
@@ -693,12 +693,12 @@
 void __ubsan_ignore_float_divide_by_zero__ testATenreciprocal() {
   KernelScope kernel_scope;
   const int kTotalSize = 128;
-  Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
 
   VarHandle index = VarHandle("index", kInt);
-  ExprHandle load_a = Load::make(a_buf, index, 1);
-  Stmt* store_b = Store::make(b_buf, index, FloatImm::make(1.0f) / load_a, 1);
+  ExprHandle load_a = Load::make(a_buf, {index}, 1);
+  Stmt* store_b = Store::make(b_buf, {index}, FloatImm::make(1.0f) / load_a, 1);
   Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
 
   PaddedBuffer<float> a_v(kTotalSize);
@@ -720,12 +720,12 @@
 void testATenreluInt() {
   KernelScope kernel_scope;
   const int kTotalSize = 128;
-  Buffer a_buf(VarHandle("A", kHandle), kInt, {ExprHandle(kTotalSize)});
-  Buffer b_buf(VarHandle("B", kHandle), kInt, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kInt);
+  Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kInt);
 
   VarHandle index = VarHandle("index", kInt);
-  ExprHandle load_a = Load::make(a_buf, index, 1);
-  Stmt* store_b = Store::make(b_buf, index, Max::make(load_a, 0, false), 1);
+  ExprHandle load_a = Load::make(a_buf, {index}, 1);
+  Stmt* store_b = Store::make(b_buf, {index}, Max::make(load_a, 0, false), 1);
   Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
 
   PaddedBuffer<int> a_v(kTotalSize);
@@ -747,14 +747,14 @@
 void testATenreluFloat() {
   KernelScope kernel_scope;
   const int kTotalSize = 128;
-  Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
 
   VarHandle index = VarHandle("index", kInt);
-  ExprHandle load_a = Load::make(a_buf, index, 1);
+  ExprHandle load_a = Load::make(a_buf, {index}, 1);
   Stmt* store_b = Store::make(
       b_buf,
-      index,
+      {index},
       Max::make(load_a, 0, false), // relu does not propagate nans
       1);
   Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
@@ -778,12 +778,12 @@
 void testATenlogFloat() {
   KernelScope kernel_scope;
   const int kTotalSize = 128;
-  Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
 
   VarHandle index = VarHandle("index", kInt);
-  ExprHandle load_a = Load::make(a_buf, index, 1);
-  Stmt* store_b = Store::make(b_buf, index, log(load_a), 1);
+  ExprHandle load_a = Load::make(a_buf, {index}, 1);
+  Stmt* store_b = Store::make(b_buf, {index}, log(load_a), 1);
   Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
 
   PaddedBuffer<float> a_v(kTotalSize);
@@ -805,12 +805,12 @@
 void testATenlog10Float() {
   KernelScope kernel_scope;
   const int kTotalSize = 128;
-  Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
 
   VarHandle index = VarHandle("index", kInt);
-  ExprHandle load_a = Load::make(a_buf, index, 1);
-  Stmt* store_b = Store::make(b_buf, index, log10(load_a), 1);
+  ExprHandle load_a = Load::make(a_buf, {index}, 1);
+  Stmt* store_b = Store::make(b_buf, {index}, log10(load_a), 1);
   Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
 
   PaddedBuffer<float> a_v(kTotalSize);
@@ -832,12 +832,12 @@
 void testATenlog2Float() {
   KernelScope kernel_scope;
   const int kTotalSize = 128;
-  Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
 
   VarHandle index = VarHandle("index", kInt);
-  ExprHandle load_a = Load::make(a_buf, index, 1);
-  Stmt* store_b = Store::make(b_buf, index, log2(load_a), 1);
+  ExprHandle load_a = Load::make(a_buf, {index}, 1);
+  Stmt* store_b = Store::make(b_buf, {index}, log2(load_a), 1);
   Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
 
   PaddedBuffer<float> a_v(kTotalSize);
@@ -859,12 +859,12 @@
 void testATenexpFloat() {
   KernelScope kernel_scope;
   const int kTotalSize = 128;
-  Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
 
   VarHandle index = VarHandle("index", kInt);
-  ExprHandle load_a = Load::make(a_buf, index, 1);
-  Stmt* store_b = Store::make(b_buf, index, exp(load_a), 1);
+  ExprHandle load_a = Load::make(a_buf, {index}, 1);
+  Stmt* store_b = Store::make(b_buf, {index}, exp(load_a), 1);
   Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
 
   PaddedBuffer<float> a_v(kTotalSize);
@@ -886,12 +886,12 @@
 void testATenerfFloat() {
   KernelScope kernel_scope;
   const int kTotalSize = 128;
-  Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
 
   VarHandle index = VarHandle("index", kInt);
-  ExprHandle load_a = Load::make(a_buf, index, 1);
-  Stmt* store_b = Store::make(b_buf, index, erf(load_a), 1);
+  ExprHandle load_a = Load::make(a_buf, {index}, 1);
+  Stmt* store_b = Store::make(b_buf, {index}, erf(load_a), 1);
   Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
 
   PaddedBuffer<float> a_v(kTotalSize);
@@ -913,12 +913,12 @@
 void testATencosFloat() {
   KernelScope kernel_scope;
   const int kTotalSize = 128;
-  Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
 
   VarHandle index = VarHandle("index", kInt);
-  ExprHandle load_a = Load::make(a_buf, index, 1);
-  Stmt* store_b = Store::make(b_buf, index, cos(load_a), 1);
+  ExprHandle load_a = Load::make(a_buf, {index}, 1);
+  Stmt* store_b = Store::make(b_buf, {index}, cos(load_a), 1);
   Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
 
   PaddedBuffer<float> a_v(kTotalSize);
@@ -940,9 +940,9 @@
 void testATeneqInt() {
   KernelScope kernel_scope;
   constexpr int N = 128;
-  Buffer a(VarHandle("A", kHandle), kInt, {N});
-  Buffer b(VarHandle("B", kHandle), kInt, {N});
-  Buffer c(VarHandle("C", kHandle), kInt, {N});
+  Buffer a(BufHandle("A", {N}), kInt);
+  Buffer b(BufHandle("B", {N}), kInt);
+  Buffer c(BufHandle("C", {N}), kInt);
   std::vector<int> a_buffer(N, 1);
   std::vector<int> b_buffer(N, 1);
   std::vector<int> c_buffer(N, 0);
@@ -955,10 +955,10 @@
       N,
       Store::make(
           c,
-          i,
+          {i},
           CompareSelect::make(
-              Load::make(a, i, mask),
-              Load::make(b, i, mask),
+              Load::make(a, {i}, mask),
+              Load::make(b, {i}, mask),
               CompareSelectOperation::kEQ),
           mask));
 
@@ -971,9 +971,9 @@
 void testATengeInt() {
   KernelScope kernel_scope;
   constexpr int N = 128;
-  Buffer a(VarHandle("A", kHandle), kInt, {N});
-  Buffer b(VarHandle("B", kHandle), kInt, {N});
-  Buffer c(VarHandle("C", kHandle), kInt, {N});
+  Buffer a(BufHandle("A", {N}), kInt);
+  Buffer b(BufHandle("B", {N}), kInt);
+  Buffer c(BufHandle("C", {N}), kInt);
   std::vector<int> a_buffer(N, 5);
   std::vector<int> b_buffer(N, 5);
   std::vector<int> c_buffer(N, 0);
@@ -986,10 +986,10 @@
       N,
       Store::make(
           c,
-          i,
+          {i},
           CompareSelect::make(
-              Load::make(a, i, mask),
-              Load::make(b, i, mask),
+              Load::make(a, {i}, mask),
+              Load::make(b, {i}, mask),
               CompareSelectOperation::kGE),
           mask));
 
@@ -1002,9 +1002,9 @@
 void testATengtInt() {
   KernelScope kernel_scope;
   constexpr int N = 128;
-  Buffer a(VarHandle("A", kHandle), kInt, {N});
-  Buffer b(VarHandle("B", kHandle), kInt, {N});
-  Buffer c(VarHandle("C", kHandle), kInt, {N});
+  Buffer a(BufHandle("A", {N}), kInt);
+  Buffer b(BufHandle("B", {N}), kInt);
+  Buffer c(BufHandle("C", {N}), kInt);
   std::vector<int> a_buffer(N, 6);
   std::vector<int> b_buffer(N, 3);
   std::vector<int> c_buffer(N, 0);
@@ -1017,10 +1017,10 @@
       N,
       Store::make(
           c,
-          i,
+          {i},
           CompareSelect::make(
-              Load::make(a, i, mask),
-              Load::make(b, i, mask),
+              Load::make(a, {i}, mask),
+              Load::make(b, {i}, mask),
               CompareSelectOperation::kGT),
           mask));
 
@@ -1033,9 +1033,9 @@
 void testATenleInt() {
   KernelScope kernel_scope;
   constexpr int N = 128;
-  Buffer a(VarHandle("A", kHandle), kInt, {N});
-  Buffer b(VarHandle("B", kHandle), kInt, {N});
-  Buffer c(VarHandle("C", kHandle), kInt, {N});
+  Buffer a(BufHandle("A", {N}), kInt);
+  Buffer b(BufHandle("B", {N}), kInt);
+  Buffer c(BufHandle("C", {N}), kInt);
   std::vector<int> a_buffer(N, 5);
   std::vector<int> b_buffer(N, 5);
   std::vector<int> c_buffer(N, 0);
@@ -1048,10 +1048,10 @@
       N,
       Store::make(
           c,
-          i,
+          {i},
           CompareSelect::make(
-              Load::make(a, i, mask),
-              Load::make(b, i, mask),
+              Load::make(a, {i}, mask),
+              Load::make(b, {i}, mask),
               CompareSelectOperation::kLE),
           mask));
 
@@ -1064,9 +1064,9 @@
 void testATenltInt() {
   KernelScope kernel_scope;
   constexpr int N = 128;
-  Buffer a(VarHandle("A", kHandle), kInt, {N});
-  Buffer b(VarHandle("B", kHandle), kInt, {N});
-  Buffer c(VarHandle("C", kHandle), kInt, {N});
+  Buffer a(BufHandle("A", {N}), kInt);
+  Buffer b(BufHandle("B", {N}), kInt);
+  Buffer c(BufHandle("C", {N}), kInt);
   std::vector<int> a_buffer(N, 5);
   std::vector<int> b_buffer(N, 5);
   std::vector<int> c_buffer(N, 1);
@@ -1079,10 +1079,10 @@
       N,
       Store::make(
           c,
-          i,
+          {i},
           CompareSelect::make(
-              Load::make(a, i, mask),
-              Load::make(b, i, mask),
+              Load::make(a, {i}, mask),
+              Load::make(b, {i}, mask),
               CompareSelectOperation::kLT),
           mask));
 
diff --git a/test/cpp/tensorexpr/test_cuda.cpp b/test/cpp/tensorexpr/test_cuda.cpp
index 6072945..033a18f 100644
--- a/test/cpp/tensorexpr/test_cuda.cpp
+++ b/test/cpp/tensorexpr/test_cuda.cpp
@@ -43,6 +43,7 @@
   std::vector<For*> loops = l.getLoopStmtsFor(c);
   l.setGPUBlockIndex(loops[1], 0);
   l.setGPUThreadIndex(loops[2], 0);
+  l.prepareForCodegen();
   Stmt* stmt = l.root_stmt();
   CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf);
   const int N = block_count * block_size * num_iter;
@@ -113,6 +114,7 @@
   l.splitWithMask(loops[0], block_size, &n_outer, &n_inner);
   l.setGPUBlockIndex(n_outer, 0);
   l.setGPUThreadIndex(n_inner, 0);
+  l.prepareForCodegen();
   Stmt* stmt = l.root_stmt();
   CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf);
   PaddedBuffer<float> a_v(N);
@@ -161,13 +163,14 @@
   auto testWithSize = [](int32_t M, int32_t N) {
     VarHandle m("m", kInt);
     VarHandle n("n", kInt);
-    Buffer a(VarHandle("a", kHandle), kFloat, {m, n});
-    Buffer b(VarHandle("b", kHandle), kFloat, {m, n});
+    Buffer a(BufHandle("a", {m, n}), kFloat);
+    Buffer b(BufHandle("b", {m, n}), kFloat);
     Tensor* c = Compute(
         "c", {{m, "m"}, {n, "n"}}, [&](const VarHandle& i, const VarHandle& j) {
           return a(i, j) + b(i, j);
         });
     LoopNest l({c});
+    l.prepareForCodegen();
     Stmt* s = l.root_stmt();
     CudaCodeGen cg(s, {a, b, c, m, n});
 
@@ -237,6 +240,7 @@
   std::vector<For*> loops = l.getLoopStmtsFor(c);
   l.setGPUBlockIndex(loops[1], 0);
   l.setGPUThreadIndex(loops[2], 0);
+  l.prepareForCodegen();
   Stmt* stmt = l.root_stmt();
   CudaCodeGen cuda_cg(stmt, c);
   const int N = block_count * block_size * num_iter;
@@ -280,7 +284,7 @@
   KernelScope ks;
   constexpr int N = 4096;
   VarHandle n("n", kInt);
-  Buffer a(VarHandle("a", kHandle), kFloat, {n});
+  Buffer a(BufHandle("a", {n}), kFloat);
   Tensor* b =
       Compute("b", {{n, "n"}}, [&](const VarHandle& i) { return a(i) * 2.0f; });
   LoopNest l({b});
diff --git a/test/cpp/tensorexpr/test_expr.cpp b/test/cpp/tensorexpr/test_expr.cpp
index d4badc5..1f9e2c6 100644
--- a/test/cpp/tensorexpr/test_expr.cpp
+++ b/test/cpp/tensorexpr/test_expr.cpp
@@ -69,9 +69,9 @@
   Buffer a_buf("a", kFloat, {1});
   Buffer b_buf("b", kFloat, {1});
 
-  ExprHandle load_a = Load::make(a_buf, 0, 1);
+  ExprHandle load_a = Load::make(a_buf, {0}, 1);
   VarHandle var = VarHandle("v", kFloat);
-  Stmt* store_b = Store::make(b_buf, 0, var, 1);
+  Stmt* store_b = Store::make(b_buf, {0}, var, 1);
   Stmt* let_store = LetStmt::make(var, load_a, store_b);
   SimpleIREvaluator eval(let_store, a_buf, b_buf);
 
@@ -182,9 +182,9 @@
   const int kVectorCount = 128;
   const int kTotalSize = kVectorSize * kVectorCount;
 
-  Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}), kFloat);
 
   /*
   Build the following:
@@ -197,16 +197,16 @@
   VarHandle index = VarHandle("index", kInt);
   ExprHandle load_a = Load::make(
       a_buf,
-      Ramp::make(index * kVectorSize, 1, kVectorSize),
+      {Ramp::make(index * kVectorSize, 1, kVectorSize)},
       Broadcast::make(1, kVectorSize));
   ExprHandle load_b = Load::make(
       b_buf,
-      Ramp::make(index * kVectorSize, 1, kVectorSize),
+      {Ramp::make(index * kVectorSize, 1, kVectorSize)},
       Broadcast::make(1, kVectorSize));
   ExprHandle value = load_a + load_b;
   Stmt* store_c = Store::make(
       c_buf,
-      Ramp::make(index * kVectorSize, 1, kVectorSize),
+      {Ramp::make(index * kVectorSize, 1, kVectorSize)},
       value,
       Broadcast::make(1, kVectorSize));
   Stmt* stmt = For::make(index, 0, kVectorCount, store_c);
@@ -232,9 +232,9 @@
 void testExprCompareSelectEQ() {
   KernelScope kernel_scope;
   constexpr int N = 1024;
-  Buffer a(VarHandle("A", kHandle), kInt, {N});
-  Buffer b(VarHandle("B", kHandle), kInt, {N});
-  Buffer c(VarHandle("C", kHandle), kInt, {N});
+  Buffer a(BufHandle("A", {N}), kInt);
+  Buffer b(BufHandle("B", {N}), kInt);
+  Buffer c(BufHandle("C", {N}), kInt);
   std::vector<int> a_buffer(N, 1);
   std::vector<int> b_buffer(N, 1);
   std::vector<int> c_buffer(N, 0);
@@ -248,10 +248,10 @@
       N,
       Store::make(
           c,
-          i,
+          {i},
           CompareSelect::make(
-              Load::make(a, i, mask),
-              Load::make(b, i, mask),
+              Load::make(a, {i}, mask),
+              Load::make(b, {i}, mask),
               CompareSelectOperation::kEQ),
           mask));
 
@@ -403,11 +403,11 @@
   KernelScope kernel_scope;
   auto testWithSize = [](int32_t size) {
     VarHandle n("n", kInt);
-    Buffer a(VarHandle("a", kHandle), kFloat, {n});
-    Buffer b(VarHandle("b", kHandle), kFloat, {n});
-    Buffer c(VarHandle("c", kHandle), kFloat, {n});
+    Buffer a(BufHandle("a", {n}), kFloat);
+    Buffer b(BufHandle("b", {n}), kFloat);
+    Buffer c(BufHandle("c", {n}), kFloat);
     VarHandle i("i", kInt);
-    Stmt* s = For::make(i, 0, n, Store::make(c, i, a(i) + b(i), 1));
+    Stmt* s = For::make(i, 0, n, Store::make(c, {i}, a(i) + b(i), 1));
     std::vector<float> aData(size, 1.0f);
     std::vector<float> bData(size, 2.0f);
     std::vector<float> cData(size, 0.0f);
@@ -426,9 +426,9 @@
   Buffer a_buf("a", kFloat, {N});
   VarHandle index = VarHandle("index", kInt);
   Stmt* assign_x2 =
-      Store::make(VarHandle(a_buf.data()), index, cast<float>(index) * 2, 1);
+      Store::make(BufHandle(a_buf.data()), {index}, cast<float>(index) * 2, 1);
   Stmt* assign_x3 =
-      Store::make(VarHandle(a_buf.data()), index, cast<float>(index) * 3, 1);
+      Store::make(BufHandle(a_buf.data()), {index}, cast<float>(index) * 3, 1);
   ExprHandle even_cond = CompareSelect::make(Mod::make(index, 2), 0, kEQ);
   Stmt* assign = Cond::make(even_cond, assign_x2, assign_x3);
   Stmt* for_stmt = For::make(index, 0, N, assign);
@@ -476,7 +476,7 @@
   Buffer a_buf("a", kInt, {N});
   VarHandle index = VarHandle("index", kInt);
   Stmt* body =
-      Store::make(VarHandle(a_buf.data()), index, 5, 1);
+      Store::make(BufHandle(a_buf.data()), {index}, 5, 1);
   Stmt* loop = For::make(index, 0, N, body);
 
   Stmt* cloned_loop = Stmt::clone(loop);
@@ -490,7 +490,7 @@
 
   // Let's add another assign to the body in the cloned loop and verify that the
   // original statement hasn't changed while the cloned one has.
-  Stmt* body_addition = Store::make(VarHandle(a_buf.data()), index, 33, 1);
+  Stmt* body_addition = Store::make(BufHandle(a_buf.data()), {index}, 33, 1);
   Block* cloned_body =
       static_cast<Block*>(static_cast<const For*>(cloned_loop)->body());
   cloned_body->append_stmt(body_addition);
diff --git a/test/cpp/tensorexpr/test_llvm.cpp b/test/cpp/tensorexpr/test_llvm.cpp
index c1522b8..9953bf2 100644
--- a/test/cpp/tensorexpr/test_llvm.cpp
+++ b/test/cpp/tensorexpr/test_llvm.cpp
@@ -196,7 +196,7 @@
 
 void testLLVMBufferTest() {
   KernelScope kernel_scope;
-  Buffer a(VarHandle("A", kHandle), kFloat, {32});
+  Buffer a(BufHandle("A", {32}), kFloat);
   std::vector<int32_t> v(5);
   std::vector<void*> args({v.data()});
   auto rv = IntImm::make(0);
@@ -206,14 +206,14 @@
 
 void testLLVMBlockTest() {
   KernelScope kernel_scope;
-  Buffer a(VarHandle("A", kHandle), kInt, {32});
+  Buffer a(BufHandle("A", {32}), kInt);
   std::vector<int32_t> v = {1, 2};
   std::vector<void*> args({v.data()});
 
   auto block = Block::make({
-      Store::make(a, IntImm::make(0), IntImm::make(3), IntImm::make(1)),
-      Store::make(a, IntImm::make(1), IntImm::make(4), IntImm::make(1)),
-      Store::make(a, IntImm::make(0), IntImm::make(4), IntImm::make(1)),
+      Store::make(a, {IntImm::make(0)}, IntImm::make(3), IntImm::make(1)),
+      Store::make(a, {IntImm::make(1)}, IntImm::make(4), IntImm::make(1)),
+      Store::make(a, {IntImm::make(0)}, IntImm::make(4), IntImm::make(1)),
   });
 
   LLVMCodeGen cg(block, {a});
@@ -224,15 +224,15 @@
 
 void testLLVMLoadStoreTest() {
   KernelScope kernel_scope;
-  Buffer a(VarHandle("A", kHandle), kInt, {1});
-  Buffer b(VarHandle("B", kHandle), kInt, {1});
+  Buffer a(BufHandle("A", {1}), kInt);
+  Buffer b(BufHandle("B", {1}), kInt);
   std::vector<int32_t> a_buffer = {42};
   std::vector<int32_t> b_buffer = {-11};
 
   auto store = Store::make(
       b,
-      IntImm::make(0),
-      Load::make(a, IntImm::make(0), IntImm::make(1)),
+      {IntImm::make(0)},
+      Load::make(a, {IntImm::make(0)}, IntImm::make(1)),
       IntImm::make(1));
   LLVMCodeGen cg(store, {a, b});
   std::vector<void*> args({a_buffer.data(), b_buffer.data()});
@@ -243,19 +243,19 @@
 
 void testLLVMIfThenElseTest() {
   KernelScope kernel_scope;
-  Buffer a(VarHandle("A", kHandle), kInt, {1});
-  Buffer b(VarHandle("B", kHandle), kInt, {1});
-  Buffer c(VarHandle("C", kHandle), kInt, {1});
+  Buffer a(BufHandle("A", {1}), kInt);
+  Buffer b(BufHandle("B", {1}), kInt);
+  Buffer c(BufHandle("C", {1}), kInt);
   std::vector<int32_t> a_buffer = {42};
   std::vector<int32_t> b_buffer = {-11};
   std::vector<int32_t> c_buffer = {1};
 
   auto store = Store::make(
       b,
-      IntImm::make(0),
+      {IntImm::make(0)},
       IfThenElse::make(
-          Load::make(c, IntImm::make(0), IntImm::make(1)), // cond
-          Load::make(a, IntImm::make(0), IntImm::make(1)), // then
+          Load::make(c, {IntImm::make(0)}, IntImm::make(1)), // cond
+          Load::make(a, {IntImm::make(0)}, IntImm::make(1)), // then
           IntImm::make(0)), // else
       IntImm::make(1));
   LLVMCodeGen cg(store, {a, b, c});
@@ -267,15 +267,15 @@
 
 void testLLVMVecLoadStoreTest() {
   KernelScope kernel_scope;
-  Buffer a(VarHandle("A", kHandle), kInt, {1});
-  Buffer b(VarHandle("B", kHandle), kInt, {1});
+  Buffer a(BufHandle("A", {1}), kInt);
+  Buffer b(BufHandle("B", {1}), kInt);
   std::vector<int32_t> a_buffer = {1, 1, 1, 1};
   std::vector<int32_t> b_buffer = {2, 2, 2, 2};
 
   auto store = Store::make(
       b,
-      Ramp::make(0, 1, 4),
-      Load::make(a, Ramp::make(0, 1, 4), Broadcast::make(IntImm::make(1), 4)),
+      {Ramp::make(0, 1, 4)},
+      Load::make(a, {Ramp::make(0, 1, 4)}, Broadcast::make(IntImm::make(1), 4)),
       Broadcast::make(IntImm::make(1), 4));
   LLVMCodeGen cg(store, {a, b});
   std::vector<void*> args({a_buffer.data(), b_buffer.data()});
@@ -293,17 +293,17 @@
 #define FLOAT_INTRINSICS_TEST(Name, Lanes)                       \
   void testLLVMVecFloat_##Name##Lane##Lanes##Test() {            \
     KernelScope kernel_scope;                                    \
-    Buffer a(VarHandle("A", kHandle), kFloat, {1});              \
-    Buffer b(VarHandle("B", kHandle), kFloat, {1});              \
+    Buffer a(BufHandle("A", {1}), kFloat);              \
+    Buffer b(BufHandle("B", {1}), kFloat);              \
     float val = 0.5f;                                            \
     std::vector<float> a_buffer(Lanes, val);                     \
     std::vector<float> b_buffer(Lanes, val);                     \
     auto store = Store::make(                                    \
         b,                                                       \
-        Ramp::make(0, 1, Lanes),                                 \
+        {Ramp::make(0, 1, Lanes)},                               \
         Name(Load::make(                                         \
             a,                                                   \
-            Ramp::make(0, 1, Lanes),                             \
+            {Ramp::make(0, 1, Lanes)},                           \
             Broadcast::make(IntImm::make(1), Lanes))),           \
         Broadcast::make(IntImm::make(1), Lanes));                \
     LLVMCodeGen cg(store, {a, b});                               \
@@ -339,17 +339,17 @@
 #define DOUBLE_INTRINSICS_TEST(Name, Lanes)                      \
   void testLLVMVecDouble_##Name##Lane##Lanes##Test() {           \
     KernelScope kernel_scope;                                    \
-    Buffer a(VarHandle("A", kHandle), kDouble, {1});             \
-    Buffer b(VarHandle("B", kHandle), kDouble, {1});             \
+    Buffer a(BufHandle("A", {1}), kDouble);             \
+    Buffer b(BufHandle("B", {1}), kDouble);             \
     float val = 0.5f;                                            \
     std::vector<double> a_buffer(Lanes, val);                    \
     std::vector<double> b_buffer(Lanes, val);                    \
     auto store = Store::make(                                    \
         b,                                                       \
-        Ramp::make(0, 1, Lanes),                                 \
+        {Ramp::make(0, 1, Lanes)},                               \
         Name(Load::make(                                         \
             a,                                                   \
-            Ramp::make(0, 1, Lanes),                             \
+            {Ramp::make(0, 1, Lanes)},                           \
             Broadcast::make(IntImm::make(1), Lanes))),           \
         Broadcast::make(IntImm::make(1), Lanes));                \
     LLVMCodeGen cg(store, {a, b});                               \
@@ -384,13 +384,13 @@
 
 void testLLVMVectorizerLoadStoreTest() {
   KernelScope kernel_scope;
-  Buffer a(VarHandle("A", kHandle), kInt, {1});
+  Buffer a(BufHandle("A", {1}), kInt);
 
   Tensor* c = Compute("c", {{4, "i"}}, [&](const VarHandle& i) {
-    return Load::make(a, i, 1);
+    return Load::make(a, {i}, 1);
   });
 
-  Buffer c_buf(VarHandle(c->func_var()), kInt, {4});
+  Buffer c_buf(BufHandle(c->func_var()), kInt);
   LoopNest l({c});
   Stmt* s = l.root_stmt();
   l.vectorize(*dynamic_cast<Block*>(s)->stmts().begin());
@@ -410,15 +410,15 @@
 void testLLVMMemcpyTest() {
   KernelScope kernel_scope;
   constexpr int N = 32;
-  Buffer a(VarHandle("A", kHandle), kInt, {N});
-  Buffer b(VarHandle("B", kHandle), kInt, {N});
+  Buffer a(BufHandle("A", {N}), kInt);
+  Buffer b(BufHandle("B", {N}), kInt);
   std::vector<int32_t> a_buffer(N, 42);
   std::vector<int32_t> b_buffer(N, 0);
 
   auto mask = IntImm::make(1);
   VarHandle i("i", kInt);
   auto expr =
-      For::make(i, 0, N, Store::make(b, i, Load::make(a, i, mask), mask));
+      For::make(i, 0, N, Store::make(b, {i}, Load::make(a, {i}, mask), mask));
 
   LLVMCodeGen cg(expr, {a, b});
 
@@ -434,12 +434,12 @@
 void testLLVMBzeroTest() {
   KernelScope kernel_scope;
   constexpr int N = 32;
-  Buffer b(VarHandle("B", kHandle), kInt, {N});
+  Buffer b(BufHandle("B", {N}), kInt);
   std::vector<int32_t> b_buffer(N, 11);
 
   auto mask = IntImm::make(1);
   VarHandle i("i", kInt);
-  auto expr = For::make(i, 0, N, Store::make(b, i, IntImm::make(0), mask));
+  auto expr = For::make(i, 0, N, Store::make(b, {i}, IntImm::make(0), mask));
 
   LLVMCodeGen cg(expr, {b});
 
@@ -453,9 +453,9 @@
 void testLLVMElemwiseAdd() {
   KernelScope kernel_scope;
   constexpr int N = 1024;
-  Buffer a(VarHandle("A", kHandle), kInt, {N});
-  Buffer b(VarHandle("B", kHandle), kInt, {N});
-  Buffer c(VarHandle("C", kHandle), kInt, {N});
+  Buffer a(BufHandle("A", {N}), kInt);
+  Buffer b(BufHandle("B", {N}), kInt);
+  Buffer c(BufHandle("C", {N}), kInt);
   std::vector<int32_t> a_buffer(N, 41);
   std::vector<int32_t> b_buffer(N, 1);
   std::vector<int32_t> c_buffer(N, 1);
@@ -468,8 +468,8 @@
       N,
       Store::make(
           c,
-          i,
-          Add::make(Load::make(a, i, mask), Load::make(b, i, mask)),
+          {i},
+          Add::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask)),
           mask));
 
   LLVMCodeGen cg(expr, {a, b, c});
@@ -488,9 +488,9 @@
 void testLLVMElemwiseAddFloat() {
   KernelScope kernel_scope;
   constexpr int N = 1024;
-  Buffer a(VarHandle("A", kHandle), kFloat, {N});
-  Buffer b(VarHandle("B", kHandle), kFloat, {N});
-  Buffer c(VarHandle("C", kHandle), kFloat, {N});
+  Buffer a(BufHandle("A", {N}), kFloat);
+  Buffer b(BufHandle("B", {N}), kFloat);
+  Buffer c(BufHandle("C", {N}), kFloat);
   std::vector<float> a_buffer(N, 41);
   std::vector<float> b_buffer(N, 1);
   std::vector<float> c_buffer(N, 1);
@@ -501,7 +501,7 @@
       i,
       0,
       N,
-      Store::make(c, i, Load::make(a, i, mask) + Load::make(b, i, mask), mask));
+      Store::make(c, {i}, Load::make(a, {i}, mask) + Load::make(b, {i}, mask), mask));
 
   LLVMCodeGen cg(expr, {a, b, c});
 
@@ -519,8 +519,8 @@
 void testLLVMElemwiseLog10Float() {
   KernelScope kernel_scope;
   constexpr int N = 1024;
-  Buffer a(VarHandle("A", kHandle), kFloat, {N});
-  Buffer b(VarHandle("B", kHandle), kFloat, {N});
+  Buffer a(BufHandle("A", {N}), kFloat);
+  Buffer b(BufHandle("B", {N}), kFloat);
   std::vector<float> a_buffer(N, 10.0f);
   std::vector<float> b_buffer(N, 2.0f);
 
@@ -532,8 +532,8 @@
       N / 4,
       Store::make(
           b,
-          Ramp::make(i * 4, 1, 4),
-          log10(Load::make(a, Ramp::make(i * 4, 1, 4), mask)),
+          {Ramp::make(i * 4, 1, 4)},
+          log10(Load::make(a, {Ramp::make(i * 4, 1, 4)}, mask)),
           mask));
 
   LLVMCodeGen cg(expr, {a, b});
@@ -550,9 +550,9 @@
 void testLLVMElemwiseMaxInt() {
   KernelScope kernel_scope;
   constexpr int N = 1024;
-  Buffer a(VarHandle("A", kHandle), kInt, {N});
-  Buffer b(VarHandle("B", kHandle), kInt, {N});
-  Buffer c(VarHandle("C", kHandle), kInt, {N});
+  Buffer a(BufHandle("A", {N}), kInt);
+  Buffer b(BufHandle("B", {N}), kInt);
+  Buffer c(BufHandle("C", {N}), kInt);
   std::vector<int> a_buffer(N, 41);
   std::vector<int> b_buffer(N, 1);
   std::vector<int> c_buffer(N, 1);
@@ -565,8 +565,8 @@
       N,
       Store::make(
           c,
-          i,
-          Max::make(Load::make(a, i, mask), Load::make(b, i, mask), false),
+          {i},
+          Max::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), false),
           mask));
 
   LLVMCodeGen cg(expr, {a, b, c});
@@ -585,9 +585,9 @@
 void testLLVMElemwiseMinInt() {
   KernelScope kernel_scope;
   constexpr int N = 1024;
-  Buffer a(VarHandle("A", kHandle), kInt, {N});
-  Buffer b(VarHandle("B", kHandle), kInt, {N});
-  Buffer c(VarHandle("C", kHandle), kInt, {N});
+  Buffer a(BufHandle("A", {N}), kInt);
+  Buffer b(BufHandle("B", {N}), kInt);
+  Buffer c(BufHandle("C", {N}), kInt);
   std::vector<int> a_buffer(N, 41);
   std::vector<int> b_buffer(N, 1);
   std::vector<int> c_buffer(N, 1);
@@ -600,8 +600,8 @@
       N,
       Store::make(
           c,
-          i,
-          Min::make(Load::make(a, i, mask), Load::make(b, i, mask), false),
+          {i},
+          Min::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), false),
           mask));
 
   LLVMCodeGen cg(expr, {a, b, c});
@@ -620,9 +620,9 @@
 void testLLVMElemwiseMaxNumFloat() {
   KernelScope kernel_scope;
   constexpr int N = 1024;
-  Buffer a(VarHandle("A", kHandle), kFloat, {N});
-  Buffer b(VarHandle("B", kHandle), kFloat, {N});
-  Buffer c(VarHandle("C", kHandle), kFloat, {N});
+  Buffer a(BufHandle("A", {N}), kFloat);
+  Buffer b(BufHandle("B", {N}), kFloat);
+  Buffer c(BufHandle("C", {N}), kFloat);
   std::vector<float> a_buffer(N, 41);
   std::vector<float> b_buffer(N, 1);
   std::vector<float> c_buffer(N, 1);
@@ -635,8 +635,8 @@
       N,
       Store::make(
           c,
-          i,
-          Max::make(Load::make(a, i, mask), Load::make(b, i, mask), false),
+          {i},
+          Max::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), false),
           mask));
 
   LLVMCodeGen cg(expr, {a, b, c});
@@ -655,9 +655,9 @@
 void testLLVMElemwiseMaxNumNaNFloat() {
   KernelScope kernel_scope;
   constexpr int N = 1024;
-  Buffer a(VarHandle("A", kHandle), kFloat, {N});
-  Buffer b(VarHandle("B", kHandle), kFloat, {N});
-  Buffer c(VarHandle("C", kHandle), kFloat, {N});
+  Buffer a(BufHandle("A", {N}), kFloat);
+  Buffer b(BufHandle("B", {N}), kFloat);
+  Buffer c(BufHandle("C", {N}), kFloat);
   std::vector<float> a_buffer(N, NAN);
   std::vector<float> b_buffer(N, 1);
   std::vector<float> c_buffer(N, 1);
@@ -670,8 +670,8 @@
       N,
       Store::make(
           c,
-          i,
-          Max::make(Load::make(a, i, mask), Load::make(b, i, mask), false),
+          {i},
+          Max::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), false),
           mask));
 
   LLVMCodeGen cg(expr, {a, b, c});
@@ -689,9 +689,9 @@
 void testLLVMElemwiseMinNumFloat() {
   KernelScope kernel_scope;
   constexpr int N = 1024;
-  Buffer a(VarHandle("A", kHandle), kFloat, {N});
-  Buffer b(VarHandle("B", kHandle), kFloat, {N});
-  Buffer c(VarHandle("C", kHandle), kFloat, {N});
+  Buffer a(BufHandle("A", {N}), kFloat);
+  Buffer b(BufHandle("B", {N}), kFloat);
+  Buffer c(BufHandle("C", {N}), kFloat);
   std::vector<float> a_buffer(N, 41);
   std::vector<float> b_buffer(N, 1);
   std::vector<float> c_buffer(N, 1);
@@ -704,8 +704,8 @@
       N,
       Store::make(
           c,
-          i,
-          Min::make(Load::make(a, i, mask), Load::make(b, i, mask), false),
+          {i},
+          Min::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), false),
           mask));
 
   LLVMCodeGen cg(expr, {a, b, c});
@@ -724,9 +724,9 @@
 void testLLVMElemwiseMinNumNaNFloat() {
   KernelScope kernel_scope;
   constexpr int N = 1024;
-  Buffer a(VarHandle("A", kHandle), kFloat, {N});
-  Buffer b(VarHandle("B", kHandle), kFloat, {N});
-  Buffer c(VarHandle("C", kHandle), kFloat, {N});
+  Buffer a(BufHandle("A", {N}), kFloat);
+  Buffer b(BufHandle("B", {N}), kFloat);
+  Buffer c(BufHandle("C", {N}), kFloat);
   std::vector<float> a_buffer(N, NAN);
   std::vector<float> b_buffer(N, 1);
   std::vector<float> c_buffer(N, 1);
@@ -739,8 +739,8 @@
       N,
       Store::make(
           c,
-          i,
-          Min::make(Load::make(a, i, mask), Load::make(b, i, mask), false),
+          {i},
+          Min::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), false),
           mask));
 
   LLVMCodeGen cg(expr, {a, b, c});
@@ -759,9 +759,9 @@
 void testLLVMElemwiseMaximumFloat() {
   KernelScope kernel_scope;
   constexpr int N = 1024;
-  Buffer a(VarHandle("A", kHandle), kFloat, {N});
-  Buffer b(VarHandle("B", kHandle), kFloat, {N});
-  Buffer c(VarHandle("C", kHandle), kFloat, {N});
+  Buffer a(BufHandle("A", {N}), kFloat);
+  Buffer b(BufHandle("B", {N}), kFloat);
+  Buffer c(BufHandle("C", {N}), kFloat);
   std::vector<float> a_buffer(N, 41);
   std::vector<float> b_buffer(N, 1);
   std::vector<float> c_buffer(N, 1);
@@ -774,8 +774,8 @@
       N,
       Store::make(
           c,
-          i,
-          Max::make(Load::make(a, i, mask), Load::make(b, i, mask), true),
+          {i},
+          Max::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), true),
           mask));
 
   LLVMCodeGen cg(expr, {a, b, c});
@@ -794,9 +794,9 @@
 void testLLVMElemwiseMaximumNaNFloat() {
   KernelScope kernel_scope;
   constexpr int N = 1024;
-  Buffer a(VarHandle("A", kHandle), kFloat, {N});
-  Buffer b(VarHandle("B", kHandle), kFloat, {N});
-  Buffer c(VarHandle("C", kHandle), kFloat, {N});
+  Buffer a(BufHandle("A", {N}), kFloat);
+  Buffer b(BufHandle("B", {N}), kFloat);
+  Buffer c(BufHandle("C", {N}), kFloat);
   std::vector<float> a_buffer(N, NAN);
   std::vector<float> b_buffer(N, 1);
   std::vector<float> c_buffer(N, 1);
@@ -809,8 +809,8 @@
       N,
       Store::make(
           c,
-          i,
-          Max::make(Load::make(a, i, mask), Load::make(b, i, mask), true),
+          {i},
+          Max::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), true),
           mask));
 
   LLVMCodeGen cg(expr, {a, b, c});
@@ -830,9 +830,9 @@
 void testLLVMElemwiseMinimumFloat() {
   KernelScope kernel_scope;
   constexpr int N = 1024;
-  Buffer a(VarHandle("A", kHandle), kFloat, {N});
-  Buffer b(VarHandle("B", kHandle), kFloat, {N});
-  Buffer c(VarHandle("C", kHandle), kFloat, {N});
+  Buffer a(BufHandle("A", {N}), kFloat);
+  Buffer b(BufHandle("B", {N}), kFloat);
+  Buffer c(BufHandle("C", {N}), kFloat);
   std::vector<float> a_buffer(N, 41);
   std::vector<float> b_buffer(N, 1);
   std::vector<float> c_buffer(N, 1);
@@ -845,8 +845,8 @@
       N,
       Store::make(
           c,
-          i,
-          Min::make(Load::make(a, i, mask), Load::make(b, i, mask), true),
+          {i},
+          Min::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), true),
           mask));
 
   LLVMCodeGen cg(expr, {a, b, c});
@@ -865,9 +865,9 @@
 void testLLVMElemwiseMinimumNaNFloat() {
   KernelScope kernel_scope;
   constexpr int N = 1024;
-  Buffer a(VarHandle("A", kHandle), kFloat, {N});
-  Buffer b(VarHandle("B", kHandle), kFloat, {N});
-  Buffer c(VarHandle("C", kHandle), kFloat, {N});
+  Buffer a(BufHandle("A", {N}), kFloat);
+  Buffer b(BufHandle("B", {N}), kFloat);
+  Buffer c(BufHandle("C", {N}), kFloat);
   std::vector<float> a_buffer(N, NAN);
   std::vector<float> b_buffer(N, 1);
   std::vector<float> c_buffer(N, 1);
@@ -880,8 +880,8 @@
       N,
       Store::make(
           c,
-          i,
-          Min::make(Load::make(a, i, mask), Load::make(b, i, mask), true),
+          {i},
+          Min::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), true),
           mask));
 
   LLVMCodeGen cg(expr, {a, b, c});
@@ -902,9 +902,9 @@
 void testLLVMCompareSelectIntEQ() {
   KernelScope kernel_scope;
   constexpr int N = 1024;
-  Buffer a(VarHandle("A", kHandle), kInt, {N});
-  Buffer b(VarHandle("B", kHandle), kInt, {N});
-  Buffer c(VarHandle("C", kHandle), kInt, {N});
+  Buffer a(BufHandle("A", {N}), kInt);
+  Buffer b(BufHandle("B", {N}), kInt);
+  Buffer c(BufHandle("C", {N}), kInt);
   std::vector<int> a_buffer(N, 1);
   std::vector<int> b_buffer(N, 1);
   std::vector<int> c_buffer(N, 0);
@@ -923,10 +923,10 @@
       N,
       Store::make(
           c,
-          i,
+          {i},
           CompareSelect::make(
-              Load::make(a, i, mask),
-              Load::make(b, i, mask),
+              Load::make(a, {i}, mask),
+              Load::make(b, {i}, mask),
               CompareSelectOperation::kEQ),
           mask));
 
@@ -948,9 +948,9 @@
 void testLLVMCompareSelectFloatEQ() {
   KernelScope kernel_scope;
   constexpr int N = 1024;
-  Buffer a(VarHandle("A", kHandle), kFloat, {N});
-  Buffer b(VarHandle("B", kHandle), kFloat, {N});
-  Buffer c(VarHandle("C", kHandle), kInt, {N});
+  Buffer a(BufHandle("A", {N}), kFloat);
+  Buffer b(BufHandle("B", {N}), kFloat);
+  Buffer c(BufHandle("C", {N}), kInt);
   std::vector<float> a_buffer(N, 1.0f);
   std::vector<float> b_buffer(N, 1.0f);
   std::vector<int> c_buffer(N, 0);
@@ -963,10 +963,10 @@
       N,
       Store::make(
           c,
-          i,
+          {i},
           CompareSelect::make(
-              Load::make(a, i, mask),
-              Load::make(b, i, mask),
+              Load::make(a, {i}, mask),
+              Load::make(b, {i}, mask),
               CompareSelectOperation::kEQ),
           mask));
 
@@ -986,10 +986,10 @@
 
 void testLLVMStoreFloat() {
   KernelScope kernel_scope;
-  Buffer result(VarHandle("result", kHandle), kFloat, {1});
+  Buffer result(BufHandle("result", {1}), kFloat);
   std::vector<float> result_buffer = {0.0f};
   auto expr = Store::make(
-      result, IntImm::make(0), FloatImm::make(3.14f), IntImm::make(1));
+      result, {IntImm::make(0)}, FloatImm::make(3.14f), IntImm::make(1));
   LLVMCodeGen cg(expr, {result});
   std::vector<void*> args({result_buffer.data()});
   ASSERT_EQ(cg.value<int>(args), 0);
@@ -1004,7 +1004,7 @@
   });
   LoopNest l({tensor});
   Stmt* stmt = l.root_stmt();
-  Buffer f_buf(VarHandle(tensor->func_var()), kFloat, {N});
+  Buffer f_buf(BufHandle(tensor->func_var()), kFloat);
   LLVMCodeGen cg(stmt, {f_buf});
 
   PaddedBuffer<float> f_v(N, "f_v");
@@ -1021,13 +1021,13 @@
 void testLLVMComputeMul() {
   KernelScope kernel_scope;
   const int N = 1024;
-  Buffer a(VarHandle("a", kHandle), kFloat, {N});
-  Buffer b(VarHandle("b", kHandle), kFloat, {N});
+  Buffer a(BufHandle("a", {N}), kFloat);
+  Buffer b(BufHandle("b", {N}), kFloat);
   Tensor* c = Compute("c", {{N, "i"}}, [&](const VarHandle& i) {
-    return Load::make(a, i, 1) * Load::make(b, i, 1);
+    return Load::make(a, {i}, 1) * Load::make(b, {i}, 1);
   });
 
-  Buffer c_buf(VarHandle(c->func_var()), kFloat, {N});
+  Buffer c_buf(BufHandle(c->func_var()), kFloat);
   LoopNest l({c});
   Stmt* s = l.root_stmt();
 
@@ -1045,16 +1045,17 @@
   KernelScope kernel_scope;
   const int M = 32;
   const int N = 1024;
-  Buffer a(VarHandle("a", kHandle), kFloat, {M, N});
-  Buffer b(VarHandle("b", kHandle), kFloat, {N});
+  Buffer a(BufHandle("a", {M, N}), kFloat);
+  Buffer b(BufHandle("b", {N}), kFloat);
   Tensor* c = Compute(
       "c", {{M, "i"}, {N, "j"}}, [&](const VarHandle& i, const VarHandle& j) {
         ExprHandle mask(1);
-        return Load::make(a, i * N + j, mask) + Load::make(b, j, mask);
+        return Load::make(a, {i, j}, mask) + Load::make(b, {j}, mask);
       });
 
-  Buffer c_buf(VarHandle(c->func_var()), kFloat, {M, N});
+  Buffer c_buf(BufHandle(c->func_var()), kFloat);
   LoopNest l({c});
+  l.prepareForCodegen();
   Stmt* s = l.root_stmt();
 
   LLVMCodeGen cg(s, {a, b, c_buf});
@@ -1091,11 +1092,11 @@
   KernelScope kernel_scope;
   auto testWithSize = [](int32_t size) {
     VarHandle n("n", kInt);
-    Buffer a(VarHandle("a", kHandle), kFloat, {n});
-    Buffer b(VarHandle("b", kHandle), kFloat, {n});
-    Buffer c(VarHandle("c", kHandle), kFloat, {n});
+    Buffer a(BufHandle("a", {n}), kFloat);
+    Buffer b(BufHandle("b", {n}), kFloat);
+    Buffer c(BufHandle("c", {n}), kFloat);
     VarHandle i("i", kInt);
-    Stmt* s = For::make(i, 0, n, Store::make(c, i, a(i) + b(i), 1));
+    Stmt* s = For::make(i, 0, n, Store::make(c, {i}, a(i) + b(i), 1));
     std::vector<float> aData(size, 1.0f);
     std::vector<float> bData(size, 2.0f);
     std::vector<float> cData(size, 0.0f);
@@ -1113,11 +1114,11 @@
   KernelScope kernel_scope;
   auto testWithSize = [](int32_t size) {
     VarHandle n("n", kInt);
-    Buffer a(VarHandle("a", kHandle), kFloat, {n});
-    Buffer b(VarHandle("b", kHandle), kFloat, {n});
-    Buffer c(VarHandle("c", kHandle), kFloat, {n});
+    Buffer a(BufHandle("a", {n}), kFloat);
+    Buffer b(BufHandle("b", {n}), kFloat);
+    Buffer c(BufHandle("c", {n}), kFloat);
     VarHandle i("i", kInt);
-    Stmt* s = For::make(i, 0, n, Store::make(c, i, a(i) + b(i), 1));
+    Stmt* s = For::make(i, 0, n, Store::make(c, {i}, a(i) + b(i), 1));
     std::vector<float> aData(size, 1.0f);
     std::vector<float> bData(size, 2.0f);
     std::vector<float> cData(size, 0.0f);
@@ -1134,8 +1135,8 @@
   KernelScope kernel_scope;
   auto testWithSize = [](int32_t size) {
     VarHandle n("n", kInt);
-    Buffer a(VarHandle("a", kHandle), kFloat, {n});
-    Buffer b(VarHandle("b", kHandle), kFloat, {n});
+    Buffer a(BufHandle("a", {n}), kFloat);
+    Buffer b(BufHandle("b", {n}), kFloat);
     Tensor* c = Compute(
         "c", {{n, "n"}}, [&](const VarHandle& i) { return a(i) + b(i); });
     LoopNest l({c});
@@ -1157,13 +1158,14 @@
   auto testWithSize = [](int32_t M, int32_t N) {
     VarHandle m("m", kInt);
     VarHandle n("n", kInt);
-    Buffer a(VarHandle("a", kHandle), kFloat, {m, n});
-    Buffer b(VarHandle("b", kHandle), kFloat, {m, n});
+    Buffer a(BufHandle("a", {m, n}), kFloat);
+    Buffer b(BufHandle("b", {m, n}), kFloat);
     Tensor* c = Compute(
         "c", {{m, "m"}, {n, "n"}}, [&](const VarHandle& i, const VarHandle& j) {
           return a(i, j) + b(i, j);
         });
     LoopNest l({c});
+    l.prepareForCodegen();
     Stmt* s = l.root_stmt();
     LLVMCodeGen cg(s, {a, b, c, m, n});
     std::vector<float> aData(M * N, 1.0f);
diff --git a/test/cpp/tensorexpr/test_loopnest.cpp b/test/cpp/tensorexpr/test_loopnest.cpp
index 387c71e..76c61e4 100644
--- a/test/cpp/tensorexpr/test_loopnest.cpp
+++ b/test/cpp/tensorexpr/test_loopnest.cpp
@@ -77,7 +77,7 @@
     VarHandle x_inner("x_inner", kInt);
     VarHandle y("y", kInt);
     VarHandle x_tail("x_tail", kInt);
-    VarHandle f("f", kHandle);
+    BufHandle f("f", {26, 5});
     ExprHandle x_1 = x_outer * 4 + x_inner;
     ExprHandle x_outer_end = (ExprHandle(26) - 0) / 4;
     For* stmt1 = For::make(
@@ -89,13 +89,13 @@
             0,
             4,
             For::make(
-                y, 0, 5, Store::make(f, x_1 * 5 + y * 1, func(x_1, y), 1))));
+                y, 0, 5, Store::make(f, {x_1, y}, func(x_1, y), 1))));
     ExprHandle x_2 = x_tail + x_outer_end * 4;
     For* stmt2 = For::make(
         x_tail,
         0,
         (ExprHandle(26) - 0) % 4,
-        For::make(y, 0, 5, Store::make(f, x_2 * 5 + y * 1, func(x_2, y), 1)));
+        For::make(y, 0, 5, Store::make(f, {x_2, y}, func(x_2, y), 1)));
     Stmt* stmt = Block::make({stmt1, stmt2});
 
     std::ostringstream oss_ref;
@@ -107,6 +107,7 @@
     PaddedBuffer<float> f_v(26, 5, "f_v");
     PaddedBuffer<float> f_ref(26, 5, "f_res");
 
+    stmt = FlattenIndexes(stmt);
     SimpleIREvaluator ir_eval(stmt, tensor);
     ir_eval(f_v);
 
@@ -145,7 +146,7 @@
     VarHandle x_inner("x_inner", kInt);
     VarHandle y("y", kInt);
     VarHandle x_tail("x_tail", kInt);
-    VarHandle f("f", kHandle);
+    BufHandle f("f", {24, 5});
     ExprHandle x_1 = x_outer * 4 + x_inner;
     ExprHandle x_outer_end = (ExprHandle(24) - 0) / 4;
     For* stmt = For::make(
@@ -157,8 +158,7 @@
             0,
             4,
             For::make(
-                y, 0, 5, Store::make(f, x_1 * 5 + y * 1, func(x_1, y), 1))));
-    // Stmt stmt = Block::make({stmt1, stmt2});
+                y, 0, 5, Store::make(f, {x_1, y}, func(x_1, y), 1))));
 
     std::ostringstream oss_ref;
     oss_ref << *stmt;
@@ -429,6 +429,7 @@
               (c_buf(m, n) * d_buf(m, k) + a_buf(m, n) * b_buf(n, k));
         });
     LoopNest l2({z2});
+    l2.prepareForCodegen();
     Stmt* stmt2 = l2.root_stmt();
 
     std::ostringstream oss2;
@@ -454,7 +455,7 @@
   const int kVectorCount = 128;
   const int kTotalSize = kVectorSize * kVectorCount;
 
-  Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
+  Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
 
   Tensor* b = Compute(
       "f", {{kTotalSize, "i"}}, [&](const std::vector<VarHandle>& axes) {
@@ -487,10 +488,10 @@
   const int kVectorCount = 128;
   const int kTotalSize = kVectorSize * kVectorCount;
 
-  Buffer a(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer b(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer c(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)});
-  Buffer d(VarHandle("D", kHandle), kFloat, {ExprHandle(kTotalSize)});
+  Buffer a(BufHandle("A", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer b(BufHandle("B", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer c(BufHandle("C", {ExprHandle(kTotalSize)}), kFloat);
+  Buffer d(BufHandle("D", {ExprHandle(kTotalSize)}), kFloat);
 
   Tensor* e = Compute("e", {{kTotalSize, "i"}}, [&](const VarHandle& i) {
     return a(i) + b(i);
@@ -525,8 +526,8 @@
   auto testWithSize = [](int32_t M, int32_t N) {
     VarHandle m("m", kInt);
     VarHandle n("n", kInt);
-    Buffer a(VarHandle("a", kHandle), kFloat, {m, n});
-    Buffer b(VarHandle("b", kHandle), kFloat, {m, n});
+    Buffer a(BufHandle("a", {m, n}), kFloat);
+    Buffer b(BufHandle("b", {m, n}), kFloat);
     Tensor* c = Compute(
         "c", {{m, "m"}, {n, "n"}}, [&](const VarHandle& i, const VarHandle& j) {
           return a(i, j) + b(i, j);
diff --git a/test/cpp/tensorexpr/test_simplify.cpp b/test/cpp/tensorexpr/test_simplify.cpp
index 8a03e0e..26356be 100644
--- a/test/cpp/tensorexpr/test_simplify.cpp
+++ b/test/cpp/tensorexpr/test_simplify.cpp
@@ -346,9 +346,9 @@
 void testHashLargeExpression() {
   KernelScope kernel_scope;
   constexpr int N = 1024;
-  Buffer a(VarHandle("A", kHandle), kInt, {N});
-  Buffer b(VarHandle("B", kHandle), kInt, {N});
-  Buffer c(VarHandle("C", kHandle), kInt, {N});
+  Buffer a(BufHandle("A", {N}), kInt);
+  Buffer b(BufHandle("B", {N}), kInt);
+  Buffer c(BufHandle("C", {N}), kInt);
   auto mask = IntImm::make(1);
   VarHandle i("i", kInt);
   auto memcpy_stmt = For::make(
@@ -357,25 +357,25 @@
       N,
       Store::make(
           c,
-          i,
+          {i},
           CompareSelect::make(
-              Load::make(a, i, mask),
-              Load::make(b, i, mask),
+              Load::make(a, {i}, mask),
+              Load::make(b, {i}, mask),
               CompareSelectOperation::kEQ),
           mask));
 
-  Buffer d(VarHandle("D", kHandle), kInt, {1});
-  Buffer e(VarHandle("E", kHandle), kInt, {1});
+  Buffer d(BufHandle("D", {1}), kInt);
+  Buffer e(BufHandle("E", {1}), kInt);
   auto store_ramp_stmt = Store::make(
       e,
-      Ramp::make(0, 1, 4),
-      Load::make(d, Ramp::make(0, 1, 4), Broadcast::make(IntImm::make(1), 4)),
+      {Ramp::make(0, 1, 4)},
+      Load::make(d, {Ramp::make(0, 1, 4)}, Broadcast::make(IntImm::make(1), 4)),
       Broadcast::make(Cast::make(kInt, DoubleImm::make(1)), 4));
 
   auto if_stmt = Cond::make(
       CompareSelect::make(
-          Load::make(a, i, mask),
-          Load::make(b, i, mask),
+          Load::make(a, {i}, mask),
+          Load::make(b, {i}, mask),
           CompareSelectOperation::kGE),
       memcpy_stmt,
       store_ramp_stmt);
diff --git a/torch/csrc/jit/tensorexpr/buffer.h b/torch/csrc/jit/tensorexpr/buffer.h
index 6cd1861..026c4c2 100644
--- a/torch/csrc/jit/tensorexpr/buffer.h
+++ b/torch/csrc/jit/tensorexpr/buffer.h
@@ -8,18 +8,13 @@
 
 class Buffer {
  public:
-  Buffer(
-      const VarHandle& data,
-      const Dtype& dtype,
-      const std::vector<ExprHandle>& dims)
-      : data_(data.node()),
-        dtype_(dtype),
-        dims_(ExprHandleVectorToExprVector(dims)) {
+  Buffer(const BufHandle& data, const Dtype& dtype)
+      : data_(data.node()), dtype_(dtype) {
     if (data.dtype() != kHandle) {
       throw malformed_input();
     }
 
-    std::vector<ExprHandle> stride_handles(dims.size());
+    std::vector<ExprHandle> stride_handles(ndim());
     for (int i = ndim() - 1; i >= 0; i--) {
       if (i == ndim() - 1) {
         stride_handles[i] = 1;
@@ -33,100 +28,60 @@
       const std::string& name,
       const Dtype& dtype,
       const std::vector<ExprHandle>& dims)
-      : Buffer(VarHandle(name, kHandle), dtype, dims) {}
+      : Buffer(BufHandle(name, dims), dtype) {}
 
-  const Var* data() const {
+  const Buf* data() const {
     return data_;
   }
   const Dtype& dtype() const {
     return dtype_;
   }
   int ndim() const {
-    return dims_.size();
+    return data_->ndim();
   }
   const Expr* dim(int index) const {
-    return dims_[index];
+    return data_->dim(index);
+  }
+  std::vector<const Expr*> dims() const {
+    return data_->dims();
   }
 
   // TODO: consider defer the storage flatten to a later stage.
   template <typename... Args>
   ExprHandle operator()(Args... args) const {
-    ExprHandle index = Index(std::forward<Args>(args)...);
-    return LoadValue(index);
+    return LoadValue(std::forward<Args>(args)...);
+  }
+  ExprHandle LoadValue(
+      const ExprHandle& x,
+      const ExprHandle& y,
+      const ExprHandle& z) const {
+    return Load::make(*this, {x, y, z}, ExprHandle(1));
+  }
+  ExprHandle LoadValue(const ExprHandle& x, const ExprHandle& y) const {
+    return Load::make(*this, {x, y}, ExprHandle(1));
+  }
+  ExprHandle LoadValue(const ExprHandle& x) const {
+    return Load::make(*this, {x}, ExprHandle(1));
   }
 
   template <typename T>
   ExprHandle call(const std::vector<T>& args) const {
     std::vector<ExprHandle> params(args.begin(), args.end());
-    ExprHandle index = Index(params);
-    return LoadValue(index);
+    return LoadValue(params);
   }
 
  private:
-  ExprHandle Index(const ExprHandle& x) const {
-    if (ndim() != 1) {
-      throw malformed_input();
-    }
-    return x;
-  }
-  ExprHandle Index(const ExprHandle& x, const ExprHandle& y) const {
-    if (ndim() != 2) {
-      throw malformed_input();
-    }
-    return x * ExprHandle(strides_[0]) + y;
-  }
-  ExprHandle Index(
-      const ExprHandle& x,
-      const ExprHandle& y,
-      const ExprHandle& z) const {
-    if (ndim() != 3) {
-      throw malformed_input();
-    }
-    return x * ExprHandle(strides_[0]) + y * ExprHandle(strides_[1]) + z;
-  }
-  ExprHandle Index(
-      const ExprHandle& x,
-      const ExprHandle& y,
-      const ExprHandle& z,
-      const ExprHandle& w) const {
-    if (ndim() != 4) {
-      throw malformed_input();
-    }
-    return x * ExprHandle(strides_[0]) + y * ExprHandle(strides_[1]) +
-        z * ExprHandle(strides_[2]) + w;
-  }
-  ExprHandle Index(const std::vector<ExprHandle>& indices) const {
-    if (ndim() != (int)indices.size()) {
-      throw malformed_input();
-    }
-    ExprHandle total_index;
-    for (size_t i = 0; i < indices.size(); i++) {
-      ExprHandle index;
-      if (i == indices.size() - 1) {
-        index = indices[i];
-      } else {
-        index = indices[i] * ExprHandle(strides_[i]);
-      }
-      if (i == 0) {
-        total_index = index;
-      } else {
-        total_index = total_index + index;
-      }
-    }
-    return total_index;
-  }
+  ExprHandle LoadValue(const std::vector<ExprHandle>& indices) const;
 
-  ExprHandle LoadValue(const ExprHandle& index) const;
-
-  const Var* data_;
+  const Buf* data_;
   Dtype dtype_;
-  std::vector<const Expr*> dims_;
   std::vector<const Expr*> strides_;
   // TODO: add strides
 };
 
-inline ExprHandle Buffer::LoadValue(const ExprHandle& index) const {
-  return Load::make(*this, index, ExprHandle(1));
+inline ExprHandle Buffer::LoadValue(
+    const std::vector<ExprHandle>& indices) const {
+  return Load::make(*this, indices, ExprHandle(1));
 }
 
 } // namespace tensorexpr
diff --git a/torch/csrc/jit/tensorexpr/codegen.h b/torch/csrc/jit/tensorexpr/codegen.h
index e232e5b..c1f63b2 100644
--- a/torch/csrc/jit/tensorexpr/codegen.h
+++ b/torch/csrc/jit/tensorexpr/codegen.h
@@ -56,12 +56,14 @@
 class CodeGen::BufferArg {
  public:
   BufferArg(const Buffer& buffer)
-      : var_(buffer.data()), dtype_(buffer.dtype()) {}
+      : var_(buffer.data()->base_handle()), dtype_(buffer.dtype()) {}
   BufferArg(Tensor* tensor)
-      : var_(tensor->function()->func_var(tensor->output_index())),
+      : var_(tensor->function()
+                 ->func_var(tensor->output_index())
+                 ->base_handle()),
         dtype_(tensor->function()->body(tensor->output_index())->dtype()) {}
   BufferArg(const Function& func)
-      : var_(func.func_var(0)), dtype_(func.body(0)->dtype()) {
+      : var_(func.func_var(0)->base_handle()), dtype_(func.body(0)->dtype()) {
     // TODO: Support multiple-output functions
     if (func.func_vars().size() != 1) {
       throw unimplemented_lowering();
diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp
index 0d5e21b..2119b1e 100644
--- a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp
+++ b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp
@@ -161,14 +161,15 @@
 void CudaPrinter::visit(const Load* v) {
   // TODO: find a better metric in using ldg or not. Support different dtypes.
   if (v->dtype().scalar_type() == ScalarType::Half) {
-    os() << "__half2float(" << *v->base_handle() << "[" << *v->index() << "])";
+    os() << "__half2float(" << *v->base_handle() << "[" << *v->flat_index()
+         << "])";
   } else {
-    os() << "__ldg(" << *v->base_handle() << " + " << *v->index() << ")";
+    os() << "__ldg(" << *v->base_handle() << " + " << *v->flat_index() << ")";
   }
 }
 
 void CudaPrinter::visit(const Store* v) {
-  os() << *v->base_handle() << "[" << *v->index() << "] = ";
+  os() << *v->base_handle() << "[" << *v->flat_index() << "] = ";
   if (v->value()->dtype().scalar_type() == ScalarType::Half) {
     os() << "__float2half(" << *v->value() << ");";
   } else {
diff --git a/torch/csrc/jit/tensorexpr/eval.h b/torch/csrc/jit/tensorexpr/eval.h
index 23dc185..66404ce 100644
--- a/torch/csrc/jit/tensorexpr/eval.h
+++ b/torch/csrc/jit/tensorexpr/eval.h
@@ -582,7 +582,8 @@
     }
     void* ptr = iter->second;
 
-    v->index()->accept(this);
+    const Expr* flat_idx = flatten_index(v->buf()->dims(), v->indices());
+    flat_idx->accept(this);
     std::vector<int> index = value().as_vec<int>();
     v->mask()->accept(this);
     std::vector<int> mask = value().as_vec<int>();
@@ -615,7 +616,8 @@
 
     void* ptr = iter->second;
 
-    v->index()->accept(this);
+    const Expr* flat_idx = flatten_index(v->buf()->dims(), v->indices());
+    flat_idx->accept(this);
     std::vector<int> index = value().as_vec<int>();
     v->mask()->accept(this);
     std::vector<int> mask = value().as_vec<int>();
@@ -857,7 +859,13 @@
       : dtype_(expr.dtype()) {
     std::vector<BufferArg> buffer_args_extended = buffer_args;
     Buffer ret_buf("ret_val", dtype_, {1});
-    Stmt* store_stmt = Store::make(VarHandle(ret_buf.data()), 0, expr);
+    std::vector<const Expr*> indices;
+    const Expr* zero = new IntImm(0);
+    for (size_t i = 0; i < ret_buf.data()->ndim(); i++) {
+      indices.push_back(zero);
+    }
+    Stmt* store_stmt =
+        new Store(ret_buf.data(), indices, expr.node(), new IntImm(1));
     buffer_args_extended.push_back(ret_buf);
     codegen_.reset(new CodeGenType(store_stmt, buffer_args_extended));
   }
diff --git a/torch/csrc/jit/tensorexpr/expr.cpp b/torch/csrc/jit/tensorexpr/expr.cpp
index 9f7ed26..82b09f2 100644
--- a/torch/csrc/jit/tensorexpr/expr.cpp
+++ b/torch/csrc/jit/tensorexpr/expr.cpp
@@ -203,6 +203,17 @@
   return IfThenElse::make(c, t, f);
 }
 
+ExprHandle Buf::make(
+    const std::string& name_hint,
+    const std::vector<ExprHandle>& dims) {
+  return ExprHandle(
+      new Buf(new Var(name_hint, kHandle), ExprHandleVectorToExprVector(dims)));
+}
+
+ExprHandle Buf::make(const std::vector<ExprHandle>& dims) {
+  return Buf::make("", dims);
+}
+
 } // namespace tensorexpr
 } // namespace jit
 } // namespace torch
diff --git a/torch/csrc/jit/tensorexpr/expr.h b/torch/csrc/jit/tensorexpr/expr.h
index 14eb4bb..2af4830 100644
--- a/torch/csrc/jit/tensorexpr/expr.h
+++ b/torch/csrc/jit/tensorexpr/expr.h
@@ -161,6 +161,66 @@
   std::string name_hint_;
 };
 
+class TORCH_API Buf : public ExprNode<Buf> {
+ public:
+  static ExprHandle make(
+      const std::string& name_hint,
+      const std::vector<ExprHandle>& dims);
+  static ExprHandle make(const std::vector<ExprHandle>& dims);
+
+  // TODO: unique_name
+  const Var* base_handle() const {
+    return base_handle_;
+  }
+  const std::string& name_hint() const {
+    return base_handle_->name_hint();
+  }
+
+  Buf(const Var* var, const std::vector<const Expr*>& dims)
+      : ExprNodeBase(kHandle, kPrimitive), base_handle_(var), dims_(dims) {
+    TORCH_CHECK(var);
+  }
+
+  int ndim() const {
+    return dims_.size();
+  }
+  const Expr* dim(int index) const {
+    return dims_[index];
+  }
+  std::vector<const Expr*> dims() const {
+    return dims_;
+  }
+
+ private:
+  const Var* base_handle_;
+  std::vector<const Expr*> dims_;
+};
+
+class TORCH_API BufHandle : public ExprHandle {
+ public:
+  BufHandle() : ExprHandle(nullptr) {}
+  //   explicit BufHandle(Dtype dtype) : ExprHandle(Buf::make(dtype)) {}
+  BufHandle(const std::string& name_hint, const std::vector<ExprHandle>& dims)
+      : ExprHandle(Buf::make(name_hint, dims)) {}
+  explicit BufHandle(const Buf* node) : ExprHandle(node) {}
+  const Buf* node() const {
+    return static_cast<const Buf*>(ExprHandle::node());
+  }
+  bool operator==(const BufHandle& other) const {
+    return this->node() == other.node();
+  }
+  bool operator!=(const BufHandle& other) const {
+    return !(*this == other);
+  }
+
+  const std::string& name_hint() const {
+    return this->node()->name_hint();
+  }
+  bool empty() const {
+    return (this->node() == nullptr);
+  }
+};
+
 // An expression to construct the underlying variable node.
 // Note: do not store any info here, since it is often possible to slice this
 // object. For example: VarHandle x('x'); ExprHandle x2 = x;
diff --git a/torch/csrc/jit/tensorexpr/function.cpp b/torch/csrc/jit/tensorexpr/function.cpp
index 97996f7..4a9a9a6 100644
--- a/torch/csrc/jit/tensorexpr/function.cpp
+++ b/torch/csrc/jit/tensorexpr/function.cpp
@@ -32,7 +32,8 @@
   unpack_dim_args(dim_args, &dims, &args);
   const Expr* body = body_func(VarVectorToVarHandleVector(args)).node();
   Function* func = new Function(func_name, dims, args, body);
-  return new Tensor(func, 0);
+  const Buf* buf = func->func_var(0);
+  return new Tensor(buf, func, 0);
 }
 
 Tensor* Compute(
@@ -48,7 +49,8 @@
   unpack_dim_args(dim_args, &dims, &args);
   const Expr* body = body_func(VarHandle(args[0])).node();
   Function* func = new Function(func_name, dims, args, body);
-  return new Tensor(func, 0);
+  const Buf* buf = func->func_var(0);
+  return new Tensor(buf, func, 0);
 }
 
 Tensor* Compute(
@@ -64,7 +66,8 @@
   unpack_dim_args(dim_args, &dims, &args);
   const Expr* body = body_func(VarHandle(args[0]), VarHandle(args[1])).node();
   Function* func = new Function(func_name, dims, args, body);
-  return new Tensor(func, 0);
+  const Buf* buf = func->func_var(0);
+  return new Tensor(buf, func, 0);
 }
 
 Tensor* Compute(
@@ -83,7 +86,8 @@
       body_func(VarHandle(args[0]), VarHandle(args[1]), VarHandle(args[2]))
           .node();
   Function* func = new Function(func_name, dims, args, body);
-  return new Tensor(func, 0);
+  const Buf* buf = func->func_var(0);
+  return new Tensor(buf, func, 0);
 }
 
 Tensor* Compute(
@@ -103,37 +107,20 @@
   auto args = VarVectorToVarHandleVector(args_nodes);
   const Expr* body = body_func(args[0], args[1], args[2], args[3]).node();
   Function* func = new Function(func_name, dims, args_nodes, body);
-  return new Tensor(func, 0);
+  const Buf* buf = func->func_var(0);
+  return new Tensor(buf, func, 0);
 }
 
 Stmt* Function::ElementStmt(size_t index) {
-  std::vector<ExprHandle> strides(dims_.size());
-  for (size_t i = 0; i < strides.size(); i++) {
-    if (i == strides.size() - 1) {
-      strides[i] = ExprHandle(1);
-      continue;
-    }
-    ExprHandle stride = ExprHandle(dims_[i + 1]);
-    for (size_t j = i + 2; j < dims_.size(); j++) {
-      stride = stride * ExprHandle(dims_[j]);
-    }
-    strides[i] = stride;
-  }
-
-  ExprHandle total_index = int32_t{0};
-  for (size_t i = 0; i < dims_.size(); i++) {
-    ExprHandle index = VarHandle(this->args_[i]) * ExprHandle(strides[i]);
-    if (i == 0) {
-      total_index = index;
-    } else {
-      total_index = total_index + index;
-    }
+  const Buf* buf = func_var(index);
+  std::vector<const Expr*> indices;
+  for (size_t i = 0; i < buf->ndim(); i++) {
+    indices.push_back(this->args_[i]);
   }
 
   const Expr* mask = new IntImm(1);
 
-  Stmt* update_stmt =
-      new Store(func_var(index), total_index.node(), body(index), mask);
+  Stmt* update_stmt = new Store(buf, indices, body(index), mask);
   return update_stmt;
 }
 
diff --git a/torch/csrc/jit/tensorexpr/function.h b/torch/csrc/jit/tensorexpr/function.h
index eba3ebe..b526543 100644
--- a/torch/csrc/jit/tensorexpr/function.h
+++ b/torch/csrc/jit/tensorexpr/function.h
@@ -17,8 +17,9 @@
       const std::vector<const Expr*>& dims,
       const std::vector<const Var*>& args,
       const Expr* body)
-      : func_vars_({VarHandle(func_name, kHandle).node()}),
-        dims_(dims),
+      // TODO: Function should not create buffers, they should be created
+      // manually before constructing a function.
+      : func_vars_({new Buf(new Var(func_name, kHandle), dims)}),
         args_(args),
         bodies_({body}) {}
   Function(
@@ -26,30 +27,14 @@
       const std::vector<const Expr*>& dims,
       const std::vector<const Var*>& args,
       const std::vector<const Expr*>& bodies)
-      : func_vars_(func_names.size()),
-        dims_(dims),
-        args_(args),
-        bodies_(bodies) {
+      : func_vars_(func_names.size()), args_(args), bodies_(bodies) {
     for (size_t i = 0; i < func_names.size(); i++) {
-      func_vars_[i] = new Var(func_names[i], kHandle);
+      func_vars_[i] = new Buf(new Var(func_names[i], kHandle), dims);
     }
   }
 
-  int ndim() const {
-    return dims_.size();
-  }
-  const Expr* dim(int index) const {
-    if (index < 0 || index >= ndim()) {
-      throw out_of_range_index();
-    }
-
-    return dims_[index];
-  }
-  const std::vector<const Expr*>& dims() const {
-    return dims_;
-  }
   const Var* arg(int index) const {
-    if (index < 0 || index >= ndim()) {
+    if (index < 0 || index >= args_.size()) {
       throw out_of_range_index();
     }
 
@@ -70,10 +55,10 @@
     return bodies_[index];
   }
 
-  std::vector<const Var*> func_vars() const {
+  std::vector<const Buf*> func_vars() const {
     return func_vars_;
   }
-  const Var* func_var(size_t index) const {
+  const Buf* func_var(size_t index) const {
     if (index >= func_vars_.size()) {
       throw out_of_range_index();
     }
@@ -83,8 +68,7 @@
   Stmt* ElementStmt(size_t index);
 
  private:
-  std::vector<const Var*> func_vars_;
-  std::vector<const Expr*> dims_;
+  std::vector<const Buf*> func_vars_;
   std::vector<const Var*> args_;
   std::vector<const Expr*> bodies_;
 };
diff --git a/torch/csrc/jit/tensorexpr/hash_provider.cpp b/torch/csrc/jit/tensorexpr/hash_provider.cpp
index f77ec3a..546a99c 100644
--- a/torch/csrc/jit/tensorexpr/hash_provider.cpp
+++ b/torch/csrc/jit/tensorexpr/hash_provider.cpp
@@ -172,21 +172,26 @@
 void HashProvider::visit(const Load* v) {
   CACHE_GUARD();
   v->base_handle()->accept(this);
-  v->index()->accept(this);
+  SimplifierHashType indices_hash;
+  for (const Expr* ind : v->indices()) {
+    ind->accept(this);
+    indices_hash = hash_combine(indices_hash, hashOf(ind));
+  }
   v->mask()->accept(this);
   putHash(
       v,
       hash_combine(
-          "load",
-          hashOf(v->base_handle()),
-          hashOf(v->index()),
-          hashOf(v->mask())));
+          "load", hashOf(v->base_handle()), indices_hash, hashOf(v->mask())));
 }
 
 void HashProvider::visit(const Store* v) {
   CACHE_GUARD();
   v->base_handle()->accept(this);
-  v->index()->accept(this);
+  SimplifierHashType indices_hash;
+  for (const Expr* ind : v->indices()) {
+    ind->accept(this);
+    indices_hash = hash_combine(indices_hash, hashOf(ind));
+  }
   v->value()->accept(this);
   v->mask()->accept(this);
   putHash(
@@ -194,7 +199,7 @@
       hash_combine(
           "store",
           hashOf(v->base_handle()),
-          hashOf(v->index()),
+          indices_hash,
           hashOf(v->value()),
           hashOf(v->mask())));
 }
diff --git a/torch/csrc/jit/tensorexpr/ir.cpp b/torch/csrc/jit/tensorexpr/ir.cpp
index 68e34ab..651c4a1 100644
--- a/torch/csrc/jit/tensorexpr/ir.cpp
+++ b/torch/csrc/jit/tensorexpr/ir.cpp
@@ -10,46 +10,183 @@
   return Dtype(buffer_dtype, index_dtype.lanes());
 }
 
-Load::Load(const Buffer& buffer, const Expr* index, const Expr* mask)
+static Dtype dtypeOfIndices(const std::vector<const Expr*>& indices) {
+  if (!indices.size()) {
+    throw malformed_input();
+  }
+  Dtype dt = indices.at(0)->dtype();
+  for (size_t i = 1; i < indices.size(); ++i) {
+    if (indices.at(i)->dtype() != dt) {
+      throw malformed_input();
+    }
+  }
+  return dt;
+}
+
+static bool indicesValid(const std::vector<const Expr*>& indices) {
+  if (indices.size() == 0) {
+    return false;
+  }
+  Dtype index_dtype = dtypeOfIndices(indices);
+  if (indices.size() > 1 && index_dtype.lanes() > 1) {
+    // Multilane is only allowed in a flattened (i.e. 1D) index
+    return false;
+  }
+  if (index_dtype.scalar_type() != ScalarType::Int) {
+    return false;
+  }
+  return true;
+}
+
+Load::Load(
+    const Buffer& buffer,
+    const std::vector<const Expr*>& indices,
+    const Expr* mask)
     : Load(
-          ChooseDtype(buffer.dtype(), index->dtype()),
+          ChooseDtype(buffer.dtype(), dtypeOfIndices(indices)),
           buffer.data(),
-          index,
+          indices,
           mask) {}
 
 Load::Load(
     Dtype dtype,
-    const Var* base_handle,
-    const Expr* index,
+    const Buf* buf,
+    const std::vector<const Expr*>& indices,
     const Expr* mask)
-    : ExprNodeBase(dtype),
-      base_handle_(base_handle),
-      index_(index),
-      mask_(mask) {
-  if (base_handle->dtype() != kHandle) {
+    : ExprNodeBase(dtype), buf_(buf), indices_(indices), mask_(mask) {
+  if (buf->base_handle()->dtype() != kHandle) {
     throw malformed_input();
   }
-
-  if (index->dtype().lanes() != mask->dtype().lanes()) {
+  if (!indicesValid(indices)) {
     throw malformed_input();
   }
-
-  if (index->dtype().scalar_type() != ScalarType::Int) {
-    throw unsupported_dtype();
+  Dtype index_dtype = dtypeOfIndices(indices);
+  if (index_dtype.lanes() != mask->dtype().lanes()) {
+    throw malformed_input();
   }
 }
 
+ExprHandle Load::make(
+    const Buffer& buffer,
+    const std::vector<ExprHandle>& indices,
+    const ExprHandle& mask) {
+  return ExprHandle(
+      new Load(buffer, ExprHandleVectorToExprVector(indices), mask.node()));
+}
+ExprHandle Load::make(
+    Dtype dtype,
+    const BufHandle& buf,
+    const std::vector<ExprHandle>& indices,
+    const ExprHandle& mask) {
+  return ExprHandle(new Load(
+      dtype, buf.node(), ExprHandleVectorToExprVector(indices), mask.node()));
+}
+
 Store::Store(
     const Buffer& buffer,
-    const Expr* index,
+    const std::vector<const Expr*>& indices,
     const Expr* value,
     const Expr* mask)
-    : Store(buffer.data(), index, value, mask) {
+    : Store(buffer.data(), indices, value, mask) {
   if (buffer.dtype().scalar_type() != value->dtype().scalar_type()) {
     throw malformed_input();
   }
 }
 
+Store::Store(
+    const Buf* buf,
+    std::vector<const Expr*> indices,
+    const Expr* value,
+    const Expr* mask)
+    : buf_(buf), indices_(std::move(indices)), value_(value), mask_(mask) {
+  if (buf->dtype() != kHandle) {
+    throw malformed_input();
+  }
+  /*
+  TODO: Reenable the checks.
+  The reason they are disabled is that kernel.cpp is using Buffers somewhat
+  loosely: we don't set dimensions properly and just construct index expressions
+  directly. We should harden that part and then we'd be able to turn on these
+  checks.
+
+  if (!indicesValid(indices)) {
+    throw malformed_input();
+  }
+  if (!mask || !value) {
+    throw malformed_input();
+  }
+  Dtype index_dtype = dtypeOfIndices(indices);
+  if (index_dtype.lanes() != mask->dtype().lanes()) {
+    throw malformed_input();
+  }
+  if (index_dtype.lanes() != value->dtype().lanes()) {
+    throw malformed_input();
+  }
+  */
+}
+
+Store* Store::make(
+    const Buffer& buffer,
+    const std::vector<ExprHandle>& indices,
+    const ExprHandle& value,
+    const ExprHandle& mask) {
+  return new Store(
+      buffer, ExprHandleVectorToExprVector(indices), value.node(), mask.node());
+}
+
+Store* Store::make(
+    const BufHandle& buf,
+    const std::vector<ExprHandle>& indices,
+    const ExprHandle& value,
+    const ExprHandle& mask) {
+  return new Store(
+      buf.node(),
+      ExprHandleVectorToExprVector(indices),
+      value.node(),
+      mask.node());
+}
+
+Store* Store::make(
+    const BufHandle& buf,
+    const std::vector<ExprHandle>& indices,
+    const ExprHandle& value) {
+  return new Store(
+      buf.node(),
+      ExprHandleVectorToExprVector(indices),
+      value.node(),
+      ExprHandle(1).node());
+}
+
+const Expr* flatten_index(
+    const std::vector<const Expr*>& dims,
+    const std::vector<const Expr*>& indices) {
+  // Handle already flattened indices first
+  if (indices.size() == 1) {
+    return indices[0];
+  }
+
+  size_t ndim = dims.size();
+  if (ndim != indices.size()) {
+    throw malformed_input();
+  }
+  if (ndim == 0) {
+    return new IntImm(0);
+  }
+  std::vector<const Expr*> strides(ndim);
+  // stride[i] = stride[i+1]*dims[i+1], i < ndim-1
+  // stride[i] = 1,                     i = ndim-1
+  strides[ndim - 1] = new IntImm(1);
+  for (size_t i = 1; i < ndim; i++) {
+    strides[ndim - 1 - i] = new Mul(strides[ndim - i], dims[ndim - i]);
+  }
+
+  const Expr* total_index = new IntImm(0);
+  for (size_t i = 0; i < ndim; i++) {
+    total_index = new Add(total_index, new Mul(indices[i], strides[i]));
+  }
+  return total_index;
+}
+
 Dtype Intrinsics::IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1) {
   // TODO: check the op_type and make a real decision
   return dt1;
diff --git a/torch/csrc/jit/tensorexpr/ir.h b/torch/csrc/jit/tensorexpr/ir.h
index 136339f..2b7b878 100644
--- a/torch/csrc/jit/tensorexpr/ir.h
+++ b/torch/csrc/jit/tensorexpr/ir.h
@@ -442,39 +442,44 @@
 class TORCH_API Load : public ExprNode<Load> {
  public:
   const Var* base_handle() const {
-    return base_handle_;
+    return buf_->base_handle();
   }
-  const Expr* index() const {
-    return index_;
+  std::vector<const Expr*> indices() const {
+    return indices_;
+  }
+  const Expr* flat_index() const {
+    TORCH_CHECK(indices_.size() == 1, "Indices haven't been flattened.");
+    return indices_[0];
   }
   const Expr* mask() const {
     return mask_;
   }
+  const Buf* buf() const {
+    return buf_;
+  }
   static ExprHandle make(
       const Buffer& buffer,
-      const ExprHandle& index,
-      const ExprHandle& mask) {
-    return ExprHandle(new Load(buffer, index.node(), mask.node()));
-  }
+      const std::vector<ExprHandle>& indices,
+      const ExprHandle& mask);
   static ExprHandle make(
       Dtype dtype,
-      const VarHandle& base_handle,
-      const ExprHandle& index,
-      const ExprHandle& mask) {
-    return ExprHandle(
-        new Load(dtype, base_handle.node(), index.node(), mask.node()));
-  }
+      const BufHandle& buf,
+      const std::vector<ExprHandle>& indices,
+      const ExprHandle& mask);
 
-  Load(const Buffer& buffer, const Expr* index, const Expr* mask);
+  Load(
+      const Buffer& buffer,
+      const std::vector<const Expr*>& indices,
+      const Expr* mask);
   Load(
       Dtype dtype,
-      const Var* base_handle,
-      const Expr* index,
+      const Buf* base_handle,
+      const std::vector<const Expr*>& indices,
       const Expr* mask);
 
  private:
-  const Var* base_handle_;
-  const Expr* index_;
+  const Buf* buf_;
+  std::vector<const Expr*> indices_;
   const Expr* mask_;
 };
 
@@ -872,6 +877,9 @@
     const std::vector<VarHandle>&);
 TORCH_API std::vector<VarHandle> VarVectorToVarHandleVector(
     const std::vector<const Var*>&);
+TORCH_API const Expr* flatten_index(
+    const std::vector<const Expr*>& dims,
+    const std::vector<const Expr*>& indices);
 
 } // namespace tensorexpr
 } // namespace jit
diff --git a/torch/csrc/jit/tensorexpr/ir_mutator.cpp b/torch/csrc/jit/tensorexpr/ir_mutator.cpp
index d2a145b..71e68f1 100644
--- a/torch/csrc/jit/tensorexpr/ir_mutator.cpp
+++ b/torch/csrc/jit/tensorexpr/ir_mutator.cpp
@@ -184,18 +184,28 @@
 
 const Expr* IRMutator::mutate(const Load* v) {
   Dtype dtype = v->dtype();
-  const Var* base_handle = v->base_handle();
-  const Expr* index = v->index();
+  const Buf* buf = v->buf();
+
+  bool any_index_changed = false;
+  std::vector<const Expr*> indices_new;
+  for (const Expr* ind : v->indices()) {
+    const Expr* new_ind = ind->accept_mutator(this);
+    if (new_ind != ind) {
+      any_index_changed = true;
+    }
+    indices_new.push_back(new_ind);
+  }
   const Expr* mask = v->mask();
-  const Expr* base_handle_expr = base_handle->accept_mutator(this);
-  const Var* base_handle_new = dynamic_cast<const Var*>(base_handle_expr);
-  const Expr* index_new = index->accept_mutator(this);
+  const Buf* buf_new = dynamic_cast<const Buf*>(buf->accept_mutator(this));
   const Expr* mask_new = mask->accept_mutator(this);
-  if (base_handle == base_handle_new && index == index_new &&
-      mask == mask_new) {
+  if (buf == buf_new && !any_index_changed && mask == mask_new) {
     return v;
   }
-  return new Load(dtype, base_handle_new, index_new, mask_new);
+  return new Load(dtype, buf_new, indices_new, mask_new);
+}
+
+const Expr* IRMutator::mutate(const Buf* v) {
+  return v;
 }
 
 const Expr* IRMutator::mutate(const Broadcast* v) {
@@ -321,20 +331,27 @@
 }
 
 Stmt* IRMutator::mutate(const Store* v) {
-  const Var* base_handle = v->base_handle();
-  const Expr* index = v->index();
+  const Buf* buf = v->buf();
+
+  bool any_index_changed = false;
+  std::vector<const Expr*> indices_new;
+  for (const Expr* ind : v->indices()) {
+    const Expr* new_ind = ind->accept_mutator(this);
+    if (new_ind != ind) {
+      any_index_changed = true;
+    }
+    indices_new.push_back(new_ind);
+  }
   const Expr* value = v->value();
   const Expr* mask = v->mask();
-  const Expr* base_handle_expr = base_handle->accept_mutator(this);
-  const Var* base_handle_new = dynamic_cast<const Var*>(base_handle_expr);
-  const Expr* index_new = index->accept_mutator(this);
+  const Buf* buf_new = dynamic_cast<const Buf*>(buf->accept_mutator(this));
   const Expr* value_new = value->accept_mutator(this);
   const Expr* mask_new = mask->accept_mutator(this);
-  if (base_handle == base_handle_new && index == index_new &&
-      value == value_new && mask == mask_new) {
+  if (buf == buf_new && !any_index_changed && value == value_new &&
+      mask == mask_new) {
     return (Stmt*)v;
   }
-  return new Store(base_handle_new, index_new, value_new, mask_new);
+  return new Store(buf_new, indices_new, value_new, mask_new);
 }
 
 Stmt* IRMutator::mutate(const Allocate* v) {
@@ -436,7 +453,7 @@
 }
 
 Stmt* StmtClone::mutate(const Store* v) {
-  return new Store(v->base_handle(), v->index(), v->value(), v->mask());
+  return new Store(v->buf(), v->indices(), v->value(), v->mask());
 }
 
 Stmt* StmtClone::mutate(const Allocate* v) {
diff --git a/torch/csrc/jit/tensorexpr/ir_mutator.h b/torch/csrc/jit/tensorexpr/ir_mutator.h
index 42c0fca..254a5bb 100644
--- a/torch/csrc/jit/tensorexpr/ir_mutator.h
+++ b/torch/csrc/jit/tensorexpr/ir_mutator.h
@@ -27,6 +27,7 @@
 
 class Cast;
 class Var;
+class Buf;
 class Let;
 class LetStmt;
 class Ramp;
@@ -71,6 +72,7 @@
 #undef IMM_MUTATE_DECLARE
   virtual const Expr* mutate(const Cast* v);
   virtual const Expr* mutate(const Var* v);
+  virtual const Expr* mutate(const Buf* v);
   virtual const Expr* mutate(const Let* v);
   virtual Stmt* mutate(const LetStmt* v);
   virtual const Expr* mutate(const Ramp* v);
diff --git a/torch/csrc/jit/tensorexpr/ir_printer.cpp b/torch/csrc/jit/tensorexpr/ir_printer.cpp
index 096a98f..5f8c619 100644
--- a/torch/csrc/jit/tensorexpr/ir_printer.cpp
+++ b/torch/csrc/jit/tensorexpr/ir_printer.cpp
@@ -263,7 +263,15 @@
 
 void IRPrinter::visit(const Load* v) {
   // TODO: support the mask case
-  os() << *v->base_handle() << "[" << *v->index() << "]";
+  os() << *v->base_handle() << "[";
+  size_t i = 0;
+  for (const Expr* ind : v->indices()) {
+    if (i++) {
+      os() << ", ";
+    }
+    ind->accept(this);
+  }
+  os() << "]";
 }
 
 void IRPrinter::visit(const For* v) {
@@ -296,8 +304,15 @@
 void IRPrinter::visit(const Store* v) {
   // TODO: handle the mask
   emitIndent();
-  os() << *v->base_handle() << "[" << *v->index() << "] = " << *v->value()
-       << ";";
+  os() << *v->base_handle() << "[";
+  size_t i = 0;
+  for (const Expr* ind : v->indices()) {
+    if (i++) {
+      os() << ", ";
+    }
+    ind->accept(this);
+  }
+  os() << "] = " << *v->value() << ";";
 }
 
 void IRPrinter::visit(const Broadcast* v) {
diff --git a/torch/csrc/jit/tensorexpr/ir_visitor.cpp b/torch/csrc/jit/tensorexpr/ir_visitor.cpp
index ee97b3b..9506a48 100644
--- a/torch/csrc/jit/tensorexpr/ir_visitor.cpp
+++ b/torch/csrc/jit/tensorexpr/ir_visitor.cpp
@@ -97,14 +97,22 @@
 }
 
 void IRVisitor::visit(const Load* v) {
-  v->base_handle()->accept(this);
-  v->index()->accept(this);
+  v->buf()->accept(this);
+  for (const Expr* ind : v->indices()) {
+    ind->accept(this);
+  }
   v->mask()->accept(this);
 }
 
-void IRVisitor::visit(const Store* v) {
+void IRVisitor::visit(const Buf* v) {
   v->base_handle()->accept(this);
-  v->index()->accept(this);
+}
+
+void IRVisitor::visit(const Store* v) {
+  v->buf()->accept(this);
+  for (const Expr* ind : v->indices()) {
+    ind->accept(this);
+  }
   v->value()->accept(this);
   v->mask()->accept(this);
 }
@@ -151,8 +159,7 @@
 }
 
 void IRVisitor::visit(const Allocate* v) {
-  const Var* buffer_var = v->buffer_var();
-  buffer_var->accept(this);
+  v->buffer_var()->accept(this);
   std::vector<const Expr*> dims = v->dims();
   for (const Expr* dim : dims) {
     dim->accept(this);
@@ -160,8 +167,7 @@
 }
 
 void IRVisitor::visit(const Free* v) {
-  const Var* buffer_var = v->buffer_var();
-  buffer_var->accept(this);
+  v->buffer_var()->accept(this);
 }
 
 void IRVisitor::visit(const Cond* v) {
diff --git a/torch/csrc/jit/tensorexpr/ir_visitor.h b/torch/csrc/jit/tensorexpr/ir_visitor.h
index 32819c4..1e1cdf3f 100644
--- a/torch/csrc/jit/tensorexpr/ir_visitor.h
+++ b/torch/csrc/jit/tensorexpr/ir_visitor.h
@@ -27,6 +27,7 @@
 
 class Cast;
 class Var;
+class Buf;
 class Let;
 class LetStmt;
 class Ramp;
@@ -70,6 +71,7 @@
 
   virtual void visit(const Cast* v);
   virtual void visit(const Var* v);
+  virtual void visit(const Buf* v);
   virtual void visit(const Let* v);
   virtual void visit(const LetStmt* v);
   virtual void visit(const Ramp* v);
diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h
index 4f2f0e2..d1977a9 100644
--- a/torch/csrc/jit/tensorexpr/kernel.h
+++ b/torch/csrc/jit/tensorexpr/kernel.h
@@ -12,9 +12,8 @@
 template <typename T>
 inline std::vector<int64_t> bufferSizes(const T& t) {
   std::vector<int64_t> sizes;
-  for (int i = 0; i < t->function()->ndim(); i++) {
-    sizes.push_back(
-        dynamic_cast<const IntImm*>(t->function()->dim(i))->value());
+  for (int i = 0; i < t->buf()->ndim(); i++) {
+    sizes.push_back(dynamic_cast<const IntImm*>(t->buf()->dim(i))->value());
   }
   return sizes;
 }
@@ -70,7 +69,7 @@
   template <typename T, typename T1>
   ExprHandle broadcast(const T& t, const std::vector<T1>& axes) {
     return t->call(computeIndicesToBroadcast(
-        axes, ExprVectorToExprHandleVector(t->function()->dims())));
+        axes, ExprVectorToExprHandleVector(t->buf()->dims())));
   }
 
   template <typename T, typename T1>
diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp
index f3b9802..4e16359 100644
--- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp
+++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp
@@ -818,7 +818,7 @@
   if (v->dtype().lanes() == 1) {
     v->base_handle()->accept(this);
     auto base = this->value_;
-    v->index()->accept(this);
+    v->flat_index()->accept(this);
     auto idx = this->value_;
 
     auto* maskimm = dynamic_cast<const IntImm*>(v->mask());
@@ -856,7 +856,7 @@
   }
 
   // Handle the case where the load is contiguous and unmasked efficiently
-  auto* idx_ramp = dynamic_cast<const Ramp*>(v->index());
+  auto* idx_ramp = dynamic_cast<const Ramp*>(v->flat_index());
   if (unmasked_load && idx_ramp) {
     auto* stride_imm = dynamic_cast<const IntImm*>(idx_ramp->stride());
     if (stride_imm && stride_imm->value() == 1) {
@@ -876,7 +876,7 @@
   // Fallback to a scalar implementation
   v->base_handle()->accept(this);
   auto base = this->value_;
-  v->index()->accept(this);
+  v->flat_index()->accept(this);
   auto idx = this->value_;
   v->mask()->accept(this);
   auto mask = this->value_;
@@ -983,7 +983,7 @@
   if (v->value()->dtype().lanes() == 1) {
     v->base_handle()->accept(this);
     auto base = this->value_;
-    v->index()->accept(this);
+    v->flat_index()->accept(this);
     auto idx = this->value_;
     v->value()->accept(this);
     auto val = this->value_;
@@ -1018,7 +1018,7 @@
   auto val = this->value_;
 
   // Handle the case where the store is contiguous and unmasked efficiently
-  auto* idx_ramp = dynamic_cast<const Ramp*>(v->index());
+  auto* idx_ramp = dynamic_cast<const Ramp*>(v->flat_index());
   if (unmasked_store && idx_ramp) {
     auto* stride_imm = dynamic_cast<const IntImm*>(idx_ramp->stride());
     if (stride_imm && stride_imm->value() == 1) {
@@ -1035,7 +1035,7 @@
     }
   }
 
-  v->index()->accept(this);
+  v->flat_index()->accept(this);
   auto idx = this->value_;
   v->mask()->accept(this);
   auto mask = this->value_;
diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp
index 4940b04..2ce5ebf 100644
--- a/torch/csrc/jit/tensorexpr/loopnest.cpp
+++ b/torch/csrc/jit/tensorexpr/loopnest.cpp
@@ -29,6 +29,35 @@
 
 } // namespace
 
+class IndexFlattener : public IRMutator {
+ public:
+  Stmt* flatten(Stmt* s) {
+    return s->accept_mutator(this);
+  }
+  const Expr* mutate(const Load* v) override {
+    if (v->indices().size() == 1) {
+      return v;
+    }
+    return new Load(
+        v->dtype(),
+        v->buf(),
+        {flatten_index(v->buf()->dims(), v->indices())},
+        v->mask());
+  }
+  Stmt* mutate(const Store* v) override {
+    const Expr* value = v->value();
+    const Expr* new_value = value->accept_mutator(this);
+    if (v->indices().size() == 1 && value == new_value) {
+      return (Stmt*)v;
+    }
+    return new Store(
+        v->buf(),
+        {flatten_index(v->buf()->dims(), v->indices())},
+        new_value,
+        v->mask());
+  }
+};
+
 class Vectorizer : public IRMutator {
  public:
   Stmt* vectorize(const For* v) {
@@ -162,13 +191,13 @@
 
   const Expr* mutate(const Load* v) override {
     Dtype dtype(v->dtype().scalar_type(), lanes_);
-    const Var* base_handle = v->base_handle();
-    std::vector<const Expr*> inputs = {v->index(), v->mask()};
+    const Buf* buf = v->buf();
+    std::vector<const Expr*> inputs = {v->flat_index(), v->mask()};
     return try_vectorize(v, inputs, [&]() {
       return Load::make(
           dtype,
-          VarHandle(base_handle),
-          ExprHandle(inputs[0]),
+          BufHandle(buf),
+          {ExprHandle(inputs[0])},
           ExprHandle(inputs[1]));
     });
   }
@@ -204,12 +233,12 @@
   }
 
   Stmt* mutate(const Store* v) override {
-    const Var* base_handle = v->base_handle();
-    std::vector<const Expr*> inputs = {v->index(), v->value(), v->mask()};
+    const Buf* buf = v->buf();
+    std::vector<const Expr*> inputs = {v->flat_index(), v->value(), v->mask()};
     return try_vectorize(v, inputs, [&]() {
       return Store::make(
-          VarHandle(base_handle),
-          ExprHandle(inputs[0]),
+          BufHandle(buf),
+          {ExprHandle(inputs[0])},
           ExprHandle(inputs[1]),
           ExprHandle(inputs[2]));
     });
@@ -299,7 +328,8 @@
   Stmt* old_f = Stmt::clone(f);
   Stmt* new_f = nullptr;
   try {
-    new_f = v.vectorize(f);
+    new_f = FlattenIndexes(f);
+    new_f = v.vectorize(dynamic_cast<For*>(new_f));
   } catch (std::runtime_error& e) {
     // Partial vectorization may have corrupted f
     new_f = old_f;
@@ -312,10 +342,8 @@
  private:
   Expr* mutate(const FunctionCall* v) override {
     const Tensor* t = v->tensor();
-    Buffer buffer(
-        VarHandle(t->func_var()),
-        t->body()->dtype(),
-        ExprVectorToExprHandleVector(t->dims()));
+    const Buf* b = t->buf();
+    Buffer buffer(BufHandle(b), t->body()->dtype());
     const std::vector<const Expr*>& params = v->params();
     std::vector<ExprHandle> params_expr(params.size());
     for (size_t i = 0; i < params.size(); i++) {
@@ -333,19 +361,20 @@
       if (func->func_vars().size() != 1) {
         throw unimplemented_lowering();
       }
-      func_var_set_.insert(func->func_var(0));
+      func_var_set_.insert(func->func_var(0)->base_handle());
     }
   }
 
  protected:
   bool should_inline(Function* func) const {
-    return func_var_set_.count(func->func_var(0)) > 0;
+    return func_var_set_.count(func->func_var(0)->base_handle()) > 0;
   }
 
   // For the target function, insert the caller/callee pair into the replacement
   // mapping.
   const Expr* mutate(const FunctionCall* v) override {
     Function* func = v->tensor()->function();
+    const Buf* buf = v->tensor()->buf();
     // TODO: Support multiple-output functions
     if (func->func_vars().size() != 1) {
       throw unimplemented_lowering();
@@ -353,7 +382,7 @@
 
     if (should_inline(func)) {
       // Insert the caller/callee pair into the mapping.
-      for (int i = 0; i < func->ndim(); i++) {
+      for (int i = 0; i < buf->ndim(); i++) {
         const Var* func_callee_arg = dynamic_cast<const Var*>(func->arg(i));
         const Expr* func_caller_param = v->param(i);
         auto iter = inline_mapping_.find(func_callee_arg);
@@ -369,7 +398,7 @@
       const Expr* result = body->accept_mutator(this);
 
       // Remove the caller/callee relationship.
-      for (int i = 0; i < func->ndim(); i++) {
+      for (int i = 0; i < buf->ndim(); i++) {
         const Var* func_callee_arg = dynamic_cast<const Var*>(func->arg(i));
         auto iter = inline_mapping_.find(func_callee_arg);
         if (iter == inline_mapping_.end()) {
@@ -652,18 +681,18 @@
   stmt_to_tensor_[body] = t;
   tensor_to_stmt_[t] = body;
 
-  if (f->ndim() == 0) {
+  if (t->buf()->ndim() == 0) {
     return body;
   }
 
-  if (f->ndim() == 0) {
+  if (t->buf()->ndim() == 0) {
     throw malformed_input();
   }
 
-  for (size_t i = 0; i < f->ndim(); i++) {
+  for (size_t i = 0; i < t->buf()->ndim(); i++) {
     // Going in reverse order: from innermost loop to the outermost
-    size_t dim_index = f->ndim() - i - 1;
-    Range r(new IntImm(0), f->dim(dim_index));
+    size_t dim_index = t->buf()->ndim() - i - 1;
+    Range r(new IntImm(0), t->buf()->dim(dim_index));
     body = new For(f->arg(dim_index), r.start(), r.stop(), body);
   }
   return body;
@@ -702,8 +731,8 @@
       continue;
     }
     Stmt* alloc = new Allocate(
-        tensor->func_var(), tensor->body()->dtype(), tensor->dims());
-    Stmt* free = new Free(tensor->func_var());
+        tensor->buf()->base_handle(), tensor->body()->dtype(), tensor->dims());
+    Stmt* free = new Free(tensor->buf()->base_handle());
     b->prepend_stmt(alloc);
     b->append_stmt(free);
   }
@@ -722,6 +751,8 @@
   Flattener flattener;
   root_stmt_ = root_stmt_->accept_mutator(&flattener);
 
+  root_stmt_ = FlattenIndexes(root_stmt_);
+
   // Add allocs and frees for intermediate buffers at the global level.
   root_stmt_ = insertAllocFree(root_stmt_);
 }
@@ -879,6 +910,11 @@
   return tensor_to_stmt_.count(t) > 0;
 }
 
+Stmt* FlattenIndexes(Stmt* s) {
+  IndexFlattener idx_flattener;
+  return idx_flattener.flatten(s);
+}
+
 } // namespace tensorexpr
 } // namespace jit
 } // namespace torch
diff --git a/torch/csrc/jit/tensorexpr/loopnest.h b/torch/csrc/jit/tensorexpr/loopnest.h
index 37c56e4..b5d1264 100644
--- a/torch/csrc/jit/tensorexpr/loopnest.h
+++ b/torch/csrc/jit/tensorexpr/loopnest.h
@@ -57,6 +57,8 @@
   std::unordered_set<Tensor*> intermediate_tensors_;
 };
 
+TORCH_API Stmt* FlattenIndexes(Stmt* s);
+
 // represent a range [start, stop)
 class Range {
  public:
diff --git a/torch/csrc/jit/tensorexpr/stmt.h b/torch/csrc/jit/tensorexpr/stmt.h
index c5cf547..944d327 100644
--- a/torch/csrc/jit/tensorexpr/stmt.h
+++ b/torch/csrc/jit/tensorexpr/stmt.h
@@ -165,10 +165,14 @@
 class TORCH_API Store : public StmtNode<Store> {
  public:
   const Var* base_handle() const {
-    return base_handle_;
+    return buf_->base_handle();
   }
-  const Expr* index() const {
-    return index_;
+  std::vector<const Expr*> indices() const {
+    return indices_;
+  }
+  const Expr* flat_index() const {
+    TORCH_CHECK(indices_.size() == 1, "Indices haven't been flattened.");
+    return indices_[0];
   }
   const Expr* value() const {
     return value_;
@@ -176,62 +180,43 @@
   const Expr* mask() const {
     return mask_;
   }
+  const Buf* buf() const {
+    return buf_;
+  }
 
   static Store* make(
       const Buffer& buffer,
-      const ExprHandle& index,
+      const std::vector<ExprHandle>& indices,
       const ExprHandle& value,
-      const ExprHandle& mask) {
-    return new Store(buffer, index.node(), value.node(), mask.node());
-  }
+      const ExprHandle& mask);
 
   static Store* make(
-      const VarHandle& base_handle,
-      const ExprHandle& index,
+      const BufHandle& buf,
+      const std::vector<ExprHandle>& indices,
       const ExprHandle& value,
-      const ExprHandle& mask) {
-    return new Store(
-        base_handle.node(), index.node(), value.node(), mask.node());
-  }
+      const ExprHandle& mask);
 
   static Store* make(
-      const VarHandle& base_handle,
-      const ExprHandle& index,
-      const ExprHandle& value) {
-    return new Store(
-        base_handle.node(), index.node(), value.node(), ExprHandle(1).node());
-  }
+      const BufHandle& buf,
+      const std::vector<ExprHandle>& indices,
+      const ExprHandle& value);
 
   // TODO: merge this with Load.
   Store(
       const Buffer& buffer,
-      const Expr* index,
+      const std::vector<const Expr*>& indices,
       const Expr* value,
       const Expr* mask);
 
   Store(
-      const Var* base_handle,
-      const Expr* index,
+      const Buf* buf,
+      std::vector<const Expr*> indices,
       const Expr* value,
-      const Expr* mask)
-      : base_handle_(base_handle), index_(index), value_(value), mask_(mask) {
-    if (base_handle_->dtype() != kHandle) {
-      throw malformed_input(base_handle);
-    }
-
-    if (index->dtype().lanes() != mask->dtype().lanes() ||
-        index->dtype().lanes() != value->dtype().lanes()) {
-      throw malformed_input();
-    }
-
-    if (index->dtype().scalar_type() != ScalarType::Int) {
-      throw unsupported_dtype();
-    }
-  }
+      const Expr* mask);
 
  private:
-  const Var* base_handle_;
-  const Expr* index_;
+  const Buf* buf_;
+  std::vector<const Expr*> indices_;
   const Expr* value_;
   const Expr* mask_;
 };
diff --git a/torch/csrc/jit/tensorexpr/tensor.h b/torch/csrc/jit/tensorexpr/tensor.h
index a896ed5..f1c9dff 100644
--- a/torch/csrc/jit/tensorexpr/tensor.h
+++ b/torch/csrc/jit/tensorexpr/tensor.h
@@ -23,17 +23,17 @@
   const Expr* body() const {
     return function()->body(output_index());
   }
-  const Var* func_var() const {
+  const Buf* func_var() const {
     return function()->func_var(output_index());
   }
   int ndim() const {
-    return function()->dims().size();
+    return buf_->dims().size();
   }
   const Expr* dim(int index) const {
-    return function()->dim(index);
+    return buf_->dim(index);
   }
-  const std::vector<const Expr*>& dims() const {
-    return function()->dims();
+  std::vector<const Expr*> dims() const {
+    return buf_->dims();
   }
   const Var* arg(int index) const {
     return function()->arg(index);
@@ -42,8 +42,12 @@
     return function()->args();
   }
 
-  Tensor(Function* function, int output_index)
-      : function_(function), output_index_(output_index) {}
+  const Buf* buf() const {
+    return buf_;
+  }
+
+  Tensor(const Buf* buf, Function* function, int output_index)
+      : buf_(buf), function_(function), output_index_(output_index) {}
   template <typename... Ts>
   inline ExprHandle operator()(const Ts&... ts);
   template <typename T>
@@ -52,6 +56,7 @@
   inline ExprHandle call(const Ts&... ts);
 
  private:
+  const Buf* buf_;
   Function* function_;
   int output_index_;
 };