[tf-mlir-translate] added tests for handling data types with subtypes
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-data-type-with-subtype.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-data-type-with-subtype.pbtxt
new file mode 100644
index 0000000..ca94474
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-data-type-with-subtype.pbtxt
@@ -0,0 +1,51 @@
+# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=p,x -tf-input-data-types="DT_INT32,DT_RESOURCE(DT_INT32)" -tf-output-arrays=p,x -o - | FileCheck %s -check-prefix=CHECK-NO-SHAPE
+
+# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=p,x -tf-input-shapes=512,1024: -tf-input-data-types="DT_INT32,DT_RESOURCE(512,1024:DT_INT32)" -tf-output-arrays=p,x -o - | FileCheck %s -check-prefix=CHECK-SHAPE
+
+# Test the handling of the input data types. In particular, if the data type
+# for an input graph node is specified via command line options, use it.
+# otherwise, use the data type of the node in the graph.
+
+node {
+  name: "p"
+  op: "Placeholder"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_INT32
+    }
+  }
+  attr {
+    key: "shape"
+    value {
+      shape {
+        unknown_rank: true
+      }
+    }
+  }
+}
+node {
+  name: "x"
+  op: "Placeholder"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_RESOURCE
+    }
+  }
+  attr {
+    key: "shape"
+    value {
+      shape {
+        unknown_rank: true
+      }
+    }
+  }
+}
+versions {
+  producer: 216
+}
+
+# CHECK-NO-SHAPE: func @main(%arg0: tensor<i32>, %arg1: tensor<!tf_type.resource<tensor<i32>>>) -> (tensor<i32>, tensor<!tf_type.resource<tensor<i32>>>)
+
+# CHECK-SHAPE: func @main(%arg0: tensor<512x1024xi32>, %arg1: tensor<!tf_type.resource<tensor<512x1024xi32>>>) -> (tensor<512x1024xi32>, tensor<!tf_type.resource<tensor<512x1024xi32>>>)
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource_with_subtype.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/resource_with_subtype.pbtxt
deleted file mode 100644
index 07fcf03..0000000
--- a/tensorflow/compiler/mlir/tensorflow/tests/resource_with_subtype.pbtxt
+++ /dev/null
@@ -1,16 +0,0 @@
-node {
-  name: "Placeholder"
-  op: "Placeholder"
-  input: "input0"
-  input: "input1"
-  input: "input2"
-  attr {
-    key: "T"
-    value {
-      type: DT_INT32
-    }
-  }
-}
-versions {
-  producer: 27
-}
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc
index a443f40..d202152 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc
@@ -202,8 +202,8 @@
 
     ArrayInfo& info = it_inserted_pair.first->second;
     // Splitting the type and subtype into parts
-    std::vector<std::string> parts = absl::StrSplit(type, absl::ByAnyChar("()")); // TODO make this better
-    /// If type has subtype, parts will have three members, part[0] = type, part[1] = subtype, part[2] = ""
+    std::vector<std::string> parts = absl::StrSplit(type, absl::ByAnyChar("()"));
+    /// If type has subtypes then part[0] = type, part[1] = subtypes, part[2] = ""
     if (parts.size() != 3 && parts.size() != 1) {
       return errors::InvalidArgument("Invalid type '", type, "'");
     } else if (parts.size() == 3) {
@@ -260,26 +260,40 @@
   return Status::OK();
 }
 
-std::vector<std::string> ParseDTypesHelper(absl::string_view data_types_str) {
+StatusOr<std::vector<std::string>> ParseDTypesHelper(absl::string_view data_types_str) {
   bool inside_subtype = false;
   int cur_pos = 0;
   std::vector<std::string> dtypes;
   for (auto it: llvm::enumerate(data_types_str)) {
     char c = it.value();
     int i = it.index();
+    // Skip parsing the subtypes of a type
     if (c == '(') {
+      if (inside_subtype) {
+        return errors::FailedPrecondition(
+          absl::StrCat("Syntax error in data types '", data_types_str, "'"));
+      }
       inside_subtype = true;
     } else if (c == ')') {
+      if (!inside_subtype) {
+        return errors::FailedPrecondition(
+          absl::StrCat("Syntax error in data types '", data_types_str, "'"));
+      }
       inside_subtype = false;
     }
     if (inside_subtype) continue;
     if (c == ',') {
       dtypes.push_back(std::string(data_types_str.substr(cur_pos, i)));
       cur_pos = i+1;
-    } else if (i == data_types_str.size()-1) {
-      dtypes.push_back(std::string(data_types_str.substr(cur_pos, i+1)));
     }
   }
+  if (inside_subtype) {
+    return errors::FailedPrecondition(
+          absl::StrCat("Syntax error in data types '", data_types_str, "'"));
+  }
+  if (!data_types_str.empty()) {
+    dtypes.push_back(std::string(data_types_str.substr(cur_pos, data_types_str.size())));
+  }
   return dtypes;
 }
 
@@ -287,7 +301,11 @@
                           std::vector<std::string>& data_type_vector) {
   data_type_vector.clear();
   if (!data_types_str.empty()) {
-    data_type_vector = ParseDTypesHelper(data_types_str);
+    auto s = ParseDTypesHelper(data_types_str);
+    if (!s.ok()) {
+      return s.status();
+    }
+    data_type_vector = s.ValueOrDie();
   }
   return Status::OK();
 }