blob: 40740097eca9d162c3bd7410dd3287ca1e2af738 [file] [log] [blame]
//===- OpDefinitionTest.cpp - Op definition unit tests --------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardTypes.h"
#include "gmock/gmock.h"
using namespace mlir;
using namespace mlir::OpTrait::impl;
namespace {
// TODO: Replace with regular test once this trait is used by operation in core.
TEST(OpDefinitionTest, SameOperandAndResultElementType) {
MLIRContext context;
#define FILE_LOC \
FileLineColLoc::get(UniquedFilename::get(__FILE__, &context), __LINE__, 0, \
&context)
Builder b(&context);
auto *operandtF32x10x10 = Operation::create(
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
/*resultTypes=*/{b.getTensorType({10, 10}, b.getF32Type())},
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
/*resizableOperandList=*/false, &context);
auto *operandtF32x1 = Operation::create(
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
/*resultTypes=*/{b.getTensorType({1}, b.getF32Type())},
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
/*resizableOperandList=*/false, &context);
auto *operandvF32x1 = Operation::create(
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
/*resultTypes=*/{b.getVectorType({1}, b.getF32Type())},
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
/*resizableOperandList=*/false, &context);
auto *operandtI32x1 = Operation::create(
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
/*resultTypes=*/{b.getTensorType({1}, b.getIntegerType(32))},
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
/*resizableOperandList=*/false, &context);
// Verifies whether an op with x and y as inputs and resultType satisfies the
// SameOperandAndResultElementType trait.
auto valid = [&](Location loc, Operation *x, Operation *y, Type resultType) {
auto op = Operation::create(loc, OperationName("some_op", &context),
/*operands=*/{x->getResult(0), y->getResult(0)},
/*resultTypes=*/{resultType},
/*attributes=*/llvm::None, /*successors=*/{},
/*numRegions=*/0,
/*resizableOperandList=*/false, &context);
return succeeded(verifySameOperandsAndResultElementType(op));
};
EXPECT_TRUE(valid(FILE_LOC, operandtF32x1, operandtF32x1,
b.getTensorType({12}, b.getF32Type())));
EXPECT_TRUE(valid(FILE_LOC, operandtF32x10x10, operandtF32x1,
b.getTensorType({5}, b.getF32Type())));
EXPECT_FALSE(valid(FILE_LOC, operandtF32x10x10, operandtI32x1,
b.getTensorType({7}, b.getF32Type())));
EXPECT_FALSE(valid(FILE_LOC, operandtF32x10x10, operandtF32x1,
b.getTensorType({12}, b.getIntegerType(32))));
EXPECT_FALSE(valid(FILE_LOC, operandtF32x10x10, operandtI32x1,
b.getTensorType({9}, b.getIntegerType(32))));
EXPECT_TRUE(valid(FILE_LOC, operandtF32x10x10, operandtF32x1,
b.getVectorType({9}, b.getF32Type())));
EXPECT_TRUE(valid(FILE_LOC, operandtF32x10x10, operandvF32x1,
b.getVectorType({9}, b.getF32Type())));
EXPECT_TRUE(valid(FILE_LOC, operandtF32x1, operandvF32x1,
b.getTensorType({5}, b.getF32Type())));
EXPECT_FALSE(valid(FILE_LOC, operandtI32x1, operandvF32x1,
b.getTensorType({5}, b.getF32Type())));
#undef FILE_LOC
}
} // end namespace