|  | #include "caffe2/operators/batch_gather_ops.h" | 
|  |  | 
|  | namespace caffe2 { | 
|  |  | 
|  | REGISTER_CPU_OPERATOR(BatchGather, BatchGatherOp<CPUContext>); | 
|  | REGISTER_CPU_OPERATOR(BatchGatherGradient, BatchGatherGradientOp<CPUContext>); | 
|  |  | 
|  | OPERATOR_SCHEMA(BatchGather) | 
|  | .NumInputs(2) | 
|  | .NumOutputs(1) | 
|  | .TensorInferenceFunction([](const OperatorDef& def, | 
|  | const vector<TensorShape>& in) { | 
|  | vector<TensorShape> out(1); | 
|  | ArgumentHelper helper(def); | 
|  | const auto& data_dims = GetDimsVector(in[0]); | 
|  | const auto& indices_dims = GetDimsVector(in[1]); | 
|  |  | 
|  | vector<int> output_dims = | 
|  | caffe2::gather_helper::calc_output_shape_vector<int>( | 
|  | data_dims, indices_dims, 1, false); | 
|  | out[0] = CreateTensorShape(output_dims, TensorProto::FLOAT); | 
|  | return out; | 
|  | }) | 
|  | .SetDoc(R"DOC( | 
|  | Batch gather operation, first dimension in DATA is the batch size. | 
|  | Given DATA tensor of rank r >= 2, and INDICES tensor of rank q >= 1, gather | 
|  | entries of the second outer dimension (axis == 1) of DATA indexed by INDICES, | 
|  | and concatenate them in an output tensor of rank q + (r - 1). | 
|  |  | 
|  | Example: | 
|  | DATA  = [ | 
|  | [1.0, 1.2, 2.4, 4.5], | 
|  | [2.3, 3.4, 3.6, 2.3], | 
|  | [4.5, 5.7, 1.2, 4.5], | 
|  | ] | 
|  | INDICES = [0, 2] | 
|  |  | 
|  | OUTPUT = [ | 
|  | [1.0, 2.4], | 
|  | [2.3, 3.6], | 
|  | [4.5, 1.2], | 
|  | ] | 
|  | )DOC") | 
|  | .Input(0, "DATA", "Tensor of rank r >= 2.") | 
|  | .Input(1, "INDICES", "Tensor of int32/int64 indices, of any rank q.") | 
|  | .Output(0, "OUTPUT", "Tensor of rank q + (r - 1).") | 
|  | .InheritOnnxSchema(); | 
|  |  | 
|  | OPERATOR_SCHEMA(BatchGatherGradient).NumInputs(3).NumOutputs(1); | 
|  |  | 
|  | class GetBatchGatherGradient : public GradientMakerBase { | 
|  | using GradientMakerBase::GradientMakerBase; | 
|  | vector<OperatorDef> GetGradientDefs() override { | 
|  | using Op = BatchGatherOp<CPUContext>; | 
|  | return SingleGradientDef( | 
|  | "BatchGatherGradient", | 
|  | "", | 
|  | vector<string>{I(Op::DATA), I(Op::INDICES), GO(0)}, | 
|  | vector<string>{GI(0)}); | 
|  | } | 
|  | }; | 
|  |  | 
|  | REGISTER_GRADIENT(BatchGather, GetBatchGatherGradient); | 
|  |  | 
|  | } // namespace caffe2 |