[TensorExpr] Switch from `ExprPtr` to `ExprHandle` in Compute impl. (#72389)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72389
This is an NFC change that just prepares the code for the upcoming
deletion of `DimArg` class. This change makes `Compute` and `Reduce`
APIs to use `ExprHandle` everywhere.
There should be no observable behavior change from this PR.
Test Plan: Imported from OSS
Reviewed By: navahgar
Differential Revision: D34030295
Pulled By: ZolotukhinM
fbshipit-source-id: 3fd035b6a6bd0a07ccfa92e118819478ae85412a
(cherry picked from commit 1b0a4b6fac54aa4d4735df435d345a30ba0d8a53)
diff --git a/test/cpp/tensorexpr/test_memdependency.cpp b/test/cpp/tensorexpr/test_memdependency.cpp
index 7019353..31c069c 100644
--- a/test/cpp/tensorexpr/test_memdependency.cpp
+++ b/test/cpp/tensorexpr/test_memdependency.cpp
@@ -543,8 +543,7 @@
*/
StorePtr aInit = Store::make(a, {0}, 0);
- ExprHandle reduce =
- ExprHandle(Sum()(a.node(), ExprHandle(1), {x.node()}, {x.node()}));
+ ExprHandle reduce = Sum()(a, 1, {x}, {x});
StorePtr aReduce = Store::make(a, {0}, reduce);
StmtPtr loop = For::make(x, 0, 10, aReduce);
StorePtr bStore = Store::make(b, {0}, Load::make(a, {0}));
diff --git a/torch/csrc/jit/tensorexpr/expr.cpp b/torch/csrc/jit/tensorexpr/expr.cpp
index e17727c..fb47c58 100644
--- a/torch/csrc/jit/tensorexpr/expr.cpp
+++ b/torch/csrc/jit/tensorexpr/expr.cpp
@@ -425,16 +425,30 @@
TORCH_CHECK(var);
}
-ExprHandle Buf::make(
- const std::string& name_hint,
- const std::vector<ExprHandle>& dims,
- Dtype dtype) {
- return ExprHandle(
- alloc<Buf>(name_hint, ExprHandleVectorToExprVector(dims), dtype));
+BufHandle Buf::make(const std::vector<ExprHandle>& dims, Dtype dtype) {
+ return Buf::make("", dims, dtype);
}
-ExprHandle Buf::make(const std::vector<ExprHandle>& dims, Dtype dtype) {
- return Buf::make("", dims, dtype);
+BufHandle Buf::make(
+ const std::string& name_hint,
+ const std::vector<ExprHandle>& dims,
+ Dtype dtype,
+ c10::optional<ExprHandle> initializer,
+ c10::optional<std::vector<ExprHandle>> strides,
+ c10::optional<ExprHandle> qscale,
+ c10::optional<ExprHandle> qzero) {
+ c10::optional<std::vector<ExprPtr>> opt_strides;
+ if (strides) {
+ opt_strides = ExprHandleVectorToExprVector(*strides);
+ }
+ return BufHandle(alloc<Buf>(
+ name_hint,
+ ExprHandleVectorToExprVector(dims),
+ dtype,
+ initializer ? initializer->node() : nullptr,
+ opt_strides,
+ qscale ? qscale->node() : nullptr,
+ qzero ? qzero->node() : nullptr));
}
std::vector<ExprHandle> BufHandle::dims() const {
diff --git a/torch/csrc/jit/tensorexpr/expr.h b/torch/csrc/jit/tensorexpr/expr.h
index 148ce41..cf75e63 100644
--- a/torch/csrc/jit/tensorexpr/expr.h
+++ b/torch/csrc/jit/tensorexpr/expr.h
@@ -191,11 +191,16 @@
class TORCH_API Buf : public ExprNode<Buf> {
public:
- static ExprHandle make(
+ static BufHandle make(const std::vector<ExprHandle>& dims, Dtype dtype);
+
+ static BufHandle make(
const std::string& name_hint,
const std::vector<ExprHandle>& dims,
- Dtype dtype);
- static ExprHandle make(const std::vector<ExprHandle>& dims, Dtype dtype);
+ Dtype dtype,
+ c10::optional<ExprHandle> initializer = c10::nullopt,
+ c10::optional<std::vector<ExprHandle>> strides = c10::nullopt,
+ c10::optional<ExprHandle> qscale = c10::nullopt,
+ c10::optional<ExprHandle> qzero = c10::nullopt);
// TODO: unique_name
VarPtr base_handle() const {
diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp
index cc49ef3..93e76f6 100644
--- a/torch/csrc/jit/tensorexpr/loopnest.cpp
+++ b/torch/csrc/jit/tensorexpr/loopnest.cpp
@@ -2955,7 +2955,7 @@
tmp_params,
reduceOp->reducer()(
producer,
- ExprHandle(alloc<Load>(tmp_buf, new_loop_vars_expr)),
+ alloc<Load>(tmp_buf, new_loop_vars_expr),
tmp_params,
{}));
diff --git a/torch/csrc/jit/tensorexpr/operators/quantization.cpp b/torch/csrc/jit/tensorexpr/operators/quantization.cpp
index 1c99f05..a3f1a91 100644
--- a/torch/csrc/jit/tensorexpr/operators/quantization.cpp
+++ b/torch/csrc/jit/tensorexpr/operators/quantization.cpp
@@ -727,8 +727,8 @@
auto input_height = ExprHandle(A.dim(2));
auto input_width = ExprHandle(A.dim(3));
- std::vector<ExprPtr> dims;
- std::vector<VarPtr> args;
+ std::vector<ExprHandle> dims;
+ std::vector<VarHandle> args;
unpack_dim_args(c10::fmap<DimArg>(outputShape), &dims, &args);
// Handle separately when scale is specified? as in 'scalar_t
// compute_scales_value' in UpSample.h
@@ -749,16 +749,16 @@
newAxes[3] = compute_nearest_idx(scale_w, axes[3], input_width);
return A.load(newAxes);
};
- auto e = body_func(VarVectorToVarHandleVector(args));
- BufPtr buf = alloc<Buf>(
+ auto e = body_func(args);
+ BufHandle buf = Buf::make(
"quantize_upsample_nearest2d",
- ExprHandleVectorToExprVector(outputShape),
+ outputShape,
Dtype(*outputType),
- nullptr,
c10::nullopt,
- A.node()->qscale(),
- A.node()->qzero());
- return Tensor(buf, args, e.node());
+ c10::nullopt,
+ ExprHandle(A.node()->qscale()),
+ ExprHandle(A.node()->qzero()));
+ return Tensor(buf, args, e);
}
Tensor computeUpsampleNearest2dExternalCall(
diff --git a/torch/csrc/jit/tensorexpr/reduction.cpp b/torch/csrc/jit/tensorexpr/reduction.cpp
index 1727482..0cf8cb2 100644
--- a/torch/csrc/jit/tensorexpr/reduction.cpp
+++ b/torch/csrc/jit/tensorexpr/reduction.cpp
@@ -6,12 +6,12 @@
namespace jit {
namespace tensorexpr {
-ReduceOpPtr Reducer::operator()(
- BufPtr result_buf,
+ExprHandle Reducer::operator()(
+ BufHandle result_buf,
ExprHandle body,
- const std::vector<ExprPtr>& output,
- const std::vector<VarPtr>& inner) const {
- return alloc<ReduceOp>(
+ const std::vector<ExprHandle>& output,
+ const std::vector<VarHandle>& inner) const {
+ return ReduceOp::make(
complete(result_buf, interaction_, body, output, inner), inner, *this);
}
@@ -26,6 +26,14 @@
*this);
}
+ExprHandle ReduceOp::make(
+ ExprHandle body,
+ std::vector<VarHandle> reduce_args,
+ const Reducer& reducer) {
+ return ExprHandle(alloc<ReduceOp>(
+ body.node(), VarHandleVectorToVarVector(reduce_args), reducer));
+}
+
} // namespace tensorexpr
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/tensorexpr/reduction.h b/torch/csrc/jit/tensorexpr/reduction.h
index 24df1ef..c1e8401 100644
--- a/torch/csrc/jit/tensorexpr/reduction.h
+++ b/torch/csrc/jit/tensorexpr/reduction.h
@@ -36,11 +36,11 @@
return init_;
}
- ReduceOpPtr operator()(
- BufPtr result_buf,
+ ExprHandle operator()(
+ BufHandle result_buf,
ExprHandle body,
- const std::vector<ExprPtr>& output,
- const std::vector<VarPtr>& inner) const;
+ const std::vector<ExprHandle>& output,
+ const std::vector<VarHandle>& inner) const;
ReduceOpPtr operator()(
BufPtr result_buf,
@@ -111,6 +111,16 @@
auto e = interaction(accum, body);
return e.node();
}
+ static ExprHandle complete(
+ BufHandle accumulator,
+ ReduceInteraction interaction,
+ ExprHandle body,
+ const std::vector<ExprHandle>& output_args,
+ const std::vector<VarHandle>& reduce_args) {
+ ExprHandle accum = Load::make(body.dtype(), accumulator, output_args);
+ auto e = interaction(accum, body);
+ return e;
+ }
private:
ExprPtr init_;
@@ -133,6 +143,10 @@
body_(body),
reduce_args_(std::move(reduce_args)),
reducer_(reducer) {}
+ static ExprHandle make(
+ ExprHandle body,
+ std::vector<VarHandle> reduce_args,
+ const Reducer& reducer);
// return the body expression which obtains the value to be reduced.
ExprPtr body() const {
diff --git a/torch/csrc/jit/tensorexpr/tensor.cpp b/torch/csrc/jit/tensorexpr/tensor.cpp
index c78f27f..10c6f96 100644
--- a/torch/csrc/jit/tensorexpr/tensor.cpp
+++ b/torch/csrc/jit/tensorexpr/tensor.cpp
@@ -53,11 +53,11 @@
const std::string& name,
const std::vector<DimArg>& dim_args,
const std::function<ExprHandle(const std::vector<VarHandle>&)>& body_func) {
- std::vector<ExprPtr> dims;
- std::vector<VarPtr> args;
+ std::vector<ExprHandle> dims;
+ std::vector<VarHandle> args;
unpack_dim_args(dim_args, &dims, &args);
- ExprPtr body = body_func(VarVectorToVarHandleVector(args)).node();
- BufPtr buf = alloc<Buf>(name, dims, body->dtype());
+ ExprHandle body = body_func(args);
+ BufHandle buf = Buf::make(name, dims, body.dtype());
return Tensor(buf, args, body);
}
@@ -69,11 +69,11 @@
throw malformed_input("mismatch between body and arg size (1)");
}
- std::vector<ExprPtr> dims;
- std::vector<VarPtr> args;
+ std::vector<ExprHandle> dims;
+ std::vector<VarHandle> args;
unpack_dim_args(dim_args, &dims, &args);
- ExprPtr body = body_func(VarHandle(args[0])).node();
- BufPtr buf = alloc<Buf>(name, dims, body->dtype());
+ ExprHandle body = body_func(args[0]);
+ BufHandle buf = Buf::make(name, dims, body.dtype());
return Tensor(buf, args, body);
}
@@ -85,11 +85,11 @@
if (dim_args.size() != 2) {
throw malformed_input("mismatch between body and arg size (2)");
}
- std::vector<ExprPtr> dims;
- std::vector<VarPtr> args;
+ std::vector<ExprHandle> dims;
+ std::vector<VarHandle> args;
unpack_dim_args(dim_args, &dims, &args);
- ExprPtr body = body_func(VarHandle(args[0]), VarHandle(args[1])).node();
- BufPtr buf = alloc<Buf>(name, dims, body->dtype());
+ ExprHandle body = body_func(args[0], args[1]);
+ BufHandle buf = Buf::make(name, dims, body.dtype());
return Tensor(buf, args, body);
}
@@ -102,13 +102,11 @@
if (dim_args.size() != 3) {
throw malformed_input("mismatch between body and arg size (3)");
}
- std::vector<ExprPtr> dims;
- std::vector<VarPtr> args;
+ std::vector<ExprHandle> dims;
+ std::vector<VarHandle> args;
unpack_dim_args(dim_args, &dims, &args);
- ExprPtr body =
- body_func(VarHandle(args[0]), VarHandle(args[1]), VarHandle(args[2]))
- .node();
- BufPtr buf = alloc<Buf>(name, dims, body->dtype());
+ ExprHandle body = body_func(args[0], args[1], args[2]);
+ BufHandle buf = Buf::make(name, dims, body.dtype());
return Tensor(buf, args, body);
}
@@ -123,16 +121,11 @@
if (dim_args.size() != 4) {
throw malformed_input("mismatch between body and arg size (4)");
}
- std::vector<ExprPtr> dims;
- std::vector<VarPtr> args;
+ std::vector<ExprHandle> dims;
+ std::vector<VarHandle> args;
unpack_dim_args(dim_args, &dims, &args);
- ExprPtr body = body_func(
- VarHandle(args[0]),
- VarHandle(args[1]),
- VarHandle(args[2]),
- VarHandle(args[3]))
- .node();
- BufPtr buf = alloc<Buf>(name, dims, body->dtype());
+ ExprHandle body = body_func(args[0], args[1], args[2], args[3]);
+ BufHandle buf = Buf::make(name, dims, body.dtype());
return Tensor(buf, args, body);
}
diff --git a/torch/csrc/jit/tensorexpr/tensor.h b/torch/csrc/jit/tensorexpr/tensor.h
index 1824463..d2cfe0a 100644
--- a/torch/csrc/jit/tensorexpr/tensor.h
+++ b/torch/csrc/jit/tensorexpr/tensor.h
@@ -19,6 +19,8 @@
: buf_(buf) {
stmt_ = constructStmt(args, body, {}, {});
}
+ Tensor(BufHandle buf, const std::vector<VarHandle>& args, ExprHandle body)
+ : Tensor(buf.node(), VarHandleVectorToVarVector(args), body.node()) {}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Tensor(
@@ -30,6 +32,18 @@
: buf_(buf) {
stmt_ = constructStmt(args, body, reduce_dims, reduce_args);
}
+ Tensor(
+ BufHandle buf,
+ const std::vector<VarHandle>& args,
+ const std::vector<ExprHandle>& reduce_dims,
+ const std::vector<VarHandle>& reduce_args,
+ ExprHandle body)
+ : Tensor(
+ buf.node(),
+ VarHandleVectorToVarVector(args),
+ ExprHandleVectorToExprVector(reduce_dims),
+ VarHandleVectorToVarVector(reduce_args),
+ body.node()) {}
Tensor(BufPtr buf, StmtPtr stmt) : buf_(buf), stmt_(stmt) {}
@@ -87,16 +101,16 @@
inline void unpack_dim_args(
const std::vector<DimArg>& dim_args,
- std::vector<ExprPtr>* dims,
- std::vector<VarPtr>* vars) {
+ std::vector<ExprHandle>* dims,
+ std::vector<VarHandle>* vars) {
dims->clear();
vars->clear();
for (const DimArg& dim_arg : dim_args) {
- ExprPtr expr = dim_arg.dim().node();
+ ExprHandle expr = dim_arg.dim();
dims->push_back(expr);
- vars->push_back(alloc<Var>(
+ vars->push_back(VarHandle(alloc<Var>(
dim_arg.name_hint(),
- expr->dtype().scalar_type() == ScalarType::Long ? kLong : kInt));
+ expr.dtype().scalar_type() == ScalarType::Long ? kLong : kInt)));
}
}
@@ -109,47 +123,31 @@
const InitFunc& init_func,
const BodyFunc& body_func,
const std::vector<DimArg>& reduce_args) {
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- std::vector<ExprPtr> dims;
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- std::vector<VarPtr> vars;
+ std::vector<ExprHandle> dims;
+ std::vector<VarHandle> vars;
unpack_dim_args(dim_args, &dims, &vars);
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- std::vector<ExprPtr> reduce_dims;
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- std::vector<VarPtr> reduce_vars;
+ std::vector<ExprHandle> reduce_dims;
+ std::vector<VarHandle> reduce_vars;
unpack_dim_args(reduce_args, &reduce_dims, &reduce_vars);
// If reduce_vars is empty, then it's not a reduction, but rather a simple
// copy
if (reduce_vars.empty()) {
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- ExprPtr body =
- Reducer::getReduceBody(body_func, VarVectorToVarHandleVector(vars))
- .node();
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- BufPtr func_result = alloc<Buf>(func_name, dims, body->dtype());
+ ExprHandle body = Reducer::getReduceBody(body_func, vars);
+ BufHandle func_result = Buf::make(func_name, dims, body.dtype());
return Tensor(func_result, vars, body);
}
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- std::vector<VarPtr> all_vars;
+ std::vector<VarHandle> all_vars;
all_vars.insert(all_vars.end(), vars.begin(), vars.end());
all_vars.insert(all_vars.end(), reduce_vars.begin(), reduce_vars.end());
- ExprHandle body =
- Reducer::getReduceBody(body_func, VarVectorToVarHandleVector(all_vars));
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- std::vector<ExprPtr> output_args(vars.begin(), vars.end());
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- ExprPtr init_expr = alloc<Cast>(
- body.dtype(), init_func(VarVectorToVarHandleVector(vars)).node());
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- BufPtr func_result = alloc<Buf>(func_name, dims, body.dtype(), init_expr);
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- ReduceOpPtr reduce_op = reducer(func_result, body, output_args, reduce_vars);
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+ ExprHandle body = Reducer::getReduceBody(body_func, all_vars);
+ std::vector<ExprHandle> output_args(vars.begin(), vars.end());
+ ExprHandle init_expr = Cast::make(body.dtype(), init_func(vars));
+ BufHandle func_result = Buf::make(func_name, dims, body.dtype(), init_expr);
+ ExprHandle reduce_op = reducer(func_result, body, output_args, reduce_vars);
Tensor t = Tensor(func_result, vars, reduce_dims, reduce_vars, reduce_op);
return t;
}