blob: 0de8a9044e19cc1c844e03b949311124c624b006 [file] [log] [blame]
//===- Example.cpp - Our running example ----------------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// RUN: %p/test | FileCheck %s
#include "TestHarness.h"
#include "linalg1/Common.h"
#include "linalg1/Dialect.h"
#include "linalg2/Intrinsics.h"
#include "linalg2/Ops.h"
#include "linalg2/Transforms.h"
#include "mlir/IR/Function.h"
using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
using namespace linalg;
using namespace linalg::common;
using namespace linalg::intrinsics;
TEST_FUNC(linalg_ops) {
MLIRContext context;
Module module(&context);
auto indexType = mlir::IndexType::get(&context);
mlir::Function *f =
makeFunction(module, "linalg_ops", {indexType, indexType, indexType}, {});
FuncBuilder builder(f);
ScopedContext scope(builder, f->getLoc());
// clang-format off
ValueHandle M(f->getArgument(0)), N(f->getArgument(1)), K(f->getArgument(2)),
rM = range(constant_index(0), M, constant_index(1)),
rN = range(constant_index(0), N, constant_index(1)),
rK = range(constant_index(0), K, constant_index(1)),
vA = view(alloc(floatMemRefType<2>(&context), {M, K}), {rM ,rK}),
vB = view(alloc(floatMemRefType<2>(&context), {K, N}), {rK, rN}),
vC = view(alloc(floatMemRefType<2>(&context), {M, N}), {rM, rN}),
sB = slice(vB, constant_index(0), 1),
sC = slice(vC, constant_index(0), 1),
sA = slice(vA, constant_index(0), 0),
ssC = slice(sC, constant_index(0), 0);
matmul(vA, vB, vC);
matvec(vA, sB, sC);
dot(sA, sB, ssC);
ret();
// CHECK-LABEL: func @linalg_ops(%arg0: index, %arg1: index, %arg2: index) {
// CHECK: {{.*}} = linalg.slice {{.*}}[{{.*}}] {dim: 1} : !linalg.view<?x?xf32>, index
// CHECK-NEXT: {{.*}} = linalg.slice {{.*}}[{{.*}}] {dim: 1} : !linalg.view<?x?xf32>, index
// CHECK-NEXT: {{.*}} = linalg.slice {{.*}}[{{.*}}] {dim: 0} : !linalg.view<?x?xf32>, index
// CHECK-NEXT: {{.*}} = linalg.slice {{.*}}[{{.*}}] {dim: 0} : !linalg.view<?xf32>, index
// CHECK: linalg.matmul({{.*}}, {{.*}}, {{.*}}) : !linalg.view<?x?xf32>
// CHECK-NEXT: linalg.matvec({{.*}}, {{.*}}, {{.*}}) : !linalg.view<?xf32>
// CHECK-NEXT: linalg.dot({{.*}}, {{.*}}, {{.*}}) : !linalg.view<f32>
// clang-format on
cleanupAndPrintFunction(f);
}
TEST_FUNC(linalg_ops_folded_slices) {
MLIRContext context;
Module module(&context);
auto indexType = mlir::IndexType::get(&context);
mlir::Function *f = makeFunction(module, "linalg_ops_folded_slices",
{indexType, indexType, indexType}, {});
FuncBuilder builder(f);
ScopedContext scope(builder, f->getLoc());
// clang-format off
ValueHandle M(f->getArgument(0)), N(f->getArgument(1)), K(f->getArgument(2)),
rM = range(constant_index(0), M, constant_index(1)),
rN = range(constant_index(0), N, constant_index(1)),
rK = range(constant_index(0), K, constant_index(1)),
vA = view(alloc(floatMemRefType<2>(&context), {M, K}), {rM, rK}),
vB = view(alloc(floatMemRefType<2>(&context), {K, N}), {rK, rN}),
vC = view(alloc(floatMemRefType<2>(&context), {M, N}), {rM, rN}),
sB = slice(vB, constant_index(0), 1),
sC = slice(vC, constant_index(0), 1),
sA = slice(vA, constant_index(0), 0),
ssC = slice(sC, constant_index(0), 0);
matmul(vA, vB, vC);
matvec(vA, sB, sC);
dot(sA, sB, ssC);
ret();
// CHECK-LABEL: func @linalg_ops_folded_slices(%arg0: index, %arg1: index, %arg2: index) {
// CHECK-NOT: linalg.slice
// CHECK: linalg.matmul({{.*}}, {{.*}}, {{.*}}) : !linalg.view<?x?xf32>
// CHECK-NEXT: linalg.matvec({{.*}}, {{.*}}, {{.*}}) : !linalg.view<?xf32>
// CHECK-NEXT: linalg.dot({{.*}}, {{.*}}, {{.*}}) : !linalg.view<f32>
// clang-format on
f->walk<SliceOp>([](SliceOp slice) {
auto *sliceResult = slice.getResult();
auto viewOp = emitAndReturnFullyComposedView(sliceResult);
sliceResult->replaceAllUsesWith(viewOp.getResult());
slice.erase();
});
cleanupAndPrintFunction(f);
}
int main() {
mlir::registerDialect<linalg::LinalgDialect>();
RUN_TESTS();
return 0;
}