Separate cuda-ness from dtype. (#6470)

* Separate cuda-ness from dtype.

There are no longer torch.cuda.int64, etc; only torch.int64 that correspond to at::ScalarType.
At the python arg parser level, the corresponding ATen type is selected from the combination of (ScalarType, Layout, Device).

There is also currently unused code in here for support ScalarType in native_functions; this will be used for specifying aggregate types
on reduction functions.

* Fix test_autograd.

* Add defaults to randint_like.

* Track is_cuda in py tensor types.

* Fix test_sparse.

* Fix multiprocessing.

* Fix rnn.

* Fix test_nn.

* Fix flake8.
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 5d02a2e..6533c82 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -310,6 +310,8 @@
 
 - func: empty_like(Tensor self, *, Type dtype) -> Tensor
   variants: function
+  python_default_init:
+    dtype: self.type()
 
 - func: exp(Tensor self) -> Tensor
 
@@ -357,6 +359,8 @@
 
 - func: full_like(Tensor self, Scalar fill_value, *, Type dtype) -> Tensor
   variants: function
+  python_default_init:
+    dtype: self.type()
 
 - func: hinge_embedding_loss(Tensor self, Tensor target, double margin=1.0, bool size_average=true, bool reduce=true) -> Tensor
   variants: function
@@ -470,6 +474,8 @@
 
 - func: ones_like(Tensor self, *, Type dtype) -> Tensor
   variants: function
+  python_default_init:
+    dtype: self.type()
 
 - func: pairwise_distance(Tensor x1, Tensor x2, double p=2, double eps=1e-6, bool keepdim=false) -> Tensor
   variants: function
@@ -490,6 +496,9 @@
 
 - func: rand_like(Tensor self, *, Type dtype) -> Tensor
   variants: function
+  python_default_init:
+    dtype: self.type()
+
 
 - func: randint(Type dtype, int64_t high, IntList size, *, Generator* generator=nullptr) -> Tensor
   variants: function
@@ -511,9 +520,13 @@
 
 - func: randint_like(Tensor self, int64_t high, *, Type dtype) -> Tensor
   variants: function
+  python_default_init:
+    dtype: self.type()
 
 - func: randint_like(Tensor self, int64_t low, int64_t high, *, Type dtype) -> Tensor
   variants: function
+  python_default_init:
+    dtype: self.type()
 
 - func: randn(Type dtype, IntList size, *, Generator* generator=nullptr) -> Tensor
   variants: function
@@ -526,6 +539,8 @@
 
 - func: randn_like(Tensor self, *, Type dtype) -> Tensor
   variants: function
+  python_default_init:
+    dtype: self.type()
 
 - func: randperm(Type dtype, int64_t n, *, Generator* generator=nullptr) -> Tensor
   variants: function
@@ -732,6 +747,8 @@
 
 - func: zeros_like(Tensor self, *, Type dtype) -> Tensor
   variants: function
+  python_default_init:
+    dtype: self.type()
 
 - func: _standard_gamma_grad(Tensor self, Tensor output) -> Tensor
   dispatch:
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 3692b86..dbc1a71 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -851,19 +851,20 @@
     def test_requires_grad_factory(self):
         x = Variable(torch.randn(2, 3))
         fns = [torch.ones_like, torch.testing.randn_like]
-        dtypes = [torch.float32, torch.float64, torch.cuda.float32, torch.cuda.float64]
+        dtypes = [torch.float32, torch.float64]
         for fn in fns:
             for requires_grad in [True, False]:
                 for dtype in dtypes:
-                    if not dtype.is_cuda:
-                        output = fn(x, dtype=dtype, requires_grad=requires_grad)
-                        self.assertEqual(requires_grad, output.requires_grad)
-                        self.assertIs(dtype, output.dtype)
-                    elif torch.cuda.is_available() and torch.cuda.device_count() > 1:
-                        output = fn(x, dtype=dtype, device=1, requires_grad=requires_grad)
-                        self.assertEqual(requires_grad, output.requires_grad)
-                        self.assertIs(dtype, output.dtype)
-                        self.assertEqual(1, output.get_device())
+                    for use_cuda in [True, False]:
+                        if not use_cuda:
+                            output = fn(x, dtype=dtype, requires_grad=requires_grad)
+                            self.assertEqual(requires_grad, output.requires_grad)
+                            self.assertIs(dtype, output.dtype)
+                        elif torch.cuda.is_available() and torch.cuda.device_count() > 1:
+                            output = fn(x, dtype=dtype, device=1, requires_grad=requires_grad)
+                            self.assertEqual(requires_grad, output.requires_grad)
+                            self.assertIs(dtype, output.dtype)
+                            self.assertEqual(1, output.get_device())
 
     def test_grad_assignment(self):
         x = Variable(torch.randn(5, 5))
diff --git a/test/test_cuda.py b/test/test_cuda.py
index 3a2a2af..37af436 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -35,7 +35,7 @@
 
 def is_half(t):
     if isinstance(t, torch.Tensor):
-        return t.dtype in [torch.float16, torch.cuda.float16]
+        return t.dtype == torch.float16
     assert isinstance(t, type)
     assert t != torch.autograd.Variable
     return t in [torch.HalfTensor, torch.cuda.HalfTensor]
@@ -1069,7 +1069,7 @@
         TestTorch._test_cat_empty(self, use_cuda=True)
 
     def test_bernoulli(self):
-        x = torch.tensor([0, 1], dtype=torch.cuda.float32)
+        x = torch.tensor([0, 1], dtype=torch.float32, device='cuda')
         self.assertEqual(x.bernoulli().tolist(), [0, 1])
 
     def test_cat_bad_input_sizes(self):
@@ -1432,7 +1432,7 @@
         TestTorch._test_int_pow(self, lambda x: x.cuda())
 
     def test_remainder_overflow(self):
-        TestTorch._test_remainder_overflow(self, dtype=torch.cuda.int64)
+        TestTorch._test_remainder_overflow(self, dtype=torch.int64, device='cuda')
 
     def test_var(self):
         cpu_tensor = torch.randn(2, 3, 3)
@@ -1541,10 +1541,10 @@
             self.assertEqual(a, b.cuda())
 
     def test_diagonal(self):
-        TestTorch._test_diagonal(self, dtype=torch.cuda.float32)
+        TestTorch._test_diagonal(self, dtype=torch.float32, device='cuda')
 
     def test_diagflat(self):
-        TestTorch._test_diagflat(self, dtype=torch.cuda.float32)
+        TestTorch._test_diagflat(self, dtype=torch.float32, device='cuda')
 
     @unittest.skipIf(torch.cuda.device_count() < 2, "only one GPU detected")
     def test_get_set_rng_state_all(self):
diff --git a/test/test_multiprocessing.py b/test/test_multiprocessing.py
index 9e57bf7..4eec708 100644
--- a/test/test_multiprocessing.py
+++ b/test/test_multiprocessing.py
@@ -371,21 +371,21 @@
             self.assertEqual(list(tensor), [4, 4, 4, 4])
         p.join()
 
-    def _test_empty_tensor_sharing(self, dtype):
+    def _test_empty_tensor_sharing(self, dtype, device):
         q = mp.Queue()
-        empty = torch.tensor([], dtype=dtype)
+        empty = torch.tensor([], dtype=dtype, device=device)
         q.put(empty)
         out = q.get(timeout=1)
         self.assertEqual(out, empty)
 
     def test_empty_tensor_sharing(self):
-        self._test_empty_tensor_sharing(torch.float32)
-        self._test_empty_tensor_sharing(torch.int64)
+        self._test_empty_tensor_sharing(torch.float32, torch.device('cpu'))
+        self._test_empty_tensor_sharing(torch.int64, torch.device('cpu'))
 
     @unittest.skipIf(not torch.cuda.is_available(), 'CUDA not available')
     def test_empty_tensor_sharing_cuda(self):
-        self._test_empty_tensor_sharing(torch.cuda.float32)
-        self._test_empty_tensor_sharing(torch.cuda.int64)
+        self._test_empty_tensor_sharing(torch.float32, torch.device('cuda'))
+        self._test_empty_tensor_sharing(torch.int64, torch.device('cuda'))
 
     def _test_autograd_sharing(self, var):
         ready = mp.Event()
diff --git a/test/test_nn.py b/test/test_nn.py
index a6ab445..fb91dc6 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -2155,7 +2155,7 @@
 
     @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
     def test_broadcast_no_grad(self):
-        x = torch.randn(1, 2, dtype=torch.cuda.float32, requires_grad=True)
+        x = torch.randn(1, 2, dtype=torch.float32, requires_grad=True, device='cuda')
         with torch.no_grad():
             broadcasted = Broadcast.apply((0, 1), x)
         self.assertTrue(x.requires_grad)
diff --git a/test/test_sparse.py b/test/test_sparse.py
index e8fdf7f..26c9f45 100644
--- a/test/test_sparse.py
+++ b/test/test_sparse.py
@@ -850,9 +850,9 @@
                     for use_cuda in ([False] if not torch.cuda.is_available() else [True, False]):
                         # have to include size with cuda sparse tensors
                         include_size = include_size or use_cuda
-                        dtype = torch.cuda.float64 if use_cuda else torch.float64
-                        long_dtype = torch.cuda.int64 if use_cuda else torch.int64
-                        device = -1 if not use_cuda else torch.cuda.device_count() - 1
+                        dtype = torch.float64
+                        long_dtype = torch.int64
+                        device = torch.device('cpu') if not use_cuda else torch.device(torch.cuda.device_count() - 1)
                         indices = torch.tensor(([0], [2]), dtype=long_dtype) if use_tensor_idx else ([0], [2])
                         values = torch.tensor([1.], dtype=dtype) if use_tensor_val else 1.
                         if include_size:
@@ -866,7 +866,7 @@
                         self.assertEqual(size if include_size else default_size, sparse_tensor.size())
                         self.assertEqual(dtype, sparse_tensor.dtype)
                         if use_cuda:
-                            self.assertEqual(device, sparse_tensor._values().get_device())
+                            self.assertEqual(device, sparse_tensor._values().device)
                         self.assertEqual(True, sparse_tensor.requires_grad)
 
     @cpu_only
@@ -910,17 +910,18 @@
 
     @cpu_only  # not really, but we only really want to run this once
     def test_dtypes(self):
-        all_dtypes = torch.testing.get_all_dtypes()
-        cpu_dtypes = [d for d in all_dtypes if not d.is_cuda and d != torch.float16]
-        cuda_dtypes = [d for d in all_dtypes if d.is_cuda and d != torch.cuda.float16]
-        TestTorch._test_dtypes(self, cpu_dtypes, cuda_dtypes, torch.sparse_coo)
+        all_sparse_dtypes = [dtype for dtype in torch.testing.get_all_dtypes() if dtype != torch.float16]
+        TestTorch._test_dtypes(self, all_sparse_dtypes, torch.sparse_coo, torch.device('cpu'))
+        if torch.cuda.is_available():
+            TestTorch._test_dtypes(self, all_sparse_dtypes, torch.sparse_coo, torch.device('cuda:0'))
 
     @cpu_only  # not really, but we only really want to run this once
     def test_empty_full(self):
-        all_dtypes = torch.testing.get_all_dtypes()
-        cpu_dtypes = [d for d in all_dtypes if not d.is_cuda and d != torch.half]
-        cuda_dtypes = [d for d in all_dtypes if d.is_cuda and d != torch.cuda.half]
-        TestTorch._test_empty_full(self, cpu_dtypes, cuda_dtypes, torch.sparse_coo)
+        all_sparse_dtypes = [dtype for dtype in torch.testing.get_all_dtypes() if dtype != torch.float16]
+        TestTorch._test_empty_full(self, all_sparse_dtypes, torch.sparse_coo, torch.device('cpu'))
+        if torch.cuda.device_count() > 0:
+            TestTorch._test_empty_full(self, all_sparse_dtypes, torch.sparse_coo, -1)
+            TestTorch._test_empty_full(self, all_sparse_dtypes, torch.sparse_coo, torch.device('cuda:0'))
 
     def test_is_sparse(self):
         x = torch.randn(3, 3)
diff --git a/test/test_torch.py b/test/test_torch.py
index f71d0cd..f63695a 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -957,9 +957,9 @@
         long_res1.remainder_(long_qs.unsqueeze(0).expand_as(long_res1))
 
     @staticmethod
-    def _test_remainder_overflow(self, dtype=torch.int64):
+    def _test_remainder_overflow(self, dtype, device):
         # Check Integer Overflows
-        x = torch.tensor(23500, dtype=dtype)
+        x = torch.tensor(23500, dtype=dtype, device=device)
         q = 392486996410368
         self.assertEqual(x % q, x)
         self.assertEqual(-x % q, q - x)
@@ -967,7 +967,7 @@
         self.assertEqual(-x % -q, -x)
 
     def test_remainder_overflow(self):
-        self._test_remainder_overflow(self, dtype=torch.int64)
+        self._test_remainder_overflow(self, dtype=torch.int64, device='cpu')
 
     def test_mm(self):
         # helper function
@@ -1429,28 +1429,19 @@
         self.assertEqual(output, expected)
 
     @staticmethod
-    def _test_dtypes(self, cpu_dtypes, cuda_dtypes, layout):
-        dtypes = cpu_dtypes + (cuda_dtypes if torch.cuda.is_available() else [])
-
+    def _test_dtypes(self, dtypes, layout, device):
         for dtype in dtypes:
-            # no ops on torch.float16 currently, cuda.float16 doesn't work on windows
             if dtype != torch.float16:
-                if dtype.is_cuda and torch.cuda.device_count() > 1:
-                    out = torch.zeros((2, 3), device=1, dtype=dtype, layout=layout)
-                    self.assertIs(dtype, out.dtype)
-                    self.assertIs(layout, out.layout)
-                    self.assertEqual(1, out.get_device())
-                else:
-                    out = torch.zeros((2, 3), dtype=dtype, layout=layout)
-                    self.assertIs(dtype, out.dtype)
-                    self.assertIs(layout, out.layout)
-            self.assertEqual(dtype in cuda_dtypes, dtype.is_cuda)
+                out = torch.zeros((2, 3), dtype=dtype, layout=layout, device=device)
+                self.assertIs(dtype, out.dtype)
+                self.assertIs(layout, out.layout)
+                self.assertEqual(device, out.device)
 
     def test_dtypes(self):
         all_dtypes = torch.testing.get_all_dtypes()
-        cpu_dtypes = [d for d in all_dtypes if not d.is_cuda]
-        cuda_dtypes = [d for d in all_dtypes if d.is_cuda]
-        self._test_dtypes(self, cpu_dtypes, cuda_dtypes, torch.strided)
+        self._test_dtypes(self, all_dtypes, torch.strided, torch.device('cpu'))
+        if torch.cuda.is_available():
+            self._test_dtypes(self, all_dtypes, torch.strided, torch.device('cuda:0'))
 
     def test_device(self):
         cpu = torch.device('cpu')
@@ -1508,20 +1499,19 @@
             assertEqual('cuda:0', lambda: torch.tensor(5).cuda('cuda:0'))
             self.assertRaises(RuntimeError, lambda: torch.tensor(5).cuda('cpu'))
             self.assertRaises(RuntimeError, lambda: torch.tensor(5).cuda('cpu:0'))
-            assertEqual('cuda:0', lambda: torch.tensor(5, dtype=torch.cuda.int64, device=0))
-            assertEqual('cuda:0', lambda: torch.tensor(5, dtype=torch.cuda.int64, device='cuda:0'))
+            assertEqual('cuda:0', lambda: torch.tensor(5, dtype=torch.int64, device=0))
+            assertEqual('cuda:0', lambda: torch.tensor(5, dtype=torch.int64, device='cuda:0'))
             assertEqual('cuda:' + str(torch.cuda.current_device()),
-                        lambda: torch.tensor(5, dtype=torch.cuda.int64, device='cuda'))
+                        lambda: torch.tensor(5, dtype=torch.int64, device='cuda'))
 
             if torch.cuda.device_count() > 1:
                 assertEqual('cuda:1', lambda: torch.tensor(5).cuda(1))
                 assertEqual('cuda:1', lambda: torch.tensor(5).cuda('cuda:1'))
-                assertEqual('cuda:1', lambda: torch.tensor(5, dtype=torch.cuda.int64, device=1))
-                assertEqual('cuda:1', lambda: torch.tensor(5, dtype=torch.cuda.int64, device='cuda:1'))
+                assertEqual('cuda:1', lambda: torch.tensor(5, dtype=torch.int64, device=1))
+                assertEqual('cuda:1', lambda: torch.tensor(5, dtype=torch.int64, device='cuda:1'))
 
     @staticmethod
-    def _test_empty_full(self, cpu_dtypes, cuda_dtypes, layout):
-        dtypes = cpu_dtypes + (cuda_dtypes if torch.cuda.is_available() else [])
+    def _test_empty_full(self, dtypes, layout, device):
         shape = torch.Size([2, 3])
 
         def check_value(tensor, dtype, layout, device, value, requires_grad):
@@ -1530,7 +1520,7 @@
             self.assertIs(layout, tensor.layout)
             self.assertEqual(tensor.requires_grad, requires_grad)
             if tensor.is_cuda and device != -1:
-                self.assertEqual(device, tensor.get_device())
+                self.assertEqual(device, tensor.device)
             if value is not None:
                 fill = tensor.new(shape).fill_(value)
                 self.assertEqual(tensor, fill)
@@ -1547,7 +1537,6 @@
         for dtype in dtypes:
             for rg in [True, False]:
                 int64_dtype = get_int64_dtype(dtype)
-                device = -1 if not (dtype.is_cuda and torch.cuda.device_count() > 1) else 1
                 v = torch.empty(shape, dtype=dtype, device=device, layout=layout, requires_grad=rg)
                 check_value(v, dtype, layout, device, None, rg)
                 out = v.new()
@@ -1576,10 +1565,10 @@
                                 int64_dtype, layout, device, fv + 5, rg)
 
     def test_empty_full(self):
-        all_dtypes = torch.testing.get_all_dtypes()
-        cpu_dtypes = [d for d in all_dtypes if not d.is_cuda]
-        cuda_dtypes = [d for d in all_dtypes if d.is_cuda]
-        self._test_empty_full(self, cpu_dtypes, cuda_dtypes, torch.strided)
+        self._test_empty_full(self, torch.testing.get_all_dtypes(), torch.strided, torch.device('cpu'))
+        if torch.cuda.device_count() > 0:
+            self._test_empty_full(self, torch.testing.get_all_dtypes(), torch.strided, -1)
+            self._test_empty_full(self, torch.testing.get_all_dtypes(), torch.strided, torch.device('cuda:0'))
 
     def test_dtype_out_match(self):
         d = torch.autograd.Variable(torch.DoubleTensor(2, 3))
@@ -1606,9 +1595,9 @@
         self.assertIs(torch.FloatStorage, torch.Storage)
 
         if torch.cuda.is_available():
-            torch.set_default_tensor_type(torch.cuda.float32)
-            self.assertIs(torch.cuda.float32, torch.get_default_dtype())
-            self.assertIs(torch.cuda.float32, torch.cuda.FloatTensor.dtype)
+            torch.set_default_tensor_type(torch.cuda.FloatTensor)
+            self.assertIs(torch.float32, torch.get_default_dtype())
+            self.assertIs(torch.float32, torch.cuda.FloatTensor.dtype)
             self.assertIs(torch.cuda.FloatStorage, torch.Storage)
 
         # don't support integral or sparse default types.
@@ -1686,8 +1675,22 @@
         saved_dtype = torch.get_default_dtype()
         torch.set_default_tensor_type(torch.float32)
         self.assertIs(torch.float32, torch.tensor(0.).dtype)
-        torch.set_default_tensor_type(torch.cuda.float64)
-        self.assertIs(torch.cuda.float64, torch.tensor(0.).dtype)
+        self.assertEqual(torch.device('cpu'), torch.tensor(0.).device)
+        torch.set_default_tensor_type(torch.float64)
+        self.assertIs(torch.float64, torch.tensor(0.).dtype)
+        torch.set_default_tensor_type(saved_dtype)
+
+    @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
+    def test_tensor_factory_cuda_type(self):
+        saved_dtype = torch.get_default_dtype()
+        torch.set_default_tensor_type(torch.cuda.FloatTensor)
+        x = torch.zeros((5, 5))
+        self.assertIs(torch.float32, x.dtype)
+        self.assertTrue(x.is_cuda)
+        torch.set_default_tensor_type(torch.cuda.DoubleTensor)
+        x = torch.zeros((5, 5))
+        self.assertIs(torch.float64, x.dtype)
+        self.assertTrue(x.is_cuda)
         torch.set_default_tensor_type(saved_dtype)
 
     def test_new_tensor(self):
@@ -1721,23 +1724,23 @@
             expected = expected.cuda(1)
             res1 = expected.new_tensor([1, 1])
             self.assertEqual(res1.get_device(), expected.get_device())
-            res1 = expected.new_tensor([1, 1], dtype=torch.cuda.int)
-            self.assertIs(torch.cuda.int, res1.dtype)
+            res1 = expected.new_tensor([1, 1], dtype=torch.int)
+            self.assertIs(torch.int, res1.dtype)
             self.assertEqual(res1.get_device(), expected.get_device())
 
             res2 = expected.new_tensor(expected)
             self.assertEqual(res2.get_device(), expected.get_device())
-            res2 = expected.new_tensor(expected, dtype=torch.cuda.int)
-            self.assertIs(torch.cuda.int, res1.dtype)
+            res2 = expected.new_tensor(expected, dtype=torch.int)
+            self.assertIs(torch.int, res1.dtype)
             self.assertEqual(res2.get_device(), expected.get_device())
-            res2 = expected.new_tensor(expected, dtype=torch.cuda.int, device=0)
-            self.assertIs(torch.cuda.int, res1.dtype)
+            res2 = expected.new_tensor(expected, dtype=torch.int, device=0)
+            self.assertIs(torch.int, res1.dtype)
             self.assertEqual(res2.get_device(), 0)
 
             res1 = expected.new_tensor(1)
             self.assertEqual(res1.get_device(), expected.get_device())
-            res1 = expected.new_tensor(1, dtype=torch.cuda.int)
-            self.assertIs(torch.cuda.int, res1.dtype)
+            res1 = expected.new_tensor(1, dtype=torch.int)
+            self.assertIs(torch.int, res1.dtype)
             self.assertEqual(res1.get_device(), expected.get_device())
 
     def test_diag(self):
@@ -1748,49 +1751,49 @@
         self.assertEqual(res1, res2)
 
     @staticmethod
-    def _test_diagonal(self, dtype=torch.float32):
-        x = torch.randn((100, 100), dtype=dtype)
+    def _test_diagonal(self, dtype, device):
+        x = torch.randn((100, 100), dtype=dtype, device=device)
         result = torch.diagonal(x)
         expected = torch.diag(x)
         self.assertEqual(result, expected)
 
-        x = torch.randn((100, 100), dtype=dtype)
+        x = torch.randn((100, 100), dtype=dtype, device=device)
         result = torch.diagonal(x, 17)
         expected = torch.diag(x, 17)
         self.assertEqual(result, expected)
 
     def test_diagonal(self):
-        self._test_diagonal(self, dtype=torch.float32)
+        self._test_diagonal(self, dtype=torch.float32, device='cpu')
 
     @staticmethod
-    def _test_diagflat(self, dtype=torch.float32):
+    def _test_diagflat(self, dtype, device):
         # Basic sanity test
-        x = torch.randn((100,), dtype=dtype)
+        x = torch.randn((100,), dtype=dtype, device=device)
         result = torch.diagflat(x)
         expected = torch.diag(x)
         self.assertEqual(result, expected)
 
         # Test offset
-        x = torch.randn((100,), dtype=dtype)
+        x = torch.randn((100,), dtype=dtype, device=device)
         result = torch.diagflat(x, 17)
         expected = torch.diag(x, 17)
         self.assertEqual(result, expected)
 
         # Test where input has more than one dimension
-        x = torch.randn((2, 3, 4), dtype=dtype)
+        x = torch.randn((2, 3, 4), dtype=dtype, device=device)
         result = torch.diagflat(x)
         expected = torch.diag(x.contiguous().view(-1))
         self.assertEqual(result, expected)
 
         # Noncontig input
-        x = torch.randn((2, 3, 4), dtype=dtype).transpose(2, 0)
+        x = torch.randn((2, 3, 4), dtype=dtype, device=device).transpose(2, 0)
         self.assertFalse(x.is_contiguous())
         result = torch.diagflat(x)
         expected = torch.diag(x.contiguous().view(-1))
         self.assertEqual(result, expected)
 
     def test_diagflat(self):
-        self._test_diagflat(self, dtype=torch.float32)
+        self._test_diagflat(self, dtype=torch.float32, device='cpu')
 
     def test_eye(self):
         res1 = torch.eye(100, 100)
@@ -2667,13 +2670,11 @@
     def _test_cat_empty(self, use_cuda=False):
         # FIXME: this is legacy behavior and should be removed
         # when we support empty tensors with arbitrary sizes
-        if use_cuda:
-            dtype = torch.cuda.float32
-        else:
-            dtype = torch.float32
+        dtype = torch.float32
+        device = 'cuda' if use_cuda else 'cpu'
 
-        x = torch.randn((4, 3, 32, 32), dtype=dtype)
-        empty = torch.randn((0,), dtype=dtype)
+        x = torch.randn((4, 3, 32, 32), dtype=dtype, device=device)
+        empty = torch.randn((0,), dtype=dtype, device=device)
 
         res1 = torch.cat([x, empty], dim=1)
         res2 = torch.cat([empty, x], dim=1)
diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py
index 731ca49..54a2e6c 100644
--- a/tools/autograd/gen_python_functions.py
+++ b/tools/autograd/gen_python_functions.py
@@ -73,7 +73,8 @@
   ${call_dispatch}
 } else {
   if (!r.isNone(${type_idx})) {
-    check_out_type_matches(r.tensor(${out_idx}), r.dtype(${type_idx}), r.layout(${layout_idx}));
+    check_out_type_matches(r.tensor(${out_idx}), r.scalartype(${type_idx}), r.layout(${layout_idx}),
+                           r.device(${device_idx}), r.isNone(${device_idx}));
   }
   ${call_dispatch_out}
 }
@@ -207,9 +208,9 @@
         'Tensor &': 'tensor',
         'Generator *': 'generator',
         'Storage &': 'storage',
-        'const Type &': 'dtype',
+        'const Type &': 'scalartype',
         'const THPLayout &': 'layout',
-        'const Device &': 'deviceInt64',
+        'const Device &': 'device',
         'int64_t': 'toInt64',
         'bool': 'toBool',
         'double': 'toDouble',
@@ -221,6 +222,10 @@
         'int64_t': 'toInt64WithDefault',
         'bool': 'setDefaultBool',
         'double': 'setDefaultDouble',
+        'const Type &': 'scalartypeWithDefault',
+        'const THPLayout &': 'layoutWithDefault',
+        'const Device &': 'deviceWithDefault',
+        'ScalarType': 'scalartypeWithDefault',
     }
 
     def first_tensor_arg(arguments):
@@ -286,6 +291,9 @@
                     '`{}` type is not supported in python_default_init'.format(typename)
                 unpack_with_default = unpack_with_default_methods.get(typename)
                 default_expr = arg.get('python_default_init')
+                # TODO: Type currently maps to ScalarType, figure out a cleaner solution
+                if typename == 'const Type &':
+                    default_expr += '.scalarType()'
                 expr = 'r.{}({}, {})'.format(unpack_with_default, arg_index, default_expr)
             else:
                 unpack = unpack_methods.get(typename, typename.lower())
@@ -335,7 +343,6 @@
                 actuals.append('results[{}]'.format(i))
 
         layout = None
-        parsed_type_dispatch = None
         # type args go after the outputs to match the signature generation.
         arg_idx = arg_idx if out_idx is None else out_idx + 1
         for arg in type_args:
@@ -357,23 +364,29 @@
         for arg in python_binding_arguments:
             if arg['name'] == 'dtype' and arg['simple_type'] == 'Type':
                 pass  # already handled by type_dispatched_args
-            elif arg['name'] == 'device' and arg['simple_type'] == 'Device':
-                if len(outputs) == 0:
-                    has_device_bind = True
-                    append_actuals_formals(*parse_arg(arg, device_idx))
-            elif arg['name'] == 'requires_grad' and arg['simple_type'] == 'bool':
-                requires_grad = parse_arg(arg, requires_grad_idx)[0]
             elif arg['name'] == 'layout' and arg['simple_type'] == 'Layout':
                 # out(s) determines the type and layout if it is present, so only use this if there are no outputs.
                 if len(outputs) == 0:
-                    layout = parse_arg(arg, layout_idx)[0]
+                    layout = parse_arg(arg, layout_idx, arg.get('python_default_init'))[0]
+            elif arg['name'] == 'device' and arg['simple_type'] == 'Device':
+                if len(outputs) == 0:
                     assert parsed_type_args
-                    actuals.append("torch::getType({}, {})".format(parsed_type_args[0], layout))
+                    assert layout
+                    device_arg = parse_arg(arg, device_idx, True)
+                    # add type, device formals and corresponding actuals.
+                    # The type actual isthe ATen type mapped from (ScalarType, Layout, Device)
+                    # The device actual is the corresponding AutoGPU index for the Device.
                     formal_args.append(parsed_type_args[1])
+                    formal_args.append(device_arg[1])
+                    actuals.append("torch::getType({}, {}, {}.type)".format(parsed_type_args[0], layout, device_arg[0]))
+                    actuals.append('{}.deviceInt64()'.format(device_arg[0]))
+                    has_device_bind = True
+            elif arg['name'] == 'requires_grad' and arg['simple_type'] == 'bool':
+                requires_grad = parse_arg(arg, requires_grad_idx)[0]
             else:
                 raise RuntimeError(("found {} in python_binding_arguments but only "
-                                    "\"bool requires_grad\", \"Dtype dtype\", \"Layout layout\", \"Device device\" "
-                                    "are supported".format(arg)))
+                                    "\"bool requires_grad\", \"ScalarType dtype\", \"Layout layout\", "
+                                    "\"Device device\" are supported".format(arg)))
 
         env['unpack_args'] = []
         env['formal_args'] = formal_args
@@ -414,7 +427,7 @@
             has_dtype_bind = 'dtype' in [d['name'] for d in dictionary['out'].get('python_binding_arguments', [])]
             if has_dtype_bind:
                 body = PY_VARIABLE_OUT_CHECK_TYPE.substitute(env, out_idx=out_idx, type_idx=out_idx + 1,
-                                                             layout_idx=out_idx + 2).split('\n')
+                                                             layout_idx=out_idx + 2, device_idx=out_idx + 3).split('\n')
             else:
                 body = PY_VARIABLE_OUT.substitute(env, out_idx=out_idx).split('\n')
         else:
@@ -463,6 +476,7 @@
             }
             python_binding_arguments.append(dtype_arg)
         if is_factory_function or is_typed_like_function:
+            py_default_layout = '*torch::getLayout(self.type().backend())' if is_typed_like_function else None
             layout_arg = {
                 'default': 'torch.strided',
                 'dynamic_type': 'Layout',
@@ -470,9 +484,10 @@
                 'name': 'layout',
                 'type': 'const THPLayout &',
                 'simple_type': 'Layout',
+                'python_default_init': py_default_layout,
             }
             python_binding_arguments.append(layout_arg)
-        if is_factory_or_like_function:
+            py_default_device = 'torch::utils::getDevice(self)' if is_typed_like_function else None
             device_arg = {
                 'default': 'None',
                 'default_init': 'None',
@@ -480,9 +495,11 @@
                 'kwarg_only': True,
                 'name': 'device',
                 'type': 'const Device &',
-                'simple_type': 'Device'
+                'simple_type': 'Device',
+                'python_default_init': py_default_device
             }
             python_binding_arguments.append(device_arg)
+        if is_factory_or_like_function:
             requires_grad_arg = {
                 'default': False,
                 'dynamic_type': 'bool',
@@ -590,7 +607,7 @@
     positional = True
 
     def get_py_formal_arg(arg):
-        typename = arg['simple_type'] if arg['simple_type'] != 'Type' else 'Dtype'
+        typename = arg['simple_type'] if arg['simple_type'] != 'Type' else 'ScalarType'
         if arg.get('is_nullable'):
             typename = '{}?'.format(typename)
         if arg.get('size') is not None:
diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py
index 2b16567..d33e75c 100644
--- a/tools/autograd/gen_variable_type.py
+++ b/tools/autograd/gen_variable_type.py
@@ -318,7 +318,8 @@
     def emit_record_trace(env):
         # Operations involving Generator, Storage, Type are not traceable
         # at the moment
-        if any(arg['simple_type'] in {'Generator', 'Storage', 'Type'} for arg in declaration['arguments']):
+        if any(arg['simple_type'] in {'Generator', 'Storage', 'ScalarType', 'Type'}
+               for arg in declaration['arguments']):
             return ('', '')
         # We can't trace functions which don't have any Tensor or TensorList returns
         if 'Tensor' not in declaration['return_type']:
diff --git a/tools/autograd/templates/VariableType.h b/tools/autograd/templates/VariableType.h
index 938a425..a215656 100644
--- a/tools/autograd/templates/VariableType.h
+++ b/tools/autograd/templates/VariableType.h
@@ -22,6 +22,7 @@
 using at::Tensor;
 using at::TensorList;
 using at::Type;
+using at::ScalarType;
 
 struct VariableType final : public at::Type {
   VariableType(Context* context, at::Type* baseType);
diff --git a/tools/autograd/templates/python_torch_functions.cpp b/tools/autograd/templates/python_torch_functions.cpp
index e2009ba..35fb89b 100644
--- a/tools/autograd/templates/python_torch_functions.cpp
+++ b/tools/autograd/templates/python_torch_functions.cpp
@@ -16,6 +16,7 @@
 #include "torch/csrc/utils/python_arg_parser.h"
 #include "torch/csrc/utils/tensor_new.h"
 #include "torch/csrc/utils/tensor_numpy.h"
+#include "torch/csrc/utils/tensor_devices.h"
 #include "torch/csrc/utils/tensor_layouts.h"
 
 #include "python_torch_functions_dispatch.h"
@@ -33,8 +34,11 @@
   return self;
 }
 
-static void check_out_type_matches(Tensor result, const THPDtype &dtype, const THPLayout& layout) {
-  const auto& type = torch::getType(dtype, layout);
+static void check_out_type_matches(Tensor result, ScalarType scalarType, const THPLayout& layout,
+                                   const Device& device, bool device_is_none) {
+  auto result_device_type = torch::getDeviceType(result.type());
+  auto device_type = device_is_none ? result_device_type : device.type;
+  const auto& type = torch::getType(scalarType, layout, device_type);
   if (result.type() != type) {
     AT_ERROR(
         "type corresponding to %s does not match type of out parameter (%s)",
@@ -90,19 +94,13 @@
 {
   HANDLE_TH_ERRORS
   static PythonArgParser parser({
-    "_promote_types(Dtype type1, Dtype type2)",
+    "_promote_types(ScalarType type1, ScalarType type2)",
   });
   ParsedArgs<2> parsed_args;
   auto r = parser.parse(args, kwargs, parsed_args);
   if (r.idx == 0) {
-    auto& d1 = r.dtype(0);
-    auto& d2 = r.dtype(1);
-    if (d1.is_cuda != d2.is_cuda) {
-      AT_ERROR("_promote_types only supports dtypes being both on cpu or cuda.  Got %s and %s",
-               d1.is_cuda ? "true" : "false", d2.is_cuda ? "true" : "false");
-    }
-    ScalarType promoted = at::promoteTypes(d1.scalar_type, d2.scalar_type);
-    return torch::autograd::utils::wrap(torch::getDtype(promoted, d1.is_cuda));
+    ScalarType promoted = at::promoteTypes(r.scalartype(0), r.scalartype(1));
+    return torch::autograd::utils::wrap(torch::getDtype(promoted));
   }
   Py_RETURN_NONE;
   END_HANDLE_TH_ERRORS
diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp
index 5d0341a..9ad5dbb 100644
--- a/tools/autograd/templates/python_variable_methods.cpp
+++ b/tools/autograd/templates/python_variable_methods.cpp
@@ -565,7 +565,8 @@
   } else {
     throw TypeError("dtype must be a type, str, or dtype object");
   }
-  auto& type = is_dtype ? torch::getType(r.dtype(0), *torch::getLayout(self_.type().backend())) :
+  auto self_device_type = torch::getDeviceType(self_.type());
+  auto& type = is_dtype ? torch::getType(r.scalartype(0), *torch::getLayout(self_.type().backend()), self_device_type) :
                           torch::utils::type_from_string(type_name);
   return THPVariable_Wrap(torch::utils::dispatch_type_conversion(self_, type, -1, r.toBool(1)));
   END_HANDLE_TH_ERRORS
diff --git a/tools/jit/gen_jit_dispatch.py b/tools/jit/gen_jit_dispatch.py
index 4bf910a..c61d69d 100644
--- a/tools/jit/gen_jit_dispatch.py
+++ b/tools/jit/gen_jit_dispatch.py
@@ -68,6 +68,7 @@
             not any(arg['simple_type'] == 'Generator' for arg in decl['arguments']) and
             not any(arg['simple_type'] == 'SparseTensor' for arg in decl['arguments']) and
             not any(arg['simple_type'] == 'Storage' for arg in decl['arguments']) and
+            not any(arg['simple_type'] == 'ScalarType' for arg in decl['arguments']) and
             not any(arg['simple_type'] == 'Type' for arg in decl['arguments']) and
             uses_tensors)
 
diff --git a/torch/backends/cudnn/rnn.py b/torch/backends/cudnn/rnn.py
index 43f9ec6..7edbbc6 100644
--- a/torch/backends/cudnn/rnn.py
+++ b/torch/backends/cudnn/rnn.py
@@ -35,12 +35,12 @@
         self.inner = None
 
 
-def init_dropout_state(ty, dropout, train, dropout_seed, dropout_state):
+def init_dropout_state(ty, device, dropout, train, dropout_seed, dropout_state):
     dropout_desc_name = 'desc_' + str(torch.cuda.current_device())
     dropout_p = dropout if train else 0
     if (dropout_desc_name not in dropout_state) or (dropout_state[dropout_desc_name].get() is None):
         dropout_state[dropout_desc_name] = Unserializable(
-            torch._C._VariableFunctions._cudnn_init_dropout_state(dropout_p, train, dropout_seed, ty=ty)
+            torch._C._VariableFunctions._cudnn_init_dropout_state(dropout_p, train, dropout_seed, ty=ty, device=device)
             if dropout_p != 0 else None
         )
     dropout_ts = dropout_state[dropout_desc_name].get()
diff --git a/torch/csrc/Device.cpp b/torch/csrc/Device.cpp
index 442af5c..cae999c 100644
--- a/torch/csrc/Device.cpp
+++ b/torch/csrc/Device.cpp
@@ -35,7 +35,7 @@
 PyObject *THPDevice_repr(THPDevice *self)
 {
   std::ostringstream oss;
-  oss << "Device(device_type=\'" << deviceTypeString(self->device.type) << "\'";
+  oss << "device(device_type=\'" << deviceTypeString(self->device.type) << "\'";
   if (!self->device.is_default) {
     oss << ", device_index=" << self->device.index;
   }
diff --git a/torch/csrc/Dtype.cpp b/torch/csrc/Dtype.cpp
index 241682b..807c17a 100644
--- a/torch/csrc/Dtype.cpp
+++ b/torch/csrc/Dtype.cpp
@@ -8,33 +8,18 @@
 #include "torch/csrc/utils/tensor_dtypes.h"
 #include "torch/csrc/utils/tensor_types.h"
 
-PyObject * THPDtype_New(at::ScalarType scalar_type, bool is_cuda, const std::string& name)
+PyObject * THPDtype_New(at::ScalarType scalar_type, const std::string& name)
 {
   auto type = (PyTypeObject*)&THPDtypeType;
   auto self = THPObjectPtr{type->tp_alloc(type, 0)};
   if (!self) throw python_error();
   auto self_ = reinterpret_cast<THPDtype*>(self.get());
   self_->scalar_type = scalar_type;
-  self_->is_cuda = is_cuda;
   std::strncpy (self_->name, name.c_str(), DTYPE_NAME_LEN);
   self_->name[DTYPE_NAME_LEN] = '\0';
   return self.release();
 }
 
-PyObject *THPDtype_repr(THPDtype *self)
-{
-  return THPUtils_packString(self->name);
-}
-
-PyObject *THPDtype_is_cuda(THPDtype *self)
-{
-  if (self->is_cuda) {
-    Py_RETURN_TRUE;
-  } else {
-    Py_RETURN_FALSE;
-  }
-}
-
 PyObject *THPDtype_is_floating_point(THPDtype *self)
 {
   if (at::isFloatingType(self->scalar_type)) {
@@ -47,11 +32,15 @@
 typedef PyObject *(*getter)(PyObject *, void *);
 
 static struct PyGetSetDef THPDtype_properties[] = {
-  {"is_cuda",      (getter)THPDtype_is_cuda, nullptr, nullptr, nullptr},
   {"is_floating_point", (getter)THPDtype_is_floating_point, nullptr, nullptr, nullptr},
   {nullptr}
 };
 
+PyObject *THPDtype_repr(THPDtype *self)
+{
+  return THPUtils_packString(self->name);
+}
+
 PyTypeObject THPDtypeType = {
   PyVarObject_HEAD_INIT(nullptr, 0)
   "torch.dtype",                         /* tp_name */
diff --git a/torch/csrc/Dtype.h b/torch/csrc/Dtype.h
index e244cb1..4a67cd2 100644
--- a/torch/csrc/Dtype.h
+++ b/torch/csrc/Dtype.h
@@ -8,7 +8,6 @@
 struct THPDtype {
   PyObject_HEAD
   at::ScalarType scalar_type;
-  bool is_cuda;
   char name[DTYPE_NAME_LEN + 1];
 };
 
@@ -18,6 +17,6 @@
   return Py_TYPE(obj) == &THPDtypeType;
 }
 
-PyObject * THPDtype_New(at::ScalarType scalar_type, bool is_cuda, const std::string& name);
+PyObject * THPDtype_New(at::ScalarType scalar_type, const std::string& name);
 
 void THPDtype_init(PyObject *module);
diff --git a/torch/csrc/DynamicTypes.cpp b/torch/csrc/DynamicTypes.cpp
index 6510e98..08702be 100644
--- a/torch/csrc/DynamicTypes.cpp
+++ b/torch/csrc/DynamicTypes.cpp
@@ -33,8 +33,7 @@
 
 static const int NumBoolOptions = 2;
 static THPDtype* dtype_registry
-  [static_cast<int>(at::ScalarType::NumOptions)]
-  [NumBoolOptions] = {};
+  [static_cast<int>(at::ScalarType::NumOptions)] = {};
 
 static THPLayout* layout_registry
   [static_cast<int>(at::Backend::NumOptions)] = {};
@@ -72,8 +71,8 @@
   }
 }
 
-void registerDtypeObject(THPDtype *dtype, at::ScalarType scalarType, bool is_cuda) {
-  dtype_registry[static_cast<int>(scalarType)][is_cuda] = dtype;
+void registerDtypeObject(THPDtype *dtype, at::ScalarType scalarType) {
+  dtype_registry[static_cast<int>(scalarType)] = dtype;
 }
 
 void registerLayoutObject(THPLayout *layout, at::Backend backend) {
@@ -89,15 +88,16 @@
   throw std::invalid_argument("unsupported Storage type");
 }
 
-at::Type& getType(const THPDtype &dtype, const THPLayout& layout) {
-  at::Backend backend = get_backend(dtype.is_cuda, !layout.is_strided);
+at::Type& getType(at::ScalarType scalarType, const THPLayout& layout, const DeviceType& deviceType) {
+  at::Backend backend = get_backend(deviceType == DeviceType::CUDA, !layout.is_strided);
   // use type_registry rather than context.getType() because getType throws exceptions.
   auto baseType = at::globalContext().type_registry[static_cast<int>(backend)]
-                                                   [static_cast<int>(dtype.scalar_type)].get();
+                                                   [static_cast<int>(scalarType)].get();
   if (!baseType) {
     std::ostringstream oss;
-    oss << "Error attempting to use dtype " << dtype.name << " with layout " << layout.name << ".";
-    if (!torch::utils::cuda_enabled()) {
+    oss << "Error attempting to use dtype " << getDtype(scalarType)->name << " with layout " << layout.name
+        << " and device type " << (deviceType == DeviceType::CPU ? "CPU" : "CUDA") << ".";
+    if (deviceType == DeviceType::CUDA && !torch::utils::cuda_enabled()) {
       oss << "  Torch not compiled with CUDA enabled." << std::endl;
     }
     throw std::runtime_error(oss.str());
@@ -105,10 +105,10 @@
   return *torch::autograd::VariableType::getType(*baseType);
 }
 
-THPDtype* getDtype(at::ScalarType scalarType, bool is_cuda) {
-  auto dtype = dtype_registry[static_cast<int>(scalarType)][is_cuda];
+THPDtype* getDtype(at::ScalarType scalarType) {
+  auto dtype = dtype_registry[static_cast<int>(scalarType)];
   if (!dtype) {
-    throw std::invalid_argument("unsupported backend, scalarType");
+    throw std::invalid_argument("unsupported scalarType");
   }
   return dtype;
 }
@@ -121,6 +121,10 @@
   return layout;
 }
 
+DeviceType getDeviceType(const at::Type& type) {
+  return type.is_cuda() ? torch::DeviceType::CUDA : torch::DeviceType::CPU;
+}
+
 PyObject* createPyObject(const at::Storage& storage)
 {
   auto type = getPyTypeObject(storage);
diff --git a/torch/csrc/DynamicTypes.h b/torch/csrc/DynamicTypes.h
index 87c82e3..32225a2 100644
--- a/torch/csrc/DynamicTypes.h
+++ b/torch/csrc/DynamicTypes.h
@@ -8,6 +8,7 @@
 #include <ATen/ATen.h>
 #include "torch/csrc/Dtype.h"
 #include "torch/csrc/Layout.h"
+#include "torch/csrc/utils/device.h"
 
 namespace torch {
 
@@ -16,15 +17,16 @@
     PyTypeObject *pytype, const std::string& name,
     bool is_cuda, bool is_sparse);
 
-void registerDtypeObject(THPDtype *dtype, at::ScalarType scalarType, bool is_cuda);
+void registerDtypeObject(THPDtype *dtype, at::ScalarType scalarType);
 void registerLayoutObject(THPLayout *layout, at::Backend backend);
 
 PyObject* createPyObject(const at::Storage& storage);
 std::unique_ptr<at::Storage> createStorage(PyObject* obj);
 bool isStorage(PyObject* obj);
 
-THPDtype* getDtype(at::ScalarType scalarType, bool is_cuda);
+THPDtype* getDtype(at::ScalarType scalarType);
 THPLayout* getLayout(at::Backend backend);
-at::Type& getType(const THPDtype &dtype, const THPLayout& layout);
+at::Type& getType(at::ScalarType scalarType, const THPLayout& layout, const DeviceType& deviceType);
+DeviceType getDeviceType(const at::Type& type);
 
 }  // namespace torch
diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp
index 6263e6c..167ca46 100644
--- a/torch/csrc/Module.cpp
+++ b/torch/csrc/Module.cpp
@@ -331,8 +331,7 @@
 PyObject *THPModule_getDefaultDtype(PyObject *_unused, PyObject *arg) {
   HANDLE_TH_ERRORS
   auto& type = torch::tensor::get_default_tensor_type();
-  bool is_cuda = type.backend() == at::kCUDA;
-  auto dtype = (PyObject*)torch::getDtype(type.scalarType(), is_cuda);
+  auto dtype = (PyObject*)torch::getDtype(type.scalarType());
   Py_INCREF(dtype);
   return dtype;
   END_HANDLE_TH_ERRORS
diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp
index 7ea866f..c89ab7b 100644
--- a/torch/csrc/autograd/python_variable.cpp
+++ b/torch/csrc/autograd/python_variable.cpp
@@ -378,7 +378,7 @@
 {
   HANDLE_TH_ERRORS
   auto& self_ = self->cdata;
-  return torch::autograd::utils::wrap(torch::getDtype(self_.type().scalarType(), self_.type().is_cuda()));
+  return torch::autograd::utils::wrap(torch::getDtype(self_.type().scalarType()));
   END_HANDLE_TH_ERRORS
 }
 
diff --git a/torch/csrc/tensor/python_tensor.cpp b/torch/csrc/tensor/python_tensor.cpp
index 1e2b515..4ad21db 100644
--- a/torch/csrc/tensor/python_tensor.cpp
+++ b/torch/csrc/tensor/python_tensor.cpp
@@ -30,6 +30,7 @@
   at::Type* aten_type;
   THPDtype* dtype;
   THPLayout* layout;
+  bool is_cuda;
   char name[64];
 };
 
@@ -51,7 +52,7 @@
   if (!tensor_type.aten_type) {
     throw unavailable_type(tensor_type);
   }
-  if (tensor_type.dtype->is_cuda) {
+  if (tensor_type.aten_type->is_cuda()) {
     torch::utils::cuda_lazy_init();
   }
   return THPVariable_Wrap(torch::utils::legacy_tensor_ctor(*tensor_type.aten_type, args, kwargs));
@@ -79,7 +80,7 @@
 }
 
 PyObject *Tensor_is_cuda(PyTensorType* self) {
-  if (self->dtype->is_cuda) {
+  if (self->is_cuda) {
     Py_RETURN_TRUE;
   } else {
     Py_RETURN_FALSE;
@@ -178,7 +179,8 @@
   auto baseType = globalContext().type_registry[static_cast<int>(backend)][static_cast<int>(scalarType)].get();
   type_obj.aten_type = baseType ? torch::autograd::VariableType::getType(*baseType) : nullptr;
   type_obj.layout = torch::getLayout(backend);
-  type_obj.dtype = torch::getDtype(scalarType, backend == kCUDA || backend == kSparseCUDA);
+  type_obj.dtype = torch::getDtype(scalarType);
+  type_obj.is_cuda = (backend == at::Backend::CUDA || backend == at::Backend::SparseCUDA);
 }
 
 static void set_name(PyTensorType& type_obj, const std::string& name) {
diff --git a/torch/csrc/utils/device.h b/torch/csrc/utils/device.h
index 91c5378..0487726 100644
--- a/torch/csrc/utils/device.h
+++ b/torch/csrc/utils/device.h
@@ -11,7 +11,9 @@
   int64_t index;
   bool is_default;   // is default device for type.
   Device(DeviceType type, int64_t index, bool is_default);
+
   bool operator==(const Device& rhs);
+  inline int64_t deviceInt64() { return (this->is_default || this->type == DeviceType::CPU) ? -1 : this->index; }
 };
 
 }
diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp
index 4327dda..b1ffea2 100644
--- a/torch/csrc/utils/python_arg_parser.cpp
+++ b/torch/csrc/utils/python_arg_parser.cpp
@@ -23,7 +23,7 @@
   {"bool", ParameterType::BOOL},
   {"Storage", ParameterType::STORAGE},
   {"PyObject*", ParameterType::PYOBJECT},
-  {"Dtype", ParameterType::DTYPE},
+  {"ScalarType", ParameterType::SCALARTYPE},
   {"Layout", ParameterType::LAYOUT},
   {"Device", ParameterType::DEVICE},
   {"String", ParameterType::STRING},
@@ -111,7 +111,7 @@
     case ParameterType::BOOL: return PyBool_Check(obj);
     case ParameterType::STORAGE: return isStorage(obj);
     case ParameterType::PYOBJECT: return true;
-    case ParameterType::DTYPE: return THPDtype_Check(obj);
+    case ParameterType::SCALARTYPE: return THPDtype_Check(obj);
     case ParameterType::LAYOUT: return THPLayout_Check(obj);
     case ParameterType::DEVICE:
       return THPUtils_checkLong(obj) || THPUtils_checkString(obj) || THPDevice_Check(obj);
@@ -132,7 +132,7 @@
     case ParameterType::BOOL: return "bool";
     case ParameterType::STORAGE: return "torch.Storage";
     case ParameterType::PYOBJECT: return "object";
-    case ParameterType::DTYPE: return "torch.dtype";
+    case ParameterType::SCALARTYPE: return "torch.dtype";
     case ParameterType::LAYOUT: return "torch.layout";
     case ParameterType::DEVICE: return "torch.device";
     case ParameterType::STRING: return "str";
@@ -166,21 +166,23 @@
     if (str != "None") {
       default_intlist.assign(size, std::stoi(str));
     }
-  } else if (type_ == ParameterType::DTYPE) {
+  } else if (type_ == ParameterType::SCALARTYPE) {
     if (str == "None") {
-      default_dtype = nullptr;
+      default_scalartype = at::ScalarType::Undefined;
     } else if (str == "torch.int64") {
-      default_dtype = torch::getDtype(kLong, false);
+      default_scalartype = at::ScalarType::Long;
     } else {
-      throw std::runtime_error("invalid default value for dtype: " + str);
+      throw std::runtime_error("invalid default value for ScalarType: " + str);
     }
   } else if (type_ == ParameterType::LAYOUT) {
-    if (str == "torch.strided") {
+    if (str == "None") {
+      default_layout = nullptr;
+    } else if (str == "torch.strided") {
       default_layout = torch::getLayout(at::Backend::CPU);
     } else if (str == "torch.sparse_coo") {
       default_layout = torch::getLayout(at::Backend::SparseCPU);
     } else {
-      throw std::runtime_error("invalid default value for dtype: " + str);
+      throw std::runtime_error("invalid default value for layout: " + str);
     }
   } else if (type_ == ParameterType::DEVICE) {
     if (str != "None") {
diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h
index 590fd05..032ea18 100644
--- a/torch/csrc/utils/python_arg_parser.h
+++ b/torch/csrc/utils/python_arg_parser.h
@@ -44,7 +44,7 @@
 
 enum class ParameterType {
   TENSOR, SCALAR, INT64, DOUBLE, TENSOR_LIST, INT_LIST, GENERATOR,
-  BOOL, STORAGE, PYOBJECT, DTYPE, LAYOUT, DEVICE, STRING
+  BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, DEVICE, STRING
 };
 
 struct FunctionParameter;
@@ -93,10 +93,12 @@
   inline std::vector<int64_t> intlistWithDefault(int i, std::vector<int64_t> default_intlist);
   inline at::Generator* generator(int i);
   inline std::unique_ptr<at::Storage> storage(int i);
-  inline const THPDtype& dtype(int i);
-  inline const THPDtype& dtypeWithDefault(int i, const THPDtype& default_dtype);
+  inline at::ScalarType scalartype(int i);
+  inline at::ScalarType scalartypeWithDefault(int i, at::ScalarType default_scalartype);
   inline const THPLayout& layout(int i);
+  inline const THPLayout& layoutWithDefault(int i, const THPLayout& default_layout);
   inline Device device(int i);
+  inline Device deviceWithDefault(int i, const Device& default_device);
   inline int64_t deviceInt64(int i);
   inline std::string string(int i);
   inline PyObject* pyobject(int i);
@@ -146,7 +148,7 @@
     bool default_bool;
     int64_t default_int;
     double default_double;
-    THPDtype* default_dtype;
+    at::ScalarType default_scalartype;
     THPLayout* default_layout;
   };
 };
@@ -256,21 +258,18 @@
   return res;
 }
 
-inline const THPDtype& PythonArgs::dtypeWithDefault(int i, const THPDtype& default_dtype) {
-  if (!args[i]) return default_dtype;
-  return dtype(i);
+inline at::ScalarType PythonArgs::scalartypeWithDefault(int i, at::ScalarType default_scalartype) {
+  if (!args[i]) return default_scalartype;
+  return scalartype(i);
 }
 
-inline const THPDtype& PythonArgs::dtype(int i) {
+inline at::ScalarType PythonArgs::scalartype(int i) {
   if (!args[i]) {
-    auto dtype = signature.params[i].default_dtype;
-    if (!dtype) {
-      const auto& type = torch::tensor::get_default_tensor_type();
-      dtype = torch::getDtype(type.scalarType(), type.is_cuda());
-    }
-    return *dtype;
+    auto scalartype = signature.params[i].default_scalartype;
+    return (scalartype == at::ScalarType::Undefined) ?
+            torch::tensor::get_default_tensor_type().scalarType() : scalartype;
   }
-  return *reinterpret_cast<THPDtype*>(args[i]);
+  return reinterpret_cast<THPDtype*>(args[i])->scalar_type;
 }
 
 inline const THPLayout& PythonArgs::layout(int i) {
@@ -278,13 +277,22 @@
   return *reinterpret_cast<THPLayout*>(args[i]);
 }
 
+inline const THPLayout& PythonArgs::layoutWithDefault(int i, const THPLayout& default_layout) {
+  if (!args[i]) return default_layout;
+  return layout(i);
+}
+
 static std::string cuda_str = "cuda";
 static std::string cpu_str = "cpu";
 static std::string cuda_prefix = "cuda:";
 static std::string cpu_prefix = "cpu:";
 
 inline Device PythonArgs::device(int i) {
-  if (!args[i]) return Device(DeviceType::CPU, -1, true);  // TODO: use CUDA if default type is a cuda type.
+  if (!args[i]) {
+    const auto& default_tensor_type = torch::tensor::get_default_tensor_type();
+    const auto device_type = torch::getDeviceType(default_tensor_type);
+    return Device(device_type, -1, true);
+  }
   if (THPDevice_Check(args[i])) {
     auto device = reinterpret_cast<THPDevice*>(args[i]);
     return device->device;
@@ -308,9 +316,14 @@
   throw torch::TypeError("only \"cuda\" and \"cpu\" are valid device types, got %s", device_str.c_str());
 }
 
+inline Device PythonArgs::deviceWithDefault(int i, const Device& default_device) {
+  if (!args[i]) return default_device;
+  return device(i);
+}
+
 inline int64_t PythonArgs::deviceInt64(int i) {
   auto dev = device(i);
-  return (dev.is_default || dev.type == DeviceType::CPU) ? -1 : dev.index;
+  return dev.deviceInt64();
 }
 
 inline std::string PythonArgs::string(int i) {
diff --git a/torch/csrc/utils/tensor_devices.h b/torch/csrc/utils/tensor_devices.h
new file mode 100644
index 0000000..31515f3
--- /dev/null
+++ b/torch/csrc/utils/tensor_devices.h
@@ -0,0 +1,13 @@
+#pragma once
+
+#include <ATen/ATen.h>
+#include "DynamicTypes.h"
+#include "device.h"
+
+namespace torch { namespace utils {
+
+Device getDevice(const at::Tensor tensor) {
+  return torch::Device(torch::getDeviceType(tensor.type()), tensor.type().is_cuda() ? tensor.get_device(): 0, false);
+}
+
+}} // namespace torch::utils
diff --git a/torch/csrc/utils/tensor_dtypes.cpp b/torch/csrc/utils/tensor_dtypes.cpp
index c8f7963..68ed303 100644
--- a/torch/csrc/utils/tensor_dtypes.cpp
+++ b/torch/csrc/utils/tensor_dtypes.cpp
@@ -38,45 +38,26 @@
 void initializeDtypes() {
   auto torch_module = THPObjectPtr(PyImport_ImportModule("torch"));
   if (!torch_module) python_error();
-  auto cuda_module = THPObjectPtr(PyImport_ImportModule("torch.cuda"));
-  if (!cuda_module) python_error();
-  for (auto type_pair : torch::utils::all_declared_types()) {
-    at::Backend backend;
-    at::ScalarType scalarType;
-    std::tie(backend, scalarType) = type_pair;
+
+#define DEFINE_SCALAR_TYPE(_1,n,_2) at::ScalarType::n,
+
+  at::ScalarType all_scalar_types[] = {
+    AT_FORALL_SCALAR_TYPES(DEFINE_SCALAR_TYPE)
+  };
+
+  for (at::ScalarType scalarType: all_scalar_types) {
     std::string primary_name, legacy_name;
     std::tie(primary_name, legacy_name) = getDtypeNames(scalarType);
-    PyObject *module = nullptr;
-    bool is_cuda;
-    switch (backend) {
-      case at::kCPU: {
-        module = torch_module.get();
-        is_cuda = false;
-        break;
-      }
-      case at::kCUDA: {
-        module = cuda_module.get();
-        is_cuda = true;
-        break;
-      }
-      case at::kSparseCPU: {
-        continue;
-      }
-      case at::kSparseCUDA: {
-        continue;
-      }
-      default: throw std::runtime_error("Unimplemented backend");
-    }
-    std::string name = std::string(PyModule_GetName(module)) + '.' + primary_name;
-    PyObject *dtype = THPDtype_New(scalarType, is_cuda, name);
-    torch::registerDtypeObject((THPDtype*)dtype, scalarType, is_cuda);
+    std::string name = std::string(PyModule_GetName(torch_module.get())) + '.' + primary_name;
+    PyObject *dtype = THPDtype_New(scalarType, name);
+    torch::registerDtypeObject((THPDtype*)dtype, scalarType);
     Py_INCREF(dtype);
-    if (PyModule_AddObject(module, primary_name.c_str(), dtype) != 0) {
+    if (PyModule_AddObject(torch_module.get(), primary_name.c_str(), dtype) != 0) {
       throw python_error();
     }
     if (legacy_name != "") {
       Py_INCREF(dtype);
-      if (PyModule_AddObject(module, legacy_name.c_str(), dtype) != 0) {
+      if (PyModule_AddObject(torch_module.get(), legacy_name.c_str(), dtype) != 0) {
         throw python_error();
       }
     }
diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp
index 25d6145..90426ae 100644
--- a/torch/csrc/utils/tensor_new.cpp
+++ b/torch/csrc/utils/tensor_new.cpp
@@ -371,9 +371,11 @@
   throw std::runtime_error("new(): invalid arguments");
 }
 
-static const Type& typeWithDefault(PythonArgs& r, int64_t idx, const Type& type) {
-  auto dtype = r.dtypeWithDefault(idx, *torch::getDtype(type.scalarType(), type.is_cuda()));
-  return torch::getType(dtype, *torch::getLayout(type.backend()));
+static const Type& typeWithDefault(PythonArgs& r, int64_t dtype_idx, int64_t device_idx, const Type& type) {
+  auto scalartype = r.scalartypeWithDefault(dtype_idx, type.scalarType());
+  auto types_device_type = torch::getDeviceType(type);
+  auto device_type = r.isNone(device_idx) ? types_device_type : r.device(device_idx).type;
+  return torch::getType(scalartype, *torch::getLayout(type.backend()), device_type);
 }
 
 static Tensor set_requires_grad(Tensor self, bool requires_grad) {
@@ -386,15 +388,15 @@
   const auto& default_sparse_type = type.toBackend(sparse_backend);
 
   static PythonArgParser parser({
-    "sparse_coo_tensor(PyObject* indices, PyObject* values, *, Dtype dtype=None, Device? device=None, bool requires_grad=False)",
-    "sparse_coo_tensor(PyObject* indices, PyObject* values, IntList size, *, Dtype dtype=None, Device? device=None, bool requires_grad=False)",
+    "sparse_coo_tensor(PyObject* indices, PyObject* values, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
+    "sparse_coo_tensor(PyObject* indices, PyObject* values, IntList size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
   });
 
   ParsedArgs<6> parsed_args;
   auto r = parser.parse(args, kwargs, parsed_args);
   if (r.idx == 0) {
     bool type_inference = r.isNone(2);
-    const auto& sparse_type = typeWithDefault(r, 2, default_sparse_type);
+    const auto& sparse_type = typeWithDefault(r, 2, 3, default_sparse_type);
     const auto& dense_type = sparse_type.toBackend(sparse_type.is_cuda() ? kCUDA : kCPU);
     const auto& index_type = dense_type.toScalarType(kLong);
     AutoGPU autogpu(r.deviceInt64(3));
@@ -405,7 +407,7 @@
     return set_requires_grad(sparse_type_to_use.sparse_coo_tensor(indices, values), r.toBool(4));
   } else if (r.idx == 1) {
     bool type_inference = r.isNone(3);
-    const auto& sparse_type = typeWithDefault(r, 3, default_sparse_type);
+    const auto& sparse_type = typeWithDefault(r, 3, 4, default_sparse_type);
     const auto& dense_type = sparse_type.toBackend(sparse_type.is_cuda() ? kCUDA : kCPU);
     const auto& index_type = dense_type.toScalarType(kLong);
     AutoGPU autogpu(r.deviceInt64(4));
@@ -420,7 +422,7 @@
 
 Tensor tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs) {
   static PythonArgParser parser({
-    "tensor(PyObject* data, *, Dtype dtype=None, Device? device=None, bool requires_grad=False)",
+    "tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
   });
 
   ParsedArgs<4> parsed_args;
@@ -428,7 +430,7 @@
   if (r.idx == 0) {
     bool type_inference = r.isNone(1);
     return set_requires_grad(internal_new_from_data(
-        typeWithDefault(r, 1, type), r.deviceInt64(2), r.pyobject(0), true, true, type_inference), r.toBool(3));
+        typeWithDefault(r, 1, 2, type), r.deviceInt64(2), r.pyobject(0), true, true, type_inference), r.toBool(3));
   }
   throw std::runtime_error("tensor(): invalid arguments");
 }
@@ -436,27 +438,27 @@
 
 Tensor new_tensor(const Type& type, PyObject* args, PyObject* kwargs) {
   static PythonArgParser parser({
-    "new_tensor(PyObject* data, *, Dtype dtype=None, Device? device=None, bool requires_grad=False)",
+    "new_tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
   });
 
   ParsedArgs<4> parsed_args;
   auto r = parser.parse(args, kwargs, parsed_args);
   if (r.idx == 0) {
     return set_requires_grad(new_from_data_copy(
-        typeWithDefault(r, 1, type), r.deviceInt64(2), r.pyobject(0)), r.toBool(3));
+        typeWithDefault(r, 1, 2, type), r.deviceInt64(2), r.pyobject(0)), r.toBool(3));
   }
   throw std::runtime_error("new_tensor(): invalid arguments");
 }
 
 Tensor new_empty(const at::Type& type, PyObject* args, PyObject* kwargs) {
   static PythonArgParser parser({
-    "new_empty(IntList size, *, Dtype dtype=None, Device? device=None, bool requires_grad=False)",
+    "new_empty(IntList size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
   });
 
   ParsedArgs<4> parsed_args;
   auto r = parser.parse(args, kwargs, parsed_args);
   if (r.idx == 0) {
-    const auto& actual_type = typeWithDefault(r, 1, type);
+    const auto& actual_type = typeWithDefault(r, 1, 2, type);
     return set_requires_grad(new_with_sizes(actual_type, r.deviceInt64(2), r.intlist(0)), r.toBool(3));
   }
   throw std::runtime_error("new_empty(): invalid arguments");
@@ -464,13 +466,13 @@
 
 Tensor new_full(const at::Type& type, PyObject* args, PyObject* kwargs) {
   static PythonArgParser parser({
-    "new_full(IntList size, Scalar fill_value, *, Dtype dtype=None, Device? device=None, bool requires_grad=False)",
+    "new_full(IntList size, Scalar fill_value, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
   });
 
   ParsedArgs<5> parsed_args;
   auto r = parser.parse(args, kwargs, parsed_args);
   if (r.idx == 0) {
-    const auto& actual_type = typeWithDefault(r, 2, type);
+    const auto& actual_type = typeWithDefault(r, 2, 3, type);
     return set_requires_grad(dispatch_full(actual_type, r.scalar(1), r.deviceInt64(3), r.intlist(0)), r.toBool(4));
   }
   throw std::runtime_error("new_full(): invalid arguments");
@@ -478,13 +480,13 @@
 
 Tensor new_ones(const at::Type& type, PyObject* args, PyObject* kwargs) {
   static PythonArgParser parser({
-    "new_ones(IntList size, *, Dtype dtype=None, Device? device=None, bool requires_grad=False)",
+    "new_ones(IntList size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
   });
 
   ParsedArgs<4> parsed_args;
   auto r = parser.parse(args, kwargs, parsed_args);
   if (r.idx == 0) {
-    const auto& actual_type = typeWithDefault(r, 1, type);
+    const auto& actual_type = typeWithDefault(r, 1, 2, type);
     return set_requires_grad(dispatch_ones(actual_type, r.deviceInt64(2), r.intlist(0)), r.toBool(3));
   }
   throw std::runtime_error("new_ones(): invalid arguments");
@@ -492,13 +494,13 @@
 
 Tensor new_zeros(const at::Type& type, PyObject* args, PyObject* kwargs) {
   static PythonArgParser parser({
-    "new_zeros(IntList size, *, Dtype dtype=None, Device? device=None, bool requires_grad=False)",
+    "new_zeros(IntList size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
   });
 
   ParsedArgs<4> parsed_args;
   auto r = parser.parse(args, kwargs, parsed_args);
   if (r.idx == 0) {
-    const auto& actual_type = typeWithDefault(r, 1, type);
+    const auto& actual_type = typeWithDefault(r, 1, 2, type);
     return set_requires_grad(dispatch_zeros(actual_type, r.deviceInt64(2), r.intlist(0)), r.toBool(3));
   }
   throw std::runtime_error("new_zeros(): invalid arguments");
diff --git a/torch/nn/_functions/rnn.py b/torch/nn/_functions/rnn.py
index 16b47fa..3eb2c2e 100644
--- a/torch/nn/_functions/rnn.py
+++ b/torch/nn/_functions/rnn.py
@@ -270,7 +270,8 @@
             cx = None
 
         handle = cudnn.get_handle()
-        dropout_ts = cudnn.rnn.init_dropout_state(torch.cuda.uint8, dropout, train, dropout_seed, dropout_state)
+        dropout_ts = cudnn.rnn.init_dropout_state(torch.uint8, torch.device('cuda'), dropout,
+                                                  train, dropout_seed, dropout_state)
 
         weight_arr = list(itertools.chain.from_iterable(weight))
         weight_stride0 = len(weight[0])
diff --git a/torch/testing/__init__.py b/torch/testing/__init__.py
index 015180e..e609ac5 100644
--- a/torch/testing/__init__.py
+++ b/torch/testing/__init__.py
@@ -85,11 +85,8 @@
 
 
 def get_all_dtypes():
-    cpu_dtypes = [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64,
-                  torch.float16, torch.float32, torch.float64]
-    cuda_dtypes = [torch.cuda.uint8, torch.cuda.int8, torch.cuda.int16, torch.cuda.int32, torch.cuda.int64,
-                   torch.cuda.float16, torch.cuda.float32, torch.cuda.float64]
-    return cpu_dtypes + cuda_dtypes
+    return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64,
+            torch.float16, torch.float32, torch.float64]
 
 
 # 'dtype': (rtol, atol)