Add InferTypeOpTrait & enable generating its member function definition

Use OpInterfaces to add an interface for ops defining a return type function.

This change does not use this trait in any meaningful way, I'll use it in a
follow up to generalize and unify some of the op type traits/constraints. Also,
currently the infer type function can only be manually specified in C++, that should rather be the fallback in future.

PiperOrigin-RevId: 271883746
diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD
index b17761b..80b00f0 100644
--- a/third_party/mlir/BUILD
+++ b/third_party/mlir/BUILD
@@ -117,6 +117,7 @@
     deps = [
         ":CallOpInterfacesIncGen",
         ":DialectSymbolRegistry",
+        ":InferTypeOpInterfaceIncGen",
         ":Support",
         "@llvm//:support",
     ],
@@ -1223,6 +1224,26 @@
     ],
 )
 
+gentbl(
+    name = "InferTypeOpInterfaceIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            "-gen-op-interface-decls",
+            "include/mlir/Analysis/InferTypeOpInterface.h.inc",
+        ),
+        (
+            "-gen-op-interface-defs",
+            "include/mlir/Analysis/InferTypeOpInterface.cpp.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Analysis/InferTypeOpInterface.td",
+    td_srcs = [
+        ":OpBaseTdFiles",
+    ],
+)
+
 cc_library(
     name = "Analysis",
     srcs = [
@@ -1230,6 +1251,7 @@
         "lib/Analysis/AffineStructures.cpp",
         "lib/Analysis/CallGraph.cpp",
         "lib/Analysis/Dominance.cpp",
+        "lib/Analysis/InferTypeOpInterface.cpp",
         "lib/Analysis/LoopAnalysis.cpp",
         "lib/Analysis/MemRefBoundCheck.cpp",
         "lib/Analysis/NestedMatcher.cpp",
@@ -1247,6 +1269,7 @@
         "include/mlir/Analysis/CallGraph.h",
         "include/mlir/Analysis/CallInterfaces.h",
         "include/mlir/Analysis/Dominance.h",
+        "include/mlir/Analysis/InferTypeOpInterface.h",
         "include/mlir/Analysis/LoopAnalysis.h",
         "include/mlir/Analysis/NestedMatcher.h",
         "include/mlir/Analysis/Passes.h",
@@ -1260,6 +1283,7 @@
         ":AffineOps",
         ":CallOpInterfacesIncGen",
         ":IR",
+        ":InferTypeOpInterfaceIncGen",
         ":LoopOps",
         ":Pass",
         ":StandardOps",
diff --git a/third_party/mlir/include/mlir/Analysis/CMakeLists.txt b/third_party/mlir/include/mlir/Analysis/CMakeLists.txt
index 619f4b1..3d9a7ed 100644
--- a/third_party/mlir/include/mlir/Analysis/CMakeLists.txt
+++ b/third_party/mlir/include/mlir/Analysis/CMakeLists.txt
@@ -2,3 +2,8 @@
 mlir_tablegen(CallInterfaces.h.inc -gen-op-interface-decls)
 mlir_tablegen(CallInterfaces.cpp.inc -gen-op-interface-defs)
 add_public_tablegen_target(MLIRCallOpInterfacesIncGen)
+
+set(LLVM_TARGET_DEFINITIONS InferTypeOpInterface.td)
+mlir_tablegen(InferTypeOpInterface.h.inc -gen-op-interface-decls)
+mlir_tablegen(InferTypeOpInterface.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRTypeInferOpInterfaceIncGen)
diff --git a/third_party/mlir/include/mlir/Analysis/InferTypeOpInterface.h b/third_party/mlir/include/mlir/Analysis/InferTypeOpInterface.h
new file mode 100644
index 0000000..b80723e
--- /dev/null
+++ b/third_party/mlir/include/mlir/Analysis/InferTypeOpInterface.h
@@ -0,0 +1,40 @@
+//===- InferTypeOpInterface.h - Infer Type Interfaces -----------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains the definitions of the infer op interfaces defined in
+// `InferTypeOpInterface.td`.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_INFERTYPEOPINTERFACE_H_
+#define MLIR_ANALYSIS_INFERTYPEOPINTERFACE_H_
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+
+#include "mlir/Analysis/InferTypeOpInterface.h.inc"
+
+} // namespace mlir
+
+#endif // MLIR_ANALYSIS_INFERTYPEOPINTERFACE_H_
diff --git a/third_party/mlir/include/mlir/Analysis/InferTypeOpInterface.td b/third_party/mlir/include/mlir/Analysis/InferTypeOpInterface.td
new file mode 100644
index 0000000..a155810
--- /dev/null
+++ b/third_party/mlir/include/mlir/Analysis/InferTypeOpInterface.td
@@ -0,0 +1,63 @@
+//===- InferTypeOpInterface.td - Infer Type interfaces -*- tablegen -----*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains a set of interfaces that can be used to define information
+// related to call-like and callable operations. Each of which are defined along
+// with the respective interface below.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef MLIR_INFERTYPEOPINTERFACE
+#else
+#define MLIR_INFERTYPEOPINTERFACE
+
+#ifdef OP_BASE
+#else
+include "mlir/IR/OpBase.td"
+#endif // OP_BASE
+
+// OpInterface to compute the return type of an operation. The arguments match
+// those in Operation::create with the exception that the location is optional
+// (if no location is provided, then the method will not emit an error on
+// mismatch).
+def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
+  let description = [{
+    Interface to access a registered method to infer the return types for an
+    operation that could be used during op construction, verification or
+    type inference.
+  }];
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{Returns the return types that an op would generate.
+
+      The method takes an optional location which, if set, will be used to
+      report errors on. The operands and attributes correspond to those with
+      which an Operation would be created (e.g., as used in Operation;:create).
+      Regions are the nested regions of the op.
+      }],
+      /*retTy=*/"SmallVector<Type, 2>",
+      /*methodName=*/"inferReturnTypes",
+      /*args=*/(ins "llvm::Optional<Location>":$location,
+                    "ArrayRef<Value*>":$operands,
+                    "ArrayRef<NamedAttribute>":$attributes,
+                    "ArrayRef<Region>":$regions)
+    >,
+  ];
+}
+
+#endif // MLIR_INFERTYPEOPINTERFACE
diff --git a/third_party/mlir/lib/Analysis/CMakeLists.txt b/third_party/mlir/lib/Analysis/CMakeLists.txt
index dff3df9..c16ad3f 100644
--- a/third_party/mlir/lib/Analysis/CMakeLists.txt
+++ b/third_party/mlir/lib/Analysis/CMakeLists.txt
@@ -3,6 +3,7 @@
   AffineStructures.cpp
   CallGraph.cpp
   Dominance.cpp
+  InferTypeOpInterface.cpp
   LoopAnalysis.cpp
   MemRefBoundCheck.cpp
   NestedMatcher.cpp
@@ -20,6 +21,7 @@
 add_dependencies(MLIRAnalysis
   MLIRAffineOps
   MLIRCallOpInterfacesIncGen
+  MLIRTypeInferOpInterfaceIncGen
   MLIRLoopOps
   MLIRVectorOps
   )
diff --git a/third_party/mlir/lib/Analysis/InferTypeOpInterface.cpp b/third_party/mlir/lib/Analysis/InferTypeOpInterface.cpp
new file mode 100644
index 0000000..cbbd446
--- /dev/null
+++ b/third_party/mlir/lib/Analysis/InferTypeOpInterface.cpp
@@ -0,0 +1,31 @@
+//===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains the definitions of the infer op interfaces defined in
+// `InferTypeOpInterface.td`.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/InferTypeOpInterface.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Types.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+#include "mlir/Analysis/InferTypeOpInterface.cpp.inc"
+} // namespace mlir
diff --git a/third_party/mlir/test/BUILD b/third_party/mlir/test/BUILD
index a7d7848..6700843 100644
--- a/third_party/mlir/test/BUILD
+++ b/third_party/mlir/test/BUILD
@@ -39,6 +39,7 @@
     td_srcs = [
         "@local_config_mlir//:include/mlir/IR/OpBase.td",
         "@local_config_mlir//:include/mlir/Analysis/CallInterfaces.td",
+        "@local_config_mlir//:include/mlir/Analysis/InferTypeOpInterface.td",
     ],
 )
 
@@ -54,11 +55,10 @@
     includes = ["lib/TestDialect"],
     deps = [
         ":TestOpsIncGen",
-        "@llvm//:support",
+        "@local_config_mlir//:Analysis",
         "@local_config_mlir//:Dialect",
         "@local_config_mlir//:IR",
         "@local_config_mlir//:Pass",
-        "@local_config_mlir//:Support",
         "@local_config_mlir//:TransformUtils",
         "@local_config_mlir//:Transforms",
     ],
diff --git a/third_party/mlir/test/lib/TestDialect/TestDialect.cpp b/third_party/mlir/test/lib/TestDialect/TestDialect.cpp
index 69c3bcd..d91bb1a 100644
--- a/third_party/mlir/test/lib/TestDialect/TestDialect.cpp
+++ b/third_party/mlir/test/lib/TestDialect/TestDialect.cpp
@@ -216,6 +216,14 @@
   return operand();
 }
 
+SmallVector<Type, 2> mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
+    llvm::Optional<Location> location, ArrayRef<Value *> operands,
+    ArrayRef<NamedAttribute> attributes, ArrayRef<Region> regions) {
+  if (location)
+    mlir::emitError(*location) << "expected to fail";
+  return SmallVector<Type, 2>{nullptr};
+}
+
 // Static initialization for Test dialect registration.
 static mlir::DialectRegistration<mlir::TestDialect> testDialect;
 
diff --git a/third_party/mlir/test/lib/TestDialect/TestDialect.h b/third_party/mlir/test/lib/TestDialect/TestDialect.h
index a2fceca..ffe2a1c 100644
--- a/third_party/mlir/test/lib/TestDialect/TestDialect.h
+++ b/third_party/mlir/test/lib/TestDialect/TestDialect.h
@@ -24,6 +24,7 @@
 #define MLIR_TESTDIALECT_H
 
 #include "mlir/Analysis/CallInterfaces.h"
+#include "mlir/Analysis/InferTypeOpInterface.h"
 #include "mlir/Dialect/Traits.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
diff --git a/third_party/mlir/test/lib/TestDialect/TestOps.td b/third_party/mlir/test/lib/TestDialect/TestOps.td
index 0c2e53b..72991ce 100644
--- a/third_party/mlir/test/lib/TestDialect/TestOps.td
+++ b/third_party/mlir/test/lib/TestDialect/TestOps.td
@@ -21,6 +21,7 @@
 
 include "mlir/IR/OpBase.td"
 include "mlir/Analysis/CallInterfaces.td"
+include "mlir/Analysis/InferTypeOpInterface.td"
 
 def TEST_Dialect : Dialect {
   let name = "test";
@@ -318,8 +319,7 @@
 def ChildOp : TEST_Op<"child", [HasParent<"ParentOp">]>;
 
 
-def TerminatorOp : TEST_Op<"finish", [Terminator]> {
-}
+def TerminatorOp : TEST_Op<"finish", [Terminator]>;
 def SingleBlockImplicitTerminatorOp : TEST_Op<"SingleBlockImplicitTerminator",
     [SingleBlockImplicitTerminator<"TerminatorOp">]> {
   let regions = (region SizedRegion<1>:$region);
@@ -329,6 +329,18 @@
   let arguments = (ins I32ElementsAttr:$attr);
 }
 
+def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if",
+    [InferTypeOpInterface]> {
+  let arguments = (ins AnyTensor:$x, AnyTensor:$y);
+  let results = (outs AnyTensor:$res);
+  // TODO(jpienaar): Remove the need to specify these here.
+  let extraClassDeclaration = [{
+    SmallVector<Type, 2> inferReturnTypes(llvm::Optional<Location> location,
+      ArrayRef<Value*> operands, ArrayRef<NamedAttribute> attributes,
+      ArrayRef<Region> regions);
+  }];
+}
+
 def IsNotScalar : Constraint<CPred<"$0.getType().getRank() != 0">>;
 
 def UpdateAttr : Pat<(I32ElementsAttrOp $attr),
diff --git a/third_party/mlir/test/lib/TestDialect/TestPatterns.cpp b/third_party/mlir/test/lib/TestDialect/TestPatterns.cpp
index 533ec1f..17a257f 100644
--- a/third_party/mlir/test/lib/TestDialect/TestPatterns.cpp
+++ b/third_party/mlir/test/lib/TestDialect/TestPatterns.cpp
@@ -52,6 +52,43 @@
     pass("test-patterns", "Run test dialect patterns");
 
 //===----------------------------------------------------------------------===//
+// ReturnType Driver.
+//===----------------------------------------------------------------------===//
+
+struct ReturnTypeOpMatch : public RewritePattern {
+  ReturnTypeOpMatch(MLIRContext *ctx)
+      : RewritePattern(OpWithInferTypeInterfaceOp::getOperationName(), 1, ctx) {
+  }
+
+  PatternMatchResult matchAndRewrite(Operation *op,
+                                     PatternRewriter &rewriter) const final {
+    if (auto retTypeFn = dyn_cast<InferTypeOpInterface>(op)) {
+      SmallVector<Value *, 4> values;
+      values.reserve(op->getNumOperands());
+      for (auto &operand : op->getOpOperands())
+        values.push_back(operand.get());
+      (void)retTypeFn.inferReturnTypes(op->getLoc(), values, op->getAttrs(),
+                                       op->getRegions());
+    }
+    return matchFailure();
+  }
+};
+
+namespace {
+struct TestReturnTypeDriver : public FunctionPass<TestReturnTypeDriver> {
+  void runOnFunction() override {
+    mlir::OwningRewritePatternList patterns;
+    populateWithGenerated(&getContext(), &patterns);
+    patterns.insert<ReturnTypeOpMatch>(&getContext());
+    applyPatternsGreedily(getFunction(), patterns);
+  }
+};
+} // end anonymous namespace
+
+static mlir::PassRegistration<TestReturnTypeDriver>
+    rt_pass("test-return-type", "Run return type functions");
+
+//===----------------------------------------------------------------------===//
 // Legalization Driver.
 //===----------------------------------------------------------------------===//