Add support for exporting Addmm with alpha != 1 or beta != 1
diff --git a/torch/autograd/_functions/blas.py b/torch/autograd/_functions/blas.py
index c7bcb3b2..787ac3b 100644
--- a/torch/autograd/_functions/blas.py
+++ b/torch/autograd/_functions/blas.py
@@ -17,10 +17,10 @@
@staticmethod
def symbolic(g, add_matrix, matrix1, matrix2, alpha=1, beta=1, inplace=False):
- # TODO: manually insert the necessary scaling, since ONNX doesn't
- # natively support it
- if alpha != 1 or beta != 1:
- return None
+ if alpha != 1:
+ matrix1 = g.op("Scale", matrix1, scale_f=alpha)
+ if beta != 1:
+ add_matrix = g.op("Scale", add_matrix, scale_f=beta)
# TODO: Talk to ONNX about why their FC involves a transpose
matrix2_t = g.op("Transpose", matrix2)
return g.op("FC", matrix1, matrix2_t, add_matrix)