[XLA/GPU] Unify functions that converts memref/tensor to an XLA shape.
PiperOrigin-RevId: 382826142
Change-Id: Iceef48e0c9663eda55187b2a126292ec73ad227f
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 36a8ae9..4024751 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -443,6 +443,7 @@
copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]) + if_nccl(["-DGOOGLE_XCCL=1"]),
deps = [
":buffer_allocations",
+ ":ir_emission_utils",
":thunk",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index b09c6ca..a6f65c7 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -1166,8 +1166,7 @@
}
mlir::BlockArgument arg = func.getArgument(i);
- sub_shapes.push_back(
- std::make_pair(shape_index, TypeToShape(arg.getType())));
+ sub_shapes.push_back(std::make_pair(shape_index, GetShape(arg)));
}
}
// Expects result_xla_shape as a XLA shape in string form.
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index f9caaf6..f62f810 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -82,6 +82,27 @@
return values;
}
+Shape GetShapeFromTensorType(mlir::Value value) {
+ constexpr char kDefaultLayoutAttrName[] = "minor_to_major";
+
+ mlir::Operation* op = value.getDefiningOp();
+ CHECK(op);
+ CHECK(value.getType().isa<mlir::TensorType>());
+ Shape shape = TypeToShape(value.getType());
+ if (auto attr = op->getAttrOfType<mlir::DenseIntElementsAttr>(
+ kDefaultLayoutAttrName)) {
+ std::vector<int64> minor_to_major;
+ absl::c_transform(
+ attr, std::back_inserter(minor_to_major),
+ std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
+ *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
+ } else {
+ *shape.mutable_layout() = LayoutUtil::MakeDescendingLayout(
+ value.getType().cast<mlir::ShapedType>().getShape().size());
+ }
+ return shape;
+}
+
} // namespace
bool IsMatrixMultiplication(const HloInstruction& dot) {
@@ -308,13 +329,12 @@
// Propagate layouts inside fusion region.
for (mlir::Operation& op : fusion_op.region().front().without_terminator()) {
if (auto load = mlir::dyn_cast<mlir::memref::TensorLoadOp>(op)) {
- add_layout(load, TypeToShape(load.memref().getType()).layout());
+ add_layout(load, GetShape(load.memref()).layout());
} else if (auto store = mlir::dyn_cast<mlir::memref::TensorStoreOp>(op)) {
// Propagate the stored memref layout to the value if it does not have a
// inferred layout already. This prefers load coalescing over stores.
if (layouts_.count(store.tensor()) == 0) {
- add_layout(store.tensor(),
- TypeToShape(store.memref().getType()).layout());
+ add_layout(store.tensor(), GetShape(store.memref()).layout());
}
} else if (auto bitcast = mlir::dyn_cast<mlir::mhlo::BitcastOp>(op)) {
auto attr =
@@ -389,7 +409,7 @@
// Enable this code to check mismatch between the inferred layout and what was
// there before. Based on actual runs, some mismatches are expected.
#if 0
- Shape operand_shape_ir = TypeToShape(input.getType());
+ Shape operand_shape_ir = GetShape(input);
if (auto tensor_type = input.getType().dyn_cast<mlir::TensorType>()) {
if (auto attr = mlir::GetLayoutFromMlirHlo(input.getDefiningOp())) {
std::vector<int64> minor_to_major;
@@ -491,7 +511,7 @@
ReductionDimensions GetReductionKindAndContiguousComponents(
mlir::Operation* reduce) {
mlir::Value input = reduce->getOperand(0);
- Shape operand_shape = TypeToShape(input.getType());
+ Shape operand_shape = GetShape(input);
std::vector<int64> dimensions;
{
auto attr = reduce->getAttrOfType<mlir::DenseIntElementsAttr>("dimensions");
@@ -911,5 +931,17 @@
return maybe_lhs.ok() && maybe_rhs.ok() && *maybe_lhs == *maybe_rhs;
}
+Shape GetShape(mlir::Value value) {
+ if (value.getType().isa<mlir::MemRefType>()) {
+ return TypeToShape(value.getType());
+ } else if (value.getType().isa<mlir::TensorType>()) {
+ return GetShapeFromTensorType(value);
+ } else if (value.getType().isa<mlir::TupleType>()) {
+ return TypeToShape(value.getType());
+ }
+ LOG(FATAL) << "Unexpected value type to get shape for";
+ return {};
+}
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
index 9e1bcb8..b4af15f 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
@@ -22,6 +22,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
+#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
@@ -303,6 +304,8 @@
mlir::lmhlo::FusionOp fusion,
absl::Span<const BufferAllocation> allocations);
+Shape GetShape(mlir::Value value);
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index a614ca1..3379aa8 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -154,8 +154,6 @@
// efficient.
const int64 kMinDimensionToTransposeTiled = 16;
-constexpr char kDefaultLayoutAttrName[] = "minor_to_major";
-
// Updates the launch dimensions in "thunk" and annotate the launch dimensions
// of the corresponding IR kernel in "llvm_module".
// Precondition: "thunk" must be a KernelThunk.
@@ -282,14 +280,18 @@
}
// Computes the maximum valid unroll factor for a given instruction.
-int ComputeMaxUnrollFactor(const Shape& shape,
+int ComputeMaxUnrollFactor(mlir::Type type,
const HloModuleConfig& hlo_module_config) {
int max_unroll_factor =
hlo_module_config.debug_options().xla_gpu_max_kernel_unroll_factor();
// Find the largest possible power of two to unroll by.
// TODO(kramerb): Make this smarter.
- int64 num_elements = ShapeUtil::ElementsIn(shape);
+
+ auto shaped_type = type.cast<mlir::ShapedType>();
+ int64 num_elements = std::accumulate(shaped_type.getShape().begin(),
+ shaped_type.getShape().end(), int64{1},
+ std::multiplies<int64>());
for (int i = max_unroll_factor; i > 1; i /= 2) {
if (num_elements % i == 0) {
return i;
@@ -301,18 +303,10 @@
}
// Computes the maximum valid unroll factor for a given instruction.
-int ComputeMaxUnrollFactor(const HloInstruction* hlo) {
- const Shape& element_shape = hlo->IsMultiOutputFusion()
- ? ShapeUtil::GetSubshape(hlo->shape(), {0})
- : hlo->shape();
- return ComputeMaxUnrollFactor(element_shape, hlo->GetModule()->config());
-}
-
-// Computes the maximum valid unroll factor for a given instruction.
int ComputeMaxUnrollFactor(mlir::Operation* op,
const HloModuleConfig& hlo_module_config) {
- Shape element_shape = [&] {
- std::vector<Shape> shapes;
+ mlir::Type element_shape = [&] {
+ std::vector<mlir::Type> shapes;
// Detect multi-output fusion. Notice that for a reduce in the fusion that
// returns a tuple, we don't want to treat it as multi-output fusion. We
// want to pass that tuple into ComputeMaxUnrollFactor below. For an actual
@@ -320,17 +314,14 @@
if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
std::vector<mlir::Operation*> fusion_outputs = GetOutputOps(fusion);
for (mlir::Value result : fusion_outputs[0]->getResults()) {
- shapes.push_back(TypeToShape(result.getType()));
+ return result.getType();
}
} else {
for (mlir::Value result : GetHloOutputs(op)) {
- shapes.push_back(TypeToShape(result.getType()));
+ return result.getType();
}
}
- if (shapes.size() > 1) {
- return ShapeUtil::MakeTupleShape(shapes);
- }
- return shapes[0];
+ CHECK(false);
}();
return ComputeMaxUnrollFactor(element_shape, hlo_module_config);
}
@@ -419,13 +410,13 @@
// Check the size of result tensors
for (auto result : GetHloOutputs(op)) {
- if (!shape_in_range(TypeToShape(result.getType()))) {
+ if (!shape_in_range(GetShape(result))) {
return i64_ty;
}
}
auto hlo_shape_in_range = [&](mlir::Value operand) -> bool {
- return shape_in_range(TypeToShape(operand.getType()));
+ return shape_in_range(GetShape(operand));
};
// Check the size of input tensors
@@ -583,8 +574,7 @@
std::vector<mlir::Value> operands;
for (mlir::Value memref : range) {
auto load = b.create<mlir::memref::TensorLoadOp>(loc, memref);
- HloFunctionImporter::SetLayoutForMlir(load,
- TypeToShape(memref.getType()));
+ HloFunctionImporter::SetLayoutForMlir(load, GetShape(memref));
operands.push_back(load);
}
return operands;
@@ -592,10 +582,9 @@
if (auto copy = mlir::dyn_cast<mlir::lmhlo::CopyOp>(op)) {
auto operand = b.create<mlir::memref::TensorLoadOp>(loc, copy.operand());
- HloFunctionImporter::SetLayoutForMlir(
- operand, TypeToShape(copy.operand().getType()));
+ HloFunctionImporter::SetLayoutForMlir(operand, GetShape(copy.operand()));
auto fused_copy = b.create<mlir::mhlo::CopyOp>(loc, operand);
- output_shape = TypeToShape(copy.output().getType());
+ output_shape = GetShape(copy.output());
HloFunctionImporter::SetLayoutForMlir(fused_copy, output_shape);
b.create<mlir::memref::TensorStoreOp>(loc, fused_copy, copy.output());
} else if (auto reduce = mlir::dyn_cast<mlir::lmhlo::ReduceOp>(op)) {
@@ -609,7 +598,7 @@
for (int i = 0; i < reduce.out().size(); i++) {
b.create<mlir::memref::TensorStoreOp>(loc, fused_reduce.getResult(i),
reduce.out()[i]);
- auto shape = TypeToShape(reduce.out()[i].getType());
+ auto shape = GetShape(reduce.out()[i]);
if (i == 0) {
HloFunctionImporter::SetLayoutForMlir(fused_reduce, shape);
}
@@ -791,10 +780,8 @@
// pseudo code for PadToStatic on a 2d array
// int* source_array = input[0];
// int* dest_array = output[0];
- const Shape& data_shape =
- TypeToShape(pad_to_static.output().front().getType());
- const Shape& input_shape =
- TypeToShape(pad_to_static.args().front().getType());
+ const Shape& data_shape = GetShape(pad_to_static.output().front());
+ const Shape& input_shape = GetShape(pad_to_static.args().front());
llvm::Value* source_buffer = source_array.GetBasePointer();
llvm::Value* raw_buffer =
b_.CreateBitCast(source_buffer, b_.getInt8Ty()->getPointerTo());
@@ -812,7 +799,7 @@
for (int64 i = 1; i < pad_to_static.output().size(); ++i) {
// Dynamic size of each dimension is attached at the end of the source
// array(operand(0)). We need to extract these value.
- const Shape& dim_shape = TypeToShape(pad_to_static.output()[i].getType());
+ const Shape& dim_shape = GetShape(pad_to_static.output()[i]);
TF_RET_CHECK(Shape::Equal()(dim_shape, ShapeUtil::MakeScalarShape(S32)));
const int64 dim_index = i - 1;
@@ -910,11 +897,9 @@
auto kernel_thunk,
BuildKernelThunk(slice_to_dynamic, GetThunkInfo(op), &ir_arrays));
- const Shape& input_shape =
- TypeToShape(slice_to_dynamic.args().front().getType());
+ const Shape& input_shape = GetShape(slice_to_dynamic.args().front());
TF_RET_CHECK(slice_to_dynamic.output().size() == 1);
- const Shape& data_shape =
- TypeToShape(slice_to_dynamic.output().front().getType());
+ const Shape& data_shape = GetShape(slice_to_dynamic.output().front());
// TODO(jurahul): data_shape here is the static shape of the output (which has
// a dynamic shape in XLA). Currently, we are mapping that to a static shaped
@@ -1098,13 +1083,11 @@
GpuConvDescriptor descriptor;
auto fill_conv_descriptor = [&](auto op) {
- descriptor.operand0_shape =
- apply_layout(TypeToShape(op->getOperand(0).getType()),
- op.backend_config().operand_0_layout());
- descriptor.operand1_shape =
- apply_layout(TypeToShape(op->getOperand(1).getType()),
- op.backend_config().operand_1_layout());
- descriptor.result_shape = apply_layout(TypeToShape(conv_result.getType()),
+ descriptor.operand0_shape = apply_layout(
+ GetShape(op->getOperand(0)), op.backend_config().operand_0_layout());
+ descriptor.operand1_shape = apply_layout(
+ GetShape(op->getOperand(1)), op.backend_config().operand_1_layout());
+ descriptor.result_shape = apply_layout(GetShape(conv_result),
op.backend_config().result_layout());
descriptor.dnums = ConvertConvDimensionNumbers(op.dimension_numbers());
descriptor.scratch_size = scratch_slice.size();
@@ -1195,9 +1178,9 @@
GpuGemmConfig config;
GemmBackendConfig& backend = config.backend_config;
- config.output_shape = TypeToShape(op.output().getType());
- config.lhs_shape = TypeToShape(op.lhs().getType());
- config.rhs_shape = TypeToShape(op.rhs().getType());
+ config.output_shape = GetShape(op.output());
+ config.lhs_shape = GetShape(op.lhs());
+ config.rhs_shape = GetShape(op.rhs());
backend.Clear();
if (op.algorithm()) {
backend.set_selected_algorithm(*op.algorithm());
@@ -1253,7 +1236,7 @@
/*source_buffer=*/bias,
/*destination_buffer=*/output,
/*mem_size=*/
- ShapeUtil::ByteSizeOf(TypeToShape(gemm.output().getType()))));
+ ShapeUtil::ByteSizeOf(GetShape(gemm.output()))));
TF_ASSIGN_OR_RETURN(
auto thunk,
make_thunk_for_gemm(gemm, bias, gemm_bias_beta,
@@ -1321,8 +1304,7 @@
// Only tested when the inputs are row-major. So only
// enable that case. Maybe it would works if only the
// inner dimensions is contiguous.
- return LayoutUtil::IsMonotonicWithDim0Major(
- TypeToShape(value.getType()).layout());
+ return LayoutUtil::IsMonotonicWithDim0Major(GetShape(value).layout());
};
bool row_vectorized =
fusion.getFusionResults().size() == 1 && // Not tested with MOF.
@@ -1365,7 +1347,7 @@
broadcast_dimensions.push_back(int_value.getSExtValue());
}
- auto rank = TypeToShape(broadcast.getResult().getType()).rank();
+ auto rank = GetShape(broadcast.getResult()).rank();
if (broadcast_dimensions.size() == 1 &&
broadcast_dimensions.back() == (rank - 1)) {
some_row_broadcasting = true;
@@ -1385,7 +1367,7 @@
Status IrEmitterUnnested::EmitBatchNormThunk(mlir::Operation* op) {
auto get_batch_norm_config = [](auto op, mlir::Value output) {
CudnnBatchNormConfig config;
- config.output_shape = TypeToShape(output.getType());
+ config.output_shape = GetShape(output);
config.output_type = config.output_shape.element_type();
config.epsilon = op.epsilon().convertToFloat();
config.feature_index = op.feature_index();
@@ -1461,7 +1443,7 @@
GetAllocationSlice(bn_grad.grad_offset()));
CudnnBatchNormConfig config;
- config.output_shape = TypeToShape(bn_grad.grad_output().getType());
+ config.output_shape = GetShape(bn_grad.grad_output());
config.output_type = config.output_shape.element_type();
config.epsilon = bn_grad.epsilon().convertToFloat();
config.feature_index = bn_grad.feature_index();
@@ -1521,7 +1503,7 @@
Status IrEmitterUnnested::EmitCholeskyThunk(mlir::Operation* op) {
auto cholesky_op = mlir::cast<mlir::lmhlo_gpu::CholeskyOp>(op);
- const Shape shape = TypeToShape(cholesky_op.input().getType());
+ const Shape shape = GetShape(cholesky_op.input());
int ndim = shape.dimensions_size();
CHECK_GE(ndim, 2);
int64 n = shape.dimensions(ndim - 1);
@@ -1632,8 +1614,8 @@
Status IrEmitterUnnested::EmitFftThunk(mlir::Operation* op) {
auto fft_op = mlir::cast<mlir::lmhlo::FftOp>(op);
- const Shape operand_shape = TypeToShape(fft_op.operand().getType());
- const Shape output_shape = TypeToShape(fft_op.output().getType());
+ const Shape operand_shape = GetShape(fft_op.operand());
+ const Shape output_shape = GetShape(fft_op.output());
TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand_shape.layout()));
TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(output_shape.layout()));
@@ -1665,10 +1647,9 @@
TF_RET_CHECK(has_fortran_layout(triangular_solve_op.layout_b()));
TF_RET_CHECK(has_fortran_layout(triangular_solve_op.layout_output()));
- const Shape b_shape = TypeToShape(triangular_solve_op.b().getType());
+ const Shape b_shape = GetShape(triangular_solve_op.b());
- const Shape output_shape =
- TypeToShape(triangular_solve_op.output().getType());
+ const Shape output_shape = GetShape(triangular_solve_op.output());
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice a_slice,
GetAllocationSlice(triangular_solve_op.a()));
@@ -1759,33 +1740,14 @@
for (auto load : loads) {
auto arg = region->addArgument(load.getType());
load.replaceAllUsesWith(arg);
- Shape shape = TypeToShape(load.getType());
- if (auto attr = load->getAttrOfType<mlir::DenseIntElementsAttr>(
- kDefaultLayoutAttrName)) {
- std::vector<int64> minor_to_major;
- absl::c_transform(
- attr, std::back_inserter(minor_to_major),
- std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
- *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
- } else {
- *shape.mutable_layout() =
- LayoutUtil::MakeDescendingLayout(load.getType().getShape().size());
- }
+ Shape shape = GetShape(load.getResult());
operand_shapes->push_back(std::move(shape));
load.erase();
}
std::vector<mlir::Value> returned_values;
for (auto store : stores) {
- Shape shape = TypeToShape(store.memref().getType());
- if (auto attr = store->getAttrOfType<mlir::DenseIntElementsAttr>(
- kDefaultLayoutAttrName)) {
- std::vector<int64> minor_to_major;
- absl::c_transform(
- attr, std::back_inserter(minor_to_major),
- std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
- *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
- }
+ Shape shape = GetShape(store.memref());
output_shapes->push_back(shape);
returned_values.push_back(store.tensor());
@@ -2132,8 +2094,8 @@
Status IrEmitterUnnested::EmitCopy(mlir::Operation* op) {
auto copy = mlir::cast<mlir::lmhlo::CopyOp>(op);
- auto operand_shape = TypeToShape(copy.operand().getType());
- auto output_shape = TypeToShape(copy.output().getType());
+ auto operand_shape = GetShape(copy.operand());
+ auto output_shape = GetShape(copy.output());
CHECK(ShapeUtil::Compatible(operand_shape, output_shape));
auto maybe_slice = GetAllocationSlice(copy.operand());
@@ -2204,10 +2166,8 @@
Status IrEmitterUnnested::EmitSelectAndScatter(mlir::Operation* op) {
auto select_and_scatter_op = mlir::cast<mlir::lmhlo::SelectAndScatterOp>(op);
- const Shape source_shape =
- TypeToShape(select_and_scatter_op.source().getType());
- const Shape operand_shape =
- TypeToShape(select_and_scatter_op.operand().getType());
+ const Shape source_shape = GetShape(select_and_scatter_op.source());
+ const Shape operand_shape = GetShape(select_and_scatter_op.operand());
const int64 rank = operand_shape.rank();
CHECK_EQ(rank, source_shape.rank());
@@ -2408,8 +2368,7 @@
InBoundsGEP(selected_index_address, {b_.getInt32(i)});
selected_multi_index.push_back(Load(selected_index_address_slot));
}
- const Shape output_shape =
- TypeToShape(select_and_scatter_op.out().getType());
+ const Shape output_shape = GetShape(select_and_scatter_op.out());
llvm::Value* source_value_address =
source_array.EmitArrayElementAddress(source_index, &b_);
IrArray::Index selected_index(selected_multi_index, output_shape,
@@ -2478,7 +2437,7 @@
llvm::Value* old_state =
llvm_ir::RngGetAndUpdateState(rng_op.delta(), module_, &b_);
- const Shape shape = TypeToShape(rng_op.state().getType());
+ const Shape shape = GetShape(rng_op.state());
llvm::Value* output_address = ir_arrays[0].EmitArrayElementAddress(
llvm_ir::IrArray::Index(
@@ -2515,7 +2474,7 @@
/*source_address=*/operand_buffer,
/*destination_buffer=*/output_buffer,
/*mem_size=*/
- ShapeUtil::ByteSizeOf(TypeToShape(scatter_op.output().getType()))));
+ ShapeUtil::ByteSizeOf(GetShape(scatter_op.output()))));
}
// Create kernel thunk for all operands except the first one (`operand`). The
@@ -2568,9 +2527,8 @@
const llvm_ir::ElementGenerator& scatter_indices_gen,
const llvm_ir::ElementGenerator& updates_gen,
std::function<llvm::Type*(int64)> get_index_type) {
- const Shape operand_shape = TypeToShape(scatter.operand().getType());
- CHECK(
- ShapeUtil::Equal(TypeToShape(scatter.output().getType()), operand_shape));
+ const Shape operand_shape = GetShape(scatter.operand());
+ CHECK(ShapeUtil::Equal(GetShape(scatter.output()), operand_shape));
TF_ASSIGN_OR_RETURN(
const HloComputation* update_computation,
@@ -2580,8 +2538,8 @@
ScatterDescriptor desc;
desc.name = mlir::GetNameFromLoc(scatter.getLoc());
desc.operand_shape = operand_shape;
- desc.scatter_indices_shape = TypeToShape(scatter.scatter_indices().getType());
- desc.updates_shape = TypeToShape(scatter.updates().getType());
+ desc.scatter_indices_shape = GetShape(scatter.scatter_indices());
+ desc.updates_shape = GetShape(scatter.updates());
desc.dim_numbers = scatter.scatter_dimension_numbers();
desc.unique_indices = scatter.unique_indices();
desc.update_computation = update_computation;
@@ -3064,7 +3022,7 @@
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
GetAllocationSlice(collective_permute_op.output()));
- const Shape shape = TypeToShape(collective_permute_op.operand().getType());
+ const Shape shape = GetShape(collective_permute_op.operand());
const int64 replica_count = hlo_module_config_.replica_count();
const int64 partition_count = hlo_module_config_.num_partitions();
@@ -3141,7 +3099,7 @@
for (auto it : llvm::zip(op.operands(), op.results())) {
mlir::Value operand = std::get<0>(it);
mlir::Value result = std::get<1>(it);
- const Shape shape = TypeToShape(operand.getType());
+ const Shape shape = GetShape(operand);
TF_ASSIGN_OR_RETURN(auto source_slice, GetAllocationSlice(operand));
TF_ASSIGN_OR_RETURN(auto dest_slice, GetAllocationSlice(result));
buffers.push_back(NcclCollectiveThunk::Buffer{
@@ -3176,7 +3134,7 @@
CollectiveOpGroupModeToString(group_mode), op.operands().size(),
NcclThunkType::NcclIsEnabled());
if (!op.operands().empty()) {
- const Shape shape = TypeToShape(op.operands().front().getType());
+ const Shape shape = GetShape(op.operands().front());
absl::StrAppendFormat(&message, "; first operand array element-type: %s",
PrimitiveType_Name(shape.element_type()));
}
@@ -3187,7 +3145,7 @@
// assignment expects a copy, so that's what we do.
ThunkSequence thunks;
for (int64 i = 0; i < buffers.size(); i++) {
- const Shape shape = TypeToShape(op.operands()[i].getType());
+ const Shape shape = GetShape(op.operands()[i]);
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
buffers.size() == 1 ? GetThunkInfo(op) : Thunk::ThunkInfo(),
/*source_address=*/buffers[i].source_buffer,
@@ -3229,7 +3187,7 @@
for (mlir::Value output : infeed_op.outputs()) {
TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(output));
- const Shape& shape = TypeToShape(output.getType());
+ const Shape& shape = GetShape(output);
dest_slices.push_back(ShapedSlice{slice, shape});
}
@@ -3246,7 +3204,7 @@
for (mlir::Value operand : outfeed_op.operands()) {
TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(operand));
- const Shape& shape = TypeToShape(operand.getType());
+ const Shape& shape = GetShape(operand);
source_slices.push_back(ShapedSlice{slice, shape});
}
@@ -3376,7 +3334,7 @@
TF_ASSIGN_OR_RETURN(slice.buffer_slice,
GetAllocationSlice(operand, &slice.constant_name));
slice.written = WritesMlirBuffer(op, operand);
- slice.shape = TypeToShape(operand.getType());
+ slice.shape = GetShape(operand);
}
std::string name = mlir::GetNameFromLoc(op->getLoc());
return BuildKernelThunkImpl(name, thunk_info, slices, ir_arrays);
@@ -3396,7 +3354,7 @@
TF_ASSIGN_OR_RETURN(slice.buffer_slice,
GetAllocationSlice(operand, &slice.constant_name));
slice.written = false;
- slice.shape = TypeToShape(operand.getType());
+ slice.shape = GetShape(operand);
}
for (auto output : outputs) {
slices.emplace_back();
@@ -3404,7 +3362,7 @@
TF_ASSIGN_OR_RETURN(slice.buffer_slice,
GetAllocationSlice(output, &slice.constant_name));
slice.written = true;
- slice.shape = TypeToShape(output.getType());
+ slice.shape = GetShape(output);
}
std::string name = mlir::GetNameFromLoc(op->getLoc());
return BuildKernelThunkImpl(name, thunk_info, slices, ir_arrays);
@@ -3480,7 +3438,7 @@
TF_ASSIGN_OR_RETURN(auto dest_slice, GetAllocationSlice(dest));
- const Shape dest_shape = TypeToShape(dest.getType());
+ const Shape dest_shape = GetShape(dest);
auto thunk =
BuildConstantInitializerThunk(literal_bytes, dest_slice, dest_shape);
if (thunk) {
@@ -3511,7 +3469,7 @@
const llvm_ir::IrArray init_array = ir_arrays[0];
const llvm_ir::IrArray dest_array = ir_arrays[1];
- const Shape dest_shape = TypeToShape(dest.getType());
+ const Shape dest_shape = GetShape(dest);
TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
CalculateLaunchDimensions(
dest_shape, ir_emitter_context_->gpu_device_info()));
@@ -3554,7 +3512,7 @@
const llvm_ir::IrArray dest_array =
ir_arrays[input_buffers.size() + output_index];
- const Shape dest_shape = TypeToShape(dest.getType());
+ const Shape dest_shape = GetShape(dest);
TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
CalculateLaunchDimensions(
dest_shape, ir_emitter_context_->gpu_device_info()));
@@ -4901,7 +4859,7 @@
const absl::flat_hash_set<mlir::Operation*>& use_chain_endings) {
return absl::c_count_if(
fusion.getFusionParameters(), [&](mlir::Value parameter) {
- Shape parameter_shape = TypeToShape(parameter.getType());
+ Shape parameter_shape = GetShape(parameter);
return ShapeUtil::SameDimensions(op_shape, parameter_shape) &&
AreUsersElementwise(parameter, use_chain_endings);
});
@@ -4914,7 +4872,7 @@
int64 num_elements = ShapeUtil::ElementsIn(shape);
return absl::c_count_if(
fusion.getFusionParameters(), [&](mlir::Value parameter) {
- Shape parameter_shape = TypeToShape(parameter.getType());
+ Shape parameter_shape = GetShape(parameter);
return ShapeUtil::ElementsIn(parameter_shape) > num_elements;
});
}
@@ -4999,7 +4957,7 @@
<< reduction_dimensions.dimensions[2];
auto get_dtype_bits = [](mlir::Value i) {
// TODO(timshen): may not be efficient.
- return primitive_util::BitWidth(TypeToShape(i.getType()).element_type());
+ return primitive_util::BitWidth(GetShape(i).element_type());
};
// For fusion with multiple inputs, use the smallest input dtype to
@@ -5283,7 +5241,7 @@
}
}
}
- Shape input_shape = TypeToShape(first_reduce->getOperand(0).getType());
+ Shape input_shape = GetShape(first_reduce->getOperand(0));
// The layout of a reduction input is either set by LayoutAssignment for
// unnested kReduce or by InstructionFusion for fused kReduce.
CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion "
@@ -5688,10 +5646,10 @@
auto operands = GetHloOperands(op);
auto outputs = GetHloOutputs(op);
for (auto operand : operands) {
- operand_shapes.push_back(TypeToShape(operand.getType()));
+ operand_shapes.push_back(GetShape(operand));
}
for (auto output : outputs) {
- output_shapes.push_back(TypeToShape(output.getType()));
+ output_shapes.push_back(GetShape(output));
}
}
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.cc
index 4e69477..e1c2640 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.cc
@@ -24,6 +24,7 @@
#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/util.h"
@@ -41,7 +42,7 @@
/*static*/ bool NcclAllGatherThunk::CanImplement(mlir::lmhlo::AllGatherOp op) {
return absl::c_all_of(op.operands(), [&](mlir::Value operand) {
- Shape shape = TypeToShape(operand.getType());
+ Shape shape = GetShape(operand);
return LayoutUtil::IsDenseArray(shape) &&
IsTypeSupportedByNccl(shape.element_type()) &&
LayoutUtil::MinorToMajor(shape).back() == op.all_gather_dimension();
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc
index 66dd7f8..0b97381 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc
@@ -108,7 +108,7 @@
if (!opcode.ok()) return absl::nullopt;
// Match the operation to a reduction kind. We can represent and/or of pred as
// min/max. This works because pred is stored as an 8-bit int of value 0 or 1.
- PrimitiveType type = TypeToShape(result.getType()).element_type();
+ PrimitiveType type = GetShape(result).element_type();
if (type == PRED) {
switch (opcode.ValueOrDie()) {
case HloOpcode::kAnd:
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.cc
index 151749d..66cfbbb 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.cc
@@ -25,6 +25,7 @@
#include "absl/strings/str_format.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -45,7 +46,7 @@
/*static*/ bool NcclAllToAllThunk::CanImplement(mlir::lmhlo::AllToAllOp op) {
return absl::c_all_of(op.operands(), [&op](mlir::Value operand) {
- Shape shape = TypeToShape(operand.getType());
+ Shape shape = GetShape(operand);
return LayoutUtil::IsDenseArray(shape) &&
IsTypeSupportedByNccl(shape.element_type()) &&
(!op.split_dimension() ||
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_collective_permute_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_collective_permute_thunk.cc
index d33d389..68b083f 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_collective_permute_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/nccl_collective_permute_thunk.cc
@@ -24,6 +24,7 @@
#include "absl/types/optional.h"
#include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
@@ -36,7 +37,7 @@
NcclCollectivePermuteConfig config;
config.operand_count = 1;
- const Shape shape = TypeToShape(op.operand().getType());
+ const Shape shape = GetShape(op.operand());
config.operand_element_type.push_back(shape.element_type());
config.SetCollectiveOpKindAndID(op);
config.group_mode = GetGroupMode(op);
@@ -89,7 +90,7 @@
/*static*/ bool NcclCollectivePermuteThunk::CanImplement(
mlir::lmhlo::CollectivePermuteOp op) {
- const Shape shape = TypeToShape(op.operand().getType());
+ const Shape shape = GetShape(op.operand());
return IsTypeSupportedByNccl(shape.element_type());
}
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h
index 552c98c..8b2c7bf 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h
@@ -22,6 +22,7 @@
#include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -98,7 +99,7 @@
config.operand_count = op.operands().size();
config.operand_element_type.reserve(config.operand_count);
for (int i = 0; i < config.operand_count; i++) {
- const Shape shape = TypeToShape(op.operands()[i].getType());
+ const Shape shape = GetShape(op.operands()[i]);
config.operand_element_type.push_back(shape.element_type());
}
config.replica_groups =