blob: f1027b1ebdf80ca689f23e44c9cc8abd72f9724a [file] [log] [blame]
#include "ATen/ATen.h"
#include "ATen/NativeFunctions.h"
namespace at { namespace native {
Tensor hinge_embedding_loss(const Tensor& self, const Tensor& target, double margin, bool size_average, bool reduce) {
auto zeros = at::zeros_like(self);
auto margin_clamp = (margin - self).clamp_min_(0);
auto output_margin = at::where(target != 1, margin_clamp, zeros);
auto output_self = at::where(target != -1, self, zeros);
auto output = output_margin + output_self;
if (reduce && size_average) {
return output.sum() / self.numel();
} else if (reduce) {
return output.sum();
}
return output;
}
}} // namespace at::native