[tf-mlir-translate] added tests for handling data types with subtypes
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-data-type-with-subtype.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-data-type-with-subtype.pbtxt
new file mode 100644
index 0000000..ca94474
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-data-type-with-subtype.pbtxt
@@ -0,0 +1,51 @@
+# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=p,x -tf-input-data-types="DT_INT32,DT_RESOURCE(DT_INT32)" -tf-output-arrays=p,x -o - | FileCheck %s -check-prefix=CHECK-NO-SHAPE
+
+# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=p,x -tf-input-shapes=512,1024: -tf-input-data-types="DT_INT32,DT_RESOURCE(512,1024:DT_INT32)" -tf-output-arrays=p,x -o - | FileCheck %s -check-prefix=CHECK-SHAPE
+
+# Test the handling of the input data types. In particular, if the data type
+# for an input graph node is specified via command line options, use it.
+# otherwise, use the data type of the node in the graph.
+
+node {
+ name: "p"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ unknown_rank: true
+ }
+ }
+ }
+}
+node {
+ name: "x"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_RESOURCE
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ unknown_rank: true
+ }
+ }
+ }
+}
+versions {
+ producer: 216
+}
+
+# CHECK-NO-SHAPE: func @main(%arg0: tensor<i32>, %arg1: tensor<!tf_type.resource<tensor<i32>>>) -> (tensor<i32>, tensor<!tf_type.resource<tensor<i32>>>)
+
+# CHECK-SHAPE: func @main(%arg0: tensor<512x1024xi32>, %arg1: tensor<!tf_type.resource<tensor<512x1024xi32>>>) -> (tensor<512x1024xi32>, tensor<!tf_type.resource<tensor<512x1024xi32>>>)
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource_with_subtype.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/resource_with_subtype.pbtxt
deleted file mode 100644
index 07fcf03..0000000
--- a/tensorflow/compiler/mlir/tensorflow/tests/resource_with_subtype.pbtxt
+++ /dev/null
@@ -1,16 +0,0 @@
-node {
- name: "Placeholder"
- op: "Placeholder"
- input: "input0"
- input: "input1"
- input: "input2"
- attr {
- key: "T"
- value {
- type: DT_INT32
- }
- }
-}
-versions {
- producer: 27
-}
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc
index a443f40..d202152 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc
@@ -202,8 +202,8 @@
ArrayInfo& info = it_inserted_pair.first->second;
// Splitting the type and subtype into parts
- std::vector<std::string> parts = absl::StrSplit(type, absl::ByAnyChar("()")); // TODO make this better
- /// If type has subtype, parts will have three members, part[0] = type, part[1] = subtype, part[2] = ""
+ std::vector<std::string> parts = absl::StrSplit(type, absl::ByAnyChar("()"));
+ /// If type has subtypes then part[0] = type, part[1] = subtypes, part[2] = ""
if (parts.size() != 3 && parts.size() != 1) {
return errors::InvalidArgument("Invalid type '", type, "'");
} else if (parts.size() == 3) {
@@ -260,26 +260,40 @@
return Status::OK();
}
-std::vector<std::string> ParseDTypesHelper(absl::string_view data_types_str) {
+StatusOr<std::vector<std::string>> ParseDTypesHelper(absl::string_view data_types_str) {
bool inside_subtype = false;
int cur_pos = 0;
std::vector<std::string> dtypes;
for (auto it: llvm::enumerate(data_types_str)) {
char c = it.value();
int i = it.index();
+ // Skip parsing the subtypes of a type
if (c == '(') {
+ if (inside_subtype) {
+ return errors::FailedPrecondition(
+ absl::StrCat("Syntax error in data types '", data_types_str, "'"));
+ }
inside_subtype = true;
} else if (c == ')') {
+ if (!inside_subtype) {
+ return errors::FailedPrecondition(
+ absl::StrCat("Syntax error in data types '", data_types_str, "'"));
+ }
inside_subtype = false;
}
if (inside_subtype) continue;
if (c == ',') {
dtypes.push_back(std::string(data_types_str.substr(cur_pos, i)));
cur_pos = i+1;
- } else if (i == data_types_str.size()-1) {
- dtypes.push_back(std::string(data_types_str.substr(cur_pos, i+1)));
}
}
+ if (inside_subtype) {
+ return errors::FailedPrecondition(
+ absl::StrCat("Syntax error in data types '", data_types_str, "'"));
+ }
+ if (!data_types_str.empty()) {
+ dtypes.push_back(std::string(data_types_str.substr(cur_pos, data_types_str.size())));
+ }
return dtypes;
}
@@ -287,7 +301,11 @@
std::vector<std::string>& data_type_vector) {
data_type_vector.clear();
if (!data_types_str.empty()) {
- data_type_vector = ParseDTypesHelper(data_types_str);
+ auto s = ParseDTypesHelper(data_types_str);
+ if (!s.ok()) {
+ return s.status();
+ }
+ data_type_vector = s.ValueOrDie();
}
return Status::OK();
}