Stop using THLongStorage for sizes/strides, remove THLongStorageView.
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/10219
Reviewed By: cpuhrsch
Differential Revision: D9159550
Pulled By: gchanan
fbshipit-source-id: 745a6d335613688ed41b32369ee4938907ce8cbb
diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap
index 1b1a456..33eb98d 100644
--- a/aten/src/ATen/Declarations.cwrap
+++ b/aten/src/ATen/Declarations.cwrap
@@ -23,9 +23,9 @@
cpu_half: True
arguments:
- THTensor* self
- - arg: THSize* size
+ - arg: IntListSize size
long_args: True
- - CONSTANT NULL
+ - CONSTANT {}
]]
[[
name: set_
@@ -43,23 +43,23 @@
scalar_check: False
arguments:
- THTensor* self
- - CONSTANT NULL, 0, THLongStorageView({0}, THLongStorageViewKind::SIZE), NULL
+ - CONSTANT NULL, 0, {0}, {}
- cname: setStorage
scalar_check: False
arguments:
- THTensor* self
- THStorage* source
- CONSTANT 0
- - CONSTANT __storage_size.get()
- - CONSTANT NULL
+ - CONSTANT {static_cast<int64_t>(source.pImpl()->size())}
+ - CONSTANT {}
- cname: setStorage
arguments:
- THTensor* self
- THStorage* source
- long storage_offset
- - THSize* size
- - arg: THStride* stride
- default: NULL
+ - IntListSize size
+ - arg: IntListStride stride
+ default: {}
]]
[[
name: _fill_
@@ -171,7 +171,7 @@
return: THTensor*
arguments:
- THTensor* self
- - arg: THSize* size
+ - arg: IntListSize size
long_args: True
]]
[[
@@ -3393,8 +3393,8 @@
arguments: []
- cname: newWithSize
arguments:
- - THSize* size
- - CONSTANT NULL
+ - IntListSize size
+ - CONSTANT {}
]]
[[
name: tensor
@@ -3404,15 +3404,15 @@
options:
- cname: newWithSize
arguments:
- - THSize* size
- - arg: THStride* stride
+ - IntListSize size
+ - arg: IntListStride stride
- cname: newWithStorage
arguments:
- THStorage* storage
- int64_t storageOffset
- - THSize* size
- - arg: THStride* stride
- default: NULL
+ - IntListSize size
+ - arg: IntListStride stride
+ default: {}
]]
# In theory, this could be a part of the above declaration. But in
diff --git a/aten/src/ATen/InferSize.h b/aten/src/ATen/InferSize.h
new file mode 100644
index 0000000..6bf5382
--- /dev/null
+++ b/aten/src/ATen/InferSize.h
@@ -0,0 +1,44 @@
+#pragma once
+
+#include <ATen/optional.h>
+#include <ATen/ScalarType.h>
+#include <sstream>
+#include <vector>
+
+namespace at {
+
+// Infers the size of a dim with size -1, if it exists. Also checks that new
+// shape is compatible with the number of elements.
+static std::vector<int64_t> infer_size(IntList shape, int64_t numel) {
+ auto res = shape.vec();
+ int64_t newsize = 1;
+ auto infer_dim = at::optional<int64_t>();
+ for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) {
+ if (shape[dim] == -1) {
+ if (infer_dim) {
+ throw std::runtime_error("only one dimension can be inferred");
+ }
+ infer_dim = dim;
+ } else if (shape[dim] >= 0) {
+ newsize *= shape[dim];
+ } else {
+ AT_ERROR("invalid shape dimension ", shape[dim]);
+ }
+ }
+
+ if (numel == newsize || (infer_dim && newsize > 0 && numel % newsize == 0)) {
+ if (infer_dim) {
+ // we have a degree of freedom here to select the dimension size; follow NumPy semantics
+ // and just bail.
+ AT_CHECK(newsize != 0, "cannot reshape tensor of 0 elements into shape ", shape);
+ res[*infer_dim] = numel / newsize;
+ }
+ return res;
+ }
+
+ std::ostringstream ss;
+ ss << "shape '" << shape << "' is invalid for input of size " << numel;
+ throw std::runtime_error(ss.str());
+}
+
+}
diff --git a/aten/src/ATen/THLongStorageView.h b/aten/src/ATen/THLongStorageView.h
deleted file mode 100644
index 3d50678..0000000
--- a/aten/src/ATen/THLongStorageView.h
+++ /dev/null
@@ -1,86 +0,0 @@
-#pragma once
-
-#include <ATen/StorageImpl.h>
-#include "TH/TH.h"
-#include "TH/THStorageFunctions.hpp"
-#include "TH/THTypeConversion.hpp"
-
-namespace at {
-
-enum class THLongStorageViewKind {
- SIZE,
- STRIDE,
- LENGTH,
-};
-
-// make a fake storage out of a size, pointer pair...
-// used as an argument where THSize and THStride are passed into TH
-class THLongStorageView {
-public:
- operator StorageImpl*() {
- if (storage.pImpl()->size() == 0 && zero_dim_to_null) {
- return nullptr;
- }
- return storage.pImpl();
- }
-
- /*
- // This is done as an enum, and not as static constructors, as there
- // is no move/copy constructor for THLongStorageView
-
- static THLongStorageView makeFromSize(ArrayRef<int64_t> ref) {
- ...
- }
-
- static THLongStorageView makeFromLength(ArrayRef<int64_t> ref) {
- ...
- }
- */
-
- THLongStorageView(ArrayRef<int64_t> ref, THLongStorageViewKind kind)
- : storage(nullptr), zero_dim_to_null(false) {
- // zero_dim_to_one converts an empty ArrayRef into [1]
- // zero_dim_to_null converts an empty ArrayRef into a null THLongStorage
- bool zero_dim_to_one = false;
- switch (kind) {
- case THLongStorageViewKind::SIZE:
- zero_dim_to_one = true;
- break;
- case THLongStorageViewKind::STRIDE:
- zero_dim_to_null = true;
- break;
- case THLongStorageViewKind::LENGTH:
- break;
- }
-
- if (zero_dim_to_one && ref.size() == 0) {
- // make storage of size 0 actually a 1-length storage with 1 element
- // so that our 0-dim tensors get allocated as 1-dim inside TH
-
- one = 1;
- storage.set_pImpl(new StorageImpl(
- at::CTypeToScalarType<th::from_type<int64_t>>::to(),
- 1,
- {&one, kCPU}, // non-owning
- nullptr,
- false));
- } else {
- storage.set_pImpl(new StorageImpl(
- at::CTypeToScalarType<th::from_type<int64_t>>::to(),
- ref.size(),
- {const_cast<void*>(static_cast<const void*>(ref.data())),
- kCPU}, // non-owning
- nullptr,
- false));
- }
- }
-private:
- int64_t one;
- // NB: The lifetime of objects like one are tied to the lifetime of an
- // instance of this class. That means if storage is used after an instance of
- // this class dies, it'll be corrupted.
- Storage storage;
- bool zero_dim_to_null;
-};
-
-}
diff --git a/aten/src/ATen/THSizeStrideCompat.h b/aten/src/ATen/THSizeStrideCompat.h
new file mode 100644
index 0000000..bbcdd77
--- /dev/null
+++ b/aten/src/ATen/THSizeStrideCompat.h
@@ -0,0 +1,32 @@
+#pragma once
+
+#include <ATen/ScalarType.h>
+#include <vector>
+
+// NOTE: these functions are for compatibility into TH functions that takes sizes and strides.
+// We should just write the TH functions that don't require this, but that involves two steps:
+// 1) first class scalar support (for sizes)
+// 2) differentiating between nullptr/non-nullptr strides (the former "infers" strides).
+
+namespace at {
+
+static inline at::IntList get_intlist_size_th(IntList sizes) {
+ static int64_t one = 1;
+ if (sizes.size() == 0) {
+ // fake scalar
+ return IntList(&one, 1);
+ } else {
+ return sizes;
+ }
+}
+
+static inline IntList get_intlist_stride_th(IntList strides) {
+ if (strides.size() == 0) {
+ // differentiating between nullptr/non-nullptr strides (the former "infers" strides)
+ return IntList();
+ } else {
+ return strides;
+ }
+}
+
+}
diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py
index 05e045d..8b111be 100644
--- a/aten/src/ATen/function_wrapper.py
+++ b/aten/src/ATen/function_wrapper.py
@@ -210,8 +210,8 @@
'THDenseIndexTensor*': 'Tensor &',
'THStorage*': 'Storage &',
'THGenerator*': 'Generator *',
- 'THSize*': 'IntList',
- 'THStride*': 'IntList',
+ 'IntListSize': 'IntList',
+ 'IntListStride': 'IntList',
'accreal': 'Scalar',
'real': 'Scalar',
'long': 'int64_t',
@@ -227,8 +227,8 @@
'THDenseIndexTensor*': 'IndexTensor',
'THStorage*': 'Storage',
'THGenerator*': 'Generator*',
- 'THSize*': 'IntList',
- 'THStride*': 'IntList',
+ 'IntListSize': 'IntList',
+ 'IntListStride': 'IntList',
'accreal': 'accreal',
'real': 'real',
'long': 'int64_t',
@@ -297,9 +297,8 @@
CodeTemplate(
'check_generator<${Backend}Generator>(${arg_name}, &globalContext().defaultGenerator(backend()))'),
# This is a cast done via direct-construction
- 'THSize*': CodeTemplate('THLongStorageView ${result_name}(${arg_name}, THLongStorageViewKind::SIZE);'),
- # This is a cast done via direct-construction
- 'THStride*': CodeTemplate('THLongStorageView ${result_name}(${arg_name}, THLongStorageViewKind::STRIDE);'),
+ 'IntListSize': CodeTemplate('at::IntList ${result_name} = get_intlist_size_th(${arg_name});'),
+ 'IntListStride': CodeTemplate('at::IntList ${result_name} = get_intlist_stride_th(${arg_name});'),
'real': CodeTemplate('${arg_name}.to${ScalarName}()'),
'accreal': CodeTemplate('${arg_name}.to${AccScalarName}()'),
'TensorList': CodeTemplate(
@@ -309,7 +308,7 @@
'IntList': CodeTemplate('check_intlist<${size}>(${arg_name}, "${arg_name}", ${arg_pos}${,default_init})')
}
-DIRECT_CONSTRUCTION_CHECKED_CAST = {'THSize*', 'THStride*'}
+DIRECT_CONSTRUCTION_CHECKED_CAST = {'IntListSize', 'IntListStride'}
CHECKED_USE = {
'THTensor*': '{}_->tensor',
@@ -349,8 +348,6 @@
# Replacements for constants when calling into TH
CONSTANT_REPLACEMENTS = [
('AS_REAL', '${AS_REAL}'),
- ('__storage_size.get\\(\\)',
- 'THLongStorageView(static_cast<int64_t>(source.pImpl()->size()), THLongStorageViewKind::LENGTH)'),
('__last_dim', 'self.ndimension()-1'),
]
@@ -1327,7 +1324,7 @@
output_count = 0
# scalar_check is the heuristic conditions when a result may be a scalar_check
- # if there is a THSize* argument, then its dimensions are used to determine scalar.
+ # if there is a IntListSize argument, then its dimensions are used to determine scalar.
# otherwise, it is true if all the input tensors are scalars,
scalar_check_is_from_size = False
scalar_check_is_from_option = False
@@ -1343,7 +1340,7 @@
for arg in option['arguments']:
if is_real_argument_to_wrapper(arg):
count += 1
- if arg['type'] == 'THSize*' and not scalar_check_is_from_option:
+ if arg['type'] == 'IntListSize' and not scalar_check_is_from_option:
scalar_check_is_from_size = True
scalar_check = '{}.size() == 0'.format(arg['name'])
if arg['type'] == 'TensorList':
diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp
index 4357e52..c70c4a9 100644
--- a/aten/src/ATen/native/TensorShape.cpp
+++ b/aten/src/ATen/native/TensorShape.cpp
@@ -1,6 +1,7 @@
#include <TH/THTensor.hpp>
#include "ATen/ATen.h"
#include "ATen/ExpandUtils.h"
+#include "ATen/InferSize.h"
#include "ATen/NativeFunctions.h"
#include "ATen/WrapDimUtils.h"
#include "ATen/core/Error.h"
@@ -217,40 +218,6 @@
return result;
}
-// Infers the size of a dim with size -1, if it exists. Also checks that new
-// shape is compatible with the number of elements.
-static std::vector<int64_t> infer_size(IntList shape, int64_t numel) {
- auto res = shape.vec();
- int64_t newsize = 1;
- auto infer_dim = at::optional<int64_t>();
- for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) {
- if (shape[dim] == -1) {
- if (infer_dim) {
- throw std::runtime_error("only one dimension can be inferred");
- }
- infer_dim = dim;
- } else if (shape[dim] >= 0) {
- newsize *= shape[dim];
- } else {
- AT_ERROR("invalid shape dimension ", shape[dim]);
- }
- }
-
- if (numel == newsize || (infer_dim && newsize > 0 && numel % newsize == 0)) {
- if (infer_dim) {
- // we have a degree of freedom here to select the dimension size; follow NumPy semantics
- // and just bail.
- AT_CHECK(newsize != 0, "cannot reshape tensor of 0 elements into shape ", shape);
- res[*infer_dim] = numel / newsize;
- }
- return res;
- }
-
- std::ostringstream ss;
- ss << "shape '" << shape << "' is invalid for input of size " << numel;
- throw std::runtime_error(ss.str());
-}
-
Tensor reshape(const Tensor& self, IntList proposed_shape) {
if (self.type().is_sparse()) {
AT_ERROR("reshape is not implemented for sparse tensors");
diff --git a/aten/src/ATen/templates/SparseTypeDerived.cpp b/aten/src/ATen/templates/SparseTypeDerived.cpp
index 73c43aa..4901938 100644
--- a/aten/src/ATen/templates/SparseTypeDerived.cpp
+++ b/aten/src/ATen/templates/SparseTypeDerived.cpp
@@ -13,7 +13,7 @@
#include "ATen/Allocator.h"
#include "ATen/DeviceGuard.h"
#include "ATen/NativeFunctions.h"
-#include "ATen/THLongStorageView.h"
+#include "ATen/THSizeStrideCompat.h"
#include "ATen/UndefinedTensor.h"
#include "ATen/Utils.h"
#include "ATen/WrapDimUtils.h"
diff --git a/aten/src/ATen/templates/TypeDerived.cpp b/aten/src/ATen/templates/TypeDerived.cpp
index 11d2a29..8806fed 100644
--- a/aten/src/ATen/templates/TypeDerived.cpp
+++ b/aten/src/ATen/templates/TypeDerived.cpp
@@ -14,7 +14,7 @@
#include "ATen/Allocator.h"
#include "ATen/DeviceGuard.h"
#include "ATen/NativeFunctions.h"
-#include "ATen/THLongStorageView.h"
+#include "ATen/THSizeStrideCompat.h"
#include "ATen/UndefinedTensor.h"
#include "ATen/Utils.h"
#include "ATen/WrapDimUtils.h"
diff --git a/aten/src/TH/CMakeLists.txt b/aten/src/TH/CMakeLists.txt
index 0f8aa4a..ab9f534 100644
--- a/aten/src/TH/CMakeLists.txt
+++ b/aten/src/TH/CMakeLists.txt
@@ -122,6 +122,7 @@
generic/THStorageCopy.h
generic/THTensor.cpp
generic/THTensor.h
+ generic/THTensor.hpp
generic/THTensorConv.cpp
generic/THTensorConv.h
generic/THTensorCopy.cpp
diff --git a/aten/src/TH/THStorageFunctions.cpp b/aten/src/TH/THStorageFunctions.cpp
index f328a5f..9d117a6 100644
--- a/aten/src/TH/THStorageFunctions.cpp
+++ b/aten/src/TH/THStorageFunctions.cpp
@@ -14,6 +14,15 @@
#include "generic/THStorageCopy.cpp"
#include "THGenerateHalfType.h"
+THStorage* THStorage_new(at::ScalarType scalar_type) {
+ THStorage* storage = new THStorage(
+ scalar_type,
+ 0,
+ getTHDefaultAllocator(),
+ true);
+ return storage;
+}
+
// Free a non-weak pointer to THStorage
void THStorage_free(THStorage* storage) {
if (!storage) {
@@ -40,40 +49,6 @@
return nullptr;
}
-THDescBuff THLongStorage_sizeDesc(const THLongStorage *size) {
- return _THSizeDesc(THLongStorage_data(size), size->size());
-}
-
-THLongStorage *THLongStorage_newInferSize(THLongStorage *size, ptrdiff_t nElement)
-{
- ptrdiff_t total_size = (size->size() > 0 ? 1 : 0);
- ptrdiff_t dim_infer = -1;
- ptrdiff_t i;
- for (i = 0; i < size->size(); i++) {
- if (THLongStorage_data(size)[i] == -1) {
- THArgCheck(dim_infer == -1, 1, "only one dimension can be inferred");
- dim_infer = i;
- } else {
- total_size *= THLongStorage_data(size)[i];
- }
- }
- if (dim_infer != -1) {
- THDescBuff buf = THLongStorage_sizeDesc(size);
- THArgCheck(total_size > 0 && nElement % total_size == 0, 2,
- "size '%s' is invalid for input with %td elements", buf.str, nElement);
- } else {
- THDescBuff buf = THLongStorage_sizeDesc(size);
- THArgCheck(nElement == total_size, 2,
- "size '%s' is invalid for input with %td elements", buf.str, nElement);
- }
- THLongStorage* copy = THLongStorage_newWithSize(size->size());
- THLongStorage_copy(copy, size);
- if (dim_infer != -1) {
- THLongStorage_data(copy)[dim_infer] = nElement / total_size;
- }
- return copy;
-}
-
ptrdiff_t THStorage_size(const THStorage *self)
{
return self->size();
diff --git a/aten/src/TH/THStorageFunctions.h b/aten/src/TH/THStorageFunctions.h
index edd1b6b..6d32207 100644
--- a/aten/src/TH/THStorageFunctions.h
+++ b/aten/src/TH/THStorageFunctions.h
@@ -20,6 +20,3 @@
// This exists to have a data-type independent way of freeing (necessary for THPPointer).
TH_API void THStorage_free(THStorage *storage);
TH_API void THStorage_weakFree(THStorage *storage);
-
-TH_API THDescBuff THLongStorage_sizeDesc(const THLongStorage *size);
-TH_API THLongStorage *THLongStorage_newInferSize(THLongStorage *size, ptrdiff_t nElement);
diff --git a/aten/src/TH/THStorageFunctions.hpp b/aten/src/TH/THStorageFunctions.hpp
index b82f0d5..117259e 100644
--- a/aten/src/TH/THStorageFunctions.hpp
+++ b/aten/src/TH/THStorageFunctions.hpp
@@ -33,6 +33,7 @@
// If it is not, you must report that the storage is dead.
//
+TH_CPP_API THStorage* THStorage_new(at::ScalarType scalar_type);
TH_API ptrdiff_t THStorage_size(const THStorage *self);
TH_API void THStorage_retain(THStorage *storage);
diff --git a/aten/src/TH/THTensor.cpp b/aten/src/TH/THTensor.cpp
index 5f3b6ed..e737471 100644
--- a/aten/src/TH/THTensor.cpp
+++ b/aten/src/TH/THTensor.cpp
@@ -14,6 +14,128 @@
self->release();
}
+void THTensor_setStorage(THTensor *self, THStorage *storage_, ptrdiff_t storageOffset_, at::IntList size_, at::IntList stride_) {
+ if(size_.data() && stride_.data())
+ THArgCheck(size_.size() == stride_.size(), 5, "inconsistent size/stride sizes");
+
+ AT_CHECK(size_.data(), "size must not be null");
+#ifdef DEBUG
+ THAssert(size_.size() <= INT_MAX);
+#endif
+ THTensor_setStorageNd(self,
+ storage_,
+ storageOffset_,
+ size_.size(),
+ size_.data(),
+ stride_.data());
+}
+
+void THTensor_setStorageNd(THTensor *self, THStorage *storage, ptrdiff_t storageOffset, int nDimension, const int64_t *size, const int64_t *stride)
+{
+ /* storage */
+ if(THTensor_getStoragePtr(self) != storage)
+ {
+ if (!THTensor_getStoragePtr(self)) {
+ THError("Tensor: invalid null storage");
+ }
+ auto scalar_type = THTensor_getStoragePtr(self)->scalar_type();
+ THStorage_free(THTensor_getStoragePtr(self));
+ if(storage)
+ {
+ THTensor_stealAndSetStoragePtr(self, storage);
+ THStorage_retain(THTensor_getStoragePtr(self));
+ }
+ else {
+ THTensor_stealAndSetStoragePtr(self, THStorage_new(scalar_type));
+ }
+ }
+
+ /* storageOffset */
+ if(storageOffset < 0)
+ THError("Tensor: invalid storage offset");
+ THTensor_setStorageOffset(self, storageOffset);
+
+ /* size and stride */
+ THTensor_resizeNd(self, nDimension, size, stride);
+}
+
+void THTensor_resize(THTensor *self, at::IntList size, at::IntList stride)
+{
+ THArgCheck(size.data() != NULL, 2, "invalid size");
+ if(stride.data())
+ THArgCheck(stride.size() == size.size(), 3, "invalid stride");
+
+#ifdef DEBUG
+ THAssert(size.size() <= INT_MAX);
+#endif
+ THTensor_resizeNd(self, size.size(), size.data(), stride.data());
+}
+
+void THTensor_resizeNd(THTensor *self, int nDimension, const int64_t *size, const int64_t *stride)
+{
+ int d;
+ ptrdiff_t totalSize;
+ bool hascorrectsize = true;
+
+#ifndef USE_TH_SCALAR
+ AT_CHECK(nDimension > 0, "resizeNd nDimension must be greater than 0");
+#else
+ AT_CHECK(nDimension >= 0, "resizeNd nDimension must be non-negative");
+#endif
+
+ for(d = 0; d < nDimension; d++)
+ {
+ if((self->dim() > d) && (size[d] != self->size(d))) {
+ hascorrectsize = false;
+ }
+
+ // NB: this used to test that stride[d] was >= 0
+ if((self->dim() > d) && stride && (stride[d] != self->stride(d))) {
+ hascorrectsize = false;
+ }
+ }
+
+ if(nDimension != self->dim()) {
+ hascorrectsize = false;
+ }
+
+ if(hascorrectsize) {
+ return;
+ }
+
+ if(nDimension != self->dim())
+ {
+ THTensor_resizeDim(self, nDimension);
+ }
+
+ totalSize = 1;
+ for(d = nDimension-1; d >= 0; d--)
+ {
+ THTensor_setSizeAtDim(self, d, size[d]);
+ if(stride && (stride[d] >= 0) ) {
+ THTensor_setStrideAtDim(self, d, stride[d]);
+ } else {
+ if(d == nDimension-1) {
+ THTensor_setStrideAtDim(self, d, 1);
+ } else {
+ // Keep stride monotonically increasing to match NumPy.
+ THTensor_setStrideAtDim(self, d, std::max<int64_t>(self->size(d+1), 1)*self->stride(d+1));
+ }
+ }
+ totalSize += (self->size(d)-1)*self->stride(d);
+ }
+
+ if(totalSize+self->storage_offset() > 0)
+ {
+ if(!THTensor_getStoragePtr(self)) {
+ THTensor_stealAndSetStoragePtr(self, THStorage_new(self->scalar_type()));
+ }
+ if(totalSize+self->storage_offset() > THTensor_getStoragePtr(self)->size()) {
+ THStorage_resize(THTensor_getStoragePtr(self), totalSize+self->storage_offset());
+ }
+ }
+}
+
// On a high level,
// 1. separate oldshape chunks of dimensions, where the dimensions are
// ``contiguous'' in each chunk, i.e., oldstride[i] = oldshape[i+1] * oldstride[i+1]
diff --git a/aten/src/TH/THTensor.hpp b/aten/src/TH/THTensor.hpp
index 71021ec..d895b84 100644
--- a/aten/src/TH/THTensor.hpp
+++ b/aten/src/TH/THTensor.hpp
@@ -84,11 +84,11 @@
return strides_[d];
}
- inline at::IntList sizes() {
+ inline at::IntList sizes() const {
return sizes_;
}
- inline at::IntList strides() {
+ inline at::IntList strides() const {
return strides_;
}
@@ -211,5 +211,16 @@
}
TH_API void THTensor_free(THTensor *self);
+TH_API void THTensor_setStorageNd(THTensor *self, THStorage *storage, ptrdiff_t storageOffset, int nDimension, const int64_t *size, const int64_t *stride);
+TH_API void THTensor_resizeNd(THTensor *self, int nDimension, const int64_t *size, const int64_t *stride);
+
+TH_CPP_API void THTensor_resize(THTensor *self, at::IntList size, at::IntList stride);
+TH_CPP_API void THTensor_setStorage(THTensor *self, THStorage *storage_, ptrdiff_t storageOffset_, at::IntList size_, at::IntList stride_);
TH_CPP_API at::optional<std::vector<int64_t>> THTensor_compute_stride(at::IntList oldshape, at::IntList oldstride,
at::IntList newshape);
+
+#include "generic/THTensor.hpp"
+#include "THGenerateAllTypes.h"
+
+#include "generic/THTensor.hpp"
+#include "THGenerateHalfType.h"
diff --git a/aten/src/TH/generic/THStorage.cpp b/aten/src/TH/generic/THStorage.cpp
index 1b52aa2..6a3c178 100644
--- a/aten/src/TH/generic/THStorage.cpp
+++ b/aten/src/TH/generic/THStorage.cpp
@@ -21,12 +21,7 @@
THStorage* THStorage_(new)(void)
{
- THStorage* storage = new THStorage(
- at::CTypeToScalarType<th::from_type<real>>::to(),
- 0,
- getTHDefaultAllocator(),
- true);
- return storage;
+ return THStorage_new(at::CTypeToScalarType<th::from_type<real>>::to());
}
THStorage* THStorage_(newWithSize)(ptrdiff_t size)
diff --git a/aten/src/TH/generic/THTensor.cpp b/aten/src/TH/generic/THTensor.cpp
index e48fda5..6a52d90 100644
--- a/aten/src/TH/generic/THTensor.cpp
+++ b/aten/src/TH/generic/THTensor.cpp
@@ -2,6 +2,7 @@
#define TH_GENERIC_FILE "generic/THTensor.cpp"
#else
+#include <ATen/InferSize.h>
#include <new>
/**** access methods ****/
@@ -39,20 +40,6 @@
return self->stride(dim);
}
-THLongStorage *THTensor_(newSizeOf)(THTensor *self)
-{
- THLongStorage *size = THLongStorage_newWithSize(self->dim());
- THLongStorage_rawCopy(size, THTensor_getSizePtr(self));
- return size;
-}
-
-THLongStorage *THTensor_(newStrideOf)(THTensor *self)
-{
- THLongStorage *stride = THLongStorage_newWithSize(self->dim());
- THLongStorage_rawCopy(stride, THTensor_getStridePtr(self));
- return stride;
-}
-
real *THTensor_(data)(const THTensor *self) {
return self->data<real>();
}
@@ -79,29 +66,11 @@
}
/* Storage init */
-THTensor *THTensor_(newWithStorage)(THStorage *storage, ptrdiff_t storageOffset, THLongStorage *size, THLongStorage *stride)
-{
- if(size && stride) {
- THArgCheck(size->size() == stride->size(), 4, "inconsistent size");
+THTensor *THTensor_(newWithStorage)(THStorage *storage, ptrdiff_t storageOffset, at::IntList sizes, at::IntList strides) {
+ if (sizes.data() && strides.data()) {
+ AT_CHECK(sizes.size() == strides.size(), "number of sizes and strides must match");
}
- AT_CHECK(size, "size must not be null");
-
- THTensor *self = new THTensor(THStorage_(new)());
-#ifdef DEBUG
- THAssert(size->size() <= INT_MAX);
-#endif
- THTensor_(setStorageNd)(self,
- storage,
- storageOffset,
- size->size(),
- THLongStorage_data(size),
- (stride ? THLongStorage_data(stride) : NULL));
-
- return self;
-}
-
-THTensor *THTensor_(newWithStorageIntLists)(THStorage *storage, ptrdiff_t storageOffset, at::IntList sizes, at::IntList strides) {
- AT_CHECK(sizes.size() == strides.size(), "number of sizes and strides must match");
+ AT_CHECK(sizes.data(), "size must not be null");
THTensor *self = new THTensor(THStorage_(new)());
THTensor_(setStorageNd)(self, storage, storageOffset, sizes.size(),
const_cast<int64_t*>(sizes.data()), const_cast<int64_t*>(strides.data()));
@@ -112,14 +81,14 @@
THTensor *THTensor_(newWithStorage1d)(THStorage *storage, ptrdiff_t storageOffset,
int64_t size0, int64_t stride0)
{
- return THTensor_(newWithStorageIntLists)(storage, storageOffset, {size0}, {stride0});
+ return THTensor_(newWithStorage)(storage, storageOffset, {size0}, {stride0});
}
THTensor *THTensor_(newWithStorage2d)(THStorage *storage, ptrdiff_t storageOffset,
int64_t size0, int64_t stride0,
int64_t size1, int64_t stride1)
{
- return THTensor_(newWithStorageIntLists)(storage, storageOffset, {size0, size1}, {stride0, stride1});
+ return THTensor_(newWithStorage)(storage, storageOffset, {size0, size1}, {stride0, stride1});
}
THTensor *THTensor_(newWithStorage3d)(THStorage *storage, ptrdiff_t storageOffset,
@@ -127,7 +96,7 @@
int64_t size1, int64_t stride1,
int64_t size2, int64_t stride2)
{
- return THTensor_(newWithStorageIntLists)(storage, storageOffset, {size0, size1, size2}, {stride0, stride1, stride2});
+ return THTensor_(newWithStorage)(storage, storageOffset, {size0, size1, size2}, {stride0, stride1, stride2});
}
THTensor *THTensor_(newWithStorage4d)(THStorage *storage, ptrdiff_t storageOffset,
@@ -136,41 +105,34 @@
int64_t size2, int64_t stride2,
int64_t size3, int64_t stride3)
{
- return THTensor_(newWithStorageIntLists)(storage, storageOffset,
+ return THTensor_(newWithStorage)(storage, storageOffset,
{size0, size1, size2, size3},
{stride0, stride1, stride2, stride3});
}
-THTensor *THTensor_(newWithSize)(THLongStorage *size, THLongStorage *stride)
+THTensor *THTensor_(newWithSize)(at::IntList size, at::IntList stride)
{
return THTensor_(newWithStorage)(NULL, 0, size, stride);
}
-THTensor *THTensor_(newWithSizeIntList)(at::IntList sizes) {
- THTensor *self = new THTensor(THStorage_(new)());
- THTensor_(resizeNd)(self, sizes.size(), const_cast<int64_t*>(sizes.data()), nullptr);
-
- return self;
-}
-
THTensor *THTensor_(newWithSize1d)(int64_t size0)
{
- return THTensor_(newWithSizeIntList)({size0});
+ return THTensor_(newWithSize)({size0}, {});
}
THTensor *THTensor_(newWithSize2d)(int64_t size0, int64_t size1)
{
- return THTensor_(newWithSizeIntList)({size0, size1});
+ return THTensor_(newWithSize)({size0, size1}, {});
}
THTensor *THTensor_(newWithSize3d)(int64_t size0, int64_t size1, int64_t size2)
{
- return THTensor_(newWithSizeIntList)({size0, size1, size2});
+ return THTensor_(newWithSize)({size0, size1, size2}, {});
}
THTensor *THTensor_(newWithSize4d)(int64_t size0, int64_t size1, int64_t size2, int64_t size3)
{
- return THTensor_(newWithSizeIntList)({size0, size1, size2, size3});
+ return THTensor_(newWithSize)({size0, size1, size2, size3}, {});
}
THTensor *THTensor_(newClone)(THTensor *self)
@@ -220,37 +182,26 @@
return self;
}
-THTensor *THTensor_(newView)(THTensor *tensor, THLongStorage *size)
+THTensor *THTensor_(newView)(THTensor *tensor, at::IntList size)
{
ptrdiff_t numel = THTensor_(nElement)(tensor);
THTensor *self = THTensor_(new)();
- THLongStorage *inferred_size = THLongStorage_newInferSize(size, numel);
+ auto inferred_size = at::infer_size(size, numel);
auto stride = THTensor_compute_stride(tensor->sizes(),
tensor->strides(),
- at::IntList(inferred_size->data<int64_t>(), inferred_size->size()));
+ inferred_size);
THArgCheck(stride.has_value(), 2, "view size is "
"not compatible with input tensor's size and stride (at least one dimension spans "
"across two contiguous subspaces). Call .contiguous() before .view().");
auto stride_value = *stride;
- THLongStorage *new_stride = THLongStorage_newWithSize(stride_value.size());
- THLongStorage_rawCopy(new_stride, stride_value.data());
- THTensor_(setStorage)(self, THTensor_getStoragePtr(tensor), tensor->storage_offset(), inferred_size, new_stride);
- THLongStorage_free(inferred_size);
- THLongStorage_free(new_stride);
+ THTensor_setStorage(self, THTensor_getStoragePtr(tensor), tensor->storage_offset(), inferred_size, stride_value);
return self;
}
/* Resize */
-void THTensor_(resize)(THTensor *self, THLongStorage *size, THLongStorage *stride)
+void THTensor_(resize)(THTensor *self, at::IntList size, at::IntList stride)
{
- THArgCheck(size != NULL, 2, "invalid size");
- if(stride)
- THArgCheck(stride->size() == size->size(), 3, "invalid stride");
-
-#ifdef DEBUG
- THAssert(size->size() <= INT_MAX);
-#endif
- THTensor_(resizeNd)(self, size->size(), THLongStorage_data(size), (stride ? THLongStorage_data(stride) : NULL));
+ return THTensor_resize(self, size, stride);
}
void THTensor_(resizeAs)(THTensor *self, THTensor *src)
@@ -300,46 +251,25 @@
THTensor_getStridePtr(src));
}
-void THTensor_(setStorage)(THTensor *self, THStorage *storage_, ptrdiff_t storageOffset_, THLongStorage *size_, THLongStorage *stride_)
+void THTensor_(setStorage)(THTensor *self, THStorage *storage_, ptrdiff_t storageOffset_, at::IntList size_, at::IntList stride_)
{
- if(size_ && stride_)
- THArgCheck(size_->size() == stride_->size(), 5, "inconsistent size/stride sizes");
-
- AT_CHECK(size_, "size must not be null");
-#ifdef DEBUG
- THAssert(size_ <= INT_MAX);
-#endif
- THTensor_(setStorageNd)(self,
- storage_,
- storageOffset_,
- size_->size(),
- THLongStorage_data(size_),
- (stride_ ? THLongStorage_data(stride_) : NULL));
-}
-
-void THTensor_(setStorageIntLists)(THTensor *self, THStorage *storage_, ptrdiff_t storageOffset_,
- at::IntList sizes, at::IntList strides)
-{
- AT_CHECK(sizes.size() == strides.size(), "number of sizes and strides must match");
-
- THTensor_(setStorageNd)(self, storage_, storageOffset_, sizes.size(),
- const_cast<int64_t *>(sizes.data()), const_cast<int64_t *>(strides.data()));
+ THTensor_setStorage(self, storage_, storageOffset_, size_, stride_);
}
void THTensor_(setStorage1d)(THTensor *self, THStorage *storage_, ptrdiff_t storageOffset_,
int64_t size0_, int64_t stride0_)
{
- THTensor_(setStorageIntLists)(self, storage_, storageOffset_,
- {size0_}, {stride0_});
+ THTensor_(setStorage)(self, storage_, storageOffset_,
+ {size0_}, {stride0_});
}
void THTensor_(setStorage2d)(THTensor *self, THStorage *storage_, ptrdiff_t storageOffset_,
int64_t size0_, int64_t stride0_,
int64_t size1_, int64_t stride1_)
{
- THTensor_(setStorageIntLists)(self, storage_, storageOffset_,
- {size0_, size1_},
- {stride0_, stride1_});
+ THTensor_(setStorage)(self, storage_, storageOffset_,
+ {size0_, size1_},
+ {stride0_, stride1_});
}
void THTensor_(setStorage3d)(THTensor *self, THStorage *storage_, ptrdiff_t storageOffset_,
@@ -347,9 +277,9 @@
int64_t size1_, int64_t stride1_,
int64_t size2_, int64_t stride2_)
{
- THTensor_(setStorageIntLists)(self, storage_, storageOffset_,
- {size0_, size1_, size2_},
- {stride0_, stride1_, stride2_});
+ THTensor_(setStorage)(self, storage_, storageOffset_,
+ {size0_, size1_, size2_},
+ {stride0_, stride1_, stride2_});
}
void THTensor_(setStorage4d)(THTensor *self, THStorage *storage_, ptrdiff_t storageOffset_,
@@ -362,7 +292,7 @@
int64_t size[4] = {size0_, size1_, size2_, size3_};
int64_t stride[4] = {stride0_, stride1_, stride2_, stride3_};
- THTensor_(setStorageIntLists)(self, storage_, storageOffset_, size, stride);
+ THTensor_(setStorage)(self, storage_, storageOffset_, size, stride);
}
@@ -664,93 +594,12 @@
void THTensor_(setStorageNd)(THTensor *self, THStorage *storage, ptrdiff_t storageOffset, int nDimension, int64_t *size, int64_t *stride)
{
- /* storage */
- if(THTensor_getStoragePtr(self) != storage)
- {
- if(THTensor_getStoragePtr(self))
- THStorage_(free)(THTensor_getStoragePtr(self));
-
- if(storage)
- {
- THTensor_stealAndSetStoragePtr(self, storage);
- THStorage_(retain)(THTensor_getStoragePtr(self));
- }
- else
- THTensor_stealAndSetStoragePtr(self, THStorage_(new)());
- }
-
- /* storageOffset */
- if(storageOffset < 0)
- THError("Tensor: invalid storage offset");
- THTensor_setStorageOffset(self, storageOffset);
-
- /* size and stride */
- THTensor_(resizeNd)(self, nDimension, size, stride);
+ return THTensor_setStorageNd(self, storage, storageOffset, nDimension, size, stride);
}
void THTensor_(resizeNd)(THTensor *self, int nDimension, int64_t *size, int64_t *stride)
{
- int d;
- ptrdiff_t totalSize;
- bool hascorrectsize = true;
-
-#ifndef USE_TH_SCALAR
- AT_CHECK(nDimension > 0, "resizeNd nDimension must be greater than 0");
-#else
- AT_CHECK(nDimension >= 0, "resizeNd nDimension must be non-negative");
-#endif
-
- for(d = 0; d < nDimension; d++)
- {
- if((self->dim() > d) && (size[d] != self->size(d))) {
- hascorrectsize = false;
- }
-
- // NB: this used to test that stride[d] was >= 0
- if((self->dim() > d) && stride && (stride[d] != self->stride(d))) {
- hascorrectsize = false;
- }
- }
-
- if(nDimension != self->dim()) {
- hascorrectsize = false;
- }
-
- if(hascorrectsize) {
- return;
- }
-
- if(nDimension != self->dim())
- {
- THTensor_resizeDim(self, nDimension);
- }
-
- totalSize = 1;
- for(d = nDimension-1; d >= 0; d--)
- {
- THTensor_setSizeAtDim(self, d, size[d]);
- if(stride && (stride[d] >= 0) ) {
- THTensor_setStrideAtDim(self, d, stride[d]);
- } else {
- if(d == nDimension-1) {
- THTensor_setStrideAtDim(self, d, 1);
- } else {
- // Keep stride monotonically increasing to match NumPy.
- THTensor_setStrideAtDim(self, d, std::max<int64_t>(self->size(d+1), 1)*self->stride(d+1));
- }
- }
- totalSize += (self->size(d)-1)*self->stride(d);
- }
-
- if(totalSize+self->storage_offset() > 0)
- {
- if(!THTensor_getStoragePtr(self)) {
- THTensor_stealAndSetStoragePtr(self, THStorage_(new)());
- }
- if(totalSize+self->storage_offset() > THTensor_getStoragePtr(self)->size()) {
- THStorage_(resize)(THTensor_getStoragePtr(self), totalSize+self->storage_offset());
- }
- }
+ return THTensor_resizeNd(self, nDimension, size, stride);
}
void THTensor_(set1d)(THTensor *tensor, int64_t x0, real value)
@@ -832,9 +681,7 @@
}
THDescBuff THTensor_(sizeDesc)(const THTensor *tensor) {
- THLongStorage *size = THTensor_(newSizeOf)((THTensor*)tensor);
- THDescBuff buf = THLongStorage_sizeDesc(size);
- THLongStorage_free(size);
+ THDescBuff buf = _THSizeDesc(tensor->sizes().data(), tensor->sizes().size());
return buf;
}
diff --git a/aten/src/TH/generic/THTensor.h b/aten/src/TH/generic/THTensor.h
index decb9ce..53adc1a 100644
--- a/aten/src/TH/generic/THTensor.h
+++ b/aten/src/TH/generic/THTensor.h
@@ -29,16 +29,12 @@
TH_API int THTensor_(nDimensionLegacyAll)(const THTensor *self);
TH_API int64_t THTensor_(size)(const THTensor *self, int dim);
TH_API int64_t THTensor_(stride)(const THTensor *self, int dim);
-TH_API THLongStorage *THTensor_(newSizeOf)(THTensor *self);
-TH_API THLongStorage *THTensor_(newStrideOf)(THTensor *self);
TH_API real *THTensor_(data)(const THTensor *self);
/**** creation methods ****/
TH_API THTensor *THTensor_(new)(void);
TH_API THTensor *THTensor_(newWithTensor)(THTensor *tensor);
-/* stride might be NULL */
-TH_API THTensor *THTensor_(newWithStorage)(THStorage *storage_, ptrdiff_t storageOffset_, THLongStorage *size_, THLongStorage *stride_);
TH_API THTensor *THTensor_(newWithStorage1d)(THStorage *storage_, ptrdiff_t storageOffset_,
int64_t size0_, int64_t stride0_);
TH_API THTensor *THTensor_(newWithStorage2d)(THStorage *storage_, ptrdiff_t storageOffset_,
@@ -55,7 +51,6 @@
int64_t size3_, int64_t stride3_);
/* stride might be NULL */
-TH_API THTensor *THTensor_(newWithSize)(THLongStorage *size_, THLongStorage *stride_);
TH_API THTensor *THTensor_(newWithSize1d)(int64_t size0_);
TH_API THTensor *THTensor_(newWithSize2d)(int64_t size0_, int64_t size1_);
TH_API THTensor *THTensor_(newWithSize3d)(int64_t size0_, int64_t size1_, int64_t size2_);
@@ -67,12 +62,10 @@
TH_API THTensor *THTensor_(newNarrow)(THTensor *tensor, int dimension_, int64_t firstIndex_, int64_t size_);
TH_API THTensor *THTensor_(newTranspose)(THTensor *tensor, int dimension1_, int dimension2_);
TH_API THTensor *THTensor_(newUnfold)(THTensor *tensor, int dimension_, int64_t size_, int64_t step_);
-TH_API THTensor *THTensor_(newView)(THTensor *tensor, THLongStorage *size);
// resize* methods simply resize the storage. So they may not retain the current data at current indices.
// This is especially likely to happen when the tensor is not contiguous. In general, if you still need the
// values, unless you are doing some size and stride tricks, do not use resize*.
-TH_API void THTensor_(resize)(THTensor *tensor, THLongStorage *size, THLongStorage *stride);
TH_API void THTensor_(resizeNd)(THTensor *tensor, int nDimension, int64_t *size, int64_t *stride);
TH_API void THTensor_(resizeAs)(THTensor *tensor, THTensor *src);
TH_API void THTensor_(resize1d)(THTensor *tensor, int64_t size0_);
@@ -83,7 +76,6 @@
// Note: these are legacy resize functions that treat sizes as size->size() == 0 and size->data<int64_t>() as being 0-terminated.
TH_API void THTensor_(set)(THTensor *self, THTensor *src);
-TH_API void THTensor_(setStorage)(THTensor *self, THStorage *storage_, ptrdiff_t storageOffset_, THLongStorage *size_, THLongStorage *stride_);
TH_API void THTensor_(setStorageNd)(THTensor *self, THStorage *storage_, ptrdiff_t storageOffset_, int nDimension, int64_t *size, int64_t *stride);
TH_API void THTensor_(setStorage1d)(THTensor *self, THStorage *storage_, ptrdiff_t storageOffset_,
int64_t size0_, int64_t stride0_);
diff --git a/aten/src/TH/generic/THTensor.hpp b/aten/src/TH/generic/THTensor.hpp
new file mode 100644
index 0000000..845ebac
--- /dev/null
+++ b/aten/src/TH/generic/THTensor.hpp
@@ -0,0 +1,21 @@
+#ifndef TH_GENERIC_FILE
+#define TH_GENERIC_FILE "generic/THTensor.hpp"
+#else
+
+// STOP!!! Thinking of including this header directly? Please
+// read Note [TH abstraction violation]
+
+// NOTE: functions exist here only to support dispatch via Declarations.cwrap. You probably don't want to put
+// new functions in here, they should probably be un-genericized.
+
+TH_CPP_API void THTensor_(setStorage)(THTensor *self, THStorage *storage_, ptrdiff_t storageOffset_,
+ at::IntList size_, at::IntList stride_);
+TH_CPP_API THTensor *THTensor_(newView)(THTensor *tensor, at::IntList size);
+/* strides.data() might be NULL */
+TH_CPP_API THTensor *THTensor_(newWithStorage)(THStorage *storage, ptrdiff_t storageOffset,
+ at::IntList sizes, at::IntList strides);
+
+TH_CPP_API void THTensor_(resize)(THTensor *self, at::IntList size, at::IntList stride);
+TH_CPP_API THTensor *THTensor_(newWithSize)(at::IntList size, at::IntList stride);
+
+#endif
diff --git a/aten/src/TH/generic/THTensorEvenMoreMath.cpp b/aten/src/TH/generic/THTensorEvenMoreMath.cpp
index a2dfc38..fcd7d1d 100644
--- a/aten/src/TH/generic/THTensorEvenMoreMath.cpp
+++ b/aten/src/TH/generic/THTensorEvenMoreMath.cpp
@@ -144,7 +144,6 @@
void THTensor_(indexSelect)(THTensor *tensor, THTensor *src, int dim, THLongTensor *index)
{
ptrdiff_t i, numel;
- THLongStorage *newSize;
THTensor *tSlice, *sSlice;
int64_t *index_data;
real *tensor_data, *src_data;
@@ -154,14 +153,12 @@
numel = THLongTensor_nElement(index);
- newSize = THLongStorage_newWithSize(src->dim());
- THLongStorage_rawCopy(newSize, THTensor_getSizePtr(src));
+ std::vector<int64_t> newSize = src->sizes().vec();
#ifdef DEBUG
THAssert(numel <= LONG_MAX);
#endif
- THLongStorage_data(newSize)[dim] = numel;
- THTensor_(resize)(tensor,newSize,NULL);
- THLongStorage_free(newSize);
+ newSize[dim] = numel;
+ THTensor_(resize)(tensor,newSize,{});
index = THLongTensor_newContiguous(index);
index_data = THLongTensor_data(index);
diff --git a/aten/src/TH/generic/THTensorMoreMath.cpp b/aten/src/TH/generic/THTensorMoreMath.cpp
index e0dd8dc..7e1976e 100644
--- a/aten/src/TH/generic/THTensorMoreMath.cpp
+++ b/aten/src/TH/generic/THTensorMoreMath.cpp
@@ -73,19 +73,16 @@
void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim)
{
- THLongStorage *dim;
-
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyAll)(t), 2, "dimension %d out of range",
dimension + TH_INDEX_BASE);
int in_dims = THTensor_(nDimensionLegacyAll)(t);
THTensor_(preserveReduceDimSemantics)(values_, in_dims, dimension, keepdim);
THLongTensor_preserveReduceDimSemantics(indices_, in_dims, dimension, keepdim);
- dim = THTensor_(newSizeOf)(t);
- THLongStorage_set(dim, dimension, 1);
- THTensor_(resize)(values_, dim, NULL);
- THLongTensor_resize(indices_, dim, NULL);
- THLongStorage_free(dim);
+ std::vector<int64_t> dim = t->sizes().vec();
+ dim[dimension] = 1;
+ THTensor_(resize)(values_, dim, {});
+ THLongTensor_resize(indices_, dim, {});
// two implementations optimized for data locality
if (THTensor_strideLegacyNoScalars(t, dimension) == 1) {
@@ -157,19 +154,16 @@
void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim)
{
- THLongStorage *dim;
-
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyAll)(t), 2, "dimension %d out of range",
dimension + TH_INDEX_BASE);
int in_dims = THTensor_(nDimensionLegacyAll)(t);
THTensor_(preserveReduceDimSemantics)(values_, in_dims, dimension, keepdim);
THLongTensor_preserveReduceDimSemantics(indices_, in_dims, dimension, keepdim);
- dim = THTensor_(newSizeOf)(t);
- THLongStorage_set(dim, dimension, 1);
- THTensor_(resize)(values_, dim, NULL);
- THLongTensor_resize(indices_, dim, NULL);
- THLongStorage_free(dim);
+ std::vector<int64_t> dim = t->sizes().vec();
+ dim[dimension] = 1;
+ THTensor_(resize)(values_, dim, {});
+ THLongTensor_resize(indices_, dim, {});
// two implementations optimized for data locality
if (THTensor_strideLegacyNoScalars(t, dimension) == 1) {
@@ -241,16 +235,13 @@
void THTensor_(sum)(THTensor *r_, THTensor *t, int dimension, int keepdim)
{
- THLongStorage *dim;
-
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyAll)(t), 2, "dimension %d out of range",
dimension + TH_INDEX_BASE);
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimensionLegacyAll)(t), dimension, keepdim);
- dim = THTensor_(newSizeOf)(t);
- THLongStorage_set(dim, dimension, 1);
- THTensor_(resize)(r_, dim, NULL);
- THLongStorage_free(dim);
+ std::vector<int64_t> dim = t->sizes().vec();
+ dim[dimension] = 1;
+ THTensor_(resize)(r_, dim, {});
int serial_path = 0;
#ifdef _OPENMP
@@ -321,16 +312,13 @@
void THTensor_(prod)(THTensor *r_, THTensor *t, int dimension, int keepdim)
{
- THLongStorage *dim;
-
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyAll)(t), 2, "dimension %d out of range",
dimension + TH_INDEX_BASE);
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimensionLegacyAll)(t), dimension, keepdim);
- dim = THTensor_(newSizeOf)(t);
- THLongStorage_set(dim, dimension, 1);
- THTensor_(resize)(r_, dim, NULL);
- THLongStorage_free(dim);
+ std::vector<int64_t> dim = t->sizes().vec();
+ dim[dimension] = 1;
+ THTensor_(resize)(r_, dim, {});
int serial_path = 0;
#ifdef _OPENMP
@@ -904,9 +892,8 @@
THTensor_(copy)(rt_, t);
{
- THLongStorage *size = THTensor_(newSizeOf)(t);
- THLongTensor_resize(ri_, size, NULL);
- THLongStorage_free(size);
+ std::vector<int64_t> size = t->sizes().vec();
+ THLongTensor_resize(ri_, size, {});
}
if(descendingOrder)
@@ -1053,7 +1040,6 @@
void THTensor_(mode)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim)
{
- THLongStorage *dim;
THTensor *temp_;
THLongTensor *tempi_;
real *temp__data;
@@ -1065,11 +1051,10 @@
int in_dims = THTensor_(nDimensionLegacyAll)(t);
THTensor_(preserveReduceDimSemantics)(values_, in_dims, dimension, keepdim);
THLongTensor_preserveReduceDimSemantics(indices_, in_dims, dimension, keepdim);
- dim = THTensor_(newSizeOf)(t);
- THLongStorage_set(dim, dimension, 1);
- THTensor_(resize)(values_, dim, NULL);
- THLongTensor_resize(indices_, dim, NULL);
- THLongStorage_free(dim);
+ std::vector<int64_t> dim = t->sizes().vec();
+ dim[dimension] = 1;
+ THTensor_(resize)(values_, dim, {});
+ THLongTensor_resize(indices_, dim, {});
t_size_dim = THTensor_sizeLegacyNoScalars(t, dimension);
@@ -1121,7 +1106,6 @@
void THTensor_(kthvalue)(THTensor *values_, THLongTensor *indices_, THTensor *t, int64_t k, int dimension, int keepdim)
{
- THLongStorage *dim;
THTensor *temp_;
THLongTensor *tempi_;
real *temp__data;
@@ -1134,11 +1118,10 @@
int in_dims = THTensor_(nDimensionLegacyAll)(t);
THTensor_(preserveReduceDimSemantics)(values_, in_dims, dimension, keepdim);
THLongTensor_preserveReduceDimSemantics(indices_, in_dims, dimension, keepdim);
- dim = THTensor_(newSizeOf)(t);
- THLongStorage_set(dim, dimension, 1);
- THTensor_(resize)(values_, dim, NULL);
- THLongTensor_resize(indices_, dim, NULL);
- THLongStorage_free(dim);
+ std::vector<int64_t> dim = t->sizes().vec();
+ dim[dimension] = 1;
+ THTensor_(resize)(values_, dim, {});
+ THLongTensor_resize(indices_, dim, {});
t_size_dim = THTensor_sizeLegacyNoScalars(t, dimension);
@@ -1197,11 +1180,10 @@
THLongTensor_resize1d(tmpIndices, sliceSize);
int64_t *tmpi__data = THLongTensor_data(tmpIndices);
- THLongStorage *topKSize = THTensor_(newSizeOf)(t);
- THLongStorage_set(topKSize, dim, k);
- THTensor_(resize)(rt_, topKSize, NULL);
- THLongTensor_resize(ri_, topKSize, NULL);
- THLongStorage_free(topKSize);
+ std::vector<int64_t> topKSize = t->sizes().vec();
+ topKSize[dim] = k;
+ THTensor_(resize)(rt_, topKSize, {});
+ THLongTensor_resize(ri_, topKSize, {});
if (dir) {
/* k largest elements, descending order (optional: see sorted) */
@@ -1379,15 +1361,15 @@
}
// Compute the size of the result
- THLongStorage *size = THLongStorage_newWithSize(nDims);
+ std::vector<int64_t> size(nDims);
for (int dim = 0; dim < nDims; dim++) {
int64_t result_dim_size = notSkippedTensor->size(dim);
if (dim == dimension) {
result_dim_size = cat_dim_size;
}
- THLongStorage_data(size)[dim] = result_dim_size;
+ size[dim] = result_dim_size;
}
- THTensor_(resize)(result, size, NULL);
+ THTensor_(resize)(result, size, {});
// Check contiguity of all inputs and result
bool allContiguous = true;
@@ -1429,7 +1411,6 @@
}
}
}
- THLongStorage_free(size);
}
int THTensor_(equal)(THTensor *ta, THTensor* tb)
@@ -1646,16 +1627,13 @@
void THTensor_(logicalAnd)(THTensor *r_, THTensor *t, int dimension, int keepdim)
{
- THLongStorage *dim;
-
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyAll)(t), 2, "dimension %d out of range",
dimension + TH_INDEX_BASE);
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimensionLegacyAll)(t), dimension, keepdim);
- dim = THTensor_(newSizeOf)(t);
- THLongStorage_set(dim, dimension, 1);
- THTensor_(resize)(r_, dim, NULL);
- THLongStorage_free(dim);
+ std::vector<int64_t> dim = t->sizes().vec();
+ dim[dimension] = 1;
+ THTensor_(resize)(r_, dim, {});
int serial_path = 0;
#ifdef _OPENMP
@@ -1726,16 +1704,13 @@
void THTensor_(logicalAny)(THTensor *r_, THTensor *t, int dimension, int keepdim)
{
- THLongStorage *dim;
-
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyAll)(t), 2, "dimension %d out of range",
dimension + TH_INDEX_BASE);
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimensionLegacyAll)(t), dimension, keepdim);
- dim = THTensor_(newSizeOf)(t);
- THLongStorage_set(dim, dimension, 1);
- THTensor_(resize)(r_, dim, NULL);
- THLongStorage_free(dim);
+ std::vector<int64_t> dim = t->sizes().vec();
+ dim[dimension] = 1;
+ THTensor_(resize)(r_, dim, {});
int serial_path = 0;
#ifdef _OPENMP
@@ -1881,16 +1856,13 @@
void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int biased, int keepdim)
{
- THLongStorage *dim;
-
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyAll)(t), 3, "invalid dimension %d",
dimension + TH_INDEX_BASE);
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimensionLegacyAll)(t), dimension, keepdim);
- dim = THTensor_(newSizeOf)(t);
- THLongStorage_set(dim, dimension, 1);
- THTensor_(resize)(r_, dim, NULL);
- THLongStorage_free(dim);
+ std::vector<int64_t> dim = t->sizes().vec();
+ dim[dimension] = 1;
+ THTensor_(resize)(r_, dim, {});
TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension,
// Uses Welford's algorithm for numeric stability
@@ -1925,16 +1897,13 @@
void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int biased, int keepdim)
{
- THLongStorage *dim;
-
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyAll)(t), 3, "invalid dimension %d",
dimension + TH_INDEX_BASE);
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimensionLegacyAll)(t), dimension, keepdim);
- dim = THTensor_(newSizeOf)(t);
- THLongStorage_set(dim, dimension, 1);
- THTensor_(resize)(r_, dim, NULL);
- THLongStorage_free(dim);
+ std::vector<int64_t> dim = t->sizes().vec();
+ dim[dimension] = 1;
+ THTensor_(resize)(r_, dim, {});
TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension,
// Uses Welford's algorithm for numeric stability
@@ -1969,16 +1938,13 @@
void THTensor_(norm)(THTensor *r_, THTensor *t, real value, int dimension, int keepdim)
{
- THLongStorage *dim;
-
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyAll)(t), 3, "invalid dimension %d",
dimension + TH_INDEX_BASE);
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimensionLegacyAll)(t), dimension, keepdim);
- dim = THTensor_(newSizeOf)(t);
- THLongStorage_set(dim, dimension, 1);
- THTensor_(resize)(r_, dim, NULL);
- THLongStorage_free(dim);
+ std::vector<int64_t> dim = t->sizes().vec();
+ dim[dimension] = 1;
+ THTensor_(resize)(r_, dim, {});
#define DIM_REDUCE(reduce, transform) \
TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension, \
diff --git a/aten/src/THC/CMakeLists.txt b/aten/src/THC/CMakeLists.txt
index ac445f7..51f196f 100644
--- a/aten/src/THC/CMakeLists.txt
+++ b/aten/src/THC/CMakeLists.txt
@@ -124,6 +124,7 @@
generic/THCTensor.cpp
generic/THCTensor.cu
generic/THCTensor.h
+ generic/THCTensor.hpp
generic/THCStorageCopy.cpp
generic/THCStorageCopy.cu
generic/THCStorageCopy.h
diff --git a/aten/src/THC/THCReduce.cuh b/aten/src/THC/THCReduce.cuh
index df1ad7c..1a72ae6 100644
--- a/aten/src/THC/THCReduce.cuh
+++ b/aten/src/THC/THCReduce.cuh
@@ -486,10 +486,9 @@
state, out, THCTensor_nDimensionLegacyAll(state, in), dim, keepdim);
// Resize out
- THLongStorage* sizes = THCTensor_newSizeOf(state, in);
- THLongStorage_set(sizes, dim, 1);
- THCTensor_resize(state, out, sizes, NULL);
- THLongStorage_free(sizes);
+ std::vector<int64_t> sizes = in->sizes().vec();
+ sizes[dim] = 1;
+ THCTensor_resize(state, out, sizes, {});
// It is possible that the tensor dimensions are able to be collapsed,
// and thus we can reduce the actual code complexity of the copy by
diff --git a/aten/src/THC/THCTensor.cpp b/aten/src/THC/THCTensor.cpp
index da00ca5..80b49df 100644
--- a/aten/src/THC/THCTensor.cpp
+++ b/aten/src/THC/THCTensor.cpp
@@ -36,12 +36,6 @@
return THTensor_strideLegacyNoScalars(self, dim);
}
-THLongStorage *THCTensor_newSizeOf(THCState *state, THCTensor *self) {
- THLongStorage *size = THLongStorage_newWithSize(self->dim());
- THLongStorage_rawCopy(size, THTensor_getSizePtr(self));
- return size;
-}
-
THCTensor *THCTensor_new(THCState *state, at::ScalarType scalar_type) {
switch(scalar_type) {
case at::ScalarType::Byte:
@@ -67,12 +61,15 @@
}
}
-void THCTensor_resize(THCState *state, THCTensor *self, THLongStorage *size, THLongStorage *stride) {
- THArgCheck(size != NULL, 2, "invalid size");
- if(stride)
- THArgCheck(stride->size() == size->size(), 3, "invalid stride");
+void THCTensor_resize(THCState *state, THCTensor *self, at::IntList size, at::IntList stride) {
+ THArgCheck(size.data() != NULL, 2, "invalid size");
+ if(stride.data())
+ THArgCheck(stride.size() == size.size(), 3, "invalid stride");
- THCTensor_resizeNd(state, self, size->size(), THLongStorage_data(size), (stride ? THLongStorage_data(stride) : NULL));
+#ifdef DEBUG
+ THAssert(size.size() <= INT_MAX);
+#endif
+ THCTensor_resizeNd(state, self, size.size(), size.data(), stride.data());
}
void THCTensor_resizeAs(THCState *state, THCTensor *self, THCTensor *src) {
@@ -95,7 +92,7 @@
THCTensor_resizeNd(state, self, src->dim(), THTensor_getSizePtr(src), NULL);
}
-void THCTensor_resizeNd(THCState *state, THCTensor *self, int nDimension, int64_t *size, int64_t *stride)
+void THCTensor_resizeNd(THCState *state, THCTensor *self, int nDimension, const int64_t *size, const int64_t *stride)
{
int d;
ptrdiff_t totalSize;
@@ -172,7 +169,22 @@
THTensor_getStridePtr(src));
}
-void THCTensor_setStorageNd(THCState *state, THCTensor *self, THCStorage *storage, ptrdiff_t storageOffset, int nDimension, int64_t *size, int64_t *stride)
+void THCTensor_setStorage(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_, at::IntList size_, at::IntList stride_)
+{
+ if(size_.data() && stride_.data())
+ THArgCheck(size_.size() == stride_.size(), 5, "inconsistent size/stride sizes");
+
+ AT_CHECK(size_.data(), "size must not be null");
+ THCTensor_setStorageNd(state,
+ self,
+ storage_,
+ storageOffset_,
+ size_.size(),
+ size_.data(),
+ stride_.data());
+}
+
+void THCTensor_setStorageNd(THCState *state, THCTensor *self, THCStorage *storage, ptrdiff_t storageOffset, int nDimension, const int64_t *size, const int64_t *stride)
{
/* storage */
if(THTensor_getStoragePtr(self) != storage)
diff --git a/aten/src/THC/THCTensor.hpp b/aten/src/THC/THCTensor.hpp
index 2ed8cbf..b6aadb3 100644
--- a/aten/src/THC/THCTensor.hpp
+++ b/aten/src/THC/THCTensor.hpp
@@ -19,16 +19,16 @@
THC_API int64_t THCTensor_sizeLegacyNoScalars(THCState *state, const THCTensor *self, int dim);
THC_API int64_t THCTensor_stride(THCState *state, const THCTensor *self, int dim);
THC_API int64_t THCTensor_strideLegacyNoScalars(THCState *state, const THCTensor *self, int dim);
-THC_API THLongStorage *THCTensor_newSizeOf(THCState *state, THCTensor *self);
THC_API THCTensor *THCTensor_new(THCState *state, at::ScalarType scalar_type);
-THC_API void THCTensor_resize(THCState *state, THCTensor *tensor, THLongStorage *size, THLongStorage *stride);
-THC_API void THCTensor_resizeNd(THCState *state, THCTensor *tensor, int nDimension, int64_t *size, int64_t *stride);
+THC_API void THCTensor_resize(THCState *state, THCTensor *tensor, at::IntList size, at::IntList stride);
+THC_API void THCTensor_resizeNd(THCState *state, THCTensor *tensor, int nDimension, const int64_t *size, const int64_t *stride);
THC_API void THCTensor_resizeAs(THCState *state, THCTensor *tensor, THCTensor *src);
THC_API void THCTensor_set(THCState *state, THCTensor *self, THCTensor *src);
-THC_API void THCTensor_setStorageNd(THCState *state, THCTensor *self, THCStorage *storage, ptrdiff_t storageOffset, int nDimension, int64_t *size, int64_t *stride);
+THC_API void THCTensor_setStorage(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_, at::IntList size_, at::IntList stride_);
+THC_API void THCTensor_setStorageNd(THCState *state, THCTensor *self, THCStorage *storage, ptrdiff_t storageOffset, int nDimension, const int64_t *size, const int64_t *stride);
THC_API void THCTensor_squeeze1d(THCState *state, THCTensor *self, THCTensor *src, int dimension_);
THC_API void THCTensor_unsqueeze1d(THCState *state, THCTensor *self, THCTensor *src, int dimension_);
@@ -53,3 +53,6 @@
/* has more than one index that references the same datapoint, */
/* true otherwise. */
THC_API bool THCTensor_maybeOverlappingIndices(THCState* state, const THCTensor* t);
+
+#include "generic/THCTensor.hpp"
+#include "THCGenerateAllTypes.h"
diff --git a/aten/src/THC/THCTensorMathCompare.cuh b/aten/src/THC/THCTensorMathCompare.cuh
index 9fac608..93b9852 100644
--- a/aten/src/THC/THCTensorMathCompare.cuh
+++ b/aten/src/THC/THCTensorMathCompare.cuh
@@ -73,9 +73,7 @@
TensorTypeOut *self_,
TensorType *src,
Op op) {
- THLongStorage* st = THCTensor_newSizeOf(state, src);
- THCTensor_resize(state, self_, st, NULL);
- THLongStorage_free(st);
+ THCTensor_resize(state, self_, src->sizes(), {});
if (!THC_pointwiseApply2<ScalarTypeOut, ScalarType>(state, self_, src, op)) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
diff --git a/aten/src/THC/THCTensorMathCompareT.cuh b/aten/src/THC/THCTensorMathCompareT.cuh
index 9b1fb4e..6b7d1d6 100644
--- a/aten/src/THC/THCTensorMathCompareT.cuh
+++ b/aten/src/THC/THCTensorMathCompareT.cuh
@@ -56,9 +56,7 @@
TensorType *src1,
TensorType *src2,
Op op) {
- THLongStorage* st = THCTensor_newSizeOf(state, src1);
- THCTensor_resize(state, self_, st, NULL);
- THLongStorage_free(st);
+ THCTensor_resize(state, self_, src1->sizes(), {});
THArgCheck(THCTensor_nElement(state, src1) ==
THCTensor_nElement(state, src2), 3,
diff --git a/aten/src/THC/THCTensorMathReduce.cuh b/aten/src/THC/THCTensorMathReduce.cuh
index cc3c7cb..9796b10 100644
--- a/aten/src/THC/THCTensorMathReduce.cuh
+++ b/aten/src/THC/THCTensorMathReduce.cuh
@@ -666,11 +666,10 @@
THCTensor_preserveReduceDimSemantics(
state, tgt2_, src_dims, dimension, keepdim);
- THLongStorage *dim = THCTensor_newSizeOf(state, src);
- THLongStorage_set(dim, dimension, 1);
- THCTensor_resize(state, tgt1_, dim, NULL);
- THCTensor_resize(state, tgt2_, dim, NULL);
- THLongStorage_free(dim);
+ std::vector<int64_t> dim = src->sizes().vec();
+ dim[dimension] = 1;
+ THCTensor_resize(state, tgt1_, dim, {});
+ THCTensor_resize(state, tgt2_, dim, {});
TensorTypeK *tgt1 = (TensorTypeK*)THCTensor_newContiguous<ScalarTypeK>(state, tgt1_);
TensorTypeIndex *tgt2 = (TensorTypeIndex*)THCTensor_newContiguous<ScalarTypeIndex>(state, tgt2_);
diff --git a/aten/src/THC/generic/THCTensor.cpp b/aten/src/THC/generic/THCTensor.cpp
index 2290067..e9380fc 100644
--- a/aten/src/THC/generic/THCTensor.cpp
+++ b/aten/src/THC/generic/THCTensor.cpp
@@ -2,6 +2,8 @@
#define THC_GENERIC_FILE "generic/THCTensor.cpp"
#else
+#include <ATen/InferSize.h>
+
/**** access methods ****/
THCStorage *THCTensor_(storage)(THCState *state, const THCTensor *self)
{
@@ -43,18 +45,6 @@
return THTensor_strideLegacyNoScalars(self, dim);
}
-THLongStorage *THCTensor_(newSizeOf)(THCState *state, THCTensor *self)
-{
- return THCTensor_newSizeOf(state, self);
-}
-
-THLongStorage *THCTensor_(newStrideOf)(THCState *state, THCTensor *self)
-{
- THLongStorage *stride = THLongStorage_newWithSize(self->dim());
- THLongStorage_rawCopy(stride, THTensor_getStridePtr(self));
- return stride;
-}
-
real *THCTensor_(data)(THCState *state, const THCTensor *self)
{
if(THTensor_getStoragePtr(self))
@@ -86,26 +76,11 @@
}
/* Storage init */
-THCTensor *THCTensor_(newWithStorage)(THCState *state, THCStorage *storage, ptrdiff_t storageOffset, THLongStorage *size, THLongStorage *stride)
-{
- if(size && stride)
- THArgCheck(size->size() == stride->size(), 4, "inconsistent size");
-
- AT_CHECK(size, "size must not be null");
- THCTensor *self = new THCTensor(THCStorage_(new)(state));
- THCTensor_(setStorageNd)(state,
- self,
- storage,
- storageOffset,
- size->size(),
- THLongStorage_data(size),
- (stride ? THLongStorage_data(stride) : NULL));
-
- return self;
-}
-
-THCTensor *THCTensor_(newWithStorageIntLists)(THCState *state, THCStorage *storage, ptrdiff_t storageOffset, at::IntList sizes, at::IntList strides) {
- AT_CHECK(sizes.size() == strides.size(), "number of sizes and strides must match");
+THCTensor *THCTensor_(newWithStorage)(THCState *state, THCStorage *storage, ptrdiff_t storageOffset, at::IntList sizes, at::IntList strides) {
+ if (sizes.data() && strides.data()) {
+ AT_CHECK(sizes.size() == strides.size(), "number of sizes and strides must match");
+ }
+ AT_CHECK(sizes.data(), "size must not be null");
THCTensor *self = new THCTensor(THCStorage_(new)(state));
THCTensor_(setStorageNd)(state, self, storage, storageOffset, sizes.size(),
const_cast<int64_t*>(sizes.data()), const_cast<int64_t*>(strides.data()));
@@ -116,14 +91,14 @@
THCTensor *THCTensor_(newWithStorage1d)(THCState *state, THCStorage *storage, ptrdiff_t storageOffset,
int64_t size0, int64_t stride0)
{
- return THCTensor_(newWithStorageIntLists)(state, storage, storageOffset, {size0}, {stride0});
+ return THCTensor_(newWithStorage)(state, storage, storageOffset, {size0}, {stride0});
}
THCTensor *THCTensor_(newWithStorage2d)(THCState *state, THCStorage *storage, ptrdiff_t storageOffset,
int64_t size0, int64_t stride0,
int64_t size1, int64_t stride1)
{
- return THCTensor_(newWithStorageIntLists)(state, storage, storageOffset, {size0, size1}, {stride0, stride1});
+ return THCTensor_(newWithStorage)(state, storage, storageOffset, {size0, size1}, {stride0, stride1});
}
THCTensor *THCTensor_(newWithStorage3d)(THCState *state, THCStorage *storage, ptrdiff_t storageOffset,
@@ -131,7 +106,7 @@
int64_t size1, int64_t stride1,
int64_t size2, int64_t stride2)
{
- return THCTensor_(newWithStorageIntLists)(state, storage, storageOffset, {size0, size1, size2}, {stride0, stride1, stride2});
+ return THCTensor_(newWithStorage)(state, storage, storageOffset, {size0, size1, size2}, {stride0, stride1, stride2});
}
THCTensor *THCTensor_(newWithStorage4d)(THCState *state, THCStorage *storage, ptrdiff_t storageOffset,
@@ -140,41 +115,34 @@
int64_t size2, int64_t stride2,
int64_t size3, int64_t stride3)
{
- return THCTensor_(newWithStorageIntLists)(state, storage, storageOffset,
+ return THCTensor_(newWithStorage)(state, storage, storageOffset,
{size0, size1, size2, size3},
{stride0, stride1, stride2, stride3});
}
-THCTensor *THCTensor_(newWithSize)(THCState *state, THLongStorage *size, THLongStorage *stride)
+THCTensor *THCTensor_(newWithSize)(THCState *state, at::IntList size, at::IntList stride)
{
return THCTensor_(newWithStorage)(state, NULL, 0, size, stride);
}
-THCTensor *THCTensor_(newWithSizeIntList)(THCState *state, at::IntList sizes) {
- THCTensor *self = new THCTensor(THCStorage_(new)(state));
- THCTensor_(resizeNd)(state, self, sizes.size(), const_cast<int64_t*>(sizes.data()), nullptr);
-
- return self;
-}
-
THCTensor *THCTensor_(newWithSize1d)(THCState *state, int64_t size0)
{
- return THCTensor_(newWithSizeIntList)(state, {size0});
+ return THCTensor_(newWithSize)(state, {size0}, {});
}
THCTensor *THCTensor_(newWithSize2d)(THCState *state, int64_t size0, int64_t size1)
{
- return THCTensor_(newWithSizeIntList)(state, {size0, size1});
+ return THCTensor_(newWithSize)(state, {size0, size1}, {});
}
THCTensor *THCTensor_(newWithSize3d)(THCState *state, int64_t size0, int64_t size1, int64_t size2)
{
- return THCTensor_(newWithSizeIntList)(state, {size0, size1, size2});
+ return THCTensor_(newWithSize)(state, {size0, size1, size2}, {});
}
THCTensor *THCTensor_(newWithSize4d)(THCState *state, int64_t size0, int64_t size1, int64_t size2, int64_t size3)
{
- return THCTensor_(newWithSizeIntList)(state, {size0, size1, size2, size3});
+ return THCTensor_(newWithSize)(state, {size0, size1, size2, size3}, {});
}
THCTensor *THCTensor_(newClone)(THCState *state, THCTensor *self)
@@ -223,23 +191,19 @@
return self;
}
-THCTensor *THCTensor_(newView)(THCState *state, THCTensor *tensor, THLongStorage *size)
+THCTensor *THCTensor_(newView)(THCState *state, THCTensor *tensor, at::IntList size)
{
ptrdiff_t numel = THCTensor_(nElement)(state, tensor);
THCTensor *self = THCTensor_(new)(state);
- THLongStorage *inferred_size = THLongStorage_newInferSize(size, numel);
+ auto inferred_size = at::infer_size(size, numel);
auto stride = THTensor_compute_stride(tensor->sizes(),
tensor->strides(),
- at::IntList(inferred_size->data<int64_t>(), inferred_size->size()));
+ inferred_size);
THArgCheck(stride.has_value(), 2, "view size is "
"not compatible with input tensor's size and stride (at least one dimension spans "
"across two contiguous subspaces). Call .contiguous() before .view().");
auto stride_value = *stride;
- THLongStorage *new_stride = THLongStorage_newWithSize(stride_value.size());
- THLongStorage_rawCopy(new_stride, stride_value.data());
- THCTensor_(setStorage)(state, self, THTensor_getStoragePtr(tensor), tensor->storage_offset(), inferred_size, new_stride);
- THLongStorage_free(inferred_size);
- THLongStorage_free(new_stride);
+ THCTensor_setStorage(state, self, THTensor_getStoragePtr(tensor), tensor->storage_offset(), inferred_size, stride_value);
return self;
}
@@ -250,18 +214,17 @@
THArgCheck(in_dims >= 2, 1, "Tensor needs to have at least two dimensions");
THArgCheck(THCTensor_(isContiguous)(state, input), 1,
"Tensor must be contiguous");
- THLongStorage *newSize = THLongStorage_newWithSize(in_dims - 1);
- THLongStorage_data(newSize)[0] = THCTensor_(size)(state, input, 0) * THCTensor_(size)(state, input, 1);
+ std::vector<int64_t> new_size(in_dims - 1);
+ new_size[0] = THCTensor_(size)(state, input, 0) * THCTensor_(size)(state, input, 1);
for (int i = 2; i < in_dims; i++) {
- THLongStorage_data(newSize)[i - 1] = THCTensor_(size)(state, input, i);
+ new_size[i - 1] = THCTensor_(size)(state, input, i);
}
- THCTensor *output = THCTensor_(newView)(state, input, newSize);
- THLongStorage_free(newSize);
+ THCTensor *output = THCTensor_(newView)(state, input, new_size);
return output;
}
/* Resize */
-void THCTensor_(resize)(THCState *state, THCTensor *self, THLongStorage *size, THLongStorage *stride)
+void THCTensor_(resize)(THCState *state, THCTensor *self, at::IntList size, at::IntList stride)
{
THCTensor_resize(state, self, size, stride);
}
@@ -306,44 +269,24 @@
THCTensor_set(state, self, src);
}
-void THCTensor_(setStorage)(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_, THLongStorage *size_, THLongStorage *stride_)
-{
- if(size_ && stride_)
- THArgCheck(size_->size() == stride_->size(), 5, "inconsistent size/stride sizes");
-
- AT_CHECK(size_, "size must not be null");
- THCTensor_(setStorageNd)(state,
- self,
- storage_,
- storageOffset_,
- size_->size(),
- THLongStorage_data(size_),
- (stride_ ? THLongStorage_data(stride_) : NULL));
-}
-
-void THCTensor_(setStorageIntLists)(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_,
- at::IntList sizes, at::IntList strides)
-{
- AT_CHECK(sizes.size() == strides.size(), "number of sizes and strides must match");
-
- THCTensor_(setStorageNd)(state, self, storage_, storageOffset_, sizes.size(),
- const_cast<int64_t*>(sizes.data()), const_cast<int64_t*>(strides.data()));
+void THCTensor_(setStorage)(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_, at::IntList size_, at::IntList stride_) {
+ THCTensor_setStorage(state, self, storage_, storageOffset_, size_, stride_);
}
void THCTensor_(setStorage1d)(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_,
int64_t size0_, int64_t stride0_)
{
- THCTensor_(setStorageIntLists)(state, self, storage_, storageOffset_,
- {size0_}, {stride0_});
+ THCTensor_(setStorage)(state, self, storage_, storageOffset_,
+ {size0_}, {stride0_});
}
void THCTensor_(setStorage2d)(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_,
int64_t size0_, int64_t stride0_,
int64_t size1_, int64_t stride1_)
{
- THCTensor_(setStorageIntLists)(state, self, storage_, storageOffset_,
- {size0_, size1_},
- {stride0_, stride1_});
+ THCTensor_(setStorage)(state, self, storage_, storageOffset_,
+ {size0_, size1_},
+ {stride0_, stride1_});
}
void THCTensor_(setStorage3d)(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_,
@@ -351,9 +294,9 @@
int64_t size1_, int64_t stride1_,
int64_t size2_, int64_t stride2_)
{
- THCTensor_(setStorageIntLists)(state, self, storage_, storageOffset_,
- {size0_, size1_, size2_},
- {stride0_, stride1_, stride2_});
+ THCTensor_(setStorage)(state, self, storage_, storageOffset_,
+ {size0_, size1_, size2_},
+ {stride0_, stride1_, stride2_});
}
void THCTensor_(setStorage4d)(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_,
@@ -366,7 +309,7 @@
int64_t size[4] = {size0_, size1_, size2_, size3_};
int64_t stride[4] = {stride0_, stride1_, stride2_, stride3_};
- THCTensor_(setStorageIntLists)(state, self, storage_, storageOffset_, size, stride);
+ THCTensor_(setStorage)(state, self, storage_, storageOffset_, size, stride);
}
diff --git a/aten/src/THC/generic/THCTensor.h b/aten/src/THC/generic/THCTensor.h
index 3f6b94e..b25e408 100644
--- a/aten/src/THC/generic/THCTensor.h
+++ b/aten/src/THC/generic/THCTensor.h
@@ -29,8 +29,6 @@
THC_API int64_t THCTensor_(sizeLegacyNoScalars)(THCState *state, const THCTensor *self, int dim);
THC_API int64_t THCTensor_(stride)(THCState *state, const THCTensor *self, int dim);
THC_API int64_t THCTensor_(strideLegacyNoScalars)(THCState *state, const THCTensor *self, int dim);
-THC_API THLongStorage *THCTensor_(newSizeOf)(THCState *state, THCTensor *self);
-THC_API THLongStorage *THCTensor_(newStrideOf)(THCState *state, THCTensor *self);
THC_API real *THCTensor_(data)(THCState *state, const THCTensor *self);
THC_API void THCTensor_(setFlag)(THCState *state, THCTensor *self, const char flag);
@@ -40,8 +38,6 @@
/**** creation methods ****/
THC_API THCTensor *THCTensor_(new)(THCState *state);
THC_API THCTensor *THCTensor_(newWithTensor)(THCState *state, THCTensor *tensor);
-/* stride might be NULL */
-THC_API THCTensor *THCTensor_(newWithStorage)(THCState *state, THCStorage *storage_, ptrdiff_t storageOffset_, THLongStorage *size_, THLongStorage *stride_);
THC_API THCTensor *THCTensor_(newWithStorage1d)(THCState *state, THCStorage *storage_, ptrdiff_t storageOffset_,
int64_t size0_, int64_t stride0_);
THC_API THCTensor *THCTensor_(newWithStorage2d)(THCState *state, THCStorage *storage_, ptrdiff_t storageOffset_,
@@ -58,7 +54,6 @@
int64_t size3_, int64_t stride3_);
/* stride might be NULL */
-THC_API THCTensor *THCTensor_(newWithSize)(THCState *state, THLongStorage *size_, THLongStorage *stride_);
THC_API THCTensor *THCTensor_(newWithSize1d)(THCState *state, int64_t size0_);
THC_API THCTensor *THCTensor_(newWithSize2d)(THCState *state, int64_t size0_, int64_t size1_);
THC_API THCTensor *THCTensor_(newWithSize3d)(THCState *state, int64_t size0_, int64_t size1_, int64_t size2_);
@@ -70,13 +65,11 @@
THC_API THCTensor *THCTensor_(newNarrow)(THCState *state, THCTensor *tensor, int dimension_, int64_t firstIndex_, int64_t size_);
THC_API THCTensor *THCTensor_(newTranspose)(THCState *state, THCTensor *tensor, int dimension1_, int dimension2_);
THC_API THCTensor *THCTensor_(newUnfold)(THCState *state, THCTensor *tensor, int dimension_, int64_t size_, int64_t step_);
-THC_API THCTensor *THCTensor_(newView)(THCState *state, THCTensor *tensor, THLongStorage *size);
THC_API THCTensor *THCTensor_(newFoldBatchDim)(THCState *state, THCTensor *input);
// resize* methods simply resize the storage. So they may not retain the current data at current indices.
// This is especially likely to happen when the tensor is not contiguous. In general, if you still need the
// values, unless you are doing some size and stride tricks, do not use resize*.
-THC_API void THCTensor_(resize)(THCState *state, THCTensor *tensor, THLongStorage *size, THLongStorage *stride);
THC_API void THCTensor_(resizeNd)(THCState *state, THCTensor *tensor, int nDimension, int64_t *size, int64_t *stride);
THC_API void THCTensor_(resizeAs)(THCState *state, THCTensor *tensor, THCTensor *src);
THC_API void THCTensor_(resize1d)(THCState *state, THCTensor *tensor, int64_t size0_);
@@ -86,7 +79,6 @@
THC_API void THCTensor_(resize5d)(THCState *state, THCTensor *tensor, int64_t size0_, int64_t size1_, int64_t size2_, int64_t size3_, int64_t size4_);
THC_API void THCTensor_(set)(THCState *state, THCTensor *self, THCTensor *src);
-THC_API void THCTensor_(setStorage)(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_, THLongStorage *size_, THLongStorage *stride_);
THC_API void THCTensor_(setStorageNd)(THCState *state, THCTensor *self, THCStorage *storage, ptrdiff_t storageOffset, int nDimension, int64_t *size, int64_t *stride);
THC_API void THCTensor_(setStorage1d)(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_,
int64_t size0_, int64_t stride0_);
diff --git a/aten/src/THC/generic/THCTensor.hpp b/aten/src/THC/generic/THCTensor.hpp
new file mode 100644
index 0000000..fee43d1
--- /dev/null
+++ b/aten/src/THC/generic/THCTensor.hpp
@@ -0,0 +1,21 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/THCTensor.hpp"
+#else
+
+// STOP!!! Thinking of including this header directly? Please
+// read Note [TH abstraction violation]
+
+// NOTE: functions exist here only to support dispatch via Declarations.cwrap. You probably don't want to put
+// new functions in here, they should probably be un-genericized.
+
+THC_API void THCTensor_(setStorage)(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_,
+ at::IntList size_, at::IntList stride_);
+THC_API THCTensor *THCTensor_(newView)(THCState *state, THCTensor *tensor, at::IntList size);
+/* strides.data() might be nullptr */
+THC_API THCTensor *THCTensor_(newWithStorage)(THCState *state, THCStorage *storage, ptrdiff_t storageOffset,
+ at::IntList sizes, at::IntList strides);
+
+THC_API void THCTensor_(resize)(THCState *state, THCTensor *self, at::IntList size, at::IntList stride);
+THC_API THCTensor *THCTensor_(newWithSize)(THCState *state, at::IntList size, at::IntList stride);
+
+#endif
diff --git a/aten/src/THC/generic/THCTensorCopy.cpp b/aten/src/THC/generic/THCTensorCopy.cpp
index f108ca1..3cc88ed 100644
--- a/aten/src/THC/generic/THCTensorCopy.cpp
+++ b/aten/src/THC/generic/THCTensorCopy.cpp
@@ -32,13 +32,11 @@
if(THCTypeIdx_(Real) == THCTypeIdx_(TYPEC)) { \
THCTensor_(copyCPU)(state, self, (THTensor*) src); /* cast just removes warnings */ \
} else { \
- THLongStorage *size = TH##TYPEC##Tensor_newSizeOf(src); \
- THTensor *srcf = THTensor_(newWithSize)(size, NULL); \
+ THTensor *srcf = THTensor_(newWithSize)(src->sizes(), {}); \
\
THTensor_(copy##TYPEC)(srcf, src); \
THCTensor_(copyCPU)(state, self, srcf); \
\
- THLongStorage_free(size); \
THTensor_(free)(srcf); \
} \
}
@@ -82,13 +80,11 @@
if(THCTypeIdx_(Real) == THCTypeIdx_(TYPEC)) { \
THTensor_(copyCuda)(state, (THTensor*) self, src); /* cast just removes compiler warning */ \
} else { \
- THLongStorage *size = THCTensor_(newSizeOf)(state, src); \
- THTensor *srcf = THTensor_(newWithSize)(size, NULL); \
+ THTensor *srcf = THTensor_(newWithSize)(src->sizes(), {}); \
\
THTensor_(copyCuda)(state, srcf, src); \
TH_CONCAT_4(TH,TYPEC,Tensor_copy,Real)(self, srcf); \
\
- THLongStorage_free(size); \
THTensor_(free)(srcf); \
} \
}
diff --git a/aten/src/THC/generic/THCTensorIndex.cu b/aten/src/THC/generic/THCTensorIndex.cu
index 111ce61..3abe739 100644
--- a/aten/src/THC/generic/THCTensorIndex.cu
+++ b/aten/src/THC/generic/THCTensorIndex.cu
@@ -535,12 +535,9 @@
THArgCheck(dim < srcDims, 4, "Indexing dim is out of bounds");
THArgCheck(srcDims > 0, 2, "Source tensor is empty");
- THLongStorage *newSize;
-
- newSize = THCTensor_(newSizeOf)(state, src);
- THLongStorage_set(newSize, dim, numIndices);
- THCTensor_(resize)(state, dst, newSize, NULL);
- THLongStorage_free(newSize);
+ std::vector<int64_t> newSize = src->sizes().vec();
+ newSize[dim] = numIndices;
+ THCTensor_(resize)(state, dst, newSize, {});
ptrdiff_t dstTotalSize = THCTensor_(nElement)(state, dst);
if (dstTotalSize == 0) {
diff --git a/aten/src/THC/generic/THCTensorMasked.cu b/aten/src/THC/generic/THCTensorMasked.cu
index 80c1344..d941134 100644
--- a/aten/src/THC/generic/THCTensorMasked.cu
+++ b/aten/src/THC/generic/THCTensorMasked.cu
@@ -25,9 +25,7 @@
THCTensor *tensor, THByteTensor *mask, real value)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, tensor));
- THLongStorage* maskSizes = THByteTensor_newSizeOf(mask);
- THCudaByteTensor* maskCuda = THCudaByteTensor_newWithSize(state, maskSizes, NULL);
- THLongStorage_free(maskSizes);
+ THCudaByteTensor* maskCuda = THCudaByteTensor_newWithSize(state, mask->sizes(), {});
THCudaByteTensor_copyByte(state, maskCuda, mask);
THCTensor_(maskedFill)(state, tensor, maskCuda, value);
THCudaByteTensor_free(state, maskCuda);
@@ -59,14 +57,13 @@
// iterator prefix sums? Convert `mask` to the same datatype as what
// we're accumulating the prefix sum in (int64_t) to get around it
THCudaLongTensor* maskLong = THCudaLongTensor_new(state);
- THLongStorage* maskSizes = THCudaByteTensor_newSizeOf(state, mask);
- THCudaLongTensor_resize(state, maskLong, maskSizes, NULL);
+ at::IntList maskSizes = mask->sizes();
+ THCudaLongTensor_resize(state, maskLong, maskSizes, {});
THCudaLongTensor_copyCudaByte(state, maskLong, mask);
// Use a prefix sum to determine the output locations of the masked elements
THCudaLongTensor* maskPrefixSum = THCudaLongTensor_new(state);
- THCudaLongTensor_resize(state, maskPrefixSum, maskSizes, NULL);
- THLongStorage_free(maskSizes);
+ THCudaLongTensor_resize(state, maskPrefixSum, maskSizes, {});
THCThrustAllocator thrustAlloc(state);
thrust::device_ptr<int64_t>
@@ -105,9 +102,7 @@
THCTensor_(maskedCopyByte)(THCState* state,
THCTensor *tensor, THByteTensor *mask, THCTensor *src) {
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, tensor, src));
- THLongStorage* maskSizes = THByteTensor_newSizeOf(mask);
- THCudaByteTensor* maskCuda = THCudaByteTensor_newWithSize(state, maskSizes, NULL);
- THLongStorage_free(maskSizes);
+ THCudaByteTensor* maskCuda = THCudaByteTensor_newWithSize(state, mask->sizes(), {});
THCudaByteTensor_copyByte(state, maskCuda, mask);
THCTensor_(maskedCopy)(state, tensor, maskCuda, src);
THCudaByteTensor_free(state, maskCuda);
@@ -134,14 +129,13 @@
// iterator prefix sums? Convert `mask` to the same datatype as what
// we're accumulating the prefix sum in (int64_t) to get around it
THCudaLongTensor* maskLong = THCudaLongTensor_new(state);
- THLongStorage* maskSizes = THCudaByteTensor_newSizeOf(state, mask);
- THCudaLongTensor_resize(state, maskLong, maskSizes, NULL);
+ at::IntList maskSizes = mask->sizes();
+ THCudaLongTensor_resize(state, maskLong, maskSizes, {});
THCudaLongTensor_copyCudaByte(state, maskLong, mask);
// Use a prefix sum to determine the output locations of the masked elements
THCudaLongTensor* maskPrefixSum = THCudaLongTensor_new(state);
- THCudaLongTensor_resize(state, maskPrefixSum, maskSizes, NULL);
- THLongStorage_free(maskSizes);
+ THCudaLongTensor_resize(state, maskPrefixSum, maskSizes, {});
THCThrustAllocator thrustAlloc(state);
thrust::device_ptr<int64_t>
@@ -182,9 +176,7 @@
THCTensor *tensor, THCTensor *src, THByteTensor *mask)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, tensor, src));
- THLongStorage* maskSizes = THByteTensor_newSizeOf(mask);
- THCudaByteTensor* maskCuda = THCudaByteTensor_newWithSize(state, maskSizes, NULL);
- THLongStorage_free(maskSizes);
+ THCudaByteTensor* maskCuda = THCudaByteTensor_newWithSize(state, mask->sizes(), {});
THCudaByteTensor_copyByte(state, maskCuda, mask);
THCTensor_(maskedSelect)(state, tensor, src, maskCuda);
THCudaByteTensor_free(state, maskCuda);
diff --git a/aten/src/THC/generic/THCTensorMath.cu b/aten/src/THC/generic/THCTensorMath.cu
index 7fa0dbe..54fb093 100644
--- a/aten/src/THC/generic/THCTensorMath.cu
+++ b/aten/src/THC/generic/THCTensorMath.cu
@@ -96,7 +96,6 @@
// to be "skipped". We maintain this behavior for backwards compatibility, but only for this specific
// size (i.e. other empty sizes are not skipped).
// FIXME: warn if this is the case
- THLongStorage *size;
int i, j, cohortMax;
int64_t offset;
bool hasSkippedInput = false;
@@ -122,7 +121,7 @@
THArgCheck(numInputs > 0, 3, "invalid number of inputs %d", numInputs);
THArgCheck(dimension >= 0, 4, "invalid dimension %d", dimension);
- size = THLongStorage_newWithSize(nDims);
+ std::vector<int64_t> size(nDims);
// Compute size of the result in the cat dimension
int64_t cat_dim_size = 0;
@@ -141,10 +140,9 @@
if (dim == dimension) {
result_dim_size = cat_dim_size;
}
- THLongStorage_data(size)[dim] = result_dim_size;
+ size[dim] = result_dim_size;
}
- THCTensor_(resize)(state, result, size, NULL);
- THLongStorage_free(size);
+ THCTensor_(resize)(state, result, size, {});
// We parallelize the copy if all 6 conditions pass:
//
diff --git a/aten/src/THC/generic/THCTensorMathPairwise.cu b/aten/src/THC/generic/THCTensorMathPairwise.cu
index 62c57a0..788d53b 100644
--- a/aten/src/THC/generic/THCTensorMathPairwise.cu
+++ b/aten/src/THC/generic/THCTensorMathPairwise.cu
@@ -256,8 +256,7 @@
// 1 if the two tensors are equal at a position, otherwise 0. If the minimum value
// in this buffer is 1, the two tensors are equal, otherwise they are not
- THLongStorage *size = THCTensor_(newSizeOf)(state, self_);
- THCudaByteTensor *buf = THCudaByteTensor_newWithSize(state, size, NULL);
+ THCudaByteTensor *buf = THCudaByteTensor_newWithSize(state, self_->sizes(), {});
if (!THC_pointwiseApply3<uint8_t, real, real>(state, buf, self_, src_, TensorEQOp<real, unsigned char>())) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
@@ -265,7 +264,6 @@
unsigned char min = THCudaByteTensor_minall(state, buf);
- THLongStorage_free(size);
THCudaByteTensor_free(state, buf);
return min != 0;
diff --git a/aten/src/THC/generic/THCTensorMathReduce.cu b/aten/src/THC/generic/THCTensorMathReduce.cu
index 614ce49..be6527e 100644
--- a/aten/src/THC/generic/THCTensorMathReduce.cu
+++ b/aten/src/THC/generic/THCTensorMathReduce.cu
@@ -95,10 +95,9 @@
THCTensor_preserveReduceDimSemantics(
state, self_, THCTensor_(nDimensionLegacyAll)(state, src), dimension, keepdim);
- THLongStorage *dim = THCTensor_(newSizeOf)(state, src);
- THLongStorage_set(dim, dimension, 1);
- THCTensor_(resize)(state, self_, dim, NULL);
- THLongStorage_free(dim);
+ std::vector<int64_t> dim = src->sizes().vec();
+ dim[dimension] = 1;
+ THCTensor_(resize)(state, self_, dim, {});
THCTensor *self = THCTensor_(newContiguous)(state, self_);
src = THCTensor_(newContiguous)(state, src);
@@ -124,10 +123,9 @@
THCTensor_preserveReduceDimSemantics(
state, self_, THCTensor_(nDimensionLegacyAll)(state, src), dimension, keepdim);
- THLongStorage *dim = THCTensor_(newSizeOf)(state, src);
- THLongStorage_set(dim, dimension, 1);
- THCTensor_(resize)(state, self_, dim, NULL);
- THLongStorage_free(dim);
+ std::vector<int64_t> dim = src->sizes().vec();
+ dim[dimension] = 1;
+ THCTensor_(resize)(state, self_, dim, {});
THCTensor *self = THCTensor_(newContiguous)(state, self_);
src = THCTensor_(newContiguous)(state, src);
@@ -375,10 +373,7 @@
nelem = THCTensor_(nElement)(state, self);
k = (nelem-1) >> 1;
- THLongStorage *size = THLongStorage_newWithSize1(nelem);
- THCTensor *view = THCTensor_(newView)(state, self, size);
-
- THLongStorage_free(size);
+ THCTensor *view = THCTensor_(newView)(state, self, {nelem});
THCTensor *sorted = THCTensor_(new)(state);
THCudaLongTensor *indices = THCudaLongTensor_new(state);
diff --git a/aten/src/THC/generic/THCTensorMode.cu b/aten/src/THC/generic/THCTensorMode.cu
index 1903995..b0b5530 100644
--- a/aten/src/THC/generic/THCTensorMode.cu
+++ b/aten/src/THC/generic/THCTensorMode.cu
@@ -161,7 +161,6 @@
THCTensor *input,
int dimension,
int keepdim) {
- THLongStorage *dim;
THCTensor *transposed, *contiguous, *valuesTransposed;
THLongStorage *position;
THCudaLongStorage *sortBuffer;
@@ -184,11 +183,10 @@
state, values, ndim, dimension, keepdim);
THCTensor_preserveReduceDimSemantics(
state, indices, ndim, dimension, keepdim);
- dim = THCTensor_(newSizeOf)(state, input);
- THLongStorage_set(dim, dimension, 1);
- THCTensor_(resize)(state, values, dim, NULL);
- THCudaLongTensor_resize(state, indices, dim, NULL);
- THLongStorage_free(dim);
+ std::vector<int64_t> dim = input->sizes().vec();
+ dim[dimension] = 1;
+ THCTensor_(resize)(state, values, dim, {});
+ THCudaLongTensor_resize(state, indices, dim, {});
// If sliceSize is 1, copy input to values and set indices
if (sliceSize == 1) {
diff --git a/aten/src/THC/generic/THCTensorSort.cu b/aten/src/THC/generic/THCTensorSort.cu
index af81898..b6bcf6a 100644
--- a/aten/src/THC/generic/THCTensorSort.cu
+++ b/aten/src/THC/generic/THCTensorSort.cu
@@ -290,9 +290,7 @@
// Make sure sufficient output space is allocated
THCTensor_(resizeAs)(state, sorted, input);
- THLongStorage *inputSize = THCTensor_(newSizeOf)(state, input);
- THCudaLongTensor_resize(state, indices, inputSize, NULL);
- THLongStorage_free(inputSize);
+ THCudaLongTensor_resize(state, indices, input->sizes(), {});
// How large are the slices that we are sorting?
int64_t sliceSize = THCTensor_(sizeLegacyNoScalars)(state, input, dim);
diff --git a/aten/src/THC/generic/THCTensorTopK.cu b/aten/src/THC/generic/THCTensorTopK.cu
index c3b3c55..888679d 100644
--- a/aten/src/THC/generic/THCTensorTopK.cu
+++ b/aten/src/THC/generic/THCTensorTopK.cu
@@ -24,11 +24,10 @@
// Build the output size, which is the dim being selected set to
// size k
- THLongStorage* topKSize = THCTensor_(newSizeOf)(state, input);
- THLongStorage_set(topKSize, dim, k);
- THCTensor_(resize)(state, topK, topKSize, NULL);
- THCudaLongTensor_resize(state, indices, topKSize, NULL);
- THLongStorage_free(topKSize);
+ std::vector<int64_t> topKSize = input->sizes().vec();
+ topKSize[dim] = k;
+ THCTensor_(resize)(state, topK, topKSize, {});
+ THCudaLongTensor_resize(state, indices, topKSize, {});
#define RUN_K(INDEX_T, DIM, DIR) \
gatherTopK<real, INDEX_T, DIM, DIR> \
diff --git a/aten/src/THCUNN/generic/GatedLinearUnit.cu b/aten/src/THCUNN/generic/GatedLinearUnit.cu
index 9bd59ee..f38fb2f 100644
--- a/aten/src/THCUNN/generic/GatedLinearUnit.cu
+++ b/aten/src/THCUNN/generic/GatedLinearUnit.cu
@@ -16,9 +16,9 @@
THArgCheck(nIn % 2 == 0, 2, "Halving dimension must be even. Dim %d is size %ld",
dim + TH_INDEX_BASE, nIn);
const int64_t inputSize = THCTensor_(size)(state, input, dim) / 2;
- THLongStorage *newSizes = THCTensor_(newSizeOf)(state, input);
- THLongStorage_set(newSizes, dim, inputSize);
- THCTensor_(resize)(state, output, newSizes, NULL);
+ std::vector<int64_t> newSizes = input->sizes().vec();
+ newSizes[dim] = inputSize;
+ THCTensor_(resize)(state, output, newSizes, {});
// halve tensor
THCTensor *firstHalf = THCTensor_(newNarrow)(state, input, dim, 0, inputSize);
@@ -27,7 +27,6 @@
// x = x1:cmul( sigmoid(x2) )
THC_pointwiseApply3<real, real, real>(state, output, secondHalf, firstHalf, gatedLinearCSigMul_functor<real, accreal>());
- THLongStorage_free(newSizes);
THCTensor_(free)(state, firstHalf);
THCTensor_(free)(state, secondHalf);
}
diff --git a/aten/src/THCUNN/generic/LookupTable.cu b/aten/src/THCUNN/generic/LookupTable.cu
index 5188609..4b96c6d 100644
--- a/aten/src/THCUNN/generic/LookupTable.cu
+++ b/aten/src/THCUNN/generic/LookupTable.cu
@@ -56,10 +56,8 @@
return;
}
- THLongStorage *inputSize = THCIndexTensor_(newSizeOf)(state, input);
- THCIndexTensor_(resize)(state, sortedIndices, inputSize, NULL);
- THCIndexTensor_(resize)(state, origIndices, inputSize, NULL);
- THLongStorage_free(inputSize);
+ THCIndexTensor_(resize)(state, sortedIndices, input->sizes(), {});
+ THCIndexTensor_(resize)(state, origIndices, input->sizes(), {});
// Sort the inputs into sorted with the corresponding indices; we
// don't need a stable or multidimensional sort, so just use Thrust
diff --git a/aten/src/THCUNN/generic/LookupTableBag.cu b/aten/src/THCUNN/generic/LookupTableBag.cu
index 5ee4315..6d8cf77 100644
--- a/aten/src/THCUNN/generic/LookupTableBag.cu
+++ b/aten/src/THCUNN/generic/LookupTableBag.cu
@@ -31,15 +31,10 @@
cudaStream_t stream = THCState_getCurrentStream(state);
- THLongStorage *inputSize = THCIndexTensor_(newSizeOf)(state, input);
- THLongStorage *outputSize = THLongStorage_newWithSize(2);
- THLongStorage_data(outputSize)[0] = numBags;
- THLongStorage_data(outputSize)[1] = stride;
- THCTensor_(resize)(state, output, outputSize, NULL);
+ std::vector<int64_t> outputSize = {numBags, stride};
+ THCTensor_(resize)(state, output, outputSize, {});
THCTensor_(zero)(state, output);
- THCIndexTensor_(resize)(state, offset2bag, inputSize, NULL);
- THLongStorage_free(inputSize);
- THLongStorage_free(outputSize);
+ THCIndexTensor_(resize)(state, offset2bag, input->sizes(), {});
dim3 block = dim3(32, 8);
int grid = 1024;
@@ -99,10 +94,8 @@
cudaStream_t stream = THCState_getCurrentStream(state);
- THLongStorage *inputSize = THCIndexTensor_(newSizeOf)(state, input);
- THCIndexTensor_(resize)(state, sortedIndices, inputSize, NULL);
- THCIndexTensor_(resize)(state, origIndices, inputSize, NULL);
- THLongStorage_free(inputSize);
+ THCIndexTensor_(resize)(state, sortedIndices, input->sizes(), {});
+ THCIndexTensor_(resize)(state, origIndices, input->sizes(), {});
// Sort the inputs into sorted with the corresponding indices; we
// don't need a stable or multidimensional sort, so just use Thrust
diff --git a/aten/src/THCUNN/generic/VolumetricDilatedMaxPooling.cu b/aten/src/THCUNN/generic/VolumetricDilatedMaxPooling.cu
index 0ec1c0d..0938064 100644
--- a/aten/src/THCUNN/generic/VolumetricDilatedMaxPooling.cu
+++ b/aten/src/THCUNN/generic/VolumetricDilatedMaxPooling.cu
@@ -240,17 +240,10 @@
THCDeviceTensor<real, 4> cudaOutput;
cudaOutput = toDeviceTensor<real, 4>(state, output);
- THLongStorage *indicesSize = THLongStorage_newWithSize(4);
- int64_t indicesSizeRaw[4] = { batchSize * inputSlices,
- outputTime, outputHeight, outputWidth };
- THLongStorage_rawCopy(indicesSize, indicesSizeRaw);
-
THCIndexTensor *indices1 = THCIndexTensor_(newWithStorage)(
state, THCIndexTensor_(storage)(state, indices),
THCIndexTensor_(storageOffset)(state, indices),
- indicesSize, NULL);
-
- THLongStorage_free(indicesSize);
+ { batchSize * inputSlices, outputTime, outputHeight, outputWidth }, {});
THCDeviceTensor<THCIndex_t, 4> cudaIndices =
toDeviceTensor<THCIndex_t, 4>(state, indices1);
@@ -365,14 +358,10 @@
cudaGradOutput = toDeviceTensor<real, 4>(state, gradOutput);
real* gradInputData = THCTensor_(data)(state, gradInput);
- THLongStorage *indicesSize = THLongStorage_newWithSize(4);
- int64_t indicesSizeRaw[4] = { batchSize * inputSlices,
- outputTime, outputHeight, outputWidth };
- THLongStorage_rawCopy(indicesSize, indicesSizeRaw);
THCIndexTensor *indices1 = THCIndexTensor_(newWithStorage)(
state, THCIndexTensor_(storage)(state, indices),
- THCIndexTensor_(storageOffset)(state, indices), indicesSize, NULL);
- THLongStorage_free(indicesSize);
+ THCIndexTensor_(storageOffset)(state, indices),
+ { batchSize * inputSlices, outputTime, outputHeight, outputWidth }, {});
THCDeviceTensor<THCIndex_t, 4> cudaIndices =
toDeviceTensor<THCIndex_t, 4>(state, indices1);
diff --git a/aten/src/THNN/generic/GatedLinearUnit.c b/aten/src/THNN/generic/GatedLinearUnit.c
index 0f88874..fb5cea9 100644
--- a/aten/src/THNN/generic/GatedLinearUnit.c
+++ b/aten/src/THNN/generic/GatedLinearUnit.c
@@ -15,9 +15,9 @@
dim + TH_INDEX_BASE, nIn);
const int64_t inputSize = THTensor_(size)(input, dim) / 2;
- THLongStorage *newSizes = THTensor_(newSizeOf)(input);
- THLongStorage_set(newSizes, dim, inputSize);
- THTensor_(resize)(output, newSizes, NULL);
+ std::vector<int64_t> newSizes = input->sizes().vec();
+ newSizes[dim] = inputSize;
+ THTensor_(resize)(output, newSizes, {});
// halve tensor
THTensor *firstHalf = THTensor_(newNarrow)(input, dim, 0, inputSize);
@@ -27,7 +27,6 @@
THTensor_(sigmoid)(output, secondHalf);
THTensor_(cmul)(output, output, firstHalf);
- THLongStorage_free(newSizes);
THTensor_(free)(firstHalf);
THTensor_(free)(secondHalf);
}