blob: 9582bf661a32748bda646cd09bdeb1544b2c44e9 [file] [log] [blame]
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/TensorUtils.h>
namespace at {
namespace native {
namespace {
template <typename scalar_t>
inline scalar_t multilabel_margin_loss_forward_inner_sum_cpu(
scalar_t* input_data,
int64_t* target_data,
scalar_t* is_target_data,
int64_t dim) {
using accscalar_t = at::acc_type<scalar_t, false>;
accscalar_t sum = 0;
for (int64_t ddt = 0; ddt < dim; ddt++) {
int64_t target_idx = target_data[ddt];
if (target_idx < 0) {
break;
}
is_target_data[target_idx] = 1;
}
for (int64_t dt = 0; dt < dim; dt++) {
int64_t target_idx = target_data[dt];
if (target_idx < 0) {
break;
}
scalar_t input_target = input_data[target_idx];
for (int64_t d = 0; d < dim; d++) {
if (!is_target_data[d]) {
scalar_t z = 1 - input_target + input_data[d];
if (z > 0) {
sum += z;
}
}
}
}
return sum;
}
template <typename scalar_t>
static void multilabel_margin_loss_forward_out_frame(
const Tensor& input_contiguous,
const Tensor& target_contiguous,
Tensor& output,
Tensor& is_target,
int64_t reduction,
int64_t nframe,
int64_t dim) {
using accscalar_t = at::acc_type<scalar_t, false>;
scalar_t* input_data = input_contiguous.data_ptr<scalar_t>();
int64_t* target_data = target_contiguous.data_ptr<int64_t>();
scalar_t* is_target_data = is_target.data_ptr<scalar_t>();
if (reduction != Reduction::None || output.dim() == 0) {
scalar_t* output_data = output.data_ptr<scalar_t>();
accscalar_t sum = 0;
for (int64_t t = 0; t < nframe; t++) {
sum += multilabel_margin_loss_forward_inner_sum_cpu(
input_data, target_data, is_target_data, dim);
input_data += dim;
target_data += dim;
is_target_data += dim;
}
sum /= dim;
if (reduction == Reduction::Mean) {
sum /= nframe;
}
*output_data = sum; // write scalar output value
} else {
auto output_acc = output.accessor<scalar_t, 1>();
for (int64_t t = 0; t < nframe; t++) {
scalar_t sum = multilabel_margin_loss_forward_inner_sum_cpu(
input_data, target_data, is_target_data, dim);
sum /= dim;
output_acc[t] = sum;
input_data += dim;
target_data += dim;
is_target_data += dim;
}
}
}
static void multilabel_margin_loss_forward_out_cpu_template(
const Tensor& input,
const Tensor& target,
Tensor& output,
Tensor& is_target,
int64_t reduction) {
auto target_arg = TensorArg(target, "target", 2);
const auto ndims = input.dim();
TORCH_CHECK(
input.numel() > 0 && ndims <= 2,
"non-empty vector or matrix expected, got size: ",
input.sizes());
int64_t nframe, dim;
if (ndims <= 1) {
nframe = 1;
dim = ndims == 0 ? 1 : input.size(0);
TORCH_CHECK(
target.numel() > 0 && target.dim() <= 1 && target.numel() == dim,
"inconsistent size ",
target.sizes(),
" for ",
target_arg);
} else {
nframe = input.size(0);
dim = input.size(1);
TORCH_CHECK(
target.numel() > 0 && target.dim() == 2 && target.size(0) == nframe &&
target.size(1) == dim,
"inconsistent size ",
target.sizes(),
" for ",
target_arg);
}
TORCH_CHECK(
target.min().item<int64_t>() >= -1, target_arg, " is out of range");
TORCH_CHECK(
target.max().item<int64_t>() < dim, target_arg, " is out of range");
auto input_contiguous = input.contiguous();
auto target_contiguous = target.contiguous();
is_target.resize_as_(target);
TORCH_CHECK(is_target.is_contiguous(), "is_target must be contiguous");
is_target.zero_();
// special case target.dim() <= 1: produce scalar output for scalar inputs
// even if reduction == Reduction::None
if (reduction != Reduction::None || target.dim() <= 1) {
output.resize_({});
} else {
output.resize_({nframe});
}
AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "multilabel_margin_loss_forward_out_frame", [&] {
multilabel_margin_loss_forward_out_frame<scalar_t>(
input_contiguous, target_contiguous, output, is_target, reduction, nframe, dim);
});
}
template <typename scalar_t>
static void multilabel_margin_loss_backward_out_frame(
Tensor& grad_input,
const Tensor& grad_output,
const Tensor& input_contiguous,
const Tensor& target_contiguous,
int64_t reduction,
const Tensor& is_target_contiguous,
int64_t nframe,
int64_t dim) {
CheckedFrom c = "multilabel_margin_loss_backward_out_frame";
auto is_target_arg = TensorArg(is_target_contiguous, "is_target", 5);
TORCH_CHECK(
is_target_contiguous.min().item<scalar_t>() >= 0, is_target_arg, " is out of range");
TORCH_CHECK(
is_target_contiguous.max().item<scalar_t>() <= 1, is_target_arg, " is out of range");
scalar_t* input_data = input_contiguous.data_ptr<scalar_t>();
int64_t* target_data = target_contiguous.data_ptr<int64_t>();
scalar_t* is_target_data = is_target_contiguous.data_ptr<scalar_t>();
scalar_t g = static_cast<scalar_t>(
reduction == Reduction::Mean ? 1. / (nframe * dim) : 1. / dim);
scalar_t* grad_input_row_data = grad_input.data_ptr<scalar_t>();
for (int64_t t = 0; t < nframe; t++) {
for (int64_t dt = 0; dt < dim; dt++) {
int64_t target_idx = target_data[dt];
if (target_idx < 0) {
break;
}
scalar_t input_target = input_data[target_idx];
for (int64_t d = 0; d < dim; d++) {
if (!is_target_data[d]) {
scalar_t z = 1 - input_target + input_data[d];
if (z > 0) {
grad_input_row_data[target_idx] -= g;
grad_input_row_data[d] += g;
}
}
}
}
input_data += dim;
target_data += dim;
is_target_data += dim;
grad_input_row_data += dim;
}
scalar_t* grad_input_data = grad_input.data_ptr<scalar_t>();
if (reduction != Reduction::None || grad_output.dim() == 0) {
assert(
reduction != Reduction::None || grad_output.dim() > 0 || nframe == 1);
const auto d = *grad_output.data_ptr<scalar_t>();
for (int64_t t = 0; t < nframe * dim; t++) {
grad_input_data[t] *= d;
}
} else {
check_dim_size(grad_output, 1, 0, nframe);
auto grad_output_acc = grad_output.accessor<scalar_t, 1>();
for (int64_t t = 0; t < nframe; t++) {
for (int64_t d = 0; d < dim; d++) {
grad_input_data[t * dim + d] *= grad_output_acc[t];
}
}
}
}
static void multilabel_margin_loss_backward_out_cpu_template(
Tensor& grad_input,
const Tensor& grad_output,
const Tensor& input,
const Tensor& target,
int64_t reduction,
const Tensor& is_target) {
CheckedFrom c = "multilabel_margin_loss_backward_cpu_template";
auto target_arg = TensorArg(target, "target", 3);
auto is_target_arg = TensorArg(is_target, "is_target", 5);
const auto ndims = input.dim();
TORCH_CHECK(
input.numel() > 0 && ndims <= 2,
"non-empty vector or matrix expected, got size: ",
input.sizes());
int64_t nframe, dim;
if (ndims <= 1) {
nframe = 1;
dim = ndims == 0 ? 1 : input.size(0);
TORCH_CHECK(
target.numel() > 0 && target.dim() <= 1 && target.numel() == dim,
"inconsistent size ",
target.sizes(),
" for ",
target_arg);
} else {
nframe = input.size(0);
dim = input.size(1);
TORCH_CHECK(
target.numel() > 0 && target.dim() == 2 && target.size(0) == nframe &&
target.size(1) == dim,
"inconsistent size ",
target.sizes(),
" for ",
target_arg);
}
checkSameSize(c, target_arg, is_target_arg);
TORCH_CHECK(
target.min().item<int64_t>() >= -1, target_arg, " is out of range");
TORCH_CHECK(
target.max().item<int64_t>() < dim, target_arg, " is out of range");
auto input_contiguous = input.contiguous();
auto target_contiguous = target.contiguous();
auto is_target_contiguous = is_target.contiguous();
grad_input.resize_as_(input);
TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
grad_input.zero_();
AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "multilabel_margin_loss_backward_out_frame", [&] {
multilabel_margin_loss_backward_out_frame<scalar_t>(
grad_input,
grad_output,
input_contiguous,
target_contiguous,
reduction,
is_target_contiguous,
nframe,
dim);
});
}
} // namespace
std::tuple<Tensor&, Tensor&> multilabel_margin_loss_forward_out_cpu(
Tensor& output,
Tensor& is_target,
const Tensor& self,
const Tensor& target,
int64_t reduction) {
multilabel_margin_loss_forward_out_cpu_template(
self, target, output, is_target, reduction);
return std::tuple<Tensor&, Tensor&>(output, is_target);
}
std::tuple<Tensor, Tensor> multilabel_margin_loss_forward_cpu(
const Tensor& self,
const Tensor& target,
int64_t reduction) {
auto output = at::empty({0}, self.options());
auto is_target = at::empty({0}, self.options());
multilabel_margin_loss_forward_out_cpu(
output, is_target, self, target, reduction);
return std::make_tuple(output, is_target);
}
Tensor& multilabel_margin_loss_backward_cpu_out(
Tensor& grad_input,
const Tensor& grad_output,
const Tensor& self,
const Tensor& target,
int64_t reduction,
const Tensor& is_target) {
multilabel_margin_loss_backward_out_cpu_template(
grad_input, grad_output, self, target, reduction, is_target);
return grad_input;
}
Tensor multilabel_margin_loss_backward_cpu(
const Tensor& grad_output,
const Tensor& self,
const Tensor& target,
int64_t reduction,
const Tensor& is_target) {
auto grad_input = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
multilabel_margin_loss_backward_cpu_out(
grad_input, grad_output, self, target, reduction, is_target);
return grad_input;
}
Tensor & multilabel_margin_loss_out(Tensor & output, const Tensor & self, const Tensor & target, int64_t reduction) {
Tensor is_target = at::empty({0}, self.options());
return std::get<0>(at::multilabel_margin_loss_forward_out(output, is_target, self, target, reduction));
}
Tensor multilabel_margin_loss(const Tensor & self, const Tensor & target, int64_t reduction) {
return std::get<0>(at::multilabel_margin_loss_forward(self, target, reduction));
}
} // namespace native
} // namespace at