Improve error message of RaggedTensor by showing data type explicitly
While working on writing a tf.data pipeline with RaggedTensor the following
error showed up:
```
def raise_from(value, from_value):
> raise value
E InvalidArgumentError: Expected splits Tensor dtype: 9, found: 3 [Op:RaggedTensorFromVariant]
/usr/local/lib/python2.7/dist-packages/six.py:737: InvalidArgumentError
```
It is not very obvious about the exact type that needs. Until found out in
`tensorflow/core/framework/types.proto` that `3` is `int32` and `9` is `int64`.
This PR enhance the error message by explictily print out the DataType in string,
so the message will be:
```
E InvalidArgumentError: Expected splits Tensor dtype: int64, found: int32 [Op:RaggedTensorFromVariant]
```
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
diff --git a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc
index e2bebf3..f83bcb3 100644
--- a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc
+++ b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc
@@ -97,8 +97,8 @@
}
if (values_tensor->dtype() != value_dtype) {
return errors::InvalidArgument(
- "Expected values Tensor dtype: ", value_dtype,
- ", found: ", values_tensor->dtype());
+ "Expected values Tensor dtype: ", DataTypeString(value_dtype),
+ ", found: ", DataTypeString(values_tensor->dtype()));
}
if (values_tensor->dims() < 1) {
return errors::InvalidArgument(
diff --git a/tensorflow/core/kernels/ragged_tensor_from_variant_op_test.cc b/tensorflow/core/kernels/ragged_tensor_from_variant_op_test.cc
index 0be3609..d5626dc 100644
--- a/tensorflow/core/kernels/ragged_tensor_from_variant_op_test.cc
+++ b/tensorflow/core/kernels/ragged_tensor_from_variant_op_test.cc
@@ -605,7 +605,7 @@
input_ragged_rank, output_ragged_rank, TensorShape({1}),
{variant_component_1});
EXPECT_TRUE(absl::StartsWith(RunOpKernel().error_message(),
- "Expected values Tensor dtype: 7, found: 3"));
+ "Expected values Tensor dtype: string, found: int32"));
}
TEST_F(RaggedTensorFromVariantKernelTest, RaggedValuesRankNotGreaterThanOne) {