blob: c82bf8ba0ae043af71634649586b30ba12f6d8a1 [file] [log] [blame]
#include "ATen/ATen.h"
#include "ATen/NativeFunctions.h"
#include "ATen/WrapDimUtilsMulti.h"
#include <cctype>
namespace at { namespace native {
// sumproduct_pair computes `(left*right).sum(sumdims)` by means of permutation and
// batch matrix multiplication
// its main purpose is to provide a pairwise reduction for einsum
static Tensor sumproduct_pair(const Tensor& left_, const Tensor& right_, IntList sum_dims_, bool keepdim) {
// assumes that tensors have been pre-unsqueezed (so that all dimensions match - after broadcasting)
// but makes no other assumptions on the order of dimensions
AT_CHECK(left_.dim()==right_.dim(), "number of dimensions must match");
if (sum_dims_.size() == 0)
return at::mul(left_, right_);
int64_t dim = left_.dim();
auto sum_dims = dim_list_to_bitset(sum_dims_, dim);
// dimensions that will be part of the output (i.e. not summed over) in three vectors
// dims in lro appear in left, right and output, similarly lo: left and output, ro: right and output
// also the sizes are kept track of for reshaping
std::vector<int64_t> lro, lo, ro;
int64_t lro_size = 1, lo_size = 1, ro_size = 1, sum_size = 1;
Tensor left = left_;
Tensor right = right_;
for (int64_t i = 0; i < dim; i++) {
auto sl = left.size(i)>1;
auto sr = right.size(i)>1;
if (sum_dims[i]) { // first dimensions that will be summed over after multiplication
if (sl && sr) { // dimensions nontrivially in both left and right must be of the same size
AT_CHECK(left.size(i)==right.size(i), "non-broadcast dimensions must match");
sum_size *= left.size(i);
} else if (sl) { // if it is only in one of left and right, we can sum right away
left = left.sum(i, true);
} else if (sr) {
right = right.sum(i, true);
}
} else if (sl && sr) { // now deal with dimensions dimensions that will be in the output
// dimensions nontrivially in both left and right must be of the same size
AT_CHECK(left.size(i)==right.size(i), "non-broadcast dimensions must match");
lro.push_back(i);
lro_size *= left.size(i);
} else if (sl) { // keep track of dimensions appearing only once
lo.push_back(i);
lo_size *= left.size(i);
} else {
ro.push_back(i);
ro_size *= right.size(i);
}
}
// we now work with the following permutations / shapes.
// the pipeline is permute inputs -> reshape inputs -> batch matrix mul -> reshape(view) output -> permute output
// output: "lro, lo, 1-for-summed-dims, ro" with orgiginal shape dimensions
// left: "lro, lo, summed" permuted with lpermutation and the three flattened
// right: "lro, summed, ro" permuted with rpermutation and the three flattened
// then the permuted output is a view of bmm(left, right)
// finally, opermutation reverts the permutation to the original order of dimensions
std::vector<int64_t> out_size;
for (auto& d : lro) out_size.push_back(left.size(d));
for (auto& d : lo) out_size.push_back(left.size(d));
for (auto& d : sum_dims_) { out_size.push_back(1); (void)(d); }; // avoid warining about not using d
for (auto& d : ro) out_size.push_back(right.size(d));
std::vector<int64_t> lpermutation(lro);
lpermutation.insert(lpermutation.end(), lo.begin(), lo.end());
lpermutation.insert(lpermutation.end(), sum_dims_.begin(), sum_dims_.end());
lpermutation.insert(lpermutation.end(), ro.begin(), ro.end());
std::vector<int64_t> rpermutation(lro);
rpermutation.insert(rpermutation.end(), sum_dims_.begin(), sum_dims_.end());
rpermutation.insert(rpermutation.end(), ro.begin(), ro.end());
rpermutation.insert(rpermutation.end(), lo.begin(), lo.end());
std::vector<int64_t> opermutation(lro.size()+lo.size()+sum_dims_.size()+ro.size(), -1);
{
int64_t i = 0;
for (auto it = lro.begin(); it != lro.end(); i++, it++) {
opermutation[*it] = i;
}
for (auto it = lo.begin(); it != lo.end(); i++, it++) {
opermutation[*it] = i;
}
for (auto it = sum_dims_.begin(); it != sum_dims_.end(); i++, it++) {
opermutation[*it] = i;
}
for (auto it = ro.begin(); it != ro.end(); i++, it++) {
opermutation[*it] = i;
}
}
// now we can execute the operations above
left = left.permute(lpermutation).reshape({lro_size, lo_size, sum_size});
right = right.permute(rpermutation).reshape({lro_size, sum_size, ro_size});
Tensor result = at::bmm(left, right);
result = result.view(out_size).permute(opermutation);
// finally squeeze summed dimensions if desired
if (! keepdim) {
for (int i = dim-1; i>=0; i--)
if (sum_dims[i])
result.squeeze_(i);
}
return result;
}
Tensor einsum(std::string eqn, TensorList tensors) {
constexpr size_t number_of_letters = 26;
std::string in_eqn;
size_t pos;
// The equation is given in terms of single lowercase letters ('a'..'z') and potentially an ellipsis.
// Internally, we represent it using indices from 0 to num_total_dimensions, with each letter
// mapped to an index and the ellipsis ('...') being mapped to a number of consequtive indices.
// The mapping of letters to internal indices is given in letter_mapping. A value of -1 means that
// the letter has not been assigned an index yet (because it has not been seen).
// The ellipsis is defined by first_ell_idx (the first index) and num_ell_idxes (the number of indices).
// A value of -1 for num_ell_idxes specifies that we have not seen an ellipsis yet.
// Note: The internal indices are NOT the dimensions used internally. There is a mapping to them below.
std::array<std::int64_t, number_of_letters> letter_mapping; // map letter to internal (numerical) label
letter_mapping.fill(-1);
int64_t num_ell_idxes = -1;
int64_t first_ell_idx = 0;
// The internal representation of the left hand side fo the equation (with ellipsis expanded) is stored in input_op_idxes.
// For each operand, we have a vector mapping each dimension to an internal index.
// We also keep track of the number of occurrences for each letter (to infer a right hand side if not given) and
// of the last occurence of each index.
std::vector<std::vector<int64_t>> input_op_idxes; // the parsed operand indices
std::array<std::int64_t, number_of_letters> num_letter_occurrences; // number of occurrence in the equation of this letter
num_letter_occurrences.fill(0);
std::vector<std::int64_t> last_idx_occurrence; // the last operator (left to right) using this index
if ((pos = eqn.find("->")) != std::string::npos) { // check whether we have a right hand side. in_eq is the left hand side
in_eqn = eqn.substr(0, pos);
} else {
in_eqn = eqn;
}
// remove spaces for einsum compatibility (#9929)
in_eqn.erase(std::remove_if(in_eqn.begin(), in_eqn.end(), isspace), in_eqn.end());
// next we parse in_eq (the left hand side) by iterating. It is a string of comma separated terms per index
int64_t operand = 0;
std::stringstream eqn_stream(in_eqn);
std::string term;
int64_t num_total_idxes = 0;
while (! eqn_stream.eof()) {
std::getline(eqn_stream, term, ','); // term = string with indices of current term
AT_CHECK((int64_t) tensors.size()>operand, "more operands in equation than tensors"); // we cannot have a longer equation than operands. We need to check here before we use the dimension
int64_t ell_char_count = 0; // handling of ellipsis '...' is a bit tedious, we count the '.'
// if there is an ellipsis, the number of dimensions it represents must be total dim - letter dimensions
int64_t candidate_num_ell_idxes = tensors[operand].dim() - term.size() + 3;
int64_t dims_in_term = 0; // dimensions we have seen
std::vector<int64_t> current_op_idxes; // mapping of operand dimensions to indices for current term
for (auto &c : term) { // c = character with a single letter or '.'
if (c == '.') {
ell_char_count++;
AT_CHECK(ell_char_count <= 3, "can only have '.' in one ellispis '...' in term ", operand, " of the equation");
if (ell_char_count == 3) { // this completes the ellipsis
if (num_ell_idxes == -1) { // if we have not seen an ellipsis before, keep track of indices and size
first_ell_idx = num_total_idxes;
num_ell_idxes = candidate_num_ell_idxes;
num_total_idxes += num_ell_idxes;
}
else { // we have seen an ellipsis before, so we check compatibility
AT_CHECK(candidate_num_ell_idxes == num_ell_idxes,
"ellipsis must represent ", num_ell_idxes, " dimensions in all terms");
}
for (int64_t i = 0; i < num_ell_idxes; ++i) { // map ellipsis dimensions in operand to indices
current_op_idxes.push_back(first_ell_idx + i);
last_idx_occurrence.push_back(operand);
}
dims_in_term += num_ell_idxes; // keep track of dimensions
}
} else { // a letter (hopefully)
AT_CHECK((ell_char_count == 0) || (ell_char_count == 3), "'.' must only occur in ellipsis, operand ", operand);
AT_CHECK(('a' <= c) && (c <= 'z'), "only lowercase letters a-z allowed as indices");
int64_t letter_num = c-'a'; // letter_num = position in letter_mapping
if (letter_mapping[letter_num] == -1) { // new letter, add internal index and mapping
letter_mapping[letter_num] = num_total_idxes;
num_total_idxes++;
last_idx_occurrence.push_back(operand);
} else { // letter we have already seen
last_idx_occurrence[letter_mapping[letter_num]] = operand;
}
num_letter_occurrences[letter_num]++;
current_op_idxes.push_back(letter_mapping[letter_num]);
dims_in_term++;
}
}
AT_CHECK(dims_in_term == tensors[operand].dim(), "dimension mismatch for operand ", operand, ": equation ", dims_in_term, " tensor ", tensors[operand].dim());
input_op_idxes.push_back(std::move(current_op_idxes));
operand++;
}
// in the check below, we need ==, but > is captured above, so the error message can be specific that it is <.
AT_CHECK((int64_t) tensors.size()==operand, "more tensors than operands in equation");
// the following parses or infers output (right hand side)
// it also assigns the idxes_to_preprocessed_dims (index -> dimension in preprocessed / output tensors)
// for the output indices. -1 means that the index has not been assigned a dimension yet
std::vector<int64_t> idxes_to_preprocessed_dims(num_total_idxes, -1); // the position of the index in the tensor dimensions
int64_t num_output_dims = 0;
if (pos != std::string::npos) { // parse the user provided right hand side
int64_t ell_char_count = 0;
for (auto &c : eqn.substr(pos+2)) {
if (c == '.') { // '.' as part of ellipsis
ell_char_count++;
AT_CHECK(ell_char_count <= 3, "can only have '.' in one ellispis '...' in right hand side of the equation");
if (ell_char_count == 3) { // ellipsis complete
AT_CHECK(num_ell_idxes >= 0, "ellipsis '...' may only appear in right hand side if it does in left hand side");
for (int64_t i = 0; i < num_ell_idxes; ++i) {
idxes_to_preprocessed_dims[first_ell_idx + i] = num_output_dims;
num_output_dims++;
}
}
} else if (! isspace(c)) { // letter (hopefully)
AT_CHECK((ell_char_count == 0) || (ell_char_count == 3), "'.' must only occur in ellipsis in the right hand side");
AT_CHECK(('a' <= c) && (c <= 'z'), "only lowercase letters a-z allowed as indices");
int64_t letter_num = c-'a';
AT_CHECK(idxes_to_preprocessed_dims[letter_mapping[letter_num]] == -1, "index ", c, "occurs twice in output");
idxes_to_preprocessed_dims[letter_mapping[letter_num]] = num_output_dims;
num_output_dims++;
}
}
} else { // create an inferred right hand side
// the ellipsis (if in the lhs) comes first
if (num_ell_idxes >= 0) {
for (int64_t i = 0; i < num_ell_idxes; ++i) {
idxes_to_preprocessed_dims[first_ell_idx + i] = num_output_dims;
num_output_dims++;
}
}
// then the indices that occur exactly once in alphabetic order
for (size_t idx = 0; idx < number_of_letters; idx++) {
if (num_letter_occurrences[idx] == 1) {
idxes_to_preprocessed_dims[letter_mapping[idx]] = num_output_dims;
num_output_dims++;
}
}
}
// now we assign the idxes_to_preprocessed_dims (index -> dimension in preprocessed / output tensors)
// for the non-output indices - those that are eventually summed over
int64_t position = num_output_dims;
for (int64_t i = 0; i < num_total_idxes; i++) {
if (idxes_to_preprocessed_dims[i]==-1) {
idxes_to_preprocessed_dims[i] = position;
position++;
}
}
// we now "homogenize the dimensions", i.e.
// - take diagonals for duplicated indices
// - permute the dimensions to match the order given by idxes_to_preprocessed_dims
// - unsqueeze to create all dimensions for each index in each tensor where they are missing
// we also check that sizes match
// after this, all operands will have compatible shapes (i.e. all dimensions are aligned are broadcastable)
std::vector<Tensor> preprocessed_operands;
std::vector<std::int64_t> size_of_dims(num_total_idxes, -1); // keep track of sizes for each index, -1 means we have not seen a size yet
for (int64_t op = 0; op < (int64_t) tensors.size(); op++) {
auto preprocessed_op = tensors[op];
std::vector<int64_t> idx_to_dim(num_total_idxes, -1); // the dimension which the index refers to in the original tensor, -1 means it does not appear
std::vector<int64_t>& current_op_input_idxes = input_op_idxes[op];
int64_t dim = 0; // there are two dimension indices: dim is after taking diagonals, i is in input
for (size_t i = 0; i < current_op_input_idxes.size(); i++) {
auto idx = current_op_input_idxes[i];
auto dim_out = idxes_to_preprocessed_dims[idx];
if (idx_to_dim[dim_out] == -1) { // first appearance
idx_to_dim[dim_out] = dim;
if (size_of_dims[idx] == -1) { // keep track of sizes
size_of_dims[idx] = preprocessed_op.size(dim);
}
else {
AT_CHECK(size_of_dims[idx] == preprocessed_op.size(dim), "size of dimension does not match previous size, operand ", op, ", dim ", i);
}
dim++;
} else { // duplicate dimension in tensor --> take diagonal of idx_to_dim[dim_out] and dim and put the diagonal dimension to idx_to_dim[dim_out]
AT_CHECK(size_of_dims[idx] == preprocessed_op.size(dim), "size of dimension does not match previous size, operand ", op, ", dim ", i);
preprocessed_op = preprocessed_op.diagonal(0, idx_to_dim[dim_out], dim);
// diagonal moves the diagonal dimension to the back
// now we permute the last dim back to idx_to_dim[dim_out]
std::vector<int64_t> perm(preprocessed_op.dim(), 0);
for (int64_t d = 0; d < preprocessed_op.dim(); d++) {
if (d == idx_to_dim[dim_out]) {
perm[d] = preprocessed_op.dim() - 1;
} else {
perm[d] = d - (d > idx_to_dim[dim_out]);
}
}
preprocessed_op = preprocessed_op.permute(perm);
}
}
// now we permute the dimensions in the right order
std::vector<int64_t> permutation; // permutation for this tensor
for (auto &d : idx_to_dim) {
if (d > -1) {
permutation.push_back(d);
}
}
preprocessed_op = preprocessed_op.permute(permutation);
// finally, we insert dimensions for idxes not in the operand
for (size_t dim = 0; dim < idx_to_dim.size(); dim++) {
if (idx_to_dim[dim] == -1) {
preprocessed_op = preprocessed_op.unsqueeze(dim);
}
}
preprocessed_operands.push_back(preprocessed_op);
}
// now we reduce the indices from left to right
// numpy allows to optimize the path using various
// algorithms (see eigen_path in numpy docs)
// we start with the leftmost operator and reduce indices that
// appear only there
Tensor result = preprocessed_operands[0];
for (int64_t idx = 0; idx < num_total_idxes; idx++) {
if ((last_idx_occurrence[idx] == 0)
&& (idxes_to_preprocessed_dims[idx]>=num_output_dims)) {
result = result.sum(idxes_to_preprocessed_dims[idx], true);
}
}
// now we process each tensor using sumproduct_pair
for (int64_t i = 1; i < (int64_t) preprocessed_operands.size(); i++) {
std::vector<int64_t> sum_dims;
for (int64_t idx = 0; idx < num_total_idxes; idx++) {
if ((last_idx_occurrence[idx] == i)
&& (idxes_to_preprocessed_dims[idx]>=num_output_dims)) {
sum_dims.push_back(idxes_to_preprocessed_dims[idx]);
}
}
result = at::native::sumproduct_pair(result, preprocessed_operands[i], sum_dims, true);
}
// finally, we squeeze out all non-result dimensions
for (int64_t dim = num_total_idxes-1; dim >= num_output_dims; dim--)
result.squeeze_(dim);
return result;
}
// _trilinear computes a trilinear einstein sum with an unrolled dimension
// the result is `(i1.unsqueeze(expand1)*i2.unsqueeze(expand2)*i2.unsqueeze(expand3)).sum(sumdim)`
// the computation is unrolled in the unroll_dim dimension
// its main purpose is to unify the computations in bilinear and bilinear_backward
Tensor _trilinear(const Tensor& i1_, const Tensor& i2_, const Tensor& i3_,
IntList expand1_, IntList expand2_, IntList expand3_,
IntList sumdim_, int64_t unroll_dim) {
int64_t total_dim = i1_.dim()+expand1_.size();
AT_CHECK((unroll_dim >= 0) && (unroll_dim < total_dim), "unroll_dim must be in [0,", total_dim-1, "]");
auto expand1 = dim_list_to_bitset(expand1_, total_dim);
auto expand2 = dim_list_to_bitset(expand2_, total_dim);
auto expand3 = dim_list_to_bitset(expand3_, total_dim);
auto sumdim = dim_list_to_bitset(sumdim_, total_dim);
Tensor i1 = i1_;
Tensor i2 = i2_;
Tensor i3 = i3_;
std::vector<int64_t> output_size;
std::vector<int64_t> sum_dims_12, sum_dims_23;
int64_t unroll_size = -1;
// asserts...
for (int64_t i = 0; i < total_dim; i++) {
int64_t s = 0;
if (expand1[i]) {
i1 = i1.unsqueeze(i);
} else {
s = i1.size(i);
}
if (expand2[i]) {
i2 = i2.unsqueeze(i);
} else {
s = i2.size(i);
}
if (expand3[i]) {
i3 = i3.unsqueeze(i);
if (sumdim[i] && (i != unroll_dim))
sum_dims_12.push_back(i);
} else {
s = i3.size(i);
if (sumdim[i] && (i != unroll_dim))
sum_dims_23.push_back(i);
}
output_size.push_back(sumdim[i] ? 1 : s);
if (i == unroll_dim)
unroll_size = s;
}
int64_t slicemul1 = (expand1[unroll_dim] ? 0 : 1);
int64_t slicemul2 = (expand2[unroll_dim] ? 0 : 1);
int64_t slicemul3 = (expand3[unroll_dim] ? 0 : 1);
auto output = i1.type().tensor(output_size).zero_();
if (! sumdim[unroll_dim]) {
for (int64_t k = 0; k < unroll_size; k++) {
Tensor buf = at::native::sumproduct_pair(i1.narrow(unroll_dim, k * slicemul1, 1),
i2.narrow(unroll_dim, k * slicemul2, 1),
sum_dims_12, true);
buf = at::native::sumproduct_pair(buf, i3.narrow(unroll_dim, k * slicemul3, 1), sum_dims_23, true);
output.narrow(unroll_dim, k, 1).add_(buf);
}
}
else {
for (int64_t k = 0; k < unroll_size; k++) {
Tensor buf = at::native::sumproduct_pair(i1.narrow(unroll_dim, k*slicemul1, 1),
i2.narrow(unroll_dim, k*slicemul2, 1), sum_dims_12, true);
buf = at::native::sumproduct_pair(buf, i3.narrow(unroll_dim, k*slicemul3, 1), sum_dims_23, true);
output.add_(buf);
}
}
for (int64_t i = output.dim()-1; i >= 0; i--)
if (sumdim[i])
output.squeeze_(i);
return output;
}
Tensor bilinear(const Tensor& input1, const Tensor& input2, const Tensor& weight, const Tensor& bias) {
AT_CHECK(input1.dim() == input2.dim(), "bilinear(): input dimensions do not match: got ", input1.dim(), " and ", input2.dim());
for (int64_t i = 0; i < input1.dim() - 1; i++) {
AT_CHECK(input1.size(i) == input2.size(i),
"bilinear(): input batch dimensions do not match at dim ", i, ": got ", input1.size(i), " and ", input2.size(i));
}
AT_CHECK(input1.size(input1.dim() - 1) == weight.size(1),
"bilinear(): input1 size does not match weight size: got ",
input1.size(input1.dim() - 1), " but expected ", weight.size(1));
AT_CHECK(input2.size(input2.dim() - 1) == weight.size(2),
"bilinear(): input2 size does not match weight size: got ",
input2.size(input2.dim() - 1), " but expected ", weight.size(2));
AT_CHECK(!bias.defined() || bias.size(0) == weight.size(0),
"bilinear(): bias size does not match weight size: got ",
bias.size(0), " but expected ", weight.size(0));
std::vector<int64_t> output_size;
auto size1 = input1.sizes();
output_size.insert(output_size.end(), size1.begin(), size1.end() - 1);
output_size.push_back(weight.size(0));
auto input1_flattened = input1.view({-1, input1.size(-1)});
auto input2_flattened = input2.view({-1, input2.size(-1)});
Tensor output = at::_trilinear(input1_flattened, weight, input2_flattened, {1,3}, {0}, {1,2}, {2,3}).reshape(output_size);
if (bias.defined()) {
output = output + bias;
}
return output;
}
}} // namespace at::native