blob: 9127389dfbc1352eade3844089af6a4cc94eb345 [file] [log] [blame]
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/TensorUtils.h>
#include <ATen/TensorOperators.h>
#include <ATen/native/cpu/WeightNormKernel.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_weight_norm_differentiable_backward_native.h>
#include <ATen/ops/_weight_norm_interface.h>
#include <ATen/ops/_weight_norm_interface_backward_native.h>
#include <ATen/ops/_weight_norm_interface_native.h>
#include <ATen/ops/_weight_norm_native.h>
#include <ATen/ops/empty_strided.h>
#include <ATen/ops/norm_except_dim.h>
#include <ATen/ops/norm_except_dim_native.h>
#endif
#include <vector>
namespace at {
namespace native {
DEFINE_DISPATCH(weight_norm_stub);
DEFINE_DISPATCH(weight_norm_backward_stub);
// Staying faithful to the Python for now for clarity, look for optimizations later
// (e.g., single return statement for RVO)
Tensor norm_except_dim(const Tensor & v, int64_t pow, int64_t dim)
{
// I assume tensor.contiguous(), view(), norm(), etc. here will dispatch through VariableType.
if (dim == -1) {
return v.norm(pow);
} else if (dim == 0) {
std::vector<int64_t> output_size(v.dim(), 1);
output_size[0] = v.size(0);
return v.contiguous().view({v.size(0), -1}).norm(pow, 1).view(output_size);
} else if (dim == v.dim() - 1) {
std::vector<int64_t> output_size(v.dim(), 1);
output_size[v.dim() - 1] = v.size(v.dim() - 1);
return v.contiguous().view({-1, v.size(v.dim() - 1)}).norm(pow, 0).view(output_size);
} else {
// To consider: at::native::norm_except_dim is probably fine as well,
// and would avoid an additional dynamic dispatch.
return at::norm_except_dim(v.transpose(0, dim), pow, 0).transpose(0, dim); // optimize?
}
}
std::tuple<Tensor,Tensor> weight_norm_cpu(
const Tensor& v,
const Tensor& g,
int64_t dim) {
auto w = at::empty_like(v, at::MemoryFormat::Contiguous);
// align with cuda behavior, keep norm in 'Float' when g is 'BFloat16'
const auto dtype = g.scalar_type() == at::ScalarType::BFloat16 ?
at::ScalarType::Float : g.scalar_type();
auto norm = at::empty_strided(g.sizes(), g.strides(), g.options().dtype(dtype));
weight_norm_stub(kCPU, w, norm, v, g, dim);
return std::tuple<Tensor, Tensor>{w, norm};
}
std::tuple<Tensor, Tensor> weight_norm_backward_cpu(
const Tensor& grad_w,
const Tensor& saved_v,
const Tensor& saved_g,
const Tensor& saved_norm,
int64_t dim) {
TORCH_CHECK(saved_v.is_contiguous(), "saved_v must be contiguous");
TORCH_CHECK(saved_g.is_contiguous(), "saved_g must be contiguous");
TORCH_CHECK(saved_norm.is_contiguous(), "saved_norm must be contiguous");
auto grad_v = at::empty_like(saved_v, at::MemoryFormat::Contiguous);
auto grad_g = at::empty_like(saved_g, at::MemoryFormat::Contiguous);
weight_norm_backward_stub(kCPU, grad_v, grad_g, grad_w, saved_v, saved_g, saved_norm, dim);
return std::tuple<Tensor, Tensor>{grad_v, grad_g};
}
Tensor _weight_norm
(const Tensor & v_in,
const Tensor & g_in,
int64_t dim)
{
TORCH_CHECK(
v_in.device() == g_in.device(),
"weight_norm: expected v_in and g_in to be on the same device, but v_in is "
"on ", v_in.device(), " and g_in is on ", g_in.device());
auto v = v_in.contiguous();
auto g = g_in.contiguous();
auto has_half_dtype = v.scalar_type() == at::ScalarType::Half
|| g.scalar_type() == at::ScalarType::Half;
bool can_use_fused = !has_half_dtype && ((dim == 0) || (dim == v.dim() - 1));
if (can_use_fused) {
// weight_norm does not have a derivative defined for it, so this will route back through
// VariableType.cpp, and construct a WeightNormFusedBackward object in the autograd graph.
return std::get<0>(at::_weight_norm_interface(v, g, dim));
} else {
// Double-differentiable primitive ops
// at::native::norm_except_dim would probably be fine as well.
return v*(g/at::norm_except_dim(v, 2, dim));
}
}
// Differentiable backward path, an alternative to weight_norm_backward, to be used
// when backward is itself creating a graph.
// The GradMode::is_enabled() check must be performed within Functions.cpp; that's why we
// define a separate function here, instead of inlining it in weight_norm_cuda_backward.
std::tuple<Tensor, Tensor> _weight_norm_differentiable_backward
(const Tensor & grad_w,
const Tensor & saved_v,
const Tensor & saved_g,
const Tensor & saved_norms,
int64_t dim)
{
// In Functions.cpp, the HardshrinkBackward object supplies "grad.contiguous()"
// as the first argument, so grad_w should be contiguous here.
// All these checks should succeed:
TORCH_CHECK(grad_w.is_contiguous(), "grad_w must be contiguous");
TORCH_CHECK(saved_v.is_contiguous(), "saved_v must be contiguous");
TORCH_CHECK(saved_g.is_contiguous(), "saved_g must be contiguous");
TORCH_CHECK(saved_norms.is_contiguous(), "saved_norms must be contiguous");
int64_t last_dim = saved_v.dim() - 1;
int64_t last_size = saved_v.size(last_dim);
// Like weight_norm_fused_backward, weight_norm_differentiable_backward should only ever be called
// through a WeightNormFusedBackward object, so we expect that dim == 0 || dim == saved_v.size(-1)
TORCH_CHECK(dim == 0 || dim == last_dim, "Expected dim to be the first or last dimension");
// saved_g and saved_norms are already shaped to broadcast over the correct dimensions
// ...but saved_norms might be Float when saved_g and saved_v are half.
// To consider: saved_norms.to(..., True /*non_blocking*/);
auto norms = saved_norms.to(saved_g.scalar_type());
std::vector<int64_t> bcast_size(saved_v.dim(), 1);
// Analytic backward path using differentiable primitive ops
if (dim == 0) {
bcast_size[0] = saved_v.size(0);
auto per_dim_sums = (grad_w*saved_v).view({saved_v.size(0), -1}).sum(1).view(bcast_size);
auto grad_v = (saved_g/norms)*(grad_w - saved_v*(per_dim_sums/(norms*norms)));
auto grad_g = per_dim_sums/norms;
return std::tuple<Tensor, Tensor>{grad_v, grad_g};
} else { // dim == last_dim
bcast_size[last_dim] = last_size;
auto per_dim_sums = (grad_w*saved_v).view({-1, last_size}).sum(0).view(bcast_size);
auto grad_v = (saved_g/norms)*(grad_w - saved_v*(per_dim_sums/(norms*norms)));
auto grad_g = per_dim_sums/norms;
return std::tuple<Tensor, Tensor>{grad_v, grad_g};
}
}
} // namespace native
} // namespace at