blob: 61177bfd4f21b0f1f0424f292b4f1cea92aa86d3 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h"
#include <stdlib.h>
#include <string.h>
#include "google/cloud/storage/client.h"
#include "tensorflow/c/env.h"
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
#include "tensorflow/c/tf_status.h"
// Implementation of a filesystem for GCS environments.
// This filesystem will support `gs://` URI schemes.
namespace gcs = google::cloud::storage;
// We can cast `google::cloud::StatusCode` to `TF_Code` because they have the
// same integer values. See
// https://github.com/googleapis/google-cloud-cpp/blob/6c09cbfa0160bc046e5509b4dd2ab4b872648b4a/google/cloud/status.h#L32-L52
static inline void TF_SetStatusFromGCSStatus(
const google::cloud::Status& gcs_status, TF_Status* status) {
TF_SetStatus(status, static_cast<TF_Code>(gcs_status.code()),
gcs_status.message().c_str());
}
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
static void plugin_memory_free(void* ptr) { free(ptr); }
void ParseGCSPath(const std::string& fname, bool object_empty_ok,
std::string& bucket, std::string& object, TF_Status* status) {
size_t scheme_end = fname.find("://") + 2;
if (fname.substr(0, scheme_end + 1) != "gs://") {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"GCS path doesn't start with 'gs://'.");
return;
}
size_t bucket_end = fname.find("/", scheme_end + 1);
if (bucket_end == std::string::npos) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"GCS path doesn't contain a bucket name.");
return;
}
bucket = std::move(fname.substr(scheme_end + 1, bucket_end - scheme_end - 1));
object = std::move(fname.substr(bucket_end + 1));
if (object.empty() && !object_empty_ok) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"GCS path doesn't contain an object name.");
}
}
// SECTION 1. Implementation for `TF_RandomAccessFile`
// ----------------------------------------------------------------------------
namespace tf_random_access_file {
// TODO(vnvo2409): Implement later
} // namespace tf_random_access_file
// SECTION 2. Implementation for `TF_WritableFile`
// ----------------------------------------------------------------------------
namespace tf_writable_file {
typedef struct GCSFile {
const std::string bucket;
const std::string object;
gcs::Client* gcs_client; // not owned
TempFile outfile;
bool sync_need;
} GCSFile;
static void Cleanup(TF_WritableFile* file) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
delete gcs_file;
}
// TODO(vnvo2409): Implement later
} // namespace tf_writable_file
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
// ----------------------------------------------------------------------------
namespace tf_read_only_memory_region {
// TODO(vnvo2409): Implement later
} // namespace tf_read_only_memory_region
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
// ----------------------------------------------------------------------------
namespace tf_gcs_filesystem {
// TODO(vnvo2409): Add lazy-loading and customizing parameters.
void Init(TF_Filesystem* filesystem, TF_Status* status) {
google::cloud::StatusOr<gcs::Client> client =
gcs::Client::CreateDefaultClient();
if (!client) {
TF_SetStatusFromGCSStatus(client.status(), status);
return;
}
filesystem->plugin_filesystem = plugin_memory_allocate(sizeof(gcs::Client));
auto gcs_client = static_cast<gcs::Client*>(filesystem->plugin_filesystem);
(*gcs_client) = client.value();
TF_SetStatus(status, TF_OK, "");
}
void Cleanup(TF_Filesystem* filesystem) {
plugin_memory_free(filesystem->plugin_filesystem);
}
// TODO(vnvo2409): Implement later
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, false, bucket, object, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_client = static_cast<gcs::Client*>(filesystem->plugin_filesystem);
char* temp_file_name = TF_GetTempFileName("");
file->plugin_file = new tf_writable_file::GCSFile(
{std::move(bucket), std::move(object), gcs_client,
TempFile(temp_file_name, std::ios::binary | std::ios::out), true});
// We are responsible for freeing the pointer returned by TF_GetTempFileName
free(temp_file_name);
TF_SetStatus(status, TF_OK, "");
}
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, false, bucket, object, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_client = static_cast<gcs::Client*>(filesystem->plugin_filesystem);
char* temp_file_name = TF_GetTempFileName("");
auto gcs_status = gcs_client->DownloadToFile(bucket, object, temp_file_name);
TF_SetStatusFromGCSStatus(gcs_status, status);
auto status_code = TF_GetCode(status);
if (status_code != TF_OK && status_code != TF_NOT_FOUND) {
return;
}
// If this file does not exist on server, we will need to sync it.
bool sync_need = (status_code == TF_NOT_FOUND);
file->plugin_file = new tf_writable_file::GCSFile(
{std::move(bucket), std::move(object), gcs_client,
TempFile(temp_file_name, std::ios::binary | std::ios::app), sync_need});
free(temp_file_name);
TF_SetStatus(status, TF_OK, "");
}
} // namespace tf_gcs_filesystem
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
const char* uri) {
TF_SetFilesystemVersionMetadata(ops);
ops->scheme = strdup(uri);
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
ops->filesystem_ops->init = tf_gcs_filesystem::Init;
ops->filesystem_ops->cleanup = tf_gcs_filesystem::Cleanup;
ops->filesystem_ops->new_writable_file = tf_gcs_filesystem::NewWritableFile;
ops->filesystem_ops->new_appendable_file =
tf_gcs_filesystem::NewAppendableFile;
}
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
info->plugin_memory_allocate = plugin_memory_allocate;
info->plugin_memory_free = plugin_memory_free;
info->num_schemes = 1;
info->ops = static_cast<TF_FilesystemPluginOps*>(
plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0])));
ProvideFilesystemSupportFor(&info->ops[0], "gs");
}