blob: 22f3adff95a013eb5a45298eaf774580cf4da361 [file] [log] [blame]
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <ATen/functorch/BatchRulesHelper.h>
#include <ATen/functorch/PlumbingHelper.h>
#include <ATen/functorch/BatchedFallback.h>
#include <ATen/core/dispatch/Dispatcher.h>
namespace at::functorch {
// Flattens out all dims except the batch dim, and also moves batch dim
// (if it exists) to front.
static at::Tensor flatten_logical(const Tensor& tensor, optional<int64_t> bdim) {
if (bdim.has_value()) {
auto result = moveBatchDimToFront(tensor, bdim);
if (result.dim() > 1) {
return result.flatten(1);
} else {
return result;
}
} else {
return tensor.flatten();
}
}
// Useful for many loss functions
template <typename Func>
static std::tuple<at::Tensor,optional<int64_t>>
loss_batch_rule_helper(const at::Tensor& self, optional<int64_t> self_bdim, const at::Tensor& target,
optional<int64_t> target_bdim, int64_t reduction,
Func loss_fn) {
auto self_ = flatten_logical(self, self_bdim);
auto target_ = flatten_logical(target, target_bdim);
auto result = loss_fn(self_, target_, Reduction::None);
if (result.dim() == 1) {
return std::make_tuple(result, 0);
} else if (reduction == Reduction::None) {
DimVector end_shape;
const auto batched_elem = self_bdim.has_value() ?
moveBatchDimToFront(self, self_bdim) : moveBatchDimToFront(target, target_bdim);
return std::make_tuple(result.reshape(batched_elem.sizes()), 0);
} else if (reduction == Reduction::Sum) {
return std::make_tuple(result.sum(-1), 0);
} else if (reduction == Reduction::Mean) {
return std::make_tuple(result.mean(-1), 0);
}
TORCH_INTERNAL_ASSERT(false);
};
static std::tuple<at::Tensor,optional<int64_t>>
mse_loss_batch_rule(const at::Tensor& self, optional<int64_t> self_bdim, const at::Tensor& target,
optional<int64_t> target_bdim, int64_t reduction) {
return loss_batch_rule_helper(self, self_bdim, target, target_bdim,
reduction, [](const at::Tensor& self, const at::Tensor& target, int64_t reduction) {
return at::mse_loss(self, target, reduction);
});
};
static std::tuple<at::Tensor,optional<int64_t>>
huber_loss_batch_rule(const at::Tensor& self, optional<int64_t> self_bdim, const at::Tensor& target,
optional<int64_t> target_bdim, int64_t reduction, double delta) {
return loss_batch_rule_helper(self, self_bdim, target, target_bdim,
reduction, [delta](const at::Tensor& self, const at::Tensor& target, int64_t reduction) {
return at::huber_loss(self, target, reduction, delta);
});
};
static std::tuple<at::Tensor,optional<int64_t>>
smooth_l1_loss_batch_rule(const at::Tensor& self, optional<int64_t> self_bdim, const at::Tensor& target,
optional<int64_t> target_bdim, int64_t reduction, double beta) {
return loss_batch_rule_helper(self, self_bdim, target, target_bdim,
reduction, [beta](const at::Tensor& self, const at::Tensor& target, int64_t reduction) {
return at::smooth_l1_loss(self, target, reduction, beta);
});
};
static Tensor apply_loss_reduction(const at::Tensor& unreduced, int64_t reduction) {
if (reduction == at::Reduction::Mean) {
return unreduced.mean();
} else if (reduction == at::Reduction::Sum) {
return unreduced.sum();
}
return unreduced;
}
static Tensor binary_cross_entropy_plumbing(
const Tensor& self, const Tensor& target,
const optional<Tensor>& weight, int64_t reduction) {
auto maybe_layer = maybeCurrentDynamicLayer();
vmap_check_escaped(maybe_layer, "binary_cross_entropy_plumbing");
int64_t cur_level = maybe_layer->layerId();
if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)
&& !isBatchedAtLevel(weight, cur_level)) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
return at::binary_cross_entropy(self, target, weight, reduction);
}
auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level);
auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level);
Tensor result;
if (self_bdim || target_bdim) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
const auto bdim_size = get_bdim_size2(self_value, self_bdim, target_value, target_bdim);
auto self_ = moveBatchDimToFront(self_value, self_bdim);
auto target_ = moveBatchDimToFront(target_value, target_bdim);
self_ = ensure_has_bdim(self_, self_bdim.has_value(), bdim_size);
target_ = ensure_has_bdim(target_, target_bdim.has_value(), bdim_size);
result = at::binary_cross_entropy(self_, target_, nullopt, Reduction::None);
result = makeBatched(result, 0, cur_level);
} else {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
result = at::binary_cross_entropy(self_value, target_value, nullopt, Reduction::None);
}
if (weight.has_value() && weight->defined()) {
result = result * weight.value();
}
return apply_loss_reduction(result, reduction);
}
static Tensor binary_cross_entropy_backward_plumbing(
const Tensor& grad, const Tensor& input, const Tensor& target,
const c10::optional<Tensor>& weight_opt, int64_t reduction) {
auto maybe_layer = maybeCurrentDynamicLayer();
vmap_check_escaped(maybe_layer, "binary_cross_entropy_backward_plumbing");
int64_t cur_level = maybe_layer->layerId();
if (!areAnyBatchedAtLevel({grad, input, target, weight_opt}, cur_level)) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
return at::binary_cross_entropy_backward(grad, input, target, weight_opt, reduction);
}
auto [grad_value, grad_bdim] = unwrapTensorAtLevel(
reduction == Reduction::None ? grad : grad.expand_as(input), cur_level);
auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level);
auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level);
Tensor grad_input;
if (grad_bdim || input_bdim || target_bdim) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
const auto bdim_size = get_bdim_size3(
grad_value, grad_bdim, input_value, input_bdim, target_value, target_bdim);
auto grad_ = moveBatchDimToFront(grad_value, grad_bdim);
auto input_ = moveBatchDimToFront(input_value, input_bdim);
auto target_ = moveBatchDimToFront(target_value, target_bdim);
grad_ = ensure_has_bdim(grad_, grad_bdim.has_value(), bdim_size);
input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size);
target_ = ensure_has_bdim(target_, target_bdim.has_value(), bdim_size);
grad_input = at::binary_cross_entropy_backward(
grad_, input_, target_, nullopt, Reduction::None);
grad_input = makeBatched(grad_input, 0, cur_level);
} else {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
grad_input = at::binary_cross_entropy_backward(
grad_value, input_value, target_value, nullopt, Reduction::None);
}
if (weight_opt.has_value() && weight_opt->defined()) {
grad_input = grad_input * weight_opt.value();
}
if (reduction == Reduction::Mean) {
grad_input.div_(input.numel());
}
return grad_input;
}
TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
VMAP_SUPPORT(mse_loss, mse_loss_batch_rule);
// mse_loss_backward uses a decomposition for its batch rule
VMAP_SUPPORT(huber_loss, huber_loss_batch_rule);
// huber_loss_backward uses a decomposition for its batch rule
VMAP_SUPPORT(smooth_l1_loss, smooth_l1_loss_batch_rule);
// smooth_l1_loss_backward uses a decomposition for its batch rule
m.impl("binary_cross_entropy", binary_cross_entropy_plumbing);
m.impl("binary_cross_entropy_backward", binary_cross_entropy_backward_plumbing);
}
} // namespace at::functorch