Add InferTypeOpTrait & enable generating its member function definition
Use OpInterfaces to add an interface for ops defining a return type function.
This change does not use this trait in any meaningful way, I'll use it in a
follow up to generalize and unify some of the op type traits/constraints. Also,
currently the infer type function can only be manually specified in C++, that should rather be the fallback in future.
PiperOrigin-RevId: 271883746
diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD
index b17761b..80b00f0 100644
--- a/third_party/mlir/BUILD
+++ b/third_party/mlir/BUILD
@@ -117,6 +117,7 @@
deps = [
":CallOpInterfacesIncGen",
":DialectSymbolRegistry",
+ ":InferTypeOpInterfaceIncGen",
":Support",
"@llvm//:support",
],
@@ -1223,6 +1224,26 @@
],
)
+gentbl(
+ name = "InferTypeOpInterfaceIncGen",
+ strip_include_prefix = "include",
+ tbl_outs = [
+ (
+ "-gen-op-interface-decls",
+ "include/mlir/Analysis/InferTypeOpInterface.h.inc",
+ ),
+ (
+ "-gen-op-interface-defs",
+ "include/mlir/Analysis/InferTypeOpInterface.cpp.inc",
+ ),
+ ],
+ tblgen = ":mlir-tblgen",
+ td_file = "include/mlir/Analysis/InferTypeOpInterface.td",
+ td_srcs = [
+ ":OpBaseTdFiles",
+ ],
+)
+
cc_library(
name = "Analysis",
srcs = [
@@ -1230,6 +1251,7 @@
"lib/Analysis/AffineStructures.cpp",
"lib/Analysis/CallGraph.cpp",
"lib/Analysis/Dominance.cpp",
+ "lib/Analysis/InferTypeOpInterface.cpp",
"lib/Analysis/LoopAnalysis.cpp",
"lib/Analysis/MemRefBoundCheck.cpp",
"lib/Analysis/NestedMatcher.cpp",
@@ -1247,6 +1269,7 @@
"include/mlir/Analysis/CallGraph.h",
"include/mlir/Analysis/CallInterfaces.h",
"include/mlir/Analysis/Dominance.h",
+ "include/mlir/Analysis/InferTypeOpInterface.h",
"include/mlir/Analysis/LoopAnalysis.h",
"include/mlir/Analysis/NestedMatcher.h",
"include/mlir/Analysis/Passes.h",
@@ -1260,6 +1283,7 @@
":AffineOps",
":CallOpInterfacesIncGen",
":IR",
+ ":InferTypeOpInterfaceIncGen",
":LoopOps",
":Pass",
":StandardOps",
diff --git a/third_party/mlir/include/mlir/Analysis/CMakeLists.txt b/third_party/mlir/include/mlir/Analysis/CMakeLists.txt
index 619f4b1..3d9a7ed 100644
--- a/third_party/mlir/include/mlir/Analysis/CMakeLists.txt
+++ b/third_party/mlir/include/mlir/Analysis/CMakeLists.txt
@@ -2,3 +2,8 @@
mlir_tablegen(CallInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(CallInterfaces.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(MLIRCallOpInterfacesIncGen)
+
+set(LLVM_TARGET_DEFINITIONS InferTypeOpInterface.td)
+mlir_tablegen(InferTypeOpInterface.h.inc -gen-op-interface-decls)
+mlir_tablegen(InferTypeOpInterface.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRTypeInferOpInterfaceIncGen)
diff --git a/third_party/mlir/include/mlir/Analysis/InferTypeOpInterface.h b/third_party/mlir/include/mlir/Analysis/InferTypeOpInterface.h
new file mode 100644
index 0000000..b80723e
--- /dev/null
+++ b/third_party/mlir/include/mlir/Analysis/InferTypeOpInterface.h
@@ -0,0 +1,40 @@
+//===- InferTypeOpInterface.h - Infer Type Interfaces -----------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains the definitions of the infer op interfaces defined in
+// `InferTypeOpInterface.td`.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_INFERTYPEOPINTERFACE_H_
+#define MLIR_ANALYSIS_INFERTYPEOPINTERFACE_H_
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+
+#include "mlir/Analysis/InferTypeOpInterface.h.inc"
+
+} // namespace mlir
+
+#endif // MLIR_ANALYSIS_INFERTYPEOPINTERFACE_H_
diff --git a/third_party/mlir/include/mlir/Analysis/InferTypeOpInterface.td b/third_party/mlir/include/mlir/Analysis/InferTypeOpInterface.td
new file mode 100644
index 0000000..a155810
--- /dev/null
+++ b/third_party/mlir/include/mlir/Analysis/InferTypeOpInterface.td
@@ -0,0 +1,63 @@
+//===- InferTypeOpInterface.td - Infer Type interfaces -*- tablegen -----*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains a set of interfaces that can be used to define information
+// related to call-like and callable operations. Each of which are defined along
+// with the respective interface below.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef MLIR_INFERTYPEOPINTERFACE
+#else
+#define MLIR_INFERTYPEOPINTERFACE
+
+#ifdef OP_BASE
+#else
+include "mlir/IR/OpBase.td"
+#endif // OP_BASE
+
+// OpInterface to compute the return type of an operation. The arguments match
+// those in Operation::create with the exception that the location is optional
+// (if no location is provided, then the method will not emit an error on
+// mismatch).
+def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
+ let description = [{
+ Interface to access a registered method to infer the return types for an
+ operation that could be used during op construction, verification or
+ type inference.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{Returns the return types that an op would generate.
+
+ The method takes an optional location which, if set, will be used to
+ report errors on. The operands and attributes correspond to those with
+ which an Operation would be created (e.g., as used in Operation;:create).
+ Regions are the nested regions of the op.
+ }],
+ /*retTy=*/"SmallVector<Type, 2>",
+ /*methodName=*/"inferReturnTypes",
+ /*args=*/(ins "llvm::Optional<Location>":$location,
+ "ArrayRef<Value*>":$operands,
+ "ArrayRef<NamedAttribute>":$attributes,
+ "ArrayRef<Region>":$regions)
+ >,
+ ];
+}
+
+#endif // MLIR_INFERTYPEOPINTERFACE
diff --git a/third_party/mlir/lib/Analysis/CMakeLists.txt b/third_party/mlir/lib/Analysis/CMakeLists.txt
index dff3df9..c16ad3f 100644
--- a/third_party/mlir/lib/Analysis/CMakeLists.txt
+++ b/third_party/mlir/lib/Analysis/CMakeLists.txt
@@ -3,6 +3,7 @@
AffineStructures.cpp
CallGraph.cpp
Dominance.cpp
+ InferTypeOpInterface.cpp
LoopAnalysis.cpp
MemRefBoundCheck.cpp
NestedMatcher.cpp
@@ -20,6 +21,7 @@
add_dependencies(MLIRAnalysis
MLIRAffineOps
MLIRCallOpInterfacesIncGen
+ MLIRTypeInferOpInterfaceIncGen
MLIRLoopOps
MLIRVectorOps
)
diff --git a/third_party/mlir/lib/Analysis/InferTypeOpInterface.cpp b/third_party/mlir/lib/Analysis/InferTypeOpInterface.cpp
new file mode 100644
index 0000000..cbbd446
--- /dev/null
+++ b/third_party/mlir/lib/Analysis/InferTypeOpInterface.cpp
@@ -0,0 +1,31 @@
+//===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains the definitions of the infer op interfaces defined in
+// `InferTypeOpInterface.td`.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/InferTypeOpInterface.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Types.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+#include "mlir/Analysis/InferTypeOpInterface.cpp.inc"
+} // namespace mlir
diff --git a/third_party/mlir/test/BUILD b/third_party/mlir/test/BUILD
index a7d7848..6700843 100644
--- a/third_party/mlir/test/BUILD
+++ b/third_party/mlir/test/BUILD
@@ -39,6 +39,7 @@
td_srcs = [
"@local_config_mlir//:include/mlir/IR/OpBase.td",
"@local_config_mlir//:include/mlir/Analysis/CallInterfaces.td",
+ "@local_config_mlir//:include/mlir/Analysis/InferTypeOpInterface.td",
],
)
@@ -54,11 +55,10 @@
includes = ["lib/TestDialect"],
deps = [
":TestOpsIncGen",
- "@llvm//:support",
+ "@local_config_mlir//:Analysis",
"@local_config_mlir//:Dialect",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
- "@local_config_mlir//:Support",
"@local_config_mlir//:TransformUtils",
"@local_config_mlir//:Transforms",
],
diff --git a/third_party/mlir/test/lib/TestDialect/TestDialect.cpp b/third_party/mlir/test/lib/TestDialect/TestDialect.cpp
index 69c3bcd..d91bb1a 100644
--- a/third_party/mlir/test/lib/TestDialect/TestDialect.cpp
+++ b/third_party/mlir/test/lib/TestDialect/TestDialect.cpp
@@ -216,6 +216,14 @@
return operand();
}
+SmallVector<Type, 2> mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
+ llvm::Optional<Location> location, ArrayRef<Value *> operands,
+ ArrayRef<NamedAttribute> attributes, ArrayRef<Region> regions) {
+ if (location)
+ mlir::emitError(*location) << "expected to fail";
+ return SmallVector<Type, 2>{nullptr};
+}
+
// Static initialization for Test dialect registration.
static mlir::DialectRegistration<mlir::TestDialect> testDialect;
diff --git a/third_party/mlir/test/lib/TestDialect/TestDialect.h b/third_party/mlir/test/lib/TestDialect/TestDialect.h
index a2fceca..ffe2a1c 100644
--- a/third_party/mlir/test/lib/TestDialect/TestDialect.h
+++ b/third_party/mlir/test/lib/TestDialect/TestDialect.h
@@ -24,6 +24,7 @@
#define MLIR_TESTDIALECT_H
#include "mlir/Analysis/CallInterfaces.h"
+#include "mlir/Analysis/InferTypeOpInterface.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
diff --git a/third_party/mlir/test/lib/TestDialect/TestOps.td b/third_party/mlir/test/lib/TestDialect/TestOps.td
index 0c2e53b..72991ce 100644
--- a/third_party/mlir/test/lib/TestDialect/TestOps.td
+++ b/third_party/mlir/test/lib/TestDialect/TestOps.td
@@ -21,6 +21,7 @@
include "mlir/IR/OpBase.td"
include "mlir/Analysis/CallInterfaces.td"
+include "mlir/Analysis/InferTypeOpInterface.td"
def TEST_Dialect : Dialect {
let name = "test";
@@ -318,8 +319,7 @@
def ChildOp : TEST_Op<"child", [HasParent<"ParentOp">]>;
-def TerminatorOp : TEST_Op<"finish", [Terminator]> {
-}
+def TerminatorOp : TEST_Op<"finish", [Terminator]>;
def SingleBlockImplicitTerminatorOp : TEST_Op<"SingleBlockImplicitTerminator",
[SingleBlockImplicitTerminator<"TerminatorOp">]> {
let regions = (region SizedRegion<1>:$region);
@@ -329,6 +329,18 @@
let arguments = (ins I32ElementsAttr:$attr);
}
+def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if",
+ [InferTypeOpInterface]> {
+ let arguments = (ins AnyTensor:$x, AnyTensor:$y);
+ let results = (outs AnyTensor:$res);
+ // TODO(jpienaar): Remove the need to specify these here.
+ let extraClassDeclaration = [{
+ SmallVector<Type, 2> inferReturnTypes(llvm::Optional<Location> location,
+ ArrayRef<Value*> operands, ArrayRef<NamedAttribute> attributes,
+ ArrayRef<Region> regions);
+ }];
+}
+
def IsNotScalar : Constraint<CPred<"$0.getType().getRank() != 0">>;
def UpdateAttr : Pat<(I32ElementsAttrOp $attr),
diff --git a/third_party/mlir/test/lib/TestDialect/TestPatterns.cpp b/third_party/mlir/test/lib/TestDialect/TestPatterns.cpp
index 533ec1f..17a257f 100644
--- a/third_party/mlir/test/lib/TestDialect/TestPatterns.cpp
+++ b/third_party/mlir/test/lib/TestDialect/TestPatterns.cpp
@@ -52,6 +52,43 @@
pass("test-patterns", "Run test dialect patterns");
//===----------------------------------------------------------------------===//
+// ReturnType Driver.
+//===----------------------------------------------------------------------===//
+
+struct ReturnTypeOpMatch : public RewritePattern {
+ ReturnTypeOpMatch(MLIRContext *ctx)
+ : RewritePattern(OpWithInferTypeInterfaceOp::getOperationName(), 1, ctx) {
+ }
+
+ PatternMatchResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const final {
+ if (auto retTypeFn = dyn_cast<InferTypeOpInterface>(op)) {
+ SmallVector<Value *, 4> values;
+ values.reserve(op->getNumOperands());
+ for (auto &operand : op->getOpOperands())
+ values.push_back(operand.get());
+ (void)retTypeFn.inferReturnTypes(op->getLoc(), values, op->getAttrs(),
+ op->getRegions());
+ }
+ return matchFailure();
+ }
+};
+
+namespace {
+struct TestReturnTypeDriver : public FunctionPass<TestReturnTypeDriver> {
+ void runOnFunction() override {
+ mlir::OwningRewritePatternList patterns;
+ populateWithGenerated(&getContext(), &patterns);
+ patterns.insert<ReturnTypeOpMatch>(&getContext());
+ applyPatternsGreedily(getFunction(), patterns);
+ }
+};
+} // end anonymous namespace
+
+static mlir::PassRegistration<TestReturnTypeDriver>
+ rt_pass("test-return-type", "Run return type functions");
+
+//===----------------------------------------------------------------------===//
// Legalization Driver.
//===----------------------------------------------------------------------===//