[vulkan] addmm support non-vulkan inputs (#39078)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39078
Adding support of non-vulkan inputs for addmm operator:
if it is not on vulkan - converting to it inside operator,
if we run torchscript pretrained model - weights of linear op will be on CPU, we need this to run mobilenetV2 on Vulkan backend
Test Plan: Imported from OSS
Differential Revision: D21962425
Pulled By: IvanKobzarev
fbshipit-source-id: 8222edd31dfb14b326d15e6fec5c8778783479df
diff --git a/aten/src/ATen/native/vulkan/VulkanAten.cpp b/aten/src/ATen/native/vulkan/VulkanAten.cpp
index c8b830a..6dd3683 100644
--- a/aten/src/ATen/native/vulkan/VulkanAten.cpp
+++ b/aten/src/ATen/native/vulkan/VulkanAten.cpp
@@ -251,9 +251,12 @@
const Tensor& mat2,
Scalar beta,
Scalar alpha) {
- VulkanTensor& t = vtensor_from_vulkan(self);
- VulkanTensor& m1 = vtensor_from_vulkan(mat1);
- VulkanTensor& m2 = vtensor_from_vulkan(mat2);
+ const VulkanTensor t =
+ vtensor_from_vulkan(self.is_vulkan() ? self : self.vulkan());
+ const VulkanTensor m1 =
+ vtensor_from_vulkan(mat1.is_vulkan() ? mat1 : mat1.vulkan());
+ const VulkanTensor m2 =
+ vtensor_from_vulkan(mat2.is_vulkan() ? mat2 : mat2.vulkan());
float b = beta.to<float>();
float a = alpha.to<float>();