| #pragma once |
| |
| #include "TH/TH.h" |
| |
| namespace at { |
| |
| // make a fake storage out of a size, pointer pair... |
| // used as an argument where THSize and THStride are passed into TH |
| class THLongStorageView { |
| public: |
| // zero_dim_to_one converts an empty ArrayRef into [1] |
| // empty_to_null converts an empty ArrayRef into a null THLongStorage |
| static THLongStorageView make(ArrayRef<int64_t> ref, bool zero_dim_to_one = false, bool empty_to_null = false) { |
| assert(!(zero_dim_to_one && empty_to_null)); |
| return THLongStorageView(ref, zero_dim_to_one, empty_to_null); |
| } |
| operator THLongStorage*() { |
| if (storage.size == 0 && empty_to_null) { |
| return nullptr; |
| } |
| return &storage; |
| } |
| private: |
| THLongStorageView(ArrayRef<int64_t> ref, bool zero_dim_to_one, bool empty_to_null) |
| : empty_to_null(empty_to_null) |
| { |
| if(zero_dim_to_one && ref.size() == 0) { |
| // make storage of size 0 actually a 1-length storage with 1 element |
| // so that our 0-dim tensors get allocated as 1-dim inside TH |
| one = 1; |
| storage.data = &one; |
| storage.size = 1; |
| } else { |
| storage.data = (int64_t*)(ref.data()); |
| storage.size = ref.size(); |
| } |
| storage.refcount = 0; |
| storage.flag = 0; |
| storage.allocator = nullptr; |
| storage.allocatorContext = nullptr; |
| } |
| int64_t one; |
| THLongStorage storage; |
| bool empty_to_null; |
| }; |
| |
| } |