|  | #pragma once | 
|  |  | 
|  | #include <vector> | 
|  |  | 
|  | #include "caffe2/core/context.h" | 
|  | #include "caffe2/core/operator.h" | 
|  | #include "caffe2/utils/math.h" | 
|  |  | 
|  | namespace caffe2 { | 
|  | template <typename F, typename T, class Context> | 
|  | class NGramFromCategoricalOp : public Operator<Context> { | 
|  | public: | 
|  | USE_OPERATOR_CONTEXT_FUNCTIONS; | 
|  |  | 
|  | template <class... Args> | 
|  | explicit NGramFromCategoricalOp(Args&&... args) | 
|  | : Operator<Context>(std::forward<Args>(args)...), | 
|  | col_ids_(this->template GetRepeatedArgument<int>("col_ids")), | 
|  | categorical_limits_( | 
|  | this->template GetRepeatedArgument<int>("categorical_limits")), | 
|  | vals_(this->template GetRepeatedArgument<int>("vals")) { | 
|  | col_num_ = col_ids_.size(); | 
|  | max_col_id_ = *std::max_element(col_ids_.begin(), col_ids_.end()); | 
|  | CAFFE_ENFORCE_EQ(col_num_, categorical_limits_.size()); | 
|  | int expected_vals_size = 0; | 
|  | for (auto& l : categorical_limits_) { | 
|  | CAFFE_ENFORCE_GT(l, 0); | 
|  | expected_vals_size += l; | 
|  | } | 
|  | CAFFE_ENFORCE_EQ(expected_vals_size, vals_.size()); | 
|  | // compute ngram maps with small end | 
|  | for (auto& j : col_ids_) { | 
|  | CAFFE_ENFORCE_GE(j, 0); | 
|  | ngram_maps_.push_back(std::map<int, int>()); | 
|  | } | 
|  | int base = 1; | 
|  | int idx = 0; | 
|  | for (int k = 0; k < col_num_; k++) { | 
|  | int l = categorical_limits_[k]; | 
|  | for (int m = 0; m < l; m++) { | 
|  | int v = vals_[idx++]; | 
|  | ngram_maps_[k][v] = m * base; | 
|  | } | 
|  | base *= l; | 
|  | } | 
|  | } | 
|  |  | 
|  | bool RunOnDevice() override { | 
|  | auto& floats = Input(0); | 
|  | auto N = floats.size(0); | 
|  | auto D = floats.size_from_dim(1); | 
|  | const F* floats_data = floats.template data<F>(); | 
|  |  | 
|  | auto* output = Output(0, {N}, at::dtype<T>()); | 
|  | auto* output_data = output->template mutable_data<T>(); | 
|  | math::Set<T, Context>(output->numel(), 0, output_data, &context_); | 
|  |  | 
|  | CAFFE_ENFORCE_GT(D, max_col_id_); | 
|  | for (int i = 0; i < N; i++) { | 
|  | for (int k = 0; k < col_num_; k++) { | 
|  | int j = col_ids_[k]; | 
|  | int v = round(floats_data[i * D + j]); | 
|  | // for out-of-vocabulary values, we always treat them the same as the | 
|  | // first value specified in vals; if we want to mimic the behavior as | 
|  | // sigrid NGram transform, just push front a random/impossible value at | 
|  | // each segments of vals | 
|  | output_data[i] += ngram_maps_[k].find(v) == ngram_maps_[k].end() | 
|  | ? 0 | 
|  | : ngram_maps_[k][v]; | 
|  | } | 
|  | } | 
|  | return true; | 
|  | } | 
|  |  | 
|  | private: | 
|  | std::vector<int> col_ids_; | 
|  | std::vector<int> categorical_limits_; | 
|  | std::vector<int> vals_; | 
|  | std::vector<std::map<int, int>> ngram_maps_; | 
|  | int col_num_; | 
|  | int max_col_id_; | 
|  | }; | 
|  | } // namespace caffe2 |