blob: ce8972f1e7fd59235556cb9514011f0b836077de [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.h"
#include "tensorflow/contrib/ignite/kernels/client/ignite_plain_client.h"
#include "tensorflow/contrib/ignite/kernels/client/ignite_ssl_wrapper.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
IgniteDatasetIterator::IgniteDatasetIterator(
const Params& params, string host, int32 port, string cache_name,
bool local, int32 part, int32 page_size, string username, string password,
string certfile, string keyfile, string cert_password,
std::vector<int32> schema, std::vector<int32> permutation)
: DatasetIterator<IgniteDataset>(params),
cache_name_(std::move(cache_name)),
local_(local),
part_(part),
page_size_(page_size),
username_(std::move(username)),
password_(std::move(password)),
schema_(std::move(schema)),
permutation_(std::move(permutation)),
remainder_(-1),
cursor_id_(-1),
last_page_(false),
valid_state_(true) {
Client* p_client = new PlainClient(std::move(host), port, false);
if (certfile.empty())
client_ = std::unique_ptr<Client>(p_client);
else
client_ = std::unique_ptr<Client>(
new SslWrapper(std::unique_ptr<Client>(p_client), std::move(certfile),
std::move(keyfile), std::move(cert_password), false));
LOG(INFO) << "Ignite Dataset Iterator created";
}
IgniteDatasetIterator::~IgniteDatasetIterator() {
Status status = CloseConnection();
if (!status.ok()) LOG(ERROR) << status.ToString();
LOG(INFO) << "Ignite Dataset Iterator destroyed";
}
Status IgniteDatasetIterator::GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) {
mutex_lock l(mutex_);
if (valid_state_) {
Status status =
GetNextInternalWithValidState(ctx, out_tensors, end_of_sequence);
if (!status.ok()) valid_state_ = false;
return status;
}
return errors::Unknown("Iterator is invalid");
}
Status IgniteDatasetIterator::SaveInternal(IteratorStateWriter* writer) {
return errors::Unimplemented(
"Iterator for IgniteDataset does not support 'SaveInternal'");
}
Status IgniteDatasetIterator::RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) {
return errors::Unimplemented(
"Iterator for IgniteDataset does not support 'RestoreInternal')");
}
Status IgniteDatasetIterator::GetNextInternalWithValidState(
IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) {
if (remainder_ == 0 && last_page_) {
cursor_id_ = -1;
*end_of_sequence = true;
return Status::OK();
} else {
TF_RETURN_IF_ERROR(EstablishConnection());
if (remainder_ == -1) {
TF_RETURN_IF_ERROR(ScanQuery());
} else if (remainder_ == 0) {
TF_RETURN_IF_ERROR(LoadNextPage());
}
uint8_t* initial_ptr = ptr_;
std::vector<Tensor> tensors;
std::vector<int32_t> types;
TF_RETURN_IF_ERROR(parser_.Parse(&ptr_, &tensors, &types)); // Parse key
TF_RETURN_IF_ERROR(parser_.Parse(&ptr_, &tensors, &types)); // Parse val
remainder_ -= (ptr_ - initial_ptr);
TF_RETURN_IF_ERROR(CheckTypes(types));
for (size_t i = 0; i < tensors.size(); i++)
out_tensors->push_back(tensors[permutation_[i]]);
*end_of_sequence = false;
return Status::OK();
}
*end_of_sequence = true;
return Status::OK();
}
Status IgniteDatasetIterator::EstablishConnection() {
if (!client_->IsConnected()) {
TF_RETURN_IF_ERROR(client_->Connect());
Status status = Handshake();
if (!status.ok()) {
Status disconnect_status = client_->Disconnect();
if (!disconnect_status.ok()) LOG(ERROR) << disconnect_status.ToString();
return status;
}
}
return Status::OK();
}
Status IgniteDatasetIterator::CloseConnection() {
if (cursor_id_ != -1 && !last_page_) {
TF_RETURN_IF_ERROR(EstablishConnection());
TF_RETURN_IF_ERROR(client_->WriteInt(kCloseConnectionReqLength));
TF_RETURN_IF_ERROR(client_->WriteShort(kCloseConnectionOpcode));
TF_RETURN_IF_ERROR(client_->WriteLong(0)); // Request ID
TF_RETURN_IF_ERROR(client_->WriteLong(cursor_id_)); // Resource ID
int32_t res_len;
TF_RETURN_IF_ERROR(client_->ReadInt(&res_len));
if (res_len < kMinResLength)
return errors::Unknown("Close Resource Response is corrupted");
int64_t req_id;
TF_RETURN_IF_ERROR(client_->ReadLong(&req_id));
int32_t status;
TF_RETURN_IF_ERROR(client_->ReadInt(&status));
if (status != 0) {
uint8_t err_msg_header;
TF_RETURN_IF_ERROR(client_->ReadByte(&err_msg_header));
if (err_msg_header == kStringVal) {
int32_t err_msg_length;
TF_RETURN_IF_ERROR(client_->ReadInt(&err_msg_length));
uint8_t* err_msg_c = new uint8_t[err_msg_length];
auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; });
TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, err_msg_length));
string err_msg(reinterpret_cast<char*>(err_msg_c), err_msg_length);
return errors::Unknown("Close Resource Error [status=", status,
", message=", err_msg, "]");
}
return errors::Unknown("Close Resource Error [status=", status, "]");
}
cursor_id_ = -1;
return client_->Disconnect();
} else {
LOG(INFO) << "Query Cursor " << cursor_id_ << " is already closed";
}
return client_->IsConnected() ? client_->Disconnect() : Status::OK();
}
Status IgniteDatasetIterator::Handshake() {
int32_t msg_len = kHandshakeReqDefaultLength;
if (username_.empty())
msg_len += 1;
else
msg_len += 5 + username_.length(); // 1 byte header, 4 bytes length.
if (password_.empty())
msg_len += 1;
else
msg_len += 5 + password_.length(); // 1 byte header, 4 bytes length.
TF_RETURN_IF_ERROR(client_->WriteInt(msg_len));
TF_RETURN_IF_ERROR(client_->WriteByte(1));
TF_RETURN_IF_ERROR(client_->WriteShort(kProtocolMajorVersion));
TF_RETURN_IF_ERROR(client_->WriteShort(kProtocolMinorVersion));
TF_RETURN_IF_ERROR(client_->WriteShort(kProtocolPatchVersion));
TF_RETURN_IF_ERROR(client_->WriteByte(2));
if (username_.empty()) {
TF_RETURN_IF_ERROR(client_->WriteByte(kNullVal));
} else {
TF_RETURN_IF_ERROR(client_->WriteByte(kStringVal));
TF_RETURN_IF_ERROR(client_->WriteInt(username_.length()));
TF_RETURN_IF_ERROR(
client_->WriteData(reinterpret_cast<const uint8_t*>(username_.c_str()),
username_.length()));
}
if (password_.empty()) {
TF_RETURN_IF_ERROR(client_->WriteByte(kNullVal));
} else {
TF_RETURN_IF_ERROR(client_->WriteByte(kStringVal));
TF_RETURN_IF_ERROR(client_->WriteInt(password_.length()));
TF_RETURN_IF_ERROR(
client_->WriteData(reinterpret_cast<const uint8_t*>(password_.c_str()),
password_.length()));
}
int32_t handshake_res_len;
TF_RETURN_IF_ERROR(client_->ReadInt(&handshake_res_len));
uint8_t handshake_res;
TF_RETURN_IF_ERROR(client_->ReadByte(&handshake_res));
if (handshake_res != 1) {
int16_t serv_ver_major;
TF_RETURN_IF_ERROR(client_->ReadShort(&serv_ver_major));
int16_t serv_ver_minor;
TF_RETURN_IF_ERROR(client_->ReadShort(&serv_ver_minor));
int16_t serv_ver_patch;
TF_RETURN_IF_ERROR(client_->ReadShort(&serv_ver_patch));
uint8_t header;
TF_RETURN_IF_ERROR(client_->ReadByte(&header));
if (header == kStringVal) {
int32_t length;
TF_RETURN_IF_ERROR(client_->ReadInt(&length));
uint8_t* err_msg_c = new uint8_t[length];
auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; });
TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, length));
string err_msg(reinterpret_cast<char*>(err_msg_c), length);
return errors::Unknown("Handshake Error [result=", handshake_res,
", version=", serv_ver_major, ".", serv_ver_minor,
".", serv_ver_patch, ", message='", err_msg, "']");
} else if (header == kNullVal) {
return errors::Unknown("Handshake Error [result=", handshake_res,
", version=", serv_ver_major, ".", serv_ver_minor,
".", serv_ver_patch, "]");
} else {
return errors::Unknown("Handshake Error [result=", handshake_res,
", version=", serv_ver_major, ".", serv_ver_minor,
".", serv_ver_patch, "]");
}
}
return Status::OK();
}
Status IgniteDatasetIterator::ScanQuery() {
TF_RETURN_IF_ERROR(client_->WriteInt(kScanQueryReqLength));
TF_RETURN_IF_ERROR(client_->WriteShort(kScanQueryOpcode));
TF_RETURN_IF_ERROR(client_->WriteLong(0)); // Request ID
TF_RETURN_IF_ERROR(
client_->WriteInt(JavaHashCode(cache_name_))); // Cache name
TF_RETURN_IF_ERROR(client_->WriteByte(0)); // Flags
TF_RETURN_IF_ERROR(client_->WriteByte(kNullVal)); // Filter object
TF_RETURN_IF_ERROR(client_->WriteInt(page_size_)); // Cursor page size
TF_RETURN_IF_ERROR(client_->WriteInt(part_)); // part_ition to query
TF_RETURN_IF_ERROR(client_->WriteByte(local_)); // local_ flag
uint64 wait_start = Env::Default()->NowMicros();
int32_t res_len;
TF_RETURN_IF_ERROR(client_->ReadInt(&res_len));
int64_t wait_stop = Env::Default()->NowMicros();
LOG(INFO) << "Scan Query waited " << (wait_stop - wait_start) / 1000 << " ms";
if (res_len < kMinResLength)
return errors::Unknown("Scan Query Response is corrupted");
int64_t req_id;
TF_RETURN_IF_ERROR(client_->ReadLong(&req_id));
int32_t status;
TF_RETURN_IF_ERROR(client_->ReadInt(&status));
if (status != 0) {
uint8_t err_msg_header;
TF_RETURN_IF_ERROR(client_->ReadByte(&err_msg_header));
if (err_msg_header == kStringVal) {
int32_t err_msg_length;
TF_RETURN_IF_ERROR(client_->ReadInt(&err_msg_length));
uint8_t* err_msg_c = new uint8_t[err_msg_length];
auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; });
TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, err_msg_length));
string err_msg(reinterpret_cast<char*>(err_msg_c), err_msg_length);
return errors::Unknown("Scan Query Error [status=", status,
", message=", err_msg, "]");
}
return errors::Unknown("Scan Query Error [status=", status, "]");
}
TF_RETURN_IF_ERROR(client_->ReadLong(&cursor_id_));
int32_t row_cnt;
TF_RETURN_IF_ERROR(client_->ReadInt(&row_cnt));
int32_t page_size = res_len - kScanQueryResHeaderLength;
return ReceivePage(page_size);
}
Status IgniteDatasetIterator::LoadNextPage() {
TF_RETURN_IF_ERROR(client_->WriteInt(kLoadNextPageReqLength));
TF_RETURN_IF_ERROR(client_->WriteShort(kLoadNextPageOpcode));
TF_RETURN_IF_ERROR(client_->WriteLong(0)); // Request ID
TF_RETURN_IF_ERROR(client_->WriteLong(cursor_id_)); // Cursor ID
uint64 wait_start = Env::Default()->NowMicros();
int32_t res_len;
TF_RETURN_IF_ERROR(client_->ReadInt(&res_len));
uint64 wait_stop = Env::Default()->NowMicros();
LOG(INFO) << "Load Next Page waited " << (wait_stop - wait_start) / 1000
<< " ms";
if (res_len < kMinResLength)
return errors::Unknown("Load Next Page Response is corrupted");
int64_t req_id;
TF_RETURN_IF_ERROR(client_->ReadLong(&req_id));
int32_t status;
TF_RETURN_IF_ERROR(client_->ReadInt(&status));
if (status != 0) {
uint8_t err_msg_header;
TF_RETURN_IF_ERROR(client_->ReadByte(&err_msg_header));
if (err_msg_header == kStringVal) {
int32_t err_msg_length;
TF_RETURN_IF_ERROR(client_->ReadInt(&err_msg_length));
uint8_t* err_msg_c = new uint8_t[err_msg_length];
auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; });
TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, err_msg_length));
string err_msg(reinterpret_cast<char*>(err_msg_c), err_msg_length);
return errors::Unknown("Load Next Page Error [status=", status,
", message=", err_msg, "]");
}
return errors::Unknown("Load Next Page Error [status=", status, "]");
}
int32_t row_cnt;
TF_RETURN_IF_ERROR(client_->ReadInt(&row_cnt));
int32_t page_size = res_len - kLoadNextPageResHeaderLength;
return ReceivePage(page_size);
}
Status IgniteDatasetIterator::ReceivePage(int32_t page_size) {
remainder_ = page_size;
page_ = std::unique_ptr<uint8_t>(new uint8_t[remainder_]);
ptr_ = page_.get();
uint64 start = Env::Default()->NowMicros();
TF_RETURN_IF_ERROR(client_->ReadData(ptr_, remainder_));
uint64 stop = Env::Default()->NowMicros();
double size_in_mb = 1.0 * remainder_ / 1024 / 1024;
double time_in_s = 1.0 * (stop - start) / 1000 / 1000;
LOG(INFO) << "Page size " << size_in_mb << " Mb, time " << time_in_s * 1000
<< " ms download speed " << size_in_mb / time_in_s << " Mb/sec";
uint8_t last_page_b;
TF_RETURN_IF_ERROR(client_->ReadByte(&last_page_b));
last_page_ = !last_page_b;
return Status::OK();
}
Status IgniteDatasetIterator::CheckTypes(const std::vector<int32_t>& types) {
if (schema_.size() != types.size())
return errors::Unknown("Object has unexpected schema");
for (size_t i = 0; i < schema_.size(); i++) {
if (schema_[i] != types[permutation_[i]])
return errors::Unknown("Object has unexpected schema");
}
return Status::OK();
}
int32_t IgniteDatasetIterator::JavaHashCode(string str) const {
int32_t h = 0;
for (char& c : str) {
h = 31 * h + c;
}
return h;
}
} // namespace tensorflow