[TensorExpr] Nuke `Function` class and directly use `Tensor` instead. (#45936)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45936
`Tensor` has been a view into a `Function` that was supposed to be used
for a more general case when we have multiple computations over the same
domain (aka multiple output functions). We have never got to a point
where we need this and now have other ideas in mind on how to support
this case if need be. For now, let's just nuke `Function` to reduce the
overall system complexity.
The change should not affect any existing behavior.
Test Plan: Imported from OSS
Reviewed By: bertmaher
Differential Revision: D24153214
Pulled By: ZolotukhinM
fbshipit-source-id: 26d5f11db5d661ff5e1135f4a49eff1c6d4c1bd5
diff --git a/test/cpp/tensorexpr/tutorial.cpp b/test/cpp/tensorexpr/tutorial.cpp
index f0bcfc4..31e0554 100644
--- a/test/cpp/tensorexpr/tutorial.cpp
+++ b/test/cpp/tensorexpr/tutorial.cpp
@@ -125,54 +125,30 @@
// independent computations over the same domain) for its elements, as a
// function of indices
//
- // We use Function objects to represent this. Let's build one.
- //
- // First, we need to specify the domain, or dimensions in which the
- // computation would be performed. Let's create a 64x32 domain:
+ // TODO: Update this section once Tensor/Function cleanup is done
std::vector<const Expr*> dims = {
new IntImm(64), new IntImm(32)}; // IntImm stands for Integer Immediate
// and represents an integer constant
- // Next we need to create Function arguments. The arguments of a Function
- // are Vars, and they play role of placeholders. The computation that the
- // function would describe would use these arguments.
+ // Next we need to create arguments. The arguments are Vars, and they play
+ // role of placeholders. The computation that the tensor would describe
+ // would use these arguments.
const Var* i = new Var("i", kInt);
const Var* j = new Var("j", kInt);
std::vector<const Var*> args = {i, j};
- // Now we can define the function computations using these arguments. Let's
- // create two computations, the first would add the arguments of the
- // function, the second would multiply them.
- Expr* func_body1 = new Mul(i, j);
- Expr* func_body2 = new Add(i, j);
+ // Now we can define the body of the tensor computation using these
+ // arguments.
+ Expr* body = new Mul(i, j);
- // Finally, we pass all these pieces together to Function constructor:
- Function* func =
- new Function({"X", "Y"}, dims, args, {func_body1, func_body2});
- // Under the hood function constructor would create separate `Buf`
- // expressions for each computation (which can be accessed via
- // `func->func_var(idx)`) with the names specified by the first parameter of
- // the constructor call. In our example two `Buf` variables will be created
- // with names 'X' and 'Y', each of them would signify a domain of 64x32.
-
- // We can now print out our function:
- std::cout << "Tensor function: " << *func << std::endl;
- // Prints:
- // Tensor function: Function F(i[64], j[32]) {
- // X = i * j
- // Y = i + j
- // }
-
- // A Tensor refers to an individual computation defined by a Function. For
- // instance, we could create a following tensor given the function above:
- int output_idx = 0; // Used to index the computation
- Tensor* X = new Tensor(func, output_idx);
+ // Finally, we pass all these pieces together to Tensor constructor:
+ Tensor* X = new Tensor("X", dims, args, body);
std::cout << "Tensor computation: " << *X << std::endl;
// Prints: Tensor computation: Tensor X(i[64], j[32]) = i * j
// Similarly to how we provide a more convenient way of using handles for
// constructing Exprs, Tensors also have a more convenient API for
- // construction. It is based on Compute functions, which take a name:
+ // construction. It is based on Compute API, which takes a name,
// dimensions, and a lambda specifying the computation body:
Tensor* Z = Compute(
"Z",
@@ -204,14 +180,6 @@
// Tensor and we use 'load' for accessing elements of an external tensor
// through its Placeholder. This is an implementation detail and could be
// changed in future.
- //
- // Why do we have Functions and Tensors and what is the relationship between
- // them? Functions are used to represent several computations performed over
- // the same domain. Tensors refer to individual computations of a Function.
- //
- // Also note that currently a lot of code only supports single-output
- // Functions, in which case they become almost identical to Tensors. This
- // probably will be changed in future.
// TODO: Show how reductions are represented and constructed
}
diff --git a/torch/csrc/jit/tensorexpr/codegen.h b/torch/csrc/jit/tensorexpr/codegen.h
index 4bf9d76..a32b362 100644
--- a/torch/csrc/jit/tensorexpr/codegen.h
+++ b/torch/csrc/jit/tensorexpr/codegen.h
@@ -73,17 +73,7 @@
BufferArg(const Placeholder& buffer)
: var_(buffer.data()->base_handle()), dtype_(buffer.dtype()) {}
BufferArg(Tensor* tensor)
- : var_(tensor->function()
- ->func_var(tensor->output_index())
- ->base_handle()),
- dtype_(tensor->function()->body(tensor->output_index())->dtype()) {}
- BufferArg(const Function& func)
- : var_(func.func_var(0)->base_handle()), dtype_(func.body(0)->dtype()) {
- // TODO: Support multiple-output functions
- if (func.func_vars().size() != 1) {
- throw unimplemented_lowering();
- }
- }
+ : var_(tensor->buf()->base_handle()), dtype_(tensor->body()->dtype()) {}
BufferArg(const VarHandle& var)
: var_(var.node()), dtype_(var.dtype()), isVar_(true) {}
diff --git a/torch/csrc/jit/tensorexpr/expr.h b/torch/csrc/jit/tensorexpr/expr.h
index 434aa52..7c64403 100644
--- a/torch/csrc/jit/tensorexpr/expr.h
+++ b/torch/csrc/jit/tensorexpr/expr.h
@@ -194,6 +194,9 @@
return dims_.size();
}
const Expr* dim(size_t index) const {
+ if (index >= ndim()) {
+ throw out_of_range_index();
+ }
return dims_[index];
}
std::vector<const Expr*> dims() const {
diff --git a/torch/csrc/jit/tensorexpr/ir_printer.cpp b/torch/csrc/jit/tensorexpr/ir_printer.cpp
index 8f47d85..5792729 100644
--- a/torch/csrc/jit/tensorexpr/ir_printer.cpp
+++ b/torch/csrc/jit/tensorexpr/ir_printer.cpp
@@ -552,11 +552,6 @@
return stream;
}
-std::ostream& operator<<(std::ostream& stream, const Function& f) {
- stream << std::to_string(&f);
- return stream;
-}
-
void print(const Expr* expr) {
if (expr) {
IRPrinter p(std::cout);
@@ -579,10 +574,6 @@
std::cout << std::to_string(t);
}
-void print(const Function* f) {
- std::cout << std::to_string(f);
-}
-
} // namespace tensorexpr
} // namespace jit
} // namespace torch
@@ -615,24 +606,4 @@
oss << ") = " << *t->body() << "\n";
return oss.str();
}
-
-std::string to_string(const Function* f) {
- if (!f) {
- return "(null function)\n";
- }
- std::ostringstream oss;
- oss << "Function F(";
- for (size_t i = 0; i < f->ndim(); i++) {
- if (i != 0) {
- oss << ", ";
- }
- oss << *f->arg(i) << "[" << *f->dim(i) << "]";
- }
- oss << ") {\n";
- for (size_t i = 0; i < f->bodies().size(); i++) {
- oss << " " << *f->func_var(i) << " = " << *f->body(i) << "\n";
- }
- oss << "}\n";
- return oss.str();
-}
} // namespace std
diff --git a/torch/csrc/jit/tensorexpr/ir_printer.h b/torch/csrc/jit/tensorexpr/ir_printer.h
index 64ba352..d9079d7 100644
--- a/torch/csrc/jit/tensorexpr/ir_printer.h
+++ b/torch/csrc/jit/tensorexpr/ir_printer.h
@@ -11,7 +11,6 @@
namespace tensorexpr {
class Tensor;
-class Function;
class TORCH_API IRPrinter : public IRVisitor {
public:
@@ -95,12 +94,10 @@
TORCH_API std::ostream& operator<<(std::ostream& stream, const ExprHandle&);
TORCH_API std::ostream& operator<<(std::ostream& stream, const Stmt&);
TORCH_API std::ostream& operator<<(std::ostream& stream, const Tensor&);
-TORCH_API std::ostream& operator<<(std::ostream& stream, const Function&);
TORCH_API void print(const Expr* expr);
TORCH_API void print(const Stmt* stmt);
TORCH_API void print(const Tensor* t);
-TORCH_API void print(const Function* f);
} // namespace tensorexpr
} // namespace jit
@@ -109,12 +106,10 @@
namespace std {
using torch::jit::tensorexpr::Expr;
-using torch::jit::tensorexpr::Function;
using torch::jit::tensorexpr::Stmt;
using torch::jit::tensorexpr::Tensor;
TORCH_API std::string to_string(const Expr* expr);
TORCH_API std::string to_string(const Stmt* stmt);
TORCH_API std::string to_string(const Tensor* t);
-TORCH_API std::string to_string(const Function* f);
} // namespace std
diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp
index 456a264..301f11f 100644
--- a/torch/csrc/jit/tensorexpr/loopnest.cpp
+++ b/torch/csrc/jit/tensorexpr/loopnest.cpp
@@ -449,11 +449,9 @@
}
Stmt* LoopNest::lowerToStmt(Tensor* t) {
- Function* f = t->function();
- // TODO: Support multiple-output functions
- Stmt* body = f->ElementStmt(0);
+ Stmt* body = t->ElementStmt();
- if (f->ndim() == 0) {
+ if (t->ndim() == 0 && t->reduce_ndim() == 0) {
return body;
}
@@ -461,18 +459,30 @@
if (initializer) {
buf_initializers_[t->buf()] = initializer;
}
+
std::vector<const Expr*> indices(t->args().begin(), t->args().end());
- for (size_t i = 0; i < f->ndim(); i++) {
- // Going in reverse order: from innermost loop to the outermost
- size_t dim_index = f->ndim() - i - 1;
- body = new For(f->arg(dim_index), new IntImm(0), f->dim(dim_index), body);
- indices.pop_back();
- if (initializer && indices.size() == t->ndim()) {
+ if (t->reduce_ndim() > 0) {
+ for (size_t i = 0; i < t->reduce_ndim(); i++) {
+ // Going in reverse order: from innermost loop to the outermost
+ size_t dim_index = t->reduce_ndim() - i - 1;
+ body = new For(
+ t->reduce_arg(dim_index),
+ new IntImm(0),
+ t->reduce_dim(dim_index),
+ body);
+ }
+ if (initializer) {
Store* init = new Store(t->buf(), indices, initializer, new IntImm(1));
body = new Block({init, body});
}
}
+
+ for (size_t i = 0; i < t->ndim(); i++) {
+ // Going in reverse order: from innermost loop to the outermost
+ size_t dim_index = t->ndim() - i - 1;
+ body = new For(t->arg(dim_index), new IntImm(0), t->dim(dim_index), body);
+ }
return body;
}
@@ -493,26 +503,21 @@
// For the target function, insert the caller/callee pair into the replacement
// mapping.
const Expr* mutate(const FunctionCall* v) override {
- Function* func = v->tensor()->function();
- const Buf* buf = v->tensor()->buf();
+ const Tensor* t = v->tensor();
+ const Buf* buf = t->buf();
if (buf != buf_) {
return IRMutator::mutate(v);
}
- // TODO: Support multiple-output functions
- if (func->func_vars().size() != 1) {
- throw unimplemented_lowering();
- }
-
if (v->nparams() != buf->ndim()) {
throw malformed_input(
"Placeholder indexed access is inconsistent with its rank", v);
}
std::vector<const Var*> index_vars;
- TORCH_INTERNAL_ASSERT(buf->ndim() == func->args().size());
+ TORCH_INTERNAL_ASSERT(buf->ndim() == t->args().size());
for (size_t i = 0; i < buf->ndim(); i++) {
- const Var* func_callee_arg = dynamic_cast<const Var*>(func->arg(i));
+ const Var* func_callee_arg = dynamic_cast<const Var*>(t->arg(i));
const Expr* func_caller_param = v->param(i);
auto iter = inline_mapping_.find(func_callee_arg);
if (iter != inline_mapping_.end()) {
diff --git a/torch/csrc/jit/tensorexpr/tensor.cpp b/torch/csrc/jit/tensorexpr/tensor.cpp
index 4fad4ca..4afc1ff 100644
--- a/torch/csrc/jit/tensorexpr/tensor.cpp
+++ b/torch/csrc/jit/tensorexpr/tensor.cpp
@@ -16,8 +16,7 @@
std::vector<const Var*> args;
unpack_dim_args(dim_args, &dims, &args);
const Expr* body = body_func(VarVectorToVarHandleVector(args)).node();
- Function* func = new Function(func_name, dims, args, body);
- return new Tensor(func, 0);
+ return new Tensor(func_name, dims, args, body);
}
Tensor* Compute(
@@ -32,8 +31,7 @@
std::vector<const Var*> args;
unpack_dim_args(dim_args, &dims, &args);
const Expr* body = body_func(VarHandle(args[0])).node();
- Function* func = new Function(func_name, dims, args, body);
- return new Tensor(func, 0);
+ return new Tensor(func_name, dims, args, body);
}
Tensor* Compute(
@@ -48,8 +46,7 @@
std::vector<const Var*> args;
unpack_dim_args(dim_args, &dims, &args);
const Expr* body = body_func(VarHandle(args[0]), VarHandle(args[1])).node();
- Function* func = new Function(func_name, dims, args, body);
- return new Tensor(func, 0);
+ return new Tensor(func_name, dims, args, body);
}
Tensor* Compute(
@@ -67,8 +64,7 @@
const Expr* body =
body_func(VarHandle(args[0]), VarHandle(args[1]), VarHandle(args[2]))
.node();
- Function* func = new Function(func_name, dims, args, body);
- return new Tensor(func, 0);
+ return new Tensor(func_name, dims, args, body);
}
Tensor* Compute(
@@ -87,20 +83,17 @@
unpack_dim_args(dim_args, &dims, &args_nodes);
auto args = VarVectorToVarHandleVector(args_nodes);
const Expr* body = body_func(args[0], args[1], args[2], args[3]).node();
- Function* func = new Function(func_name, dims, args_nodes, body);
- return new Tensor(func, 0);
+ return new Tensor(func_name, dims, args_nodes, body);
}
-Stmt* Function::ElementStmt(size_t index) {
- const Buf* buf = func_var(index);
+Stmt* Tensor::ElementStmt() {
std::vector<const Expr*> indices;
- for (size_t i = 0; i < buf->ndim(); i++) {
- indices.push_back(this->args_[i]);
+ for (size_t i = 0; i < buf_->ndim(); i++) {
+ indices.push_back(args_[i]);
}
const Expr* mask = new IntImm(1);
-
- Stmt* update_stmt = new Store(buf, indices, body(index), mask);
+ Stmt* update_stmt = new Store(buf_, indices, body_, mask);
return update_stmt;
}
diff --git a/torch/csrc/jit/tensorexpr/tensor.h b/torch/csrc/jit/tensorexpr/tensor.h
index 9d0cadc..d37f14c 100644
--- a/torch/csrc/jit/tensorexpr/tensor.h
+++ b/torch/csrc/jit/tensorexpr/tensor.h
@@ -12,129 +12,80 @@
namespace jit {
namespace tensorexpr {
-class Function : public KernelScopedObject {
+class Tensor : KernelScopedObject {
public:
- Function(
- const std::string& func_name,
+ Tensor(
+ const std::string& name,
const std::vector<const Expr*>& dims,
const std::vector<const Var*>& args,
const Expr* body)
// TODO: Function should not create buffers, they should be created
// manually before constructing a function.
- : func_vars_({new Buf(func_name, dims, body->dtype())}),
- dims_(dims),
- args_(args),
- bodies_({body}) {}
- Function(
- const std::vector<std::string>& func_names,
- const std::vector<const Expr*>& dims,
+ : buf_(new Buf(name, dims, body->dtype())), args_(args), body_(body) {}
+
+ Tensor(Buf* buf, const std::vector<const Var*>& args, const Expr* body)
+ : buf_(buf), args_(args), body_(body) {}
+
+ Tensor(
+ Buf* buf,
const std::vector<const Var*>& args,
- const std::vector<const Expr*>& bodies)
- : func_vars_(func_names.size()),
- dims_(dims),
- args_(args),
- bodies_(bodies) {
- for (size_t i = 0; i < func_names.size(); i++) {
- func_vars_[i] = new Buf(func_names[i], dims, bodies[i]->dtype());
- }
- }
- Function(
- const std::string& func_name,
- Buf* func_var,
- const std::vector<const Expr*>& dims,
- const std::vector<const Var*>& args,
+ const std::vector<const Expr*>& reduce_dims,
+ const std::vector<const Var*>& reduce_args,
const Expr* body)
- : func_vars_({func_var}), dims_(dims), args_(args), bodies_({body}) {}
-
- size_t ndim() const {
- return dims_.size();
- }
-
- const Expr* dim(size_t index) const {
- if (index < 0 || index >= dims_.size()) {
- throw out_of_range_index();
- }
-
- return dims_[index];
- }
- const std::vector<const Expr*>& dims() const {
- return dims_;
- }
-
- const Var* arg(size_t index) const {
- if (index < 0 || index >= args_.size()) {
- throw out_of_range_index();
- }
-
- return args_[index];
- }
- const std::vector<const Var*>& args() const {
- return args_;
- }
-
- std::vector<const Expr*> bodies() const {
- return bodies_;
- }
- const Expr* body(size_t index) const {
- if (index >= bodies_.size()) {
- throw out_of_range_index();
- }
-
- return bodies_[index];
- }
-
- std::vector<const Buf*> func_vars() const {
- return func_vars_;
- }
- const Buf* func_var(size_t index) const {
- if (index >= func_vars_.size()) {
- throw out_of_range_index();
- }
- return func_vars_[index];
- }
-
- Stmt* ElementStmt(size_t index);
-
- private:
- std::vector<const Buf*> func_vars_;
- std::vector<const Expr*> dims_;
- std::vector<const Var*> args_;
- std::vector<const Expr*> bodies_;
-};
-
-class Tensor : KernelScopedObject {
- public:
- Tensor(Function* function, int output_index)
- : function_(function), output_index_(output_index) {}
-
- Function* function() const {
- return function_;
- }
- int output_index() const {
- return output_index_;
- }
+ : buf_(buf),
+ args_(args),
+ body_(body),
+ reduce_dims_(reduce_dims),
+ reduce_args_(reduce_args) {}
// Wrappers over accessors to fields of the underlying function
const Expr* body() const {
- return function()->body(output_index());
+ return body_;
}
const Buf* buf() const {
- return function()->func_var(output_index());
+ return buf_;
}
- int ndim() const {
- return buf()->dims().size();
+ size_t ndim() const {
+ return buf()->ndim();
}
- const Expr* dim(int index) const {
+ const Expr* dim(size_t index) const {
+ if (index >= ndim()) {
+ throw out_of_range_index();
+ }
return buf()->dim(index);
}
std::vector<const Expr*> dims() const {
return buf()->dims();
}
- const Var* arg(int index) const {
- return function()->arg(index);
+ const Var* arg(size_t index) const {
+ if (index >= ndim()) {
+ throw out_of_range_index();
+ }
+ return args_[index];
}
const std::vector<const Var*>& args() const {
- return function()->args();
+ return args_;
+ }
+ size_t reduce_ndim() const {
+ return reduce_dims_.size();
+ }
+ std::vector<const Expr*> reduce_dims() const {
+ return reduce_dims_;
+ }
+ std::vector<const Var*> reduce_args() const {
+ return reduce_args_;
+ }
+ const Expr* reduce_dim(size_t index) const {
+ if (index >= reduce_ndim()) {
+ throw out_of_range_index();
+ }
+ return reduce_dims_[index];
+ }
+ const Var* reduce_arg(size_t index) const {
+ if (index >= reduce_ndim()) {
+ throw out_of_range_index();
+ }
+ return reduce_args_[index];
}
void initializeTo(const Expr* initializer) {
@@ -143,6 +94,7 @@
const Expr* initializer() const {
return initializer_;
}
+ Stmt* ElementStmt();
template <typename... Ts>
inline ExprHandle operator()(const Ts&... ts);
@@ -152,8 +104,12 @@
inline ExprHandle call(const Ts&... ts);
private:
- Function* function_;
- int output_index_;
+ const Buf* buf_;
+ std::vector<const Var*> args_;
+ const Expr* body_;
+ std::vector<const Expr*> reduce_dims_;
+ std::vector<const Var*> reduce_args_;
+
const Expr* initializer_{nullptr};
};
@@ -295,10 +251,8 @@
Buf* func_result = new Buf(func_name, dims, body.dtype());
const ReduceOp* reduce_op =
reducer(func_result, body, output_args, reduce_vars);
- dims.insert(dims.end(), reduce_dims.begin(), reduce_dims.end());
- Function* func =
- new Function(func_name, func_result, dims, all_vars, reduce_op);
- Tensor* t = new Tensor(func, 0);
+ Tensor* t =
+ new Tensor(func_result, vars, reduce_dims, reduce_vars, reduce_op);
t->initializeTo(new Cast(body.dtype(), reducer.initializer()));
return t;
}
@@ -352,10 +306,7 @@
}
FunctionCall(Tensor* tensor, const std::vector<const Expr*>& params)
- : BaseClass(
- tensor->function()->body(tensor->output_index())->dtype(),
- kFunctionCall,
- params),
+ : BaseClass(tensor->body()->dtype(), kFunctionCall, params),
tensor_(tensor) {}
private: