blob: d6b099decb3f573fbd2a25cbef981d7463f08023 [file] [log] [blame]
#pragma once
#include <immintrin.h>
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/types.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
template <class Context>
class ConcatAddMulReplaceNaNClipOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
ConcatAddMulReplaceNaNClipOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {
if (HasArgument("clip_min")) {
min_ = static_cast<float>(this->template GetSingleArgument<float>(
"clip_min", std::numeric_limits<float>::lowest()));
}
if (HasArgument("clip_max")) {
max_ = static_cast<float>(this->template GetSingleArgument<float>(
"clip_max", std::numeric_limits<float>::max()));
}
}
bool RunOnDevice() {
auto concat_input_start = 2;
auto axis_ = 1;
Tensor* split = Output(
1,
vector<int64_t>(1, InputSize() - concat_input_start),
at::dtype<int>());
int* axis_data = split->template mutable_data<int>();
auto& add_input = Input(0);
auto& mul_input = Input(1);
auto& concat_input_0 = Input(2);
int adj_size = concat_input_0.dim();
int canonical_axis = canonical_axis_index_(axis_, adj_size);
CAFFE_ENFORCE_LT(canonical_axis, adj_size, "Axis not in input ndim range.");
for (int i = concat_input_start + 1; i < InputSize(); ++i) {
CAFFE_ENFORCE(
Input(i).dtype() == concat_input_0.dtype(),
"All inputs must have the same type, expected: ",
concat_input_0.dtype().name(),
" but got: ",
Input(i).dtype().name(),
" for input: ",
i);
}
int before = 1, after = 1;
vector<int64_t> output_dims(concat_input_0.sizes().vec());
for (const auto i : c10::irange(concat_input_0.dim())) {
if (i == canonical_axis) {
continue;
}
int dim = concat_input_0.dim32(i);
if (i < canonical_axis) {
before *= dim;
} else { // i > canonical_axis
after *= dim;
}
// check the input dims are compatible.
for (const auto j : c10::irange(concat_input_start, InputSize())) {
int dim_j = Input(j).dim32(i);
CAFFE_ENFORCE(
dim == dim_j,
"Expect dimension = ",
dim,
" got ",
dim_j,
" at axis = ",
i,
" for input: ",
j,
". The input tensors can only have different dimensions "
"when arg 'add_axis' = 0 and along the axis = ",
canonical_axis,
" <",
Input(0).sizes(),
"> vs <",
Input(j).sizes(),
">.");
}
}
CAFFE_ENFORCE(
concat_input_0.dim() <= 2,
"Cannot handle fused concat with dim > 2, please update your fusion logic");
int output_channels = 0;
for (const auto i : c10::irange(concat_input_start, InputSize())) {
axis_data[i - concat_input_start] = Input(i).dim32(canonical_axis);
output_channels += Input(i).dim32(canonical_axis);
}
output_dims[canonical_axis] = output_channels;
auto* output = Output(0, output_dims, at::dtype<float>());
size_t output_offset = 0;
for (const auto i : c10::irange(concat_input_start, InputSize())) {
auto& input = Input(i);
auto axis_dim = input.dim32(canonical_axis);
math::CopyMatrix<Context>(
input.itemsize(),
before,
axis_dim * after,
input.raw_data(),
axis_dim * after,
static_cast<char*>(output->raw_mutable_data(concat_input_0.dtype())) +
output_offset,
output_channels * after,
&context_,
concat_input_0.dtype().copy());
output_offset += axis_dim * after * input.itemsize();
}
float* output_data = output->template mutable_data<float>();
const float* add_input_data = add_input.template data<float>();
const float* mul_input_data = mul_input.template data<float>();
const auto _max_mask = _mm256_set1_ps(max_);
const auto _min_mask = _mm256_set1_ps(min_);
const auto _zeros = _mm256_set1_ps(0.f);
output_offset = 0;
for (const auto outer : c10::irange(before)) {
auto axis_dim = output->dim32(canonical_axis);
size_t inner_size = axis_dim * after;
auto inner = 0;
for (; inner < inner_size; inner += 8) {
if (inner + 7 >= inner_size) {
break;
}
auto elem = _mm256_loadu_ps(&(output_data[output_offset + inner]));
auto add_elem = _mm256_loadu_ps(&(add_input_data[inner]));
auto mul_elem = _mm256_loadu_ps(&(mul_input_data[inner]));
auto added = _mm256_add_ps(elem, add_elem);
auto mulled = _mm256_mul_ps(added, mul_elem);
// ordered non-signaling compare returns false on NaN
auto mask = _mm256_cmp_ps(mulled, mulled, _CMP_EQ_OQ);
auto removed_nan = _mm256_blendv_ps(_zeros, mulled, mask);
auto out_val =
_mm256_max_ps(_mm256_min_ps(_max_mask, removed_nan), _min_mask);
_mm256_storeu_ps(&output_data[output_offset + inner], out_val);
}
for (const auto inner_omp : c10::irange(inner, inner_size)) {
float elem = output_data[output_offset + inner_omp];
float add_elem = add_input_data[inner_omp];
float mul_elem = mul_input_data[inner_omp];
float clipped = (elem + add_elem) * mul_elem;
if (std::isnan(clipped)) {
clipped = 0;
}
if (clipped > max_) {
clipped = max_;
} else if (clipped < min_) {
clipped = min_;
}
output->template mutable_data<float>()[output_offset + inner_omp] = clipped;
}
output_offset += axis_dim * after;
}
return true;
}
protected:
float min_;
float max_;
};
} // namespace caffe2