Support "device" keyword argument (#79)
Adds the optional "device" keyword argument to Tensor and Storage
constructors and .new methods.
diff --git a/torch/_utils.py b/torch/_utils.py
index 18e6f97..b9a3250 100644
--- a/torch/_utils.py
+++ b/torch/_utils.py
@@ -21,7 +21,7 @@
else:
return self
else:
- ctx = torch.cuda.device(idx) if idx else torch.cuda._dummy_ctx()
+ ctx = torch.cuda.device(idx if idx else -1)
with ctx:
return self.type(getattr(torch.cuda, self.__class__.__name__), async)
diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py
index ced5135..ccaa77a 100644
--- a/torch/cuda/__init__.py
+++ b/torch/cuda/__init__.py
@@ -36,14 +36,17 @@
@contextlib.contextmanager
def device(idx):
- _lazy_init()
- prev_idx = torch._C._cuda_getDevice()
- if prev_idx != idx:
- torch._C._cuda_setDevice(idx)
+ if idx is -1:
yield
- torch._C._cuda_setDevice(prev_idx)
else:
- yield
+ _lazy_init()
+ prev_idx = torch._C._cuda_getDevice()
+ if prev_idx != idx:
+ torch._C._cuda_setDevice(idx)
+ yield
+ torch._C._cuda_setDevice(prev_idx)
+ else:
+ yield
@contextlib.contextmanager
@@ -55,15 +58,11 @@
yield
-@contextlib.contextmanager
-def _dummy_ctx():
- yield
-
-
def device_count():
_lazy_init()
return torch._C._cuda_getDeviceCount()
+
def current_device():
_lazy_init()
return torch._C._cuda_getDevice()
@@ -76,8 +75,8 @@
################################################################################
-from .tensor import _CudaTensorBase
-from .storage import _CudaStorageBase
+from ..tensor import _TensorBase
+from ..storage import _StorageBase
if not hasattr(torch._C, 'CudaDoubleStorageBase'):
# Define dummy base classes
@@ -88,51 +87,64 @@
torch._C.__dict__[storage_name] = type(storage_name, (object,), {})
torch._C.__dict__[tensor_name] = type(tensor_name, (object,), {})
-class InitCuda(object):
+
+class _CudaBase(object):
+ is_cuda = True
+
+ def type(self, *args, **kwargs):
+ with device(self.get_device()):
+ return super(_CudaBase, self).type(*args, **kwargs)
+
+ def new(self, *args, **kwargs):
+ with device(kwargs.pop('device', self.get_device())):
+ return super(_CudaBase, self).new(*args, **kwargs)
+
def __new__(cls, *args, **kwargs):
_lazy_init()
- return super(InitCuda, cls).__new__(cls, *args, **kwargs)
+ with device(kwargs.pop('device', -1)):
+ return super(_CudaBase, cls).__new__(cls, *args, **kwargs)
-class DoubleStorage(InitCuda, torch._C.CudaDoubleStorageBase, _CudaStorageBase):
+
+class DoubleStorage(_CudaBase, torch._C.CudaDoubleStorageBase, _StorageBase):
pass
-class FloatStorage(InitCuda, torch._C.CudaFloatStorageBase, _CudaStorageBase):
+class FloatStorage(_CudaBase, torch._C.CudaFloatStorageBase, _StorageBase):
pass
-class LongStorage(InitCuda, torch._C.CudaLongStorageBase, _CudaStorageBase):
+class LongStorage(_CudaBase, torch._C.CudaLongStorageBase, _StorageBase):
pass
-class IntStorage(InitCuda, torch._C.CudaIntStorageBase, _CudaStorageBase):
+class IntStorage(_CudaBase, torch._C.CudaIntStorageBase, _StorageBase):
pass
-class ShortStorage(InitCuda, torch._C.CudaShortStorageBase, _CudaStorageBase):
+class ShortStorage(_CudaBase, torch._C.CudaShortStorageBase, _StorageBase):
pass
-class CharStorage(InitCuda, torch._C.CudaCharStorageBase, _CudaStorageBase):
+class CharStorage(_CudaBase, torch._C.CudaCharStorageBase, _StorageBase):
pass
-class ByteStorage(InitCuda, torch._C.CudaByteStorageBase, _CudaStorageBase):
+class ByteStorage(_CudaBase, torch._C.CudaByteStorageBase, _StorageBase):
pass
-class HalfStorage(InitCuda, torch._C.CudaHalfStorageBase, _CudaStorageBase):
+class HalfStorage(_CudaBase, torch._C.CudaHalfStorageBase, _StorageBase):
pass
-class DoubleTensor(InitCuda, torch._C.CudaDoubleTensorBase, _CudaTensorBase):
+class DoubleTensor(_CudaBase, torch._C.CudaDoubleTensorBase, _TensorBase):
def is_signed(self):
return True
-class FloatTensor(InitCuda, torch._C.CudaFloatTensorBase, _CudaTensorBase):
+class FloatTensor(_CudaBase, torch._C.CudaFloatTensorBase, _TensorBase):
def is_signed(self):
return True
-class LongTensor(InitCuda, torch._C.CudaLongTensorBase, _CudaTensorBase):
+class LongTensor(_CudaBase, torch._C.CudaLongTensorBase, _TensorBase):
def is_signed(self):
return True
-class IntTensor(InitCuda, torch._C.CudaIntTensorBase, _CudaTensorBase):
+class IntTensor(_CudaBase, torch._C.CudaIntTensorBase, _TensorBase):
def is_signed(self):
return True
-class ShortTensor(InitCuda, torch._C.CudaShortTensorBase, _CudaTensorBase):
+class ShortTensor(_CudaBase, torch._C.CudaShortTensorBase, _TensorBase):
def is_signed(self):
return True
-class CharTensor(InitCuda, torch._C.CudaCharTensorBase, _CudaTensorBase):
+class CharTensor(_CudaBase, torch._C.CudaCharTensorBase, _TensorBase):
def is_signed(self):
# TODO
return False
-class ByteTensor(InitCuda, torch._C.CudaByteTensorBase, _CudaTensorBase):
+class ByteTensor(_CudaBase, torch._C.CudaByteTensorBase, _TensorBase):
def is_signed(self):
return False
-class HalfTensor(InitCuda, torch._C.CudaHalfTensorBase, _CudaTensorBase):
+class HalfTensor(_CudaBase, torch._C.CudaHalfTensorBase, _TensorBase):
def is_signed(self):
return True
diff --git a/torch/cuda/storage.py b/torch/cuda/storage.py
deleted file mode 100644
index 99358c2..0000000
--- a/torch/cuda/storage.py
+++ /dev/null
@@ -1,19 +0,0 @@
-from . import device, _dummy_ctx
-from ..storage import _StorageBase
-
-
-class _CudaStorageBase(_StorageBase):
- is_cuda = True
-
- def type(self, *args, **kwargs):
- source_device = self.get_device()
- ctx = device(source_device) if source_device != -1 else _dummy_ctx()
- with ctx:
- return super(_CudaStorageBase, self).type(*args, **kwargs)
-
- def new(self, *args, **kwargs):
- source_device = self.get_device()
- ctx = device(source_device) if source_device != -1 else _dummy_ctx()
- with ctx:
- return super(_CudaStorageBase, self).new(*args, **kwargs)
-
diff --git a/torch/cuda/tensor.py b/torch/cuda/tensor.py
deleted file mode 100644
index fabd98b..0000000
--- a/torch/cuda/tensor.py
+++ /dev/null
@@ -1,19 +0,0 @@
-from . import device, _dummy_ctx
-from ..tensor import _TensorBase
-
-
-class _CudaTensorBase(_TensorBase):
- is_cuda = True
-
- def type(self, *args, **kwargs):
- source_device = self.get_device()
- ctx = device(source_device) if source_device != -1 else _dummy_ctx()
- with ctx:
- return super(_CudaTensorBase, self).type(*args, **kwargs)
-
- def new(self, *args, **kwargs):
- source_device = self.get_device()
- ctx = device(source_device) if source_device != -1 else _dummy_ctx()
- with ctx:
- return super(_CudaTensorBase, self).new(*args, **kwargs)
-