[TensorExpr] TensorExprKernel: switch type of tensors_ from Tensor to Buf. (#56318)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56318

Test Plan: Imported from OSS

Reviewed By: bertmaher

Differential Revision: D27838748

Pulled By: ZolotukhinM

fbshipit-source-id: 371a454912be76889999eda79e60d8154b749134
diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp
index 47209ab..69a0e9a 100644
--- a/torch/csrc/jit/tensorexpr/kernel.cpp
+++ b/torch/csrc/jit/tensorexpr/kernel.cpp
@@ -239,8 +239,8 @@
   return static_cast<size_t>(idx);
 }
 
-static at::ScalarType tensorType(Tensor* t) {
-  return static_cast<at::ScalarType>(t->buf()->dtype().scalar_type());
+static at::ScalarType tensorType(const Buf* b) {
+  return static_cast<at::ScalarType>(b->dtype().scalar_type());
 }
 
 static std::vector<ExprHandle> computeIndicesToBroadcast(
@@ -267,20 +267,28 @@
 }
 
 ExprHandle TensorExprKernel::broadcast(
-    Tensor* t,
+    const Buf* b,
     const std::vector<ExprHandle>& axes) {
-  return t->load(computeIndicesToBroadcast(
-      axes, ExprVectorToExprHandleVector(t->buf()->dims())));
+  return BufHandle(b).load(
+      computeIndicesToBroadcast(axes, ExprVectorToExprHandleVector(b->dims())));
+}
+
+std::vector<int64_t> bufferSizes(const Buf* b) {
+  std::vector<int64_t> sizes;
+  for (size_t i = 0; i < b->ndim(); i++) {
+    sizes.push_back(dynamic_cast<const IntImm*>(b->dim(i))->value());
+  }
+  return sizes;
 }
 
 ExprHandle TensorExprKernel::chunk(
-    Tensor* t,
+    const Buf* b,
     size_t chunkIdx,
     int64_t dim,
     int64_t chunks,
     const std::vector<ExprHandle>& axes) {
   auto norm_dim = normalizeAndCheckIndex(dim, axes.size());
-  auto sizes = bufferSizes(t);
+  auto sizes = bufferSizes(b);
   size_t step = sizes[norm_dim] / chunks;
 
   std::vector<ExprHandle> indices;
@@ -292,7 +300,7 @@
     }
   }
 
-  return t->load(indices);
+  return BufHandle(b).load(indices);
 }
 
 ExprHandle promoteToDtype(ExprHandle e, ScalarType dt) {
@@ -317,8 +325,8 @@
 ExprHandle TensorExprKernel::tensorOrConstant(
     const torch::jit::Value* v,
     const std::vector<ExprHandle>& axes) {
-  auto ti = tensors_.find(v);
-  if (ti != tensors_.end()) {
+  auto ti = bufs_.find(v);
+  if (ti != bufs_.end()) {
     return broadcast(ti->second, axes);
   }
   return constant(v);
@@ -728,11 +736,11 @@
 
 std::vector<ExprHandle> TensorExprKernel::valueShape(
     const torch::jit::Value* v) {
-  auto it = tensors_.find(v);
-  if (it == tensors_.end()) {
+  auto it = bufs_.find(v);
+  if (it == bufs_.end()) {
     return {};
   }
-  return ExprVectorToExprHandleVector(it->second->buf()->dims());
+  return ExprVectorToExprHandleVector(it->second->dims());
 }
 
 Tensor* TensorExprKernel::computeOneOperand(
@@ -1317,8 +1325,8 @@
 
     case aten::type_as: {
       auto const& n = v->node();
-      Tensor* rhs = tensors_.at(n->input(1));
-      auto dtype = rhs->buf()->dtype();
+      const Buf* rhs = bufs_.at(n->input(1));
+      auto dtype = rhs->dtype();
       return computeOneOperand(
           "aten_type_as", v, [dtype](const ExprHandle& lhs) {
             return Cast::make(dtype, lhs);
@@ -1570,7 +1578,7 @@
             int64_t chunks = n->i(attr::chunks);
             std::vector<ExprHandle> indices(axes.begin(), axes.end());
             return chunk(
-                tensors_.at(n->input(0)), v->offset(), dim, chunks, indices);
+                bufs_.at(n->input(0)), v->offset(), dim, chunks, indices);
           });
     }
 
@@ -1641,7 +1649,7 @@
             std::vector<ExprHandle> newAxes(axes.begin(), axes.end());
             ExprHandle load = promoteToDtype(
                 tensorOrConstant(nonempty_inputs[0], newAxes), highType);
-            size_t offset = bufferSizes(tensors_.at(nonempty_inputs[0]))[dim];
+            size_t offset = bufferSizes(bufs_.at(nonempty_inputs[0]))[dim];
             newAxes[dim] = newAxes[dim] - IntImm::make(offset);
 
             for (size_t ii = 1; ii < nonempty_inputs.size(); ++ii) {
@@ -1651,7 +1659,7 @@
                   load,
                   promoteToDtype(tensorOrConstant(input, newAxes), highType));
 
-              offset += bufferSizes(tensors_.at(input))[dim];
+              offset += bufferSizes(bufs_.at(input))[dim];
               newAxes[dim] = axes[dim] - IntImm::make(offset);
             }
 
@@ -1920,8 +1928,9 @@
   input_name_map_ = std::move(value_to_name);
 }
 
-void TensorExprKernel::bindInput(const torch::jit::Value* input) {
+Tensor* TensorExprKernel::bindInput(const torch::jit::Value* input) {
   auto const& t = input->type();
+  Tensor* result = nullptr;
   switch (t->kind()) {
     case TypeKind::TensorType: {
       auto tt = input->type()->cast<TensorType>();
@@ -1936,18 +1945,18 @@
             DimArg(IntImm::make(size), "i" + c10::to_string(i)));
       }
       auto const strides = tt->strides();
-      tensors_.emplace(
-          input,
-          Compute(
-              "input" + c10::to_string(tensors_.size() + 1),
-              inputTensorDims,
-              [&](const std::vector<VarHandle>& axes) {
-                ExprHandle idx = 0;
-                for (size_t i = 0; i < axes.size(); i++) {
-                  idx = idx + axes[i] * IntImm::make(*strides[i]);
-                }
-                return inBuffer.load(idx);
-              }));
+      result = Compute(
+          "input" + c10::to_string(bufs_.size() + 1),
+          inputTensorDims,
+          [&](const std::vector<VarHandle>& axes) {
+            ExprHandle idx = 0;
+            for (size_t i = 0; i < axes.size(); i++) {
+              idx = idx + axes[i] * IntImm::make(*strides[i]);
+            }
+            return inBuffer.load(idx);
+          });
+      bufs_.emplace(input, result->buf());
+
       bufferArgs_.emplace_back(inBuffer);
       break;
     }
@@ -1974,6 +1983,7 @@
       break;
     }
   }
+  return result;
 }
 
 namespace {
@@ -2038,9 +2048,9 @@
     dtype = Dtype(*maybe_stype);
   }
   BufHandle ResultBuf("conv", shape, dtype);
-  BufHandle inp = BufHandle(tensors_.at(n->input(0))->buf());
-  BufHandle w = BufHandle(tensors_.at(n->input(1))->buf());
-  BufHandle b = BufHandle(tensors_.at(n->input(2))->buf());
+  BufHandle inp = BufHandle(bufs_.at(n->input(0)));
+  BufHandle w = BufHandle(bufs_.at(n->input(1)));
+  BufHandle b = BufHandle(bufs_.at(n->input(2)));
 
   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
   int sH, sW;
@@ -2097,8 +2107,8 @@
     dtype = Dtype(*maybe_stype);
   }
   BufHandle ResultBuf("matmul", shape, dtype);
-  const Buf* a = tensors_.at(n->input(0))->buf();
-  const Buf* b = tensors_.at(n->input(1))->buf();
+  const Buf* a = bufs_.at(n->input(0));
+  const Buf* b = bufs_.at(n->input(1));
 
   auto size_a = ExprVectorToExprHandleVector(a->dims());
   auto size_b = ExprVectorToExprHandleVector(b->dims());
@@ -2355,7 +2365,7 @@
         store_indices[i] = for_vars[i];
       }
     }
-    auto inp_buf = tensors_.at(inp)->buf();
+    auto inp_buf = bufs_.at(inp);
     auto load_expr = new Load(inp_buf, load_indices);
     auto load_promoted = promoteToDtype(ExprHandle(load_expr), highType);
     Stmt* st = new Store(output_buf, store_indices, load_promoted.node());
@@ -2470,8 +2480,8 @@
 
 Tensor* TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) {
   const TensorTypePtr& tt = v->type()->expect<TensorType>();
-  TORCH_INTERNAL_ASSERT(tensors_.count(v));
-  Tensor* tensor = tensors_[v];
+  TORCH_INTERNAL_ASSERT(bufs_.count(v));
+  const Buf* buf = bufs_.at(v);
 
   TORCH_INTERNAL_ASSERT(tt->sizes().concrete_sizes());
   const auto sizes = *tt->sizes().concrete_sizes();
@@ -2481,12 +2491,12 @@
   // All Tensors in NNC are layed out in default, contiguous layout.
   // If the output is also default contiguous we don't need to do anything
   if (strides == default_strides) {
-    return tensor;
+    return new Tensor(buf, nullptr);
   }
   // If the tensor is not dense or overlaps, we have
   // no way of matching the profiled striding
   if (!denseAndNonOverlapping(sizes, strides)) {
-    return tensor;
+    return new Tensor(buf, nullptr);
   }
 
   auto dims = dimsFromSizes(sizesForValue(v));
@@ -2528,7 +2538,7 @@
               Mod::make(absolute_position, IntImm::make(stride));
           new_axes[stride_index] = index;
         }
-        return tensor->load(new_axes);
+        return BufHandle(buf).load(new_axes);
       });
 }
 
@@ -2543,10 +2553,9 @@
   nInputs_ = graph_->inputs().size();
   genInputDebugNames();
   for (auto const& input : graph_->inputs()) {
-    bindInput(input);
     inputTypes_.push_back(input->type());
-    if (input->type()->kind() == TypeKind::TensorType) {
-      block->append_stmt(tensors_.at(input)->stmt());
+    if (Tensor* t = bindInput(input)) {
+      block->append_stmt(t->stmt());
     }
   }
 
@@ -2557,8 +2566,9 @@
     } else {
       for (auto const& output : n->outputs()) {
         if (output->hasUses()) {
-          tensors_.emplace(output, computeValue(output));
-          block->append_stmt(tensors_.at(output)->stmt());
+          Tensor* t = computeValue(output);
+          bufs_.emplace(output, t->buf());
+          block->append_stmt(t->stmt());
         }
       }
     }
@@ -2570,19 +2580,19 @@
 
   device_ = *pickDeviceType(graph_->inputs());
 
-  // Move output operands from `tensors_` to `bufOutputs_`
+  // Move output operands from `bufs_` to `bufOutputs_`
   for (const auto& output : graph_->outputs()) {
-    if (!tensors_.count(output)) {
+    if (!bufs_.count(output)) {
       throw malformed_input("cannot find output Tensor");
     }
     // The "strided" tensor will be incorrect if used in NNC,
     // since NNC views it as contiguous. Only convert it to the right
     // strides at the end of the kernel (if already contiguous it's a no-op)
     Tensor* properly_strided_output = convertOutputToCorrectStrides(output);
-    if (tensors_.at(output) != properly_strided_output) {
+    if (properly_strided_output->stmt()) {
       block->append_stmt(properly_strided_output->stmt());
     }
-    tensors_[output] = properly_strided_output;
+    bufs_[output] = properly_strided_output->buf();
     const auto& tt = output->type()->expect<TensorType>();
     auto sizes = *tt->sizes().concrete_sizes();
     tensorOutputSizes_.push_back(sizes);
@@ -2596,11 +2606,11 @@
       tensorOutputStrides_.push_back(TensorType::contiguousStridesOf(sizes));
     }
 
-    bufOutputs_.insert(tensors_.at(output)->buf());
-    bufferArgs_.emplace_back(tensors_.at(output));
+    bufOutputs_.insert(bufs_.at(output));
+    bufferArgs_.emplace_back(BufHandle(bufs_.at(output)));
     tensorOutputTensorOptions_.emplace_back(
-        c10::TensorOptions(tensorType(tensors_[output])).device(device_));
-    tensors_.erase(output);
+        c10::TensorOptions(tensorType(bufs_.at(output))).device(device_));
+    bufs_.erase(output);
   }
 
   BackendType backendType = inferBackendTypeFromDevice(device_);
diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h
index f169651..918785f 100644
--- a/torch/csrc/jit/tensorexpr/kernel.h
+++ b/torch/csrc/jit/tensorexpr/kernel.h
@@ -10,15 +10,6 @@
 namespace jit {
 namespace tensorexpr {
 
-template <typename T>
-inline std::vector<int64_t> bufferSizes(const T& t) {
-  std::vector<int64_t> sizes;
-  for (size_t i = 0; i < t->buf()->ndim(); i++) {
-    sizes.push_back(dynamic_cast<const IntImm*>(t->buf()->dim(i))->value());
-  }
-  return sizes;
-}
-
 // Returns true if the TE fuser supports this conv2d.
 bool conv2dIsSupported(const Node* node);
 // Returns true if the TE fuser supports this matmul.
@@ -81,9 +72,9 @@
       std::vector<std::vector<ExprHandle>> shapes);
 
   ExprHandle constant(const torch::jit::Value* v);
-  ExprHandle broadcast(Tensor* t, const std::vector<ExprHandle>& axes);
+  ExprHandle broadcast(const Buf* b, const std::vector<ExprHandle>& axes);
   ExprHandle chunk(
-      Tensor* t,
+      const Buf* b,
       size_t chunkIdx,
       int64_t dim,
       int64_t chunks,
@@ -166,7 +157,7 @@
       std::vector<at::Tensor>& outputs);
   BackendType inferBackendTypeFromDevice(at::Device device);
 
-  void bindInput(const torch::jit::Value* input);
+  Tensor* bindInput(const torch::jit::Value* input);
 
   Tensor* convertOutputToCorrectStrides(torch::jit::Value* v);
 
@@ -205,7 +196,7 @@
   std::vector<std::vector<int64_t>> tensorOutputStrides_;
   std::vector<UnpackedTensorOptions> tensorOutputTensorOptions_;
   std::unordered_set<const Buf*> bufOutputs_;
-  std::unordered_map<const torch::jit::Value*, Tensor*> tensors_;
+  std::unordered_map<const torch::jit::Value*, const Buf*> bufs_;
   std::unordered_map<const torch::jit::Value*, VarHandle> scalars_;
   std::unordered_map<const torch::jit::Value*, std::string> input_name_map_;
   std::unique_ptr<CodeGen> codegen_;