[XLA/GPU] Unify functions that converts memref/tensor to an XLA shape.

PiperOrigin-RevId: 382826142
Change-Id: Iceef48e0c9663eda55187b2a126292ec73ad227f
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 36a8ae9..4024751 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -443,6 +443,7 @@
     copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]) + if_nccl(["-DGOOGLE_XCCL=1"]),
     deps = [
         ":buffer_allocations",
+        ":ir_emission_utils",
         ":thunk",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:flat_hash_set",
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index b09c6ca..a6f65c7 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -1166,8 +1166,7 @@
       }
 
       mlir::BlockArgument arg = func.getArgument(i);
-      sub_shapes.push_back(
-          std::make_pair(shape_index, TypeToShape(arg.getType())));
+      sub_shapes.push_back(std::make_pair(shape_index, GetShape(arg)));
     }
   }
   // Expects result_xla_shape as a XLA shape in string form.
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index f9caaf6..f62f810 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -82,6 +82,27 @@
   return values;
 }
 
+Shape GetShapeFromTensorType(mlir::Value value) {
+  constexpr char kDefaultLayoutAttrName[] = "minor_to_major";
+
+  mlir::Operation* op = value.getDefiningOp();
+  CHECK(op);
+  CHECK(value.getType().isa<mlir::TensorType>());
+  Shape shape = TypeToShape(value.getType());
+  if (auto attr = op->getAttrOfType<mlir::DenseIntElementsAttr>(
+          kDefaultLayoutAttrName)) {
+    std::vector<int64> minor_to_major;
+    absl::c_transform(
+        attr, std::back_inserter(minor_to_major),
+        std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
+    *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
+  } else {
+    *shape.mutable_layout() = LayoutUtil::MakeDescendingLayout(
+        value.getType().cast<mlir::ShapedType>().getShape().size());
+  }
+  return shape;
+}
+
 }  // namespace
 
 bool IsMatrixMultiplication(const HloInstruction& dot) {
@@ -308,13 +329,12 @@
   // Propagate layouts inside fusion region.
   for (mlir::Operation& op : fusion_op.region().front().without_terminator()) {
     if (auto load = mlir::dyn_cast<mlir::memref::TensorLoadOp>(op)) {
-      add_layout(load, TypeToShape(load.memref().getType()).layout());
+      add_layout(load, GetShape(load.memref()).layout());
     } else if (auto store = mlir::dyn_cast<mlir::memref::TensorStoreOp>(op)) {
       // Propagate the stored memref layout to the value if it does not have a
       // inferred layout already. This prefers load coalescing over stores.
       if (layouts_.count(store.tensor()) == 0) {
-        add_layout(store.tensor(),
-                   TypeToShape(store.memref().getType()).layout());
+        add_layout(store.tensor(), GetShape(store.memref()).layout());
       }
     } else if (auto bitcast = mlir::dyn_cast<mlir::mhlo::BitcastOp>(op)) {
       auto attr =
@@ -389,7 +409,7 @@
   // Enable this code to check mismatch between the inferred layout and what was
   // there before. Based on actual runs, some mismatches are expected.
 #if 0
-  Shape operand_shape_ir = TypeToShape(input.getType());
+  Shape operand_shape_ir = GetShape(input);
   if (auto tensor_type = input.getType().dyn_cast<mlir::TensorType>()) {
     if (auto attr = mlir::GetLayoutFromMlirHlo(input.getDefiningOp())) {
       std::vector<int64> minor_to_major;
@@ -491,7 +511,7 @@
 ReductionDimensions GetReductionKindAndContiguousComponents(
     mlir::Operation* reduce) {
   mlir::Value input = reduce->getOperand(0);
-  Shape operand_shape = TypeToShape(input.getType());
+  Shape operand_shape = GetShape(input);
   std::vector<int64> dimensions;
   {
     auto attr = reduce->getAttrOfType<mlir::DenseIntElementsAttr>("dimensions");
@@ -911,5 +931,17 @@
   return maybe_lhs.ok() && maybe_rhs.ok() && *maybe_lhs == *maybe_rhs;
 }
 
+Shape GetShape(mlir::Value value) {
+  if (value.getType().isa<mlir::MemRefType>()) {
+    return TypeToShape(value.getType());
+  } else if (value.getType().isa<mlir::TensorType>()) {
+    return GetShapeFromTensorType(value);
+  } else if (value.getType().isa<mlir::TupleType>()) {
+    return TypeToShape(value.getType());
+  }
+  LOG(FATAL) << "Unexpected value type to get shape for";
+  return {};
+}
+
 }  // namespace gpu
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
index 9e1bcb8..b4af15f 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
@@ -22,6 +22,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Value.h"
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
@@ -303,6 +304,8 @@
     mlir::lmhlo::FusionOp fusion,
     absl::Span<const BufferAllocation> allocations);
 
+Shape GetShape(mlir::Value value);
+
 }  // namespace gpu
 }  // namespace xla
 
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index a614ca1..3379aa8 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -154,8 +154,6 @@
 // efficient.
 const int64 kMinDimensionToTransposeTiled = 16;
 
-constexpr char kDefaultLayoutAttrName[] = "minor_to_major";
-
 // Updates the launch dimensions in "thunk" and annotate the launch dimensions
 // of the corresponding IR kernel in "llvm_module".
 // Precondition: "thunk" must be a KernelThunk.
@@ -282,14 +280,18 @@
 }
 
 // Computes the maximum valid unroll factor for a given instruction.
-int ComputeMaxUnrollFactor(const Shape& shape,
+int ComputeMaxUnrollFactor(mlir::Type type,
                            const HloModuleConfig& hlo_module_config) {
   int max_unroll_factor =
       hlo_module_config.debug_options().xla_gpu_max_kernel_unroll_factor();
 
   // Find the largest possible power of two to unroll by.
   // TODO(kramerb): Make this smarter.
-  int64 num_elements = ShapeUtil::ElementsIn(shape);
+
+  auto shaped_type = type.cast<mlir::ShapedType>();
+  int64 num_elements = std::accumulate(shaped_type.getShape().begin(),
+                                       shaped_type.getShape().end(), int64{1},
+                                       std::multiplies<int64>());
   for (int i = max_unroll_factor; i > 1; i /= 2) {
     if (num_elements % i == 0) {
       return i;
@@ -301,18 +303,10 @@
 }
 
 // Computes the maximum valid unroll factor for a given instruction.
-int ComputeMaxUnrollFactor(const HloInstruction* hlo) {
-  const Shape& element_shape = hlo->IsMultiOutputFusion()
-                                   ? ShapeUtil::GetSubshape(hlo->shape(), {0})
-                                   : hlo->shape();
-  return ComputeMaxUnrollFactor(element_shape, hlo->GetModule()->config());
-}
-
-// Computes the maximum valid unroll factor for a given instruction.
 int ComputeMaxUnrollFactor(mlir::Operation* op,
                            const HloModuleConfig& hlo_module_config) {
-  Shape element_shape = [&] {
-    std::vector<Shape> shapes;
+  mlir::Type element_shape = [&] {
+    std::vector<mlir::Type> shapes;
     // Detect multi-output fusion. Notice that for a reduce in the fusion that
     // returns a tuple, we don't want to treat it as multi-output fusion. We
     // want to pass that tuple into ComputeMaxUnrollFactor below. For an actual
@@ -320,17 +314,14 @@
     if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
       std::vector<mlir::Operation*> fusion_outputs = GetOutputOps(fusion);
       for (mlir::Value result : fusion_outputs[0]->getResults()) {
-        shapes.push_back(TypeToShape(result.getType()));
+        return result.getType();
       }
     } else {
       for (mlir::Value result : GetHloOutputs(op)) {
-        shapes.push_back(TypeToShape(result.getType()));
+        return result.getType();
       }
     }
-    if (shapes.size() > 1) {
-      return ShapeUtil::MakeTupleShape(shapes);
-    }
-    return shapes[0];
+    CHECK(false);
   }();
   return ComputeMaxUnrollFactor(element_shape, hlo_module_config);
 }
@@ -419,13 +410,13 @@
 
   // Check the size of result tensors
   for (auto result : GetHloOutputs(op)) {
-    if (!shape_in_range(TypeToShape(result.getType()))) {
+    if (!shape_in_range(GetShape(result))) {
       return i64_ty;
     }
   }
 
   auto hlo_shape_in_range = [&](mlir::Value operand) -> bool {
-    return shape_in_range(TypeToShape(operand.getType()));
+    return shape_in_range(GetShape(operand));
   };
 
   // Check the size of input tensors
@@ -583,8 +574,7 @@
     std::vector<mlir::Value> operands;
     for (mlir::Value memref : range) {
       auto load = b.create<mlir::memref::TensorLoadOp>(loc, memref);
-      HloFunctionImporter::SetLayoutForMlir(load,
-                                            TypeToShape(memref.getType()));
+      HloFunctionImporter::SetLayoutForMlir(load, GetShape(memref));
       operands.push_back(load);
     }
     return operands;
@@ -592,10 +582,9 @@
 
   if (auto copy = mlir::dyn_cast<mlir::lmhlo::CopyOp>(op)) {
     auto operand = b.create<mlir::memref::TensorLoadOp>(loc, copy.operand());
-    HloFunctionImporter::SetLayoutForMlir(
-        operand, TypeToShape(copy.operand().getType()));
+    HloFunctionImporter::SetLayoutForMlir(operand, GetShape(copy.operand()));
     auto fused_copy = b.create<mlir::mhlo::CopyOp>(loc, operand);
-    output_shape = TypeToShape(copy.output().getType());
+    output_shape = GetShape(copy.output());
     HloFunctionImporter::SetLayoutForMlir(fused_copy, output_shape);
     b.create<mlir::memref::TensorStoreOp>(loc, fused_copy, copy.output());
   } else if (auto reduce = mlir::dyn_cast<mlir::lmhlo::ReduceOp>(op)) {
@@ -609,7 +598,7 @@
     for (int i = 0; i < reduce.out().size(); i++) {
       b.create<mlir::memref::TensorStoreOp>(loc, fused_reduce.getResult(i),
                                             reduce.out()[i]);
-      auto shape = TypeToShape(reduce.out()[i].getType());
+      auto shape = GetShape(reduce.out()[i]);
       if (i == 0) {
         HloFunctionImporter::SetLayoutForMlir(fused_reduce, shape);
       }
@@ -791,10 +780,8 @@
   // pseudo code for PadToStatic on a 2d array
   //   int* source_array = input[0];
   //   int* dest_array = output[0];
-  const Shape& data_shape =
-      TypeToShape(pad_to_static.output().front().getType());
-  const Shape& input_shape =
-      TypeToShape(pad_to_static.args().front().getType());
+  const Shape& data_shape = GetShape(pad_to_static.output().front());
+  const Shape& input_shape = GetShape(pad_to_static.args().front());
   llvm::Value* source_buffer = source_array.GetBasePointer();
   llvm::Value* raw_buffer =
       b_.CreateBitCast(source_buffer, b_.getInt8Ty()->getPointerTo());
@@ -812,7 +799,7 @@
   for (int64 i = 1; i < pad_to_static.output().size(); ++i) {
     // Dynamic size of each dimension is attached at the end of the source
     // array(operand(0)). We need to extract these value.
-    const Shape& dim_shape = TypeToShape(pad_to_static.output()[i].getType());
+    const Shape& dim_shape = GetShape(pad_to_static.output()[i]);
     TF_RET_CHECK(Shape::Equal()(dim_shape, ShapeUtil::MakeScalarShape(S32)));
 
     const int64 dim_index = i - 1;
@@ -910,11 +897,9 @@
       auto kernel_thunk,
       BuildKernelThunk(slice_to_dynamic, GetThunkInfo(op), &ir_arrays));
 
-  const Shape& input_shape =
-      TypeToShape(slice_to_dynamic.args().front().getType());
+  const Shape& input_shape = GetShape(slice_to_dynamic.args().front());
   TF_RET_CHECK(slice_to_dynamic.output().size() == 1);
-  const Shape& data_shape =
-      TypeToShape(slice_to_dynamic.output().front().getType());
+  const Shape& data_shape = GetShape(slice_to_dynamic.output().front());
 
   // TODO(jurahul): data_shape here is the static shape of the output (which has
   // a dynamic shape in XLA). Currently, we are mapping that to a static shaped
@@ -1098,13 +1083,11 @@
   GpuConvDescriptor descriptor;
 
   auto fill_conv_descriptor = [&](auto op) {
-    descriptor.operand0_shape =
-        apply_layout(TypeToShape(op->getOperand(0).getType()),
-                     op.backend_config().operand_0_layout());
-    descriptor.operand1_shape =
-        apply_layout(TypeToShape(op->getOperand(1).getType()),
-                     op.backend_config().operand_1_layout());
-    descriptor.result_shape = apply_layout(TypeToShape(conv_result.getType()),
+    descriptor.operand0_shape = apply_layout(
+        GetShape(op->getOperand(0)), op.backend_config().operand_0_layout());
+    descriptor.operand1_shape = apply_layout(
+        GetShape(op->getOperand(1)), op.backend_config().operand_1_layout());
+    descriptor.result_shape = apply_layout(GetShape(conv_result),
                                            op.backend_config().result_layout());
     descriptor.dnums = ConvertConvDimensionNumbers(op.dimension_numbers());
     descriptor.scratch_size = scratch_slice.size();
@@ -1195,9 +1178,9 @@
 
     GpuGemmConfig config;
     GemmBackendConfig& backend = config.backend_config;
-    config.output_shape = TypeToShape(op.output().getType());
-    config.lhs_shape = TypeToShape(op.lhs().getType());
-    config.rhs_shape = TypeToShape(op.rhs().getType());
+    config.output_shape = GetShape(op.output());
+    config.lhs_shape = GetShape(op.lhs());
+    config.rhs_shape = GetShape(op.rhs());
     backend.Clear();
     if (op.algorithm()) {
       backend.set_selected_algorithm(*op.algorithm());
@@ -1253,7 +1236,7 @@
           /*source_buffer=*/bias,
           /*destination_buffer=*/output,
           /*mem_size=*/
-          ShapeUtil::ByteSizeOf(TypeToShape(gemm.output().getType()))));
+          ShapeUtil::ByteSizeOf(GetShape(gemm.output()))));
       TF_ASSIGN_OR_RETURN(
           auto thunk,
           make_thunk_for_gemm(gemm, bias, gemm_bias_beta,
@@ -1321,8 +1304,7 @@
     // Only tested when the inputs are row-major. So only
     // enable that case. Maybe it would works if only the
     // inner dimensions is contiguous.
-    return LayoutUtil::IsMonotonicWithDim0Major(
-        TypeToShape(value.getType()).layout());
+    return LayoutUtil::IsMonotonicWithDim0Major(GetShape(value).layout());
   };
   bool row_vectorized =
       fusion.getFusionResults().size() == 1 &&  // Not tested with MOF.
@@ -1365,7 +1347,7 @@
         broadcast_dimensions.push_back(int_value.getSExtValue());
       }
 
-      auto rank = TypeToShape(broadcast.getResult().getType()).rank();
+      auto rank = GetShape(broadcast.getResult()).rank();
       if (broadcast_dimensions.size() == 1 &&
           broadcast_dimensions.back() == (rank - 1)) {
         some_row_broadcasting = true;
@@ -1385,7 +1367,7 @@
 Status IrEmitterUnnested::EmitBatchNormThunk(mlir::Operation* op) {
   auto get_batch_norm_config = [](auto op, mlir::Value output) {
     CudnnBatchNormConfig config;
-    config.output_shape = TypeToShape(output.getType());
+    config.output_shape = GetShape(output);
     config.output_type = config.output_shape.element_type();
     config.epsilon = op.epsilon().convertToFloat();
     config.feature_index = op.feature_index();
@@ -1461,7 +1443,7 @@
                         GetAllocationSlice(bn_grad.grad_offset()));
 
     CudnnBatchNormConfig config;
-    config.output_shape = TypeToShape(bn_grad.grad_output().getType());
+    config.output_shape = GetShape(bn_grad.grad_output());
     config.output_type = config.output_shape.element_type();
     config.epsilon = bn_grad.epsilon().convertToFloat();
     config.feature_index = bn_grad.feature_index();
@@ -1521,7 +1503,7 @@
 Status IrEmitterUnnested::EmitCholeskyThunk(mlir::Operation* op) {
   auto cholesky_op = mlir::cast<mlir::lmhlo_gpu::CholeskyOp>(op);
 
-  const Shape shape = TypeToShape(cholesky_op.input().getType());
+  const Shape shape = GetShape(cholesky_op.input());
   int ndim = shape.dimensions_size();
   CHECK_GE(ndim, 2);
   int64 n = shape.dimensions(ndim - 1);
@@ -1632,8 +1614,8 @@
 
 Status IrEmitterUnnested::EmitFftThunk(mlir::Operation* op) {
   auto fft_op = mlir::cast<mlir::lmhlo::FftOp>(op);
-  const Shape operand_shape = TypeToShape(fft_op.operand().getType());
-  const Shape output_shape = TypeToShape(fft_op.output().getType());
+  const Shape operand_shape = GetShape(fft_op.operand());
+  const Shape output_shape = GetShape(fft_op.output());
   TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand_shape.layout()));
   TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(output_shape.layout()));
 
@@ -1665,10 +1647,9 @@
   TF_RET_CHECK(has_fortran_layout(triangular_solve_op.layout_b()));
   TF_RET_CHECK(has_fortran_layout(triangular_solve_op.layout_output()));
 
-  const Shape b_shape = TypeToShape(triangular_solve_op.b().getType());
+  const Shape b_shape = GetShape(triangular_solve_op.b());
 
-  const Shape output_shape =
-      TypeToShape(triangular_solve_op.output().getType());
+  const Shape output_shape = GetShape(triangular_solve_op.output());
 
   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice a_slice,
                       GetAllocationSlice(triangular_solve_op.a()));
@@ -1759,33 +1740,14 @@
   for (auto load : loads) {
     auto arg = region->addArgument(load.getType());
     load.replaceAllUsesWith(arg);
-    Shape shape = TypeToShape(load.getType());
-    if (auto attr = load->getAttrOfType<mlir::DenseIntElementsAttr>(
-            kDefaultLayoutAttrName)) {
-      std::vector<int64> minor_to_major;
-      absl::c_transform(
-          attr, std::back_inserter(minor_to_major),
-          std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
-      *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
-    } else {
-      *shape.mutable_layout() =
-          LayoutUtil::MakeDescendingLayout(load.getType().getShape().size());
-    }
+    Shape shape = GetShape(load.getResult());
     operand_shapes->push_back(std::move(shape));
     load.erase();
   }
 
   std::vector<mlir::Value> returned_values;
   for (auto store : stores) {
-    Shape shape = TypeToShape(store.memref().getType());
-    if (auto attr = store->getAttrOfType<mlir::DenseIntElementsAttr>(
-            kDefaultLayoutAttrName)) {
-      std::vector<int64> minor_to_major;
-      absl::c_transform(
-          attr, std::back_inserter(minor_to_major),
-          std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
-      *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
-    }
+    Shape shape = GetShape(store.memref());
     output_shapes->push_back(shape);
 
     returned_values.push_back(store.tensor());
@@ -2132,8 +2094,8 @@
 
 Status IrEmitterUnnested::EmitCopy(mlir::Operation* op) {
   auto copy = mlir::cast<mlir::lmhlo::CopyOp>(op);
-  auto operand_shape = TypeToShape(copy.operand().getType());
-  auto output_shape = TypeToShape(copy.output().getType());
+  auto operand_shape = GetShape(copy.operand());
+  auto output_shape = GetShape(copy.output());
 
   CHECK(ShapeUtil::Compatible(operand_shape, output_shape));
   auto maybe_slice = GetAllocationSlice(copy.operand());
@@ -2204,10 +2166,8 @@
 Status IrEmitterUnnested::EmitSelectAndScatter(mlir::Operation* op) {
   auto select_and_scatter_op = mlir::cast<mlir::lmhlo::SelectAndScatterOp>(op);
 
-  const Shape source_shape =
-      TypeToShape(select_and_scatter_op.source().getType());
-  const Shape operand_shape =
-      TypeToShape(select_and_scatter_op.operand().getType());
+  const Shape source_shape = GetShape(select_and_scatter_op.source());
+  const Shape operand_shape = GetShape(select_and_scatter_op.operand());
   const int64 rank = operand_shape.rank();
 
   CHECK_EQ(rank, source_shape.rank());
@@ -2408,8 +2368,7 @@
           InBoundsGEP(selected_index_address, {b_.getInt32(i)});
       selected_multi_index.push_back(Load(selected_index_address_slot));
     }
-    const Shape output_shape =
-        TypeToShape(select_and_scatter_op.out().getType());
+    const Shape output_shape = GetShape(select_and_scatter_op.out());
     llvm::Value* source_value_address =
         source_array.EmitArrayElementAddress(source_index, &b_);
     IrArray::Index selected_index(selected_multi_index, output_shape,
@@ -2478,7 +2437,7 @@
   llvm::Value* old_state =
       llvm_ir::RngGetAndUpdateState(rng_op.delta(), module_, &b_);
 
-  const Shape shape = TypeToShape(rng_op.state().getType());
+  const Shape shape = GetShape(rng_op.state());
 
   llvm::Value* output_address = ir_arrays[0].EmitArrayElementAddress(
       llvm_ir::IrArray::Index(
@@ -2515,7 +2474,7 @@
         /*source_address=*/operand_buffer,
         /*destination_buffer=*/output_buffer,
         /*mem_size=*/
-        ShapeUtil::ByteSizeOf(TypeToShape(scatter_op.output().getType()))));
+        ShapeUtil::ByteSizeOf(GetShape(scatter_op.output()))));
   }
 
   // Create kernel thunk for all operands except the first one (`operand`). The
@@ -2568,9 +2527,8 @@
     const llvm_ir::ElementGenerator& scatter_indices_gen,
     const llvm_ir::ElementGenerator& updates_gen,
     std::function<llvm::Type*(int64)> get_index_type) {
-  const Shape operand_shape = TypeToShape(scatter.operand().getType());
-  CHECK(
-      ShapeUtil::Equal(TypeToShape(scatter.output().getType()), operand_shape));
+  const Shape operand_shape = GetShape(scatter.operand());
+  CHECK(ShapeUtil::Equal(GetShape(scatter.output()), operand_shape));
 
   TF_ASSIGN_OR_RETURN(
       const HloComputation* update_computation,
@@ -2580,8 +2538,8 @@
   ScatterDescriptor desc;
   desc.name = mlir::GetNameFromLoc(scatter.getLoc());
   desc.operand_shape = operand_shape;
-  desc.scatter_indices_shape = TypeToShape(scatter.scatter_indices().getType());
-  desc.updates_shape = TypeToShape(scatter.updates().getType());
+  desc.scatter_indices_shape = GetShape(scatter.scatter_indices());
+  desc.updates_shape = GetShape(scatter.updates());
   desc.dim_numbers = scatter.scatter_dimension_numbers();
   desc.unique_indices = scatter.unique_indices();
   desc.update_computation = update_computation;
@@ -3064,7 +3022,7 @@
   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
                       GetAllocationSlice(collective_permute_op.output()));
 
-  const Shape shape = TypeToShape(collective_permute_op.operand().getType());
+  const Shape shape = GetShape(collective_permute_op.operand());
   const int64 replica_count = hlo_module_config_.replica_count();
   const int64 partition_count = hlo_module_config_.num_partitions();
 
@@ -3141,7 +3099,7 @@
   for (auto it : llvm::zip(op.operands(), op.results())) {
     mlir::Value operand = std::get<0>(it);
     mlir::Value result = std::get<1>(it);
-    const Shape shape = TypeToShape(operand.getType());
+    const Shape shape = GetShape(operand);
     TF_ASSIGN_OR_RETURN(auto source_slice, GetAllocationSlice(operand));
     TF_ASSIGN_OR_RETURN(auto dest_slice, GetAllocationSlice(result));
     buffers.push_back(NcclCollectiveThunk::Buffer{
@@ -3176,7 +3134,7 @@
         CollectiveOpGroupModeToString(group_mode), op.operands().size(),
         NcclThunkType::NcclIsEnabled());
     if (!op.operands().empty()) {
-      const Shape shape = TypeToShape(op.operands().front().getType());
+      const Shape shape = GetShape(op.operands().front());
       absl::StrAppendFormat(&message, "; first operand array element-type: %s",
                             PrimitiveType_Name(shape.element_type()));
     }
@@ -3187,7 +3145,7 @@
   // assignment expects a copy, so that's what we do.
   ThunkSequence thunks;
   for (int64 i = 0; i < buffers.size(); i++) {
-    const Shape shape = TypeToShape(op.operands()[i].getType());
+    const Shape shape = GetShape(op.operands()[i]);
     thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
         buffers.size() == 1 ? GetThunkInfo(op) : Thunk::ThunkInfo(),
         /*source_address=*/buffers[i].source_buffer,
@@ -3229,7 +3187,7 @@
 
   for (mlir::Value output : infeed_op.outputs()) {
     TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(output));
-    const Shape& shape = TypeToShape(output.getType());
+    const Shape& shape = GetShape(output);
     dest_slices.push_back(ShapedSlice{slice, shape});
   }
 
@@ -3246,7 +3204,7 @@
 
   for (mlir::Value operand : outfeed_op.operands()) {
     TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(operand));
-    const Shape& shape = TypeToShape(operand.getType());
+    const Shape& shape = GetShape(operand);
     source_slices.push_back(ShapedSlice{slice, shape});
   }
 
@@ -3376,7 +3334,7 @@
     TF_ASSIGN_OR_RETURN(slice.buffer_slice,
                         GetAllocationSlice(operand, &slice.constant_name));
     slice.written = WritesMlirBuffer(op, operand);
-    slice.shape = TypeToShape(operand.getType());
+    slice.shape = GetShape(operand);
   }
   std::string name = mlir::GetNameFromLoc(op->getLoc());
   return BuildKernelThunkImpl(name, thunk_info, slices, ir_arrays);
@@ -3396,7 +3354,7 @@
       TF_ASSIGN_OR_RETURN(slice.buffer_slice,
                           GetAllocationSlice(operand, &slice.constant_name));
       slice.written = false;
-      slice.shape = TypeToShape(operand.getType());
+      slice.shape = GetShape(operand);
     }
     for (auto output : outputs) {
       slices.emplace_back();
@@ -3404,7 +3362,7 @@
       TF_ASSIGN_OR_RETURN(slice.buffer_slice,
                           GetAllocationSlice(output, &slice.constant_name));
       slice.written = true;
-      slice.shape = TypeToShape(output.getType());
+      slice.shape = GetShape(output);
     }
     std::string name = mlir::GetNameFromLoc(op->getLoc());
     return BuildKernelThunkImpl(name, thunk_info, slices, ir_arrays);
@@ -3480,7 +3438,7 @@
 
     TF_ASSIGN_OR_RETURN(auto dest_slice, GetAllocationSlice(dest));
 
-    const Shape dest_shape = TypeToShape(dest.getType());
+    const Shape dest_shape = GetShape(dest);
     auto thunk =
         BuildConstantInitializerThunk(literal_bytes, dest_slice, dest_shape);
     if (thunk) {
@@ -3511,7 +3469,7 @@
   const llvm_ir::IrArray init_array = ir_arrays[0];
   const llvm_ir::IrArray dest_array = ir_arrays[1];
 
-  const Shape dest_shape = TypeToShape(dest.getType());
+  const Shape dest_shape = GetShape(dest);
   TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
                       CalculateLaunchDimensions(
                           dest_shape, ir_emitter_context_->gpu_device_info()));
@@ -3554,7 +3512,7 @@
   const llvm_ir::IrArray dest_array =
       ir_arrays[input_buffers.size() + output_index];
 
-  const Shape dest_shape = TypeToShape(dest.getType());
+  const Shape dest_shape = GetShape(dest);
   TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
                       CalculateLaunchDimensions(
                           dest_shape, ir_emitter_context_->gpu_device_info()));
@@ -4901,7 +4859,7 @@
     const absl::flat_hash_set<mlir::Operation*>& use_chain_endings) {
   return absl::c_count_if(
       fusion.getFusionParameters(), [&](mlir::Value parameter) {
-        Shape parameter_shape = TypeToShape(parameter.getType());
+        Shape parameter_shape = GetShape(parameter);
         return ShapeUtil::SameDimensions(op_shape, parameter_shape) &&
                AreUsersElementwise(parameter, use_chain_endings);
       });
@@ -4914,7 +4872,7 @@
   int64 num_elements = ShapeUtil::ElementsIn(shape);
   return absl::c_count_if(
       fusion.getFusionParameters(), [&](mlir::Value parameter) {
-        Shape parameter_shape = TypeToShape(parameter.getType());
+        Shape parameter_shape = GetShape(parameter);
         return ShapeUtil::ElementsIn(parameter_shape) > num_elements;
       });
 }
@@ -4999,7 +4957,7 @@
            << reduction_dimensions.dimensions[2];
   auto get_dtype_bits = [](mlir::Value i) {
     // TODO(timshen): may not be efficient.
-    return primitive_util::BitWidth(TypeToShape(i.getType()).element_type());
+    return primitive_util::BitWidth(GetShape(i).element_type());
   };
 
   // For fusion with multiple inputs, use the smallest input dtype to
@@ -5283,7 +5241,7 @@
       }
     }
   }
-  Shape input_shape = TypeToShape(first_reduce->getOperand(0).getType());
+  Shape input_shape = GetShape(first_reduce->getOperand(0));
   // The layout of a reduction input is either set by LayoutAssignment for
   // unnested kReduce or by InstructionFusion for fused kReduce.
   CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion "
@@ -5688,10 +5646,10 @@
   auto operands = GetHloOperands(op);
   auto outputs = GetHloOutputs(op);
   for (auto operand : operands) {
-    operand_shapes.push_back(TypeToShape(operand.getType()));
+    operand_shapes.push_back(GetShape(operand));
   }
   for (auto output : outputs) {
-    output_shapes.push_back(TypeToShape(output.getType()));
+    output_shapes.push_back(GetShape(output));
   }
 }
 
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.cc
index 4e69477..e1c2640 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.cc
@@ -24,6 +24,7 @@
 
 #include "absl/strings/str_format.h"
 #include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
 #include "tensorflow/compiler/xla/util.h"
@@ -41,7 +42,7 @@
 
 /*static*/ bool NcclAllGatherThunk::CanImplement(mlir::lmhlo::AllGatherOp op) {
   return absl::c_all_of(op.operands(), [&](mlir::Value operand) {
-    Shape shape = TypeToShape(operand.getType());
+    Shape shape = GetShape(operand);
     return LayoutUtil::IsDenseArray(shape) &&
            IsTypeSupportedByNccl(shape.element_type()) &&
            LayoutUtil::MinorToMajor(shape).back() == op.all_gather_dimension();
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc
index 66dd7f8..0b97381 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc
@@ -108,7 +108,7 @@
   if (!opcode.ok()) return absl::nullopt;
   // Match the operation to a reduction kind. We can represent and/or of pred as
   // min/max. This works because pred is stored as an 8-bit int of value 0 or 1.
-  PrimitiveType type = TypeToShape(result.getType()).element_type();
+  PrimitiveType type = GetShape(result).element_type();
   if (type == PRED) {
     switch (opcode.ValueOrDie()) {
       case HloOpcode::kAnd:
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.cc
index 151749d..66cfbbb 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.cc
@@ -25,6 +25,7 @@
 #include "absl/strings/str_format.h"
 #include "absl/types/optional.h"
 #include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
 #include "tensorflow/compiler/xla/shape_util.h"
@@ -45,7 +46,7 @@
 
 /*static*/ bool NcclAllToAllThunk::CanImplement(mlir::lmhlo::AllToAllOp op) {
   return absl::c_all_of(op.operands(), [&op](mlir::Value operand) {
-    Shape shape = TypeToShape(operand.getType());
+    Shape shape = GetShape(operand);
     return LayoutUtil::IsDenseArray(shape) &&
            IsTypeSupportedByNccl(shape.element_type()) &&
            (!op.split_dimension() ||
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_collective_permute_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_collective_permute_thunk.cc
index d33d389..68b083f 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_collective_permute_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/nccl_collective_permute_thunk.cc
@@ -24,6 +24,7 @@
 #include "absl/types/optional.h"
 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 
 namespace xla {
@@ -36,7 +37,7 @@
   NcclCollectivePermuteConfig config;
 
   config.operand_count = 1;
-  const Shape shape = TypeToShape(op.operand().getType());
+  const Shape shape = GetShape(op.operand());
   config.operand_element_type.push_back(shape.element_type());
   config.SetCollectiveOpKindAndID(op);
   config.group_mode = GetGroupMode(op);
@@ -89,7 +90,7 @@
 
 /*static*/ bool NcclCollectivePermuteThunk::CanImplement(
     mlir::lmhlo::CollectivePermuteOp op) {
-  const Shape shape = TypeToShape(op.operand().getType());
+  const Shape shape = GetShape(op.operand());
   return IsTypeSupportedByNccl(shape.element_type());
 }
 
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h
index 552c98c..8b2c7bf 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h
@@ -22,6 +22,7 @@
 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -98,7 +99,7 @@
   config.operand_count = op.operands().size();
   config.operand_element_type.reserve(config.operand_count);
   for (int i = 0; i < config.operand_count; i++) {
-    const Shape shape = TypeToShape(op.operands()[i].getType());
+    const Shape shape = GetShape(op.operands()[i]);
     config.operand_element_type.push_back(shape.element_type());
   }
   config.replica_groups =