blob: 068943baab9d5b91699928735c6fd147986f2c6f [file] [log] [blame]
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
extern "C" {
#include <THNN.h>
}
namespace caffe2 {
namespace {
using UniqueTHFloatTensor =
std::unique_ptr<THFloatTensor, decltype(&THFloatTensor_free)>;
UniqueTHFloatTensor aliasFromTensorCPU(TensorCPU* tensor) {
if (!tensor->ndim()) {
return UniqueTHFloatTensor(THFloatTensor_new(), THFloatTensor_free);
}
THLongStorage* thshape = THLongStorage_newWithSize(tensor->ndim());
for (int i = 0; i < tensor->ndim(); ++i) {
THLongStorage_set(thshape, i, tensor->dim(i));
}
THFloatStorage* storage = THFloatStorage_newWithData(
tensor->template mutable_data<float>(), tensor->size());
THFloatStorage_clearFlag(storage, TH_STORAGE_FREEMEM);
auto* th = THFloatTensor_newWithStorage(storage, 0, thshape, nullptr);
THFloatStorage_free(storage);
THLongStorage_free(thshape);
CAFFE_ENFORCE_EQ(
THFloatTensor_storage(th)->data, tensor->template mutable_data<float>());
return UniqueTHFloatTensor(th, THFloatTensor_free);
}
void copyToTensorCPU(UniqueTHFloatTensor th, TensorCPU* tensor) {
// TODO - if th and tensor point to the same data and have the same
// size, elide the copy!
th = UniqueTHFloatTensor(
THFloatTensor_newContiguous(th.get()), THFloatTensor_free);
const auto dims = std::vector<TIndex>(
th->size, th->size + THFloatTensor_nDimension(th.get()));
// Short-circuit if we never reallocated in TH
auto* storage = THFloatTensor_storage(th.get());
// Short-circuit if we never reallocated in TH
if (dims == tensor->dims() &&
storage->data == tensor->template data<float>()) {
THFloatStorage_clearFlag(storage, TH_STORAGE_FREEMEM);
return;
}
tensor->Resize(dims);
CPUContext ctx;
ctx.Copy<float, CPUContext, CPUContext>(
tensor->size(), storage->data, tensor->mutable_data<float>());
}
// _Everything_ below here can be autogenerated with the TBD
// THNN/THCUNN schema. This is just a proof of concept.
class THNNELUCPUOp final : public Operator<CPUContext> {
public:
USE_OPERATOR_FUNCTIONS(CPUContext);
using Operator<CPUContext>::Operator;
bool RunOnDevice() override {
// TODO - we can autogenerate this from a schema.
auto X = aliasFromTensorCPU(const_cast<TensorCPU*>(&Input(0)));
auto Y = aliasFromTensorCPU(Output(0));
THNN_FloatELU_updateOutput(
nullptr,
X.get(),
Y.get(),
GetSingleArgument<float>("alpha", 1.0),
&Input(0) == Output(0));
copyToTensorCPU(std::move(Y), Output(0));
return true;
}
};
class THNNELUCPUGradientOp final : public Operator<CPUContext> {
public:
USE_OPERATOR_FUNCTIONS(CPUContext);
using Operator<CPUContext>::Operator;
bool RunOnDevice() override {
// TODO - we can autogenerate this from a schema.
auto X = aliasFromTensorCPU(const_cast<TensorCPU*>(&Input(0)));
auto Y = aliasFromTensorCPU(const_cast<TensorCPU*>(&Input(1)));
auto dY = aliasFromTensorCPU(const_cast<TensorCPU*>(&Input(2)));
auto dX = aliasFromTensorCPU(Output(0));
THNN_FloatELU_updateGradInput(
nullptr,
X.get(),
dY.get(),
dX.get(),
Y.get(),
GetSingleArgument<float>("alpha", 1.0),
&Input(2) == Output(0) /* inplace */);
copyToTensorCPU(std::move(dX), Output(0));
return true;
}
};
REGISTER_CPU_OPERATOR_WITH_ENGINE(ELU, THNN, THNNELUCPUOp);
REGISTER_CPU_OPERATOR_WITH_ENGINE(ELUGradient, THNN, THNNELUCPUGradientOp);
class GetELUGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef(
"ELUGradient",
"",
vector<string>{I(0), O(0), GO(0)},
vector<string>{GI(0)},
Def().arg());
}
};
REGISTER_GRADIENT(ELU, GetELUGradient);
}
}