blob: cf4b62404dfa9f6fa7aa3bf88adaaa6ee66ebaa3 [file] [log] [blame]
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "caffe2/operators/elementwise_op.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
struct SigmoidCPUFunctor {
template <typename T>
inline void
operator()(const int n, const T* x, T* y, CPUContext* /*device_context*/) {
ConstEigenVectorArrayMap<T> xM(x, n);
EigenVectorArrayMap<T>(y, n) = 1. / (1. + (-xM).exp());
}
};
struct SigmoidGradientCPUFunctor {
template <typename T>
inline void Run(
const int n,
const T* y,
const T* dy,
T* dx,
CPUContext* /*device_context*/) {
ConstEigenVectorArrayMap<T> yM(y, n), dyM(dy, n);
EigenVectorArrayMap<T>(dx, n) = dyM * yM * (1. - yM);
}
};
REGISTER_CPU_OPERATOR(
Sigmoid, UnaryElementwiseOp<
TensorTypes<float>, CPUContext, SigmoidCPUFunctor>);
REGISTER_CPU_OPERATOR(
SigmoidGradient,
BinaryElementwiseOp<
TensorTypes<float>,
CPUContext,
WithoutBroadcast<SigmoidGradientCPUFunctor>>);
// Input: X, output: Y
OPERATOR_SCHEMA(Sigmoid)
.NumInputs(1)
.NumOutputs(1)
.AllowInplace({{0, 0}})
.IdenticalTypeAndShape()
.SetDoc(R"DOC(
Sigmoid takes one input data (Tensor<T>) and produces one output data
(Tensor<T>) where the sigmoid function, y = 1 / (1 + exp(-x)), is applied to the
tensor elementwise.
)DOC")
.Input(0, "X", "1D input tensor")
.Output(0, "Y", "1D output tensor");
// Input: Y, dY, output: dX
OPERATOR_SCHEMA(SigmoidGradient)
.NumInputs(2)
.NumOutputs(1)
.AllowInplace({{1, 0}})
.SetDoc(R"DOC(
SigmoidGradient takes both Y and dY and uses this to update dX according to the
chain rule and derivatives of the sigmoid function.
)DOC");
class GetSigmoidGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef(
"SigmoidGradient", "",
vector<string>{O(0), GO(0)},
vector<string>{GI(0)});
}
};
REGISTER_GRADIENT(Sigmoid, GetSigmoidGradient);
} // namespace caffe2