fix #16448 (#18479)
Summary:
Fixes #16448
bddppq
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18479
Differential Revision: D14635360
Pulled By: ezyang
fbshipit-source-id: 4010319fbce050dd0bdf4da3cd1171b9737f3c4c
diff --git a/test/common_utils.py b/test/common_utils.py
index 43131c0..cca6663 100644
--- a/test/common_utils.py
+++ b/test/common_utils.py
@@ -421,12 +421,7 @@
if a.device.type == 'cpu' and a.dtype == torch.float16:
# CPU half tensors don't have the methods we need below
a = a.to(torch.float32)
- if TEST_WITH_ROCM:
- # Workaround for bug https://github.com/pytorch/pytorch/issues/16448
- # TODO: remove after the bug is resolved.
- b = b.to(a.dtype).to(a.device)
- else:
- b = b.to(a)
+ b = b.to(a)
if (a.dtype == torch.bool) != (b.dtype == torch.bool):
raise TypeError("Was expecting both tensors to be bool type.")