| #include "counter_ops.h" | 
 | #include "caffe2/core/blob_serialization.h" | 
 |  | 
 | namespace caffe2 { | 
 |  | 
 | const char* githubLinks = R"DOC( | 
 |   Github Links: | 
 |   - https://github.com/pytorch/pytorch/blob/main/caffe2/operators/counter_ops.cc | 
 |  | 
 | )DOC"; | 
 |  | 
 | const char* kCountExample = R"DOC( | 
 | <details> | 
 |  | 
 | <summary> <b>Example</b> </summary> | 
 |  | 
 | **Code** | 
 |  | 
 | ``` | 
 | workspace.ResetWorkspace() | 
 |  | 
 | createcounter_op = core.CreateOperator( | 
 |     "CreateCounter", | 
 |     [], | 
 |     ["counter"], | 
 |     init_count=5 | 
 | ) | 
 |  | 
 | retrievecount_op = core.CreateOperator( | 
 |     "RetrieveCount", | 
 |     ["counter"], | 
 |     ["count"] | 
 | ) | 
 |  | 
 | checkcounterdone_op = core.CreateOperator( | 
 |     "CheckCounterDone", | 
 |     ["counter"], | 
 |     ["done"] | 
 | ) | 
 |  | 
 | countup_op = core.CreateOperator( | 
 |     "CountUp", | 
 |     ["counter"], | 
 |     ["previous_count"], | 
 | ) | 
 |  | 
 | countdown_op = core.CreateOperator( | 
 |     "CountDown", | 
 |     ["counter"], | 
 |     ["done"], | 
 | ) | 
 |  | 
 | resetcounter_op = core.CreateOperator( | 
 |     "ResetCounter", | 
 |     ["counter"], | 
 |     ["previous_count"], | 
 |     init_count=3 | 
 | ) | 
 |  | 
 |  | 
 | // Create counter | 
 | workspace.RunOperatorOnce(createcounter_op) | 
 | print("'counter' pointer:", workspace.FetchBlob("counter")) | 
 |  | 
 |  | 
 | // Retrieve initial counter value | 
 | workspace.RunOperatorOnce(retrievecount_op) | 
 | print("Initial 'count':", workspace.FetchBlob("count")) | 
 |  | 
 |  | 
 | // Check if counter is done | 
 | workspace.RunOperatorOnce(checkcounterdone_op) | 
 | print("Initial 'done' value:", workspace.FetchBlob("done")) | 
 |  | 
 |  | 
 | // Test CountUp operator | 
 | print("\nTesting CountUp operator...") | 
 | for i in range(5): | 
 |     workspace.RunOperatorOnce(countup_op) | 
 |     print("'previous_count' after CountUp:", workspace.FetchBlob("previous_count")) | 
 |  | 
 | workspace.RunOperatorOnce(retrievecount_op) | 
 | print("'count' value after CountUp test:", workspace.FetchBlob("count")) | 
 |  | 
 |  | 
 | // Test CountDown operator | 
 | print("\nTesting CountDown operator...") | 
 | for i in range(11): | 
 |     workspace.RunOperatorOnce(countdown_op) | 
 |     workspace.RunOperatorOnce(retrievecount_op) | 
 |     print("'count' value after CountDown: {}\t'done' value: {}".format(workspace.FetchBlob("count"), workspace.FetchBlob("done"))) | 
 | ``` | 
 |  | 
 | **Result** | 
 |  | 
 | ``` | 
 | 'counter' pointer: counter, a C++ native class of type std::__1::unique_ptr<caffe2::Counter<long long>, std::__1::default_delete<caffe2::Counter<long long> > >. | 
 | Initial 'count': 5 | 
 | Initial 'done' value: False | 
 |  | 
 | Testing CountUp operator... | 
 | 'previous_count' after CountUp: 5 | 
 | 'previous_count' after CountUp: 6 | 
 | 'previous_count' after CountUp: 7 | 
 | 'previous_count' after CountUp: 8 | 
 | 'previous_count' after CountUp: 9 | 
 | 'count' value after CountUp test: 10 | 
 |  | 
 | Testing CountDown operator... | 
 | 'count' value after CountDown: 9        'done' value: False | 
 | 'count' value after CountDown: 8        'done' value: False | 
 | 'count' value after CountDown: 7        'done' value: False | 
 | 'count' value after CountDown: 6        'done' value: False | 
 | 'count' value after CountDown: 5        'done' value: False | 
 | 'count' value after CountDown: 4        'done' value: False | 
 | 'count' value after CountDown: 3        'done' value: False | 
 | 'count' value after CountDown: 2        'done' value: False | 
 | 'count' value after CountDown: 1        'done' value: False | 
 | 'count' value after CountDown: 0        'done' value: False | 
 | 'count' value after CountDown: -1        'done' value: True | 
 | ``` | 
 |  | 
 | </details> | 
 |  | 
 | )DOC"; | 
 |  | 
 | namespace { | 
 | /** | 
 |  *  @brief CounterSerializer is the serializer for Counter type. | 
 |  * | 
 |  * CounterSerializer takes in a blob that contains a Counter, and serializes | 
 |  * it into a BlobProto protocol buffer. At the moment only int64_t counters are | 
 |  * supported (since it's the only once that is really used). | 
 |  * | 
 |  */ | 
 | class CounterSerializer : public BlobSerializerBase { | 
 |  public: | 
 |   // NOLINTNEXTLINE(modernize-use-equals-default) | 
 |   CounterSerializer() {} | 
 |   // NOLINTNEXTLINE(modernize-use-equals-default) | 
 |   ~CounterSerializer() override {} | 
 |  | 
 |   void Serialize( | 
 |       const void* pointer, | 
 |       TypeMeta typeMeta, | 
 |       const string& name, | 
 |       SerializationAcceptor acceptor) override { | 
 |     CAFFE_ENFORCE(typeMeta.Match<std::unique_ptr<Counter<int64_t>>>()); | 
 |  | 
 |     BlobProto blob_proto; | 
 |     blob_proto.set_name(name); | 
 |     blob_proto.set_type("std::unique_ptr<Counter<int64_t>>"); | 
 |     TensorProto& proto = *blob_proto.mutable_tensor(); | 
 |     proto.set_name(name); | 
 |     proto.set_data_type(TensorProto_DataType_INT64); | 
 |     proto.add_dims(1); | 
 |     proto.add_int64_data( | 
 |         (*static_cast<const std::unique_ptr<Counter<int64_t>>*>(pointer)) | 
 |             ->retrieve()); | 
 |     acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto)); | 
 |   } | 
 | }; | 
 |  | 
 | /** | 
 |  * @brief CounterDeserializer is the deserializer for Counters. | 
 |  * | 
 |  */ | 
 | class CounterDeserializer : public BlobDeserializerBase { | 
 |  public: | 
 |   void Deserialize(const BlobProto& proto, Blob* blob) override { | 
 |     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) | 
 |     auto tensorProto = proto.tensor(); | 
 |     CAFFE_ENFORCE_EQ(tensorProto.dims_size(), 1, "Unexpected size of dims"); | 
 |     CAFFE_ENFORCE_EQ(tensorProto.dims(0), 1, "Unexpected value of dims"); | 
 |     CAFFE_ENFORCE_EQ( | 
 |         tensorProto.data_type(), | 
 |         TensorProto_DataType_INT64, | 
 |         "Only int64_t counters supported"); | 
 |     CAFFE_ENFORCE_EQ( | 
 |         tensorProto.int64_data_size(), 1, "Unexpected size of data"); | 
 |     *blob->GetMutable<std::unique_ptr<Counter<int64_t>>>() = | 
 |         std::make_unique<Counter<int64_t>>(tensorProto.int64_data(0)); | 
 |   } | 
 | }; | 
 | } | 
 |  | 
 | // TODO(jiayq): deprecate these ops & consolidate them with | 
 | // IterOp/AtomicIterOp | 
 |  | 
 | REGISTER_CPU_OPERATOR(CreateCounter, CreateCounterOp<int64_t, CPUContext>); | 
 | REGISTER_CPU_OPERATOR(ResetCounter, ResetCounterOp<int64_t, CPUContext>); | 
 | REGISTER_CPU_OPERATOR(CountDown, CountDownOp<int64_t, CPUContext>); | 
 | REGISTER_CPU_OPERATOR( | 
 |     CheckCounterDone, | 
 |     CheckCounterDoneOp<int64_t, CPUContext>); | 
 | REGISTER_CPU_OPERATOR(CountUp, CountUpOp<int64_t, CPUContext>); | 
 | REGISTER_CPU_OPERATOR(RetrieveCount, RetrieveCountOp<int64_t, CPUContext>); | 
 |  | 
 | OPERATOR_SCHEMA(CreateCounter) | 
 |     .NumInputs(0) | 
 |     .NumOutputs(1) | 
 |     .SetDoc(R"DOC( | 
 | Creates a count-down counter with initial value specified by the `init_count` | 
 | argument. | 
 |  | 
 | )DOC" + (string) githubLinks + (string) kCountExample) | 
 |     .Output( | 
 |         0, | 
 |         "counter", | 
 |         "*(type: Tensor`<ptr>`)* A blob pointing to an instance of a new counter.") | 
 |     .Arg( | 
 |         "init_count", | 
 |         "*(type: int; default: 0)* Initial count for the counter, must be >= 0."); | 
 |  | 
 | OPERATOR_SCHEMA(ResetCounter) | 
 |     .NumInputs(1) | 
 |     .NumOutputs(0, 1) | 
 |     .SetDoc(R"DOC( | 
 | Resets a count-down counter with initial value specified by the `init_count` | 
 | argument. | 
 | )DOC" + (string) githubLinks + (string) kCountExample) | 
 |     .Input( | 
 |         0, | 
 |         "counter", | 
 |         "*(type: Tensor`<ptr>`)* A blob pointing to an instance of a counter.") | 
 |     .Output( | 
 |         0, | 
 |         "previous_value", | 
 |         "*(type: int)* [OPTIONAL] count value BEFORE this operation.") | 
 |     .Arg( | 
 |         "init_count", | 
 |         "*(type: int; default: 0)* Resets counter to this value, must be >= 0."); | 
 |  | 
 | OPERATOR_SCHEMA(CountDown) | 
 |     .NumInputs(1) | 
 |     .NumOutputs(1) | 
 |     .SetDoc(R"DOC( | 
 | If the internal count value > 0, decreases count value by 1 and outputs False, | 
 | otherwise outputs True. | 
 | )DOC" + (string) githubLinks + (string) kCountExample) | 
 |     .Input( | 
 |         0, | 
 |         "counter", | 
 |         "*(type: Tensor`<ptr>`)* A blob pointing to an instance of a counter.") | 
 |     .Output( | 
 |         0, | 
 |         "done", | 
 |         "*(type: bool)* False unless the internal count is zero."); | 
 |  | 
 | OPERATOR_SCHEMA(CheckCounterDone) | 
 |     .NumInputs(1) | 
 |     .NumOutputs(1) | 
 |     .SetDoc(R"DOC( | 
 | If the internal count value <= 0, outputs true, otherwise outputs false. | 
 | )DOC" + (string) githubLinks + (string) kCountExample) | 
 |     .Input( | 
 |         0, | 
 |         "counter", | 
 |         "*(type: Tensor`<ptr>`)* A blob pointing to an instance of a counter.") | 
 |     .Output( | 
 |         0, | 
 |         "done", | 
 |         "*(type: bool)* True if the internal count is zero or negative, otherwise False."); | 
 |  | 
 | OPERATOR_SCHEMA(CountUp) | 
 |     .NumInputs(1) | 
 |     .NumOutputs(1) | 
 |     .SetDoc(R"DOC( | 
 | Increases count value by 1 and outputs the previous value atomically. | 
 | )DOC" + (string) githubLinks + (string) kCountExample) | 
 |     .Input( | 
 |         0, | 
 |         "counter", | 
 |         "*(type: Tensor`<ptr>`)* A blob pointing to an instance of a counter.") | 
 |     .Output( | 
 |         0, | 
 |         "previous_count", | 
 |         "*(type: int)* Count value BEFORE this operation."); | 
 |  | 
 | OPERATOR_SCHEMA(RetrieveCount) | 
 |     .NumInputs(1) | 
 |     .NumOutputs(1) | 
 |     .ScalarType(TensorProto::INT64) | 
 |     .SetDoc(R"DOC( | 
 | Retrieve the current value from the counter as an integer. | 
 | )DOC" + (string) githubLinks + (string) kCountExample) | 
 |     .Input( | 
 |         0, | 
 |         "counter", | 
 |         "*(type: Tensor`<ptr>`)* A blob pointing to an instance of a counter.") | 
 |     .Output( | 
 |         0, | 
 |         "count", | 
 |         "*(type: int)* Current count value."); | 
 |  | 
 | SHOULD_NOT_DO_GRADIENT(CreateCounter); | 
 | SHOULD_NOT_DO_GRADIENT(ResetCounter); | 
 | SHOULD_NOT_DO_GRADIENT(CountDown); | 
 | SHOULD_NOT_DO_GRADIENT(CountUp); | 
 | SHOULD_NOT_DO_GRADIENT(RetrieveCount); | 
 |  | 
 | CAFFE_KNOWN_TYPE(std::unique_ptr<Counter<int64_t>>); | 
 | REGISTER_BLOB_SERIALIZER( | 
 |     (TypeMeta::Id<std::unique_ptr<Counter<int64_t>>>()), | 
 |     CounterSerializer); | 
 | REGISTER_BLOB_DESERIALIZER( | 
 |     std::unique_ptr<Counter<int64_t>>, | 
 |     CounterDeserializer); | 
 |  | 
 | } // namespace caffe2 |