[tf-mlir-translate] refactored subtype shape parsing
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource_with_subtype.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/resource_with_subtype.pbtxt
new file mode 100644
index 0000000..07fcf03
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/resource_with_subtype.pbtxt
@@ -0,0 +1,16 @@
+node {
+  name: "Placeholder"
+  op: "Placeholder"
+  input: "input0"
+  input: "input1"
+  input: "input2"
+  attr {
+    key: "T"
+    value {
+      type: DT_INT32
+    }
+  }
+}
+versions {
+  producer: 27
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc
index 0514048..a443f40 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc
@@ -26,6 +26,7 @@
 #include "absl/strings/str_join.h"
 #include "absl/strings/str_split.h"
 #include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/core/framework/tensor_shape.pb.h"
 #include "tensorflow/core/framework/types.h"
@@ -110,6 +111,23 @@
   return Status::OK();
 }
 
+StatusOr<std::vector<int>> ParseShapeStr(absl::string_view node_shapes_str) {
+  std::vector<int> dims;
+  for (const absl::string_view dim_str :
+        absl::StrSplit(node_shapes_str, ',')) {
+    // Treats empty input shape as scalar
+    if (dim_str.empty()) continue;
+    if (dim_str == "?") {
+      dims.push_back(-1);
+      continue;
+    }
+    int size;
+    TF_RET_CHECK(absl::SimpleAtoi(dim_str, &size));
+    dims.push_back(size);
+  }
+  return dims;
+}
+
 static Status 
 HandleSubtype(absl::string_view subtype, ArrayInfo::SubTypeInfo *result) {
   std::vector<std::string> shape_and_type = absl::StrSplit(subtype, ':');
@@ -120,9 +138,7 @@
       "Invalid argument, the subtype and shape have to be separated with a ':'");
   } else if(shape_and_type.size() == 2) {
     const auto &shape_str = shape_and_type[0];
-    auto s = ParseSubtypeShape(shape_str, dims);
-    if (!s.ok())
-      return s;
+    dims = ParseShapeStr(shape_str).ValueOrDie();
   }
 
   const auto &subtype_str = shape_and_type.back();
@@ -228,20 +244,11 @@
         shapes_vector.push_back(llvm::None);
         continue;
       }
-      std::vector<int> dims;
-      for (const absl::string_view dim_str :
-           absl::StrSplit(node_shapes_str[i], ',')) {
-        // Treats empty input shape as scalar
-        if (dim_str.empty()) continue;
-        if (dim_str == "?") {
-          dims.push_back(-1);
-          continue;
-        }
-        int size;
-        TF_RET_CHECK(absl::SimpleAtoi(dim_str, &size));
-        dims.push_back(size);
+      auto s = ParseShapeStr(node_shapes_str[i]);
+      if (!s.ok()) {
+        return s.status();
       }
-      shapes_vector.push_back(dims);
+      shapes_vector.push_back(s.ValueOrDie());
     }
   }
   return Status::OK();
@@ -253,11 +260,34 @@
   return Status::OK();
 }
 
+std::vector<std::string> ParseDTypesHelper(absl::string_view data_types_str) {
+  bool inside_subtype = false;
+  int cur_pos = 0;
+  std::vector<std::string> dtypes;
+  for (auto it: llvm::enumerate(data_types_str)) {
+    char c = it.value();
+    int i = it.index();
+    if (c == '(') {
+      inside_subtype = true;
+    } else if (c == ')') {
+      inside_subtype = false;
+    }
+    if (inside_subtype) continue;
+    if (c == ',') {
+      dtypes.push_back(std::string(data_types_str.substr(cur_pos, i)));
+      cur_pos = i+1;
+    } else if (i == data_types_str.size()-1) {
+      dtypes.push_back(std::string(data_types_str.substr(cur_pos, i+1)));
+    }
+  }
+  return dtypes;
+}
+
 Status ParseNodeDataTypes(absl::string_view data_types_str,
                           std::vector<std::string>& data_type_vector) {
   data_type_vector.clear();
   if (!data_types_str.empty()) {
-    data_type_vector = absl::StrSplit(data_types_str, ',');
+    data_type_vector = ParseDTypesHelper(data_types_str);
   }
   return Status::OK();
 }