blob: b7dc6752c475a25337f9c6f1fae77304276abedd [file] [log] [blame]
#pragma once
#include "TH/TH.h"
namespace at {
// make a fake storage out of a size, pointer pair...
// used as an argument where THSize and THStride are passed into TH
class THLongStorageView {
public:
// zero_dim_to_one converts an empty ArrayRef into [1]
// empty_to_null converts an empty ArrayRef into a null THLongStorage
static THLongStorageView make(ArrayRef<int64_t> ref, bool zero_dim_to_one = false, bool empty_to_null = false) {
assert(!(zero_dim_to_one && empty_to_null));
return THLongStorageView(ref, zero_dim_to_one, empty_to_null);
}
operator THLongStorage*() {
if (storage.size == 0 && empty_to_null) {
return nullptr;
}
return &storage;
}
private:
THLongStorageView(ArrayRef<int64_t> ref, bool zero_dim_to_one, bool empty_to_null)
: empty_to_null(empty_to_null)
{
if(zero_dim_to_one && ref.size() == 0) {
// make storage of size 0 actually a 1-length storage with 1 element
// so that our 0-dim tensors get allocated as 1-dim inside TH
one = 1;
storage.data = &one;
storage.size = 1;
} else {
storage.data = (int64_t*)(ref.data());
storage.size = ref.size();
}
storage.refcount = 0;
storage.flag = 0;
storage.allocator = nullptr;
storage.allocatorContext = nullptr;
}
int64_t one;
THLongStorage storage;
bool empty_to_null;
};
}