|  | #ifndef CAFFE2_OPERATORS_DATASET_OPS_H_ | 
|  | #define CAFFE2_OPERATORS_DATASET_OPS_H_ | 
|  |  | 
|  | #include <memory> | 
|  | #include <mutex> | 
|  | #include <string> | 
|  | #include <vector> | 
|  | #include "caffe2/core/blob.h" | 
|  | #include "caffe2/core/blob_serialization.h" | 
|  | #include "caffe2/core/tensor.h" | 
|  |  | 
|  | namespace caffe2 { | 
|  | namespace dataset_ops { | 
|  |  | 
|  | // used for lengths tensors in the dataset | 
|  | using TLength = int32_t; | 
|  | // used for all internal dataset operations (offsets, sizes to read, etc.) | 
|  | using TOffset = int64_t; | 
|  |  | 
|  | /** | 
|  | * Provides functionality to iterate across a list of tensors where some | 
|  | * of those tensors represent lengths in a hierarchical structure. | 
|  | */ | 
|  | class TreeIterator { | 
|  | public: | 
|  | struct FieldDesc { | 
|  | int id; | 
|  | int lengthFieldId = -1; | 
|  | std::string name; | 
|  | }; | 
|  |  | 
|  | explicit TreeIterator(const std::vector<std::string>& fields); | 
|  |  | 
|  | void advance( | 
|  | const std::vector<const TLength*>& lengths, | 
|  | std::vector<TOffset>& offsets, | 
|  | std::vector<TOffset>& sizes, | 
|  | std::vector<TOffset>& limits, | 
|  | TOffset num); | 
|  |  | 
|  | // Corresponds to the number of fields that have "length" as its last name | 
|  | int numLengthFields() const { | 
|  | return lengthFieldIds_.size(); | 
|  | } | 
|  |  | 
|  | // Corresponds to the number of length fields + 1 (for the top-level domain) | 
|  | int numOffsetFields() const { | 
|  | return numLengthFields() + 1; | 
|  | } | 
|  |  | 
|  | // Get lengthField description for the given field | 
|  | const FieldDesc* lengthFieldFor(const FieldDesc& desc) { | 
|  | return (desc.lengthFieldId == -1) | 
|  | ? nullptr | 
|  | : &fields_.at(lengthFieldIds_.at(desc.lengthFieldId)); | 
|  | } | 
|  |  | 
|  | // Get lengthField description for the given lengthFieldId, where | 
|  | // 0 <= lengthFieldId < numLengthFields() | 
|  | const FieldDesc& lengthField(int lengthFieldId) { | 
|  | return fields_.at(lengthFieldIds_.at(lengthFieldId)); | 
|  | } | 
|  |  | 
|  | // Returns the index into the 'offset' vector for the given field. | 
|  | int offsetFieldIdFor(const FieldDesc& fieldDesc) { | 
|  | return fieldDesc.lengthFieldId + 1; | 
|  | } | 
|  |  | 
|  | // Returns the field description for all fields. | 
|  | const std::vector<FieldDesc>& fields() { | 
|  | return fields_; | 
|  | } | 
|  |  | 
|  | const std::vector<int>& lengthFieldIds() const { | 
|  | return lengthFieldIds_; | 
|  | } | 
|  |  | 
|  | private: | 
|  | // Description of each field | 
|  | std::vector<FieldDesc> fields_; | 
|  | // Index into fields_ above for the fields that are lengths. | 
|  | std::vector<int> lengthFieldIds_; | 
|  | }; | 
|  |  | 
|  | class TreeCursor { | 
|  | public: | 
|  | explicit TreeCursor(const TreeIterator& iterator) : it(iterator) {} | 
|  | std::vector<TOffset> offsets; | 
|  | std::mutex mutex_; | 
|  | TreeIterator it; | 
|  | }; | 
|  |  | 
|  | /** | 
|  | * Simple wrapper class allowing an easy traversal of the tensors representing | 
|  | * the hirerarchical structure. | 
|  | */ | 
|  | class TreeWalker { | 
|  | public: | 
|  | TreeWalker(const vector<const Blob*>& inputs, TreeCursor& cursor); | 
|  |  | 
|  | // Returns the number of records in a dataset | 
|  | inline TOffset size() const { | 
|  | return limits_.at(0); | 
|  | } | 
|  |  | 
|  | void advance(); | 
|  |  | 
|  | private: | 
|  | inline const TensorCPU& input(int32_t idx) const { | 
|  | return inputs_[idx]->Get<TensorCPU>(); | 
|  | } | 
|  |  | 
|  | // TODO: Change to fieldDesc | 
|  | inline const TreeIterator::FieldDesc& field(int idx) const { | 
|  | return cursor_.it.fields().at(idx); | 
|  | } | 
|  |  | 
|  | inline int lengthIdx(int fieldId) const { | 
|  | return field(fieldId).lengthFieldId + 1; | 
|  | } | 
|  |  | 
|  | inline TOffset offset(int fieldId) const { | 
|  | return prevOffsets_[lengthIdx(fieldId)]; | 
|  | } | 
|  |  | 
|  | std::vector<int64_t> fieldDim(int fieldId) const; | 
|  |  | 
|  | void* fieldPtr(int fieldId) const; | 
|  |  | 
|  | public: | 
|  | // Simple Proxy class to expose nicer API for field access | 
|  | class Field { | 
|  | public: | 
|  | Field(TreeWalker& walker, int fieldId) | 
|  | : walker_(walker), fieldId_(fieldId) {} | 
|  |  | 
|  | inline std::vector<int64_t> dim() const { | 
|  | return walker_.fieldDim(fieldId_); | 
|  | } | 
|  |  | 
|  | inline int64_t size() const { | 
|  | int64_t size = 1; | 
|  | for (const auto d : dim()) { | 
|  | size *= d; | 
|  | } | 
|  | return size; | 
|  | } | 
|  |  | 
|  | inline const TypeMeta meta() const { | 
|  | return walker_.input(fieldId_).dtype(); | 
|  | } | 
|  |  | 
|  | inline void* ptr() const { | 
|  | return walker_.fieldPtr(fieldId_); | 
|  | } | 
|  |  | 
|  | int fieldId() const { | 
|  | return fieldId_; | 
|  | } | 
|  |  | 
|  | inline TOffset offset() const { | 
|  | return walker_.offset(fieldId_); | 
|  | } | 
|  |  | 
|  | private: | 
|  | const TreeWalker& walker_; | 
|  | const int fieldId_; | 
|  | }; | 
|  |  | 
|  | // Notice that a reference is returned. If advance() is called the fields will | 
|  | // be updated to represent the new state. | 
|  | inline const std::vector<Field>& fields() const { | 
|  | return fields_; | 
|  | } | 
|  |  | 
|  | private: | 
|  | void gatherLengthData(); | 
|  |  | 
|  | void gatherSizeLimits(); | 
|  |  | 
|  | const vector<const Blob*>& inputs_; | 
|  | TreeCursor& cursor_; | 
|  | std::vector<Field> fields_; | 
|  |  | 
|  | std::vector<const TLength*> lengths_; | 
|  | std::vector<TOffset> limits_; | 
|  | std::vector<TOffset> sizes_; | 
|  | std::vector<TOffset> offsets_; | 
|  | std::vector<TOffset> prevOffsets_; | 
|  | }; | 
|  |  | 
|  | using SharedTensorVectorPtr = std::shared_ptr<std::vector<TensorCPU>>; | 
|  |  | 
|  | using Shared2DTensorVectorPtr = | 
|  | std::shared_ptr<std::vector<std::vector<caffe2::TensorCPU>>>; | 
|  |  | 
|  | using Tensor2DVector = std::vector<std::vector<caffe2::TensorCPU>>; | 
|  |  | 
|  | using TensorVectorPtr = std::unique_ptr<std::vector<Tensor>>; | 
|  |  | 
|  | class SharedTensorVectorPtrSerializer : public BlobSerializerBase { | 
|  | public: | 
|  | void Serialize( | 
|  | const void* pointer, | 
|  | TypeMeta typeMeta, | 
|  | const string& name, | 
|  | BlobSerializerBase::SerializationAcceptor acceptor) override; | 
|  | }; | 
|  |  | 
|  | class SharedTensorVectorPtrDeserializer : public BlobDeserializerBase { | 
|  | public: | 
|  | void Deserialize(const BlobProto& proto, Blob* blob) override; | 
|  | }; | 
|  |  | 
|  | } // namespace dataset_ops | 
|  | } // namespace caffe2 | 
|  |  | 
|  | #endif // CAFFE2_OPERATORS_DATASET_OPS_H_ |