[TensorExpr] Nuke `Function` class and directly use `Tensor` instead. (#45936)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45936

`Tensor` has been a view into a `Function` that was supposed to be used
for a more general case when we have multiple computations over the same
domain (aka multiple output functions). We have never got to a point
where we need this and now have other ideas in mind on how to support
this case if need be. For now, let's just nuke `Function` to reduce the
overall system complexity.

The change should not affect any existing behavior.

Test Plan: Imported from OSS

Reviewed By: bertmaher

Differential Revision: D24153214

Pulled By: ZolotukhinM

fbshipit-source-id: 26d5f11db5d661ff5e1135f4a49eff1c6d4c1bd5
diff --git a/test/cpp/tensorexpr/tutorial.cpp b/test/cpp/tensorexpr/tutorial.cpp
index f0bcfc4..31e0554 100644
--- a/test/cpp/tensorexpr/tutorial.cpp
+++ b/test/cpp/tensorexpr/tutorial.cpp
@@ -125,54 +125,30 @@
     //   independent computations over the same domain) for its elements, as a
     //   function of indices
     //
-    // We use Function objects to represent this. Let's build one.
-    //
-    // First, we need to specify the domain, or dimensions in which the
-    // computation would be performed. Let's create a 64x32 domain:
+    // TODO: Update this section once Tensor/Function cleanup is done
     std::vector<const Expr*> dims = {
         new IntImm(64), new IntImm(32)}; // IntImm stands for Integer Immediate
                                          // and represents an integer constant
 
-    // Next we need to create Function arguments. The arguments of a Function
-    // are Vars, and they play role of placeholders. The computation that the
-    // function would describe would use these arguments.
+    // Next we need to create arguments. The arguments are Vars, and they play
+    // role of placeholders. The computation that the tensor would describe
+    // would use these arguments.
     const Var* i = new Var("i", kInt);
     const Var* j = new Var("j", kInt);
     std::vector<const Var*> args = {i, j};
 
-    // Now we can define the function computations using these arguments. Let's
-    // create two computations, the first would add the arguments of the
-    // function, the second would multiply them.
-    Expr* func_body1 = new Mul(i, j);
-    Expr* func_body2 = new Add(i, j);
+    // Now we can define the body of the tensor computation using these
+    // arguments.
+    Expr* body = new Mul(i, j);
 
-    // Finally, we pass all these pieces together to Function constructor:
-    Function* func =
-        new Function({"X", "Y"}, dims, args, {func_body1, func_body2});
-    // Under the hood function constructor would create separate `Buf`
-    // expressions for each computation (which can be accessed via
-    // `func->func_var(idx)`) with the names specified by the first parameter of
-    // the constructor call. In our example two `Buf` variables will be created
-    // with names 'X' and 'Y', each of them would signify a domain of 64x32.
-
-    // We can now print out our function:
-    std::cout << "Tensor function: " << *func << std::endl;
-    // Prints:
-    // Tensor function: Function F(i[64], j[32]) {
-    //   X = i * j
-    //   Y = i + j
-    // }
-
-    // A Tensor refers to an individual computation defined by a Function. For
-    // instance, we could create a following tensor given the function above:
-    int output_idx = 0; // Used to index the computation
-    Tensor* X = new Tensor(func, output_idx);
+    // Finally, we pass all these pieces together to Tensor constructor:
+    Tensor* X = new Tensor("X", dims, args, body);
     std::cout << "Tensor computation: " << *X << std::endl;
     // Prints: Tensor computation: Tensor X(i[64], j[32]) = i * j
 
     // Similarly to how we provide a more convenient way of using handles for
     // constructing Exprs, Tensors also have a more convenient API for
-    // construction. It is based on Compute functions, which take a name:
+    // construction. It is based on Compute API, which takes a name,
     // dimensions, and a lambda specifying the computation body:
     Tensor* Z = Compute(
         "Z",
@@ -204,14 +180,6 @@
     // Tensor and we use 'load' for accessing elements of an external tensor
     // through its Placeholder. This is an implementation detail and could be
     // changed in future.
-    //
-    // Why do we have Functions and Tensors and what is the relationship between
-    // them? Functions are used to represent several computations performed over
-    // the same domain. Tensors refer to individual computations of a Function.
-    //
-    // Also note that currently a lot of code only supports single-output
-    // Functions, in which case they become almost identical to Tensors. This
-    // probably will be changed in future.
 
     // TODO: Show how reductions are represented and constructed
   }
diff --git a/torch/csrc/jit/tensorexpr/codegen.h b/torch/csrc/jit/tensorexpr/codegen.h
index 4bf9d76..a32b362 100644
--- a/torch/csrc/jit/tensorexpr/codegen.h
+++ b/torch/csrc/jit/tensorexpr/codegen.h
@@ -73,17 +73,7 @@
   BufferArg(const Placeholder& buffer)
       : var_(buffer.data()->base_handle()), dtype_(buffer.dtype()) {}
   BufferArg(Tensor* tensor)
-      : var_(tensor->function()
-                 ->func_var(tensor->output_index())
-                 ->base_handle()),
-        dtype_(tensor->function()->body(tensor->output_index())->dtype()) {}
-  BufferArg(const Function& func)
-      : var_(func.func_var(0)->base_handle()), dtype_(func.body(0)->dtype()) {
-    // TODO: Support multiple-output functions
-    if (func.func_vars().size() != 1) {
-      throw unimplemented_lowering();
-    }
-  }
+      : var_(tensor->buf()->base_handle()), dtype_(tensor->body()->dtype()) {}
   BufferArg(const VarHandle& var)
       : var_(var.node()), dtype_(var.dtype()), isVar_(true) {}
 
diff --git a/torch/csrc/jit/tensorexpr/expr.h b/torch/csrc/jit/tensorexpr/expr.h
index 434aa52..7c64403 100644
--- a/torch/csrc/jit/tensorexpr/expr.h
+++ b/torch/csrc/jit/tensorexpr/expr.h
@@ -194,6 +194,9 @@
     return dims_.size();
   }
   const Expr* dim(size_t index) const {
+    if (index >= ndim()) {
+      throw out_of_range_index();
+    }
     return dims_[index];
   }
   std::vector<const Expr*> dims() const {
diff --git a/torch/csrc/jit/tensorexpr/ir_printer.cpp b/torch/csrc/jit/tensorexpr/ir_printer.cpp
index 8f47d85..5792729 100644
--- a/torch/csrc/jit/tensorexpr/ir_printer.cpp
+++ b/torch/csrc/jit/tensorexpr/ir_printer.cpp
@@ -552,11 +552,6 @@
   return stream;
 }
 
-std::ostream& operator<<(std::ostream& stream, const Function& f) {
-  stream << std::to_string(&f);
-  return stream;
-}
-
 void print(const Expr* expr) {
   if (expr) {
     IRPrinter p(std::cout);
@@ -579,10 +574,6 @@
   std::cout << std::to_string(t);
 }
 
-void print(const Function* f) {
-  std::cout << std::to_string(f);
-}
-
 } // namespace tensorexpr
 } // namespace jit
 } // namespace torch
@@ -615,24 +606,4 @@
   oss << ") = " << *t->body() << "\n";
   return oss.str();
 }
-
-std::string to_string(const Function* f) {
-  if (!f) {
-    return "(null function)\n";
-  }
-  std::ostringstream oss;
-  oss << "Function F(";
-  for (size_t i = 0; i < f->ndim(); i++) {
-    if (i != 0) {
-      oss << ", ";
-    }
-    oss << *f->arg(i) << "[" << *f->dim(i) << "]";
-  }
-  oss << ") {\n";
-  for (size_t i = 0; i < f->bodies().size(); i++) {
-    oss << "  " << *f->func_var(i) << " = " << *f->body(i) << "\n";
-  }
-  oss << "}\n";
-  return oss.str();
-}
 } // namespace std
diff --git a/torch/csrc/jit/tensorexpr/ir_printer.h b/torch/csrc/jit/tensorexpr/ir_printer.h
index 64ba352..d9079d7 100644
--- a/torch/csrc/jit/tensorexpr/ir_printer.h
+++ b/torch/csrc/jit/tensorexpr/ir_printer.h
@@ -11,7 +11,6 @@
 namespace tensorexpr {
 
 class Tensor;
-class Function;
 
 class TORCH_API IRPrinter : public IRVisitor {
  public:
@@ -95,12 +94,10 @@
 TORCH_API std::ostream& operator<<(std::ostream& stream, const ExprHandle&);
 TORCH_API std::ostream& operator<<(std::ostream& stream, const Stmt&);
 TORCH_API std::ostream& operator<<(std::ostream& stream, const Tensor&);
-TORCH_API std::ostream& operator<<(std::ostream& stream, const Function&);
 
 TORCH_API void print(const Expr* expr);
 TORCH_API void print(const Stmt* stmt);
 TORCH_API void print(const Tensor* t);
-TORCH_API void print(const Function* f);
 
 } // namespace tensorexpr
 } // namespace jit
@@ -109,12 +106,10 @@
 namespace std {
 
 using torch::jit::tensorexpr::Expr;
-using torch::jit::tensorexpr::Function;
 using torch::jit::tensorexpr::Stmt;
 using torch::jit::tensorexpr::Tensor;
 
 TORCH_API std::string to_string(const Expr* expr);
 TORCH_API std::string to_string(const Stmt* stmt);
 TORCH_API std::string to_string(const Tensor* t);
-TORCH_API std::string to_string(const Function* f);
 } // namespace std
diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp
index 456a264..301f11f 100644
--- a/torch/csrc/jit/tensorexpr/loopnest.cpp
+++ b/torch/csrc/jit/tensorexpr/loopnest.cpp
@@ -449,11 +449,9 @@
 }
 
 Stmt* LoopNest::lowerToStmt(Tensor* t) {
-  Function* f = t->function();
-  // TODO: Support multiple-output functions
-  Stmt* body = f->ElementStmt(0);
+  Stmt* body = t->ElementStmt();
 
-  if (f->ndim() == 0) {
+  if (t->ndim() == 0 && t->reduce_ndim() == 0) {
     return body;
   }
 
@@ -461,18 +459,30 @@
   if (initializer) {
     buf_initializers_[t->buf()] = initializer;
   }
+
   std::vector<const Expr*> indices(t->args().begin(), t->args().end());
 
-  for (size_t i = 0; i < f->ndim(); i++) {
-    // Going in reverse order: from innermost loop to the outermost
-    size_t dim_index = f->ndim() - i - 1;
-    body = new For(f->arg(dim_index), new IntImm(0), f->dim(dim_index), body);
-    indices.pop_back();
-    if (initializer && indices.size() == t->ndim()) {
+  if (t->reduce_ndim() > 0) {
+    for (size_t i = 0; i < t->reduce_ndim(); i++) {
+      // Going in reverse order: from innermost loop to the outermost
+      size_t dim_index = t->reduce_ndim() - i - 1;
+      body = new For(
+          t->reduce_arg(dim_index),
+          new IntImm(0),
+          t->reduce_dim(dim_index),
+          body);
+    }
+    if (initializer) {
       Store* init = new Store(t->buf(), indices, initializer, new IntImm(1));
       body = new Block({init, body});
     }
   }
+
+  for (size_t i = 0; i < t->ndim(); i++) {
+    // Going in reverse order: from innermost loop to the outermost
+    size_t dim_index = t->ndim() - i - 1;
+    body = new For(t->arg(dim_index), new IntImm(0), t->dim(dim_index), body);
+  }
   return body;
 }
 
@@ -493,26 +503,21 @@
   // For the target function, insert the caller/callee pair into the replacement
   // mapping.
   const Expr* mutate(const FunctionCall* v) override {
-    Function* func = v->tensor()->function();
-    const Buf* buf = v->tensor()->buf();
+    const Tensor* t = v->tensor();
+    const Buf* buf = t->buf();
     if (buf != buf_) {
       return IRMutator::mutate(v);
     }
 
-    // TODO: Support multiple-output functions
-    if (func->func_vars().size() != 1) {
-      throw unimplemented_lowering();
-    }
-
     if (v->nparams() != buf->ndim()) {
       throw malformed_input(
           "Placeholder indexed access is inconsistent with its rank", v);
     }
 
     std::vector<const Var*> index_vars;
-    TORCH_INTERNAL_ASSERT(buf->ndim() == func->args().size());
+    TORCH_INTERNAL_ASSERT(buf->ndim() == t->args().size());
     for (size_t i = 0; i < buf->ndim(); i++) {
-      const Var* func_callee_arg = dynamic_cast<const Var*>(func->arg(i));
+      const Var* func_callee_arg = dynamic_cast<const Var*>(t->arg(i));
       const Expr* func_caller_param = v->param(i);
       auto iter = inline_mapping_.find(func_callee_arg);
       if (iter != inline_mapping_.end()) {
diff --git a/torch/csrc/jit/tensorexpr/tensor.cpp b/torch/csrc/jit/tensorexpr/tensor.cpp
index 4fad4ca..4afc1ff 100644
--- a/torch/csrc/jit/tensorexpr/tensor.cpp
+++ b/torch/csrc/jit/tensorexpr/tensor.cpp
@@ -16,8 +16,7 @@
   std::vector<const Var*> args;
   unpack_dim_args(dim_args, &dims, &args);
   const Expr* body = body_func(VarVectorToVarHandleVector(args)).node();
-  Function* func = new Function(func_name, dims, args, body);
-  return new Tensor(func, 0);
+  return new Tensor(func_name, dims, args, body);
 }
 
 Tensor* Compute(
@@ -32,8 +31,7 @@
   std::vector<const Var*> args;
   unpack_dim_args(dim_args, &dims, &args);
   const Expr* body = body_func(VarHandle(args[0])).node();
-  Function* func = new Function(func_name, dims, args, body);
-  return new Tensor(func, 0);
+  return new Tensor(func_name, dims, args, body);
 }
 
 Tensor* Compute(
@@ -48,8 +46,7 @@
   std::vector<const Var*> args;
   unpack_dim_args(dim_args, &dims, &args);
   const Expr* body = body_func(VarHandle(args[0]), VarHandle(args[1])).node();
-  Function* func = new Function(func_name, dims, args, body);
-  return new Tensor(func, 0);
+  return new Tensor(func_name, dims, args, body);
 }
 
 Tensor* Compute(
@@ -67,8 +64,7 @@
   const Expr* body =
       body_func(VarHandle(args[0]), VarHandle(args[1]), VarHandle(args[2]))
           .node();
-  Function* func = new Function(func_name, dims, args, body);
-  return new Tensor(func, 0);
+  return new Tensor(func_name, dims, args, body);
 }
 
 Tensor* Compute(
@@ -87,20 +83,17 @@
   unpack_dim_args(dim_args, &dims, &args_nodes);
   auto args = VarVectorToVarHandleVector(args_nodes);
   const Expr* body = body_func(args[0], args[1], args[2], args[3]).node();
-  Function* func = new Function(func_name, dims, args_nodes, body);
-  return new Tensor(func, 0);
+  return new Tensor(func_name, dims, args_nodes, body);
 }
 
-Stmt* Function::ElementStmt(size_t index) {
-  const Buf* buf = func_var(index);
+Stmt* Tensor::ElementStmt() {
   std::vector<const Expr*> indices;
-  for (size_t i = 0; i < buf->ndim(); i++) {
-    indices.push_back(this->args_[i]);
+  for (size_t i = 0; i < buf_->ndim(); i++) {
+    indices.push_back(args_[i]);
   }
 
   const Expr* mask = new IntImm(1);
-
-  Stmt* update_stmt = new Store(buf, indices, body(index), mask);
+  Stmt* update_stmt = new Store(buf_, indices, body_, mask);
   return update_stmt;
 }
 
diff --git a/torch/csrc/jit/tensorexpr/tensor.h b/torch/csrc/jit/tensorexpr/tensor.h
index 9d0cadc..d37f14c 100644
--- a/torch/csrc/jit/tensorexpr/tensor.h
+++ b/torch/csrc/jit/tensorexpr/tensor.h
@@ -12,129 +12,80 @@
 namespace jit {
 namespace tensorexpr {
 
-class Function : public KernelScopedObject {
+class Tensor : KernelScopedObject {
  public:
-  Function(
-      const std::string& func_name,
+  Tensor(
+      const std::string& name,
       const std::vector<const Expr*>& dims,
       const std::vector<const Var*>& args,
       const Expr* body)
       // TODO: Function should not create buffers, they should be created
       // manually before constructing a function.
-      : func_vars_({new Buf(func_name, dims, body->dtype())}),
-        dims_(dims),
-        args_(args),
-        bodies_({body}) {}
-  Function(
-      const std::vector<std::string>& func_names,
-      const std::vector<const Expr*>& dims,
+      : buf_(new Buf(name, dims, body->dtype())), args_(args), body_(body) {}
+
+  Tensor(Buf* buf, const std::vector<const Var*>& args, const Expr* body)
+      : buf_(buf), args_(args), body_(body) {}
+
+  Tensor(
+      Buf* buf,
       const std::vector<const Var*>& args,
-      const std::vector<const Expr*>& bodies)
-      : func_vars_(func_names.size()),
-        dims_(dims),
-        args_(args),
-        bodies_(bodies) {
-    for (size_t i = 0; i < func_names.size(); i++) {
-      func_vars_[i] = new Buf(func_names[i], dims, bodies[i]->dtype());
-    }
-  }
-  Function(
-      const std::string& func_name,
-      Buf* func_var,
-      const std::vector<const Expr*>& dims,
-      const std::vector<const Var*>& args,
+      const std::vector<const Expr*>& reduce_dims,
+      const std::vector<const Var*>& reduce_args,
       const Expr* body)
-      : func_vars_({func_var}), dims_(dims), args_(args), bodies_({body}) {}
-
-  size_t ndim() const {
-    return dims_.size();
-  }
-
-  const Expr* dim(size_t index) const {
-    if (index < 0 || index >= dims_.size()) {
-      throw out_of_range_index();
-    }
-
-    return dims_[index];
-  }
-  const std::vector<const Expr*>& dims() const {
-    return dims_;
-  }
-
-  const Var* arg(size_t index) const {
-    if (index < 0 || index >= args_.size()) {
-      throw out_of_range_index();
-    }
-
-    return args_[index];
-  }
-  const std::vector<const Var*>& args() const {
-    return args_;
-  }
-
-  std::vector<const Expr*> bodies() const {
-    return bodies_;
-  }
-  const Expr* body(size_t index) const {
-    if (index >= bodies_.size()) {
-      throw out_of_range_index();
-    }
-
-    return bodies_[index];
-  }
-
-  std::vector<const Buf*> func_vars() const {
-    return func_vars_;
-  }
-  const Buf* func_var(size_t index) const {
-    if (index >= func_vars_.size()) {
-      throw out_of_range_index();
-    }
-    return func_vars_[index];
-  }
-
-  Stmt* ElementStmt(size_t index);
-
- private:
-  std::vector<const Buf*> func_vars_;
-  std::vector<const Expr*> dims_;
-  std::vector<const Var*> args_;
-  std::vector<const Expr*> bodies_;
-};
-
-class Tensor : KernelScopedObject {
- public:
-  Tensor(Function* function, int output_index)
-      : function_(function), output_index_(output_index) {}
-
-  Function* function() const {
-    return function_;
-  }
-  int output_index() const {
-    return output_index_;
-  }
+      : buf_(buf),
+        args_(args),
+        body_(body),
+        reduce_dims_(reduce_dims),
+        reduce_args_(reduce_args) {}
 
   // Wrappers over accessors to fields of the underlying function
   const Expr* body() const {
-    return function()->body(output_index());
+    return body_;
   }
   const Buf* buf() const {
-    return function()->func_var(output_index());
+    return buf_;
   }
-  int ndim() const {
-    return buf()->dims().size();
+  size_t ndim() const {
+    return buf()->ndim();
   }
-  const Expr* dim(int index) const {
+  const Expr* dim(size_t index) const {
+    if (index >= ndim()) {
+      throw out_of_range_index();
+    }
     return buf()->dim(index);
   }
   std::vector<const Expr*> dims() const {
     return buf()->dims();
   }
-  const Var* arg(int index) const {
-    return function()->arg(index);
+  const Var* arg(size_t index) const {
+    if (index >= ndim()) {
+      throw out_of_range_index();
+    }
+    return args_[index];
   }
   const std::vector<const Var*>& args() const {
-    return function()->args();
+    return args_;
+  }
+  size_t reduce_ndim() const {
+    return reduce_dims_.size();
+  }
+  std::vector<const Expr*> reduce_dims() const {
+    return reduce_dims_;
+  }
+  std::vector<const Var*> reduce_args() const {
+    return reduce_args_;
+  }
+  const Expr* reduce_dim(size_t index) const {
+    if (index >= reduce_ndim()) {
+      throw out_of_range_index();
+    }
+    return reduce_dims_[index];
+  }
+  const Var* reduce_arg(size_t index) const {
+    if (index >= reduce_ndim()) {
+      throw out_of_range_index();
+    }
+    return reduce_args_[index];
   }
 
   void initializeTo(const Expr* initializer) {
@@ -143,6 +94,7 @@
   const Expr* initializer() const {
     return initializer_;
   }
+  Stmt* ElementStmt();
 
   template <typename... Ts>
   inline ExprHandle operator()(const Ts&... ts);
@@ -152,8 +104,12 @@
   inline ExprHandle call(const Ts&... ts);
 
  private:
-  Function* function_;
-  int output_index_;
+  const Buf* buf_;
+  std::vector<const Var*> args_;
+  const Expr* body_;
+  std::vector<const Expr*> reduce_dims_;
+  std::vector<const Var*> reduce_args_;
+
   const Expr* initializer_{nullptr};
 };
 
@@ -295,10 +251,8 @@
   Buf* func_result = new Buf(func_name, dims, body.dtype());
   const ReduceOp* reduce_op =
       reducer(func_result, body, output_args, reduce_vars);
-  dims.insert(dims.end(), reduce_dims.begin(), reduce_dims.end());
-  Function* func =
-      new Function(func_name, func_result, dims, all_vars, reduce_op);
-  Tensor* t = new Tensor(func, 0);
+  Tensor* t =
+      new Tensor(func_result, vars, reduce_dims, reduce_vars, reduce_op);
   t->initializeTo(new Cast(body.dtype(), reducer.initializer()));
   return t;
 }
@@ -352,10 +306,7 @@
   }
 
   FunctionCall(Tensor* tensor, const std::vector<const Expr*>& params)
-      : BaseClass(
-            tensor->function()->body(tensor->output_index())->dtype(),
-            kFunctionCall,
-            params),
+      : BaseClass(tensor->body()->dtype(), kFunctionCall, params),
         tensor_(tensor) {}
 
  private: