blob: 64ecc0a52e46763e8b22ec4b4206289f8f3a36ba [file] [log] [blame]
#ifndef THP_EXPAND_UTILS_H
#define THP_EXPAND_UTILS_H
#include <sstream>
#include <Python.h>
template <typename ExpandType>
ExpandType *newForExpand(LIBRARY_STATE_TYPE_NOARGS);
template <typename TensorType>
void expand(LIBRARY_STATE_TYPE TensorType *r, TensorType *tensor, THLongStorage *sizes);
template <typename TensorType>
void expand2(LIBRARY_STATE_TYPE TensorType *r1, TensorType *r2,
TensorType *e1, TensorType *e2);
template <typename TensorType>
void expand3(LIBRARY_STATE_TYPE TensorType *r1, TensorType *r2, TensorType *r3,
TensorType *e1, TensorType *e2, TensorType *e3);
template <typename ExpandType, typename TensorType>
void check_backincompat_expand_warn(ExpandType *to_expand, TensorType *tensor,
char *to_expand_name, char *tensor_name, bool fallback,
ptrdiff_t to_expand_nElem, ptrdiff_t tensor_nElem) {
if (fallback && getBackCompatBroadcastWarn()) {
bool same_shape = THSize_isSameSizeAs(tensor->size, tensor->nDimension,
to_expand->size, to_expand->nDimension);
if (!same_shape && (tensor_nElem == to_expand_nElem)) {
std::ostringstream warn;
warn << tensor_name << " and " << to_expand_name << " do not have the same shape, but are "
<< "broadcastable, and have the same number of elements. Changing behavior in a backwards incompatible "
<< "manner to broadcasting rather than viewing as 1-dimensional.";
PyErr_WarnEx(PyExc_UserWarning, warn.str().c_str(), 1);
}
}
}
template <typename ExpandType, typename TensorType>
void expand_inplace(LIBRARY_STATE_TYPE ExpandType *r, ExpandType *to_expand, TensorType *tensor,
char *to_expand_name, char *tensor_name, bool fallback,
THLongStorage *tensor_size, ptrdiff_t to_expand_nElem, ptrdiff_t tensor_nElem,
bool warn_pointwise_fallback) {
try {
expand<ExpandType>(LIBRARY_STATE r, to_expand, tensor_size);
} catch (std::exception &e) {
if (warn_pointwise_fallback) {
std::ostringstream warn;
warn << to_expand_name << " is not broadcastable to " << tensor_name
<< ", but they have the same number of elements. Falling back to deprecated pointwise behavior.";
PyErr_WarnEx(PyExc_UserWarning, warn.str().c_str(), 1);
}
throw;
}
}
template <typename ExpandType, typename TensorType>
void expand_inplace1(LIBRARY_STATE_TYPE ExpandType *r, ExpandType *to_expand, TensorType *tensor,
char *to_expand_name, char *tensor_name, bool fallback) {
ptrdiff_t to_expand_nElem = THSize_nElement(to_expand->nDimension, to_expand->size);
ptrdiff_t tensor_nElem = THSize_nElement(tensor->nDimension, tensor->size);
bool to_expand_warn = fallback && (to_expand_nElem == tensor_nElem) && to_expand_nElem != 0;
THLongStoragePtr tensor_size(THLongStorage_newWithSize(tensor->nDimension));
THLongStorage_rawCopy(tensor_size.get(), tensor->size);
expand_inplace(LIBRARY_STATE r, to_expand, tensor, to_expand_name, tensor_name, fallback,
tensor_size, to_expand_nElem, tensor_nElem, to_expand_warn);
check_backincompat_expand_warn<ExpandType, TensorType>(to_expand, tensor, to_expand_name, tensor_name, fallback,
to_expand_nElem, tensor_nElem);
}
template <typename TensorType>
void expand_inplace2(LIBRARY_STATE_TYPE TensorType *r1, TensorType *r2,
TensorType *to_expand1, TensorType *to_expand2, TensorType *tensor,
char *to_expand1_name, char *to_expand2_name, char *tensor_name, bool fallback) {
ptrdiff_t tensor_nElem = THSize_nElement(tensor->nDimension, tensor->size);
ptrdiff_t to_expand1_nElem = THSize_nElement(to_expand1->nDimension, to_expand1->size);
ptrdiff_t to_expand2_nElem = THSize_nElement(to_expand2->nDimension, to_expand2->size);
bool to_expand1_warn = fallback && (tensor_nElem == to_expand1_nElem) && tensor_nElem != 0;
bool to_expand2_warn = fallback && (tensor_nElem == to_expand2_nElem) && tensor_nElem != 0;
THLongStoragePtr tensor_size(THLongStorage_newWithSize(tensor->nDimension));
THLongStorage_rawCopy(tensor_size.get(), tensor->size);
expand_inplace(LIBRARY_STATE r1, to_expand1, tensor, to_expand1_name, tensor_name, fallback,
tensor_size, to_expand1_nElem, tensor_nElem, to_expand1_warn && to_expand2_warn);
expand_inplace(LIBRARY_STATE r2, to_expand2, tensor, to_expand2_name, tensor_name, fallback,
tensor_size, to_expand2_nElem, tensor_nElem, to_expand1_warn && to_expand2_warn);
check_backincompat_expand_warn<TensorType, TensorType>(to_expand1, tensor, to_expand1_name, tensor_name, fallback,
to_expand1_nElem, tensor_nElem);
check_backincompat_expand_warn<TensorType, TensorType>(to_expand2, tensor, to_expand2_name, tensor_name, fallback,
to_expand2_nElem, tensor_nElem);
}
template <typename TensorType>
void expand_outplace2(LIBRARY_STATE_TYPE TensorType *r1, TensorType *r2,
TensorType *to_expand1, TensorType *to_expand2,
char *to_expand1_name, char *to_expand2_name, bool fallback) {
ptrdiff_t to_expand1_nElem = THSize_nElement(to_expand1->nDimension, to_expand1->size);
ptrdiff_t to_expand2_nElem = THSize_nElement(to_expand2->nDimension, to_expand2->size);
bool expand_warn = fallback && (to_expand1_nElem == to_expand2_nElem) && to_expand1_nElem != 0;
try {
expand2<TensorType>(LIBRARY_STATE r1, r2, to_expand1, to_expand2);
} catch (std::exception &e) {
if (expand_warn) {
std::ostringstream warn;
warn << to_expand1_name << " and " << to_expand2_name << " not broadcastable, but have the same number of "
<< "elements. Falling back to deprecated pointwise behavior.";
PyErr_WarnEx(PyExc_UserWarning, warn.str().c_str(), 1);
}
throw;
}
check_backincompat_expand_warn<TensorType, TensorType>(to_expand1, to_expand2, to_expand1_name, to_expand2_name,
fallback, to_expand1_nElem, to_expand2_nElem);
}
template <typename TensorType>
void expand_outplace3(LIBRARY_STATE_TYPE TensorType *r1, TensorType *r2, TensorType *r3,
TensorType *to_expand1, TensorType *to_expand2, TensorType *to_expand3,
char *to_expand1_name, char *to_expand2_name, char *to_expand3_name, bool fallback) {
ptrdiff_t to_expand1_nElem = THSize_nElement(to_expand1->nDimension, to_expand1->size);
ptrdiff_t to_expand2_nElem = THSize_nElement(to_expand2->nDimension, to_expand2->size);
ptrdiff_t to_expand3_nElem = THSize_nElement(to_expand3->nDimension, to_expand3->size);
bool to_expand2_warn = fallback && (to_expand1_nElem == to_expand2_nElem) && to_expand1_nElem != 0;
bool to_expand3_warn = fallback && (to_expand1_nElem == to_expand3_nElem) && to_expand1_nElem != 0;
try {
expand3<TensorType>(LIBRARY_STATE r1, r2, r3, to_expand1, to_expand2, to_expand3);
} catch (std::exception &e) {
if(to_expand2_warn && to_expand3_warn) {
std::ostringstream warn;
warn << to_expand1_name << ", " << to_expand2_name << ", and " << to_expand3_name << " not broadcastable,"
<< " but have the same number of elements. Falling back to deprecated pointwise behavior.";
PyErr_WarnEx(PyExc_UserWarning, warn.str().c_str(), 1);
}
throw;
}
check_backincompat_expand_warn<TensorType, TensorType>(to_expand1, to_expand2, to_expand1_name, to_expand2_name,
fallback, to_expand1_nElem, to_expand2_nElem);
check_backincompat_expand_warn<TensorType, TensorType>(to_expand1, to_expand3, to_expand1_name, to_expand3_name,
fallback, to_expand1_nElem, to_expand3_nElem);
}
#endif