Add edsc::ops for pointwise, conv and dilated_conv
This CL adds more Linalg EDSC ops and tests to support building pointwise operations along with conv and dilated_conv.
This also fixes a bug in the existing linalg_matmul EDSC and beefs up the test.
The current set of ops is already enough to build an interesting, albeit simple, model used internally.
PiperOrigin-RevId: 285838012
Change-Id: I35edf4bed5eef32a22900c89c4482f5426b4645d
diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD
index 9be60bb..131afd5 100644
--- a/third_party/mlir/BUILD
+++ b/third_party/mlir/BUILD
@@ -2287,6 +2287,7 @@
hdrs = [
"include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h",
"include/mlir/Dialect/Linalg/EDSC/Builders.h",
+ "include/mlir/Dialect/Linalg/EDSC/Intrinsics.h",
"include/mlir/Dialect/Linalg/IR/LinalgOps.h",
"include/mlir/Dialect/Linalg/IR/LinalgTraits.h",
"include/mlir/Dialect/Linalg/IR/LinalgTypes.h",
diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/third_party/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
index 00da1d6..4213420 100644
--- a/third_party/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
+++ b/third_party/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
@@ -22,15 +22,17 @@
#ifndef MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_
#define MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_
+#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/EDSC/Builders.h"
+#include "mlir/EDSC/Intrinsics.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h"
namespace mlir {
class BlockArgument;
-namespace edsc {
+namespace edsc {
enum class IterType { Parallel, Reduction };
inline StringRef toString(IterType t) {
@@ -38,7 +40,7 @@
case IterType::Parallel:
return getParallelIteratorTypeName();
case IterType::Reduction:
- return getParallelIteratorTypeName();
+ return getReductionIteratorTypeName();
default:
llvm_unreachable("Unsupport IterType");
}
@@ -78,20 +80,83 @@
Operation *makeLinalgGenericOp(
ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
ArrayRef<StructuredIndexed> outputs,
- decltype(defaultRegionBuilder) regionBuilder = defaultRegionBuilder,
+ llvm::function_ref<void(ArrayRef<BlockArgument *>)> regionBuilder =
+ defaultRegionBuilder,
ArrayRef<Value *> otherValues = {},
ArrayRef<Attribute> otherAttributes = {});
+namespace ops {
+using edsc::StructuredIndexed;
+using edsc::ValueHandle;
+using edsc::intrinsics::linalg_yield;
+
//===----------------------------------------------------------------------===//
// EDSC builders for linalg generic operations.
//===----------------------------------------------------------------------===//
+/// Build the body of a region to compute a multiply-accumulate, under the
+/// current ScopedContext, at the current insert point.
+void macRegionBuilder(ArrayRef<BlockArgument *> args);
+
/// TODO(ntv): In the future we should tie these implementations to something in
/// Tablegen that generates the proper interfaces and the proper sugared named
/// ops.
-/// Build a linalg.generic that represents C = A * B in the current
-/// ScopedContext.
+/// Build a linalg.pointwise, under the current ScopedContext, at the current
+/// insert point, that computes:
+/// ```
+/// (i0, ..., in) = (par, ..., par)
+/// |
+/// | O...(some_subset...(i0, ..., in)) =
+/// | some_pointwise_func...(I...(some_other_subset...(i0, ..., in)))
+/// ```
+///
+/// This is a very generic entry point that can be configured in many ways to
+/// build a perfect loop nest of parallel loops with arbitrarily complex
+/// innermost loop code and whatever (explicit) broadcast semantics.
+///
+/// This can be used with both out-of-place and in-place semantics.
+/// The client is responsible for ensuring the region operations are compatible
+/// with in-place semantics and parallelism.
+
+/// Unary pointwise operation (with broadcast) entry point.
+using UnaryPointwiseOpBuilder = llvm::function_ref<Value *(ValueHandle)>;
+Operation *linalg_pointwise(UnaryPointwiseOpBuilder unaryOp,
+ StructuredIndexed I, StructuredIndexed O);
+
+/// Build a linalg.pointwise with all `parallel` iterators and a region that
+/// computes `O = tanh(I)`. The client is responsible for specifying the proper
+/// indexings when creating the StructuredIndexed.
+Operation *linalg_pointwise_tanh(StructuredIndexed I, StructuredIndexed O);
+
+/// Binary pointwise operation (with broadcast) entry point.
+using BinaryPointwiseOpBuilder =
+ llvm::function_ref<Value *(ValueHandle, ValueHandle)>;
+Operation *linalg_pointwise(BinaryPointwiseOpBuilder binaryOp,
+ StructuredIndexed I1, StructuredIndexed I2,
+ StructuredIndexed O);
+
+/// Build a linalg.pointwise with all `parallel` iterators and a region that
+/// computes `O = I1 + I2`. The client is responsible for specifying the proper
+/// indexings when creating the StructuredIndexed.
+Operation *linalg_pointwise_add(StructuredIndexed I1, StructuredIndexed I2,
+ StructuredIndexed O);
+
+/// Build a linalg.pointwise with all `parallel` iterators and a region that
+/// computes `O = max(I!, I2)`. The client is responsible for specifying the
+/// proper indexings when creating the StructuredIndexed.
+Operation *linalg_pointwise_max(StructuredIndexed I1, StructuredIndexed I2,
+ StructuredIndexed O);
+
+// TODO(ntv): Implement more useful pointwise operations on a per-need basis.
+
+/// Build a linalg.generic, under the current ScopedContext, at the current
+/// insert point, that computes:
+/// ```
+/// (m, n, k) = (par, par, seq)
+/// |
+/// | C(m, n) += A(m, k) * B(k, n)
+/// ```
Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC);
template <typename Container> Operation *linalg_matmul(Container values) {
@@ -99,6 +164,76 @@
return linalg_matmul(values[0], values[1], values[2]);
}
+/// Build a linalg.generic, under the current ScopedContext, at the current
+/// insert point, that computes:
+/// ```
+/// (batch, f, [h, w, ...], [kh, kw, ...], c) =
+/// | (par, par, [par, par, ...], [red, red, ...], red)
+/// |
+/// | O(batch, [h, w, ...], f) +=
+/// | I(batch,
+/// | [
+/// | stride[0] * h + dilations[0] * kh,
+/// | stride[1] * w + dilations[1] * kw, ...
+/// ],
+/// | c)
+/// | *
+/// | W([kh, kw, ...], c, f)
+/// ```
+/// If `dilations` or `strides` are left empty, the default value of `1` is used
+/// along each relevant dimension.
+///
+/// For now `...` must be empty (i.e. only 2-D convolutions are supported).
+///
+// TODO(ntv) Extend convolution rank with some template magic.
+Operation *linalg_conv_nhwc(ValueHandle vI, ValueHandle vW, ValueHandle vO,
+ ArrayRef<int> strides = {},
+ ArrayRef<int> dilations = {});
+
+template <typename Container>
+Operation *linalg_conv_nhwc(Container values, ArrayRef<int> strides = {},
+ ArrayRef<int> dilations = {}) {
+ assert(values.size() == 3 && "Expected exactly 3 values");
+ return linalg_conv_nhwc(values[0], values[1], values[2], strides, dilations);
+}
+
+/// Build a linalg.generic, under the current ScopedContext, at the current
+/// insert point, that computes:
+/// ```
+/// (batch, dm, c, [h, w, ...], [kh, kw, ...]) =
+/// | (par, par, par, [par, par, ...], [red, red, ...])
+/// |
+/// | O(batch, [h, w, ...], c * depth_multiplier) +=
+/// | I(batch,
+/// | [
+/// | stride[0] * h + dilations[0] * kh,
+/// | stride[1] * w + dilations[1] * kw, ...
+/// ],
+/// | c)
+/// | *
+/// | W([kh, kw, ...], c, depth_multiplier)
+/// ```
+/// If `dilations` or `strides` are left empty, the default value of `1` is used
+/// along each relevant dimension.
+///
+/// For now `...` must be empty (i.e. only 2-D convolutions are supported).
+///
+// TODO(ntv) Extend convolution rank with some template magic.
+Operation *linalg_dilated_conv_nhwc(ValueHandle vI, ValueHandle vW,
+ ValueHandle vO, int depth_multiplier = 1,
+ ArrayRef<int> strides = {},
+ ArrayRef<int> dilations = {});
+
+template <typename Container>
+Operation *linalg_dilated_conv_nhwc(Container values, int depth_multiplier,
+ ArrayRef<int> strides = {},
+ ArrayRef<int> dilations = {}) {
+ assert(values.size() == 3 && "Expected exactly 3 values");
+ return linalg_dilated_conv_nhwc(values[0], values[1], values[2],
+ depth_multiplier, strides, dilations);
+}
+
+} // namespace ops
} // namespace edsc
} // namespace mlir
diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h b/third_party/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h
new file mode 100644
index 0000000..f1acab6
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h
@@ -0,0 +1,35 @@
+//===- Intrinsics.h - MLIR EDSC Intrinsics for Linalg -----------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_
+#define MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_
+
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/EDSC/Builders.h"
+#include "mlir/EDSC/Intrinsics.h"
+
+namespace mlir {
+namespace edsc {
+namespace intrinsics {
+
+using linalg_fill = OperationBuilder<linalg::FillOp>;
+using linalg_yield = OperationBuilder<linalg::YieldOp>;
+
+} // namespace intrinsics
+} // namespace edsc
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_
diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td b/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td
index 4f9621c..1f24a90 100644
--- a/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td
+++ b/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td
@@ -247,13 +247,13 @@
}
def FillOp : LinalgLibrary_Op<"fill", [NInputs<0>, NOutputs<1>]> {
- let arguments = (ins AnyStridedMemRef:$input,
+ let arguments = (ins AnyStridedMemRef:$output,
AnyTypeOf<[AnyFloat, AnyInteger, AnyVector]>:$value);
let extraClassDeclaration = libraryCallName # [{
ArrayAttr indexing_maps();
ArrayAttr iterator_types() {
- unsigned nPar = input()->getType().cast<ShapedType>().getRank();
+ unsigned nPar = output()->getType().cast<ShapedType>().getRank();
MLIRContext *ctx = getContext();
SmallVector<Attribute, 8> iters(
nPar, StringAttr::get(getParallelIteratorTypeName(), ctx));
diff --git a/third_party/mlir/include/mlir/EDSC/Intrinsics.h b/third_party/mlir/include/mlir/EDSC/Intrinsics.h
index 68bd210..6dbb343 100644
--- a/third_party/mlir/include/mlir/EDSC/Intrinsics.h
+++ b/third_party/mlir/include/mlir/EDSC/Intrinsics.h
@@ -154,22 +154,22 @@
/// Folder-based
template <typename... Args>
- ValueBuilder(OperationFolder &folder, Args... args)
+ ValueBuilder(OperationFolder *folder, Args... args)
: ValueHandle(ValueHandle::create<Op>(folder, detail::unpack(args)...)) {}
- ValueBuilder(OperationFolder &folder, ArrayRef<ValueHandle> vs)
+ ValueBuilder(OperationFolder *folder, ArrayRef<ValueHandle> vs)
: ValueBuilder(ValueBuilder::create<Op>(folder, detail::unpack(vs))) {}
template <typename... Args>
- ValueBuilder(OperationFolder &folder, ArrayRef<ValueHandle> vs, Args... args)
+ ValueBuilder(OperationFolder *folder, ArrayRef<ValueHandle> vs, Args... args)
: ValueHandle(ValueHandle::create<Op>(folder, detail::unpack(vs),
detail::unpack(args)...)) {}
template <typename T, typename... Args>
- ValueBuilder(OperationFolder &folder, T t, ArrayRef<ValueHandle> vs,
+ ValueBuilder(OperationFolder *folder, T t, ArrayRef<ValueHandle> vs,
Args... args)
: ValueHandle(ValueHandle::create<Op>(folder, detail::unpack(t),
detail::unpack(vs),
detail::unpack(args)...)) {}
template <typename T1, typename T2, typename... Args>
- ValueBuilder(OperationFolder &folder, T1 t1, T2 t2, ArrayRef<ValueHandle> vs,
+ ValueBuilder(OperationFolder *folder, T1 t1, T2 t2, ArrayRef<ValueHandle> vs,
Args... args)
: ValueHandle(ValueHandle::create<Op>(
folder, detail::unpack(t1), detail::unpack(t2), detail::unpack(vs),
@@ -200,6 +200,7 @@
OperationBuilder() : OperationHandle(OperationHandle::create<Op>()) {}
};
+using addf = ValueBuilder<AddFOp>;
using affine_apply = ValueBuilder<AffineApplyOp>;
using affine_if = OperationBuilder<AffineIfOp>;
using affine_load = ValueBuilder<AffineLoadOp>;
@@ -212,11 +213,14 @@
using dealloc = OperationBuilder<DeallocOp>;
using dim = ValueBuilder<DimOp>;
using muli = ValueBuilder<MulIOp>;
+using mulf = ValueBuilder<MulFOp>;
+using memref_cast = ValueBuilder<MemRefCastOp>;
using ret = OperationBuilder<ReturnOp>;
using select = ValueBuilder<SelectOp>;
using std_load = ValueBuilder<LoadOp>;
using std_store = OperationBuilder<StoreOp>;
using subi = ValueBuilder<SubIOp>;
+using tanh = ValueBuilder<TanhOp>;
using view = ValueBuilder<ViewOp>;
/// Branches into the mlir::Block* captured by BlockHandle `b` with `operands`.
diff --git a/third_party/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/third_party/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
index 3daeafe..77e3a1e 100644
--- a/third_party/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
+++ b/third_party/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
@@ -16,6 +16,7 @@
// =============================================================================
#include "mlir/Dialect/Linalg/EDSC/Builders.h"
+#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/EDSC/Builders.h"
#include "mlir/EDSC/Intrinsics.h"
@@ -26,6 +27,7 @@
using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
+using namespace mlir::edsc::ops;
static void getMaxDimIndex(ArrayRef<StructuredIndexed> structuredIndices,
unsigned &pos) {
@@ -42,24 +44,26 @@
Operation *mlir::edsc::makeLinalgGenericOp(
ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
ArrayRef<StructuredIndexed> outputs,
- decltype(defaultRegionBuilder) regionBuilder, ArrayRef<Value *> otherValues,
- ArrayRef<Attribute> otherAttributes) {
+ llvm::function_ref<void(ArrayRef<BlockArgument *>)> regionBuilder,
+ ArrayRef<Value *> otherValues, ArrayRef<Attribute> otherAttributes) {
auto &builder = edsc::ScopedContext::getBuilder();
auto *ctx = builder.getContext();
unsigned nInputs = inputs.size();
unsigned nOutputs = outputs.size();
- unsigned rank = 0;
- getMaxDimIndex(inputs, rank);
- getMaxDimIndex(outputs, rank);
+ unsigned maxPos = 0;
+ getMaxDimIndex(inputs, maxPos);
+ getMaxDimIndex(outputs, maxPos);
+ // maxPos is 0 indexed, need to turn this into a count (i.e. +1)
+ unsigned nDims = maxPos + 1;
SmallVector<AffineMap, 4> maps;
maps.reserve(nInputs + nOutputs);
for (auto in : inputs)
maps.push_back(
- AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, in.getExprs()));
+ AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, in.getExprs()));
for (auto out : outputs)
maps.push_back(
- AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, out.getExprs()));
+ AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, out.getExprs()));
unsigned nViews = nInputs + nOutputs;
SmallVector<Value *, 4> values;
@@ -105,23 +109,148 @@
return op;
}
-using linalg_yield = OperationBuilder<linalg::YieldOp>;
+void mlir::edsc::ops::macRegionBuilder(ArrayRef<BlockArgument *> args) {
+ using edsc::op::operator+;
+ using edsc::op::operator*;
+ assert(args.size() == 3 && "expected 3 block arguments");
+ ValueHandle a(args[0]), b(args[1]), c(args[2]);
+ linalg_yield((c + a * b).getValue());
+}
-Operation *mlir::edsc::linalg_matmul(ValueHandle vA, ValueHandle vB,
- ValueHandle vC) {
+Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp,
+ StructuredIndexed I,
+ StructuredIndexed O) {
+ SmallVector<edsc::IterType, 4> iterTypes(O.getExprs().size(),
+ edsc::IterType::Parallel);
+ auto fun = [&unaryOp](ArrayRef<BlockArgument *> args) {
+ assert(args.size() == 2 && "expected 2 block arguments");
+ ValueHandle a(args[0]);
+ linalg_yield(unaryOp(a));
+ };
+ return makeLinalgGenericOp(iterTypes, {I}, {O}, fun);
+}
+
+Operation *mlir::edsc::ops::linalg_pointwise_tanh(StructuredIndexed I,
+ StructuredIndexed O) {
+ ;
+ using edsc::intrinsics::tanh;
+ UnaryPointwiseOpBuilder unOp(
+ [](ValueHandle a) -> Value * { return tanh(a); });
+ return linalg_pointwise(unOp, I, O);
+}
+
+/// Binary pointwise operation (with broadcast) entry point.
+Operation *mlir::edsc::ops::linalg_pointwise(BinaryPointwiseOpBuilder binaryOp,
+ StructuredIndexed I1,
+ StructuredIndexed I2,
+ StructuredIndexed O) {
+ SmallVector<edsc::IterType, 4> iterTypes(O.getExprs().size(),
+ edsc::IterType::Parallel);
+ auto fun = [&binaryOp](ArrayRef<BlockArgument *> args) {
+ assert(args.size() == 3 && "expected 3 block arguments");
+ ValueHandle a(args[0]), b(args[1]);
+ linalg_yield(binaryOp(a, b));
+ };
+ return makeLinalgGenericOp(iterTypes, {I1, I2}, {O}, fun);
+}
+
+Operation *mlir::edsc::ops::linalg_pointwise_add(StructuredIndexed I1,
+ StructuredIndexed I2,
+ StructuredIndexed O) {
+ using edsc::op::operator+;
+ BinaryPointwiseOpBuilder binOp(
+ [](ValueHandle a, ValueHandle b) -> Value * { return a + b; });
+ return linalg_pointwise(binOp, I1, I2, O);
+}
+
+Operation *mlir::edsc::ops::linalg_pointwise_max(StructuredIndexed I1,
+ StructuredIndexed I2,
+ StructuredIndexed O) {
+ BinaryPointwiseOpBuilder binOp([](ValueHandle a, ValueHandle b) -> Value * {
+ using edsc::intrinsics::select;
+ using edsc::op::operator>;
+ return select(a > b, a, b).getValue();
+ });
+ return linalg_pointwise(binOp, I1, I2, O);
+}
+
+Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
+ ValueHandle vC) {
// clang-format off
AffineExpr m, n, k;
bindDims(ScopedContext::getContext(), m, n, k);
StructuredIndexed A(vA), B(vB), C(vC);
return makeLinalgGenericOp(
{IterType::Parallel, IterType::Parallel, IterType::Reduction},
- {A({m, n}), B({k, n})},
+ {A({m, k}), B({k, n})},
{C({m, n})},
- [](ArrayRef<BlockArgument *> args) {
- using edsc::op::operator*;
- using edsc::op::operator+;
- ValueHandle a(args[0]), b(args[1]), c(args[2]);
- linalg_yield((c + a * b).getValue());
- });
+ macRegionBuilder);
+ // clang-format on
+}
+
+Operation *mlir::edsc::ops::linalg_conv_nhwc(ValueHandle vI, ValueHandle vW,
+ ValueHandle vO,
+ ArrayRef<int> strides,
+ ArrayRef<int> dilations) {
+ MLIRContext *ctx = ScopedContext::getContext();
+ // TODO(ntv) some template magic to make everything rank-polymorphic.
+ assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm");
+ assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm");
+
+ // Some short names.
+ auto par = IterType::Parallel;
+ auto red = IterType::Reduction;
+ auto s = strides;
+ auto d = dilations;
+
+ AffineExpr b, f, h, w, kh, kw, c;
+ bindDims(ctx, b, f, h, w, kh, kw, c);
+ unsigned numDims = c.cast<AffineDimExpr>().getPosition() + 1;
+ StructuredIndexed I(vI), W(vW), O(vO);
+ // clang-format off
+ return makeLinalgGenericOp(
+ {par, par, par, par, red, red, red}, {
+ I({b,
+ // Roundtrip to flattened form to serve as canonicalization and ensure
+ // consistent ordering of subexpressions.
+ simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0),
+ simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0),
+ c}),
+ W({kh, kw, c, f})}, {
+ O({b, h, w, f})},
+ macRegionBuilder);
+ // clang-format on
+}
+
+Operation *mlir::edsc::ops::linalg_dilated_conv_nhwc(
+ ValueHandle vI, ValueHandle vW, ValueHandle vO, int depth_multiplier,
+ ArrayRef<int> strides, ArrayRef<int> dilations) {
+ MLIRContext *ctx = ScopedContext::getContext();
+ // TODO(ntv) some template magic to make everything rank-polymorphic.
+ assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm");
+ assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm");
+
+ // Some short names.
+ auto par = IterType::Parallel;
+ auto red = IterType::Reduction;
+ auto s = strides;
+ auto d = dilations;
+
+ // clang-format off
+ AffineExpr b, dm, c, h, w, kh, kw;
+ bindDims(ctx, b, dm, c, h, w, kh, kw);
+ unsigned numDims = kw.cast<AffineDimExpr>().getPosition() + 1;
+ StructuredIndexed I(vI), W(vW), O(vO);
+ return makeLinalgGenericOp(
+ {par, par, par, par, par, red, red}, {
+ I({b,
+ // Roundtrip to flattened form to serve as canonicalization and ensure
+ // consistent ordering of subexpressions.
+ simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0),
+ simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0),
+ c}),
+ W({kh, kw, c, dm})}, {
+ O({b, h, w, simplifyAffineExpr(c * depth_multiplier + dm, numDims, 0)})},
+ macRegionBuilder);
// clang-format on
}