blob: 4d53035b150f4d349f20e7e7229f0045470e4f31 [file] [log] [blame]
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "batch_permutation_op.h"
#ifdef CAFFE2_USE_MKLDNN
#include <caffe2/ideep/operators/operator_fallback_ideep.h>
#include <caffe2/ideep/utils/ideep_operator.h>
#endif
namespace caffe2 {
#ifdef CAFFE2_USE_MKLDNN
REGISTER_IDEEP_OPERATOR(
BatchPermutation,
IDEEPFallbackOp<BatchPermutationOp<float, CPUContext>>);
#endif
REGISTER_CPU_OPERATOR(BatchPermutation, BatchPermutationOp<float, CPUContext>);
REGISTER_CPU_OPERATOR(
BatchPermutationGradient,
BatchPermutationGradientOp<float, CPUContext>);
OPERATOR_SCHEMA(BatchPermutation)
.NumInputs(2)
.NumOutputs(1)
.SetDoc(R"DOC(
Permute the batch elements of the input tensor X according to the permutation
specified in the input indices.
Warning: this op does not verify that indices is a valid permutation; gradient
comptuation is only correct if indices is a permutation.
)DOC")
.Input(
0,
"X",
"Tensor of at least 1D shape (N, D0, D1, ...).")
.Input(
1,
"indices",
"1D tensor of type int with shape (N, ) specifying a valid permutation "
"of the indices in [0, N - 1] (inclusive).")
.Output(
0,
"Y",
"Tensor with the same shape as X where the (D0, D1, ...) dimensional "
"batch elements of X are permuted according to the input indices.");
OPERATOR_SCHEMA(BatchPermutationGradient)
.NumInputs(2)
.NumOutputs(1)
.Input(
0,
"indices",
"See BatchPermutation.")
.Input(
1,
"dY",
"Gradient of forward output 0 (Y).")
.Output(
0,
"dX",
"Gradient of forward input 0 (X).");
template <>
bool BatchPermutationOp<float, CPUContext>::RunOnDevice() {
const auto& X = Input(0);
const auto& indices = Input(1);
auto* Y = Output(0);
CAFFE_ENFORCE_EQ(indices.dim(), 1, "indices must be 1-d");
CAFFE_ENFORCE_EQ(
X.dim32(0), indices.dim32(0),
"X.dim32(0) must be equal to indices.dim32(0)",
"(",
X.dim32(0),
" vs. ",
indices.dim32(0),
")");
Y->ResizeLike(X);
const int N = X.dim32(0);
const int C = X.dim32(1);
const int H = X.dim32(2);
const int W = X.dim32(3);
const float *src = X.template data<float>();
float *dst = Y->template mutable_data<float>();
#ifdef _OPENMP
#if (_OPENMP >= 201307)
#pragma omp parallel for simd
#else
#pragma omp parallel for
#endif
#endif
for (int i = 0; i < N; i++) {
int idx = indices.template data<int>()[i];
std::memcpy(dst + i * C * H * W, src + idx * C * H * W, sizeof(float) * C * H * W);
}
return true;
}
class GetBatchPermutationGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef(
"BatchPermutationGradient",
"",
vector<string>{I(1), GO(0)},
vector<string>{GI(0)});
}
};
REGISTER_GRADIENT(BatchPermutation, GetBatchPermutationGradient);
} // namespace caffe2