[mlir][mhlo][sparse] lower general reshape to proper expand/collapse for sparse tensors
This change ensures "trivial" general reshape lowers to "trivial"
expand/collapse operations, without redundant conversions that
strip the sparsity of the operand.
PiperOrigin-RevId: 457601360
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
index 40509d2..ba2e97e 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
@@ -1151,14 +1151,18 @@
shape[targetDim] = 1;
}
}
- auto newOperandType = RankedTensorType::get(shape, elemType);
+ // Insert a cast if types are not the same (ignoring sparse encoding).
+ auto enc = sparse_tensor::getSparseTensorEncoding(operandType);
+ auto newOperandType = RankedTensorType::get(shape, elemType, enc);
if (newOperandType != operandType) {
operand = rewriter.create<tensor::CastOp>(reshapeOp.getLoc(),
newOperandType, operand);
}
+ // Generate collapse operation.
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
reshapeOp, resultType, operand, *reassociationMap);
} else {
+ // Generate expand operation.
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
reshapeOp, resultType, operand, *reassociationMap);
}
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/sparse_lower.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/sparse_lower.mlir
index 806873c..12681d2 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/sparse_lower.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/sparse_lower.mlir
@@ -262,3 +262,20 @@
return %0 : tensor<2x3xi32, #DCSR>
}
+// CHECK-LABEL: func @sparse_expand(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<100xf64, #{{.*}}>) -> tensor<10x10xf64, #{{.*}}> {
+// CHECK: %[[OUT:.*]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1]] : tensor<100xf64, #{{.*}}> into tensor<10x10xf64, #{{.*}}>
+// CHECK: return %[[OUT]] : tensor<10x10xf64, #{{.*}}>
+func.func @sparse_expand(%arg0: tensor<100xf64, #SV>) -> tensor<10x10xf64, #CSR> {
+ %0 = "mhlo.reshape"(%arg0) : (tensor<100xf64, #SV>) -> tensor<10x10xf64, #CSR>
+ return %0 : tensor<10x10xf64, #CSR>
+}
+
+// CHECK-LABEL: func @sparse_collapse(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<10x10xf64, #{{.*}}>) -> tensor<100xf64, #{{.*}}> {
+// CHECK: %[[OUT:.*]] = tensor.collapse_shape %[[ARG0]] {{\[\[}}0, 1]] : tensor<10x10xf64, #{{.*}}> into tensor<100xf64, #{{.*}}>
+// CHECK: return %[[OUT]] : tensor<100xf64, #{{.*}}>
+func.func @sparse_collapse(%arg0: tensor<10x10xf64, #CSR>) -> tensor<100xf64, #SV> {
+ %0 = "mhlo.reshape"(%arg0) : (tensor<10x10xf64, #CSR>) -> tensor<100xf64, #SV>
+ return %0 : tensor<100xf64, #SV>
+}