Updated contrib/bigtable/ to use tstring.
Added Cord copy-ctor/assignment for tstring.
This is a part of a larger migration effort for tensorflow::tstring.
See: https://github.com/tensorflow/community/pull/91
PiperOrigin-RevId: 265451696
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc
index 01cedd8..1365855 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc
@@ -67,9 +67,9 @@
strings::StrCat("Error reading from Cloud Bigtable: ", status.message()));
}
-string RegexFromStringSet(const std::vector<string>& strs) {
+string RegexFromStringSet(const std::vector<tstring>& strs) {
CHECK(!strs.empty()) << "The list of strings to turn into a regex was empty.";
- std::unordered_set<string> uniq(strs.begin(), strs.end());
+ std::unordered_set<tstring> uniq(strs.begin(), strs.end());
if (uniq.size() == 1) {
return *uniq.begin();
}
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h
index 085dc75..ce2bea0 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h
@@ -25,7 +25,7 @@
Status GcpStatusToTfStatus(const ::google::cloud::Status& status);
-string RegexFromStringSet(const std::vector<string>& strs);
+string RegexFromStringSet(const std::vector<tstring>& strs);
class BigtableClientResource : public ResourceBase {
public:
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc
index 6f1f880..a699362 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc
@@ -29,11 +29,11 @@
core::RefCountPtr<BigtableTableResource> table;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &table));
- std::vector<string> column_families;
- std::vector<string> columns;
- OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "column_families",
- &column_families));
- OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "columns", &columns));
+ std::vector<tstring> column_families;
+ std::vector<tstring> columns;
+ OP_REQUIRES_OK(ctx, ParseVectorArgument<tstring>(ctx, "column_families",
+ &column_families));
+ OP_REQUIRES_OK(ctx, ParseVectorArgument<tstring>(ctx, "columns", &columns));
OP_REQUIRES(
ctx, column_families.size() == columns.size(),
errors::InvalidArgument("len(columns) != len(column_families)"));
@@ -58,8 +58,8 @@
public:
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
BigtableTableResource* table,
- std::vector<string> column_families,
- std::vector<string> columns,
+ std::vector<tstring> column_families,
+ std::vector<tstring> columns,
const DataTypeVector& output_types,
std::vector<PartialTensorShape> output_shapes)
: DatasetBase(DatasetContext(ctx)),
@@ -112,8 +112,8 @@
private:
static ::google::cloud::bigtable::Filter MakeFilter(
- const std::vector<string>& column_families,
- const std::vector<string>& columns) {
+ const std::vector<tstring>& column_families,
+ const std::vector<tstring>& columns) {
string column_family_regex = RegexFromStringSet(column_families);
string column_regex = RegexFromStringSet(columns);
@@ -210,7 +210,7 @@
for (auto cell_itr = row.cells().begin();
!found_column && cell_itr != row.cells().end(); ++cell_itr) {
if (cell_itr->family_name() == dataset()->column_families_[i] &&
- string(cell_itr->column_qualifier()) ==
+ tstring(cell_itr->column_qualifier()) ==
dataset()->columns_[i]) {
col_tensor.scalar<tstring>()() = tstring(cell_itr->value());
found_column = true;
@@ -232,8 +232,8 @@
const DatasetBase* const input_;
BigtableTableResource* table_;
- const std::vector<string> column_families_;
- const std::vector<string> columns_;
+ const std::vector<tstring> column_families_;
+ const std::vector<tstring> columns_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
const ::google::cloud::bigtable::Filter filter_;
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc
index 51ccd83..6af5c6d 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc
@@ -26,8 +26,8 @@
using DatasetOpKernel::DatasetOpKernel;
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
- string prefix;
- OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "prefix", &prefix));
+ tstring prefix;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, "prefix", &prefix));
core::RefCountPtr<BigtableTableResource> resource;
OP_REQUIRES_OK(ctx,
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc
index 2bc642f..22f7ddf 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc
@@ -26,11 +26,11 @@
using DatasetOpKernel::DatasetOpKernel;
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
- string start_key;
+ tstring start_key;
OP_REQUIRES_OK(ctx,
- ParseScalarArgument<string>(ctx, "start_key", &start_key));
- string end_key;
- OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "end_key", &end_key));
+ ParseScalarArgument<tstring>(ctx, "start_key", &start_key));
+ tstring end_key;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, "end_key", &end_key));
core::RefCountPtr<BigtableTableResource> resource;
OP_REQUIRES_OK(ctx,
@@ -108,7 +108,7 @@
const ::google::cloud::bigtable::Row& row,
std::vector<Tensor>* out_tensors) override {
Tensor output_tensor(ctx->allocator({}), DT_STRING, {});
- output_tensor.scalar<tstring>()() = string(row.row_key());
+ output_tensor.scalar<tstring>()() = tstring(row.row_key());
out_tensors->emplace_back(std::move(output_tensor));
return Status::OK();
}
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc
index 2659097..08bf35f 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc
@@ -27,14 +27,14 @@
using DatasetOpKernel::DatasetOpKernel;
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
- string prefix;
- OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "prefix", &prefix));
+ tstring prefix;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, "prefix", &prefix));
- string start_key;
+ tstring start_key;
OP_REQUIRES_OK(ctx,
- ParseScalarArgument<string>(ctx, "start_key", &start_key));
- string end_key;
- OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "end_key", &end_key));
+ ParseScalarArgument<tstring>(ctx, "start_key", &start_key));
+ tstring end_key;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, "end_key", &end_key));
core::RefCountPtr<BigtableTableResource> resource;
OP_REQUIRES_OK(ctx,
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc
index 1118caf..f449830 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc
@@ -103,7 +103,7 @@
out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
TensorShape({}));
out_tensors->back().scalar<tstring>()() =
- string(row_keys_[index_].row_key);
+ tstring(row_keys_[index_].row_key);
*end_of_sequence = false;
index_++;
} else {
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc
index b6beaf3..d2b6959 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc
@@ -26,13 +26,13 @@
using DatasetOpKernel::DatasetOpKernel;
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
- string prefix;
- OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "prefix", &prefix));
- string start_key;
+ tstring prefix;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, "prefix", &prefix));
+ tstring start_key;
OP_REQUIRES_OK(ctx,
- ParseScalarArgument<string>(ctx, "start_key", &start_key));
- string end_key;
- OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "end_key", &end_key));
+ ParseScalarArgument<tstring>(ctx, "start_key", &start_key));
+ tstring end_key;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, "end_key", &end_key));
OP_REQUIRES(ctx, !(prefix.empty() && start_key.empty()),
errors::InvalidArgument(
@@ -46,11 +46,11 @@
"If prefix is specified, end_key must be empty."));
}
- std::vector<string> column_families;
- std::vector<string> columns;
- OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "column_families",
- &column_families));
- OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "columns", &columns));
+ std::vector<tstring> column_families;
+ std::vector<tstring> columns;
+ OP_REQUIRES_OK(ctx, ParseVectorArgument<tstring>(ctx, "column_families",
+ &column_families));
+ OP_REQUIRES_OK(ctx, ParseVectorArgument<tstring>(ctx, "columns", &columns));
OP_REQUIRES(
ctx, column_families.size() == columns.size(),
errors::InvalidArgument("len(columns) != len(column_families)"));
@@ -90,8 +90,8 @@
public:
explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table,
string prefix, string start_key, string end_key,
- std::vector<string> column_families,
- std::vector<string> columns, float probability,
+ std::vector<tstring> column_families,
+ std::vector<tstring> columns, float probability,
const DataTypeVector& output_types,
std::vector<PartialTensorShape> output_shapes)
: DatasetBase(DatasetContext(ctx)),
@@ -180,7 +180,7 @@
std::vector<Tensor>* out_tensors) override {
out_tensors->reserve(dataset()->columns_.size() + 1);
Tensor row_key_tensor(ctx->allocator({}), DT_STRING, {});
- row_key_tensor.scalar<tstring>()() = string(row.row_key());
+ row_key_tensor.scalar<tstring>()() = tstring(row.row_key());
out_tensors->emplace_back(std::move(row_key_tensor));
if (row.cells().size() > 2 * dataset()->columns_.size()) {
@@ -196,9 +196,9 @@
for (auto cell_itr = row.cells().begin();
!found_column && cell_itr != row.cells().end(); ++cell_itr) {
if (cell_itr->family_name() == dataset()->column_families_[i] &&
- string(cell_itr->column_qualifier()) ==
+ tstring(cell_itr->column_qualifier()) ==
dataset()->columns_[i]) {
- col_tensor.scalar<string>()() = string(cell_itr->value());
+ col_tensor.scalar<tstring>()() = tstring(cell_itr->value());
found_column = true;
}
}
@@ -217,8 +217,8 @@
const string prefix_;
const string start_key_;
const string end_key_;
- const std::vector<string> column_families_;
- const std::vector<string> columns_;
+ const std::vector<tstring> column_families_;
+ const std::vector<tstring> columns_;
const string column_family_regex_;
const string column_regex_;
const float probability_;
diff --git a/tensorflow/core/platform/tstring.h b/tensorflow/core/platform/tstring.h
index d7c8275..30d2547 100644
--- a/tensorflow/core/platform/tstring.h
+++ b/tensorflow/core/platform/tstring.h
@@ -32,6 +32,12 @@
class string_view;
}
+#ifdef PLATFORM_GOOGLE
+// TODO(dero): Move above to 'namespace absl' when absl moves Cord out of global
+// namepace.
+class Cord;
+#endif // PLATFORM_GOOGLE
+
namespace tensorflow {
// tensorflow::tstring is the scalar type for DT_STRING tensors.
@@ -77,6 +83,12 @@
T>::type* = nullptr>
explicit tstring(const T& str) : str_(str.data(), str.size()) {}
+#ifdef PLATFORM_GOOGLE
+ template <typename T, typename std::enable_if<std::is_same<T, Cord>::value,
+ T>::type* = nullptr>
+ explicit tstring(const T& cord) : str_(string(cord)) {}
+#endif // PLATFORM_GOOGLE
+
tstring(tstring&&) noexcept = default;
~tstring() = default;
@@ -98,6 +110,16 @@
return *this;
}
+#ifdef PLATFORM_GOOGLE
+ template <typename T, typename std::enable_if<std::is_same<T, Cord>::value,
+ T>::type* = nullptr>
+ tstring& operator=(const T& cord) {
+ str_ = string(cord);
+
+ return *this;
+ }
+#endif // PLATFORM_GOOGLE
+
tstring& operator=(const char* str) {
str_ = str;