blob: 09c24dff2622b2acb8c8c1b96d835ede99582bb1 [file] [log] [blame]
#ifndef CAFFE2_OPERATORS_DROPOUT_OP_H_
#define CAFFE2_OPERATORS_DROPOUT_OP_H_
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
template <typename T, class Context>
class DropoutOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
DropoutOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
ratio_(OperatorBase::GetSingleArgument<float>("ratio", 0.5)),
is_test_(OperatorBase::GetSingleArgument<int>("is_test", 0)) {
DCHECK_GE(ratio_, 0);
DCHECK_LT(ratio_, 1);
}
bool RunOnDevice() override;
protected:
float ratio_;
bool is_test_;
// Input: X; Output: Y, mask.
};
template <typename T, class Context>
class DropoutGradientOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
DropoutGradientOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
ratio_(OperatorBase::GetSingleArgument<float>("ratio", 0.5)),
is_test_(OperatorBase::GetSingleArgument<int>("is_test", 0)) {
DCHECK_GE(ratio_, 0);
DCHECK_LT(ratio_, 1);
}
bool RunOnDevice() override;
protected:
float ratio_;
bool is_test_;
// Input: dY, mask; Output: dX
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_DROPOUT_OP_H_