|  | #ifndef CAFFE2_OPERATORS_H_SOFTMAX_OP_H_ | 
|  | #define CAFFE2_OPERATORS_H_SOFTMAX_OP_H_ | 
|  |  | 
|  | #include <c10/util/Optional.h> | 
|  | #include <c10/util/irange.h> | 
|  | #include "caffe2/core/context.h" | 
|  | #include "caffe2/core/logging.h" | 
|  | #include "caffe2/core/operator.h" | 
|  | #include "caffe2/proto/hsm.pb.h" | 
|  | #include "caffe2/utils/math.h" | 
|  |  | 
|  | namespace caffe2 { | 
|  |  | 
|  | template <typename T, typename Context> | 
|  | class HSoftmaxOpBase : public Operator<Context> { | 
|  | public: | 
|  | USE_OPERATOR_CONTEXT_FUNCTIONS; | 
|  | template <class... Args> | 
|  | explicit HSoftmaxOpBase(Args&&... args) | 
|  | : Operator<Context>(std::forward<Args>(args)...) { | 
|  | HierarchyProto hierarchy; | 
|  | CAFFE_ENFORCE(hierarchy.ParseFromString( | 
|  | this->template GetSingleArgument<string>("hierarchy", ""))); | 
|  | for (const auto& path : hierarchy.paths()) { | 
|  | hierarchy_all_map_.emplace(path.word_id(), path); | 
|  | } | 
|  | } | 
|  |  | 
|  | protected: | 
|  | std::unordered_map<int, PathProto> hierarchy_all_map_; | 
|  | c10::optional<Tensor> scale_; | 
|  | c10::optional<Tensor> sum_multiplier_; | 
|  | c10::optional<Tensor> bias_multiplier_; | 
|  | static constexpr T kLOG_THRESHOLD() { | 
|  | return 1e-20f; | 
|  | } | 
|  | static std::unordered_map<int, PathProto> getHierarchyForLabels( | 
|  | int M, | 
|  | const int* labels, | 
|  | const std::unordered_map<int, PathProto>& hierarchy_all_map) { | 
|  | std::unordered_map<int, PathProto> hierarchy_map; | 
|  | std::set<int> label_set = std::set<int>(labels, labels + M); | 
|  | for (const auto& label : label_set) { | 
|  | auto search = hierarchy_all_map.find(label); | 
|  | CAFFE_ENFORCE(search != hierarchy_all_map.end(), "incorrect label."); | 
|  | hierarchy_map.emplace(search->first, search->second); | 
|  | } | 
|  | return hierarchy_map; | 
|  | } | 
|  | int getIntermediateOutputSize( | 
|  | const int* labels, | 
|  | int M, | 
|  | std::unordered_map<int, PathProto>& hierarchy) const { | 
|  | int size = 0; | 
|  | for (const auto label : c10::irange(M)) { | 
|  | int word_id = labels[label]; | 
|  | const auto& path = hierarchy[word_id]; | 
|  | size += std::accumulate( | 
|  | path.path_nodes().begin(), | 
|  | path.path_nodes().end(), | 
|  | 0, | 
|  | // Output of FC + Output of Softmax | 
|  | [](int sz, PathNodeProto node) { return sz + 2 * node.length(); }); | 
|  | } | 
|  | return size; | 
|  | } | 
|  | }; | 
|  |  | 
|  | template <typename T, class Context> | 
|  | class HSoftmaxOp : public HSoftmaxOpBase<T, Context> { | 
|  | public: | 
|  | USE_OPERATOR_CONTEXT_FUNCTIONS; | 
|  | using HSoftmaxOpBase<T, Context>::HSoftmaxOpBase; | 
|  |  | 
|  | bool RunOnDevice() override; | 
|  |  | 
|  | protected: | 
|  | float RunForwardSingle( | 
|  | const float* X, | 
|  | const float* W, | 
|  | const float* b, | 
|  | int target, | 
|  | float* output, | 
|  | const float* bias_multiplier, | 
|  | int w_length, | 
|  | int K, | 
|  | int& output_offset); | 
|  | }; | 
|  |  | 
|  | template <typename T, class Context> | 
|  | class HSoftmaxGradientOp final : public HSoftmaxOpBase<T, Context> { | 
|  | public: | 
|  | USE_OPERATOR_CONTEXT_FUNCTIONS; | 
|  | using HSoftmaxOpBase<T, Context>::HSoftmaxOpBase; | 
|  | bool RunOnDevice() override; | 
|  |  | 
|  | private: | 
|  | void RunBackwardSingle( | 
|  | const float* X, | 
|  | const float* dY, | 
|  | const float* W, | 
|  | int target, | 
|  | const float* int_output, | 
|  | float* dX, | 
|  | float* dW, | 
|  | float* db, | 
|  | float* dOutput, | 
|  | int dim_in, | 
|  | int w_length, | 
|  | int& output_offset); | 
|  | }; | 
|  |  | 
|  | template <typename T, class Context> | 
|  | class HSoftmaxSearchOp final : public HSoftmaxOp<T, Context> { | 
|  | public: | 
|  | USE_OPERATOR_CONTEXT_FUNCTIONS; | 
|  | template <class... Args> | 
|  | explicit HSoftmaxSearchOp(Args&&... args) | 
|  | : HSoftmaxOp<T, Context>(std::forward<Args>(args)...), | 
|  | top_n_(this->template GetSingleArgument<int>("topN", 5)), | 
|  | beam_(this->template GetSingleArgument<float>("beam", 0.01f)) { | 
|  | CAFFE_ENFORCE(tree_.ParseFromString( | 
|  | this->template GetSingleArgument<string>("tree", ""))); | 
|  | } | 
|  | bool RunOnDevice() override; | 
|  |  | 
|  | private: | 
|  | int top_n_; | 
|  | float beam_; | 
|  | TreeProto tree_; | 
|  | bool pruning( | 
|  | const float* X, | 
|  | int sample, | 
|  | int K, | 
|  | const float* W, | 
|  | const float* b, | 
|  | const NodeProto& src_node, | 
|  | NodeProto& dst_node, | 
|  | float parent_score, | 
|  | float beam); | 
|  | bool extractNodes( | 
|  | const NodeProto& node, | 
|  | std::vector<std::pair<string, float>>& info); | 
|  | }; | 
|  |  | 
|  | template <typename T, class Context> | 
|  | class HuffmanTreeHierarchyOp : public Operator<Context> { | 
|  | public: | 
|  | USE_OPERATOR_CONTEXT_FUNCTIONS; | 
|  | template <class... Args> | 
|  | explicit HuffmanTreeHierarchyOp(Args&&... args) | 
|  | : Operator<Context>(std::forward<Args>(args)...), | 
|  | num_classes_(this->template GetSingleArgument<int>("num_classes", -1)) { | 
|  | } | 
|  | bool RunOnDevice() override; | 
|  |  | 
|  | private: | 
|  | // Internal huffman tree data. | 
|  | struct Node { | 
|  | Node(T l, int count) | 
|  | : label(l), count(count), left_ch_index(-1), right_ch_index(-1) {} | 
|  | T label; | 
|  | int count; | 
|  | int left_ch_index; | 
|  | int right_ch_index; | 
|  | }; | 
|  |  | 
|  | struct NodeComparator { | 
|  | bool operator()(const Node& node_a, const Node& node_b) { | 
|  | return node_a.count > node_b.count; | 
|  | } | 
|  | }; | 
|  |  | 
|  | int num_classes_; | 
|  | }; | 
|  |  | 
|  | } // namespace caffe2 | 
|  |  | 
|  | #endif // CAFFE2_OPERATORS_SOFTMAX_OP_H_ |