Fix undefined symbol errors in THD
diff --git a/torch/lib/THD/base/TensorDescriptor.cpp b/torch/lib/THD/base/TensorDescriptor.cpp
index d7dcfd5..742f49c 100644
--- a/torch/lib/THD/base/TensorDescriptor.cpp
+++ b/torch/lib/THD/base/TensorDescriptor.cpp
@@ -1,4 +1,5 @@
#include "TensorDescriptor.hpp"
+#include "Cuda.hpp"
#include <THPP/tensors/THTensor.hpp>
#ifdef WITH_CUDA
#include <THPP/tensors/THCTensor.hpp>
@@ -40,41 +41,40 @@
}
#ifdef WITH_CUDA
-extern THCState* state;
THDTensorDescriptor* THDTensorDescriptor_newFromTHCudaDoubleTensor(THCudaDoubleTensor *tensor) {
- THCudaDoubleTensor_retain(state, tensor);
- return new thpp::THCTensor<double>(state, tensor);
+ THCudaDoubleTensor_retain(THDGetCudaState(), tensor);
+ return new thpp::THCTensor<double>(THDGetCudaState(), tensor);
}
THDTensorDescriptor* THDTensorDescriptor_newFromTHCudaFloatTensor(THCudaTensor *tensor) {
- THCudaTensor_retain(state, tensor);
- return new thpp::THCTensor<float>(state, tensor);
+ THCudaTensor_retain(THDGetCudaState(), tensor);
+ return new thpp::THCTensor<float>(THDGetCudaState(), tensor);
}
THDTensorDescriptor* THDTensorDescriptor_newFromTHCudaLongTensor(THCudaLongTensor *tensor) {
- THCudaLongTensor_retain(state, tensor);
- return new thpp::THCTensor<long>(state, tensor);
+ THCudaLongTensor_retain(THDGetCudaState(), tensor);
+ return new thpp::THCTensor<long>(THDGetCudaState(), tensor);
}
THDTensorDescriptor* THDTensorDescriptor_newFromTHCudaIntTensor(THCudaIntTensor *tensor) {
- THCudaIntTensor_retain(state, tensor);
- return new thpp::THCTensor<int>(state, tensor);
+ THCudaIntTensor_retain(THDGetCudaState(), tensor);
+ return new thpp::THCTensor<int>(THDGetCudaState(), tensor);
}
THDTensorDescriptor* THDTensorDescriptor_newFromTHCudaShortTensor(THCudaShortTensor *tensor) {
- THCudaShortTensor_retain(state, tensor);
- return new thpp::THCTensor<short>(state, tensor);
+ THCudaShortTensor_retain(THDGetCudaState(), tensor);
+ return new thpp::THCTensor<short>(THDGetCudaState(), tensor);
}
THDTensorDescriptor* THDTensorDescriptor_newFromTHCudaCharTensor(THCudaCharTensor *tensor) {
- THCudaCharTensor_retain(state, tensor);
- return new thpp::THCTensor<char>(state, tensor);
+ THCudaCharTensor_retain(THDGetCudaState(), tensor);
+ return new thpp::THCTensor<char>(THDGetCudaState(), tensor);
}
THDTensorDescriptor* THDTensorDescriptor_newFromTHCudaByteTensor(THCudaByteTensor *tensor) {
- THCudaByteTensor_retain(state, tensor);
- return new thpp::THCTensor<unsigned char>(state, tensor);
+ THCudaByteTensor_retain(THDGetCudaState(), tensor);
+ return new thpp::THCTensor<unsigned char>(THDGetCudaState(), tensor);
}
#endif