blob: c39a96f9c6fd7d84a3176ae501ba2615f816808e [file] [log] [blame]
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <ATen/TensorUtils.h>
#include <ATen/NativeFunctions.h>
#include <c10/util/irange.h>
#include <cstring>
#include <memory>
#include <sstream>
#include <vector>
namespace at { namespace native {
Tensor embedding(const Tensor & weight, const Tensor & indices,
int64_t padding_idx, bool scale_grad_by_freq, bool sparse) {
TORCH_CHECK(weight.dim() >= 1, "'weight' must be at least 1-D");
auto indices_arg = TensorArg(indices, "indices", 1);
checkScalarTypes("embedding", indices_arg, {kLong, kInt});
// TODO: use tensor.index() after improving perf
if (indices.dim() == 1) {
return weight.index_select(0, indices);
}
auto size = indices.sizes().vec();
for (auto d : weight.sizes().slice(1)) {
size.push_back(d);
}
return weight.index_select(0, indices.reshape(-1)).view(size);
}
Tensor embedding_backward(
const Tensor & grad, const Tensor & indices, int64_t num_weights,
int64_t padding_idx, bool scale_grad_by_freq, bool sparse) {
if (sparse) {
return at::embedding_sparse_backward(
grad, indices, num_weights, padding_idx, scale_grad_by_freq);
} else {
return at::embedding_dense_backward(
grad, indices, num_weights, padding_idx, scale_grad_by_freq);
}
}
Tensor embedding_sparse_backward(
const Tensor & grad_, const Tensor & indices_, int64_t num_weights,
int64_t padding_idx, bool scale_grad_by_freq) {
auto indices_arg = TensorArg(indices_, "indices", 2);
checkScalarTypes("embedding_backward", indices_arg, {kLong, kInt});
// TODO: implement scale_grad_by_freq
if (scale_grad_by_freq) {
AT_ERROR(
"embedding_backward: scale_grad_by_freq not supported with sparse gradients");
}
Tensor indices = indices_;
Tensor grad = grad_;
if (padding_idx != -1) {
torch::List<c10::optional<Tensor>> c({indices != padding_idx});
indices = indices.index(c);
grad = grad.index(c);
}
int64_t num_features = grad_.size(-1);
auto weight_size = std::array<int64_t, 2>{{ num_weights, num_features }};
auto dense_options = grad.options();
// check if all our grad come from padding_idx
if (grad.numel() == 0) {
return at::_sparse_coo_tensor_unsafe(at::empty({1, 0}, indices_.options().dtype(kLong)),
at::empty({0, num_features}, dense_options),
weight_size);
}
auto index = indices.reshape({1, -1});
auto values = grad.reshape({-1, num_features});
return at::_sparse_coo_tensor_unsafe(index.to(kLong), values, weight_size);
}
Tensor embedding_dense_backward_cpu(
const Tensor & grad_, const Tensor & indices, int64_t num_weights,
int64_t padding_idx, bool scale_grad_by_freq) {
auto indices_arg = TensorArg(indices, "indices", 2);
checkScalarTypes("embedding_backward", indices_arg, {kLong, kInt});
auto grad_weight = at::zeros({num_weights, grad_.size(-1)}, grad_.options());
auto indices_contig = indices.contiguous();
int64_t numel = indices.numel();
auto grad = grad_.contiguous().view({numel, grad_.size(-1)});
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cpu", [&] () {
auto indices_data = indices_contig.data_ptr<index_t>();
std::unique_ptr<index_t[]> counts;
if (scale_grad_by_freq) {
counts.reset(new index_t[num_weights]);
for (const auto i : c10::irange(numel)) {
counts[indices_data[i]] = 0;
}
for (const auto i : c10::irange(numel)) {
counts[indices_data[i]]++;
}
}
auto parallel_section = [&](index_t start, index_t end) {
for (int64_t i = 0; i < numel; i++) {
if (indices_data[i] != padding_idx) {
index_t k = indices_data[i];
if (k >= start && k < end) {
double scale = 1.0;
if (scale_grad_by_freq) {
scale /= counts[k];
}
grad_weight[k].add_(grad[i], scale);
}
}
}
};
if (numel > 1000) {
at::parallel_for(0, num_weights, 0, parallel_section);
} else {
parallel_section(0, num_weights);
}
});
return grad_weight;
}
Tensor & embedding_renorm_cpu_(
Tensor & self, const Tensor & indices, double max_norm, double norm_type) {
auto self_arg = TensorArg(self, "self", 1);
auto indices_arg = TensorArg(indices, "indices", 2);
checkDim("embedding_renorm_", self_arg, 2);
checkScalarTypes("embedding_renorm_", indices_arg, {kLong, kInt});
auto indices_contig = indices.contiguous();
auto num_indices = indices.numel();
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_renorm_cpu_", [&]() {
auto data_ptr = indices_contig.data_ptr<index_t>();
auto sorted_indices = std::vector<index_t>(data_ptr, data_ptr + num_indices);
std::sort(sorted_indices.begin(), sorted_indices.end());
// Note that we cannot use at::parallel_for here because we perform operations on
// Tensor inside the loop. See github.com/pytorch/pytorch/issues/28370 for more details.
for (auto i = 0; i < num_indices; i++) {
if (i > 0 && sorted_indices[i] == sorted_indices[i - 1]) {
continue;
}
auto row = self[sorted_indices[i]];
auto norm = row.norm(norm_type).item<double>();
if (norm > max_norm) {
auto scale = max_norm / (norm + 1e-7);
row *= scale;
}
}
});
return self;
}
}} // namespace at::native