[nnc] ported some more ops + added vectors to argvalue (#56766)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56766

Test Plan: Imported from OSS

Reviewed By: desertfire

Differential Revision: D28118331

Pulled By: Chillee

fbshipit-source-id: eb012943ad3b83e72a8cb17b594852164c3f0567
diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp
index ae23453..f625526 100644
--- a/torch/csrc/jit/tensorexpr/kernel.cpp
+++ b/torch/csrc/jit/tensorexpr/kernel.cpp
@@ -396,6 +396,20 @@
   if (ti != bufs_.end()) {
     return BufHandle(ti->second);
   }
+  if (v->node()->kind() == prim::ListConstruct) {
+    std::vector<ArgValue> vec;
+    for (auto el : v->node()->inputs()) {
+      vec.push_back(toArg(el));
+    }
+    if (vec.size() == 0) {
+      return BufList(); // Return arbitrarily typed vector
+    } else if (c10::get_if<BufHandle>(&vec[0])) {
+      return convertVecArgValue<BufHandle>(vec);
+    } else if (c10::get_if<int64_t>(&vec[0])) {
+      return convertVecArgValue<int64_t>(vec);
+    }
+    throw unsupported_dtype();
+  }
   if (v->node()->kind() == prim::Constant) {
     const auto val = toIValue(v).value();
     if (val.isDouble()) {
@@ -410,7 +424,7 @@
       // the operator-specific lowering code.
       return ArgNone();
     } else {
-      throw unsupported_dtype();
+      throw unsupported_dtype(val.type()->str());
     }
   }
 
@@ -976,15 +990,166 @@
       });
 }
 
-Tensor* computeCat(
-    const std::vector<ArgValue>& inputList,
-    const ArgValue& argDim,
-    const std::vector<ExprHandle>& outputShape);
-
+std::pair<ScalarType, std::vector<BufHandle>> processCatList(
+    const std::vector<BufHandle>& bufList) {
+  if (bufList.size() == 0) {
+    throw std::runtime_error("Empty input list is passed to aten::cat");
+  }
+  std::vector<BufHandle> bufInputs;
+  std::vector<BufHandle> nonEmptyInputs;
+  for (auto buf : bufList) {
+    bufInputs.push_back(buf);
+    TORCH_INTERNAL_ASSERT(buf.node()->dims().size() > 0);
+    if (buf.node()->dims().size() == 1 &&
+        immediateAs<int>(buf.node()->dim(0)) == 0) {
+      continue;
+    }
+    nonEmptyInputs.push_back(buf);
+  }
+  ScalarType highType = bufInputs[0].dtype().scalar_type();
+  for (const auto input : bufInputs) {
+    auto maybe_dtype = input.dtype().scalar_type();
+    highType = promoteTypes(highType, maybe_dtype);
+  }
+  return {highType, nonEmptyInputs};
+}
 Tensor* computeCatWoConditionals(
-    const std::vector<ArgValue>& inputList,
-    const ArgValue& argDim,
-    const std::vector<ExprHandle>& outputShape);
+    const std::vector<ArgValue>& inputs,
+    const std::vector<ExprHandle>& outputShape) {
+  auto input_list = c10::get<BufList>(inputs[0]);
+  auto arg_dim = inputs[1];
+  auto cat_info = processCatList(input_list);
+  ScalarType high_type = cat_info.first;
+  std::vector<BufHandle> non_empty_inputs = cat_info.second;
+
+  // Now we build one loop per input:
+  //
+  // for i
+  //   for j
+  //     for k
+  //       output[i,j,k] = inp1[i,j,k]
+  // for i
+  //   for j
+  //     for k
+  //       output[i,j+l1,k] = inp2[i,j,k]
+  // for i
+  //   for j
+  //     for k
+  //       output[i,j+l2,k] = inp3[i,j,k]
+
+  auto output_sizes_expr = ExprHandleVectorToExprVector(outputShape);
+  auto output_buf = new Buf("aten_cat", output_sizes_expr, ToDtype(high_type));
+  if (non_empty_inputs.size() == 0) {
+    return new Tensor(output_buf, new tensorexpr::Block({}));
+  }
+
+  int64_t concat_dim = c10::get<int64_t>(arg_dim);
+  size_t norm_concat_dim =
+      normalizeAndCheckIndex(concat_dim, outputShape.size());
+
+  auto gen_code_for_input = [&](const BufHandle& inp,
+                                size_t inp_pos,
+                                const Expr* concat_dim_size,
+                                const std::vector<ExprHandle>& dims) {
+    std::vector<Var*> for_vars(dims.size());
+    std::vector<const Expr*> load_indices(dims.size());
+    std::vector<const Expr*> store_indices(dims.size());
+    for (size_t i = 0; i < dims.size(); ++i) {
+      for_vars[i] = new Var(
+          "i" + c10::to_string(inp_pos) + "_" + c10::to_string(i), kInt);
+      load_indices[i] = for_vars[i];
+      if (i == norm_concat_dim) {
+        store_indices[i] = new Add(for_vars[i], concat_dim_size);
+      } else {
+        store_indices[i] = for_vars[i];
+      }
+    }
+    auto inp_buf = inp.node();
+    auto load_expr = new Load(inp_buf, load_indices);
+    auto load_promoted = promoteToDtype(ExprHandle(load_expr), high_type);
+    Stmt* st = new Store(output_buf, store_indices, load_promoted.node());
+    for (size_t i = dims.size(); i > 0; --i) {
+      st = new For(for_vars[i - 1], new IntImm(0), dims[i - 1].node(), st);
+    }
+    return st;
+  };
+
+  Expr* concat_dim_size = nullptr;
+  auto block = new tensorexpr::Block({});
+  for (size_t i = 0; i < non_empty_inputs.size(); ++i) {
+    auto input_dims =
+        ExprVectorToExprHandleVector(non_empty_inputs[i].node()->dims());
+    if (concat_dim_size == nullptr) {
+      concat_dim_size = new IntImm(0);
+    }
+    block->append_stmt(gen_code_for_input(
+        non_empty_inputs[i], i, concat_dim_size, input_dims));
+    concat_dim_size =
+        new Add(concat_dim_size, input_dims[norm_concat_dim].node());
+  }
+  return new Tensor(output_buf, IRSimplifier::simplify(block));
+}
+
+Tensor* computeCat(
+    const std::vector<ArgValue>& inputs,
+    const std::vector<ExprHandle>& outputShape) {
+  if (getCatWoConditionals()) {
+    return computeCatWoConditionals(inputs, outputShape);
+  }
+  auto inputList = c10::get<BufList>(inputs[0]);
+  auto argDim = inputs[1];
+  auto catInfo = processCatList(inputList);
+  ScalarType highType = catInfo.first;
+  std::vector<BufHandle> nonEmptyInputs = catInfo.second;
+  return Compute(
+      "aten_cat",
+      c10::fmap<DimArg>(outputShape),
+      [&](const std::vector<VarHandle>& axes) {
+        if (nonEmptyInputs.size() == 0) {
+          return ExprHandle(0);
+        }
+
+        int64_t dim_ = c10::get<int64_t>(argDim);
+        size_t dim = normalizeAndCheckIndex(dim_, axes.size());
+        // Promote input types.
+        // Note that we need to consider all inputs, including empty - they
+        // also affect the resultant dtype.
+
+        // Now we know the final dtype, we know what inputs are non-empty,
+        // and we know that there is at least one such an input. With all
+        // that we construct a tensor expression performing the
+        // concatenation.
+        // The expression we build here is a cascading if-then-else that
+        // essentially represents:
+        //
+        //              inp1[i, j, k]         if 0   < i < l1,
+        // out[i,j,k] = inp2[i, j-l1, k]      if l1 =< i < l1 + l2,
+        //              ...
+        //              inpN[i, j-l_N_1, k]   if l1+l2+...l_N_1  < i
+        // where l_i is the corresponding size of the i-th input.
+        std::vector<ExprHandle> newAxes(axes.begin(), axes.end());
+        ExprHandle load = promoteToDtype(
+            tensorOrConstant(nonEmptyInputs[0], newAxes), highType);
+        size_t offset =
+            dynamic_cast<const IntImm*>(nonEmptyInputs[0].node()->dim(dim))
+                ->value();
+        newAxes[dim] = newAxes[dim] - IntImm::make(offset);
+
+        for (size_t ii = 1; ii < nonEmptyInputs.size(); ++ii) {
+          auto input = nonEmptyInputs[ii];
+          load = ifThenElse(
+              CompareSelect::make(axes[dim], IntImm::make(offset), kLT),
+              load,
+              promoteToDtype(tensorOrConstant(input, newAxes), highType));
+
+          offset +=
+              dynamic_cast<const IntImm*>(input.node()->dim(dim))->value();
+          newAxes[dim] = axes[dim] - IntImm::make(offset);
+        }
+
+        return load;
+      });
+}
 
 Tensor* computeMatmul(
     const std::vector<ArgValue>& inputs,
@@ -1500,6 +1665,16 @@
             return tan(promoteIntegerToDefaultType(a));
           });
     } break;
+    case aten::type_as: {
+      const BufHandle rhs = c10::get<BufHandle>(inputs[1]);
+      auto dtype = rhs.dtype();
+      return computeOneOperand(
+          "aten_type_as",
+          inputs,
+          outputShape,
+          outputType,
+          [dtype](const ExprHandle& lhs) { return Cast::make(dtype, lhs); });
+    } break;
     case aten::pow: {
       return computeTwoOperand(
           "aten_pow",
@@ -1770,7 +1945,19 @@
           outputType,
           [](const ExprHandle& a) { return cast<float>(a); });
     } break;
-
+    case aten::to: {
+      // see handling of aten::to in tensorexpr_fuser.cpp for why we only
+      // need to handle the first input
+      return computeOneOperand(
+          "aten_to",
+          inputs,
+          outputShape,
+          outputType,
+          [outputType](const ExprHandle& a) {
+            TORCH_INTERNAL_ASSERT(outputType);
+            return Cast::make(ToDtype(*outputType), a);
+          });
+    } break;
     case aten::threshold: {
       return computeThreeOperand(
           "aten_threshold",
@@ -1872,6 +2059,9 @@
     case aten::matmul: {
       return computeMatmul(inputs, outputShape, outputType);
     }
+    case aten::cat: {
+      return computeCat(inputs, outputShape);
+    }
     default: {
       throw std::runtime_error("Unhandled node kind");
       return nullptr;
@@ -1932,6 +2122,7 @@
     case aten::cos:
     case aten::sin:
     case aten::tan:
+    case aten::type_as:
     case aten::pow:
     case aten::fmod:
     case aten::lerp:
@@ -1952,13 +2143,15 @@
     case aten::round:
     case aten::trunc:
     case aten::_cast_Float:
+    case aten::to:
     case aten::threshold:
     case aten::where:
     case aten::frac:
     case aten::lgamma:
     case aten::slice:
     case aten::unsqueeze:
-    case aten::matmul: {
+    case aten::matmul:
+    case aten::cat: {
       std::vector<ArgValue> argInputs;
       for (auto inp : inputs) {
         argInputs.push_back(toArg(inp));
@@ -1969,44 +2162,6 @@
           v->node()->kind(), argInputs, outputShape, outputType);
     } break;
 
-    case aten::to: {
-      // see handling of aten::to in tensorexpr_fuser.cpp for why we only
-      // need to handle the first input
-      auto outputType = findDtypeForValue(v->node()->output());
-      auto outputShape = sizesForValue(v);
-      auto output_dtype = findDtypeForValue(v->node()->output());
-      return computeOneOperand(
-          "aten_to",
-          {toArg(inputs[0])},
-          outputShape,
-          outputType,
-          [output_dtype](const ExprHandle& a) {
-            TORCH_INTERNAL_ASSERT(output_dtype);
-            return Cast::make(ToDtype(*output_dtype), a);
-          });
-    } break;
-
-    case aten::type_as: {
-      auto outputType = findDtypeForValue(v->node()->output());
-      auto outputShape = sizesForValue(v);
-      const Buf* rhs = bufs_.at(inputs[1]);
-      auto dtype = rhs->dtype();
-      return computeOneOperand(
-          "aten_type_as",
-          {toArg(inputs[0])},
-          outputShape,
-          outputType,
-          [dtype](const ExprHandle& lhs) { return Cast::make(dtype, lhs); });
-    } break;
-    case aten::cat: {
-      std::vector<ArgValue> inputList;
-      for (auto inp : v->node()->input(0)->node()->inputs()) {
-        inputList.push_back(toArg(inp));
-      }
-      auto outputShape = sizesForValue(v);
-      return computeCat(inputList, toArg(v->node()->input(1)), outputShape);
-    } break;
-
     case prim::ConstantChunk: {
       return Compute(
           "prim_constantchunk",
@@ -2377,167 +2532,6 @@
 
 } // namespace
 
-std::pair<ScalarType, std::vector<BufHandle>> processCatList(
-    const std::vector<ArgValue>& inputList) {
-  if (inputList.size() == 0) {
-    throw std::runtime_error("Empty input list is passed to aten::cat");
-  }
-  std::vector<BufHandle> bufInputs;
-  std::vector<BufHandle> nonEmptyInputs;
-  for (auto input : inputList) {
-    auto buf = c10::get<BufHandle>(input);
-    bufInputs.push_back(buf);
-    assert(buf.node()->dims().size() > 0);
-    if (buf.node()->dims().size() == 1 &&
-        immediateAs<int>(buf.node()->dim(0)) == 0) {
-      continue;
-    }
-    nonEmptyInputs.push_back(buf);
-  }
-  auto maybe_dtype = bufInputs[0].dtype().scalar_type();
-  ScalarType highType = maybe_dtype;
-  for (const auto input : bufInputs) {
-    auto maybe_dtype = input.dtype().scalar_type();
-    highType = promoteTypes(highType, maybe_dtype);
-  }
-  return {highType, nonEmptyInputs};
-}
-
-Tensor* computeCat(
-    const std::vector<ArgValue>& inputList,
-    const ArgValue& argDim,
-    const std::vector<ExprHandle>& outputShape) {
-  if (getCatWoConditionals()) {
-    return computeCatWoConditionals(inputList, argDim, outputShape);
-  }
-  auto inputs = processCatList(inputList);
-  ScalarType highType = inputs.first;
-  std::vector<BufHandle> nonEmptyInputs = inputs.second;
-  return Compute(
-      "aten_cat",
-      c10::fmap<DimArg>(outputShape),
-      [&](const std::vector<VarHandle>& axes) {
-        if (nonEmptyInputs.size() == 0) {
-          return ExprHandle(0);
-        }
-
-        int64_t dim_ = c10::get<int64_t>(argDim);
-        size_t dim = normalizeAndCheckIndex(dim_, axes.size());
-        // Promote input types.
-        // Note that we need to consider all inputs, including empty - they
-        // also affect the resultant dtype.
-
-        // Now we know the final dtype, we know what inputs are non-empty,
-        // and we know that there is at least one such an input. With all
-        // that we construct a tensor expression performing the
-        // concatenation.
-        // The expression we build here is a cascading if-then-else that
-        // essentially represents:
-        //
-        //              inp1[i, j, k]         if 0   < i < l1,
-        // out[i,j,k] = inp2[i, j-l1, k]      if l1 =< i < l1 + l2,
-        //              ...
-        //              inpN[i, j-l_N_1, k]   if l1+l2+...l_N_1  < i
-        // where l_i is the corresponding size of the i-th input.
-        std::vector<ExprHandle> newAxes(axes.begin(), axes.end());
-        ExprHandle load = promoteToDtype(
-            tensorOrConstant(nonEmptyInputs[0], newAxes), highType);
-        size_t offset =
-            dynamic_cast<const IntImm*>(nonEmptyInputs[0].node()->dim(dim))
-                ->value();
-        newAxes[dim] = newAxes[dim] - IntImm::make(offset);
-
-        for (size_t ii = 1; ii < nonEmptyInputs.size(); ++ii) {
-          auto input = nonEmptyInputs[ii];
-          load = ifThenElse(
-              CompareSelect::make(axes[dim], IntImm::make(offset), kLT),
-              load,
-              promoteToDtype(tensorOrConstant(input, newAxes), highType));
-
-          offset +=
-              dynamic_cast<const IntImm*>(input.node()->dim(dim))->value();
-          newAxes[dim] = axes[dim] - IntImm::make(offset);
-        }
-
-        return load;
-      });
-}
-Tensor* computeCatWoConditionals(
-    const std::vector<ArgValue>& input_list,
-    const ArgValue& arg_dim,
-    const std::vector<ExprHandle>& output_shape) {
-  auto inputs = processCatList(input_list);
-  ScalarType high_type = inputs.first;
-  std::vector<BufHandle> non_empty_inputs = inputs.second;
-
-  // Now we build one loop per input:
-  //
-  // for i
-  //   for j
-  //     for k
-  //       output[i,j,k] = inp1[i,j,k]
-  // for i
-  //   for j
-  //     for k
-  //       output[i,j+l1,k] = inp2[i,j,k]
-  // for i
-  //   for j
-  //     for k
-  //       output[i,j+l2,k] = inp3[i,j,k]
-
-  auto output_sizes_expr = ExprHandleVectorToExprVector(output_shape);
-  auto output_buf = new Buf("aten_cat", output_sizes_expr, ToDtype(high_type));
-  if (non_empty_inputs.size() == 0) {
-    return new Tensor(output_buf, new tensorexpr::Block({}));
-  }
-
-  int64_t concat_dim = c10::get<int64_t>(arg_dim);
-  size_t norm_concat_dim =
-      normalizeAndCheckIndex(concat_dim, output_shape.size());
-
-  auto gen_code_for_input = [&](const BufHandle& inp,
-                                size_t inp_pos,
-                                const Expr* concat_dim_size,
-                                const std::vector<ExprHandle>& dims) {
-    std::vector<Var*> for_vars(dims.size());
-    std::vector<const Expr*> load_indices(dims.size());
-    std::vector<const Expr*> store_indices(dims.size());
-    for (size_t i = 0; i < dims.size(); ++i) {
-      for_vars[i] = new Var(
-          "i" + c10::to_string(inp_pos) + "_" + c10::to_string(i), kInt);
-      load_indices[i] = for_vars[i];
-      if (i == norm_concat_dim) {
-        store_indices[i] = new Add(for_vars[i], concat_dim_size);
-      } else {
-        store_indices[i] = for_vars[i];
-      }
-    }
-    auto inp_buf = inp.node();
-    auto load_expr = new Load(inp_buf, load_indices);
-    auto load_promoted = promoteToDtype(ExprHandle(load_expr), high_type);
-    Stmt* st = new Store(output_buf, store_indices, load_promoted.node());
-    for (size_t i = dims.size(); i > 0; --i) {
-      st = new For(for_vars[i - 1], new IntImm(0), dims[i - 1].node(), st);
-    }
-    return st;
-  };
-
-  Expr* concat_dim_size = nullptr;
-  auto block = new tensorexpr::Block({});
-  for (size_t i = 0; i < non_empty_inputs.size(); ++i) {
-    auto input_dims =
-        ExprVectorToExprHandleVector(non_empty_inputs[i].node()->dims());
-    if (concat_dim_size == nullptr) {
-      concat_dim_size = new IntImm(0);
-    }
-    block->append_stmt(gen_code_for_input(
-        non_empty_inputs[i], i, concat_dim_size, input_dims));
-    concat_dim_size =
-        new Add(concat_dim_size, input_dims[norm_concat_dim].node());
-  }
-  return new Tensor(output_buf, IRSimplifier::simplify(block));
-}
-
 Tensor* TensorExprKernel::computeSum(const torch::jit::Value* v) {
   auto reduction_info = getReductionInfo(v->node());
   return Reduce(
diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h
index c2a9d98..368f562 100644
--- a/torch/csrc/jit/tensorexpr/kernel.h
+++ b/torch/csrc/jit/tensorexpr/kernel.h
@@ -25,14 +25,24 @@
 }
 using ArgNone = c10::monostate;
 using BufList = std::vector<tensorexpr::BufHandle>;
+using IntList = std::vector<int64_t>;
 using ArgValue = c10::variant<
     tensorexpr::BufHandle,
     tensorexpr::VarHandle,
-    BufList,
     double,
     int64_t,
     bool,
+    BufList,
+    IntList,
     ArgNone>;
+template <class T>
+std::vector<T> convertVecArgValue(const std::vector<ArgValue>& v) {
+  std::vector<T> res;
+  for (const auto& x : v) {
+    res.push_back(c10::get<T>(x));
+  }
+  return res;
+}
 
 enum ElementType {
   kAllTypes = 0,