[tf-mlir-translate] refactored subtype shape parsing
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource_with_subtype.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/resource_with_subtype.pbtxt
new file mode 100644
index 0000000..07fcf03
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/resource_with_subtype.pbtxt
@@ -0,0 +1,16 @@
+node {
+ name: "Placeholder"
+ op: "Placeholder"
+ input: "input0"
+ input: "input1"
+ input: "input2"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+versions {
+ producer: 27
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc
index 0514048..a443f40 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc
@@ -26,6 +26,7 @@
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
@@ -110,6 +111,23 @@
return Status::OK();
}
+StatusOr<std::vector<int>> ParseShapeStr(absl::string_view node_shapes_str) {
+ std::vector<int> dims;
+ for (const absl::string_view dim_str :
+ absl::StrSplit(node_shapes_str, ',')) {
+ // Treats empty input shape as scalar
+ if (dim_str.empty()) continue;
+ if (dim_str == "?") {
+ dims.push_back(-1);
+ continue;
+ }
+ int size;
+ TF_RET_CHECK(absl::SimpleAtoi(dim_str, &size));
+ dims.push_back(size);
+ }
+ return dims;
+}
+
static Status
HandleSubtype(absl::string_view subtype, ArrayInfo::SubTypeInfo *result) {
std::vector<std::string> shape_and_type = absl::StrSplit(subtype, ':');
@@ -120,9 +138,7 @@
"Invalid argument, the subtype and shape have to be separated with a ':'");
} else if(shape_and_type.size() == 2) {
const auto &shape_str = shape_and_type[0];
- auto s = ParseSubtypeShape(shape_str, dims);
- if (!s.ok())
- return s;
+ dims = ParseShapeStr(shape_str).ValueOrDie();
}
const auto &subtype_str = shape_and_type.back();
@@ -228,20 +244,11 @@
shapes_vector.push_back(llvm::None);
continue;
}
- std::vector<int> dims;
- for (const absl::string_view dim_str :
- absl::StrSplit(node_shapes_str[i], ',')) {
- // Treats empty input shape as scalar
- if (dim_str.empty()) continue;
- if (dim_str == "?") {
- dims.push_back(-1);
- continue;
- }
- int size;
- TF_RET_CHECK(absl::SimpleAtoi(dim_str, &size));
- dims.push_back(size);
+ auto s = ParseShapeStr(node_shapes_str[i]);
+ if (!s.ok()) {
+ return s.status();
}
- shapes_vector.push_back(dims);
+ shapes_vector.push_back(s.ValueOrDie());
}
}
return Status::OK();
@@ -253,11 +260,34 @@
return Status::OK();
}
+std::vector<std::string> ParseDTypesHelper(absl::string_view data_types_str) {
+ bool inside_subtype = false;
+ int cur_pos = 0;
+ std::vector<std::string> dtypes;
+ for (auto it: llvm::enumerate(data_types_str)) {
+ char c = it.value();
+ int i = it.index();
+ if (c == '(') {
+ inside_subtype = true;
+ } else if (c == ')') {
+ inside_subtype = false;
+ }
+ if (inside_subtype) continue;
+ if (c == ',') {
+ dtypes.push_back(std::string(data_types_str.substr(cur_pos, i)));
+ cur_pos = i+1;
+ } else if (i == data_types_str.size()-1) {
+ dtypes.push_back(std::string(data_types_str.substr(cur_pos, i+1)));
+ }
+ }
+ return dtypes;
+}
+
Status ParseNodeDataTypes(absl::string_view data_types_str,
std::vector<std::string>& data_type_vector) {
data_type_vector.clear();
if (!data_types_str.empty()) {
- data_type_vector = absl::StrSplit(data_types_str, ',');
+ data_type_vector = ParseDTypesHelper(data_types_str);
}
return Status::OK();
}