[TensorExpr] TensorExprKernel: switch type of tensors_ from Tensor to Buf. (#56318)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56318
Test Plan: Imported from OSS
Reviewed By: bertmaher
Differential Revision: D27838748
Pulled By: ZolotukhinM
fbshipit-source-id: 371a454912be76889999eda79e60d8154b749134
diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp
index 47209ab..69a0e9a 100644
--- a/torch/csrc/jit/tensorexpr/kernel.cpp
+++ b/torch/csrc/jit/tensorexpr/kernel.cpp
@@ -239,8 +239,8 @@
return static_cast<size_t>(idx);
}
-static at::ScalarType tensorType(Tensor* t) {
- return static_cast<at::ScalarType>(t->buf()->dtype().scalar_type());
+static at::ScalarType tensorType(const Buf* b) {
+ return static_cast<at::ScalarType>(b->dtype().scalar_type());
}
static std::vector<ExprHandle> computeIndicesToBroadcast(
@@ -267,20 +267,28 @@
}
ExprHandle TensorExprKernel::broadcast(
- Tensor* t,
+ const Buf* b,
const std::vector<ExprHandle>& axes) {
- return t->load(computeIndicesToBroadcast(
- axes, ExprVectorToExprHandleVector(t->buf()->dims())));
+ return BufHandle(b).load(
+ computeIndicesToBroadcast(axes, ExprVectorToExprHandleVector(b->dims())));
+}
+
+std::vector<int64_t> bufferSizes(const Buf* b) {
+ std::vector<int64_t> sizes;
+ for (size_t i = 0; i < b->ndim(); i++) {
+ sizes.push_back(dynamic_cast<const IntImm*>(b->dim(i))->value());
+ }
+ return sizes;
}
ExprHandle TensorExprKernel::chunk(
- Tensor* t,
+ const Buf* b,
size_t chunkIdx,
int64_t dim,
int64_t chunks,
const std::vector<ExprHandle>& axes) {
auto norm_dim = normalizeAndCheckIndex(dim, axes.size());
- auto sizes = bufferSizes(t);
+ auto sizes = bufferSizes(b);
size_t step = sizes[norm_dim] / chunks;
std::vector<ExprHandle> indices;
@@ -292,7 +300,7 @@
}
}
- return t->load(indices);
+ return BufHandle(b).load(indices);
}
ExprHandle promoteToDtype(ExprHandle e, ScalarType dt) {
@@ -317,8 +325,8 @@
ExprHandle TensorExprKernel::tensorOrConstant(
const torch::jit::Value* v,
const std::vector<ExprHandle>& axes) {
- auto ti = tensors_.find(v);
- if (ti != tensors_.end()) {
+ auto ti = bufs_.find(v);
+ if (ti != bufs_.end()) {
return broadcast(ti->second, axes);
}
return constant(v);
@@ -728,11 +736,11 @@
std::vector<ExprHandle> TensorExprKernel::valueShape(
const torch::jit::Value* v) {
- auto it = tensors_.find(v);
- if (it == tensors_.end()) {
+ auto it = bufs_.find(v);
+ if (it == bufs_.end()) {
return {};
}
- return ExprVectorToExprHandleVector(it->second->buf()->dims());
+ return ExprVectorToExprHandleVector(it->second->dims());
}
Tensor* TensorExprKernel::computeOneOperand(
@@ -1317,8 +1325,8 @@
case aten::type_as: {
auto const& n = v->node();
- Tensor* rhs = tensors_.at(n->input(1));
- auto dtype = rhs->buf()->dtype();
+ const Buf* rhs = bufs_.at(n->input(1));
+ auto dtype = rhs->dtype();
return computeOneOperand(
"aten_type_as", v, [dtype](const ExprHandle& lhs) {
return Cast::make(dtype, lhs);
@@ -1570,7 +1578,7 @@
int64_t chunks = n->i(attr::chunks);
std::vector<ExprHandle> indices(axes.begin(), axes.end());
return chunk(
- tensors_.at(n->input(0)), v->offset(), dim, chunks, indices);
+ bufs_.at(n->input(0)), v->offset(), dim, chunks, indices);
});
}
@@ -1641,7 +1649,7 @@
std::vector<ExprHandle> newAxes(axes.begin(), axes.end());
ExprHandle load = promoteToDtype(
tensorOrConstant(nonempty_inputs[0], newAxes), highType);
- size_t offset = bufferSizes(tensors_.at(nonempty_inputs[0]))[dim];
+ size_t offset = bufferSizes(bufs_.at(nonempty_inputs[0]))[dim];
newAxes[dim] = newAxes[dim] - IntImm::make(offset);
for (size_t ii = 1; ii < nonempty_inputs.size(); ++ii) {
@@ -1651,7 +1659,7 @@
load,
promoteToDtype(tensorOrConstant(input, newAxes), highType));
- offset += bufferSizes(tensors_.at(input))[dim];
+ offset += bufferSizes(bufs_.at(input))[dim];
newAxes[dim] = axes[dim] - IntImm::make(offset);
}
@@ -1920,8 +1928,9 @@
input_name_map_ = std::move(value_to_name);
}
-void TensorExprKernel::bindInput(const torch::jit::Value* input) {
+Tensor* TensorExprKernel::bindInput(const torch::jit::Value* input) {
auto const& t = input->type();
+ Tensor* result = nullptr;
switch (t->kind()) {
case TypeKind::TensorType: {
auto tt = input->type()->cast<TensorType>();
@@ -1936,18 +1945,18 @@
DimArg(IntImm::make(size), "i" + c10::to_string(i)));
}
auto const strides = tt->strides();
- tensors_.emplace(
- input,
- Compute(
- "input" + c10::to_string(tensors_.size() + 1),
- inputTensorDims,
- [&](const std::vector<VarHandle>& axes) {
- ExprHandle idx = 0;
- for (size_t i = 0; i < axes.size(); i++) {
- idx = idx + axes[i] * IntImm::make(*strides[i]);
- }
- return inBuffer.load(idx);
- }));
+ result = Compute(
+ "input" + c10::to_string(bufs_.size() + 1),
+ inputTensorDims,
+ [&](const std::vector<VarHandle>& axes) {
+ ExprHandle idx = 0;
+ for (size_t i = 0; i < axes.size(); i++) {
+ idx = idx + axes[i] * IntImm::make(*strides[i]);
+ }
+ return inBuffer.load(idx);
+ });
+ bufs_.emplace(input, result->buf());
+
bufferArgs_.emplace_back(inBuffer);
break;
}
@@ -1974,6 +1983,7 @@
break;
}
}
+ return result;
}
namespace {
@@ -2038,9 +2048,9 @@
dtype = Dtype(*maybe_stype);
}
BufHandle ResultBuf("conv", shape, dtype);
- BufHandle inp = BufHandle(tensors_.at(n->input(0))->buf());
- BufHandle w = BufHandle(tensors_.at(n->input(1))->buf());
- BufHandle b = BufHandle(tensors_.at(n->input(2))->buf());
+ BufHandle inp = BufHandle(bufs_.at(n->input(0)));
+ BufHandle w = BufHandle(bufs_.at(n->input(1)));
+ BufHandle b = BufHandle(bufs_.at(n->input(2)));
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int sH, sW;
@@ -2097,8 +2107,8 @@
dtype = Dtype(*maybe_stype);
}
BufHandle ResultBuf("matmul", shape, dtype);
- const Buf* a = tensors_.at(n->input(0))->buf();
- const Buf* b = tensors_.at(n->input(1))->buf();
+ const Buf* a = bufs_.at(n->input(0));
+ const Buf* b = bufs_.at(n->input(1));
auto size_a = ExprVectorToExprHandleVector(a->dims());
auto size_b = ExprVectorToExprHandleVector(b->dims());
@@ -2355,7 +2365,7 @@
store_indices[i] = for_vars[i];
}
}
- auto inp_buf = tensors_.at(inp)->buf();
+ auto inp_buf = bufs_.at(inp);
auto load_expr = new Load(inp_buf, load_indices);
auto load_promoted = promoteToDtype(ExprHandle(load_expr), highType);
Stmt* st = new Store(output_buf, store_indices, load_promoted.node());
@@ -2470,8 +2480,8 @@
Tensor* TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) {
const TensorTypePtr& tt = v->type()->expect<TensorType>();
- TORCH_INTERNAL_ASSERT(tensors_.count(v));
- Tensor* tensor = tensors_[v];
+ TORCH_INTERNAL_ASSERT(bufs_.count(v));
+ const Buf* buf = bufs_.at(v);
TORCH_INTERNAL_ASSERT(tt->sizes().concrete_sizes());
const auto sizes = *tt->sizes().concrete_sizes();
@@ -2481,12 +2491,12 @@
// All Tensors in NNC are layed out in default, contiguous layout.
// If the output is also default contiguous we don't need to do anything
if (strides == default_strides) {
- return tensor;
+ return new Tensor(buf, nullptr);
}
// If the tensor is not dense or overlaps, we have
// no way of matching the profiled striding
if (!denseAndNonOverlapping(sizes, strides)) {
- return tensor;
+ return new Tensor(buf, nullptr);
}
auto dims = dimsFromSizes(sizesForValue(v));
@@ -2528,7 +2538,7 @@
Mod::make(absolute_position, IntImm::make(stride));
new_axes[stride_index] = index;
}
- return tensor->load(new_axes);
+ return BufHandle(buf).load(new_axes);
});
}
@@ -2543,10 +2553,9 @@
nInputs_ = graph_->inputs().size();
genInputDebugNames();
for (auto const& input : graph_->inputs()) {
- bindInput(input);
inputTypes_.push_back(input->type());
- if (input->type()->kind() == TypeKind::TensorType) {
- block->append_stmt(tensors_.at(input)->stmt());
+ if (Tensor* t = bindInput(input)) {
+ block->append_stmt(t->stmt());
}
}
@@ -2557,8 +2566,9 @@
} else {
for (auto const& output : n->outputs()) {
if (output->hasUses()) {
- tensors_.emplace(output, computeValue(output));
- block->append_stmt(tensors_.at(output)->stmt());
+ Tensor* t = computeValue(output);
+ bufs_.emplace(output, t->buf());
+ block->append_stmt(t->stmt());
}
}
}
@@ -2570,19 +2580,19 @@
device_ = *pickDeviceType(graph_->inputs());
- // Move output operands from `tensors_` to `bufOutputs_`
+ // Move output operands from `bufs_` to `bufOutputs_`
for (const auto& output : graph_->outputs()) {
- if (!tensors_.count(output)) {
+ if (!bufs_.count(output)) {
throw malformed_input("cannot find output Tensor");
}
// The "strided" tensor will be incorrect if used in NNC,
// since NNC views it as contiguous. Only convert it to the right
// strides at the end of the kernel (if already contiguous it's a no-op)
Tensor* properly_strided_output = convertOutputToCorrectStrides(output);
- if (tensors_.at(output) != properly_strided_output) {
+ if (properly_strided_output->stmt()) {
block->append_stmt(properly_strided_output->stmt());
}
- tensors_[output] = properly_strided_output;
+ bufs_[output] = properly_strided_output->buf();
const auto& tt = output->type()->expect<TensorType>();
auto sizes = *tt->sizes().concrete_sizes();
tensorOutputSizes_.push_back(sizes);
@@ -2596,11 +2606,11 @@
tensorOutputStrides_.push_back(TensorType::contiguousStridesOf(sizes));
}
- bufOutputs_.insert(tensors_.at(output)->buf());
- bufferArgs_.emplace_back(tensors_.at(output));
+ bufOutputs_.insert(bufs_.at(output));
+ bufferArgs_.emplace_back(BufHandle(bufs_.at(output)));
tensorOutputTensorOptions_.emplace_back(
- c10::TensorOptions(tensorType(tensors_[output])).device(device_));
- tensors_.erase(output);
+ c10::TensorOptions(tensorType(bufs_.at(output))).device(device_));
+ bufs_.erase(output);
}
BackendType backendType = inferBackendTypeFromDevice(device_);
diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h
index f169651..918785f 100644
--- a/torch/csrc/jit/tensorexpr/kernel.h
+++ b/torch/csrc/jit/tensorexpr/kernel.h
@@ -10,15 +10,6 @@
namespace jit {
namespace tensorexpr {
-template <typename T>
-inline std::vector<int64_t> bufferSizes(const T& t) {
- std::vector<int64_t> sizes;
- for (size_t i = 0; i < t->buf()->ndim(); i++) {
- sizes.push_back(dynamic_cast<const IntImm*>(t->buf()->dim(i))->value());
- }
- return sizes;
-}
-
// Returns true if the TE fuser supports this conv2d.
bool conv2dIsSupported(const Node* node);
// Returns true if the TE fuser supports this matmul.
@@ -81,9 +72,9 @@
std::vector<std::vector<ExprHandle>> shapes);
ExprHandle constant(const torch::jit::Value* v);
- ExprHandle broadcast(Tensor* t, const std::vector<ExprHandle>& axes);
+ ExprHandle broadcast(const Buf* b, const std::vector<ExprHandle>& axes);
ExprHandle chunk(
- Tensor* t,
+ const Buf* b,
size_t chunkIdx,
int64_t dim,
int64_t chunks,
@@ -166,7 +157,7 @@
std::vector<at::Tensor>& outputs);
BackendType inferBackendTypeFromDevice(at::Device device);
- void bindInput(const torch::jit::Value* input);
+ Tensor* bindInput(const torch::jit::Value* input);
Tensor* convertOutputToCorrectStrides(torch::jit::Value* v);
@@ -205,7 +196,7 @@
std::vector<std::vector<int64_t>> tensorOutputStrides_;
std::vector<UnpackedTensorOptions> tensorOutputTensorOptions_;
std::unordered_set<const Buf*> bufOutputs_;
- std::unordered_map<const torch::jit::Value*, Tensor*> tensors_;
+ std::unordered_map<const torch::jit::Value*, const Buf*> bufs_;
std::unordered_map<const torch::jit::Value*, VarHandle> scalars_;
std::unordered_map<const torch::jit::Value*, std::string> input_name_map_;
std::unique_ptr<CodeGen> codegen_;