|  | #pragma once | 
|  |  | 
|  | #include "caffe2/core/operator.h" | 
|  |  | 
|  | namespace caffe2 { | 
|  |  | 
|  | template <typename T> | 
|  | struct FtrlParams { | 
|  | explicit FtrlParams(OperatorBase* op) | 
|  | : alphaInv(1.0 / op->GetSingleArgument<float>("alpha", 0.005f)), | 
|  | beta(op->GetSingleArgument<float>("beta", 1.0f)), | 
|  | lambda1(op->GetSingleArgument<float>("lambda1", 0.001f)), | 
|  | lambda2(op->GetSingleArgument<float>("lambda2", 0.001f)) {} | 
|  | T alphaInv; | 
|  | T beta; | 
|  | T lambda1; | 
|  | T lambda2; | 
|  | }; | 
|  |  | 
|  | // TODO(dzhulgakov): implement GPU version if necessary | 
|  | template <typename T, class Context> | 
|  | class FtrlOp final : public Operator<Context> { | 
|  | public: | 
|  | USE_OPERATOR_CONTEXT_FUNCTIONS; | 
|  | FtrlOp(const OperatorDef& operator_def, Workspace* ws) | 
|  | : Operator<Context>(operator_def, ws), params_(this) { | 
|  | CAFFE_ENFORCE( | 
|  | !HasArgument("alpha") || ALPHA >= InputSize(), | 
|  | "Cannot specify alpha by both input and argument"); | 
|  | } | 
|  | bool RunOnDevice() override; | 
|  |  | 
|  | protected: | 
|  | FtrlParams<T> params_; | 
|  | INPUT_TAGS(VAR, N_Z, GRAD, ALPHA); | 
|  | OUTPUT_TAGS(OUTPUT_VAR, OUTPUT_N_Z); | 
|  | }; | 
|  |  | 
|  | template <typename T> | 
|  | class SparseFtrlOp final : public Operator<CPUContext> { | 
|  | public: | 
|  | SparseFtrlOp(const OperatorDef& operator_def, Workspace* ws) | 
|  | : Operator<CPUContext>(operator_def, ws), params_(this) { | 
|  | CAFFE_ENFORCE( | 
|  | !HasArgument("alpha") || ALPHA >= InputSize(), | 
|  | "Cannot specify alpha by both input and argument"); | 
|  | } | 
|  |  | 
|  | bool RunOnDevice() override { | 
|  | // run time learning rate override | 
|  | if (ALPHA < InputSize()) { | 
|  | CAFFE_ENFORCE_EQ(Input(ALPHA).numel(), 1, "alpha should be real-valued"); | 
|  | params_.alphaInv = 1.0 / *(Input(ALPHA).template data<T>()); | 
|  | } | 
|  | // Use run-time polymorphism | 
|  | auto& indices = Input(INDICES); | 
|  | if (indices.template IsType<int32_t>()) { | 
|  | DoRun<int32_t>(); | 
|  | } else if (indices.template IsType<int64_t>()) { | 
|  | DoRun<int64_t>(); | 
|  | } else { | 
|  | LOG(FATAL) << "Unsupported type of INDICES in SparseFtrlOp: " | 
|  | << indices.dtype().name(); | 
|  | } | 
|  | return true; | 
|  | } | 
|  |  | 
|  | protected: | 
|  | FtrlParams<T> params_; | 
|  | INPUT_TAGS(VAR, N_Z, INDICES, GRAD, ALPHA); | 
|  | OUTPUT_TAGS(OUTPUT_VAR, OUTPUT_N_Z); | 
|  |  | 
|  | private: | 
|  | template <typename SIndex> | 
|  | void DoRun(); | 
|  | }; | 
|  |  | 
|  | } |