| /** |
| * Copyright (c) 2016-present, Facebook, Inc. |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #ifndef CAFFE2_OPERATORS_MATMUL_OP_H_ |
| #define CAFFE2_OPERATORS_MATMUL_OP_H_ |
| |
| #include "caffe2/core/context.h" |
| #include "caffe2/core/operator.h" |
| #include "caffe2/utils/math.h" |
| |
| namespace caffe2 { |
| |
| template <typename T, class Context, class Engine = DefaultEngine> |
| class MatMulOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| MatMulOp(const OperatorDef& operator_def, Workspace* ws) |
| : Operator<Context>(operator_def, ws), |
| axis_a_(OperatorBase::GetSingleArgument<int>("axis_a", 1)), |
| axis_b_(OperatorBase::GetSingleArgument<int>("axis_b", 1)), |
| trans_a_(OperatorBase::GetSingleArgument<int>("trans_a", 0)), |
| trans_b_(OperatorBase::GetSingleArgument<int>("trans_b", 0)) {} |
| ~MatMulOp() {} |
| |
| bool RunOnDevice() override { |
| const auto& A = Input(0); |
| const auto& B = Input(1); |
| auto* Y = Output(0); |
| |
| const auto canonical_axis_a = A.canonical_axis_index(axis_a_); |
| const auto canonical_axis_b = B.canonical_axis_index(axis_b_); |
| int A_dim0 = A.size_to_dim(canonical_axis_a); |
| int A_dim1 = A.size_from_dim(canonical_axis_a); |
| int B_dim0 = B.size_to_dim(canonical_axis_b); |
| int B_dim1 = B.size_from_dim(canonical_axis_b); |
| |
| int a_dim0, a_dim1, b_dim0, b_dim1; |
| |
| if (trans_a_) { |
| a_dim0 = A_dim1; |
| a_dim1 = A_dim0; |
| } else { |
| a_dim0 = A_dim0; |
| a_dim1 = A_dim1; |
| } |
| |
| if (trans_b_) { |
| b_dim0 = B_dim1; |
| b_dim1 = B_dim0; |
| } else { |
| b_dim0 = B_dim0; |
| b_dim1 = B_dim1; |
| } |
| |
| auto dimErrorString = [&]() { |
| return MakeString( |
| "Dimension mismatch: ", |
| trans_a_ ? "trans(A): " : "A: ", |
| a_dim0, |
| " ", |
| a_dim1, |
| trans_b_ ? ", trans(B): " : ", B: ", |
| b_dim0, |
| " ", |
| b_dim1); |
| }; |
| // Error checking |
| CAFFE_ENFORCE(a_dim1 == b_dim0, dimErrorString()); |
| |
| Y_shape_cache_[0] = a_dim0; |
| Y_shape_cache_[1] = b_dim1; |
| Y->Resize(Y_shape_cache_); |
| CAFFE_ENFORCE(a_dim0 * b_dim1 == Y->size(), dimErrorString()); |
| // Y = A * B |
| math::Gemm<T, Context, Engine>( |
| trans_a_ ? CblasTrans : CblasNoTrans, |
| trans_b_ ? CblasTrans : CblasNoTrans, |
| a_dim0, |
| b_dim1, |
| a_dim1, |
| 1, |
| A.template data<T>(), |
| B.template data<T>(), |
| 0, |
| Y->template mutable_data<T>(), |
| &context_); |
| |
| if (InputSize() == 3) { |
| // In gradient op, resize to input |
| Y->ResizeLike(Input(2)); |
| } |
| return true; |
| } |
| |
| protected: |
| // A local vector to cache the output shape so we don't need to recreate |
| // a vector object every time we run Run(). |
| vector<TIndex> Y_shape_cache_{0, 0}; |
| int axis_a_{1}; |
| int axis_b_{1}; |
| bool trans_a_; |
| bool trans_b_; |
| }; |
| |
| } // namespace caffe2 |
| |
| #endif // CAFFE2_OPERATORS_MATMUL_OP_H_ |