[mlir][mhlo][sparse] lower general reshape to proper expand/collapse for sparse tensors

This change ensures "trivial" general reshape lowers to "trivial"
expand/collapse operations, without redundant conversions that
strip the sparsity of the operand.

PiperOrigin-RevId: 457601360
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
index 40509d2..ba2e97e 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
@@ -1151,14 +1151,18 @@
               shape[targetDim] = 1;
           }
         }
-        auto newOperandType = RankedTensorType::get(shape, elemType);
+        // Insert a cast if types are not the same (ignoring sparse encoding).
+        auto enc = sparse_tensor::getSparseTensorEncoding(operandType);
+        auto newOperandType = RankedTensorType::get(shape, elemType, enc);
         if (newOperandType != operandType) {
           operand = rewriter.create<tensor::CastOp>(reshapeOp.getLoc(),
                                                     newOperandType, operand);
         }
+        // Generate collapse operation.
         rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
             reshapeOp, resultType, operand, *reassociationMap);
       } else {
+        // Generate expand operation.
         rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
             reshapeOp, resultType, operand, *reassociationMap);
       }
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/sparse_lower.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/sparse_lower.mlir
index 806873c..12681d2 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/sparse_lower.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/sparse_lower.mlir
@@ -262,3 +262,20 @@
   return %0 : tensor<2x3xi32, #DCSR>
 }
 
+// CHECK-LABEL: func @sparse_expand(
+// CHECK-SAME:    %[[ARG0:.*]]: tensor<100xf64, #{{.*}}>) -> tensor<10x10xf64, #{{.*}}> {
+// CHECK:         %[[OUT:.*]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1]] : tensor<100xf64, #{{.*}}> into tensor<10x10xf64, #{{.*}}>
+// CHECK:         return %[[OUT]] : tensor<10x10xf64, #{{.*}}>
+func.func @sparse_expand(%arg0: tensor<100xf64, #SV>) -> tensor<10x10xf64, #CSR> {
+  %0 = "mhlo.reshape"(%arg0) : (tensor<100xf64, #SV>) -> tensor<10x10xf64, #CSR>
+  return %0 : tensor<10x10xf64, #CSR>
+}
+
+// CHECK-LABEL: func @sparse_collapse(
+// CHECK-SAME:    %[[ARG0:.*]]: tensor<10x10xf64, #{{.*}}>) -> tensor<100xf64, #{{.*}}> {
+// CHECK:         %[[OUT:.*]] = tensor.collapse_shape %[[ARG0]] {{\[\[}}0, 1]] : tensor<10x10xf64, #{{.*}}> into tensor<100xf64, #{{.*}}>
+// CHECK:         return %[[OUT]] : tensor<100xf64, #{{.*}}>
+func.func @sparse_collapse(%arg0: tensor<10x10xf64, #CSR>) -> tensor<100xf64, #SV> {
+  %0 = "mhlo.reshape"(%arg0) : (tensor<10x10xf64, #CSR>) -> tensor<100xf64, #SV>
+  return %0 : tensor<100xf64, #SV>
+}