| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| // This file implements logic for lowering TensorFlow dialect to XLA dialect. |
| |
| #include <numeric> |
| |
| #include "mlir/IR/PatternMatch.h" // TF:local_config_mlir |
| #include "mlir/Pass/Pass.h" // TF:local_config_mlir |
| #include "mlir/StandardOps/Ops.h" // TF:local_config_mlir |
| #include "tensorflow/compiler/mlir/xla/ir/xla_ops.h" |
| #include "tensorflow/compiler/mlir/xla/transforms/passes.h" |
| |
| using namespace mlir; |
| |
| namespace { |
| struct LegalizeTF : public FunctionPass<LegalizeTF> { |
| /// Perform the lowering to XLA dialect. |
| void runOnFunction() override; |
| }; |
| } // end anonymous namespace |
| |
| FunctionPassBase *mlir::XLA::createLegalizeTFPass() { return new LegalizeTF(); } |
| |
| /// Returns if the given TF data format string is the default format. |
| static bool isDefaultDataFormat(StringRef format) { return format == "NHWC"; } |
| |
| /// Returns the feature dimension for the given format and input type. |
| static size_t getFeatureDimension(StringAttr format, |
| RankedTensorType inputType) { |
| return isDefaultDataFormat(format.getValue()) ? inputType.getRank() - 1 : 1; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // BatchNorm op utilities. |
| //===----------------------------------------------------------------------===// |
| |
| static IntegerAttr getFeatureDimensionAttr(Builder &b, StringAttr format, |
| Value *input) { |
| return b.getI64IntegerAttr( |
| getFeatureDimension(format, input->getType().cast<RankedTensorType>())); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Bias op utilities. |
| //===----------------------------------------------------------------------===// |
| |
| /// Returns whether the biasAdd feature dimension is valid or not. |
| static bool hasValidBiasFeatureDimension(StringAttr format, Value *input, |
| Value *bias) { |
| auto inputType = input->getType().cast<RankedTensorType>(); |
| auto biasType = bias->getType().cast<RankedTensorType>(); |
| |
| // There must be enough biases as the feature dimension of the input tensor. |
| size_t featureDim = getFeatureDimension(format, inputType); |
| return biasType.getDimSize(0) == inputType.getDimSize(featureDim); |
| } |
| |
| /// Return a 1D ElementsAttr for the feature dimension of a BiasAdd. |
| static ElementsAttr getBiasFeatureDimension(Builder &b, StringAttr format, |
| Value *input) { |
| return b.getDenseIntElementsAttr( |
| b.getTensorType(1, b.getIntegerType(64)), |
| getFeatureDimension(format, input->getType().cast<RankedTensorType>())); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Binary op utilities. |
| //===----------------------------------------------------------------------===// |
| |
| /// Get a constant splat for the given value type. |
| template <typename T> |
| static ElementsAttr getSplat(Builder &b, Value *val, T constant) { |
| auto valType = val->getType().cast<TensorType>(); |
| auto valElementType = valType.getElementType(); |
| |
| // Handle integer elements. |
| Attribute elementAttr; |
| if (valElementType.isa<IntegerType>()) |
| elementAttr = b.getIntegerAttr(valElementType, constant); |
| else if (valElementType.isa<FloatType>()) |
| elementAttr = b.getFloatAttr(valElementType, constant); |
| else |
| llvm_unreachable("unhandled element type"); |
| return DenseElementsAttr::get(valType, elementAttr); |
| } |
| |
| static ElementsAttr getBroadcastDimensionsAttr(Builder &b, Value *x, Value *y) { |
| TensorType xType = x->getType().dyn_cast<RankedTensorType>(); |
| TensorType yType = y->getType().dyn_cast<RankedTensorType>(); |
| if (xType == yType || !xType || !yType) return {}; |
| |
| // If the shapes have the same rank, then there is nothing to do. |
| auto xRank = xType.getRank(), yRank = yType.getRank(); |
| if (xRank == yRank) return {}; |
| |
| // Otherwise if the ranks of the inputs don't match, TensorFlow automatically |
| // reshapes the smaller by padding with dimensions of size 1 as a prefix. In |
| // other words to pad a 5-vector to a 3-dimensional tensor it is reshaped to |
| // have shape [1,1,5]. XLA's automatic broadcast code is able to broadcast |
| // from lower to higher rank, but doesn't assume you want to pad as a prefix |
| // of the dimensions, and instead needs to be told which dimensions of the |
| // higher rank tensor to match to the lower rank tensor. |
| auto maxRank = std::max(xRank, yRank); |
| auto minRank = std::min(xRank, yRank); |
| |
| // Match the lower rank tensor along the larger-numbered dimensions of the |
| // higher rank tensor. |
| SmallVector<int64_t, 4> broadcastDimensions(minRank); |
| std::iota(broadcastDimensions.begin(), broadcastDimensions.end(), |
| maxRank - minRank); |
| |
| return b.getDenseIntElementsAttr( |
| b.getTensorType({minRank}, b.getIntegerType(64)), broadcastDimensions); |
| } |
| |
| namespace mlir { |
| namespace XLA { |
| namespace { |
| #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc" |
| } // end anonymous namespace |
| } // end namespace XLA |
| } // end namespace mlir |
| |
| /// Perform the lowering to XLA dialect. |
| void LegalizeTF::runOnFunction() { |
| OwningRewritePatternList patterns; |
| auto func = getFunction(); |
| |
| // Add the generated patterns to the list. |
| XLA::populateWithGenerated(func.getContext(), &patterns); |
| applyPatternsGreedily(func, std::move(patterns)); |
| } |
| |
| static PassRegistration<LegalizeTF> pass( |
| "xla-legalize-tf", "Legalize from TensorFlow to the XLA dialect"); |