Internal change

PiperOrigin-RevId: 347935883
Change-Id: I88d826b6b14135d51009adb174dbdcf2a1e00e87
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
index b5455bb..5731b93 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
@@ -300,6 +300,19 @@
 // CHECK-NEXT: return %[[CST]], %[[CST1]]
 }
 
+// Tests ops that have non-local device assignment but with local device with
+// same type (CPU) are correctly evaluated.
+// CHECK-LABEL: func @testRemoteDevice() -> tensor<2x2xi32>
+func @testRemoteDevice() -> tensor<2x2xi32> {
+^bb0:
+  %0 = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
+  %1 = constant dense<1> : tensor<2xi32>
+  %2 = "tf.Add"(%0, %1) {device = "/job:remote_worker/replica:123/task:456/CPU:0", name = "add"} : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
+  // CHECK:         [[cst:%.*]] = "tf.Const{{.*}} dense<{{\[\[}}1, 2], {{\[}}3, 4]]> : tensor<2x2xi32>
+  // CHECK-NEXT:    return [[cst]] : tensor<2x2xi32>
+  return %2: tensor<2x2xi32>
+}
+
 // Tests ops that variable shapes are correctly evaluated on static types.
 // CHECK-LABEL: func @testVariableShape
 func @testVariableShape(%arg0: tensor<!tf.resource<tensor<2x4xf32>>>) -> tensor<2xi32> {
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc
index 833d35c..a3c487f 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc
@@ -105,10 +105,6 @@
     // The TFE_Context is created without an accompanying delete due to current
     // lifetime. This does not result in memory leaks reported (see totw/110).
     TFE_ContextOptions* opts = TFE_NewContextOptions();
-    // Input tensors are placed on the host CPU so use the explicit device
-    // policy to fail if no CPU kernels are available for the op.
-    TFE_ContextOptionsSetDevicePlacementPolicy(opts,
-                                               TFE_DEVICE_PLACEMENT_EXPLICIT);
     auto ctx = TFE_NewContext(opts, status);
     TFE_DeleteContextOptions(opts);
     TF_DeleteStatus(status);
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc
index b9d09e7..cca6981 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc
@@ -55,6 +55,26 @@
   return false;
 }
 
+// Update node_def's device attribute (if any) to use a local device, that is
+// /job:localhost/replica:0/task:0/{DEVICE_TYPE}:{DEVICE_ID}.
+// This is because EvaluateOperation only has access to local devices but the
+// given node may carry a device assignment to a remote device. In that case,
+// evaluation would fail even if we have a device of same type locally. By
+// altering device assignment to a local one, we could successfully evaluate in
+// that case.
+void ForceUseLocalhostDevice(NodeDef* node_def) {
+  DeviceNameUtils::ParsedName parsed_name;
+
+  if (!DeviceNameUtils::ParseFullName(node_def->device(), &parsed_name)) return;
+
+  if (parsed_name.has_job) parsed_name.job = "localhost";
+  if (parsed_name.has_replica) parsed_name.replica = 0;
+  if (parsed_name.has_task) parsed_name.task = 0;
+
+  *node_def->mutable_device() =
+      DeviceNameUtils::ParsedNameToString(parsed_name);
+}
+
 mlir::LogicalResult EvaluateOperation(
     mlir::Operation* inst, llvm::ArrayRef<mlir::ElementsAttr> operands,
     TFE_Context* context, llvm::SmallVectorImpl<mlir::Attribute>* results) {
@@ -84,14 +104,13 @@
   RETURN_FAILURE_IF_ERROR(node_def_or.status());
   const auto& node_def = node_def_or.ValueOrDie();
 
-  // Note that we don't set device for this op based on the assigned device
-  // attribute of the op. We want to evaluate operation on the host CPU as the
-  // assigned device might be remote, not available yet or compilation only
-  // on demand device which may create a recursion. The eager runtime executes
-  // the op on the device input tensors are placed which is host CPU here.
+  ForceUseLocalhostDevice(node_def.get());
+
   TFE_Op* op = TFE_NewOp(context, node_def->op().c_str(), status);
   RETURN_FAILURE_IF_ERROR(status);
   auto clean_op = MakeCleanup([op] { TFE_DeleteOp(op); });
+  TFE_OpSetDevice(op, node_def->device().c_str(), status);
+  RETURN_FAILURE_IF_ERROR(status);
   for (const auto& attr : node_def->attr()) {
     SetOpAttrValueScalar(context, op, attr.second, attr.first.c_str(), status);
     RETURN_FAILURE_IF_ERROR(status);
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.h b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.h
index e3e14af..4130e72 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.h
+++ b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.h
@@ -25,10 +25,8 @@
 namespace tensorflow {
 
 // Attempts to evaluates an MLIR Operation in TensorFlow eager mode with the
-// specified operands. The op is always executed on the local host CPU
-// irrespective of the device attribute of the given op. If there is a CPU
-// kernel registered for the op and is executed successfully, this fills in the
-// results vector.  If not, results vector is unspecified.
+// specified operands. If successful, this fills in the results vector. If not,
+// results vector is unspecified.
 //
 mlir::LogicalResult EvaluateOperation(
     mlir::Operation* inst, llvm::ArrayRef<mlir::ElementsAttr> operands,
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index c13517b..9186ae0 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -791,7 +791,6 @@
     name = "listdiff_op_test",
     size = "small",
     srcs = ["listdiff_op_test.py"],
-    enable_mlir_bridge = True,
     python_version = "PY3",
     tags = [
         "no_cuda_asan",  # times out
@@ -1391,7 +1390,6 @@
     name = "unary_ops_test",
     size = "medium",
     srcs = ["unary_ops_test.py"],
-    enable_mlir_bridge = True,
     python_version = "PY3",
     tags = [
         "no_cuda_asan",  # times out
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index ab71f6f..94b34cf 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -1009,6 +1009,7 @@
           np.array([], dtype=dtype).reshape((0, 3)),
           expected=np.array([[0, 0, 0], [0, 0, 0]], dtype=dtype))
 
+  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testMatMul(self):
     self._testMatMul(math_ops.matmul, self.float_types | {np.float64})
 
@@ -1046,6 +1047,7 @@
     self._testMatMul(SparseMatmulWrapperFT, self.float_types)
     self._testMatMul(SparseMatmulWrapperTT, self.float_types)
 
+  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testBatchMatMul(self):
     # Tests with batches of matrices.
     for dtype in self.float_types | {np.float64}:
@@ -1097,6 +1099,7 @@
             x,
             expected=np.matmul(x, x.transpose([0, 1, 3, 2])))
 
+  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testExpandDims(self):
     for dtype in self.numeric_types:
       self._testBinary(
@@ -1364,6 +1367,7 @@
               np.reshape(np.array([16, 18, 8], dtype=dtype), (3, 1)),
               (1, 2, 3, 1)))
 
+  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testReshape(self):
     for dtype in self.numeric_types:
       self._testBinary(
@@ -1495,6 +1499,7 @@
                [1, 2]],
               dtype=dtype))
 
+  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testTranspose(self):
     for dtype in self.numeric_types:
       self._testBinary(
@@ -1513,6 +1518,7 @@
           np.array([1, 0], dtype=np.int32),
           expected=np.array([[1, 3], [2, 4]], dtype=dtype))
 
+  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testConjugateTranspose(self):
     for dtype in self.complex_types:
       self._testBinary(
@@ -1549,6 +1555,7 @@
           np.array([[4, 5, 6], [40, 50, 60]], dtype=dtype),
           expected=np.array([[-3, 6, -3], [60, -120, 60]], dtype=dtype))
 
+  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testBroadcastArgs(self):
     self._testBinary(array_ops.broadcast_dynamic_shape,
                      np.array([2, 3, 5], dtype=np.int32),
@@ -1609,7 +1616,6 @@
                        np.array([4, 5, 6], dtype=np.int32),
                        expected=None)
 
-  @test_util.disable_mlir_bridge("TODO(b/175721108): Legalize BroadcastArgs op")
   def testBroadcastTo(self):
     for dtype in self.all_types:
       x = np.random.randint(0, high=100, size=[2, 3])
diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py
index 41107d0..5d2b8a6 100644
--- a/tensorflow/compiler/tests/concat_ops_test.py
+++ b/tensorflow/compiler/tests/concat_ops_test.py
@@ -368,6 +368,7 @@
         ans = self.evaluate(packed)
         self.assertAllEqual(ans, [2, 3, 5])
 
+  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testEmpty(self):
     with self.session():
       with self.test_scope():
diff --git a/tensorflow/compiler/tests/nary_ops_test.py b/tensorflow/compiler/tests/nary_ops_test.py
index 66be684..d0f6229 100644
--- a/tensorflow/compiler/tests/nary_ops_test.py
+++ b/tensorflow/compiler/tests/nary_ops_test.py
@@ -159,8 +159,7 @@
                     np.array([[3, 4], [7, 8], [1, 2]], dtype=np.float32)]
         self.assertAllEqual(output, expected)
 
-  @test_util.disable_mlir_bridge(
-      "TODO(b/344771933): Fix canonicalization failure")
+  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testStridedSlice(self):
     self._testNAry(lambda x: array_ops.strided_slice(*x),
                    [np.array([[], [], []], dtype=np.float32),
@@ -205,8 +204,7 @@
                              dtype=np.float32)],
                    expected=np.array([[4], [5], [6]], dtype=np.float32))
 
-  @test_util.disable_mlir_bridge(
-      "TODO(b/344771933): Fix canonicalization failure")
+  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testStridedSliceGrad(self):
     # Tests cases where input shape is empty.
     self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py
index b890960..fe1d2c5 100644
--- a/tensorflow/compiler/tests/reduce_ops_test.py
+++ b/tensorflow/compiler/tests/reduce_ops_test.py
@@ -98,22 +98,27 @@
   ]
   ONES = [np.ones([34000, 2])]
 
+  @test_util.disable_mlir_bridge('TODO(b/172473885)')
   def testReduceSumF32(self, index_dtype):
     self._testReduction(math_ops.reduce_sum, np.sum, np.float32, self.REAL_DATA,
                         index_dtype)
 
+  @test_util.disable_mlir_bridge('TODO(b/172473885)')
   def testReduceSumC64(self, index_dtype):
     self._testReduction(math_ops.reduce_sum, np.sum, np.complex64,
                         self.COMPLEX_DATA, index_dtype)
 
+  @test_util.disable_mlir_bridge('TODO(b/172473885)')
   def testReduceProdF32(self, index_dtype):
     self._testReduction(math_ops.reduce_prod, np.prod, np.float32,
                         self.REAL_DATA, index_dtype)
 
+  @test_util.disable_mlir_bridge('TODO(b/172473885)')
   def testReduceProdC64(self, index_dtype):
     self._testReduction(math_ops.reduce_prod, np.prod, np.complex64,
                         self.COMPLEX_DATA, index_dtype)
 
+  @test_util.disable_mlir_bridge('TODO(b/172473885)')
   def testReduceMin(self, index_dtype):
 
     def reference_min(dtype, inp, axis):
@@ -131,6 +136,7 @@
                           functools.partial(reference_min, dtype), dtype,
                           self.REAL_DATA, index_dtype)
 
+  @test_util.disable_mlir_bridge('TODO(b/172473885)')
   def testReduceMax(self, index_dtype):
 
     def reference_max(dtype, inp, axis):
@@ -149,6 +155,7 @@
                           functools.partial(reference_max, dtype), dtype,
                           self.REAL_DATA, index_dtype)
 
+  @test_util.disable_mlir_bridge('TODO(b/172473885)')
   def testReduceMeanF32(self, index_dtype):
     # TODO(phawkins): mean on XLA currently returns 0 instead of NaN when
     # reducing across zero inputs.
@@ -164,10 +171,12 @@
     self._testReduction(math_ops.reduce_mean, np.mean, np.complex64,
                         self.NONEMPTY_COMPLEX_DATA, index_dtype)
 
+  @test_util.disable_mlir_bridge('TODO(b/172473885)')
   def testReduceAll(self, index_dtype):
     self._testReduction(math_ops.reduce_all, np.all, np.bool, self.BOOL_DATA,
                         index_dtype)
 
+  @test_util.disable_mlir_bridge('TODO(b/172473885)')
   def testReduceAny(self, index_dtype):
     self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA,
                         index_dtype)
diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py
index 01df466..aa72f47 100644
--- a/tensorflow/compiler/tests/scan_ops_test.py
+++ b/tensorflow/compiler/tests/scan_ops_test.py
@@ -91,8 +91,7 @@
       for reverse in [True, False]:
         self._compare(x, axis, exclusive, reverse)
 
-  @test_util.disable_mlir_bridge(
-      "TODO(b/175721108): Fix lowering to not generate illegal ReduceWindow op")
+  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testEmpty(self):
     for dtype in self.valid_dtypes:
       x = np.zeros([0]).astype(dtype)
@@ -172,8 +171,7 @@
       for reverse in [True, False]:
         self._compare(x, axis, exclusive, reverse)
 
-  @test_util.disable_mlir_bridge(
-      "TODO(b/175721108): Fix lowering to not generate illegal ReduceWindow op")
+  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testEmpty(self):
     for dtype in self.valid_dtypes:
       x = np.zeros([0]).astype(dtype)
diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py
index 7e817c6..76b7e18 100644
--- a/tensorflow/compiler/tests/spacetobatch_op_test.py
+++ b/tensorflow/compiler/tests/spacetobatch_op_test.py
@@ -248,18 +248,19 @@
         outputs=[[[0, 0], [2, 21]], [[0, 0], [5, 51]], [[1, 11], [3, 31]],
                  [[4, 41], [6, 61]]])
 
-  @test_util.disable_mlir_bridge("TODO(b/344771933): Fix failure")
+  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testDirect0(self):
     # Test with zero-size remaining dimension.
     self._testDirect(
         input_shape=[3, 1, 2, 0], block_shape=[3], paddings=[[0, 2]])
 
+  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testDirect1(self):
     # Test with zero-size blocked dimension.
     self._testDirect(
         input_shape=[3, 0, 2, 5], block_shape=[3], paddings=[[0, 0]])
 
-  @test_util.disable_mlir_bridge("TODO(b/344771933): Fix failure")
+  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testDirect2(self):
     # Test with padding up from zero size.
     self._testDirect(
diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py
index 4109fdc..3d310dd 100644
--- a/tensorflow/compiler/tests/ternary_ops_test.py
+++ b/tensorflow/compiler/tests/ternary_ops_test.py
@@ -63,6 +63,7 @@
     self.assertEqual(result[-1], expected[-1])
     self.assertEqual(result[0], expected[0])
 
+  @test_util.disable_mlir_bridge('TODO(b/172473885)')
   def testRange(self):
     self._testTernary(
         math_ops.range,
@@ -182,6 +183,7 @@
           np.array([8, 9], dtype=dtype),
           expected=np.array([[7, 9], [8, 7], [8, 9]], dtype=dtype))
 
+  @test_util.disable_mlir_bridge('TODO(b/172473885)')
   def testSlice(self):
     for dtype in self.numeric_types:
       self._testTernary(
diff --git a/tensorflow/compiler/tests/tridiagonal_solve_ops_test.py b/tensorflow/compiler/tests/tridiagonal_solve_ops_test.py
index ca50916..84aa725 100644
--- a/tensorflow/compiler/tests/tridiagonal_solve_ops_test.py
+++ b/tensorflow/compiler/tests/tridiagonal_solve_ops_test.py
@@ -193,6 +193,7 @@
   def test1x1(self):
     self._test(diags=[[0], [3], [0]], rhs=[6], expected=[2])
 
+  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def test0x0(self):
     self._test(
         diags=np.zeros(shape=(3, 0), dtype=np.float32),
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index 60cdbf5..f3f6fa8 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -625,7 +625,6 @@
           np.array([-1, -0.5, 0, 0.3], dtype=dtype),
           expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype))
 
-  @test_util.disable_mlir_bridge("TODO(b/344771933): Fix failure")
   def testComplexOps(self):
     for dtype in self.complex_types:
 
diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py
index 271bf66..254f9ac 100644
--- a/tensorflow/compiler/tests/xla_device_test.py
+++ b/tensorflow/compiler/tests/xla_device_test.py
@@ -24,6 +24,7 @@
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_control_flow_ops
 from tensorflow.python.platform import test
@@ -31,6 +32,7 @@
 
 class XlaDeviceTest(xla_test.XLATestCase):
 
+  @test_util.disable_mlir_bridge("TODO(b/172473885)")
   def testCopies(self):
     """Tests that copies onto and off XLA devices work."""
     shapes = [[0], [1], [1, 0], [1024, 0], [1024, 1], [3, 777], [777, 3],