| #ifndef CAFFE2_OPERATORS_LOAD_SAVE_OP_H_ |
| #define CAFFE2_OPERATORS_LOAD_SAVE_OP_H_ |
| |
| #include <cstdio> |
| #include <map> |
| #include <unordered_set> |
| |
| #include "caffe2/core/blob_serialization.h" |
| #include "caffe2/core/context.h" |
| #include "caffe2/core/db.h" |
| #include "caffe2/core/logging.h" |
| #include "caffe2/core/operator.h" |
| #include "caffe2/utils/math.h" |
| #include "caffe2/utils/proto_utils.h" |
| |
| namespace caffe2 { |
| |
| namespace { |
| struct BlobState { |
| int64_t total_size; |
| int64_t current_size; |
| bool is_tensor; |
| std::set<int32_t> seen_chunks_ids; |
| |
| explicit BlobState( |
| int64_t total_size = 0, |
| int64_t current_size = 0, |
| bool is_tensor = false) |
| : total_size(total_size), |
| current_size(current_size), |
| is_tensor(is_tensor) {} |
| }; |
| } // namespace |
| |
| using db::Cursor; |
| using db::DB; |
| using db::Transaction; |
| |
| template <class Context> |
| class DBExistsOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| DBExistsOp(const OperatorDef& operator_def, Workspace* ws) |
| : Operator<Context>(operator_def, ws), |
| ws_(ws), |
| absolute_path_( |
| OperatorBase::GetSingleArgument<int>("absolute_path", false)), |
| db_name_(OperatorBase::GetSingleArgument<string>("db_name", "")), |
| db_type_(OperatorBase::GetSingleArgument<string>("db_type", "")) {} |
| |
| bool RunOnDevice() override { |
| string full_db_name = |
| absolute_path_ ? db_name_ : (ws_->RootFolder() + "/" + db_name_); |
| auto* output = Output(0); |
| output->Resize(); |
| bool* exists = output->template mutable_data<bool>(); |
| |
| *exists = caffe2::db::DBExists(db_type_, full_db_name); |
| return true; |
| } |
| |
| private: |
| Workspace* ws_; |
| bool absolute_path_; |
| std::string db_name_; |
| std::string db_type_; |
| }; |
| |
| template <class Context> |
| class LoadOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| LoadOp(const OperatorDef& operator_def, Workspace* ws) |
| : Operator<Context>(operator_def, ws), |
| ws_(ws), |
| absolute_path_( |
| OperatorBase::GetSingleArgument<int>("absolute_path", false)), |
| add_prefix_(OperatorBase::GetSingleArgument<string>("add_prefix", "")), |
| strip_prefix_( |
| OperatorBase::GetSingleArgument<string>("strip_prefix", "")), |
| db_name_(OperatorBase::GetSingleArgument<string>("db", "")), |
| db_type_(OperatorBase::GetSingleArgument<string>("db_type", "")), |
| keep_device_(OperatorBase::GetSingleArgument<int>("keep_device", 0)), |
| load_all_(OperatorBase::GetSingleArgument<int>("load_all", 0)), |
| allow_incomplete_( |
| OperatorBase::GetSingleArgument<bool>("allow_incomplete", false)), |
| blob_names_(OperatorBase::GetRepeatedArgument<string>( |
| "source_blob_names")) { |
| if (InputSize() == 0) { |
| CAFFE_ENFORCE_GT(db_name_.size(), 0, "Must specify a db name."); |
| CAFFE_ENFORCE_GT(db_type_.size(), 0, "Must specify a db type."); |
| } |
| CAFFE_ENFORCE(blob_names_.empty() || blob_names_.size() == OutputSize(), |
| "Number of output blobs and source_blob_names mismatch."); |
| CAFFE_ENFORCE(blob_names_.empty() || strip_prefix_.empty(), |
| "strip_prefix and source_blob_names are mutually exclusive."); |
| CAFFE_ENFORCE(blob_names_.empty() || !load_all_, |
| "cannot load_all_ while using source_blob_names."); |
| if (!load_all_) { |
| // blob_names_ will be filled with ''source blob names'' in file/db |
| // if argument source_blob_names is not given, then blob_names_ is |
| // inferred from operator output |
| if(blob_names_.empty()) { |
| for (const string& name : operator_def.output()) { |
| blob_names_.push_back(name); |
| } |
| } |
| int idx = 0; |
| std::set<std::string> name_set; |
| for (const string& name : blob_names_) { |
| CAFFE_ENFORCE(name_set.insert(name).second, |
| "Duplicated source blob name: ", name); |
| output_indices_[name] = idx++; |
| } |
| } |
| } |
| |
| void SetCurrentDevice(BlobProto* proto); |
| |
| bool RunOnDevice() override { |
| if (InputSize() == 1) { |
| const db::DBReader& reader = OperatorBase::Input<db::DBReader>(0); |
| extract(reader.cursor()); |
| } else { |
| string full_db_name = |
| absolute_path_ ? db_name_ : (ws_->RootFolder() + "/" + db_name_); |
| std::unique_ptr<DB> in_db( |
| caffe2::db::CreateDB(db_type_, full_db_name, caffe2::db::READ)); |
| CAFFE_ENFORCE(in_db.get(), "Cannot open db: ", db_name_); |
| std::unique_ptr<Cursor> cursor(in_db->NewCursor()); |
| extract(cursor.get()); |
| } |
| |
| return true; |
| } |
| |
| private: |
| void extract(Cursor* cursor) { |
| if (load_all_) { |
| extractAll(cursor); |
| } else { |
| extractFrom(cursor, OperatorBase::Outputs()); |
| } |
| } |
| |
| void extractAll(Cursor* cursor) { |
| CAFFE_ENFORCE(cursor, "cursor is not valid"); |
| std::unordered_map<string, BlobState> blob_states; |
| int loaded_blobs = 0; |
| for (; cursor->Valid(); cursor->Next()) { |
| const auto key = buildBlobNameFromDbKey(cursor->key()); |
| BlobProto proto; |
| CAFFE_ENFORCE( |
| proto.ParseFromString(cursor->value()), "Couldn't parse Proto"); |
| if (!keep_device_) { |
| // If we are not keeping the device as the one specified in the |
| // proto, we will set the current device. |
| SetCurrentDevice(&proto); |
| } |
| |
| Blob* blob = ws_->CreateBlob(key); |
| ProcessBlob(blob, proto, &blob_states, key, &loaded_blobs); |
| } |
| |
| VLOG(1) << "Loaded " << loaded_blobs << " from db"; |
| validateBlobStates(blob_states); |
| } |
| |
| void extractFrom(Cursor* cursor, const vector<Blob*>& outputs) { |
| CAFFE_ENFORCE(cursor); |
| std::unordered_map<string, BlobState> blob_states; |
| int loaded_blobs = 0; |
| for (; cursor->Valid(); cursor->Next()) { |
| const auto key = buildBlobNameFromDbKey(cursor->key()); |
| if (!output_indices_.count(key)) { |
| VLOG(1) << "Key " << key << " not used. Skipping."; |
| } else { |
| VLOG(2) << "Deserializing blob " << key; |
| BlobProto proto; |
| CAFFE_ENFORCE(proto.ParseFromString(cursor->value())); |
| if (!keep_device_) { |
| // If we are not keeping the device as the one specified in the |
| // proto, we will set the current device. |
| SetCurrentDevice(&proto); |
| } |
| auto blobIndex = output_indices_[key]; |
| Blob* blob = outputs.at(blobIndex); |
| ProcessBlob(blob, proto, &blob_states, key, &loaded_blobs); |
| |
| if (loaded_blobs == OutputSize()) { |
| VLOG(1) << "Read all required blobs"; |
| break; |
| } |
| } |
| } |
| |
| validateBlobStates(blob_states); |
| VLOG(1) << "Fully loaded " << blob_states.size() << " blobs"; |
| |
| if (loaded_blobs != OutputSize()) { |
| if (allow_incomplete_ && loaded_blobs < OutputSize()) { |
| VLOG(1) << "Loaded " << loaded_blobs << " blobs out of " << OutputSize() |
| << " blobs from db."; |
| return; |
| } |
| for (const string& output_name : this->debug_def().output()) { |
| if (blob_states.count(output_name) == 0) { |
| LOG(ERROR) << "Failed to load blob: " << output_name; |
| } |
| } |
| CAFFE_THROW( |
| "Expected to load ", |
| OutputSize(), |
| " blobs, got ", |
| loaded_blobs, |
| " only.\n"); |
| } |
| } |
| |
| string buildBlobNameFromDbKey(const string& dbKey) { |
| string key = dbKey.substr(0, dbKey.find(kChunkIdSeparator)); |
| if (!strip_prefix_.empty()) { |
| auto match_pos = key.find(strip_prefix_); |
| if (match_pos != string::npos) { |
| key = key.substr(match_pos + strip_prefix_.size()); |
| } |
| } |
| key = add_prefix_ + key; |
| return key; |
| } |
| |
| private: |
| // We are tracking sizes of already read tensor parts while reading data |
| // chunks. This way we can make sure that all chunks were loaded in the end. |
| void ProcessBlob( |
| Blob* blob, |
| const BlobProto& proto, |
| std::unordered_map<string, BlobState>* blob_states_ptr, |
| const string& key, |
| int* loaded_blobs) { |
| auto& blob_states = *blob_states_ptr; |
| if (blob_states.count(key) == 0) { |
| // We reset the blob so that any existing content is destroyed. This |
| // is to guaranee correct device placement: if we are deserializing |
| // into a TensorCUDA, without explicit Reset we might be loading data |
| // into an existing TensorCUDA that has pre-allocated memory on a |
| // different GPU. |
| blob->Reset(); |
| } |
| blob->Deserialize(proto); |
| if (proto.has_content_num_chunks()) { |
| if (!blob_states.count(key)) { |
| blob_states[key] = BlobState(proto.content_num_chunks()); |
| } |
| CAFFE_ENFORCE( |
| blob_states[key] |
| .seen_chunks_ids.insert(proto.content_chunk_id()) |
| .second, |
| "Chunk with the same id has occured twice for: ", |
| key); |
| CAFFE_ENFORCE( |
| proto.content_chunk_id() >= 0 && |
| proto.content_chunk_id() < blob_states[key].total_size, |
| "Chunk id has to be not less than 0 and " |
| "less than content_num_chunks for key: ", |
| key); |
| blob_states[key].current_size++; |
| CAFFE_ENFORCE( |
| !blob_states[key].is_tensor, |
| "Proto with content_chunks can not store tensor: ", |
| key); |
| CAFFE_ENFORCE( |
| blob_states[key].current_size <= blob_states[key].total_size, |
| "Found an extra part for an already filled blob: ", |
| key); |
| if (blob_states[key].current_size == blob_states[key].total_size) { |
| (*loaded_blobs)++; |
| } |
| return; |
| } |
| if (!proto.has_tensor()) { |
| // If blob is divided into chunks the field content_chunks has to be set, |
| // otherwise only tensors can be seen multiple times as chunks. |
| CAFFE_ENFORCE(blob_states.count(key) == 0, "Blob duplicated: ", key); |
| blob_states[key] = BlobState(); |
| (*loaded_blobs)++; |
| return; |
| } |
| CAFFE_ENFORCE(proto.has_tensor()); |
| if (blob_states.count(key)) { |
| CAFFE_ENFORCE(blob_states[key].is_tensor, "Must be tensor ", key); |
| CAFFE_ENFORCE( |
| blob_states[key].current_size < blob_states[key].total_size, |
| "Found an extra part for an already filled tensor: ", |
| key); |
| CAFFE_ENFORCE( |
| proto.tensor().has_segment(), |
| "Partial tensor must have a segment: ", |
| key); |
| blob_states[key].current_size += |
| proto.tensor().segment().end() - proto.tensor().segment().begin(); |
| CAFFE_ENFORCE( |
| blob_states[key].current_size <= blob_states[key].total_size, |
| "Tensor parts are bigger than target size for tensor: ", |
| key); |
| } else { |
| const auto& dims = proto.tensor().dims(); |
| int64_t total_size = 1; |
| for (const auto& dim : dims) { |
| total_size *= dim; |
| } |
| auto current_size = total_size; |
| if (proto.tensor().has_segment()) { |
| current_size = |
| proto.tensor().segment().end() - proto.tensor().segment().begin(); |
| } |
| blob_states[key] = |
| BlobState(total_size, current_size, true /* is_tensor */); |
| } |
| |
| if (blob_states[key].current_size == blob_states[key].total_size) { |
| (*loaded_blobs)++; |
| } |
| } |
| |
| void validateBlobStates( |
| const std::unordered_map<string, BlobState>& blob_states) { |
| for (const auto& iter : blob_states) { |
| const BlobState& blob_state = iter.second; |
| CAFFE_ENFORCE( |
| blob_state.current_size == blob_state.total_size, |
| "Data size mismatch for blob ", |
| iter.first, |
| ". Expected: ", |
| blob_state.total_size, |
| " Read: ", |
| blob_state.current_size); |
| } |
| } |
| |
| Workspace* ws_; |
| bool absolute_path_; |
| string add_prefix_; |
| string strip_prefix_; |
| string db_name_; |
| string db_type_; |
| bool keep_device_; |
| bool load_all_; |
| bool allow_incomplete_; |
| std::map<string, int> output_indices_; |
| std::vector<std::string> blob_names_; |
| }; |
| |
| template <class Context> |
| class SaveOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| SaveOp(const OperatorDef& operator_def, Workspace* ws) |
| : Operator<Context>(operator_def, ws), |
| ws_(ws), |
| absolute_path_( |
| OperatorBase::GetSingleArgument<int>("absolute_path", false)), |
| strip_prefix_( |
| OperatorBase::GetSingleArgument<string>("strip_prefix", "")), |
| db_name_(OperatorBase::GetSingleArgument<string>("db", "")), |
| db_type_(OperatorBase::GetSingleArgument<string>("db_type", "")), |
| blob_names_( |
| OperatorBase::GetRepeatedArgument<string>("blob_name_overrides")) { |
| CAFFE_ENFORCE_GT(db_name_.size(), 0, "Must specify a db name."); |
| CAFFE_ENFORCE_GT(db_type_.size(), 0, "Must specify a db type."); |
| CAFFE_ENFORCE( |
| blob_names_.empty() || |
| blob_names_.size() == OperatorBase::Inputs().size(), |
| "Number of blobs and blob_name_overrides mismatch."); |
| CAFFE_ENFORCE( |
| blob_names_.empty() || strip_prefix_.empty(), |
| "strip_prefix and blob_name_overrides are mutually exclusive."); |
| |
| if (blob_names_.empty()) { |
| std::set<std::string> input_names; |
| blob_names_.resize(OperatorBase::Inputs().size()); |
| for (int i = 0; i < blob_names_.size(); ++i) { |
| std::string name; |
| if (strip_prefix_.empty()) { |
| name = operator_def.input(i); |
| } else { |
| auto match_pos = operator_def.input(i).find(strip_prefix_); |
| if (match_pos == string::npos) { |
| name = operator_def.input(i); |
| } else { |
| name = operator_def.input(i).substr( |
| match_pos + strip_prefix_.size(), string::npos); |
| } |
| } |
| CAFFE_ENFORCE( |
| input_names.insert(name).second, "Duplicated input: ", name); |
| blob_names_[i] = name; |
| } |
| } |
| } |
| |
| bool RunOnDevice() override { |
| string full_db_name = |
| absolute_path_ ? db_name_ : (ws_->RootFolder() + "/" + db_name_); |
| std::unique_ptr<DB> out_db( |
| caffe2::db::CreateDB(db_type_, full_db_name, caffe2::db::NEW)); |
| CAFFE_ENFORCE(out_db.get(), "Cannot open db for writing: ", full_db_name); |
| |
| BlobSerializerBase::SerializationAcceptor acceptor = [&]( |
| const std::string& blobName, const std::string& data) { |
| // transaction should take care of locking |
| VLOG(2) << "Sending " << blobName << " blob's data of size " |
| << data.size() << " to db"; |
| auto transaction = out_db->NewTransaction(); |
| transaction->Put(blobName, data); |
| transaction->Commit(); |
| }; |
| |
| const vector<const Blob*>& inputs = OperatorBase::Inputs(); |
| for (int i = 0; i < inputs.size(); ++i) { |
| inputs[i]->Serialize(blob_names_[i], acceptor); |
| } |
| out_db->Close(); |
| return true; |
| } |
| |
| private: |
| Workspace* ws_; |
| bool absolute_path_; |
| string strip_prefix_; |
| string db_name_; |
| string db_type_; |
| std::vector<std::string> blob_names_; |
| }; |
| |
| template <typename... Ts> |
| string FormatString(const string& pattern, Ts... values) { |
| // Note(Yangqing): We believe that 1024 is enough, but who are we to assert |
| // that? |
| // As a result, if things go wrong, we'll just throw the towel and quit loud. |
| // Yeah, I know that there is snprintf, but it is not present in *some* |
| // platforms unfortunately. |
| char buffer[1024]; |
| int written = sprintf(buffer, pattern.c_str(), values...); |
| if (written < 0 || written + 1 > 1024) { |
| LOG(FATAL) << "FormatString fails: total bytes written " << written; |
| } |
| return string(buffer); |
| /* |
| * The following is the snprintf version that is safe; enable it one day? |
| unsigned int required = |
| std::snprintf(nullptr, 0, pattern.c_str(), values...) + 1; |
| char bytes[required]; |
| std::snprintf(bytes, required, pattern.c_str(), values...); |
| return string(bytes); |
| */ |
| } |
| |
| // CheckpointOp is a wrapper over a SaveFloatTensorOp that basically allows |
| // flexible naming over iterations. |
| // The file pattern in db_name should be a format string that can be passed into |
| // sprintf with an int argument specifying the current iteration. An example: |
| // "/path/to/my/checkpoint/checkpoint_at_%d.pb" |
| template <class Context> |
| class CheckpointOp final : public Operator<Context> { |
| public: |
| CheckpointOp(const OperatorDef& operator_def, Workspace* ws) |
| : Operator<Context>(operator_def, ws), |
| db_pattern_(OperatorBase::GetSingleArgument<string>("db", "")), |
| every_(OperatorBase::GetSingleArgument<int>("every", 1)), |
| ws_(ws), |
| save_op_def_(operator_def) { |
| CAFFE_ENFORCE_GT( |
| db_pattern_.size(), 0, "Must specify a checkpoint file pattern."); |
| CAFFE_ENFORCE_GT(every_, 0, "Checkpoint interval should be positive."); |
| if (every_ == 1) { |
| // Just issue a warning, but it's totally legal so we don't do anything. |
| LOG(WARNING) << "It seems that we are checkpointting every iteration. " |
| << "Is that intended?"; |
| } |
| save_op_def_.set_type("Save"); |
| } |
| |
| bool RunOnDevice() override { |
| int64_t iter = |
| OperatorBase::Input<TensorCPU>(0).template data<int64_t>()[0]; |
| if (iter % every_ == 0) { |
| GetMutableArgument("db", true, &save_op_def_) |
| ->set_s(FormatString(db_pattern_, iter)); |
| SaveOp<Context> sub_op(save_op_def_, ws_); |
| return sub_op.Run(); |
| } else { |
| return true; |
| } |
| } |
| |
| private: |
| string db_pattern_; |
| int every_; |
| Workspace* ws_; |
| OperatorDef save_op_def_; |
| }; |
| |
| } // namespace caffe2 |
| |
| #endif // CAFFE2_OPERATORS_LOAD_SAVE_OP_H_ |