blob: f8fbf0ae5137d059a55cc32e90b92015fcb3f345 [file] [log] [blame]
#pragma once
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
namespace caffe2 {
template <class Context>
class AdjustBatchOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
AdjustBatchOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
max_batch_size_(
this->template GetSingleArgument<int64_t>("max_batch_size", -1)) {}
bool RunOnDevice() override {
auto& input = Input(0);
vector<int64_t> output_dims(input.sizes().vec());
CAFFE_ENFORCE(!output_dims.empty());
if (InputSize() > 1) {
// TODO: if we have a second input and we have max_batch_size set, check
// the batch size of the two inputs for consistency
auto& batch_size = Input(1);
int64_t real_batch_size = *batch_size.template data<int64_t>();
int64_t max_batch_size = output_dims[0];
CAFFE_ENFORCE_GE(max_batch_size, real_batch_size);
output_dims[0] = real_batch_size;
auto* output = Output(0, output_dims, input.dtype());
this->context_.template CopyItems<Context, Context>(
input.dtype(),
input.numel() * real_batch_size / max_batch_size,
input.raw_data(),
output->raw_mutable_data(input.dtype()));
} else {
// Pad to max batch size
CAFFE_ENFORCE_GT(
max_batch_size_,
0,
"max_batch_size should be larger than 0. Got ",
max_batch_size_);
// TODO: ideally we can support the case when input batch is larger than
// the max_batch_size, as we can just pad to the multiple of
// max_batch_size.
CAFFE_ENFORCE_GE(max_batch_size_, output_dims.front());
int64_t real_batch_size = output_dims[0];
output_dims[0] = max_batch_size_;
auto* output = Output(0, output_dims, input.dtype());
math::Set(
output->nbytes(),
static_cast<char>(0),
static_cast<char*>(output->raw_data()),
&context_);
this->context_.template CopyItems<Context, Context>(
input.dtype(),
input.numel(),
input.raw_data(),
output->raw_mutable_data(input.dtype()));
if (OutputSize() > 1) {
auto* real_batch_tensor = Output(1, {1}, at::dtype<int64_t>());
real_batch_tensor->template mutable_data<int64_t>()[0] =
real_batch_size;
}
}
return true;
}
private:
int64_t max_batch_size_;
};
} // namespace caffe2