[TensorExpr] Implement shape inference for TE. (#41451)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41451
Since TE operates on a limited subset of ops with a well-defined
semantics, we can easily infer shapes of intermediate and output tensors
given shapes of the inputs.
There is a couple of ops that are not yet supported in the shape
inference, once we add them we could relax the shape info requirements
in the TE fuser: currently it requires all values in the fusion group to
have shapes known and we can change it to only inputs.
Test Plan: Imported from OSS
Reviewed By: eellison
Differential Revision: D22543470
Pulled By: ZolotukhinM
fbshipit-source-id: 256bae921028cb6ec3af91977f12bb870c385f40
diff --git a/test/cpp/tensorexpr/test_kernel.cpp b/test/cpp/tensorexpr/test_kernel.cpp
index 892b767..f6bd2f1 100644
--- a/test/cpp/tensorexpr/test_kernel.cpp
+++ b/test/cpp/tensorexpr/test_kernel.cpp
@@ -105,5 +105,119 @@
}
}
+void testKernel_4() {
+ // Test TensorExpr shape inference capabilities: it should only require shapes
+ // for the inputs
+ {
+ KernelScope kernel_scope;
+
+ const auto graph_string = R"IR(
+ graph(%0 : Float(5:3, 3:1, device=cpu),
+ %1 : Float(5:12, 3:2, device=cpu)):
+ %2 : Tensor = aten::mul(%0, %1)
+ %3 : Tensor = aten::mul(%0, %2)
+ return (%3))IR";
+ auto graph = std::make_shared<Graph>();
+ parseIR(graph_string, &*graph);
+
+ auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
+ auto b = at::rand({10, 6}, TensorOptions(kCPU).dtype(at::kFloat))
+ .index({Slice(None, None, 2), Slice(None, None, 2)});
+ auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
+ auto ref = a * (a * b);
+ TensorExprKernel k(graph);
+ std::vector<at::Tensor> inputs = {a, b};
+ Stmt* s = k.getCodeGenStmt();
+
+ std::vector<IValue> stack = fmap<IValue>(inputs);
+ k.run(stack);
+ o = stack[0].toTensor();
+ for (size_t i = 0; i < 5 * 3; i++) {
+ CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
+ }
+ }
+ {
+ KernelScope kernel_scope;
+
+ const auto graph_string = R"IR(
+ graph(%0 : Float(8:8, 8:1, device=cpu),
+ %1 : Float(8:8, 8:1, device=cpu)):
+ %2 : Tensor = aten::mul(%0, %1)
+ %3 : Tensor, %4 : Tensor = prim::ConstantChunk[dim=1,chunks=2](%2)
+ %r : Tensor = aten::mul(%3, %4)
+ return (%r))IR";
+ auto graph = std::make_shared<Graph>();
+ parseIR(graph_string, &*graph);
+
+ auto a = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat));
+ auto b = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat));
+ auto o = at::zeros({8, 4}, TensorOptions(kCPU).dtype(at::kFloat));
+ auto t = torch::chunk(a * b, 2, 1);
+ auto ref = t[0] * t[1];
+ TensorExprKernel k(graph);
+ std::vector<at::Tensor> inputs = {a, b};
+ Stmt* s = k.getCodeGenStmt();
+
+ std::vector<IValue> stack = fmap<IValue>(inputs);
+ k.run(stack);
+ o = stack[0].toTensor();
+ CHECK_EQ(o.sizes()[0], 8);
+ CHECK_EQ(o.sizes()[1], 4);
+ for (size_t i = 0; i < 8 * 4; i++) {
+ CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
+ }
+ }
+ {
+ // Test that shape inference handles aten::unsqueeze
+ KernelScope kernel_scope;
+
+ const auto graph_string = R"IR(
+ graph(%a : Float(4:2, 2:1, device=cpu),
+ %b : Float(4:6, 3:2, 2:1, device=cpu),
+ %c : Float(3:4, 2:2, 2:1, device=cpu)):
+ %one : int = prim::Constant[value=1]()
+ %minus_one : int = prim::Constant[value=-1]()
+ %three : int = prim::Constant[value=3]()
+ %minus_four : int = prim::Constant[value=-4]()
+ %a1 : Tensor = aten::unsqueeze(%a, %one) # new size: [4,1,2]
+ %a2 : Tensor = aten::unsqueeze(%a1, %minus_one) # new size: [4,1,2,1]
+ %b1 : Tensor = aten::unsqueeze(%b, %three) # new size: [4,3,2,1]
+ %c1 : Tensor = aten::unsqueeze(%c, %minus_four) # new size: [1,3,2,2]
+ %ab : Tensor = aten::mul(%a2, %b1) # expected size: [4,3,2,1]
+ %abc : Tensor = aten::mul(%ab, %c1) # expected size: [4,3,2,2]
+ return (%abc))IR";
+ auto graph = std::make_shared<Graph>();
+ parseIR(graph_string, &*graph);
+
+ auto a = at::rand({4, 2}, TensorOptions(kCPU).dtype(at::kFloat));
+ auto b = at::rand({4, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat));
+ auto c = at::rand({3, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
+ auto o = at::zeros({4, 3, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
+ auto ref = at::unsqueeze(at::unsqueeze(a, 1), -1) * at::unsqueeze(b, 3) *
+ at::unsqueeze(c, -4);
+
+ TensorExprKernel k(graph);
+ std::vector<at::Tensor> inputs = {a, b, c};
+ Stmt* s = k.getCodeGenStmt();
+
+ std::vector<IValue> stack = fmap<IValue>(inputs);
+ k.run(stack);
+ o = stack[0].toTensor();
+
+ // Check sizes
+ CHECK_EQ(o.sizes().size(), ref.sizes().size());
+ size_t num_el = 1;
+ for (auto idx = 0; idx < ref.sizes().size(); idx++) {
+ CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]);
+ num_el *= ref.sizes()[idx];
+ }
+
+ // Check the contents
+ for (size_t i = 0; i < num_el; i++) {
+ CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
+ }
+ }
+}
+
} // namespace jit
} // namespace torch
diff --git a/test/cpp/tensorexpr/tests.h b/test/cpp/tensorexpr/tests.h
index b40f1b8..d79cf90 100644
--- a/test/cpp/tensorexpr/tests.h
+++ b/test/cpp/tensorexpr/tests.h
@@ -193,6 +193,7 @@
_(Kernel_1) \
_(Kernel_2) \
_(Kernel_3) \
+ _(Kernel_4) \
_(FuserPass_1) \
_(FuserPass_2)
diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp
index 6f9a447..d2f6420 100644
--- a/torch/csrc/jit/tensorexpr/kernel.cpp
+++ b/torch/csrc/jit/tensorexpr/kernel.cpp
@@ -56,7 +56,7 @@
return static_cast<at::ScalarType>(t->body()->dtype().scalar_type());
}
-static std::vector<ExprHandle> texprSizes(
+std::vector<ExprHandle> TensorExprKernel::sizesFromVaryingShape(
const c10::VaryingShape<int64_t>& shape) {
std::vector<ExprHandle> dims;
for (size_t i = 0; i < *shape.size(); i++) {
@@ -65,27 +65,162 @@
return dims;
}
-static std::vector<DimArg> texprDims(const torch::jit::Value* v) {
- if (v->type()->kind() != TypeKind::TensorType) {
- throw malformed_input("type is not Tensor");
- }
-
- auto tt = v->type()->cast<TensorType>();
+std::vector<DimArg> TensorExprKernel::dimsFromSizes(
+ const std::vector<ExprHandle>& sizes) {
std::vector<DimArg> dimArgs;
- int i = 0;
- for (auto const& s : texprSizes(tt->sizes())) {
- dimArgs.emplace_back(DimArg(s, "i" + c10::to_string(i++)));
+ for (size_t idx = 0; idx < sizes.size(); idx++) {
+ dimArgs.emplace_back(DimArg(sizes[idx], "i" + c10::to_string(idx)));
}
return dimArgs;
}
-template <typename T>
-int64_t bufferSize(T t) {
- int64_t size = 1;
- for (int i = 0; i < t.ndim(); i++) {
- size *= t.dim(i).template AsNode<IntImm>()->value();
+std::vector<ExprHandle> TensorExprKernel::sizesForValue(
+ const torch::jit::Value* v) {
+ if (known_sizes_.count(v)) {
+ return known_sizes_.at(v);
}
- return size;
+
+ // If the shape is present in the type info, just extract it from here. No
+ // need to infer it.
+ if (v->type()->kind() == TypeKind::TensorType) {
+ auto tt = v->type()->cast<TensorType>();
+ if (tt->isComplete()) {
+ return sizesFromVaryingShape(tt->sizes());
+ }
+ }
+
+ known_sizes_[v] = inferSizesForValue(v);
+ return known_sizes_.at(v);
+}
+
+std::vector<ExprHandle> TensorExprKernel::inferSizesForValue(
+ const torch::jit::Value* v) {
+ switch (v->node()->kind()) {
+ case aten::_cast_Float:
+ case aten::sigmoid:
+ case aten::reciprocal:
+ case aten::neg:
+ case aten::relu:
+ case aten::log:
+ case aten::log10:
+ case aten::log2:
+ case aten::exp:
+ case aten::expm1:
+ case aten::erf:
+ case aten::erfc:
+ case aten::cos:
+ case aten::sin:
+ case aten::tan:
+ case aten::rand_like:
+ case aten::acos:
+ case aten::asin:
+ case aten::cosh:
+ case aten::sinh:
+ case aten::atan:
+ case aten::tanh:
+ case aten::sqrt:
+ case aten::rsqrt:
+ case aten::abs:
+ case aten::ceil:
+ case aten::floor:
+ case aten::round:
+ case aten::trunc:
+ case aten::frac:
+ case aten::lgamma:
+ return sizesForValue(v->node()->input());
+
+ case aten::sub:
+ case aten::add:
+ case aten::mul:
+ case aten::div:
+ case aten::__and__:
+ case aten::__or__:
+ case aten::__xor__:
+ case aten::__lshift__:
+ case aten::__rshift__:
+ case aten::eq:
+ case aten::ne:
+ case aten::ge:
+ case aten::gt:
+ case aten::le:
+ case aten::lt:
+ case aten::min:
+ case aten::max:
+ case aten::type_as:
+ case aten::pow:
+ case aten::fmod:
+ case aten::remainder:
+ case aten::atan2:
+ case aten::_sigmoid_backward:
+ case aten::_tanh_backward: {
+ std::vector<std::vector<ExprHandle>> shapes;
+ for (size_t idx = 0; idx < 2; idx++) {
+ torch::jit::Value* inp = v->node()->input(idx);
+ shapes.push_back(sizesForValue(inp));
+ }
+ return broadcastShapes(shapes);
+ }
+
+ case aten::lerp:
+ case aten::clamp:
+ case aten::threshold:
+ case aten::where: {
+ std::vector<std::vector<ExprHandle>> shapes;
+ for (size_t idx = 0; idx < 3; idx++) {
+ torch::jit::Value* inp = v->node()->input(idx);
+ shapes.push_back(sizesForValue(inp));
+ }
+ return broadcastShapes(shapes);
+ }
+
+ case aten::addcmul: {
+ std::vector<std::vector<ExprHandle>> shapes;
+ for (size_t idx = 0; idx < 4; idx++) {
+ torch::jit::Value* inp = v->node()->input(idx);
+ shapes.push_back(sizesForValue(inp));
+ }
+ return broadcastShapes(shapes);
+ }
+
+ case prim::ConstantChunk: {
+ auto shape = sizesForValue(v->node()->input());
+ int dim = v->node()->i(attr::dim);
+ int chunks = v->node()->i(attr::chunks);
+ shape[dim] = IRSimplifier::simplify(shape[dim] / chunks);
+ return shape;
+ }
+
+ case aten::unsqueeze: {
+ auto const& n = v->node();
+ auto shape = sizesForValue(n->input(0));
+
+ int64_t dim = constant(n->input(1)).AsNode<IntImm>()->value();
+ // From the documentation
+ // (https://pytorch.org/docs/master/generated/torch.unsqueeze.html):
+ //
+ // A dim value within the range [-input.dim() - 1, input.dim() + 1) can be
+ // used. Negative dim will correspond to unsqueeze() applied at dim = dim
+ // + input.dim() + 1.
+ if (dim < 0) {
+ dim = dim + shape.size() + 1;
+ }
+ if (dim < 0 || dim > shape.size()) {
+ throw std::runtime_error("Invalid 'dim' input in aten::unsqueeze");
+ }
+
+ shape.insert(shape.begin() + dim, ExprHandle(1));
+ return shape;
+ }
+
+ case aten::cat:
+ case aten::slice:
+ throw std::runtime_error(
+ "Shape info is not implemented for this kind of node");
+
+ default: {
+ throw std::runtime_error("Unhandled node kind");
+ }
+ }
}
ExprHandle TensorExprKernel::constant(const torch::jit::Value* v) {
@@ -159,6 +294,10 @@
return e;
}
+ if (!v->isCompleteTensor()) {
+ return e;
+ }
+
auto tt = *v->type()->cast<TensorType>()->scalarType();
if (tt == static_cast<at::ScalarType>(e.dtype().scalar_type())) {
@@ -189,21 +328,32 @@
return n->value() == 1;
}
-static std::pair<std::vector<ExprHandle>, bool> broadcastShapes(
+std::vector<ExprHandle> TensorExprKernel::broadcastShapes(
+ std::vector<std::vector<ExprHandle>> shapes) {
+ size_t n = shapes.size();
+ if (n == 1) {
+ return shapes[0];
+ }
+ auto res1 = broadcastShapes(shapes[n - 2], shapes[n - 1]);
+ shapes[n - 2] = res1;
+ shapes.pop_back();
+ auto res2 = broadcastShapes(shapes);
+ return res2;
+}
+std::vector<ExprHandle> TensorExprKernel::broadcastShapes(
const std::vector<ExprHandle>& a,
const std::vector<ExprHandle>& b) {
- bool broadcast = false;
auto at = a.rbegin();
auto bt = b.rbegin();
std::vector<ExprHandle> ret;
while (at != a.rend() || bt != b.rend()) {
if (at == a.rend()) {
- broadcast = true;
+ hasBroadcast_ = true;
ret.push_back(*bt++);
continue;
}
if (bt == b.rend()) {
- broadcast = true;
+ hasBroadcast_ = true;
ret.push_back(*at++);
continue;
}
@@ -214,7 +364,7 @@
if (isOne(*at)) {
if (!isOne(*bt)) {
dim = *bt;
- broadcast = true;
+ hasBroadcast_ = true;
}
}
ret.push_back(dim);
@@ -222,17 +372,7 @@
bt++;
}
std::reverse(ret.begin(), ret.end());
- return {ret, broadcast};
-}
-
-template <typename... Args>
-static std::pair<std::vector<ExprHandle>, bool> broadcastShapes(
- const std::vector<ExprHandle>& a,
- const std::vector<ExprHandle>& b,
- Args... args) {
- auto const& res = broadcastShapes(a, b);
- auto const& res2 = broadcastShapes(res.first, args...);
- return {res2.first, res.second || res2.second};
+ return ret;
}
std::vector<ExprHandle> TensorExprKernel::valueShape(
@@ -270,10 +410,8 @@
const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
innerExpr) {
auto const& n = v->node();
- auto const& res =
+ auto const& shape =
broadcastShapes(valueShape(n->inputs()[0]), valueShape(n->inputs()[1]));
- auto const& shape = res.first;
- hasBroadcast_ |= res.second;
return Compute(
name,
c10::fmap<DimArg>(shape),
@@ -296,10 +434,8 @@
const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
innerExpr) {
auto const& n = v->node();
- auto const& res =
+ auto const& shape =
broadcastShapes(valueShape(n->inputs()[0]), valueShape(n->inputs()[1]));
- auto const& shape = res.first;
- hasBroadcast_ |= res.second;
return Compute(
name,
c10::fmap<DimArg>(shape),
@@ -324,12 +460,12 @@
ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>&
innerExpr) {
auto const& n = v->node();
- auto const& res = broadcastShapes(
- valueShape(n->inputs()[0]),
- valueShape(n->inputs()[1]),
- valueShape(n->inputs()[2]));
- auto const& shape = res.first;
- hasBroadcast_ |= res.second;
+ std::vector<std::vector<ExprHandle>> shapes;
+ for (size_t idx = 0; idx < 2; idx++) {
+ torch::jit::Value* inp = n->input(idx);
+ shapes.push_back(sizesForValue(inp));
+ }
+ auto const& shape = broadcastShapes(shapes);
return Compute(
name,
c10::fmap<DimArg>(shape),
@@ -355,12 +491,12 @@
ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>&
innerExpr) {
auto const& n = v->node();
- auto const& res = broadcastShapes(
- valueShape(n->inputs()[0]),
- valueShape(n->inputs()[1]),
- valueShape(n->inputs()[2]));
- auto const& shape = res.first;
- hasBroadcast_ |= res.second;
+ std::vector<std::vector<ExprHandle>> shapes;
+ for (size_t idx = 0; idx < 3; idx++) {
+ torch::jit::Value* inp = n->input(idx);
+ shapes.push_back(sizesForValue(inp));
+ }
+ auto const& shape = broadcastShapes(shapes);
return Compute(
name,
c10::fmap<DimArg>(shape),
@@ -387,13 +523,12 @@
const ExprHandle&,
const ExprHandle&)>& innerExpr) {
auto const& n = v->node();
- auto const& res = broadcastShapes(
- valueShape(n->inputs()[0]),
- valueShape(n->inputs()[1]),
- valueShape(n->inputs()[2]),
- valueShape(n->inputs()[3]));
- auto const& shape = res.first;
- hasBroadcast_ |= res.second;
+ std::vector<std::vector<ExprHandle>> shapes;
+ for (size_t idx = 0; idx < 4; idx++) {
+ torch::jit::Value* inp = n->input(idx);
+ shapes.push_back(sizesForValue(inp));
+ }
+ auto const& shape = broadcastShapes(shapes);
return Compute(
name,
c10::fmap<DimArg>(shape),
@@ -871,7 +1006,7 @@
case prim::ConstantChunk: {
return Compute(
"prim_constantchunk",
- texprDims(v),
+ dimsFromSizes(sizesForValue(v)),
[this, v](const std::vector<VarHandle>& axes) {
auto const& n = v->node();
int64_t dim = n->i(attr::dim);
@@ -888,7 +1023,7 @@
case aten::cat: {
return Compute(
"aten_cat",
- texprDims(v),
+ dimsFromSizes(sizesForValue(v)),
[this, v](const std::vector<VarHandle>& axes) {
auto const& n = v->node();
auto inputs = n->inputs()[0]->node()->inputs();
@@ -911,11 +1046,10 @@
return load;
});
}
-
case aten::slice: {
return Compute(
"aten_slice",
- texprDims(v),
+ dimsFromSizes(sizesForValue(v)),
[this, v](const std::vector<VarHandle>& axes) {
auto const& n = v->node();
int dim = constant(n->inputs()[1]).AsNode<IntImm>()->value();
@@ -931,7 +1065,7 @@
case aten::unsqueeze: {
return Compute(
"aten_unsqueeze",
- texprDims(v),
+ dimsFromSizes(sizesForValue(v)),
[this, v](const std::vector<VarHandle>& axes) {
auto const& n = v->node();
int64_t dim = constant(n->inputs()[1]).AsNode<IntImm>()->value();
@@ -939,13 +1073,21 @@
if (axes.size() == 0) {
throw malformed_input("axes are zero handling unsqueeze");
}
-
- dim += axes.size() - 1;
+ dim += axes.size();
+ }
+ // To construct an expression for an 'unsqueezed' tensor we need to
+ // drop the DIM-th axis, i.e.
+ // unsqueezed_v[i,j,k,l] = v[i,j,l] # dim = 2 - drop index 'k'
+ // 0 1 2 3
+ std::vector<ExprHandle> indices;
+ int64_t i = 0;
+ for (auto a : axes) {
+ if (i++ != dim) {
+ indices.emplace_back(ExprHandle(a.node()));
+ }
}
- std::vector<ExprHandle> newAxes(axes.begin(), axes.end());
- newAxes.erase(newAxes.begin() + dim);
- return tensorOrConstant(n->inputs()[0], newAxes);
+ return tensorOrConstant(n->inputs()[0], indices);
});
}
diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h
index bb2144e..548e2a7 100644
--- a/torch/csrc/jit/tensorexpr/kernel.h
+++ b/torch/csrc/jit/tensorexpr/kernel.h
@@ -66,6 +66,18 @@
void runKernel(Stack& stack);
+ std::vector<DimArg> dimsFromSizes(const std::vector<ExprHandle>& sizes);
+ std::vector<ExprHandle> sizesForValue(const torch::jit::Value* v);
+ std::vector<ExprHandle> inferSizesForValue(const torch::jit::Value* v);
+ std::vector<ExprHandle> sizesFromVaryingShape(
+ const c10::VaryingShape<int64_t>& shape);
+
+ std::vector<ExprHandle> broadcastShapes(
+ const std::vector<ExprHandle>& a,
+ const std::vector<ExprHandle>& b);
+ std::vector<ExprHandle> broadcastShapes(
+ std::vector<std::vector<ExprHandle>> shapes);
+
ExprHandle constant(const torch::jit::Value* v);
template <typename T, typename T1>
@@ -219,6 +231,8 @@
bool fallback_{false};
bool hasRandom_{false};
bool hasBroadcast_{false};
+ std::unordered_map<const torch::jit::Value*, std::vector<ExprHandle>>
+ known_sizes_;
};
TORCH_API int& getTECudaPointwiseLoopLevels();