blob: a10329cea06270a3de3d875ba60dbf55296f3f85 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements logic for lowering TensorFlow dialect to XLA dialect.
#include <numeric>
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/xla/ir/xla_ops.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
using namespace mlir;
namespace {
struct LegalizeTF : public FunctionPass<LegalizeTF> {
/// Perform the lowering to XLA dialect.
void runOnFunction() override;
};
} // end anonymous namespace
FunctionPassBase *mlir::XLA::createLegalizeTFPass() { return new LegalizeTF(); }
/// Returns if the given TF data format string is the default format.
static bool isDefaultDataFormat(StringRef format) { return format == "NHWC"; }
/// Returns the feature dimension for the given format and input type.
static size_t getFeatureDimension(StringAttr format,
RankedTensorType inputType) {
return isDefaultDataFormat(format.getValue()) ? inputType.getRank() - 1 : 1;
}
//===----------------------------------------------------------------------===//
// BatchNorm op utilities.
//===----------------------------------------------------------------------===//
static IntegerAttr getFeatureDimensionAttr(Builder &b, StringAttr format,
Value *input) {
return b.getI64IntegerAttr(
getFeatureDimension(format, input->getType().cast<RankedTensorType>()));
}
//===----------------------------------------------------------------------===//
// Bias op utilities.
//===----------------------------------------------------------------------===//
/// Returns whether the biasAdd feature dimension is valid or not.
static bool hasValidBiasFeatureDimension(StringAttr format, Value *input,
Value *bias) {
auto inputType = input->getType().cast<RankedTensorType>();
auto biasType = bias->getType().cast<RankedTensorType>();
// There must be enough biases as the feature dimension of the input tensor.
size_t featureDim = getFeatureDimension(format, inputType);
return biasType.getDimSize(0) == inputType.getDimSize(featureDim);
}
/// Return a 1D ElementsAttr for the feature dimension of a BiasAdd.
static ElementsAttr getBiasFeatureDimension(Builder &b, StringAttr format,
Value *input) {
return b.getDenseIntElementsAttr(
b.getTensorType(1, b.getIntegerType(64)),
getFeatureDimension(format, input->getType().cast<RankedTensorType>()));
}
//===----------------------------------------------------------------------===//
// Binary op utilities.
//===----------------------------------------------------------------------===//
/// Get a constant splat for the given value type.
template <typename T>
static ElementsAttr getSplat(Builder &b, Value *val, T constant) {
auto valType = val->getType().cast<TensorType>();
auto valElementType = valType.getElementType();
// Handle integer elements.
Attribute elementAttr;
if (valElementType.isa<IntegerType>())
elementAttr = b.getIntegerAttr(valElementType, constant);
else if (valElementType.isa<FloatType>())
elementAttr = b.getFloatAttr(valElementType, constant);
else
llvm_unreachable("unhandled element type");
return DenseElementsAttr::get(valType, elementAttr);
}
static ElementsAttr getBroadcastDimensionsAttr(Builder &b, Value *x, Value *y) {
TensorType xType = x->getType().dyn_cast<RankedTensorType>();
TensorType yType = y->getType().dyn_cast<RankedTensorType>();
if (xType == yType || !xType || !yType) return {};
// If the shapes have the same rank, then there is nothing to do.
auto xRank = xType.getRank(), yRank = yType.getRank();
if (xRank == yRank) return {};
// Otherwise if the ranks of the inputs don't match, TensorFlow automatically
// reshapes the smaller by padding with dimensions of size 1 as a prefix. In
// other words to pad a 5-vector to a 3-dimensional tensor it is reshaped to
// have shape [1,1,5]. XLA's automatic broadcast code is able to broadcast
// from lower to higher rank, but doesn't assume you want to pad as a prefix
// of the dimensions, and instead needs to be told which dimensions of the
// higher rank tensor to match to the lower rank tensor.
auto maxRank = std::max(xRank, yRank);
auto minRank = std::min(xRank, yRank);
// Match the lower rank tensor along the larger-numbered dimensions of the
// higher rank tensor.
SmallVector<int64_t, 4> broadcastDimensions(minRank);
std::iota(broadcastDimensions.begin(), broadcastDimensions.end(),
maxRank - minRank);
return b.getDenseIntElementsAttr(
b.getTensorType({minRank}, b.getIntegerType(64)), broadcastDimensions);
}
namespace mlir {
namespace XLA {
namespace {
#include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
} // end anonymous namespace
} // end namespace XLA
} // end namespace mlir
/// Perform the lowering to XLA dialect.
void LegalizeTF::runOnFunction() {
OwningRewritePatternList patterns;
auto func = getFunction();
// Add the generated patterns to the list.
XLA::populateWithGenerated(func.getContext(), &patterns);
applyPatternsGreedily(func, std::move(patterns));
}
static PassRegistration<LegalizeTF> pass(
"xla-legalize-tf", "Legalize from TensorFlow to the XLA dialect");