[nnc] ported some more ops + added vectors to argvalue (#56766)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56766
Test Plan: Imported from OSS
Reviewed By: desertfire
Differential Revision: D28118331
Pulled By: Chillee
fbshipit-source-id: eb012943ad3b83e72a8cb17b594852164c3f0567
diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp
index ae23453..f625526 100644
--- a/torch/csrc/jit/tensorexpr/kernel.cpp
+++ b/torch/csrc/jit/tensorexpr/kernel.cpp
@@ -396,6 +396,20 @@
if (ti != bufs_.end()) {
return BufHandle(ti->second);
}
+ if (v->node()->kind() == prim::ListConstruct) {
+ std::vector<ArgValue> vec;
+ for (auto el : v->node()->inputs()) {
+ vec.push_back(toArg(el));
+ }
+ if (vec.size() == 0) {
+ return BufList(); // Return arbitrarily typed vector
+ } else if (c10::get_if<BufHandle>(&vec[0])) {
+ return convertVecArgValue<BufHandle>(vec);
+ } else if (c10::get_if<int64_t>(&vec[0])) {
+ return convertVecArgValue<int64_t>(vec);
+ }
+ throw unsupported_dtype();
+ }
if (v->node()->kind() == prim::Constant) {
const auto val = toIValue(v).value();
if (val.isDouble()) {
@@ -410,7 +424,7 @@
// the operator-specific lowering code.
return ArgNone();
} else {
- throw unsupported_dtype();
+ throw unsupported_dtype(val.type()->str());
}
}
@@ -976,15 +990,166 @@
});
}
-Tensor* computeCat(
- const std::vector<ArgValue>& inputList,
- const ArgValue& argDim,
- const std::vector<ExprHandle>& outputShape);
-
+std::pair<ScalarType, std::vector<BufHandle>> processCatList(
+ const std::vector<BufHandle>& bufList) {
+ if (bufList.size() == 0) {
+ throw std::runtime_error("Empty input list is passed to aten::cat");
+ }
+ std::vector<BufHandle> bufInputs;
+ std::vector<BufHandle> nonEmptyInputs;
+ for (auto buf : bufList) {
+ bufInputs.push_back(buf);
+ TORCH_INTERNAL_ASSERT(buf.node()->dims().size() > 0);
+ if (buf.node()->dims().size() == 1 &&
+ immediateAs<int>(buf.node()->dim(0)) == 0) {
+ continue;
+ }
+ nonEmptyInputs.push_back(buf);
+ }
+ ScalarType highType = bufInputs[0].dtype().scalar_type();
+ for (const auto input : bufInputs) {
+ auto maybe_dtype = input.dtype().scalar_type();
+ highType = promoteTypes(highType, maybe_dtype);
+ }
+ return {highType, nonEmptyInputs};
+}
Tensor* computeCatWoConditionals(
- const std::vector<ArgValue>& inputList,
- const ArgValue& argDim,
- const std::vector<ExprHandle>& outputShape);
+ const std::vector<ArgValue>& inputs,
+ const std::vector<ExprHandle>& outputShape) {
+ auto input_list = c10::get<BufList>(inputs[0]);
+ auto arg_dim = inputs[1];
+ auto cat_info = processCatList(input_list);
+ ScalarType high_type = cat_info.first;
+ std::vector<BufHandle> non_empty_inputs = cat_info.second;
+
+ // Now we build one loop per input:
+ //
+ // for i
+ // for j
+ // for k
+ // output[i,j,k] = inp1[i,j,k]
+ // for i
+ // for j
+ // for k
+ // output[i,j+l1,k] = inp2[i,j,k]
+ // for i
+ // for j
+ // for k
+ // output[i,j+l2,k] = inp3[i,j,k]
+
+ auto output_sizes_expr = ExprHandleVectorToExprVector(outputShape);
+ auto output_buf = new Buf("aten_cat", output_sizes_expr, ToDtype(high_type));
+ if (non_empty_inputs.size() == 0) {
+ return new Tensor(output_buf, new tensorexpr::Block({}));
+ }
+
+ int64_t concat_dim = c10::get<int64_t>(arg_dim);
+ size_t norm_concat_dim =
+ normalizeAndCheckIndex(concat_dim, outputShape.size());
+
+ auto gen_code_for_input = [&](const BufHandle& inp,
+ size_t inp_pos,
+ const Expr* concat_dim_size,
+ const std::vector<ExprHandle>& dims) {
+ std::vector<Var*> for_vars(dims.size());
+ std::vector<const Expr*> load_indices(dims.size());
+ std::vector<const Expr*> store_indices(dims.size());
+ for (size_t i = 0; i < dims.size(); ++i) {
+ for_vars[i] = new Var(
+ "i" + c10::to_string(inp_pos) + "_" + c10::to_string(i), kInt);
+ load_indices[i] = for_vars[i];
+ if (i == norm_concat_dim) {
+ store_indices[i] = new Add(for_vars[i], concat_dim_size);
+ } else {
+ store_indices[i] = for_vars[i];
+ }
+ }
+ auto inp_buf = inp.node();
+ auto load_expr = new Load(inp_buf, load_indices);
+ auto load_promoted = promoteToDtype(ExprHandle(load_expr), high_type);
+ Stmt* st = new Store(output_buf, store_indices, load_promoted.node());
+ for (size_t i = dims.size(); i > 0; --i) {
+ st = new For(for_vars[i - 1], new IntImm(0), dims[i - 1].node(), st);
+ }
+ return st;
+ };
+
+ Expr* concat_dim_size = nullptr;
+ auto block = new tensorexpr::Block({});
+ for (size_t i = 0; i < non_empty_inputs.size(); ++i) {
+ auto input_dims =
+ ExprVectorToExprHandleVector(non_empty_inputs[i].node()->dims());
+ if (concat_dim_size == nullptr) {
+ concat_dim_size = new IntImm(0);
+ }
+ block->append_stmt(gen_code_for_input(
+ non_empty_inputs[i], i, concat_dim_size, input_dims));
+ concat_dim_size =
+ new Add(concat_dim_size, input_dims[norm_concat_dim].node());
+ }
+ return new Tensor(output_buf, IRSimplifier::simplify(block));
+}
+
+Tensor* computeCat(
+ const std::vector<ArgValue>& inputs,
+ const std::vector<ExprHandle>& outputShape) {
+ if (getCatWoConditionals()) {
+ return computeCatWoConditionals(inputs, outputShape);
+ }
+ auto inputList = c10::get<BufList>(inputs[0]);
+ auto argDim = inputs[1];
+ auto catInfo = processCatList(inputList);
+ ScalarType highType = catInfo.first;
+ std::vector<BufHandle> nonEmptyInputs = catInfo.second;
+ return Compute(
+ "aten_cat",
+ c10::fmap<DimArg>(outputShape),
+ [&](const std::vector<VarHandle>& axes) {
+ if (nonEmptyInputs.size() == 0) {
+ return ExprHandle(0);
+ }
+
+ int64_t dim_ = c10::get<int64_t>(argDim);
+ size_t dim = normalizeAndCheckIndex(dim_, axes.size());
+ // Promote input types.
+ // Note that we need to consider all inputs, including empty - they
+ // also affect the resultant dtype.
+
+ // Now we know the final dtype, we know what inputs are non-empty,
+ // and we know that there is at least one such an input. With all
+ // that we construct a tensor expression performing the
+ // concatenation.
+ // The expression we build here is a cascading if-then-else that
+ // essentially represents:
+ //
+ // inp1[i, j, k] if 0 < i < l1,
+ // out[i,j,k] = inp2[i, j-l1, k] if l1 =< i < l1 + l2,
+ // ...
+ // inpN[i, j-l_N_1, k] if l1+l2+...l_N_1 < i
+ // where l_i is the corresponding size of the i-th input.
+ std::vector<ExprHandle> newAxes(axes.begin(), axes.end());
+ ExprHandle load = promoteToDtype(
+ tensorOrConstant(nonEmptyInputs[0], newAxes), highType);
+ size_t offset =
+ dynamic_cast<const IntImm*>(nonEmptyInputs[0].node()->dim(dim))
+ ->value();
+ newAxes[dim] = newAxes[dim] - IntImm::make(offset);
+
+ for (size_t ii = 1; ii < nonEmptyInputs.size(); ++ii) {
+ auto input = nonEmptyInputs[ii];
+ load = ifThenElse(
+ CompareSelect::make(axes[dim], IntImm::make(offset), kLT),
+ load,
+ promoteToDtype(tensorOrConstant(input, newAxes), highType));
+
+ offset +=
+ dynamic_cast<const IntImm*>(input.node()->dim(dim))->value();
+ newAxes[dim] = axes[dim] - IntImm::make(offset);
+ }
+
+ return load;
+ });
+}
Tensor* computeMatmul(
const std::vector<ArgValue>& inputs,
@@ -1500,6 +1665,16 @@
return tan(promoteIntegerToDefaultType(a));
});
} break;
+ case aten::type_as: {
+ const BufHandle rhs = c10::get<BufHandle>(inputs[1]);
+ auto dtype = rhs.dtype();
+ return computeOneOperand(
+ "aten_type_as",
+ inputs,
+ outputShape,
+ outputType,
+ [dtype](const ExprHandle& lhs) { return Cast::make(dtype, lhs); });
+ } break;
case aten::pow: {
return computeTwoOperand(
"aten_pow",
@@ -1770,7 +1945,19 @@
outputType,
[](const ExprHandle& a) { return cast<float>(a); });
} break;
-
+ case aten::to: {
+ // see handling of aten::to in tensorexpr_fuser.cpp for why we only
+ // need to handle the first input
+ return computeOneOperand(
+ "aten_to",
+ inputs,
+ outputShape,
+ outputType,
+ [outputType](const ExprHandle& a) {
+ TORCH_INTERNAL_ASSERT(outputType);
+ return Cast::make(ToDtype(*outputType), a);
+ });
+ } break;
case aten::threshold: {
return computeThreeOperand(
"aten_threshold",
@@ -1872,6 +2059,9 @@
case aten::matmul: {
return computeMatmul(inputs, outputShape, outputType);
}
+ case aten::cat: {
+ return computeCat(inputs, outputShape);
+ }
default: {
throw std::runtime_error("Unhandled node kind");
return nullptr;
@@ -1932,6 +2122,7 @@
case aten::cos:
case aten::sin:
case aten::tan:
+ case aten::type_as:
case aten::pow:
case aten::fmod:
case aten::lerp:
@@ -1952,13 +2143,15 @@
case aten::round:
case aten::trunc:
case aten::_cast_Float:
+ case aten::to:
case aten::threshold:
case aten::where:
case aten::frac:
case aten::lgamma:
case aten::slice:
case aten::unsqueeze:
- case aten::matmul: {
+ case aten::matmul:
+ case aten::cat: {
std::vector<ArgValue> argInputs;
for (auto inp : inputs) {
argInputs.push_back(toArg(inp));
@@ -1969,44 +2162,6 @@
v->node()->kind(), argInputs, outputShape, outputType);
} break;
- case aten::to: {
- // see handling of aten::to in tensorexpr_fuser.cpp for why we only
- // need to handle the first input
- auto outputType = findDtypeForValue(v->node()->output());
- auto outputShape = sizesForValue(v);
- auto output_dtype = findDtypeForValue(v->node()->output());
- return computeOneOperand(
- "aten_to",
- {toArg(inputs[0])},
- outputShape,
- outputType,
- [output_dtype](const ExprHandle& a) {
- TORCH_INTERNAL_ASSERT(output_dtype);
- return Cast::make(ToDtype(*output_dtype), a);
- });
- } break;
-
- case aten::type_as: {
- auto outputType = findDtypeForValue(v->node()->output());
- auto outputShape = sizesForValue(v);
- const Buf* rhs = bufs_.at(inputs[1]);
- auto dtype = rhs->dtype();
- return computeOneOperand(
- "aten_type_as",
- {toArg(inputs[0])},
- outputShape,
- outputType,
- [dtype](const ExprHandle& lhs) { return Cast::make(dtype, lhs); });
- } break;
- case aten::cat: {
- std::vector<ArgValue> inputList;
- for (auto inp : v->node()->input(0)->node()->inputs()) {
- inputList.push_back(toArg(inp));
- }
- auto outputShape = sizesForValue(v);
- return computeCat(inputList, toArg(v->node()->input(1)), outputShape);
- } break;
-
case prim::ConstantChunk: {
return Compute(
"prim_constantchunk",
@@ -2377,167 +2532,6 @@
} // namespace
-std::pair<ScalarType, std::vector<BufHandle>> processCatList(
- const std::vector<ArgValue>& inputList) {
- if (inputList.size() == 0) {
- throw std::runtime_error("Empty input list is passed to aten::cat");
- }
- std::vector<BufHandle> bufInputs;
- std::vector<BufHandle> nonEmptyInputs;
- for (auto input : inputList) {
- auto buf = c10::get<BufHandle>(input);
- bufInputs.push_back(buf);
- assert(buf.node()->dims().size() > 0);
- if (buf.node()->dims().size() == 1 &&
- immediateAs<int>(buf.node()->dim(0)) == 0) {
- continue;
- }
- nonEmptyInputs.push_back(buf);
- }
- auto maybe_dtype = bufInputs[0].dtype().scalar_type();
- ScalarType highType = maybe_dtype;
- for (const auto input : bufInputs) {
- auto maybe_dtype = input.dtype().scalar_type();
- highType = promoteTypes(highType, maybe_dtype);
- }
- return {highType, nonEmptyInputs};
-}
-
-Tensor* computeCat(
- const std::vector<ArgValue>& inputList,
- const ArgValue& argDim,
- const std::vector<ExprHandle>& outputShape) {
- if (getCatWoConditionals()) {
- return computeCatWoConditionals(inputList, argDim, outputShape);
- }
- auto inputs = processCatList(inputList);
- ScalarType highType = inputs.first;
- std::vector<BufHandle> nonEmptyInputs = inputs.second;
- return Compute(
- "aten_cat",
- c10::fmap<DimArg>(outputShape),
- [&](const std::vector<VarHandle>& axes) {
- if (nonEmptyInputs.size() == 0) {
- return ExprHandle(0);
- }
-
- int64_t dim_ = c10::get<int64_t>(argDim);
- size_t dim = normalizeAndCheckIndex(dim_, axes.size());
- // Promote input types.
- // Note that we need to consider all inputs, including empty - they
- // also affect the resultant dtype.
-
- // Now we know the final dtype, we know what inputs are non-empty,
- // and we know that there is at least one such an input. With all
- // that we construct a tensor expression performing the
- // concatenation.
- // The expression we build here is a cascading if-then-else that
- // essentially represents:
- //
- // inp1[i, j, k] if 0 < i < l1,
- // out[i,j,k] = inp2[i, j-l1, k] if l1 =< i < l1 + l2,
- // ...
- // inpN[i, j-l_N_1, k] if l1+l2+...l_N_1 < i
- // where l_i is the corresponding size of the i-th input.
- std::vector<ExprHandle> newAxes(axes.begin(), axes.end());
- ExprHandle load = promoteToDtype(
- tensorOrConstant(nonEmptyInputs[0], newAxes), highType);
- size_t offset =
- dynamic_cast<const IntImm*>(nonEmptyInputs[0].node()->dim(dim))
- ->value();
- newAxes[dim] = newAxes[dim] - IntImm::make(offset);
-
- for (size_t ii = 1; ii < nonEmptyInputs.size(); ++ii) {
- auto input = nonEmptyInputs[ii];
- load = ifThenElse(
- CompareSelect::make(axes[dim], IntImm::make(offset), kLT),
- load,
- promoteToDtype(tensorOrConstant(input, newAxes), highType));
-
- offset +=
- dynamic_cast<const IntImm*>(input.node()->dim(dim))->value();
- newAxes[dim] = axes[dim] - IntImm::make(offset);
- }
-
- return load;
- });
-}
-Tensor* computeCatWoConditionals(
- const std::vector<ArgValue>& input_list,
- const ArgValue& arg_dim,
- const std::vector<ExprHandle>& output_shape) {
- auto inputs = processCatList(input_list);
- ScalarType high_type = inputs.first;
- std::vector<BufHandle> non_empty_inputs = inputs.second;
-
- // Now we build one loop per input:
- //
- // for i
- // for j
- // for k
- // output[i,j,k] = inp1[i,j,k]
- // for i
- // for j
- // for k
- // output[i,j+l1,k] = inp2[i,j,k]
- // for i
- // for j
- // for k
- // output[i,j+l2,k] = inp3[i,j,k]
-
- auto output_sizes_expr = ExprHandleVectorToExprVector(output_shape);
- auto output_buf = new Buf("aten_cat", output_sizes_expr, ToDtype(high_type));
- if (non_empty_inputs.size() == 0) {
- return new Tensor(output_buf, new tensorexpr::Block({}));
- }
-
- int64_t concat_dim = c10::get<int64_t>(arg_dim);
- size_t norm_concat_dim =
- normalizeAndCheckIndex(concat_dim, output_shape.size());
-
- auto gen_code_for_input = [&](const BufHandle& inp,
- size_t inp_pos,
- const Expr* concat_dim_size,
- const std::vector<ExprHandle>& dims) {
- std::vector<Var*> for_vars(dims.size());
- std::vector<const Expr*> load_indices(dims.size());
- std::vector<const Expr*> store_indices(dims.size());
- for (size_t i = 0; i < dims.size(); ++i) {
- for_vars[i] = new Var(
- "i" + c10::to_string(inp_pos) + "_" + c10::to_string(i), kInt);
- load_indices[i] = for_vars[i];
- if (i == norm_concat_dim) {
- store_indices[i] = new Add(for_vars[i], concat_dim_size);
- } else {
- store_indices[i] = for_vars[i];
- }
- }
- auto inp_buf = inp.node();
- auto load_expr = new Load(inp_buf, load_indices);
- auto load_promoted = promoteToDtype(ExprHandle(load_expr), high_type);
- Stmt* st = new Store(output_buf, store_indices, load_promoted.node());
- for (size_t i = dims.size(); i > 0; --i) {
- st = new For(for_vars[i - 1], new IntImm(0), dims[i - 1].node(), st);
- }
- return st;
- };
-
- Expr* concat_dim_size = nullptr;
- auto block = new tensorexpr::Block({});
- for (size_t i = 0; i < non_empty_inputs.size(); ++i) {
- auto input_dims =
- ExprVectorToExprHandleVector(non_empty_inputs[i].node()->dims());
- if (concat_dim_size == nullptr) {
- concat_dim_size = new IntImm(0);
- }
- block->append_stmt(gen_code_for_input(
- non_empty_inputs[i], i, concat_dim_size, input_dims));
- concat_dim_size =
- new Add(concat_dim_size, input_dims[norm_concat_dim].node());
- }
- return new Tensor(output_buf, IRSimplifier::simplify(block));
-}
-
Tensor* TensorExprKernel::computeSum(const torch::jit::Value* v) {
auto reduction_info = getReductionInfo(v->node());
return Reduce(
diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h
index c2a9d98..368f562 100644
--- a/torch/csrc/jit/tensorexpr/kernel.h
+++ b/torch/csrc/jit/tensorexpr/kernel.h
@@ -25,14 +25,24 @@
}
using ArgNone = c10::monostate;
using BufList = std::vector<tensorexpr::BufHandle>;
+using IntList = std::vector<int64_t>;
using ArgValue = c10::variant<
tensorexpr::BufHandle,
tensorexpr::VarHandle,
- BufList,
double,
int64_t,
bool,
+ BufList,
+ IntList,
ArgNone>;
+template <class T>
+std::vector<T> convertVecArgValue(const std::vector<ArgValue>& v) {
+ std::vector<T> res;
+ for (const auto& x : v) {
+ res.push_back(c10::get<T>(x));
+ }
+ return res;
+}
enum ElementType {
kAllTypes = 0,