[xla:jitrt] Work around the canonicalization bug for mhlo copy operation

mhlo.copy operation canonicalizer can break downstream gpu code generation, for now just make sure that the first pass only removes memref.get_global ops

PiperOrigin-RevId: 450593603
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_jitrt.cc b/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_jitrt.cc
index 3e0771f..7095464 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_jitrt.cc
+++ b/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_jitrt.cc
@@ -469,6 +469,11 @@
   target.addIllegalOp<GetGlobalOp>();
   target.addLegalOp<ConstantOp, memref::ViewOp>();
 
+  // TODO(ezhulenev): By adding MHLO and LMHLO to a set of legal dialects, we
+  // suppress any rewrites for these dialects (there are canonicalization
+  // patterns that interact badly with downstream Gpu binary code generation).
+  target.addLegalDialect<mhlo::MhloDialect, lmhlo::LmhloDialect>();
+
   if (failed(applyPartialConversion(module, target, std::move(patterns))))
     signalPassFailure();
 }