[TensorExpr] Implement shape inference for TE. (#41451)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41451

Since TE operates on a limited subset of ops with a well-defined
semantics, we can easily infer shapes of intermediate and output tensors
given shapes of the inputs.

There is a couple of ops that are not yet supported in the shape
inference, once we add them we could relax the shape info requirements
in the TE fuser: currently it requires all values in the fusion group to
have shapes known and we can change it to only inputs.

Test Plan: Imported from OSS

Reviewed By: eellison

Differential Revision: D22543470

Pulled By: ZolotukhinM

fbshipit-source-id: 256bae921028cb6ec3af91977f12bb870c385f40
diff --git a/test/cpp/tensorexpr/test_kernel.cpp b/test/cpp/tensorexpr/test_kernel.cpp
index 892b767..f6bd2f1 100644
--- a/test/cpp/tensorexpr/test_kernel.cpp
+++ b/test/cpp/tensorexpr/test_kernel.cpp
@@ -105,5 +105,119 @@
   }
 }
 
+void testKernel_4() {
+  // Test TensorExpr shape inference capabilities: it should only require shapes
+  // for the inputs
+  {
+    KernelScope kernel_scope;
+
+    const auto graph_string = R"IR(
+      graph(%0 : Float(5:3,  3:1, device=cpu),
+            %1 : Float(5:12, 3:2, device=cpu)):
+        %2 : Tensor = aten::mul(%0, %1)
+        %3 : Tensor = aten::mul(%0, %2)
+        return (%3))IR";
+    auto graph = std::make_shared<Graph>();
+    parseIR(graph_string, &*graph);
+
+    auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
+    auto b = at::rand({10, 6}, TensorOptions(kCPU).dtype(at::kFloat))
+                 .index({Slice(None, None, 2), Slice(None, None, 2)});
+    auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
+    auto ref = a * (a * b);
+    TensorExprKernel k(graph);
+    std::vector<at::Tensor> inputs = {a, b};
+    Stmt* s = k.getCodeGenStmt();
+
+    std::vector<IValue> stack = fmap<IValue>(inputs);
+    k.run(stack);
+    o = stack[0].toTensor();
+    for (size_t i = 0; i < 5 * 3; i++) {
+      CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
+    }
+  }
+  {
+    KernelScope kernel_scope;
+
+    const auto graph_string = R"IR(
+      graph(%0 : Float(8:8, 8:1, device=cpu),
+            %1 : Float(8:8, 8:1, device=cpu)):
+        %2 : Tensor = aten::mul(%0, %1)
+        %3 : Tensor, %4 : Tensor = prim::ConstantChunk[dim=1,chunks=2](%2)
+        %r : Tensor = aten::mul(%3, %4)
+        return (%r))IR";
+    auto graph = std::make_shared<Graph>();
+    parseIR(graph_string, &*graph);
+
+    auto a = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat));
+    auto b = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat));
+    auto o = at::zeros({8, 4}, TensorOptions(kCPU).dtype(at::kFloat));
+    auto t = torch::chunk(a * b, 2, 1);
+    auto ref = t[0] * t[1];
+    TensorExprKernel k(graph);
+    std::vector<at::Tensor> inputs = {a, b};
+    Stmt* s = k.getCodeGenStmt();
+
+    std::vector<IValue> stack = fmap<IValue>(inputs);
+    k.run(stack);
+    o = stack[0].toTensor();
+    CHECK_EQ(o.sizes()[0], 8);
+    CHECK_EQ(o.sizes()[1], 4);
+    for (size_t i = 0; i < 8 * 4; i++) {
+      CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
+    }
+  }
+  {
+    // Test that shape inference handles aten::unsqueeze
+    KernelScope kernel_scope;
+
+    const auto graph_string = R"IR(
+      graph(%a : Float(4:2, 2:1, device=cpu),
+            %b : Float(4:6, 3:2, 2:1, device=cpu),
+            %c : Float(3:4, 2:2, 2:1, device=cpu)):
+        %one : int = prim::Constant[value=1]()
+        %minus_one : int = prim::Constant[value=-1]()
+        %three : int = prim::Constant[value=3]()
+        %minus_four : int = prim::Constant[value=-4]()
+        %a1 : Tensor = aten::unsqueeze(%a, %one)        # new size: [4,1,2]
+        %a2 : Tensor = aten::unsqueeze(%a1, %minus_one) # new size: [4,1,2,1]
+        %b1 : Tensor = aten::unsqueeze(%b, %three)      # new size: [4,3,2,1]
+        %c1 : Tensor = aten::unsqueeze(%c, %minus_four) # new size: [1,3,2,2]
+        %ab : Tensor = aten::mul(%a2, %b1)         # expected size: [4,3,2,1]
+        %abc : Tensor = aten::mul(%ab, %c1)        # expected size: [4,3,2,2]
+        return (%abc))IR";
+    auto graph = std::make_shared<Graph>();
+    parseIR(graph_string, &*graph);
+
+    auto a = at::rand({4, 2}, TensorOptions(kCPU).dtype(at::kFloat));
+    auto b = at::rand({4, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat));
+    auto c = at::rand({3, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
+    auto o = at::zeros({4, 3, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
+    auto ref = at::unsqueeze(at::unsqueeze(a, 1), -1) * at::unsqueeze(b, 3) *
+        at::unsqueeze(c, -4);
+
+    TensorExprKernel k(graph);
+    std::vector<at::Tensor> inputs = {a, b, c};
+    Stmt* s = k.getCodeGenStmt();
+
+    std::vector<IValue> stack = fmap<IValue>(inputs);
+    k.run(stack);
+    o = stack[0].toTensor();
+
+    // Check sizes
+    CHECK_EQ(o.sizes().size(), ref.sizes().size());
+    size_t num_el = 1;
+    for (auto idx = 0; idx < ref.sizes().size(); idx++) {
+      CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]);
+      num_el *= ref.sizes()[idx];
+    }
+
+    // Check the contents
+    for (size_t i = 0; i < num_el; i++) {
+      CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
+    }
+  }
+}
+
 } // namespace jit
 } // namespace torch
diff --git a/test/cpp/tensorexpr/tests.h b/test/cpp/tensorexpr/tests.h
index b40f1b8..d79cf90 100644
--- a/test/cpp/tensorexpr/tests.h
+++ b/test/cpp/tensorexpr/tests.h
@@ -193,6 +193,7 @@
   _(Kernel_1)                               \
   _(Kernel_2)                               \
   _(Kernel_3)                               \
+  _(Kernel_4)                               \
   _(FuserPass_1)                            \
   _(FuserPass_2)
 
diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp
index 6f9a447..d2f6420 100644
--- a/torch/csrc/jit/tensorexpr/kernel.cpp
+++ b/torch/csrc/jit/tensorexpr/kernel.cpp
@@ -56,7 +56,7 @@
   return static_cast<at::ScalarType>(t->body()->dtype().scalar_type());
 }
 
-static std::vector<ExprHandle> texprSizes(
+std::vector<ExprHandle> TensorExprKernel::sizesFromVaryingShape(
     const c10::VaryingShape<int64_t>& shape) {
   std::vector<ExprHandle> dims;
   for (size_t i = 0; i < *shape.size(); i++) {
@@ -65,27 +65,162 @@
   return dims;
 }
 
-static std::vector<DimArg> texprDims(const torch::jit::Value* v) {
-  if (v->type()->kind() != TypeKind::TensorType) {
-    throw malformed_input("type is not Tensor");
-  }
-
-  auto tt = v->type()->cast<TensorType>();
+std::vector<DimArg> TensorExprKernel::dimsFromSizes(
+    const std::vector<ExprHandle>& sizes) {
   std::vector<DimArg> dimArgs;
-  int i = 0;
-  for (auto const& s : texprSizes(tt->sizes())) {
-    dimArgs.emplace_back(DimArg(s, "i" + c10::to_string(i++)));
+  for (size_t idx = 0; idx < sizes.size(); idx++) {
+    dimArgs.emplace_back(DimArg(sizes[idx], "i" + c10::to_string(idx)));
   }
   return dimArgs;
 }
 
-template <typename T>
-int64_t bufferSize(T t) {
-  int64_t size = 1;
-  for (int i = 0; i < t.ndim(); i++) {
-    size *= t.dim(i).template AsNode<IntImm>()->value();
+std::vector<ExprHandle> TensorExprKernel::sizesForValue(
+    const torch::jit::Value* v) {
+  if (known_sizes_.count(v)) {
+    return known_sizes_.at(v);
   }
-  return size;
+
+  // If the shape is present in the type info, just extract it from here. No
+  // need to infer it.
+  if (v->type()->kind() == TypeKind::TensorType) {
+    auto tt = v->type()->cast<TensorType>();
+    if (tt->isComplete()) {
+      return sizesFromVaryingShape(tt->sizes());
+    }
+  }
+
+  known_sizes_[v] = inferSizesForValue(v);
+  return known_sizes_.at(v);
+}
+
+std::vector<ExprHandle> TensorExprKernel::inferSizesForValue(
+    const torch::jit::Value* v) {
+  switch (v->node()->kind()) {
+    case aten::_cast_Float:
+    case aten::sigmoid:
+    case aten::reciprocal:
+    case aten::neg:
+    case aten::relu:
+    case aten::log:
+    case aten::log10:
+    case aten::log2:
+    case aten::exp:
+    case aten::expm1:
+    case aten::erf:
+    case aten::erfc:
+    case aten::cos:
+    case aten::sin:
+    case aten::tan:
+    case aten::rand_like:
+    case aten::acos:
+    case aten::asin:
+    case aten::cosh:
+    case aten::sinh:
+    case aten::atan:
+    case aten::tanh:
+    case aten::sqrt:
+    case aten::rsqrt:
+    case aten::abs:
+    case aten::ceil:
+    case aten::floor:
+    case aten::round:
+    case aten::trunc:
+    case aten::frac:
+    case aten::lgamma:
+      return sizesForValue(v->node()->input());
+
+    case aten::sub:
+    case aten::add:
+    case aten::mul:
+    case aten::div:
+    case aten::__and__:
+    case aten::__or__:
+    case aten::__xor__:
+    case aten::__lshift__:
+    case aten::__rshift__:
+    case aten::eq:
+    case aten::ne:
+    case aten::ge:
+    case aten::gt:
+    case aten::le:
+    case aten::lt:
+    case aten::min:
+    case aten::max:
+    case aten::type_as:
+    case aten::pow:
+    case aten::fmod:
+    case aten::remainder:
+    case aten::atan2:
+    case aten::_sigmoid_backward:
+    case aten::_tanh_backward: {
+      std::vector<std::vector<ExprHandle>> shapes;
+      for (size_t idx = 0; idx < 2; idx++) {
+        torch::jit::Value* inp = v->node()->input(idx);
+        shapes.push_back(sizesForValue(inp));
+      }
+      return broadcastShapes(shapes);
+    }
+
+    case aten::lerp:
+    case aten::clamp:
+    case aten::threshold:
+    case aten::where: {
+      std::vector<std::vector<ExprHandle>> shapes;
+      for (size_t idx = 0; idx < 3; idx++) {
+        torch::jit::Value* inp = v->node()->input(idx);
+        shapes.push_back(sizesForValue(inp));
+      }
+      return broadcastShapes(shapes);
+    }
+
+    case aten::addcmul: {
+      std::vector<std::vector<ExprHandle>> shapes;
+      for (size_t idx = 0; idx < 4; idx++) {
+        torch::jit::Value* inp = v->node()->input(idx);
+        shapes.push_back(sizesForValue(inp));
+      }
+      return broadcastShapes(shapes);
+    }
+
+    case prim::ConstantChunk: {
+      auto shape = sizesForValue(v->node()->input());
+      int dim = v->node()->i(attr::dim);
+      int chunks = v->node()->i(attr::chunks);
+      shape[dim] = IRSimplifier::simplify(shape[dim] / chunks);
+      return shape;
+    }
+
+    case aten::unsqueeze: {
+      auto const& n = v->node();
+      auto shape = sizesForValue(n->input(0));
+
+      int64_t dim = constant(n->input(1)).AsNode<IntImm>()->value();
+      // From the documentation
+      // (https://pytorch.org/docs/master/generated/torch.unsqueeze.html):
+      //
+      // A dim value within the range [-input.dim() - 1, input.dim() + 1) can be
+      // used. Negative dim will correspond to unsqueeze() applied at dim = dim
+      // + input.dim() + 1.
+      if (dim < 0) {
+        dim = dim + shape.size() + 1;
+      }
+      if (dim < 0 || dim > shape.size()) {
+        throw std::runtime_error("Invalid 'dim' input in aten::unsqueeze");
+      }
+
+      shape.insert(shape.begin() + dim, ExprHandle(1));
+      return shape;
+    }
+
+    case aten::cat:
+    case aten::slice:
+      throw std::runtime_error(
+          "Shape info is not implemented for this kind of node");
+
+    default: {
+      throw std::runtime_error("Unhandled node kind");
+    }
+  }
 }
 
 ExprHandle TensorExprKernel::constant(const torch::jit::Value* v) {
@@ -159,6 +294,10 @@
     return e;
   }
 
+  if (!v->isCompleteTensor()) {
+    return e;
+  }
+
   auto tt = *v->type()->cast<TensorType>()->scalarType();
 
   if (tt == static_cast<at::ScalarType>(e.dtype().scalar_type())) {
@@ -189,21 +328,32 @@
   return n->value() == 1;
 }
 
-static std::pair<std::vector<ExprHandle>, bool> broadcastShapes(
+std::vector<ExprHandle> TensorExprKernel::broadcastShapes(
+    std::vector<std::vector<ExprHandle>> shapes) {
+  size_t n = shapes.size();
+  if (n == 1) {
+    return shapes[0];
+  }
+  auto res1 = broadcastShapes(shapes[n - 2], shapes[n - 1]);
+  shapes[n - 2] = res1;
+  shapes.pop_back();
+  auto res2 = broadcastShapes(shapes);
+  return res2;
+}
+std::vector<ExprHandle> TensorExprKernel::broadcastShapes(
     const std::vector<ExprHandle>& a,
     const std::vector<ExprHandle>& b) {
-  bool broadcast = false;
   auto at = a.rbegin();
   auto bt = b.rbegin();
   std::vector<ExprHandle> ret;
   while (at != a.rend() || bt != b.rend()) {
     if (at == a.rend()) {
-      broadcast = true;
+      hasBroadcast_ = true;
       ret.push_back(*bt++);
       continue;
     }
     if (bt == b.rend()) {
-      broadcast = true;
+      hasBroadcast_ = true;
       ret.push_back(*at++);
       continue;
     }
@@ -214,7 +364,7 @@
     if (isOne(*at)) {
       if (!isOne(*bt)) {
         dim = *bt;
-        broadcast = true;
+        hasBroadcast_ = true;
       }
     }
     ret.push_back(dim);
@@ -222,17 +372,7 @@
     bt++;
   }
   std::reverse(ret.begin(), ret.end());
-  return {ret, broadcast};
-}
-
-template <typename... Args>
-static std::pair<std::vector<ExprHandle>, bool> broadcastShapes(
-    const std::vector<ExprHandle>& a,
-    const std::vector<ExprHandle>& b,
-    Args... args) {
-  auto const& res = broadcastShapes(a, b);
-  auto const& res2 = broadcastShapes(res.first, args...);
-  return {res2.first, res.second || res2.second};
+  return ret;
 }
 
 std::vector<ExprHandle> TensorExprKernel::valueShape(
@@ -270,10 +410,8 @@
     const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
         innerExpr) {
   auto const& n = v->node();
-  auto const& res =
+  auto const& shape =
       broadcastShapes(valueShape(n->inputs()[0]), valueShape(n->inputs()[1]));
-  auto const& shape = res.first;
-  hasBroadcast_ |= res.second;
   return Compute(
       name,
       c10::fmap<DimArg>(shape),
@@ -296,10 +434,8 @@
     const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
         innerExpr) {
   auto const& n = v->node();
-  auto const& res =
+  auto const& shape =
       broadcastShapes(valueShape(n->inputs()[0]), valueShape(n->inputs()[1]));
-  auto const& shape = res.first;
-  hasBroadcast_ |= res.second;
   return Compute(
       name,
       c10::fmap<DimArg>(shape),
@@ -324,12 +460,12 @@
         ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>&
         innerExpr) {
   auto const& n = v->node();
-  auto const& res = broadcastShapes(
-      valueShape(n->inputs()[0]),
-      valueShape(n->inputs()[1]),
-      valueShape(n->inputs()[2]));
-  auto const& shape = res.first;
-  hasBroadcast_ |= res.second;
+  std::vector<std::vector<ExprHandle>> shapes;
+  for (size_t idx = 0; idx < 2; idx++) {
+    torch::jit::Value* inp = n->input(idx);
+    shapes.push_back(sizesForValue(inp));
+  }
+  auto const& shape = broadcastShapes(shapes);
   return Compute(
       name,
       c10::fmap<DimArg>(shape),
@@ -355,12 +491,12 @@
         ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>&
         innerExpr) {
   auto const& n = v->node();
-  auto const& res = broadcastShapes(
-      valueShape(n->inputs()[0]),
-      valueShape(n->inputs()[1]),
-      valueShape(n->inputs()[2]));
-  auto const& shape = res.first;
-  hasBroadcast_ |= res.second;
+  std::vector<std::vector<ExprHandle>> shapes;
+  for (size_t idx = 0; idx < 3; idx++) {
+    torch::jit::Value* inp = n->input(idx);
+    shapes.push_back(sizesForValue(inp));
+  }
+  auto const& shape = broadcastShapes(shapes);
   return Compute(
       name,
       c10::fmap<DimArg>(shape),
@@ -387,13 +523,12 @@
         const ExprHandle&,
         const ExprHandle&)>& innerExpr) {
   auto const& n = v->node();
-  auto const& res = broadcastShapes(
-      valueShape(n->inputs()[0]),
-      valueShape(n->inputs()[1]),
-      valueShape(n->inputs()[2]),
-      valueShape(n->inputs()[3]));
-  auto const& shape = res.first;
-  hasBroadcast_ |= res.second;
+  std::vector<std::vector<ExprHandle>> shapes;
+  for (size_t idx = 0; idx < 4; idx++) {
+    torch::jit::Value* inp = n->input(idx);
+    shapes.push_back(sizesForValue(inp));
+  }
+  auto const& shape = broadcastShapes(shapes);
   return Compute(
       name,
       c10::fmap<DimArg>(shape),
@@ -871,7 +1006,7 @@
     case prim::ConstantChunk: {
       return Compute(
           "prim_constantchunk",
-          texprDims(v),
+          dimsFromSizes(sizesForValue(v)),
           [this, v](const std::vector<VarHandle>& axes) {
             auto const& n = v->node();
             int64_t dim = n->i(attr::dim);
@@ -888,7 +1023,7 @@
     case aten::cat: {
       return Compute(
           "aten_cat",
-          texprDims(v),
+          dimsFromSizes(sizesForValue(v)),
           [this, v](const std::vector<VarHandle>& axes) {
             auto const& n = v->node();
             auto inputs = n->inputs()[0]->node()->inputs();
@@ -911,11 +1046,10 @@
             return load;
           });
     }
-
     case aten::slice: {
       return Compute(
           "aten_slice",
-          texprDims(v),
+          dimsFromSizes(sizesForValue(v)),
           [this, v](const std::vector<VarHandle>& axes) {
             auto const& n = v->node();
             int dim = constant(n->inputs()[1]).AsNode<IntImm>()->value();
@@ -931,7 +1065,7 @@
     case aten::unsqueeze: {
       return Compute(
           "aten_unsqueeze",
-          texprDims(v),
+          dimsFromSizes(sizesForValue(v)),
           [this, v](const std::vector<VarHandle>& axes) {
             auto const& n = v->node();
             int64_t dim = constant(n->inputs()[1]).AsNode<IntImm>()->value();
@@ -939,13 +1073,21 @@
               if (axes.size() == 0) {
                 throw malformed_input("axes are zero handling unsqueeze");
               }
-
-              dim += axes.size() - 1;
+              dim += axes.size();
+            }
+            // To construct an expression for an 'unsqueezed' tensor we need to
+            // drop the DIM-th axis, i.e.
+            //    unsqueezed_v[i,j,k,l] = v[i,j,l] # dim = 2 - drop index 'k'
+            //                 0 1 2 3
+            std::vector<ExprHandle> indices;
+            int64_t i = 0;
+            for (auto a : axes) {
+              if (i++ != dim) {
+                indices.emplace_back(ExprHandle(a.node()));
+              }
             }
 
-            std::vector<ExprHandle> newAxes(axes.begin(), axes.end());
-            newAxes.erase(newAxes.begin() + dim);
-            return tensorOrConstant(n->inputs()[0], newAxes);
+            return tensorOrConstant(n->inputs()[0], indices);
           });
     }
 
diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h
index bb2144e..548e2a7 100644
--- a/torch/csrc/jit/tensorexpr/kernel.h
+++ b/torch/csrc/jit/tensorexpr/kernel.h
@@ -66,6 +66,18 @@
 
   void runKernel(Stack& stack);
 
+  std::vector<DimArg> dimsFromSizes(const std::vector<ExprHandle>& sizes);
+  std::vector<ExprHandle> sizesForValue(const torch::jit::Value* v);
+  std::vector<ExprHandle> inferSizesForValue(const torch::jit::Value* v);
+  std::vector<ExprHandle> sizesFromVaryingShape(
+      const c10::VaryingShape<int64_t>& shape);
+
+  std::vector<ExprHandle> broadcastShapes(
+      const std::vector<ExprHandle>& a,
+      const std::vector<ExprHandle>& b);
+  std::vector<ExprHandle> broadcastShapes(
+      std::vector<std::vector<ExprHandle>> shapes);
+
   ExprHandle constant(const torch::jit::Value* v);
 
   template <typename T, typename T1>
@@ -219,6 +231,8 @@
   bool fallback_{false};
   bool hasRandom_{false};
   bool hasBroadcast_{false};
+  std::unordered_map<const torch::jit::Value*, std::vector<ExprHandle>>
+      known_sizes_;
 };
 
 TORCH_API int& getTECudaPointwiseLoopLevels();