Clean up Embedding op side effect handling

- now we have two Embedding side effects, read and write
- now dependencies between EnqueueTPUEmbedding ops with same device ordinal are
  properly modeled
- we now finally don't have any Embedding-specific code left in side effect
  analysis
- introduced new `TF_MustExecute` trait that avoids pruning of an op; this is
  useful for side-effecting ops that don't produce any output and don't have
  dependencies to/from other ops
- for ops that just used `TF_TPUEmbeddingSideEffect` to avoid pruning, use new
  `TF_MustExecute` trait instead
- in contrast to the old `TF_TPUEmbeddingSideEffect`, `TF_MustExecute` avoids
  pruning independent of reachability (see new graph pruning test)

PiperOrigin-RevId: 413175982
Change-Id: I7b65c7a0e8a17b8a1683a0e01d1fd0614f7ac95a
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index a25a1d7..387c19e 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -1248,6 +1248,7 @@
         ":tensorflow_analysis",
         ":tensorflow_ops",
         ":tensorflow_optimize_inc_gen",
+        ":tensorflow_side_effects",
         ":tensorflow_types",
         ":tf_data_optimization",
         ":tf_legalize_hlo",
diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc
index d1806a5..9ef779e 100644
--- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc
+++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc
@@ -35,6 +35,7 @@
 #include "mlir/Support/LLVM.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
@@ -155,11 +156,7 @@
     const MemoryEffects::EffectInstance& effect_instance, Operation* op) {
   mlir::SideEffects::Effect* effect = effect_instance.getEffect();
   SideEffects side_effects;
-  if (llvm::isa<ResourceEffects::TPUEmbedding>(effect_instance.getResource())) {
-    // TODO(mgester) This hack can be removed once b/196857154 is fixed.
-    // See definition of `TF_TPUEmbeddingSideEffect` for more details.
-    side_effects.SetRead();
-  } else if (isa<MemoryEffects::Allocate>(effect)) {
+  if (isa<MemoryEffects::Allocate>(effect)) {
     side_effects.SetAlloc();
   } else if (isa<MemoryEffects::Free>(effect)) {
     side_effects.SetFree();
@@ -356,11 +353,20 @@
         // We handle value-based side effects for which we can use resource
         // alias analysis at a different place, skip here.
         if (ShouldUseResourceAliasAnalysis(effect)) continue;
+        if (llvm::isa<ResourceEffects::MustExecute>(effect.getResource()))
+          // We have this fake resource to avoid that certain ops are considered
+          // dead or get pruned, ignore it for side effect analysis.
+          continue;
 
         // Add side effects for op resource ID.
+        int64_t instance_id = -1;
         SideEffects side_effects(GetSideEffectsFromEffectInstance(effect, op));
+        if (auto resource_instance_op =
+            dyn_cast<GetResourceInstanceInterface>(op)) {
+          instance_id = resource_instance_op.GetResourceInstanceId();
+        }
         ResourceId resource_id =
-            GetOpResourceId(effect.getResource()->getResourceID());
+            GetOpResourceId(effect.getResource()->getResourceID(), instance_id);
         side_effects.SetResourceId(resource_id);
         UpdateSideEffectsByResourceId(side_effects,
                                       side_effects_by_resource_id);
@@ -368,10 +374,11 @@
     }
   }
 
-  // Get internal op resource ID from MLIR type ID.
-  ResourceId GetOpResourceId(TypeID type_id) {
+  // Get internal op resource ID from MLIR type ID and instance ID.
+  ResourceId GetOpResourceId(TypeID type_id, int64_t instance_id) {
     auto emplace_result =
-        type_id_to_op_resource_id_.try_emplace(type_id, next_op_resource_id_);
+        type_instance_ids_to_op_resource_id_.try_emplace(
+            std::make_pair(type_id, instance_id), next_op_resource_id_);
     // Increment type ID if we have encountered a new resource type.
     if (emplace_result.second) ++next_op_resource_id_;
     return emplace_result.first->getSecond();
@@ -385,9 +392,10 @@
   // Next available ID for op-based resources (resources not handled by resource
   // alias analysis).
   ResourceId next_op_resource_id_ = kMaxResourceId + 1;
-  // Maps MLIR type IDs for resource types to internal IDs for op-based
-  // resources. Also see comment above.
-  llvm::SmallDenseMap<TypeID, ResourceId> type_id_to_op_resource_id_;
+  // Maps (type ID, instance ID) pairs to internal IDs for op-based resources.
+  // Also see comment above.
+  llvm::SmallDenseMap<std::pair<TypeID, int64_t>, ResourceId>
+    type_instance_ids_to_op_resource_id_;
   // Used for faster callable resolution.
   SymbolTableCollection symbol_table_collection_;
   // Collect all op-based side effects here.
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index 6196948..3bf73fb 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -4065,7 +4065,7 @@
   let hasFolder = 1;
 }
 
-def TF_EnqueueTPUEmbeddingArbitraryTensorBatchOp : TF_Op<"EnqueueTPUEmbeddingArbitraryTensorBatch", [SameVariadicOperandSize, TF_TPUEmbeddingSideEffect]> {
+def TF_EnqueueTPUEmbeddingArbitraryTensorBatchOp : TF_Op<"EnqueueTPUEmbeddingArbitraryTensorBatch", [DeclareOpInterfaceMethods<TF_GetResourceInstanceInterface>, SameVariadicOperandSize, TF_TPUEmbeddingWriteEffect]> {
   let summary = [{
 Eases the porting of code that uses tf.nn.embedding_lookup_sparse().
   }];
@@ -4117,7 +4117,7 @@
   TF_DerivedOperandTypeAttr T3 = TF_DerivedOperandTypeAttr<2>;
 }
 
-def TF_EnqueueTPUEmbeddingBatchOp : TF_Op<"EnqueueTPUEmbeddingBatch", [TF_TPUEmbeddingSideEffect]> {
+def TF_EnqueueTPUEmbeddingBatchOp : TF_Op<"EnqueueTPUEmbeddingBatch", [DeclareOpInterfaceMethods<TF_GetResourceInstanceInterface>, TF_TPUEmbeddingWriteEffect]> {
   let summary = [{
 An op that enqueues a list of input batch tensors to TPUEmbedding.
   }];
@@ -4141,7 +4141,7 @@
   TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
 }
 
-def TF_EnqueueTPUEmbeddingIntegerBatchOp : TF_Op<"EnqueueTPUEmbeddingIntegerBatch", [TF_TPUEmbeddingSideEffect]> {
+def TF_EnqueueTPUEmbeddingIntegerBatchOp : TF_Op<"EnqueueTPUEmbeddingIntegerBatch", [DeclareOpInterfaceMethods<TF_GetResourceInstanceInterface>, TF_TPUEmbeddingWriteEffect]> {
   let summary = [{
 An op that enqueues a list of input batch tensors to TPUEmbedding.
   }];
@@ -4162,7 +4162,7 @@
   TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
 }
 
-def TF_EnqueueTPUEmbeddingRaggedTensorBatchOp : TF_Op<"EnqueueTPUEmbeddingRaggedTensorBatch", [SameVariadicOperandSize, TF_TPUEmbeddingSideEffect]> {
+def TF_EnqueueTPUEmbeddingRaggedTensorBatchOp : TF_Op<"EnqueueTPUEmbeddingRaggedTensorBatch", [DeclareOpInterfaceMethods<TF_GetResourceInstanceInterface>, SameVariadicOperandSize, TF_TPUEmbeddingWriteEffect]> {
   let summary = "Eases the porting of code that uses tf.nn.embedding_lookup().";
 
   let description = [{
@@ -4207,7 +4207,7 @@
   TF_DerivedOperandTypeAttr T3 = TF_DerivedOperandTypeAttr<2>;
 }
 
-def TF_EnqueueTPUEmbeddingSparseBatchOp : TF_Op<"EnqueueTPUEmbeddingSparseBatch", [SameVariadicOperandSize, TF_TPUEmbeddingSideEffect]> {
+def TF_EnqueueTPUEmbeddingSparseBatchOp : TF_Op<"EnqueueTPUEmbeddingSparseBatch", [DeclareOpInterfaceMethods<TF_GetResourceInstanceInterface>, SameVariadicOperandSize, TF_TPUEmbeddingWriteEffect]> {
   let summary = [{
 An op that enqueues TPUEmbedding input indices from a SparseTensor.
   }];
@@ -4250,7 +4250,7 @@
   TF_DerivedOperandTypeAttr T3 = TF_DerivedOperandTypeAttr<2>;
 }
 
-def TF_EnqueueTPUEmbeddingSparseTensorBatchOp : TF_Op<"EnqueueTPUEmbeddingSparseTensorBatch", [SameVariadicOperandSize, TF_TPUEmbeddingSideEffect]> {
+def TF_EnqueueTPUEmbeddingSparseTensorBatchOp : TF_Op<"EnqueueTPUEmbeddingSparseTensorBatch", [DeclareOpInterfaceMethods<TF_GetResourceInstanceInterface>, SameVariadicOperandSize, TF_TPUEmbeddingWriteEffect]> {
   let summary = [{
 Eases the porting of code that uses tf.nn.embedding_lookup_sparse().
   }];
@@ -6948,7 +6948,7 @@
   TF_DerivedResultTypeAttr out_idx = TF_DerivedResultTypeAttr<1>;
 }
 
-def TF_LoadTPUEmbeddingADAMParametersOp : TF_Op<"LoadTPUEmbeddingADAMParameters", [TF_TPUEmbeddingSideEffect]> {
+def TF_LoadTPUEmbeddingADAMParametersOp : TF_Op<"LoadTPUEmbeddingADAMParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "Load ADAM embedding parameters.";
 
   let description = [{
@@ -6974,7 +6974,7 @@
   let results = (outs);
 }
 
-def TF_LoadTPUEmbeddingADAMParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingADAMParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
+def TF_LoadTPUEmbeddingADAMParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingADAMParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "";
 
   let arguments = (ins
@@ -6993,7 +6993,7 @@
   let results = (outs);
 }
 
-def TF_LoadTPUEmbeddingAdadeltaParametersOp : TF_Op<"LoadTPUEmbeddingAdadeltaParameters", [TF_TPUEmbeddingSideEffect]> {
+def TF_LoadTPUEmbeddingAdadeltaParametersOp : TF_Op<"LoadTPUEmbeddingAdadeltaParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "Load Adadelta embedding parameters.";
 
   let description = [{
@@ -7019,7 +7019,7 @@
   let results = (outs);
 }
 
-def TF_LoadTPUEmbeddingAdadeltaParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingAdadeltaParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
+def TF_LoadTPUEmbeddingAdadeltaParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingAdadeltaParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "";
 
   let arguments = (ins
@@ -7038,7 +7038,7 @@
   let results = (outs);
 }
 
-def TF_LoadTPUEmbeddingAdagradParametersOp : TF_Op<"LoadTPUEmbeddingAdagradParameters", [TF_TPUEmbeddingSideEffect]> {
+def TF_LoadTPUEmbeddingAdagradParametersOp : TF_Op<"LoadTPUEmbeddingAdagradParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "Load Adagrad embedding parameters.";
 
   let description = [{
@@ -7063,7 +7063,7 @@
   let results = (outs);
 }
 
-def TF_LoadTPUEmbeddingAdagradParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingAdagradParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
+def TF_LoadTPUEmbeddingAdagradParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingAdagradParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "";
 
   let arguments = (ins
@@ -7081,7 +7081,7 @@
   let results = (outs);
 }
 
-def TF_LoadTPUEmbeddingCenteredRMSPropParametersOp : TF_Op<"LoadTPUEmbeddingCenteredRMSPropParameters", [TF_TPUEmbeddingSideEffect]> {
+def TF_LoadTPUEmbeddingCenteredRMSPropParametersOp : TF_Op<"LoadTPUEmbeddingCenteredRMSPropParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "Load centered RMSProp embedding parameters.";
 
   let description = [{
@@ -7108,7 +7108,7 @@
   let results = (outs);
 }
 
-def TF_LoadTPUEmbeddingFTRLParametersOp : TF_Op<"LoadTPUEmbeddingFTRLParameters", [TF_TPUEmbeddingSideEffect]> {
+def TF_LoadTPUEmbeddingFTRLParametersOp : TF_Op<"LoadTPUEmbeddingFTRLParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "Load FTRL embedding parameters.";
 
   let description = [{
@@ -7134,7 +7134,7 @@
   let results = (outs);
 }
 
-def TF_LoadTPUEmbeddingFTRLParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingFTRLParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
+def TF_LoadTPUEmbeddingFTRLParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingFTRLParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "";
 
   let arguments = (ins
@@ -7153,7 +7153,7 @@
   let results = (outs);
 }
 
-def TF_LoadTPUEmbeddingMDLAdagradLightParametersOp : TF_Op<"LoadTPUEmbeddingMDLAdagradLightParameters", [TF_TPUEmbeddingSideEffect]> {
+def TF_LoadTPUEmbeddingMDLAdagradLightParametersOp : TF_Op<"LoadTPUEmbeddingMDLAdagradLightParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "Load MDL Adagrad Light embedding parameters.";
 
   let description = [{
@@ -7180,7 +7180,7 @@
   let results = (outs);
 }
 
-def TF_LoadTPUEmbeddingMomentumParametersOp : TF_Op<"LoadTPUEmbeddingMomentumParameters", [TF_TPUEmbeddingSideEffect]> {
+def TF_LoadTPUEmbeddingMomentumParametersOp : TF_Op<"LoadTPUEmbeddingMomentumParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "Load Momentum embedding parameters.";
 
   let description = [{
@@ -7205,7 +7205,7 @@
   let results = (outs);
 }
 
-def TF_LoadTPUEmbeddingMomentumParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingMomentumParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
+def TF_LoadTPUEmbeddingMomentumParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingMomentumParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "";
 
   let arguments = (ins
@@ -7223,7 +7223,7 @@
   let results = (outs);
 }
 
-def TF_LoadTPUEmbeddingProximalAdagradParametersOp : TF_Op<"LoadTPUEmbeddingProximalAdagradParameters", [TF_TPUEmbeddingSideEffect]> {
+def TF_LoadTPUEmbeddingProximalAdagradParametersOp : TF_Op<"LoadTPUEmbeddingProximalAdagradParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "Load proximal Adagrad embedding parameters.";
 
   let description = [{
@@ -7248,7 +7248,7 @@
   let results = (outs);
 }
 
-def TF_LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
+def TF_LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "";
 
   let arguments = (ins
@@ -7266,7 +7266,7 @@
   let results = (outs);
 }
 
-def TF_LoadTPUEmbeddingProximalYogiParametersOp : TF_Op<"LoadTPUEmbeddingProximalYogiParameters", [TF_TPUEmbeddingSideEffect]> {
+def TF_LoadTPUEmbeddingProximalYogiParametersOp : TF_Op<"LoadTPUEmbeddingProximalYogiParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "";
 
   let arguments = (ins
@@ -7284,7 +7284,7 @@
   let results = (outs);
 }
 
-def TF_LoadTPUEmbeddingProximalYogiParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingProximalYogiParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
+def TF_LoadTPUEmbeddingProximalYogiParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingProximalYogiParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "";
 
   let arguments = (ins
@@ -7303,7 +7303,7 @@
   let results = (outs);
 }
 
-def TF_LoadTPUEmbeddingRMSPropParametersOp : TF_Op<"LoadTPUEmbeddingRMSPropParameters", [TF_TPUEmbeddingSideEffect]> {
+def TF_LoadTPUEmbeddingRMSPropParametersOp : TF_Op<"LoadTPUEmbeddingRMSPropParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "Load RMSProp embedding parameters.";
 
   let description = [{
@@ -7329,7 +7329,7 @@
   let results = (outs);
 }
 
-def TF_LoadTPUEmbeddingRMSPropParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingRMSPropParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
+def TF_LoadTPUEmbeddingRMSPropParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingRMSPropParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "";
 
   let arguments = (ins
@@ -7348,7 +7348,7 @@
   let results = (outs);
 }
 
-def TF_LoadTPUEmbeddingStochasticGradientDescentParametersOp : TF_Op<"LoadTPUEmbeddingStochasticGradientDescentParameters", [TF_TPUEmbeddingSideEffect]> {
+def TF_LoadTPUEmbeddingStochasticGradientDescentParametersOp : TF_Op<"LoadTPUEmbeddingStochasticGradientDescentParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "Load SGD embedding parameters.";
 
   let description = [{
@@ -7372,7 +7372,7 @@
   let results = (outs);
 }
 
-def TF_LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
+def TF_LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "";
 
   let arguments = (ins
@@ -11548,7 +11548,7 @@
   TF_DerivedResultTypeAttr tensor_type = TF_DerivedResultTypeAttr<0>;
 }
 
-def TF_RecvTPUEmbeddingActivationsOp : TF_Op<"RecvTPUEmbeddingActivations", [TF_TPUEmbeddingSideEffect]> {
+def TF_RecvTPUEmbeddingActivationsOp : TF_Op<"RecvTPUEmbeddingActivations", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "An op that receives embedding activations on the TPU.";
 
   let description = [{
@@ -12939,7 +12939,7 @@
   TF_DerivedResultTypeListAttr dtypes = TF_DerivedResultTypeListAttr<0>;
 }
 
-def TF_RetrieveTPUEmbeddingADAMParametersOp : TF_Op<"RetrieveTPUEmbeddingADAMParameters", [TF_TPUEmbeddingSideEffect]> {
+def TF_RetrieveTPUEmbeddingADAMParametersOp : TF_Op<"RetrieveTPUEmbeddingADAMParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "Retrieve ADAM embedding parameters.";
 
   let description = [{
@@ -12964,7 +12964,7 @@
   );
 }
 
-def TF_RetrieveTPUEmbeddingADAMParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingADAMParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
+def TF_RetrieveTPUEmbeddingADAMParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingADAMParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "";
 
   let arguments = (ins
@@ -12983,7 +12983,7 @@
   );
 }
 
-def TF_RetrieveTPUEmbeddingAdadeltaParametersOp : TF_Op<"RetrieveTPUEmbeddingAdadeltaParameters", [TF_TPUEmbeddingSideEffect]> {
+def TF_RetrieveTPUEmbeddingAdadeltaParametersOp : TF_Op<"RetrieveTPUEmbeddingAdadeltaParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "Retrieve Adadelta embedding parameters.";
 
   let description = [{
@@ -13008,7 +13008,7 @@
   );
 }
 
-def TF_RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
+def TF_RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "";
 
   let arguments = (ins
@@ -13027,7 +13027,7 @@
   );
 }
 
-def TF_RetrieveTPUEmbeddingAdagradParametersOp : TF_Op<"RetrieveTPUEmbeddingAdagradParameters", [TF_TPUEmbeddingSideEffect]> {
+def TF_RetrieveTPUEmbeddingAdagradParametersOp : TF_Op<"RetrieveTPUEmbeddingAdagradParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "Retrieve Adagrad embedding parameters.";
 
   let description = [{
@@ -13051,7 +13051,7 @@
   );
 }
 
-def TF_RetrieveTPUEmbeddingAdagradParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingAdagradParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
+def TF_RetrieveTPUEmbeddingAdagradParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingAdagradParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "";
 
   let arguments = (ins
@@ -13069,7 +13069,7 @@
   );
 }
 
-def TF_RetrieveTPUEmbeddingCenteredRMSPropParametersOp : TF_Op<"RetrieveTPUEmbeddingCenteredRMSPropParameters", [TF_TPUEmbeddingSideEffect]> {
+def TF_RetrieveTPUEmbeddingCenteredRMSPropParametersOp : TF_Op<"RetrieveTPUEmbeddingCenteredRMSPropParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "Retrieve centered RMSProp embedding parameters.";
 
   let description = [{
@@ -13095,7 +13095,7 @@
   );
 }
 
-def TF_RetrieveTPUEmbeddingFTRLParametersOp : TF_Op<"RetrieveTPUEmbeddingFTRLParameters", [TF_TPUEmbeddingSideEffect]> {
+def TF_RetrieveTPUEmbeddingFTRLParametersOp : TF_Op<"RetrieveTPUEmbeddingFTRLParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "Retrieve FTRL embedding parameters.";
 
   let description = [{
@@ -13120,7 +13120,7 @@
   );
 }
 
-def TF_RetrieveTPUEmbeddingFTRLParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingFTRLParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
+def TF_RetrieveTPUEmbeddingFTRLParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingFTRLParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "";
 
   let arguments = (ins
@@ -13139,7 +13139,7 @@
   );
 }
 
-def TF_RetrieveTPUEmbeddingMDLAdagradLightParametersOp : TF_Op<"RetrieveTPUEmbeddingMDLAdagradLightParameters", [TF_TPUEmbeddingSideEffect]> {
+def TF_RetrieveTPUEmbeddingMDLAdagradLightParametersOp : TF_Op<"RetrieveTPUEmbeddingMDLAdagradLightParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "Retrieve MDL Adagrad Light embedding parameters.";
 
   let description = [{
@@ -13165,7 +13165,7 @@
   );
 }
 
-def TF_RetrieveTPUEmbeddingMomentumParametersOp : TF_Op<"RetrieveTPUEmbeddingMomentumParameters", [TF_TPUEmbeddingSideEffect]> {
+def TF_RetrieveTPUEmbeddingMomentumParametersOp : TF_Op<"RetrieveTPUEmbeddingMomentumParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "Retrieve Momentum embedding parameters.";
 
   let description = [{
@@ -13189,7 +13189,7 @@
   );
 }
 
-def TF_RetrieveTPUEmbeddingMomentumParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingMomentumParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
+def TF_RetrieveTPUEmbeddingMomentumParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingMomentumParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "";
 
   let arguments = (ins
@@ -13207,7 +13207,7 @@
   );
 }
 
-def TF_RetrieveTPUEmbeddingProximalAdagradParametersOp : TF_Op<"RetrieveTPUEmbeddingProximalAdagradParameters", [TF_TPUEmbeddingSideEffect]> {
+def TF_RetrieveTPUEmbeddingProximalAdagradParametersOp : TF_Op<"RetrieveTPUEmbeddingProximalAdagradParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "Retrieve proximal Adagrad embedding parameters.";
 
   let description = [{
@@ -13231,7 +13231,7 @@
   );
 }
 
-def TF_RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
+def TF_RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "";
 
   let arguments = (ins
@@ -13249,7 +13249,7 @@
   );
 }
 
-def TF_RetrieveTPUEmbeddingProximalYogiParametersOp : TF_Op<"RetrieveTPUEmbeddingProximalYogiParameters", [TF_TPUEmbeddingSideEffect]> {
+def TF_RetrieveTPUEmbeddingProximalYogiParametersOp : TF_Op<"RetrieveTPUEmbeddingProximalYogiParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "";
 
   let arguments = (ins
@@ -13267,7 +13267,7 @@
   );
 }
 
-def TF_RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
+def TF_RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "";
 
   let arguments = (ins
@@ -13286,7 +13286,7 @@
   );
 }
 
-def TF_RetrieveTPUEmbeddingRMSPropParametersOp : TF_Op<"RetrieveTPUEmbeddingRMSPropParameters", [TF_TPUEmbeddingSideEffect]> {
+def TF_RetrieveTPUEmbeddingRMSPropParametersOp : TF_Op<"RetrieveTPUEmbeddingRMSPropParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "Retrieve RMSProp embedding parameters.";
 
   let description = [{
@@ -13311,7 +13311,7 @@
   );
 }
 
-def TF_RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
+def TF_RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "";
 
   let arguments = (ins
@@ -13330,7 +13330,7 @@
   );
 }
 
-def TF_RetrieveTPUEmbeddingStochasticGradientDescentParametersOp : TF_Op<"RetrieveTPUEmbeddingStochasticGradientDescentParameters", [TF_TPUEmbeddingSideEffect]> {
+def TF_RetrieveTPUEmbeddingStochasticGradientDescentParametersOp : TF_Op<"RetrieveTPUEmbeddingStochasticGradientDescentParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "Retrieve SGD embedding parameters.";
 
   let description = [{
@@ -13353,7 +13353,7 @@
   );
 }
 
-def TF_RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
+def TF_RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "";
 
   let arguments = (ins
@@ -14407,7 +14407,7 @@
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
 }
 
-def TF_SendTPUEmbeddingGradientsOp : TF_Op<"SendTPUEmbeddingGradients", [AttrSizedOperandSegments, TF_TPUEmbeddingSideEffect]> {
+def TF_SendTPUEmbeddingGradientsOp : TF_Op<"SendTPUEmbeddingGradients", [AttrSizedOperandSegments, TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "Performs gradient updates of embedding tables.";
 
   let arguments = (ins
@@ -20173,7 +20173,7 @@
   TF_DerivedResultTypeAttr T = TF_DerivedResultTypeAttr<0>;
 }
 
-def TF__RecvTPUEmbeddingActivationsOp : TF_Op<"_RecvTPUEmbeddingActivations", [TF_TPUEmbeddingSideEffect]> {
+def TF__RecvTPUEmbeddingActivationsOp : TF_Op<"_RecvTPUEmbeddingActivations", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "An op that receives embeddng activations on the TPU.";
 
   let description = [{
@@ -20225,7 +20225,7 @@
   );
 }
 
-def TF__SendTPUEmbeddingGradientsOp : TF_Op<"_SendTPUEmbeddingGradients", [AttrSizedOperandSegments, TF_TPUEmbeddingSideEffect]> {
+def TF__SendTPUEmbeddingGradientsOp : TF_Op<"_SendTPUEmbeddingGradients", [AttrSizedOperandSegments, TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
   let summary = "An op that performs gradient updates of embedding tables.";
 
   let description = [{
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
index 2bd55f2..020ffdc 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
@@ -165,6 +165,8 @@
 def TF_GeneratorOpResource : TF_ResourceBase<"GeneratorOp">;
 def TF_SendRecvResource : TF_ResourceBase<"SendRecv">;
 def TF_TPUCompileExecuteResource : TF_ResourceBase<"TPUCompileExecute">;
+// Fake resource, see `TF_MustExecute` below.
+def TF_MustExecuteResource : TF_ResourceBase<"MustExecute">;
 
 // Value-based side effects
 //
@@ -214,17 +216,21 @@
 // effecting ops. Note that for `read` effects ops might be pruned if nothing
 // depends on them.
 def TF_GeneratorOpSideEffect : MemoryEffects<[MemWrite<TF_GeneratorOpResource>]>;
-// Note: We actually want a `read` effect here but then some ops with this
-// effect are considered dead and are deleted which is not desired (see
-// b/195782952).
-// Therefore, we use a `write` effect + special handling in side effect
-// analysis. Once we have proper dependencies that avoid deletion (see
-// b/196857154), or once MLIR supports a trait to mark an op as not dead, this
-// hack can be removed.
-def TF_TPUEmbeddingSideEffect : MemoryEffects<[MemWrite<TF_TPUEmbeddingResource>]>;
+
+def TF_TPUEmbeddingWriteEffect : MemoryEffects<[MemWrite<TF_TPUEmbeddingResource>]>;
+def TF_TPUEmbeddingReadEffect : MemoryEffects<[MemRead<TF_TPUEmbeddingResource>]>;
+
 def TF_SendRecvSideEffect : MemoryEffects<[MemWrite<TF_SendRecvResource>]>;
 def TF_TPUCompileExecuteSideEffect : MemoryEffects<[MemWrite<TF_TPUCompileExecuteResource>]>;
 
+// Trait for enforcing that a side-effecting op is executed, even if it would be
+// considered dead by MLIR (see b/195782952).
+// The trait is implemented as a write effect for a fake resource which is
+// ignored by side effect analysis, so it does not affect execution order
+// constraints and control dependencies at all (for example, multiple ops with
+// this trait do not have to execute in order).
+def TF_MustExecute : MemoryEffects<[MemWrite<TF_MustExecuteResource>]>;
+
 //===----------------------------------------------------------------------===//
 // TensorFlow op definitions
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td
index bbc4638..bf1ed85 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td
@@ -131,4 +131,23 @@
   ];
 }
 
+def TF_GetResourceInstanceInterface : OpInterface<"GetResourceInstanceInterface"> {
+  let description = [{Returns an integer corresponding to the resource instance
+                      accessed by this op}];
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{Returns an integer corresponding to the resource instance
+                 accessed by this op. The implementation must guarantee that the
+                 mapping between resource instances and integers is bijective,
+                 i.e., two op instances should return the same integer if and
+                 only if they access the same resource. The interface should
+                 only be used for ops that access exactly one resource.}],
+      /*retTy=*/"int64_t",
+      /*methodName=*/"GetResourceInstanceId",
+      /*args=*/(ins)
+    >,
+  ];
+}
+
 #endif // TF_OP_INTERFACES
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
index 7133441..8e19bc2 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
@@ -1564,6 +1564,13 @@
   let results = (outs);
 }
 
+def TF__InternalTestMustExecuteTrait_ : TF_Op<"_InternalTestMustExecuteTrait_", [TF_MustExecute]> {
+  let summary = "Internal op for testing only";
+
+  let arguments = (ins);
+  let results = (outs);
+}
+
 def TF_SetStaticDimensionBoundsOp : TF_Op<"SetStaticDimensionBounds", []> {
   let summary = "Op used to indicate to the compiler and runtime the static bounds of a tensor.";
   let description = [{
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
index a2228cb..ff37408 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
@@ -2300,6 +2300,37 @@
 }
 
 //===----------------------------------------------------------------------===//
+// EnqueueTPUEmbedding ops
+//===----------------------------------------------------------------------===//
+
+// For EnqueueTPUEmbedding ops the device ordinal corresponds to the resource
+// instance.
+
+int64_t EnqueueTPUEmbeddingArbitraryTensorBatchOp::GetResourceInstanceId() {
+  return device_ordinal();
+}
+
+int64_t EnqueueTPUEmbeddingBatchOp::GetResourceInstanceId() {
+  return device_ordinal();
+}
+
+int64_t EnqueueTPUEmbeddingIntegerBatchOp::GetResourceInstanceId() {
+  return device_ordinal();
+}
+
+int64_t EnqueueTPUEmbeddingRaggedTensorBatchOp::GetResourceInstanceId() {
+  return device_ordinal();
+}
+
+int64_t EnqueueTPUEmbeddingSparseBatchOp::GetResourceInstanceId() {
+  return device_ordinal();
+}
+
+int64_t EnqueueTPUEmbeddingSparseTensorBatchOp::GetResourceInstanceId() {
+  return device_ordinal();
+}
+
+//===----------------------------------------------------------------------===//
 // EnsureShapeOp
 //===----------------------------------------------------------------------===//
 
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
index de2cf2d..853cfaa 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
@@ -2171,12 +2171,7 @@
 void TPUExecuteOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  effects.reserve(args().size() + 2);
-
-  // There may be some TPU Embedding ops in the computation, so this effect is
-  // added conservatively.
-  effects.emplace_back(MemoryEffects::Write::get(),
-                       ResourceEffects::TPUEmbedding::get());
+  effects.reserve(args().size() + 1);
   effects.emplace_back(MemoryEffects::Write::get(),
                        ResourceEffects::TPUCompileExecute::get());
 
@@ -2239,12 +2234,7 @@
 void TPUExecuteAndUpdateVariablesOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  effects.reserve(device_var_reads_indices().size() + 2);
-
-  // There may be some TPU Embedding ops in the computation, so this effect is
-  // added conservatively.
-  effects.emplace_back(MemoryEffects::Write::get(),
-                       ResourceEffects::TPUEmbedding::get());
+  effects.reserve(device_var_reads_indices().size() + 1);
   effects.emplace_back(MemoryEffects::Write::get(),
                        ResourceEffects::TPUCompileExecute::get());
   auto resource_handles = llvm::make_filter_range(args(), [](Value value) {
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h
index 267978d..ecc53b6 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h
@@ -77,6 +77,10 @@
   StringRef getName() final { return "<TPUCompileExecute>"; }
 };
 
+struct MustExecute : public ::mlir::SideEffects::Resource::Base<MustExecute> {
+  StringRef getName() final { return "<MustExecute>"; }
+};
+
 }  // namespace ResourceEffects
 }  // namespace TF
 }  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir b/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir
index 1f0a183..35feeeb 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir
@@ -185,3 +185,22 @@
   }
   return
 }
+
+// -----
+
+// Check that an op with must-execute effect is not pruned, even if it is
+// unreachable.
+func @must_execute_op() -> () {
+// CHECK: tf_executor.graph
+// CHECK: tf_executor.island
+// CHECK: tf._InternalTestMustExecuteTrait_
+  tf_executor.graph {
+    %1 = tf_executor.island {
+      "tf._InternalTestMustExecuteTrait_"() : () -> ()
+      tf_executor.yield
+    }
+    tf_executor.fetch
+  }
+  return
+}
+
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir b/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir
index 85c99a4..ad6e2f4 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir
@@ -1494,9 +1494,9 @@
 
 // -----
 
-// Tests that we treat different op instances with `TPUEmbeddingSideEffect` as
-// independent.
-func @embedding_effect_ops(
+// Tests that we create a dependency for op instances with
+// `TPUEmbeddingSideEffect` with same device ordinal.
+func @embedding_effect_same_device(
   // expected-remark@above {{ID: 7}}
   %arg0: tensor<!tf_type.string>) {
   tf_executor.graph {
@@ -1504,10 +1504,42 @@
     %island = tf_executor.island {
         // expected-remark@above {{ID: 3}}
         // expected-remark@above {{Successors: {4}}}
-        "tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg0){table_ids = [1, 2]} : (tensor<!tf_type.string>) -> ()
+        "tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg0) {table_ids = [1, 2], device_ordinal = 1} : (tensor<!tf_type.string>) -> ()
+        // expected-remark@above {{ID: 0}}
+        // expected-remark@above {{Successors: {1}}}
+        "tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg0) {table_ids = [1, 2], device_ordinal = 1} : (tensor<!tf_type.string>) -> ()
+        // expected-remark@above {{ID: 1}}
+        // expected-remark@above {{Predecessors: {0}}}
+        // expected-remark@above {{Successors: {2}}}
+        tf_executor.yield
+        // expected-remark@above {{ID: 2}}
+        // expected-remark@above {{Predecessors: {1}}}
+    }
+    tf_executor.fetch %island : !tf_executor.control
+    // expected-remark@above {{ID: 4}}
+    // expected-remark@above {{Predecessors: {3}}}
+  }
+  return
+  // expected-remark@above {{ID: 6}}
+  // expected-remark@above {{Sinks: {5}}}
+}
+
+// -----
+
+// Tests that we treat different op instances with `TPUEmbeddingSideEffect` as
+// independent if they have different device ordinals.
+func @embedding_effect_different_devices(
+  // expected-remark@above {{ID: 7}}
+  %arg0: tensor<!tf_type.string>) {
+  tf_executor.graph {
+    // expected-remark@above {{ID: 5}}
+    %island = tf_executor.island {
+        // expected-remark@above {{ID: 3}}
+        // expected-remark@above {{Successors: {4}}}
+        "tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg0) {table_ids = [1, 2], device_ordinal = 1} : (tensor<!tf_type.string>) -> ()
         // expected-remark@above {{ID: 0}}
         // expected-remark@above {{Successors: {2}}}
-        "tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg0){table_ids = [1, 2]} : (tensor<!tf_type.string>) -> ()
+        "tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg0) {table_ids = [1, 2], device_ordinal = 2} : (tensor<!tf_type.string>) -> ()
         // expected-remark@above {{ID: 1}}
         // expected-remark@above {{Successors: {2}}}
         tf_executor.yield
@@ -1561,6 +1593,42 @@
 
 // -----
 
+// Tests that we don't create dependencies between ops `EnqueueTPUEmbedding`
+ // ops and other embedding ops that don't have a device ordinal.
+func @mixed_embedding_and_unknown_effects(
+  // expected-remark@above {{ID: 8}}
+  %arg0: tensor<!tf_type.string>,
+  %arg1: tensor<8xf32>,
+  %arg2: tensor<8xf32>) {
+  tf_executor.graph {
+    // expected-remark@above {{ID: 6}}
+    %island = tf_executor.island {
+        // expected-remark@above {{ID: 4}}
+        // expected-remark@above {{Successors: {5}}}
+        "tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg0){table_ids = [1, 2], device_ordinal = 1} : (tensor<!tf_type.string>) -> ()
+        // expected-remark@above {{ID: 0}}
+        // expected-remark@above {{Successors: {3}}}
+        "tf.LoadTPUEmbeddingAdagradParameters"(%arg1, %arg2) {config = "", num_shards = 1 : i64, shard_id = 0 : i64, table_id = -1 : i64, table_name = "table1"} : (tensor<8xf32>, tensor<8xf32>) -> ()
+        // expected-remark@above {{ID: 1}}
+        // expected-remark@above {{Successors: {3}}}
+        "tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg0){table_ids = [1, 2], device_ordinal = 2} : (tensor<!tf_type.string>) -> ()
+        // expected-remark@above {{ID: 2}}
+        // expected-remark@above {{Successors: {3}}}
+        tf_executor.yield
+        // expected-remark@above {{ID: 3}}
+        // expected-remark@above {{Predecessors: {0,1,2}}}
+    }
+    tf_executor.fetch %island : !tf_executor.control
+    // expected-remark@above {{ID: 5}}
+    // expected-remark@above {{Predecessors: {4}}}
+  }
+  return
+  // expected-remark@above {{ID: 7}}
+  // expected-remark@above {{Sinks: {6}}}
+}
+
+// -----
+
 // Tests that we create a dependency between two ops with the same op-based
 // write effect.
 func @same_op_based_write_effect(
@@ -1602,13 +1670,13 @@
     %island = tf_executor.island {
         // expected-remark@above {{ID: 4}}
         // expected-remark@above {{Successors: {5}}}
-        "tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg0){table_ids = [1, 2]} : (tensor<!tf_type.string>) -> ()
+        "tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg0){table_ids = [1, 2], device_ordinal = 1} : (tensor<!tf_type.string>) -> ()
         // expected-remark@above {{ID: 0}}
         // expected-remark@above {{Successors: {3}}}
         %0 = "tf.GeneratorDataset"(%arg0, %arg0, %arg0) {device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0", finalize_func = @__func_a, init_func = @__func_b, next_func = @__func_c, next_func.experimental_ints_on_device = true, operand_segment_sizes = dense<[1, 1, 1]> : vector<3xi32>, output_shapes = [#tf_type.shape<>], output_types = [!tf_type.string], metadata = ""} : (tensor<!tf_type.string>, tensor<!tf_type.string>, tensor<!tf_type.string>) -> tensor<!tf_type.variant>
         // expected-remark@above {{ID: 1}}
         // expected-remark@above {{Successors: {3}}}
-        "tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg0){table_ids = [1, 2]} : (tensor<!tf_type.string>) -> ()
+        "tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg0){table_ids = [1, 2], device_ordinal = 5} : (tensor<!tf_type.string>) -> ()
         // expected-remark@above {{ID: 2}}
         // expected-remark@above {{Successors: {3}}}
         tf_executor.yield
@@ -1700,10 +1768,7 @@
 // -----
 
 // Tests that we create a dependency between ops with
-// `TF_TPUCompileExecuteSideEffect`. Note that this test also shows a case where
-// we could improve pruning of control dependencies (see b/201013649): The
-// dependency between the first `tf.TPUExecute` and the `tf_executor.yield` is
-// redundant.
+// `TF_TPUCompileExecuteSideEffect`.
 func @tpu_compile_execute_effect(
   // expected-remark@above {{ID: 7}}
   %arg0: tensor<!tf_type.string>,
@@ -1715,14 +1780,14 @@
         // expected-remark@above {{Successors: {4}}}
         "tf.TPUExecute"(%arg0, %arg0) : (tensor<!tf_type.string>, tensor<!tf_type.string>) -> ()
         // expected-remark@above {{ID: 0}}
-        // expected-remark@above {{Successors: {1,2}}}
+        // expected-remark@above {{Successors: {1}}}
         "tf.TPUExecute"(%arg1, %arg1) : (tensor<!tf_type.string>, tensor<!tf_type.string>) -> ()
         // expected-remark@above {{ID: 1}}
         // expected-remark@above {{Predecessors: {0}}}
         // expected-remark@above {{Successors: {2}}}
         tf_executor.yield
         // expected-remark@above {{ID: 2}}
-        // expected-remark@above {{Predecessors: {0,1}}}
+        // expected-remark@above {{Predecessors: {1}}}
     }
     tf_executor.fetch %island : !tf_executor.control
     // expected-remark@above {{ID: 4}}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc
index eebbf6e..e94e593 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc
@@ -27,6 +27,7 @@
 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
 
@@ -118,10 +119,21 @@
   getFunction().walk([this](tf_executor::GraphOp graph) { PruneGraph(graph); });
 }
 
-// An op should be preserved if its identifier is contained in
-// `ops_to_preserve_ids_`.
+// An op should be preserved if either its identifier is contained in
+// `ops_to_preserve_ids_` or if it has a `MustExecute` effect.
 bool GraphPruningPass::ShouldPreserveOp(Operation* op) {
-  return ops_to_preserve_ids_.contains(op->getName().getIdentifier());
+  if (ops_to_preserve_ids_.contains(op->getName().getIdentifier())) return true;
+
+  llvm::SmallVector<MemoryEffects::EffectInstance, 4> effects;
+  auto interface = dyn_cast<MemoryEffectOpInterface>(op);
+  if (interface) interface.getEffects(effects);
+
+  for (const auto& effect : effects) {
+    if (llvm::isa<TF::ResourceEffects::MustExecute>(effect.getResource())) {
+      return true;
+    }
+  }
+  return false;
 }
 
 // An island should be preserved if any of its inner ops should be preserved.