| #include <ATen/ATen.h> |
| #include <ATen/AccumulateType.h> |
| #include <ATen/Dispatch.h> |
| #include <ATen/Parallel.h> |
| #include <ATen/TensorUtils.h> |
| #include <ATen/native/cpu/utils.h> |
| |
| namespace at { |
| namespace native { |
| |
| namespace { |
| |
| // Returns a contiguous tensor if the source tensor |
| // is defined. Otherwise returns the undefined |
| // source tensor unmodified. |
| inline Tensor optional_contiguous(const Tensor& source) { |
| return source.defined() ? source.contiguous() : source; |
| } |
| |
| // Returns the address of the first element of a tensor |
| // or nullptr if the tensor is undefined. |
| template <typename scalar_t> |
| inline scalar_t* optional_data(const Tensor& source) { |
| return source.defined() ? source.data_ptr<scalar_t>() : nullptr; |
| } |
| |
| inline void check_inputs_nll_loss2d( |
| const Tensor& input, |
| const Tensor& target, |
| const Tensor& weight) { |
| TORCH_CHECK( |
| target.dim() == 3, |
| "only batches of spatial targets supported (3D tensors)" |
| " but got targets of dimension: ", |
| target.dim()); |
| TORCH_CHECK( |
| input.dim() == 4, |
| "only batches of spatial inputs supported (4D tensors), " |
| "but got input of dimension: ", |
| input.dim()); |
| TORCH_CHECK( |
| !weight.defined() || weight.numel() == input.size(1), |
| "weight tensor should be defined either for all or no classes"); |
| |
| const int64_t input0 = input.size(0); |
| const int64_t input2 = input.size(2); |
| const int64_t input3 = input.size(3); |
| const int64_t target0 = target.size(0); |
| const int64_t target1 = target.size(1); |
| const int64_t target2 = target.size(2); |
| TORCH_CHECK( |
| input0 == target0 && input2 == target1 && input3 == target2, |
| "size mismatch (got input: ", |
| input.sizes(), |
| " , target: ", |
| target.sizes()); |
| } |
| |
| inline void check_gradout_shape_nll_loss2d( |
| const Tensor& grad_output, |
| const Tensor& target) { |
| TORCH_CHECK( |
| grad_output.dim() == 3, |
| "grad_output must have same dimension as target (3) but got dimension: ", |
| grad_output.sizes()); |
| |
| const int64_t grad_output0 = grad_output.size(0); |
| const int64_t grad_output1 = grad_output.size(1); |
| const int64_t grad_output2 = grad_output.size(2); |
| const int64_t target0 = target.size(0); |
| const int64_t target1 = target.size(1); |
| const int64_t target2 = target.size(2); |
| TORCH_CHECK( |
| grad_output0 == target0 && grad_output1 == target1 && |
| grad_output2 == target2, |
| "size mismatch (got grad_output: ", |
| grad_output.sizes(), |
| " target: ", |
| target.sizes()); |
| } |
| |
| |
| template <typename scalar_t> |
| static void nll_loss2d_forward_out_frame( |
| Tensor& output, |
| Tensor& total_weight, |
| const Tensor& input, |
| const Tensor& target, |
| const Tensor& weight, |
| int64_t reduction, |
| int64_t ignore_index) { |
| const int64_t n_classes = input.size(1); |
| |
| scalar_t* total_weight_data = total_weight.data_ptr<scalar_t>(); |
| *total_weight_data = 0; |
| |
| auto weight_contiguous = optional_contiguous(weight); |
| const scalar_t* weight_data = optional_data<scalar_t>(weight_contiguous); |
| |
| if (reduction == Reduction::None) { |
| const int64_t batch_size = input.size(0); |
| const int64_t H = input.size(2); |
| const int64_t W = input.size(3); |
| |
| output.resize_({batch_size, H, W}); |
| auto input_acc = input.accessor<scalar_t, 4>(); |
| auto output_acc = output.accessor<scalar_t, 3>(); |
| auto target_acc = target.accessor<int64_t, 3>(); |
| |
| at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) { |
| for (int64_t b = start; b < end; b++) { |
| for (int64_t h = 0; h < H; h++) { |
| for (int64_t w = 0; w < W; w++) { |
| const int64_t cur_target = (int64_t)target_acc[b][h][w]; |
| |
| if (cur_target == ignore_index) { |
| output_acc[b][h][w] = static_cast<scalar_t>(0); |
| continue; |
| } |
| |
| TORCH_CHECK_INDEX( |
| cur_target >= 0 && cur_target < n_classes, |
| "Target ", |
| cur_target, |
| " is out of bounds."); |
| |
| // load optional weight value |
| const scalar_t cur_weight = weight_data != nullptr |
| ? weight_data[cur_target] |
| : static_cast<scalar_t>(1); |
| output_acc[b][h][w] = -input_acc[b][cur_target][h][w] * cur_weight; |
| } |
| } |
| } |
| }); |
| |
| return; |
| } |
| |
| // produce scalar outputs for the reduction case |
| output.resize_({}); |
| |
| auto input_contiguous = input.contiguous(); |
| auto target_contiguous = target.contiguous(); |
| |
| const scalar_t* input_data = input_contiguous.data_ptr<scalar_t>(); |
| const int64_t* target_data = target_contiguous.data_ptr<int64_t>(); |
| |
| const int64_t batch_size = input.size(0); |
| const int64_t map_size = input.size(2) * input.size(3); |
| const int64_t sample_size = map_size * n_classes; |
| const int64_t numiter = batch_size * map_size; |
| |
| constexpr int64_t cascade_sum_num_levels = 8; |
| scalar_t weight_partial_sums[cascade_sum_num_levels] = {0}; |
| scalar_t loss_partial_sums[cascade_sum_num_levels] = {0}; |
| const int64_t level_power = |
| std::max(int64_t(4), utils::CeilLog2(numiter) / cascade_sum_num_levels); |
| const int64_t level_step = (1 << level_power); |
| const int64_t level_mask = level_step - 1; |
| |
| int64_t num_ignored = 0; |
| for (int64_t b = 0; b < batch_size; b++) { |
| for (int64_t elem = 0; elem < map_size; elem++) { |
| const int64_t cur_target = target_data[b * map_size + elem]; |
| if (cur_target == ignore_index) { |
| ++num_ignored; |
| continue; |
| } |
| |
| TORCH_CHECK_INDEX( |
| cur_target >= 0 && cur_target < n_classes, |
| "Target ", |
| cur_target, |
| " is out of bounds."); |
| |
| const auto data = input_data[b * sample_size + cur_target * map_size + elem]; |
| if (weight_data) { |
| const scalar_t weight_val = weight_data[cur_target]; |
| loss_partial_sums[0] -= data * weight_val; |
| weight_partial_sums[0] += weight_val; |
| } else { |
| loss_partial_sums[0] -= data; |
| } |
| |
| const int64_t linear_idx = b * map_size + elem; |
| for (int64_t j = 0; j + 1 < cascade_sum_num_levels; ++j) { |
| const auto mask = (level_mask << (j * level_power)); |
| if (C10_LIKELY((linear_idx & mask) != 0)) { |
| break; |
| } |
| |
| weight_partial_sums[j + 1] += weight_partial_sums[j]; |
| loss_partial_sums[j + 1] += loss_partial_sums[j]; |
| |
| weight_partial_sums[j] = 0; |
| loss_partial_sums[j] = 0; |
| } |
| } |
| } |
| |
| |
| const scalar_t total_weight_val = !weight_data ? |
| static_cast<scalar_t>(numiter - num_ignored) : |
| std::accumulate(std::begin(weight_partial_sums), |
| std::end(weight_partial_sums), |
| scalar_t{0}); |
| |
| scalar_t output_val = std::accumulate(std::begin(loss_partial_sums), |
| std::end(loss_partial_sums), |
| scalar_t{0}); |
| |
| if (reduction == Reduction::Mean && |
| (total_weight_val != 0 || input.numel() == 0)) { |
| // allow NaN result for total_weight_val == 0 case, see #15870 |
| output_val /= total_weight_val; |
| } |
| |
| *total_weight_data = total_weight_val; |
| *output.data_ptr<scalar_t>() = output_val; |
| } |
| |
| void nll_loss2d_forward_out_cpu_template( |
| Tensor& output, |
| Tensor& total_weight, |
| const Tensor& input, |
| const Tensor& target, |
| const Tensor& weight, |
| int64_t reduction, |
| int64_t ignore_index) { |
| check_inputs_nll_loss2d(input, target, weight); |
| total_weight.resize_({}); |
| |
| AT_DISPATCH_FLOATING_TYPES_AND( |
| ScalarType::BFloat16, |
| input.scalar_type(), |
| "nll_loss2d_forward_out_frame", |
| [&] { |
| nll_loss2d_forward_out_frame<scalar_t>( |
| output, |
| total_weight, |
| input, |
| target, |
| weight, |
| reduction, |
| ignore_index); |
| }); |
| } |
| |
| template <typename scalar_t> |
| static void nll_loss2d_backward_out_frame( |
| Tensor& grad_input, |
| const Tensor& grad_output, |
| const Tensor& input, |
| const Tensor& target, |
| const Tensor& weight, |
| int64_t reduction, |
| int64_t ignore_index, |
| const Tensor& total_weight) { |
| auto weight_contiguous = optional_contiguous(weight); |
| const scalar_t* weight_data = optional_data<scalar_t>(weight_contiguous); |
| |
| if (reduction == at::Reduction::None) { |
| check_gradout_shape_nll_loss2d(grad_output, target); |
| |
| const int64_t batch_size = input.size(0); |
| const int64_t H = input.size(2); |
| const int64_t W = input.size(3); |
| |
| auto grad_input_acc = grad_input.accessor<scalar_t, 4>(); |
| auto grad_output_acc = grad_output.accessor<scalar_t, 3>(); |
| auto target_acc = target.accessor<int64_t, 3>(); |
| |
| at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) { |
| for (int64_t b = start; b < end; b++) { |
| for (int64_t h = 0; h < H; h++) { |
| for (int64_t w = 0; w < W; w++) { |
| const int64_t cur_target = target_acc[b][h][w]; |
| if (cur_target == ignore_index) { |
| continue; |
| } |
| const scalar_t value = |
| -(weight_data ? weight_data[cur_target] |
| : static_cast<scalar_t>(1)); |
| const scalar_t grad_output_value = grad_output_acc[b][h][w]; |
| grad_input_acc[b][cur_target][h][w] = value * grad_output_value; |
| } |
| } |
| } |
| }); |
| |
| return; |
| } |
| |
| const scalar_t total_weight_value = *total_weight.data_ptr<scalar_t>(); |
| if (total_weight_value <= 0) { |
| return; |
| } |
| |
| TORCH_CHECK( |
| grad_output.dim() <= 1 && grad_output.numel() == 1, |
| "Expected a single element grad_output tensor, but got: ", |
| grad_output.sizes()); |
| |
| const scalar_t grad_output_value = *grad_output.data_ptr<scalar_t>(); |
| |
| const auto target_contiguous = target.contiguous(); |
| const int64_t* target_data = target_contiguous.data_ptr<int64_t>(); |
| |
| scalar_t* grad_input_data = grad_input.data_ptr<scalar_t>(); |
| |
| const int64_t batch_size = input.size(0); |
| const int64_t n_classes = input.size(1); |
| const int64_t map_size = input.size(2) * input.size(3); |
| const int64_t sample_size = map_size * n_classes; |
| |
| scalar_t normalize = (reduction == at::Reduction::Mean) |
| ? total_weight_value |
| : static_cast<scalar_t>(1); |
| |
| at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) { |
| for (int64_t b = start; b < end; b++) { |
| for (int64_t elem = 0; elem < map_size; elem++) { |
| const int64_t cur_target = target_data[b * map_size + elem]; |
| |
| if (cur_target == ignore_index) { |
| continue; |
| } |
| |
| TORCH_CHECK_INDEX( |
| cur_target >= 0 && cur_target < n_classes, |
| "Target ", |
| cur_target, |
| " is out of bounds."); |
| |
| const int64_t index = b * sample_size + cur_target * map_size + elem; |
| const scalar_t w = weight_data != nullptr ? weight_data[cur_target] |
| : static_cast<scalar_t>(1); |
| grad_input_data[index] = -w / normalize * grad_output_value; |
| } |
| } |
| }); |
| } |
| |
| void nll_loss2d_backward_out_cpu_template( |
| Tensor& grad_input, |
| const Tensor& grad_output, |
| const Tensor& input, |
| const Tensor& target, |
| const Tensor& weight, |
| int64_t reduction, |
| int64_t ignore_index, |
| const Tensor& total_weight) { |
| check_inputs_nll_loss2d(input, target, weight); |
| grad_input.resize_as_(input); |
| grad_input.zero_(); |
| TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous"); |
| TORCH_CHECK( |
| total_weight.numel() == 1, |
| "expected total_weight to be a single element tensor, got: ", |
| total_weight.sizes(), |
| " (", |
| total_weight.numel(), |
| " elements)"); |
| |
| AT_DISPATCH_FLOATING_TYPES_AND( |
| ScalarType::BFloat16, |
| input.scalar_type(), |
| "nll_loss2d_backward_out_frame", |
| [&] { |
| nll_loss2d_backward_out_frame<scalar_t>( |
| grad_input, |
| grad_output, |
| input, |
| target, |
| weight, |
| reduction, |
| ignore_index, |
| total_weight); |
| }); |
| } |
| |
| } // namespace |
| |
| std::tuple<Tensor&, Tensor&> nll_loss2d_forward_out_cpu(const Tensor& self, |
| const Tensor& target, const c10::optional<Tensor>& weight_opt, |
| int64_t reduction, |
| int64_t ignore_index, |
| Tensor& output, |
| Tensor& total_weight) { |
| // See [Note: hacky wrapper removal for optional tensor] |
| c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); |
| const Tensor& weight = *weight_maybe_owned; |
| |
| nll_loss2d_forward_out_cpu_template( |
| output, total_weight, self, target, weight, reduction, ignore_index); |
| return std::tuple<Tensor&, Tensor&>(output, total_weight); |
| } |
| |
| std::tuple<Tensor, Tensor> nll_loss2d_forward_cpu( |
| const Tensor& self, |
| const Tensor& target, const c10::optional<Tensor>& weight_opt, |
| int64_t reduction, |
| int64_t ignore_index) { |
| // See [Note: hacky wrapper removal for optional tensor] |
| c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); |
| const Tensor& weight = *weight_maybe_owned; |
| |
| auto output = at::empty({0}, self.options()); |
| auto total_weight = at::empty({0}, self.options()); |
| at::native::nll_loss2d_forward_out_cpu( |
| self, target, weight, reduction, ignore_index, output, total_weight); |
| return std::make_tuple(output, total_weight); |
| } |
| |
| Tensor& nll_loss2d_backward_out_cpu(const Tensor& grad_output, |
| const Tensor& self, |
| const Tensor& target, const c10::optional<Tensor>& weight_opt, |
| int64_t reduction, |
| int64_t ignore_index, |
| const Tensor& total_weight, |
| Tensor& grad_input) { |
| // See [Note: hacky wrapper removal for optional tensor] |
| c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); |
| const Tensor& weight = *weight_maybe_owned; |
| |
| nll_loss2d_backward_out_cpu_template( |
| grad_input, |
| grad_output, |
| self, |
| target, |
| weight, |
| reduction, |
| ignore_index, |
| total_weight); |
| return grad_input; |
| } |
| |
| Tensor nll_loss2d_backward_cpu( |
| const Tensor& grad_output, |
| const Tensor& self, |
| const Tensor& target, const c10::optional<Tensor>& weight_opt, |
| int64_t reduction, |
| int64_t ignore_index, |
| const Tensor& total_weight) { |
| // See [Note: hacky wrapper removal for optional tensor] |
| c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); |
| const Tensor& weight = *weight_maybe_owned; |
| |
| auto grad_input = at::zeros_like(self); |
| at::native::nll_loss2d_backward_out_cpu( |
| grad_output, |
| self, |
| target, |
| weight, |
| reduction, |
| ignore_index, |
| total_weight, |
| grad_input); |
| return grad_input; |
| } |
| |
| Tensor & nll_loss2d_out(const Tensor & self, const Tensor & target, const c10::optional<Tensor>& weight_opt, int64_t reduction, int64_t ignore_index, Tensor & output) { |
| // See [Note: hacky wrapper removal for optional tensor] |
| c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); |
| const Tensor& weight = *weight_maybe_owned; |
| |
| Tensor total_weight = at::empty({0}, self.options()); |
| return std::get<0>(at::nll_loss2d_forward_out(output, total_weight, self, target, weight, reduction, ignore_index)); |
| } |
| |
| Tensor nll_loss2d(const Tensor & self, const Tensor & target, const c10::optional<Tensor>& weight_opt, int64_t reduction, int64_t ignore_index) { |
| // See [Note: hacky wrapper removal for optional tensor] |
| c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); |
| const Tensor& weight = *weight_maybe_owned; |
| |
| return std::get<0>(at::nll_loss2d_forward(self, target, weight, reduction, ignore_index)); |
| } |
| |
| } // namespace native |
| } // namespace at |