blob: 97417748b6407709682ca155da7b681ebb640fd9 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
#include <iostream>
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/protobuf.h"
using mlir::Builder;
using mlir::MemRefType;
using mlir::MLIRContext;
using mlir::RankedTensorType;
using mlir::UnrankedTensorType;
using mlir::VectorType;
namespace xla {
namespace {
// Simple implementation of a proto matcher comparing string representations.
// Only works as ShapeProto's textual representation is deterministic.
class ProtoStringMatcher {
public:
explicit ProtoStringMatcher(const tensorflow::protobuf::Message& expected)
: expected_(expected.SerializeAsString()) {}
template <typename Message>
bool MatchAndExplain(const Message& p, testing::MatchResultListener*) const {
return p.SerializeAsString() == expected_;
}
void DescribeTo(::std::ostream* os) const { *os << expected_; }
void DescribeNegationTo(::std::ostream* os) const {
*os << "not equal to expected message: " << expected_;
}
private:
const std::string expected_;
};
inline ::testing::PolymorphicMatcher<ProtoStringMatcher> EqualsProto(
const tensorflow::protobuf::Message& x) {
return ::testing::MakePolymorphicMatcher(ProtoStringMatcher(x));
}
TEST(TypeToShapeTest, ConvertPrimitiveTypes) {
MLIRContext context;
Builder b(&context);
EXPECT_EQ(TypeToPrimitiveType(b.getF32Type()), PrimitiveType::F32);
EXPECT_EQ(TypeToPrimitiveType(b.getIntegerType(1)), PrimitiveType::PRED);
EXPECT_EQ(TypeToPrimitiveType(b.getIntegerType(17)),
PrimitiveType::PRIMITIVE_TYPE_INVALID);
}
TEST(TypeToShapeTest, ConvertBasicTypesToTypes) {
MLIRContext context;
Builder b(&context);
EXPECT_TRUE(
ShapeUtil::IsScalarWithElementType(TypeToShape(b.getF32Type()), F32));
EXPECT_THAT(
TypeToShape(VectorType::get({8, 128}, b.getIntegerType(32))).ToProto(),
EqualsProto(
ShapeUtil::MakeShape(PrimitiveType::S32, {8, 128}).ToProto()));
EXPECT_THAT(
TypeToShape(VectorType::get({8, 128}, b.getF32Type())).ToProto(),
EqualsProto(
ShapeUtil::MakeShape(PrimitiveType::F32, {8, 128}).ToProto()));
// MLIR Type that is not representable as XLA Shape.
EXPECT_THAT(
TypeToShape(VectorType::get({8, 128}, b.getIntegerType(17))).ToProto(),
EqualsProto(Shape().ToProto()));
}
TEST(TypeToShapeTest, ConvertMemRefTypeToTypes) {
MLIRContext context;
Builder b(&context);
// Memref without any affine map. Note: memory space is ignored for shape.
EXPECT_THAT(
TypeToShape(MemRefType::get({8, 128}, b.getF32Type())).ToProto(),
EqualsProto(
ShapeUtil::MakeShape(PrimitiveType::F32, {8, 128}).ToProto()));
EXPECT_THAT(
TypeToShape(MemRefType::get({100, 13, 210}, b.getF32Type())).ToProto(),
EqualsProto(
ShapeUtil::MakeShape(PrimitiveType::F32, {100, 13, 210}).ToProto()));
// Vector types are "flattened" into the end of the shape.
EXPECT_THAT(
TypeToShape(MemRefType::get({100, 13, 210},
VectorType::get({8, 128}, b.getF32Type())))
.ToProto(),
EqualsProto(
ShapeUtil::MakeShape(PrimitiveType::F32, {100, 13, 210, 8, 128})
.ToProto()));
}
TEST(TypeToShapeTest, ConvertTensorTypeToTypes) {
MLIRContext context;
Builder b(&context);
EXPECT_THAT(
TypeToShape(RankedTensorType::get({8, 128}, b.getF32Type())).ToProto(),
EqualsProto(
ShapeUtil::MakeShape(PrimitiveType::F32, {8, 128}).ToProto()));
// Shape cannot represent dynamic shapes.
// TODO(b/115638799): Update once Shape can support dynamic shapes.
EXPECT_THAT(TypeToShape(UnrankedTensorType::get(b.getF32Type())).ToProto(),
EqualsProto(Shape().ToProto()));
// TODO(jpienaar): Expand to handle more complicated tensor types.
EXPECT_THAT(
TypeToShape(RankedTensorType::get(
{8, 128}, VectorType::get({16, 16}, b.getF32Type())))
.ToProto(),
EqualsProto(Shape().ToProto()));
}
TEST(TypeToShapeTest, ConvertWithShapeRepresentationFn) {
tensorflow::DataType captured_dtype;
tensorflow::TensorShape captured_tensor_shape;
// A dummy shape representation function that does nothing other than
// capturing arguments passed to it.
auto test_shape_representation_fn = [&](const tensorflow::TensorShape& shape,
tensorflow::DataType dtype) {
captured_tensor_shape = shape;
captured_dtype = dtype;
return xla::Shape();
};
MLIRContext context;
Builder b(&context);
StatusOr<Shape> status_or_shape;
// Non-fully-defined shape.
status_or_shape =
TypeToShape(RankedTensorType::get({-1, 2, 3}, b.getF32Type()),
test_shape_representation_fn);
EXPECT_EQ(status_or_shape.status().code(),
tensorflow::errors::Code::INVALID_ARGUMENT);
// Scalar Int32 Tensor, using fast memory.
status_or_shape =
TypeToShape(b.getIntegerType(32), test_shape_representation_fn);
EXPECT_TRUE(status_or_shape.ok());
EXPECT_EQ(captured_dtype, tensorflow::DataType::DT_INT32);
EXPECT_EQ(captured_tensor_shape, tensorflow::TensorShape());
// Ranked Float32 Tensor, not using fast memory.
status_or_shape =
TypeToShape(RankedTensorType::get({1, 2, 3}, b.getF32Type()),
test_shape_representation_fn);
EXPECT_TRUE(status_or_shape.ok());
EXPECT_EQ(captured_dtype, tensorflow::DataType::DT_FLOAT);
EXPECT_EQ(captured_tensor_shape, tensorflow::TensorShape({1, 2, 3}));
}
TEST(TypeToShapeTest, ConvertMemRefToShape) {
Shape shape = ShapeUtil::MakeShapeWithLayout(PrimitiveType::F32, {10, 20, 30},
{2, 0, 1});
MLIRContext context;
mlir::Builder builder(&context);
StatusOr<mlir::Type> mlir_type =
ConvertShapeToType<MemRefType>(shape, builder);
ASSERT_TRUE(mlir_type.ok());
mlir::Type type = mlir_type.ConsumeValueOrDie();
Shape converted = TypeToShape(type);
EXPECT_TRUE(ShapeUtil::Equal(
converted, ShapeUtil::MakeShapeWithLayout(PrimitiveType::F32,
{10, 20, 30}, {2, 0, 1})));
EXPECT_TRUE(ShapeUtil::Equal(converted, shape));
}
TEST(TypeToShapeTest, ConvertMemRefToShape2) {
Shape shape = ShapeUtil::MakeShapeWithLayout(PrimitiveType::C64, {2, 4, 3, 3},
{2, 3, 1, 0});
MLIRContext context;
mlir::Builder builder(&context);
StatusOr<mlir::Type> mlir_type =
ConvertShapeToType<MemRefType>(shape, builder);
ASSERT_TRUE(mlir_type.ok());
mlir::Type type = mlir_type.ConsumeValueOrDie();
Shape converted = TypeToShape(type);
EXPECT_TRUE(ShapeUtil::Equal(
converted, ShapeUtil::MakeShapeWithLayout(PrimitiveType::C64,
{2, 4, 3, 3}, {2, 3, 1, 0})));
EXPECT_TRUE(ShapeUtil::Equal(converted, shape));
}
} // namespace
} // namespace xla