blob: 2dce3ef5ee2957cf2d0e8de5b018e62b54e91cff [file] [log] [blame]
#pragma once
#include "ArrayRef.h"
#include "ATenGeneral.h"
#include <algorithm>
#include <sstream>
#include <typeinfo>
namespace at {
#define AT_ASSERT(cond, ...) if (! (cond) ) { at::runtime_error(__VA_ARGS__); }
[[noreturn]]
ATen_CLASS void runtime_error(const char *format, ...);
template <typename T, typename Base>
static inline T* checked_cast(Base* expr, const char * name, int pos, bool allowNull) {
if(!expr) {
if (allowNull) {
return (T*) expr;
}
runtime_error("Expected a Tensor of type %s but found an undefined Tensor for argument #%d '%s'",
T::typeString(),pos,name);
}
if (typeid(*expr) != typeid(T))
runtime_error("Expected object of type %s but found type %s for argument #%d '%s'",
T::typeString(),expr->type().toString(),pos,name);
return static_cast<T*>(expr);
}
// Converts a TensorList (i.e. ArrayRef<Tensor> to the underlying TH* Tensor Pointer)
template <typename T, typename TBase, typename TH>
static inline std::vector<TH*> tensor_list_checked_cast(ArrayRef<TBase> tensors, const char * name, int pos) {
std::vector<TH*> casted(tensors.size());
for (unsigned int i = 0; i < tensors.size(); ++i) {
auto *expr = tensors[i].pImpl;
if (!expr) {
runtime_error("Expected a Tensor of type %s but found an undefined Tensor for sequence element %u "
" in sequence argument at position #%d '%s'",
T::typeString(),i,pos,name);
}
auto result = dynamic_cast<T*>(expr);
if (result) {
casted[i] = result->tensor;
} else {
runtime_error("Expected a Tensor of type %s but found a type %s for sequence element %u "
" in sequence argument at position #%d '%s'",
T::typeString(),expr->type().toString(),i,pos,name);
}
}
return casted;
}
template <size_t N>
std::array<int64_t, N> check_intlist(ArrayRef<int64_t> list, const char * name, int pos, ArrayRef<int64_t> def={}) {
if (list.empty()) {
list = def;
}
auto res = std::array<int64_t, N>();
if (list.size() == 1 && N > 1) {
res.fill(list[0]);
return res;
}
if (list.size() != N) {
runtime_error("Expected a list of %zd ints but got %zd for argument #%d '%s'",
N, list.size(), pos, name);
}
std::copy_n(list.begin(), N, res.begin());
return res;
}
} // at