[TensorExpr] Add a constructor accepting a name_hint to class Buf. (#36617)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/36617
Test Plan: Imported from OSS
Differential Revision: D21027355
Pulled By: ZolotukhinM
fbshipit-source-id: 54633f7400f24f7f9fdcaeead94c80282ccb5207
diff --git a/torch/csrc/jit/tensorexpr/expr.cpp b/torch/csrc/jit/tensorexpr/expr.cpp
index 8549d1c..9920035 100644
--- a/torch/csrc/jit/tensorexpr/expr.cpp
+++ b/torch/csrc/jit/tensorexpr/expr.cpp
@@ -207,8 +207,8 @@
const std::string& name_hint,
const std::vector<ExprHandle>& dims,
Dtype dtype) {
- return ExprHandle(new Buf(
- new Var(name_hint, kHandle), ExprHandleVectorToExprVector(dims), dtype));
+ return ExprHandle(
+ new Buf(name_hint, ExprHandleVectorToExprVector(dims), dtype));
}
ExprHandle Buf::make(const std::vector<ExprHandle>& dims, Dtype dtype) {
diff --git a/torch/csrc/jit/tensorexpr/expr.h b/torch/csrc/jit/tensorexpr/expr.h
index 1756ef9c..5018f6b 100644
--- a/torch/csrc/jit/tensorexpr/expr.h
+++ b/torch/csrc/jit/tensorexpr/expr.h
@@ -178,6 +178,11 @@
return base_handle_->name_hint();
}
+ Buf(const std::string& name_hint,
+ const std::vector<const Expr*>& dims,
+ Dtype dtype)
+ : Buf(new Var(name_hint, kHandle), dims, dtype) {}
+
Buf(const Var* var, const std::vector<const Expr*>& dims, Dtype dtype)
: ExprNodeBase(dtype, kPrimitive), base_handle_(var), dims_(dims) {
TORCH_CHECK(var);
diff --git a/torch/csrc/jit/tensorexpr/function.h b/torch/csrc/jit/tensorexpr/function.h
index 8b1c479..128253d 100644
--- a/torch/csrc/jit/tensorexpr/function.h
+++ b/torch/csrc/jit/tensorexpr/function.h
@@ -19,7 +19,7 @@
const Expr* body)
// TODO: Function should not create buffers, they should be created
// manually before constructing a function.
- : func_vars_({new Buf(new Var(func_name, kHandle), dims, body->dtype())}),
+ : func_vars_({new Buf(func_name, dims, body->dtype())}),
dims_(dims),
args_(args),
bodies_({body}) {}
@@ -33,8 +33,7 @@
args_(args),
bodies_(bodies) {
for (size_t i = 0; i < func_names.size(); i++) {
- func_vars_[i] =
- new Buf(new Var(func_names[i], kHandle), dims, bodies[i]->dtype());
+ func_vars_[i] = new Buf(func_names[i], dims, bodies[i]->dtype());
}
}
Function(
diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp
index 15f27d3..c1ec69a 100644
--- a/torch/csrc/jit/tensorexpr/loopnest.cpp
+++ b/torch/csrc/jit/tensorexpr/loopnest.cpp
@@ -1441,8 +1441,7 @@
}
// TODO: Use name-hint of the producer instead of "temp"
- const Buf* temp_buf =
- new Buf(new Var("temp", kHandle), dims, st->value()->dtype());
+ const Buf* temp_buf = new Buf("temp", dims, st->value()->dtype());
// Generate index variables for 'temp'
std::vector<const Expr*> temp_indices(dims.size());
diff --git a/torch/csrc/jit/tensorexpr/tensor.h b/torch/csrc/jit/tensorexpr/tensor.h
index 22f04a6..73208bb 100644
--- a/torch/csrc/jit/tensorexpr/tensor.h
+++ b/torch/csrc/jit/tensorexpr/tensor.h
@@ -129,7 +129,7 @@
ExprHandle body =
Reducer::getReduceBody(body_func, VarVectorToVarHandleVector(all_vars));
std::vector<const Expr*> output_args(vars.begin(), vars.end());
- Buf* func_result = new Buf(new Var(func_name, kHandle), dims, body.dtype());
+ Buf* func_result = new Buf(func_name, dims, body.dtype());
const ReduceOp* reduce_op =
reducer(func_result, body, output_args, reduce_vars);
dims.insert(dims.end(), reduce_dims.begin(), reduce_dims.end());