Support "device" keyword argument (#79)

Adds the optional "device" keyword argument to Tensor and Storage
constructors and .new methods.
diff --git a/torch/_utils.py b/torch/_utils.py
index 18e6f97..b9a3250 100644
--- a/torch/_utils.py
+++ b/torch/_utils.py
@@ -21,7 +21,7 @@
         else:
             return self
     else:
-        ctx = torch.cuda.device(idx) if idx else torch.cuda._dummy_ctx()
+        ctx = torch.cuda.device(idx if idx else -1)
         with ctx:
             return self.type(getattr(torch.cuda, self.__class__.__name__), async)
 
diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py
index ced5135..ccaa77a 100644
--- a/torch/cuda/__init__.py
+++ b/torch/cuda/__init__.py
@@ -36,14 +36,17 @@
 
 @contextlib.contextmanager
 def device(idx):
-    _lazy_init()
-    prev_idx = torch._C._cuda_getDevice()
-    if prev_idx != idx:
-        torch._C._cuda_setDevice(idx)
+    if idx is -1:
         yield
-        torch._C._cuda_setDevice(prev_idx)
     else:
-        yield
+        _lazy_init()
+        prev_idx = torch._C._cuda_getDevice()
+        if prev_idx != idx:
+            torch._C._cuda_setDevice(idx)
+            yield
+            torch._C._cuda_setDevice(prev_idx)
+        else:
+            yield
 
 
 @contextlib.contextmanager
@@ -55,15 +58,11 @@
         yield
 
 
-@contextlib.contextmanager
-def _dummy_ctx():
-    yield
-
-
 def device_count():
     _lazy_init()
     return torch._C._cuda_getDeviceCount()
 
+
 def current_device():
     _lazy_init()
     return torch._C._cuda_getDevice()
@@ -76,8 +75,8 @@
 ################################################################################
 
 
-from .tensor import _CudaTensorBase
-from .storage import _CudaStorageBase
+from ..tensor import _TensorBase
+from ..storage import _StorageBase
 
 if not hasattr(torch._C, 'CudaDoubleStorageBase'):
     # Define dummy base classes
@@ -88,51 +87,64 @@
         torch._C.__dict__[storage_name] = type(storage_name, (object,), {})
         torch._C.__dict__[tensor_name] = type(tensor_name, (object,), {})
 
-class InitCuda(object):
+
+class _CudaBase(object):
+    is_cuda = True
+
+    def type(self, *args, **kwargs):
+        with device(self.get_device()):
+            return super(_CudaBase, self).type(*args, **kwargs)
+
+    def new(self, *args, **kwargs):
+        with device(kwargs.pop('device', self.get_device())):
+            return super(_CudaBase, self).new(*args, **kwargs)
+
     def __new__(cls, *args, **kwargs):
         _lazy_init()
-        return super(InitCuda, cls).__new__(cls, *args, **kwargs)
+        with device(kwargs.pop('device', -1)):
+            return super(_CudaBase, cls).__new__(cls, *args, **kwargs)
 
-class DoubleStorage(InitCuda, torch._C.CudaDoubleStorageBase, _CudaStorageBase):
+
+class DoubleStorage(_CudaBase, torch._C.CudaDoubleStorageBase, _StorageBase):
     pass
-class FloatStorage(InitCuda, torch._C.CudaFloatStorageBase, _CudaStorageBase):
+class FloatStorage(_CudaBase, torch._C.CudaFloatStorageBase, _StorageBase):
     pass
-class LongStorage(InitCuda, torch._C.CudaLongStorageBase, _CudaStorageBase):
+class LongStorage(_CudaBase, torch._C.CudaLongStorageBase, _StorageBase):
     pass
-class IntStorage(InitCuda, torch._C.CudaIntStorageBase, _CudaStorageBase):
+class IntStorage(_CudaBase, torch._C.CudaIntStorageBase, _StorageBase):
     pass
-class ShortStorage(InitCuda, torch._C.CudaShortStorageBase, _CudaStorageBase):
+class ShortStorage(_CudaBase, torch._C.CudaShortStorageBase, _StorageBase):
     pass
-class CharStorage(InitCuda, torch._C.CudaCharStorageBase, _CudaStorageBase):
+class CharStorage(_CudaBase, torch._C.CudaCharStorageBase, _StorageBase):
     pass
-class ByteStorage(InitCuda, torch._C.CudaByteStorageBase, _CudaStorageBase):
+class ByteStorage(_CudaBase, torch._C.CudaByteStorageBase, _StorageBase):
     pass
-class HalfStorage(InitCuda, torch._C.CudaHalfStorageBase, _CudaStorageBase):
+class HalfStorage(_CudaBase, torch._C.CudaHalfStorageBase, _StorageBase):
     pass
 
-class DoubleTensor(InitCuda, torch._C.CudaDoubleTensorBase, _CudaTensorBase):
+class DoubleTensor(_CudaBase, torch._C.CudaDoubleTensorBase, _TensorBase):
     def is_signed(self):
         return True
-class FloatTensor(InitCuda, torch._C.CudaFloatTensorBase, _CudaTensorBase):
+class FloatTensor(_CudaBase, torch._C.CudaFloatTensorBase, _TensorBase):
     def is_signed(self):
         return True
-class LongTensor(InitCuda, torch._C.CudaLongTensorBase, _CudaTensorBase):
+class LongTensor(_CudaBase, torch._C.CudaLongTensorBase, _TensorBase):
     def is_signed(self):
         return True
-class IntTensor(InitCuda, torch._C.CudaIntTensorBase, _CudaTensorBase):
+class IntTensor(_CudaBase, torch._C.CudaIntTensorBase, _TensorBase):
     def is_signed(self):
         return True
-class ShortTensor(InitCuda, torch._C.CudaShortTensorBase, _CudaTensorBase):
+class ShortTensor(_CudaBase, torch._C.CudaShortTensorBase, _TensorBase):
     def is_signed(self):
         return True
-class CharTensor(InitCuda, torch._C.CudaCharTensorBase, _CudaTensorBase):
+class CharTensor(_CudaBase, torch._C.CudaCharTensorBase, _TensorBase):
     def is_signed(self):
         # TODO
         return False
-class ByteTensor(InitCuda, torch._C.CudaByteTensorBase, _CudaTensorBase):
+class ByteTensor(_CudaBase, torch._C.CudaByteTensorBase, _TensorBase):
     def is_signed(self):
         return False
-class HalfTensor(InitCuda, torch._C.CudaHalfTensorBase, _CudaTensorBase):
+class HalfTensor(_CudaBase, torch._C.CudaHalfTensorBase, _TensorBase):
     def is_signed(self):
         return True
 
diff --git a/torch/cuda/storage.py b/torch/cuda/storage.py
deleted file mode 100644
index 99358c2..0000000
--- a/torch/cuda/storage.py
+++ /dev/null
@@ -1,19 +0,0 @@
-from . import device, _dummy_ctx
-from ..storage import _StorageBase
-
-
-class _CudaStorageBase(_StorageBase):
-    is_cuda = True
-
-    def type(self, *args, **kwargs):
-        source_device = self.get_device()
-        ctx = device(source_device) if source_device != -1 else _dummy_ctx()
-        with ctx:
-            return super(_CudaStorageBase, self).type(*args, **kwargs)
-
-    def new(self, *args, **kwargs):
-        source_device = self.get_device()
-        ctx = device(source_device) if source_device != -1 else _dummy_ctx()
-        with ctx:
-            return super(_CudaStorageBase, self).new(*args, **kwargs)
-
diff --git a/torch/cuda/tensor.py b/torch/cuda/tensor.py
deleted file mode 100644
index fabd98b..0000000
--- a/torch/cuda/tensor.py
+++ /dev/null
@@ -1,19 +0,0 @@
-from . import device, _dummy_ctx
-from ..tensor import _TensorBase
-
-
-class _CudaTensorBase(_TensorBase):
-    is_cuda = True
-
-    def type(self, *args, **kwargs):
-        source_device = self.get_device()
-        ctx = device(source_device) if source_device != -1 else _dummy_ctx()
-        with ctx:
-            return super(_CudaTensorBase, self).type(*args, **kwargs)
-
-    def new(self, *args, **kwargs):
-        source_device = self.get_device()
-        ctx = device(source_device) if source_device != -1 else _dummy_ctx()
-        with ctx:
-            return super(_CudaTensorBase, self).new(*args, **kwargs)
-