Internal change
PiperOrigin-RevId: 347935883
Change-Id: I88d826b6b14135d51009adb174dbdcf2a1e00e87
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
index b5455bb..5731b93 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
@@ -300,6 +300,19 @@
// CHECK-NEXT: return %[[CST]], %[[CST1]]
}
+// Tests ops that have non-local device assignment but with local device with
+// same type (CPU) are correctly evaluated.
+// CHECK-LABEL: func @testRemoteDevice() -> tensor<2x2xi32>
+func @testRemoteDevice() -> tensor<2x2xi32> {
+^bb0:
+ %0 = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
+ %1 = constant dense<1> : tensor<2xi32>
+ %2 = "tf.Add"(%0, %1) {device = "/job:remote_worker/replica:123/task:456/CPU:0", name = "add"} : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
+ // CHECK: [[cst:%.*]] = "tf.Const{{.*}} dense<{{\[\[}}1, 2], {{\[}}3, 4]]> : tensor<2x2xi32>
+ // CHECK-NEXT: return [[cst]] : tensor<2x2xi32>
+ return %2: tensor<2x2xi32>
+}
+
// Tests ops that variable shapes are correctly evaluated on static types.
// CHECK-LABEL: func @testVariableShape
func @testVariableShape(%arg0: tensor<!tf.resource<tensor<2x4xf32>>>) -> tensor<2xi32> {
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc
index 833d35c..a3c487f 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc
@@ -105,10 +105,6 @@
// The TFE_Context is created without an accompanying delete due to current
// lifetime. This does not result in memory leaks reported (see totw/110).
TFE_ContextOptions* opts = TFE_NewContextOptions();
- // Input tensors are placed on the host CPU so use the explicit device
- // policy to fail if no CPU kernels are available for the op.
- TFE_ContextOptionsSetDevicePlacementPolicy(opts,
- TFE_DEVICE_PLACEMENT_EXPLICIT);
auto ctx = TFE_NewContext(opts, status);
TFE_DeleteContextOptions(opts);
TF_DeleteStatus(status);
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc
index b9d09e7..cca6981 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc
@@ -55,6 +55,26 @@
return false;
}
+// Update node_def's device attribute (if any) to use a local device, that is
+// /job:localhost/replica:0/task:0/{DEVICE_TYPE}:{DEVICE_ID}.
+// This is because EvaluateOperation only has access to local devices but the
+// given node may carry a device assignment to a remote device. In that case,
+// evaluation would fail even if we have a device of same type locally. By
+// altering device assignment to a local one, we could successfully evaluate in
+// that case.
+void ForceUseLocalhostDevice(NodeDef* node_def) {
+ DeviceNameUtils::ParsedName parsed_name;
+
+ if (!DeviceNameUtils::ParseFullName(node_def->device(), &parsed_name)) return;
+
+ if (parsed_name.has_job) parsed_name.job = "localhost";
+ if (parsed_name.has_replica) parsed_name.replica = 0;
+ if (parsed_name.has_task) parsed_name.task = 0;
+
+ *node_def->mutable_device() =
+ DeviceNameUtils::ParsedNameToString(parsed_name);
+}
+
mlir::LogicalResult EvaluateOperation(
mlir::Operation* inst, llvm::ArrayRef<mlir::ElementsAttr> operands,
TFE_Context* context, llvm::SmallVectorImpl<mlir::Attribute>* results) {
@@ -84,14 +104,13 @@
RETURN_FAILURE_IF_ERROR(node_def_or.status());
const auto& node_def = node_def_or.ValueOrDie();
- // Note that we don't set device for this op based on the assigned device
- // attribute of the op. We want to evaluate operation on the host CPU as the
- // assigned device might be remote, not available yet or compilation only
- // on demand device which may create a recursion. The eager runtime executes
- // the op on the device input tensors are placed which is host CPU here.
+ ForceUseLocalhostDevice(node_def.get());
+
TFE_Op* op = TFE_NewOp(context, node_def->op().c_str(), status);
RETURN_FAILURE_IF_ERROR(status);
auto clean_op = MakeCleanup([op] { TFE_DeleteOp(op); });
+ TFE_OpSetDevice(op, node_def->device().c_str(), status);
+ RETURN_FAILURE_IF_ERROR(status);
for (const auto& attr : node_def->attr()) {
SetOpAttrValueScalar(context, op, attr.second, attr.first.c_str(), status);
RETURN_FAILURE_IF_ERROR(status);
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.h b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.h
index e3e14af..4130e72 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.h
+++ b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.h
@@ -25,10 +25,8 @@
namespace tensorflow {
// Attempts to evaluates an MLIR Operation in TensorFlow eager mode with the
-// specified operands. The op is always executed on the local host CPU
-// irrespective of the device attribute of the given op. If there is a CPU
-// kernel registered for the op and is executed successfully, this fills in the
-// results vector. If not, results vector is unspecified.
+// specified operands. If successful, this fills in the results vector. If not,
+// results vector is unspecified.
//
mlir::LogicalResult EvaluateOperation(
mlir::Operation* inst, llvm::ArrayRef<mlir::ElementsAttr> operands,
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index c13517b..9186ae0 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -791,7 +791,6 @@
name = "listdiff_op_test",
size = "small",
srcs = ["listdiff_op_test.py"],
- enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_cuda_asan", # times out
@@ -1391,7 +1390,6 @@
name = "unary_ops_test",
size = "medium",
srcs = ["unary_ops_test.py"],
- enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_cuda_asan", # times out
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index ab71f6f..94b34cf 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -1009,6 +1009,7 @@
np.array([], dtype=dtype).reshape((0, 3)),
expected=np.array([[0, 0, 0], [0, 0, 0]], dtype=dtype))
+ @test_util.disable_mlir_bridge("TODO(b/172473885)")
def testMatMul(self):
self._testMatMul(math_ops.matmul, self.float_types | {np.float64})
@@ -1046,6 +1047,7 @@
self._testMatMul(SparseMatmulWrapperFT, self.float_types)
self._testMatMul(SparseMatmulWrapperTT, self.float_types)
+ @test_util.disable_mlir_bridge("TODO(b/172473885)")
def testBatchMatMul(self):
# Tests with batches of matrices.
for dtype in self.float_types | {np.float64}:
@@ -1097,6 +1099,7 @@
x,
expected=np.matmul(x, x.transpose([0, 1, 3, 2])))
+ @test_util.disable_mlir_bridge("TODO(b/172473885)")
def testExpandDims(self):
for dtype in self.numeric_types:
self._testBinary(
@@ -1364,6 +1367,7 @@
np.reshape(np.array([16, 18, 8], dtype=dtype), (3, 1)),
(1, 2, 3, 1)))
+ @test_util.disable_mlir_bridge("TODO(b/172473885)")
def testReshape(self):
for dtype in self.numeric_types:
self._testBinary(
@@ -1495,6 +1499,7 @@
[1, 2]],
dtype=dtype))
+ @test_util.disable_mlir_bridge("TODO(b/172473885)")
def testTranspose(self):
for dtype in self.numeric_types:
self._testBinary(
@@ -1513,6 +1518,7 @@
np.array([1, 0], dtype=np.int32),
expected=np.array([[1, 3], [2, 4]], dtype=dtype))
+ @test_util.disable_mlir_bridge("TODO(b/172473885)")
def testConjugateTranspose(self):
for dtype in self.complex_types:
self._testBinary(
@@ -1549,6 +1555,7 @@
np.array([[4, 5, 6], [40, 50, 60]], dtype=dtype),
expected=np.array([[-3, 6, -3], [60, -120, 60]], dtype=dtype))
+ @test_util.disable_mlir_bridge("TODO(b/172473885)")
def testBroadcastArgs(self):
self._testBinary(array_ops.broadcast_dynamic_shape,
np.array([2, 3, 5], dtype=np.int32),
@@ -1609,7 +1616,6 @@
np.array([4, 5, 6], dtype=np.int32),
expected=None)
- @test_util.disable_mlir_bridge("TODO(b/175721108): Legalize BroadcastArgs op")
def testBroadcastTo(self):
for dtype in self.all_types:
x = np.random.randint(0, high=100, size=[2, 3])
diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py
index 41107d0..5d2b8a6 100644
--- a/tensorflow/compiler/tests/concat_ops_test.py
+++ b/tensorflow/compiler/tests/concat_ops_test.py
@@ -368,6 +368,7 @@
ans = self.evaluate(packed)
self.assertAllEqual(ans, [2, 3, 5])
+ @test_util.disable_mlir_bridge("TODO(b/172473885)")
def testEmpty(self):
with self.session():
with self.test_scope():
diff --git a/tensorflow/compiler/tests/nary_ops_test.py b/tensorflow/compiler/tests/nary_ops_test.py
index 66be684..d0f6229 100644
--- a/tensorflow/compiler/tests/nary_ops_test.py
+++ b/tensorflow/compiler/tests/nary_ops_test.py
@@ -159,8 +159,7 @@
np.array([[3, 4], [7, 8], [1, 2]], dtype=np.float32)]
self.assertAllEqual(output, expected)
- @test_util.disable_mlir_bridge(
- "TODO(b/344771933): Fix canonicalization failure")
+ @test_util.disable_mlir_bridge("TODO(b/172473885)")
def testStridedSlice(self):
self._testNAry(lambda x: array_ops.strided_slice(*x),
[np.array([[], [], []], dtype=np.float32),
@@ -205,8 +204,7 @@
dtype=np.float32)],
expected=np.array([[4], [5], [6]], dtype=np.float32))
- @test_util.disable_mlir_bridge(
- "TODO(b/344771933): Fix canonicalization failure")
+ @test_util.disable_mlir_bridge("TODO(b/172473885)")
def testStridedSliceGrad(self):
# Tests cases where input shape is empty.
self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py
index b890960..fe1d2c5 100644
--- a/tensorflow/compiler/tests/reduce_ops_test.py
+++ b/tensorflow/compiler/tests/reduce_ops_test.py
@@ -98,22 +98,27 @@
]
ONES = [np.ones([34000, 2])]
+ @test_util.disable_mlir_bridge('TODO(b/172473885)')
def testReduceSumF32(self, index_dtype):
self._testReduction(math_ops.reduce_sum, np.sum, np.float32, self.REAL_DATA,
index_dtype)
+ @test_util.disable_mlir_bridge('TODO(b/172473885)')
def testReduceSumC64(self, index_dtype):
self._testReduction(math_ops.reduce_sum, np.sum, np.complex64,
self.COMPLEX_DATA, index_dtype)
+ @test_util.disable_mlir_bridge('TODO(b/172473885)')
def testReduceProdF32(self, index_dtype):
self._testReduction(math_ops.reduce_prod, np.prod, np.float32,
self.REAL_DATA, index_dtype)
+ @test_util.disable_mlir_bridge('TODO(b/172473885)')
def testReduceProdC64(self, index_dtype):
self._testReduction(math_ops.reduce_prod, np.prod, np.complex64,
self.COMPLEX_DATA, index_dtype)
+ @test_util.disable_mlir_bridge('TODO(b/172473885)')
def testReduceMin(self, index_dtype):
def reference_min(dtype, inp, axis):
@@ -131,6 +136,7 @@
functools.partial(reference_min, dtype), dtype,
self.REAL_DATA, index_dtype)
+ @test_util.disable_mlir_bridge('TODO(b/172473885)')
def testReduceMax(self, index_dtype):
def reference_max(dtype, inp, axis):
@@ -149,6 +155,7 @@
functools.partial(reference_max, dtype), dtype,
self.REAL_DATA, index_dtype)
+ @test_util.disable_mlir_bridge('TODO(b/172473885)')
def testReduceMeanF32(self, index_dtype):
# TODO(phawkins): mean on XLA currently returns 0 instead of NaN when
# reducing across zero inputs.
@@ -164,10 +171,12 @@
self._testReduction(math_ops.reduce_mean, np.mean, np.complex64,
self.NONEMPTY_COMPLEX_DATA, index_dtype)
+ @test_util.disable_mlir_bridge('TODO(b/172473885)')
def testReduceAll(self, index_dtype):
self._testReduction(math_ops.reduce_all, np.all, np.bool, self.BOOL_DATA,
index_dtype)
+ @test_util.disable_mlir_bridge('TODO(b/172473885)')
def testReduceAny(self, index_dtype):
self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA,
index_dtype)
diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py
index 01df466..aa72f47 100644
--- a/tensorflow/compiler/tests/scan_ops_test.py
+++ b/tensorflow/compiler/tests/scan_ops_test.py
@@ -91,8 +91,7 @@
for reverse in [True, False]:
self._compare(x, axis, exclusive, reverse)
- @test_util.disable_mlir_bridge(
- "TODO(b/175721108): Fix lowering to not generate illegal ReduceWindow op")
+ @test_util.disable_mlir_bridge("TODO(b/172473885)")
def testEmpty(self):
for dtype in self.valid_dtypes:
x = np.zeros([0]).astype(dtype)
@@ -172,8 +171,7 @@
for reverse in [True, False]:
self._compare(x, axis, exclusive, reverse)
- @test_util.disable_mlir_bridge(
- "TODO(b/175721108): Fix lowering to not generate illegal ReduceWindow op")
+ @test_util.disable_mlir_bridge("TODO(b/172473885)")
def testEmpty(self):
for dtype in self.valid_dtypes:
x = np.zeros([0]).astype(dtype)
diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py
index 7e817c6..76b7e18 100644
--- a/tensorflow/compiler/tests/spacetobatch_op_test.py
+++ b/tensorflow/compiler/tests/spacetobatch_op_test.py
@@ -248,18 +248,19 @@
outputs=[[[0, 0], [2, 21]], [[0, 0], [5, 51]], [[1, 11], [3, 31]],
[[4, 41], [6, 61]]])
- @test_util.disable_mlir_bridge("TODO(b/344771933): Fix failure")
+ @test_util.disable_mlir_bridge("TODO(b/172473885)")
def testDirect0(self):
# Test with zero-size remaining dimension.
self._testDirect(
input_shape=[3, 1, 2, 0], block_shape=[3], paddings=[[0, 2]])
+ @test_util.disable_mlir_bridge("TODO(b/172473885)")
def testDirect1(self):
# Test with zero-size blocked dimension.
self._testDirect(
input_shape=[3, 0, 2, 5], block_shape=[3], paddings=[[0, 0]])
- @test_util.disable_mlir_bridge("TODO(b/344771933): Fix failure")
+ @test_util.disable_mlir_bridge("TODO(b/172473885)")
def testDirect2(self):
# Test with padding up from zero size.
self._testDirect(
diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py
index 4109fdc..3d310dd 100644
--- a/tensorflow/compiler/tests/ternary_ops_test.py
+++ b/tensorflow/compiler/tests/ternary_ops_test.py
@@ -63,6 +63,7 @@
self.assertEqual(result[-1], expected[-1])
self.assertEqual(result[0], expected[0])
+ @test_util.disable_mlir_bridge('TODO(b/172473885)')
def testRange(self):
self._testTernary(
math_ops.range,
@@ -182,6 +183,7 @@
np.array([8, 9], dtype=dtype),
expected=np.array([[7, 9], [8, 7], [8, 9]], dtype=dtype))
+ @test_util.disable_mlir_bridge('TODO(b/172473885)')
def testSlice(self):
for dtype in self.numeric_types:
self._testTernary(
diff --git a/tensorflow/compiler/tests/tridiagonal_solve_ops_test.py b/tensorflow/compiler/tests/tridiagonal_solve_ops_test.py
index ca50916..84aa725 100644
--- a/tensorflow/compiler/tests/tridiagonal_solve_ops_test.py
+++ b/tensorflow/compiler/tests/tridiagonal_solve_ops_test.py
@@ -193,6 +193,7 @@
def test1x1(self):
self._test(diags=[[0], [3], [0]], rhs=[6], expected=[2])
+ @test_util.disable_mlir_bridge("TODO(b/172473885)")
def test0x0(self):
self._test(
diags=np.zeros(shape=(3, 0), dtype=np.float32),
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index 60cdbf5..f3f6fa8 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -625,7 +625,6 @@
np.array([-1, -0.5, 0, 0.3], dtype=dtype),
expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype))
- @test_util.disable_mlir_bridge("TODO(b/344771933): Fix failure")
def testComplexOps(self):
for dtype in self.complex_types:
diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py
index 271bf66..254f9ac 100644
--- a/tensorflow/compiler/tests/xla_device_test.py
+++ b/tensorflow/compiler/tests/xla_device_test.py
@@ -24,6 +24,7 @@
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_control_flow_ops
from tensorflow.python.platform import test
@@ -31,6 +32,7 @@
class XlaDeviceTest(xla_test.XLATestCase):
+ @test_util.disable_mlir_bridge("TODO(b/172473885)")
def testCopies(self):
"""Tests that copies onto and off XLA devices work."""
shapes = [[0], [1], [1, 0], [1024, 0], [1024, 1], [3, 777], [777, 3],