blob: 6cfd3ae403749b139ce85bee85de99475eeffbee [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines the operations used in the XLA dialect.
#include "tensorflow/compiler/mlir/xla/ir/xla_ops.h"
#include <assert.h>
#include <stddef.h>
#include <stdint.h>
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Dialect.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/xla/ir/xla_ops.h.inc"
using namespace mlir;
using namespace mlir::XLA;
XlaHloDialect::XlaHloDialect(MLIRContext* context)
: Dialect(getDialectNamespace(), context) {
addOperations<
#define GET_OP_LIST
#include "tensorflow/compiler/mlir/xla/ir/xla_ops.cc.inc"
>();
// Support unknown operations because not all XLA operations are registered.
// allowUnknownOperations();
}
Operation* XlaHloDialect::materializeConstant(OpBuilder& builder,
Attribute value, Type type,
Location loc) {
// If this is an opaque elements attribute, then generate an xla_hlo.constant.
if (value.isa<OpaqueElementsAttr>())
return builder.create<XLA::ConstOp>(loc, type, value.cast<ElementsAttr>());
return nullptr;
}
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/xla/ir/xla_ops.cc.inc"
//===----------------------------------------------------------------------===//
// ConstOp
//===----------------------------------------------------------------------===//
OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
assert(operands.empty() && "constant has no operands");
// Return the held attribute value.
return value();
}
// Builds a constant op with the specified attribute `value`.
void ConstOp::build(Builder* builder, OperationState* result, Attribute value) {
Type type;
if (auto elemAttr = value.dyn_cast<ElementsAttr>()) {
type = elemAttr.getType();
} else if (value.isa<BoolAttr>() || value.isa<FloatAttr>() ||
value.isa<IntegerAttr>()) {
// All XLA types must be tensor types. In the build() method, we want to
// provide more flexiblity by allowing attributes of scalar types. But we
// need to wrap it up with ElementsAttr to construct valid XLA constants.
type = RankedTensorType::get(/*shape=*/{}, value.getType());
value = DenseElementsAttr::get(type.cast<TensorType>(), value);
}
// TODO: support other XLA specific types.
assert(type && "unsupported attribute type for building xla_hlo.constant");
result->types.push_back(type);
result->addAttribute("value", value);
}
//===----------------------------------------------------------------------===//
// ConvertOp
//===----------------------------------------------------------------------===//
namespace {
// Converts the values of an ElementsAttr into the corresponding type.
ElementsAttr ConvertElements(const ElementsAttr& elements, Type newType) {
auto oldType = getElementTypeOrSelf(elements);
size_t bitWidth = newType.isBF16() ? 64 : newType.getIntOrFloatBitWidth();
if (oldType.isa<FloatType>()) {
// mapValues always takes a function returning APInt, even when the output
// is actually float.
using func_type = APInt(const APFloat&);
if (auto newFloatType = newType.dyn_cast<FloatType>()) {
// Float -> Float
return elements.mapValues(
newType, llvm::function_ref<func_type>([&newFloatType](
const APFloat& floatVal) {
APFloat newDouble(FloatAttr::getValueAsDouble(floatVal));
bool losesInfo = false;
newDouble.convert(newFloatType.getFloatSemantics(),
llvm::APFloat::rmNearestTiesToEven, &losesInfo);
return newDouble.bitcastToAPInt();
}));
}
// Float -> Int
return elements.mapValues(
newType,
llvm::function_ref<func_type>([&bitWidth](const APFloat& floatVal) {
return APInt(bitWidth, FloatAttr::getValueAsDouble(floatVal));
}));
}
// oldType is Integer
// mapValues always takes a function returning APInt, even when the output
// is actually float.
using func_type = APInt(const APInt&);
if (auto newFloatType = newType.dyn_cast<FloatType>()) {
// Int -> Float
return elements.mapValues(
newType,
llvm::function_ref<func_type>([&newFloatType](const APInt& intVal) {
APFloat newDouble(static_cast<double>(intVal.getLimitedValue()));
bool losesInfo = false;
newDouble.convert(newFloatType.getFloatSemantics(),
llvm::APFloat::rmNearestTiesToEven, &losesInfo);
return newDouble.bitcastToAPInt();
}));
}
// newType is Integer
// Int -> Int
return elements.mapValues(
newType, llvm::function_ref<func_type>([&bitWidth](const APInt& intVal) {
return APInt(bitWidth, intVal.getLimitedValue());
}));
}
} // namespace
OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
if (getOperand()->getType() == getResult()->getType()) return getOperand();
// If the operand is constant, we can do the conversion now.
if (auto elementsAttr = operands.front().dyn_cast_or_null<ElementsAttr>()) {
return ConvertElements(elementsAttr, getElementTypeOrSelf(getResult()));
}
return {};
}
//===----------------------------------------------------------------------===//
// IotaOp
//===----------------------------------------------------------------------===//
OpFoldResult IotaOp::fold(ArrayRef<Attribute> operands) {
const auto output_type = getResult()->getType().cast<ShapedType>();
const auto output_size = output_type.getNumElements();
const auto dimension = iota_dimension().getLimitedValue();
const auto max_dim_size = output_type.getDimSize(dimension);
int bitwidth = output_type.getElementType().getIntOrFloatBitWidth();
llvm::SmallVector<APInt, 10> values;
values.reserve(output_size);
int64_t increase_stride = output_size;
for (int i = 0; i <= dimension; i++) {
increase_stride /= output_type.getDimSize(i);
}
int64_t current_value = 0;
for (int i = 0; i < output_size; i++) {
int64_t value = (current_value / increase_stride) % max_dim_size;
values.push_back(APInt(bitwidth, value));
++current_value;
}
return DenseIntElementsAttr::get(output_type, values);
}
//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
if (getOperand()->getType() == getType()) {
return getOperand();
}
if (auto prev_op =
dyn_cast_or_null<ReshapeOp>(getOperand()->getDefiningOp())) {
setOperand(prev_op.getOperand());
return getResult();
}
if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) {
return elements.reshape(getResult()->getType().cast<ShapedType>());
}
return {};
}
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
for (auto it : llvm::enumerate(permutation().cast<DenseIntElementsAttr>())) {
if (it.index() != it.value()) {
return {};
}
}
return getOperand();
}