| #include "caffe2/utils/math_utils.h" |
| |
| #include <algorithm> |
| #include <functional> |
| #include <numeric> |
| #include <vector> |
| |
| #include "caffe2/core/logging.h" |
| |
| namespace caffe2 { |
| namespace math { |
| namespace utils { |
| |
| void IncreaseIndexInDims(const int n, const int* dims, int* index) { |
| for (int i = n - 1; i >= 0; --i) { |
| ++index[i]; |
| if (index[i] >= dims[i]) { |
| index[i] -= dims[i]; |
| } else { |
| break; |
| } |
| } |
| } |
| |
| int GetIndexFromDims(const int n, const int* dims, const int* index) { |
| int sum = 0; |
| for (int i = 0; i < n; ++i) { |
| if (dims[i] > 1) { |
| sum = sum * dims[i] + index[i]; |
| } |
| } |
| return sum; |
| } |
| |
| bool IsIdentityPermutation(const int n, const int* perm) { |
| for (int i = 0; i < n; ++i) { |
| if (perm[i] != i) { |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| bool IsRowwiseReduce( |
| const int ndim, |
| const int* A_dims, |
| const int* B_dims, |
| int* rows, |
| int* cols) { |
| *cols = 1; |
| int pivot = ndim - 1; |
| for (; pivot >= 0 && B_dims[pivot] == 1; --pivot) { |
| *cols *= A_dims[pivot]; |
| } |
| *rows = 1; |
| for (int i = pivot; i >= 0; --i) { |
| if (A_dims[i] != B_dims[i]) { |
| return false; |
| } |
| *rows *= A_dims[i]; |
| } |
| return true; |
| } |
| |
| bool IsColwiseReduce( |
| const int ndim, |
| const int* A_dims, |
| const int* B_dims, |
| int* rows, |
| int* cols) { |
| *rows = 1; |
| int pivot = 0; |
| for (; pivot < ndim && B_dims[pivot] == 1; ++pivot) { |
| *rows *= A_dims[pivot]; |
| } |
| *cols = 1; |
| for (int i = pivot; i < ndim; ++i) { |
| if (A_dims[i] != B_dims[i]) { |
| return false; |
| } |
| *cols *= A_dims[i]; |
| } |
| return true; |
| } |
| |
| bool IsBothEndsReduce( |
| const int ndim, |
| const int* A_dims, |
| const int* B_dims, |
| int* pre, |
| int* mid, |
| int* nxt) { |
| *nxt = 1; |
| int r = ndim - 1; |
| for (; r >= 0 && B_dims[r] == 1; --r) { |
| *nxt *= A_dims[r]; |
| } |
| *pre = 1; |
| int l = 0; |
| for (; l <= r && B_dims[l] == 1; ++l) { |
| *pre *= A_dims[l]; |
| } |
| *mid = 1; |
| for (int i = l; i <= r; ++i) { |
| if (A_dims[i] != B_dims[i]) { |
| return false; |
| } |
| *mid *= A_dims[i]; |
| } |
| return true; |
| } |
| |
| void ComputeBroadcastBinaryOpDims( |
| const int A_ndim, |
| const int* A_dims, |
| const int B_ndim, |
| const int* B_dims, |
| int* A_broadcast_dims, |
| int* B_broadcast_dims, |
| int* C_broadcast_dims) { |
| const int ndim = std::max(A_ndim, B_ndim); |
| std::fill(A_broadcast_dims, A_broadcast_dims + ndim - A_ndim, 1); |
| std::fill(B_broadcast_dims, B_broadcast_dims + ndim - B_ndim, 1); |
| std::copy(A_dims, A_dims + A_ndim, A_broadcast_dims + ndim - A_ndim); |
| std::copy(B_dims, B_dims + B_ndim, B_broadcast_dims + ndim - B_ndim); |
| for (int i = 0; i < ndim; ++i) { |
| CAFFE_ENFORCE( |
| A_broadcast_dims[i] == B_broadcast_dims[i] || |
| A_broadcast_dims[i] <= 1 || B_broadcast_dims[i] <= 1); |
| if (A_broadcast_dims[i] == 0 || B_broadcast_dims[i] == 0) { |
| C_broadcast_dims[i] = 0; |
| } else { |
| C_broadcast_dims[i] = std::max(A_broadcast_dims[i], B_broadcast_dims[i]); |
| } |
| } |
| } |
| |
| bool IsRowwiseBroadcastBinaryOp( |
| const int ndim, |
| const int* A_dims, |
| const int* B_dims, |
| int* rows, |
| int* cols, |
| bool* broadcast_1st) { |
| if (ndim == 0) { |
| return false; |
| } |
| int A_pivot = 0; |
| for (; A_pivot < ndim && A_dims[A_pivot] == 1; ++A_pivot) |
| ; |
| int B_pivot = 0; |
| for (; B_pivot < ndim && B_dims[B_pivot] == 1; ++B_pivot) |
| ; |
| if (A_pivot == B_pivot) { |
| return false; |
| } |
| const int pivot = std::max(A_pivot, B_pivot); |
| if (A_pivot > B_pivot) { |
| *rows = std::accumulate( |
| B_dims + B_pivot, B_dims + pivot, 1, std::multiplies<int>()); |
| *broadcast_1st = true; |
| } else { |
| *rows = std::accumulate( |
| A_dims + A_pivot, A_dims + pivot, 1, std::multiplies<int>()); |
| *broadcast_1st = false; |
| } |
| *cols = 1; |
| for (int i = pivot; i < ndim; ++i) { |
| if (A_dims[i] != B_dims[i]) { |
| return false; |
| } |
| *cols *= A_dims[i]; |
| } |
| return true; |
| } |
| |
| bool IsColwiseBroadcastBinaryOp( |
| const int ndim, |
| const int* A_dims, |
| const int* B_dims, |
| int* rows, |
| int* cols, |
| bool* broadcast_1st) { |
| if (ndim == 0) { |
| return false; |
| } |
| int A_pivot = ndim - 1; |
| for (; A_pivot >= 0 && A_dims[A_pivot] == 1; --A_pivot) |
| ; |
| int B_pivot = ndim - 1; |
| for (; B_pivot >= 0 && B_dims[B_pivot] == 1; --B_pivot) |
| ; |
| if (A_pivot == B_pivot) { |
| return false; |
| } |
| ++A_pivot; |
| ++B_pivot; |
| const int pivot = std::min(A_pivot, B_pivot); |
| if (A_pivot < B_pivot) { |
| *cols = std::accumulate( |
| B_dims + pivot, B_dims + B_pivot, 1, std::multiplies<int>()); |
| *broadcast_1st = true; |
| } else { |
| *cols = std::accumulate( |
| A_dims + pivot, A_dims + A_pivot, 1, std::multiplies<int>()); |
| *broadcast_1st = false; |
| } |
| *rows = 1; |
| for (int i = 0; i < pivot; ++i) { |
| if (A_dims[i] != B_dims[i]) { |
| return false; |
| } |
| *rows *= A_dims[i]; |
| } |
| return true; |
| } |
| |
| bool IsBothEndsBroadcastBinaryOp( |
| const int ndim, |
| const int* A_dims, |
| const int* B_dims, |
| int* pre, |
| int* mid, |
| int* nxt, |
| bool* broadcast_1st) { |
| if (ndim == 0) { |
| return false; |
| } |
| int A_pre = 0; |
| for (; A_pre < ndim && A_dims[A_pre] == 1; ++A_pre) |
| ; |
| int B_pre = 0; |
| for (; B_pre < ndim && B_dims[B_pre] == 1; ++B_pre) |
| ; |
| int A_nxt = ndim - 1; |
| for (; A_nxt >= 0 && A_dims[A_nxt] == 1; --A_nxt) |
| ; |
| int B_nxt = ndim - 1; |
| for (; B_nxt >= 0 && B_dims[B_nxt] == 1; --B_nxt) |
| ; |
| ++A_nxt; |
| ++B_nxt; |
| if (A_pre == B_pre || A_nxt == B_nxt) { |
| return false; |
| } |
| if (A_pre > B_pre && A_nxt < B_nxt) { |
| *pre = std::accumulate( |
| B_dims + B_pre, B_dims + A_pre, 1, std::multiplies<int>()); |
| *nxt = std::accumulate( |
| B_dims + A_nxt, B_dims + B_nxt, 1, std::multiplies<int>()); |
| *broadcast_1st = true; |
| } else if (A_pre < B_pre && A_nxt > B_nxt) { |
| *pre = std::accumulate( |
| A_dims + A_pre, A_dims + B_pre, 1, std::multiplies<int>()); |
| *nxt = std::accumulate( |
| A_dims + B_nxt, A_dims + A_nxt, 1, std::multiplies<int>()); |
| *broadcast_1st = false; |
| } else { |
| return false; |
| } |
| const int l = std::max(A_pre, B_pre); |
| const int r = std::min(A_nxt, B_nxt); |
| *mid = 1; |
| for (int i = l; i < r; ++i) { |
| if (A_dims[i] != B_dims[i]) { |
| return false; |
| } |
| *mid *= A_dims[i]; |
| } |
| return true; |
| } |
| |
| void ComputeTransposeAxesForReduceOp( |
| const int num_dims, |
| const int num_reduce_axes, |
| const int* reduce_axes, |
| int* transpose_axes) { |
| const int d = num_dims - num_reduce_axes; |
| std::copy_n(reduce_axes, num_reduce_axes, transpose_axes + d); |
| std::sort(transpose_axes + d, transpose_axes + num_dims); |
| int p = 0; |
| int q = d; |
| for (int i = 0; i < num_dims; ++i) { |
| if (q < num_dims && i == transpose_axes[q]) { |
| ++q; |
| } else { |
| transpose_axes[p++] = i; |
| } |
| } |
| } |
| |
| void ComputeTransposedStrides( |
| const int ndim, |
| const int* dims, |
| const int* axes, |
| int* strides) { |
| std::vector<int> buff(ndim); |
| int cur_stride = 1; |
| for (int i = ndim - 1; i >= 0; --i) { |
| buff[i] = cur_stride; |
| cur_stride *= dims[i]; |
| } |
| for (int i = 0; i < ndim; ++i) { |
| strides[i] = buff[axes[i]]; |
| } |
| } |
| |
| } // namespace utils |
| } // namespace math |
| } // namespace caffe2 |