[functorch] fixed issue with incorrect device querying on cuda
diff --git a/functorch/functorch/_src/decompositions.py b/functorch/functorch/_src/decompositions.py
index d6f70e5..bf38722 100644
--- a/functorch/functorch/_src/decompositions.py
+++ b/functorch/functorch/_src/decompositions.py
@@ -269,6 +269,7 @@
zero_grad = aten.full_like(grad, 0)
return aten.index_put(grad_weight, [indices_rank1], aten.where(skip_padding, grad, zero_grad), accumulate=True)
+
# @register_decomposition(aten.addmm)
# def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta=1, alpha=1):
# if not self.is_floating_point():
diff --git a/functorch/test/test_ops.py b/functorch/test/test_ops.py
index bad1dc8..c1bc351 100644
--- a/functorch/test/test_ops.py
+++ b/functorch/test/test_ops.py
@@ -1106,9 +1106,10 @@
def upcast_tensor(x, dtype=torch.float32):
if isinstance(x, Tensor) and (x.dtype == torch.bfloat16 or x.dtype == torch.float16):
x = x.to(dtype=dtype)
+ FLOAT16_DTYPE = 5
BFLOAT16_DTYPE = 15
FLOAT64_DTYPE = 7
- if isinstance(x, int) and func in dtype_arg_table and x == BFLOAT16_DTYPE:
+ if isinstance(x, int) and func in dtype_arg_table and x in [FLOAT16_DTYPE, BFLOAT16_DTYPE]:
x = FLOAT64_DTYPE
return x
@@ -1143,7 +1144,7 @@
wrapped_out = tree_map(wrap_tensor, real_out)
return wrapped_out
- if TEST_DTYPE not in op.supported_dtypes(TEST_DTYPE):
+ if TEST_DTYPE not in op.supported_dtypes(self.device_type):
self.skipTest("Dtype not in op's supported dtypes")
return
if is_inplace(op, op.get_op()):