Finish the port to upstream TF from TF 2.4
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 020f129..4cfb5a6 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -1800,6 +1800,20 @@
return input;
}
+bool IsRowMajor(mlir::Operation* op) {
+ if (auto attr = mlir::GetLayoutFromMlirHlo(op)) {
+ std::vector<int64> minor_to_major;
+ absl::c_transform(
+ attr, std::back_inserter(minor_to_major),
+ std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
+ bool ret = std::is_sorted(minor_to_major.begin(),
+ minor_to_major.end(), std::greater<int64>());
+ return ret;
+ }
+ // It is row major by default.
+ return true;
+}
+
// TODO(timshen): update the comment once the HandleFusion code path deleted.
//
// This is migrated from IrEmitter::HandleFusion() with IrEmitterUnnested as the
@@ -1876,18 +1890,23 @@
}();
bool row_optimized = fusion.getFusionResults().size() == 1 && // Not tested with MOF.
- absl::c_all_of(GetHloOperands(fusion), [](const mlir::Value& op) {
+ absl::c_all_of(GetHloOperands(fusion), [](const mlir::Value& value) {
// Only tested when the inputs are row-major. So only enable that case.
// Maybe it would works if only the inner dimensions is contiguous.
- return true;//TODO: LayoutUtil::IsMonotonicWithDim0Major(instr->shape().layout());
+ if (auto op = value.getDefiningOp()) {
+ return IsRowMajor(value.getDefiningOp());
+ }
+ // Reuse TypeToShape to not duplicate the layout convertion code.
+ return LayoutUtil::IsMonotonicWithDim0Major(TypeToShape(value.getType()).layout());
}) &&
// Only tested when the output is row-major.
- //LayoutUtil::IsMonotonicWithDim0Major(hlo.shape().layout());
- true;
+ absl::c_all_of(GetOutputOps(fusion), IsRowMajor);
+
bool some_row_broadcasting = false;
for (mlir::Operation& op : fusion.region().front()) {
if (mlir::isa<mlir::memref::TensorLoadOp, mlir::memref::TensorStoreOp,
- mlir::lmhlo::TerminatorOp, mlir::mhlo::ReturnOp>(op) ) {
+ mlir::lmhlo::TerminatorOp, mlir::mhlo::ReturnOp,
+ mlir::mhlo::ConstOp, mlir::lmhlo::ConstOp>(op) ) {
continue;
}
HloOpcode opcode = *MhloToHloOpcode(&op);
@@ -1895,10 +1914,10 @@
continue;
}
- if (auto broadcast = mlir::dyn_cast<mlir::mhlo::BroadcastOp>(op)) {
+ if (auto broadcast = mlir::dyn_cast<mlir::mhlo::BroadcastInDimOp>(op)) {
std::vector<int64> broadcast_dimensions;
- if (broadcast.broadcast_sizes().size() > 0) {
- for (const llvm::APInt& int_value : broadcast.broadcast_sizes()) {
+ if (broadcast.broadcast_dimensions().size() > 0) {
+ for (const llvm::APInt& int_value : broadcast.broadcast_dimensions()) {
broadcast_dimensions.push_back(int_value.getSExtValue());
}
}
@@ -1909,12 +1928,13 @@
continue;
}
if (broadcast_dimensions.size() == 1 &&
- broadcast_dimensions.back() != (rank - 1)) {
+ broadcast_dimensions.back() == (rank - 1)) {
some_row_broadcasting = true;
+ continue;
}
}
row_optimized = false;
- VLOG(3) << "Row vectorization not enabled due to this op: " << HloOpcodeString(opcode);
+ VLOG(2) << "Row vectorization not enabled due to this op: " << MlirToString(&op);
break;
}
// Trigger only when there is a row broadcasting.