[functorch] fixed issue with incorrect device querying on cuda
diff --git a/functorch/functorch/_src/decompositions.py b/functorch/functorch/_src/decompositions.py
index d6f70e5..bf38722 100644
--- a/functorch/functorch/_src/decompositions.py
+++ b/functorch/functorch/_src/decompositions.py
@@ -269,6 +269,7 @@
     zero_grad = aten.full_like(grad, 0)
     return aten.index_put(grad_weight, [indices_rank1], aten.where(skip_padding, grad, zero_grad), accumulate=True)
 
+
 # @register_decomposition(aten.addmm)
 # def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta=1, alpha=1):
 #     if not self.is_floating_point():
diff --git a/functorch/test/test_ops.py b/functorch/test/test_ops.py
index bad1dc8..c1bc351 100644
--- a/functorch/test/test_ops.py
+++ b/functorch/test/test_ops.py
@@ -1106,9 +1106,10 @@
                     def upcast_tensor(x, dtype=torch.float32):
                         if isinstance(x, Tensor) and (x.dtype == torch.bfloat16 or x.dtype == torch.float16):
                             x = x.to(dtype=dtype)
+                        FLOAT16_DTYPE = 5
                         BFLOAT16_DTYPE = 15
                         FLOAT64_DTYPE = 7
-                        if isinstance(x, int) and func in dtype_arg_table and x == BFLOAT16_DTYPE:
+                        if isinstance(x, int) and func in dtype_arg_table and x in [FLOAT16_DTYPE, BFLOAT16_DTYPE]:
                             x = FLOAT64_DTYPE
                         return x
 
@@ -1143,7 +1144,7 @@
                 wrapped_out = tree_map(wrap_tensor, real_out)
                 return wrapped_out
 
-        if TEST_DTYPE not in op.supported_dtypes(TEST_DTYPE):
+        if TEST_DTYPE not in op.supported_dtypes(self.device_type):
             self.skipTest("Dtype not in op's supported dtypes")
             return
         if is_inplace(op, op.get_op()):