blob: 0e43cc19aae513e106cbd2f884ebf04eb289de5d [file] [log] [blame]
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/nn_ops.cc.
#include "tensorflow/core/kernels/nth_element_op.h"
#include <algorithm>
#include <iostream>
#include <vector>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
template <typename Device, typename T>
class NthElementOp : public OpKernel {
public:
explicit NthElementOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("reverse", &reverse_));
}
void Compute(OpKernelContext* context) override {
// The second args is N, which must be a positive scalar.
const auto& n_in = context->input(1);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(n_in.shape()),
errors::InvalidArgument("N must be scalar, got shape ",
n_in.shape().DebugString()));
int n = n_in.scalar<int32>()();
OP_REQUIRES(context, n >= 0,
errors::InvalidArgument("Need n >= 0, got ", n));
// The first args is input tensor, which must have 1 dimension at least.
const Tensor& input_in = context->input(0);
const int num_dims = input_in.dims();
OP_REQUIRES(context, num_dims >= 1,
errors::InvalidArgument("Input must be >= 1-D, got shape ",
input_in.shape().DebugString()));
// The last dimension of input tensor must be greater than N.
OP_REQUIRES(
context, input_in.dim_size(num_dims - 1) > n,
errors::InvalidArgument("Input must have at least n+1 columns"));
// std::nth_element only support the nth-smallest selection.
if (reverse_) {
n = input_in.dim_size(num_dims - 1) - n - 1;
}
// Assume input_shape is [d1,d2,...dk], and output_shape is [d1,d2...dk-1].
TensorShape out_shape;
for (int i = 0; i < num_dims - 1; ++i) {
out_shape.AddDim(input_in.dim_size(i));
}
Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, out_shape, &output_tensor));
functor::NthElementFunctor<Device, T> nthElementFunc;
nthElementFunc(context, input_in, *output_tensor, n, reverse_);
}
private:
bool reverse_;
};
namespace functor {
template <typename T>
struct NthElementFunctor<CPUDevice, T> {
void operator()(OpKernelContext* context, const Tensor& input_tensor,
Tensor& output_tensor, int n, bool reverse) {
const T* input = input_tensor.flat<T>().data();
T* output = output_tensor.flat<T>().data();
// Assume input_shape is [d1,d2,...dk], and output_shape is [d1,d2...dk-1],
// then num_rows = d1*d2...dk-1, last_dim = dk.
const int num_rows = output_tensor.NumElements();
const int last_dim = input_tensor.dim_size(input_tensor.dims() - 1);
// Allocate each row to different shard.
auto SubNthElement = [&, input, output, last_dim, n](int start, int limit) {
// std::nth_element would rearrange the array, so we need a new buffer.
std::vector<T> buf(last_dim);
for (int b = start; b < limit; ++b) {
// Copy from one row of elements to buffer
const T* input_start = input + b * last_dim;
const T* input_end = input + (b + 1) * last_dim;
std::copy(input_start, input_end, buf.begin());
std::nth_element(buf.begin(), buf.begin() + n, buf.end());
// The element placed in the nth position is exactly the element that
// would occur in this position if the range was fully sorted.
output[b] = buf[n];
}
};
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
// The average time complexity of partition-based nth_element (BFPRT) is
// O(n), although the worst time complexity could be O(n^2). Here, 20 is a
// empirical factor of cost_per_unit.
Shard(worker_threads.num_threads, worker_threads.workers, num_rows,
20 * last_dim, SubNthElement);
}
};
} // namespace functor
#define REGISTER_NTHOP(T) \
REGISTER_KERNEL_BUILDER( \
Name("NthElement").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
NthElementOp<CPUDevice, T>)
TF_CALL_REAL_NUMBER_TYPES(REGISTER_NTHOP);
#undef REGISTER_NTHOP
} // end namespace tensorflow