Add edsc::ops for pointwise, conv and dilated_conv

This CL adds more Linalg EDSC ops and tests to support building pointwise operations along with conv and dilated_conv.
This also fixes a bug in the existing linalg_matmul EDSC and beefs up the test.

The current set of ops is already enough to build an interesting, albeit simple, model used internally.

PiperOrigin-RevId: 285838012
Change-Id: I35edf4bed5eef32a22900c89c4482f5426b4645d
diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD
index 9be60bb..131afd5 100644
--- a/third_party/mlir/BUILD
+++ b/third_party/mlir/BUILD
@@ -2287,6 +2287,7 @@
     hdrs = [
         "include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h",
         "include/mlir/Dialect/Linalg/EDSC/Builders.h",
+        "include/mlir/Dialect/Linalg/EDSC/Intrinsics.h",
         "include/mlir/Dialect/Linalg/IR/LinalgOps.h",
         "include/mlir/Dialect/Linalg/IR/LinalgTraits.h",
         "include/mlir/Dialect/Linalg/IR/LinalgTypes.h",
diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/third_party/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
index 00da1d6..4213420 100644
--- a/third_party/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
+++ b/third_party/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
@@ -22,15 +22,17 @@
 #ifndef MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_
 #define MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_
 
+#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/EDSC/Builders.h"
+#include "mlir/EDSC/Intrinsics.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/Builders.h"
 
 namespace mlir {
 class BlockArgument;
-namespace edsc {
 
+namespace edsc {
 enum class IterType { Parallel, Reduction };
 
 inline StringRef toString(IterType t) {
@@ -38,7 +40,7 @@
   case IterType::Parallel:
     return getParallelIteratorTypeName();
   case IterType::Reduction:
-    return getParallelIteratorTypeName();
+    return getReductionIteratorTypeName();
   default:
     llvm_unreachable("Unsupport IterType");
   }
@@ -78,20 +80,83 @@
 Operation *makeLinalgGenericOp(
     ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
     ArrayRef<StructuredIndexed> outputs,
-    decltype(defaultRegionBuilder) regionBuilder = defaultRegionBuilder,
+    llvm::function_ref<void(ArrayRef<BlockArgument *>)> regionBuilder =
+        defaultRegionBuilder,
     ArrayRef<Value *> otherValues = {},
     ArrayRef<Attribute> otherAttributes = {});
 
+namespace ops {
+using edsc::StructuredIndexed;
+using edsc::ValueHandle;
+using edsc::intrinsics::linalg_yield;
+
 //===----------------------------------------------------------------------===//
 // EDSC builders for linalg generic operations.
 //===----------------------------------------------------------------------===//
 
+/// Build the body of a region to compute a multiply-accumulate, under the
+/// current ScopedContext, at the current insert point.
+void macRegionBuilder(ArrayRef<BlockArgument *> args);
+
 /// TODO(ntv): In the future we should tie these implementations to something in
 /// Tablegen that generates the proper interfaces and the proper sugared named
 /// ops.
 
-/// Build a linalg.generic that represents C = A * B in the current
-/// ScopedContext.
+/// Build a linalg.pointwise, under the current ScopedContext, at the current
+/// insert point, that computes:
+/// ```
+///    (i0, ..., in) = (par, ..., par)
+///    |
+///    |  O...(some_subset...(i0, ..., in)) =
+///    |    some_pointwise_func...(I...(some_other_subset...(i0, ..., in)))
+/// ```
+///
+/// This is a very generic entry point that can be configured in many ways to
+/// build a perfect loop nest of parallel loops with arbitrarily complex
+/// innermost loop code and whatever (explicit) broadcast semantics.
+///
+/// This can be used with both out-of-place and in-place semantics.
+/// The client is responsible for ensuring the region operations are compatible
+/// with in-place semantics and parallelism.
+
+/// Unary pointwise operation (with broadcast) entry point.
+using UnaryPointwiseOpBuilder = llvm::function_ref<Value *(ValueHandle)>;
+Operation *linalg_pointwise(UnaryPointwiseOpBuilder unaryOp,
+                            StructuredIndexed I, StructuredIndexed O);
+
+/// Build a linalg.pointwise with all `parallel` iterators and a region that
+/// computes `O = tanh(I)`. The client is responsible for specifying the proper
+/// indexings when creating the StructuredIndexed.
+Operation *linalg_pointwise_tanh(StructuredIndexed I, StructuredIndexed O);
+
+/// Binary pointwise operation (with broadcast) entry point.
+using BinaryPointwiseOpBuilder =
+    llvm::function_ref<Value *(ValueHandle, ValueHandle)>;
+Operation *linalg_pointwise(BinaryPointwiseOpBuilder binaryOp,
+                            StructuredIndexed I1, StructuredIndexed I2,
+                            StructuredIndexed O);
+
+/// Build a linalg.pointwise with all `parallel` iterators and a region that
+/// computes `O = I1 + I2`. The client is responsible for specifying the proper
+/// indexings when creating the StructuredIndexed.
+Operation *linalg_pointwise_add(StructuredIndexed I1, StructuredIndexed I2,
+                                StructuredIndexed O);
+
+/// Build a linalg.pointwise with all `parallel` iterators and a region that
+/// computes `O = max(I!, I2)`. The client is responsible for specifying the
+/// proper indexings when creating the StructuredIndexed.
+Operation *linalg_pointwise_max(StructuredIndexed I1, StructuredIndexed I2,
+                                StructuredIndexed O);
+
+// TODO(ntv): Implement more useful pointwise operations on a per-need basis.
+
+/// Build a linalg.generic, under the current ScopedContext, at the current
+/// insert point, that computes:
+/// ```
+///    (m, n, k) = (par, par, seq)
+///    |
+///    |  C(m, n) += A(m, k) * B(k, n)
+/// ```
 Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC);
 
 template <typename Container> Operation *linalg_matmul(Container values) {
@@ -99,6 +164,76 @@
   return linalg_matmul(values[0], values[1], values[2]);
 }
 
+/// Build a linalg.generic, under the current ScopedContext, at the current
+/// insert point, that computes:
+/// ```
+///    (batch, f, [h, w, ...], [kh, kw, ...], c) =
+///    |  (par, par, [par, par, ...], [red, red, ...], red)
+///    |
+///    | O(batch, [h, w, ...], f) +=
+///    |   I(batch,
+///    |     [
+///    |       stride[0] * h + dilations[0] * kh,
+///    |       stride[1] * w + dilations[1] * kw, ...
+///          ],
+///    |     c)
+///    |   *
+///    |   W([kh, kw, ...], c, f)
+/// ```
+/// If `dilations` or `strides` are left empty, the default value of `1` is used
+/// along each relevant dimension.
+///
+/// For now `...` must be empty (i.e. only 2-D convolutions are supported).
+///
+// TODO(ntv) Extend convolution rank with some template magic.
+Operation *linalg_conv_nhwc(ValueHandle vI, ValueHandle vW, ValueHandle vO,
+                            ArrayRef<int> strides = {},
+                            ArrayRef<int> dilations = {});
+
+template <typename Container>
+Operation *linalg_conv_nhwc(Container values, ArrayRef<int> strides = {},
+                            ArrayRef<int> dilations = {}) {
+  assert(values.size() == 3 && "Expected exactly 3 values");
+  return linalg_conv_nhwc(values[0], values[1], values[2], strides, dilations);
+}
+
+/// Build a linalg.generic, under the current ScopedContext, at the current
+/// insert point, that computes:
+/// ```
+///    (batch, dm, c, [h, w, ...], [kh, kw, ...]) =
+///    |  (par, par, par, [par, par, ...], [red, red, ...])
+///    |
+///    | O(batch, [h, w, ...], c * depth_multiplier) +=
+///    |   I(batch,
+///    |     [
+///    |       stride[0] * h + dilations[0] * kh,
+///    |       stride[1] * w + dilations[1] * kw, ...
+///          ],
+///    |     c)
+///    |   *
+///    |   W([kh, kw, ...], c, depth_multiplier)
+/// ```
+/// If `dilations` or `strides` are left empty, the default value of `1` is used
+/// along each relevant dimension.
+///
+/// For now `...` must be empty (i.e. only 2-D convolutions are supported).
+///
+// TODO(ntv) Extend convolution rank with some template magic.
+Operation *linalg_dilated_conv_nhwc(ValueHandle vI, ValueHandle vW,
+                                    ValueHandle vO, int depth_multiplier = 1,
+                                    ArrayRef<int> strides = {},
+                                    ArrayRef<int> dilations = {});
+
+template <typename Container>
+Operation *linalg_dilated_conv_nhwc(Container values, int depth_multiplier,
+                                    ArrayRef<int> strides = {},
+                                    ArrayRef<int> dilations = {}) {
+  assert(values.size() == 3 && "Expected exactly 3 values");
+  return linalg_dilated_conv_nhwc(values[0], values[1], values[2],
+                                  depth_multiplier, strides, dilations);
+}
+
+} // namespace ops
 } // namespace edsc
 } // namespace mlir
 
diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h b/third_party/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h
new file mode 100644
index 0000000..f1acab6
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h
@@ -0,0 +1,35 @@
+//===- Intrinsics.h - MLIR EDSC Intrinsics for Linalg -----------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_
+#define MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_
+
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/EDSC/Builders.h"
+#include "mlir/EDSC/Intrinsics.h"
+
+namespace mlir {
+namespace edsc {
+namespace intrinsics {
+
+using linalg_fill = OperationBuilder<linalg::FillOp>;
+using linalg_yield = OperationBuilder<linalg::YieldOp>;
+
+} // namespace intrinsics
+} // namespace edsc
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_
diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td b/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td
index 4f9621c..1f24a90 100644
--- a/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td
+++ b/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td
@@ -247,13 +247,13 @@
 }
 
 def FillOp : LinalgLibrary_Op<"fill", [NInputs<0>, NOutputs<1>]> {
-  let arguments = (ins AnyStridedMemRef:$input,
+  let arguments = (ins AnyStridedMemRef:$output,
                    AnyTypeOf<[AnyFloat, AnyInteger, AnyVector]>:$value);
   let extraClassDeclaration = libraryCallName # [{
     ArrayAttr indexing_maps();
 
     ArrayAttr iterator_types() {
-      unsigned nPar = input()->getType().cast<ShapedType>().getRank();
+      unsigned nPar = output()->getType().cast<ShapedType>().getRank();
       MLIRContext *ctx = getContext();
       SmallVector<Attribute, 8> iters(
         nPar, StringAttr::get(getParallelIteratorTypeName(), ctx));
diff --git a/third_party/mlir/include/mlir/EDSC/Intrinsics.h b/third_party/mlir/include/mlir/EDSC/Intrinsics.h
index 68bd210..6dbb343 100644
--- a/third_party/mlir/include/mlir/EDSC/Intrinsics.h
+++ b/third_party/mlir/include/mlir/EDSC/Intrinsics.h
@@ -154,22 +154,22 @@
 
   /// Folder-based
   template <typename... Args>
-  ValueBuilder(OperationFolder &folder, Args... args)
+  ValueBuilder(OperationFolder *folder, Args... args)
       : ValueHandle(ValueHandle::create<Op>(folder, detail::unpack(args)...)) {}
-  ValueBuilder(OperationFolder &folder, ArrayRef<ValueHandle> vs)
+  ValueBuilder(OperationFolder *folder, ArrayRef<ValueHandle> vs)
       : ValueBuilder(ValueBuilder::create<Op>(folder, detail::unpack(vs))) {}
   template <typename... Args>
-  ValueBuilder(OperationFolder &folder, ArrayRef<ValueHandle> vs, Args... args)
+  ValueBuilder(OperationFolder *folder, ArrayRef<ValueHandle> vs, Args... args)
       : ValueHandle(ValueHandle::create<Op>(folder, detail::unpack(vs),
                                             detail::unpack(args)...)) {}
   template <typename T, typename... Args>
-  ValueBuilder(OperationFolder &folder, T t, ArrayRef<ValueHandle> vs,
+  ValueBuilder(OperationFolder *folder, T t, ArrayRef<ValueHandle> vs,
                Args... args)
       : ValueHandle(ValueHandle::create<Op>(folder, detail::unpack(t),
                                             detail::unpack(vs),
                                             detail::unpack(args)...)) {}
   template <typename T1, typename T2, typename... Args>
-  ValueBuilder(OperationFolder &folder, T1 t1, T2 t2, ArrayRef<ValueHandle> vs,
+  ValueBuilder(OperationFolder *folder, T1 t1, T2 t2, ArrayRef<ValueHandle> vs,
                Args... args)
       : ValueHandle(ValueHandle::create<Op>(
             folder, detail::unpack(t1), detail::unpack(t2), detail::unpack(vs),
@@ -200,6 +200,7 @@
   OperationBuilder() : OperationHandle(OperationHandle::create<Op>()) {}
 };
 
+using addf = ValueBuilder<AddFOp>;
 using affine_apply = ValueBuilder<AffineApplyOp>;
 using affine_if = OperationBuilder<AffineIfOp>;
 using affine_load = ValueBuilder<AffineLoadOp>;
@@ -212,11 +213,14 @@
 using dealloc = OperationBuilder<DeallocOp>;
 using dim = ValueBuilder<DimOp>;
 using muli = ValueBuilder<MulIOp>;
+using mulf = ValueBuilder<MulFOp>;
+using memref_cast = ValueBuilder<MemRefCastOp>;
 using ret = OperationBuilder<ReturnOp>;
 using select = ValueBuilder<SelectOp>;
 using std_load = ValueBuilder<LoadOp>;
 using std_store = OperationBuilder<StoreOp>;
 using subi = ValueBuilder<SubIOp>;
+using tanh = ValueBuilder<TanhOp>;
 using view = ValueBuilder<ViewOp>;
 
 /// Branches into the mlir::Block* captured by BlockHandle `b` with `operands`.
diff --git a/third_party/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/third_party/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
index 3daeafe..77e3a1e 100644
--- a/third_party/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
+++ b/third_party/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
@@ -16,6 +16,7 @@
 // =============================================================================
 
 #include "mlir/Dialect/Linalg/EDSC/Builders.h"
+#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/EDSC/Builders.h"
 #include "mlir/EDSC/Intrinsics.h"
@@ -26,6 +27,7 @@
 using namespace mlir;
 using namespace mlir::edsc;
 using namespace mlir::edsc::intrinsics;
+using namespace mlir::edsc::ops;
 
 static void getMaxDimIndex(ArrayRef<StructuredIndexed> structuredIndices,
                            unsigned &pos) {
@@ -42,24 +44,26 @@
 Operation *mlir::edsc::makeLinalgGenericOp(
     ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
     ArrayRef<StructuredIndexed> outputs,
-    decltype(defaultRegionBuilder) regionBuilder, ArrayRef<Value *> otherValues,
-    ArrayRef<Attribute> otherAttributes) {
+    llvm::function_ref<void(ArrayRef<BlockArgument *>)> regionBuilder,
+    ArrayRef<Value *> otherValues, ArrayRef<Attribute> otherAttributes) {
   auto &builder = edsc::ScopedContext::getBuilder();
   auto *ctx = builder.getContext();
   unsigned nInputs = inputs.size();
   unsigned nOutputs = outputs.size();
-  unsigned rank = 0;
-  getMaxDimIndex(inputs, rank);
-  getMaxDimIndex(outputs, rank);
+  unsigned maxPos = 0;
+  getMaxDimIndex(inputs, maxPos);
+  getMaxDimIndex(outputs, maxPos);
+  // maxPos is 0 indexed, need to turn this into a count (i.e. +1)
+  unsigned nDims = maxPos + 1;
 
   SmallVector<AffineMap, 4> maps;
   maps.reserve(nInputs + nOutputs);
   for (auto in : inputs)
     maps.push_back(
-        AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, in.getExprs()));
+        AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, in.getExprs()));
   for (auto out : outputs)
     maps.push_back(
-        AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, out.getExprs()));
+        AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, out.getExprs()));
 
   unsigned nViews = nInputs + nOutputs;
   SmallVector<Value *, 4> values;
@@ -105,23 +109,148 @@
   return op;
 }
 
-using linalg_yield = OperationBuilder<linalg::YieldOp>;
+void mlir::edsc::ops::macRegionBuilder(ArrayRef<BlockArgument *> args) {
+  using edsc::op::operator+;
+  using edsc::op::operator*;
+  assert(args.size() == 3 && "expected 3 block arguments");
+  ValueHandle a(args[0]), b(args[1]), c(args[2]);
+  linalg_yield((c + a * b).getValue());
+}
 
-Operation *mlir::edsc::linalg_matmul(ValueHandle vA, ValueHandle vB,
-                                     ValueHandle vC) {
+Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp,
+                                             StructuredIndexed I,
+                                             StructuredIndexed O) {
+  SmallVector<edsc::IterType, 4> iterTypes(O.getExprs().size(),
+                                           edsc::IterType::Parallel);
+  auto fun = [&unaryOp](ArrayRef<BlockArgument *> args) {
+    assert(args.size() == 2 && "expected 2 block arguments");
+    ValueHandle a(args[0]);
+    linalg_yield(unaryOp(a));
+  };
+  return makeLinalgGenericOp(iterTypes, {I}, {O}, fun);
+}
+
+Operation *mlir::edsc::ops::linalg_pointwise_tanh(StructuredIndexed I,
+                                                  StructuredIndexed O) {
+  ;
+  using edsc::intrinsics::tanh;
+  UnaryPointwiseOpBuilder unOp(
+      [](ValueHandle a) -> Value * { return tanh(a); });
+  return linalg_pointwise(unOp, I, O);
+}
+
+/// Binary pointwise operation (with broadcast) entry point.
+Operation *mlir::edsc::ops::linalg_pointwise(BinaryPointwiseOpBuilder binaryOp,
+                                             StructuredIndexed I1,
+                                             StructuredIndexed I2,
+                                             StructuredIndexed O) {
+  SmallVector<edsc::IterType, 4> iterTypes(O.getExprs().size(),
+                                           edsc::IterType::Parallel);
+  auto fun = [&binaryOp](ArrayRef<BlockArgument *> args) {
+    assert(args.size() == 3 && "expected 3 block arguments");
+    ValueHandle a(args[0]), b(args[1]);
+    linalg_yield(binaryOp(a, b));
+  };
+  return makeLinalgGenericOp(iterTypes, {I1, I2}, {O}, fun);
+}
+
+Operation *mlir::edsc::ops::linalg_pointwise_add(StructuredIndexed I1,
+                                                 StructuredIndexed I2,
+                                                 StructuredIndexed O) {
+  using edsc::op::operator+;
+  BinaryPointwiseOpBuilder binOp(
+      [](ValueHandle a, ValueHandle b) -> Value * { return a + b; });
+  return linalg_pointwise(binOp, I1, I2, O);
+}
+
+Operation *mlir::edsc::ops::linalg_pointwise_max(StructuredIndexed I1,
+                                                 StructuredIndexed I2,
+                                                 StructuredIndexed O) {
+  BinaryPointwiseOpBuilder binOp([](ValueHandle a, ValueHandle b) -> Value * {
+    using edsc::intrinsics::select;
+    using edsc::op::operator>;
+    return select(a > b, a, b).getValue();
+  });
+  return linalg_pointwise(binOp, I1, I2, O);
+}
+
+Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
+                                          ValueHandle vC) {
   // clang-format off
   AffineExpr m, n, k;
   bindDims(ScopedContext::getContext(), m, n, k);
   StructuredIndexed A(vA), B(vB), C(vC);
   return makeLinalgGenericOp(
     {IterType::Parallel, IterType::Parallel, IterType::Reduction},
-    {A({m, n}), B({k, n})},
+    {A({m, k}), B({k, n})},
     {C({m, n})},
-    [](ArrayRef<BlockArgument *> args) {
-      using edsc::op::operator*;
-      using edsc::op::operator+;
-      ValueHandle a(args[0]), b(args[1]), c(args[2]);
-      linalg_yield((c + a * b).getValue());
-  });
+    macRegionBuilder);
+  // clang-format on
+}
+
+Operation *mlir::edsc::ops::linalg_conv_nhwc(ValueHandle vI, ValueHandle vW,
+                                             ValueHandle vO,
+                                             ArrayRef<int> strides,
+                                             ArrayRef<int> dilations) {
+  MLIRContext *ctx = ScopedContext::getContext();
+  // TODO(ntv) some template magic to make everything rank-polymorphic.
+  assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm");
+  assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm");
+
+  // Some short names.
+  auto par = IterType::Parallel;
+  auto red = IterType::Reduction;
+  auto s = strides;
+  auto d = dilations;
+
+  AffineExpr b, f, h, w, kh, kw, c;
+  bindDims(ctx, b, f, h, w, kh, kw, c);
+  unsigned numDims = c.cast<AffineDimExpr>().getPosition() + 1;
+  StructuredIndexed I(vI), W(vW), O(vO);
+  // clang-format off
+  return makeLinalgGenericOp(
+    {par, par, par, par, red, red, red}, {
+      I({b,
+         // Roundtrip to flattened form to serve as canonicalization and ensure
+         // consistent ordering of subexpressions.
+         simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0),
+         simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0),
+         c}),
+      W({kh, kw, c, f})}, {
+      O({b, h, w, f})},
+    macRegionBuilder);
+  // clang-format on
+}
+
+Operation *mlir::edsc::ops::linalg_dilated_conv_nhwc(
+    ValueHandle vI, ValueHandle vW, ValueHandle vO, int depth_multiplier,
+    ArrayRef<int> strides, ArrayRef<int> dilations) {
+  MLIRContext *ctx = ScopedContext::getContext();
+  // TODO(ntv) some template magic to make everything rank-polymorphic.
+  assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm");
+  assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm");
+
+  // Some short names.
+  auto par = IterType::Parallel;
+  auto red = IterType::Reduction;
+  auto s = strides;
+  auto d = dilations;
+
+  // clang-format off
+  AffineExpr b, dm, c, h, w, kh, kw;
+  bindDims(ctx, b, dm, c, h, w, kh, kw);
+  unsigned numDims = kw.cast<AffineDimExpr>().getPosition() + 1;
+  StructuredIndexed I(vI), W(vW), O(vO);
+  return makeLinalgGenericOp(
+    {par, par, par, par, par, red, red}, {
+      I({b,
+         // Roundtrip to flattened form to serve as canonicalization and ensure
+         // consistent ordering of subexpressions.
+         simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0),
+         simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0),
+         c}),
+      W({kh, kw, c, dm})}, {
+      O({b, h, w, simplifyAffineExpr(c * depth_multiplier + dm, numDims, 0)})},
+    macRegionBuilder);
   // clang-format on
 }