[xla:jitrt] Work around the canonicalization bug for mhlo copy operation
mhlo.copy operation canonicalizer can break downstream gpu code generation, for now just make sure that the first pass only removes memref.get_global ops
PiperOrigin-RevId: 450593603
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_jitrt.cc b/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_jitrt.cc
index 3e0771f..7095464 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_jitrt.cc
+++ b/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_jitrt.cc
@@ -469,6 +469,11 @@
target.addIllegalOp<GetGlobalOp>();
target.addLegalOp<ConstantOp, memref::ViewOp>();
+ // TODO(ezhulenev): By adding MHLO and LMHLO to a set of legal dialects, we
+ // suppress any rewrites for these dialects (there are canonicalization
+ // patterns that interact badly with downstream Gpu binary code generation).
+ target.addLegalDialect<mhlo::MhloDialect, lmhlo::LmhloDialect>();
+
if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();
}