Compute values as shapes where requested by shape function

Detect cases where shape functions require input operand as shape. If the operand is not known, then attempt to compute all the dimensions of the required shape (e.g., this allows computing [?,1] as output rather than failing to compute the entire output). This adds a simple partial computation evaluator of inputs. These are currently local functions but the intention is to extract them into general traits rather than code the op specific behavior here.

PiperOrigin-RevId: 309865790
Change-Id: I4458f4694e775132ba218cfc5f292edcd0001c64
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
index caac814..160bba9 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
@@ -401,4 +401,13 @@
     %13 = "tf.Neg"(%12) {device = ""} : (tensor<4xf32>) -> tensor<4xf32>
     return
   }
+
+  // CHECK-LABEL: operand_as_shape
+  func @operand_as_shape(%18: tensor<i32>, %39: tensor<1x4x4x32xf32>) -> () {
+    %cst_5 = constant dense<512> : tensor<i32>
+    %19 = "tf.Pack"(%18, %cst_5) {N = 2 : i64, T = i32, axis = 0 : i64, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<2xi32>
+    // CHECK: -> tensor<1x512xf32>
+    %40 = "tf.Reshape"(%39, %19) {T = f32, Tshape = i32, device = ""} : (tensor<1x4x4x32xf32>, tensor<2xi32>) -> tensor<?x?xf32>
+   return
+  }
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
index 30e11e9..789088b 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
@@ -19,6 +19,8 @@
 #include <initializer_list>
 #include <iterator>
 
+#include "llvm/ADT/Hashing.h"
+#include "llvm/ADT/PointerUnion.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/iterator_range.h"
@@ -26,6 +28,7 @@
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
+#include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Block.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
@@ -56,6 +59,9 @@
 #define DEBUG_TYPE "tf-shape-inference"
 
 using ::tensorflow::int64;
+using tensorflow::shape_inference::DimensionHandle;
+using tensorflow::shape_inference::InferenceContext;
+using tensorflow::shape_inference::ShapeHandle;
 
 namespace mlir {
 namespace TF {
@@ -345,6 +351,191 @@
 
 }  // namespace
 
+// Combination of value producer and port of value produced (e.g.,
+//   <value result output>:<value in output tensor>,
+// so for tf.Const -> tensor<10x20xf32>, [0,2,18] would point to a unique output
+// scalar value).
+struct ValuePort {
+  llvm::PointerUnion<Operation*, BlockArgument> producer;
+  SmallVector<unsigned int, 2> port;
+
+  bool operator==(const ValuePort& other) const {
+    return producer == other.producer && port == other.port;
+  }
+
+  // Convert output value to ValuePort.
+  explicit ValuePort(Value v) {
+    OpResult opr = v.dyn_cast<OpResult>();
+    if (opr) {
+      producer = opr.getOwner();
+      port = {opr.getResultNumber()};
+    } else {
+      producer = v.cast<BlockArgument>();
+      port = {0};
+    }
+  }
+  ValuePort(llvm::PointerUnion<Operation*, BlockArgument> producer,
+            SmallVector<unsigned int, 2> port)
+      : producer(producer), port(port) {}
+
+  llvm::raw_ostream& print(llvm::raw_ostream& os) const {
+    if (auto* op = producer.dyn_cast<Operation*>())
+      os << "op " << op->getName();
+    if (auto ba = producer.dyn_cast<BlockArgument>())
+      os << "block_arg " << ba.getArgNumber();
+    os << llvm::formatv(" [{0}]", llvm::make_range(port.begin(), port.end()));
+    return os;
+  }
+};
+
+struct ValuePortHasher {
+  std::size_t operator()(const ValuePort& other) const {
+    return llvm::hash_combine(
+        llvm::hash_value(other.producer.getOpaqueValue()),
+        llvm::hash_value(ArrayRef<unsigned int>(other.port)));
+  }
+};
+
+using ValuePortResultMap =
+    std::unordered_map<ValuePort, Attribute, ValuePortHasher>;
+using ComputedQueryFn = llvm::function_ref<bool(ValuePort)>;
+using ValueQueryFn = llvm::function_ref<Attribute(const ValuePort&)>;
+using ValuePortInputs = llvm::SmallVectorImpl<ValuePort>;
+
+// TODO(jpienaar): InputsRequiredForOutput and ComputeOutputComponent are
+// intended to be switched to op interfaces once more refined.
+LogicalResult InputsRequiredForOutput(ValuePort value_port,
+                                      ComputedQueryFn has_been_computed,
+                                      ValuePortInputs* inputs) {
+  auto op = value_port.producer.dyn_cast<Operation*>();
+  auto& port = value_port.port;
+  if (!op) return failure();
+
+  // No inputs required for constants.
+  if (matchPattern(op, m_Constant())) return success();
+
+  // Note: this focusses only on the trivial pack op case and this could be
+  // generalized.
+  if (auto pack_op = dyn_cast<TF::PackOp>(op)) {
+    if (pack_op.getType().cast<TensorType>().getRank() != 1) return failure();
+    if (port.size() != 2) return failure();
+    assert(port[0] == 0);
+    ValuePort req(pack_op.getOperand(port[1]));
+    if (!has_been_computed(req)) inputs->push_back(req);
+    return success();
+  }
+
+  return failure();
+}
+
+// Computes the output produced by ValuePort using the query function of
+// existing computed values.
+Attribute ComputeOutputComponent(const ValuePort& value_port,
+                                 ValueQueryFn values) {
+  LLVM_DEBUG(value_port.print(llvm::errs() << "\nComputing output for "));
+
+  auto op = value_port.producer.dyn_cast<Operation*>();
+  if (!op) return nullptr;
+  auto& port = value_port.port;
+
+  if (port.empty()) {
+    LLVM_DEBUG(llvm::dbgs() << "skipping, port outside spec of " << op << "\n");
+    return nullptr;
+  }
+
+  ElementsAttr attr;
+  if (matchPattern(op, m_Constant(&attr))) {
+    if (port.size() == 1 && port[0] == 0) return attr;
+    return nullptr;
+  }
+
+  // Note: this focusses only on the trivial pack op case and this could be
+  // generalized.
+  if (auto pack_op = dyn_cast<TF::PackOp>(op)) {
+    if (pack_op.getType().cast<TensorType>().getRank() != 1) return nullptr;
+    if (port.size() != 2 || port[0] != 0) return nullptr;
+    ValuePort op_port(op->getOperand(port[1]));
+    return values(op_port);
+  }
+  return nullptr;
+}
+
+ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic) {
+  LLVM_DEBUG(result.print(llvm::dbgs() << "\nEvaluate partially "));
+  auto rt = result.getType().dyn_cast<RankedTensorType>();
+  if (!rt || !rt.hasStaticShape() || rt.getRank() != 1) return {};
+  int dim_size = rt.getDimSize(0);
+
+  // Worklist to direct partial evaluation.
+  llvm::SmallVector<ValuePort, 4> worklist;
+  // The ValuePort evaluated results.
+  // TODO(jpienaar): This could be cached across invocations (e.g., part of some
+  // inference context).
+  ValuePortResultMap evaluated;
+  // Returns whether a ValuePort has been previously computed.
+  auto has_been_computed = [&evaluated](const ValuePort& port) {
+    return evaluated.find(port) != evaluated.end();
+  };
+  // Returns previously computed ValuePort value.
+  auto values = [&evaluated](const ValuePort& port) -> Attribute {
+    return evaluated[port];
+  };
+
+  // Simple evaluator that attempts to partially evaluate the input value even
+  // if unable to evaluate the complete output. Below follows a simple stack
+  // based evaluation where it queries what operands/part of operands need to
+  // be evaluated and attempting to partially evaluate those operands. It does
+  // so by pushing the operands that need to be required on to the worklist
+  // before enqueuing the operation requiering those values.
+  std::vector<DimensionHandle> dims(dim_size, ic->UnknownDim());
+  for (unsigned int i = 0, e = dims.size(); i != e; ++i) {
+    LLVM_DEBUG(llvm::dbgs() << "\nConsidering output dim " << i << "\n");
+
+    worklist.push_back(
+        ValuePort{result.getOwner(), {result.getResultNumber(), i}});
+    while (!worklist.empty()) {
+      auto front = worklist.pop_back_val();
+      LLVM_DEBUG(front.print(llvm::errs() << "\nWorklist front "));
+
+      SmallVector<ValuePort, 4> inputs;
+      auto res = InputsRequiredForOutput(front, has_been_computed, &inputs);
+      if (failed(res)) {
+        // Abort if unable to find which required inputs need to be computed.
+        worklist.clear();
+        break;
+      }
+
+      if (!inputs.empty()) {
+        // Enqueue required computation followed by its required operands in
+        // stack.
+        worklist.push_back(std::move(front));
+        for (auto& it : inputs) worklist.push_back(std::move(it));
+        continue;
+      }
+
+      auto ret = ComputeOutputComponent(front, values);
+      if (!ret) continue;
+
+      evaluated[front] = ret;
+      LLVM_DEBUG(ret.print(llvm::dbgs() << "\ncomputed result = "));
+
+      // If worklist is empty, then this is the root query op.
+      if (worklist.empty()) {
+        LLVM_DEBUG(llvm::dbgs() << "[root node]\n");
+        if (auto dea = ret.dyn_cast<mlir::DenseIntElementsAttr>()) {
+          if (dea.getNumElements() != 1) {
+            LLVM_DEBUG(llvm::errs() << "Unexpected number of elements\n");
+            return {};
+          }
+          int64_t val = (*dea.getIntValues().begin()).getSExtValue();
+          dims[i] = ic->MakeDim(val);
+        }
+      }
+    }
+  }
+  return ic->MakeShape(dims);
+}
+
 bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
                                   int64_t graph_version) {
   assert(tf_dialect == op->getDialect());
@@ -455,9 +646,9 @@
   // Perform the shape inference using an InferenceContext with the input
   // shapes. This object is abstracting the information that the ShapeInference
   // function operates on.
-  tensorflow::shape_inference::InferenceContext c(
-      graph_version, *node_def, op_reg_data->op_def, input_shapes,
-      input_tensors, /*input_tensors_as_shapes=*/{}, handle_shapes_and_types);
+  InferenceContext c(graph_version, *node_def, op_reg_data->op_def,
+                     input_shapes, input_tensors,
+                     /*input_tensors_as_shapes=*/{}, handle_shapes_and_types);
   auto status = c.Run(op_reg_data->shape_inference_fn);
   if (!status.ok()) {
     LLVM_DEBUG(llvm::dbgs() << "Shape inference error for '" << *op
@@ -465,6 +656,43 @@
     return false;
   }
 
+  // Determine if, during shape computation, the shape functions attempted to
+  // query an input operand as shape where the input was not known/constant.
+  bool requires_inputs =
+      llvm::any_of(llvm::seq<int>(0, c.num_inputs()), [&](int input) {
+        return c.requested_input_tensor_as_partial_shape(input) &&
+               !input_tensors[input];
+      });
+  if (requires_inputs) {
+    std::vector<ShapeHandle> input_tensors_as_shapes;
+    for (int input : llvm::seq<int>(0, c.num_inputs())) {
+      if (c.requested_input_tensor_as_partial_shape(input) &&
+          !input_tensors[input]) {
+        auto op_result = op->getOperand(input).dyn_cast<OpResult>();
+        if (!op_result) continue;
+        // Resize on first valid shape computed.
+        input_tensors_as_shapes.resize(c.num_inputs());
+        auto handle = ComputeOutputAsShape(op_result, &c);
+        LLVM_DEBUG(llvm::dbgs() << "Requested " << input << " as shape "
+                                << (handle.Handle() ? "found" : "not found"));
+        if (handle.Handle()) input_tensors_as_shapes[input] = handle;
+      }
+    }
+
+    // Attempt to compute the unknown operands as shapes.
+    // Note: in the case where no partial outputs could be computed, this would
+    // be empty.
+    if (!input_tensors_as_shapes.empty()) {
+      c.set_input_tensors_as_shapes(input_tensors_as_shapes);
+      auto status = c.Run(op_reg_data->shape_inference_fn);
+      if (!status.ok()) {
+        LLVM_DEBUG(llvm::dbgs() << "Shape inference error for '" << *op
+                                << "': " << status.error_message() << "\n");
+        return false;
+      }
+    }
+  }
+
   assert(c.num_outputs() == op->getNumResults() &&
          "inference context matches the MLIR number of results.");
 
@@ -477,12 +705,11 @@
     if (!CanBeRefined(result.getType())) continue;
     auto shaped_type = result.getType().cast<ShapedType>();
 
-    tensorflow::shape_inference::ShapeHandle shape_handle = c.output(output);
+    ShapeHandle shape_handle = c.output(output);
     LLVM_DEBUG(llvm::dbgs() << "Inferred output " << output << " : "
                             << c.DebugString(shape_handle) << "\n");
-    auto get_tensor_type =
-        [&c](const tensorflow::shape_inference::ShapeHandle& sh,
-             Type element_type) -> TensorType {
+    auto get_tensor_type = [&c](const ShapeHandle& sh,
+                                Type element_type) -> TensorType {
       if (!c.RankKnown(sh)) return UnrankedTensorType::get(element_type);
       // Convert the shape from TensorFlow (int64) to MLIR (int64_t).
       SmallVector<int64_t, 8> shape;