|  | #include "counter_ops.h" | 
|  | #include "caffe2/core/blob_serialization.h" | 
|  |  | 
|  | namespace caffe2 { | 
|  |  | 
|  | const char* githubLinks = R"DOC( | 
|  | Github Links: | 
|  | - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/counter_ops.cc | 
|  |  | 
|  | )DOC"; | 
|  |  | 
|  | const char* kCountExample = R"DOC( | 
|  | <details> | 
|  |  | 
|  | <summary> <b>Example</b> </summary> | 
|  |  | 
|  | **Code** | 
|  |  | 
|  | ``` | 
|  | workspace.ResetWorkspace() | 
|  |  | 
|  | createcounter_op = core.CreateOperator( | 
|  | "CreateCounter", | 
|  | [], | 
|  | ["counter"], | 
|  | init_count=5 | 
|  | ) | 
|  |  | 
|  | retrievecount_op = core.CreateOperator( | 
|  | "RetrieveCount", | 
|  | ["counter"], | 
|  | ["count"] | 
|  | ) | 
|  |  | 
|  | checkcounterdone_op = core.CreateOperator( | 
|  | "CheckCounterDone", | 
|  | ["counter"], | 
|  | ["done"] | 
|  | ) | 
|  |  | 
|  | countup_op = core.CreateOperator( | 
|  | "CountUp", | 
|  | ["counter"], | 
|  | ["previous_count"], | 
|  | ) | 
|  |  | 
|  | countdown_op = core.CreateOperator( | 
|  | "CountDown", | 
|  | ["counter"], | 
|  | ["done"], | 
|  | ) | 
|  |  | 
|  | resetcounter_op = core.CreateOperator( | 
|  | "ResetCounter", | 
|  | ["counter"], | 
|  | ["previous_count"], | 
|  | init_count=3 | 
|  | ) | 
|  |  | 
|  |  | 
|  | # Create counter | 
|  | workspace.RunOperatorOnce(createcounter_op) | 
|  | print("'counter' pointer:", workspace.FetchBlob("counter")) | 
|  |  | 
|  |  | 
|  | # Retrieve initial counter value | 
|  | workspace.RunOperatorOnce(retrievecount_op) | 
|  | print("Initial 'count':", workspace.FetchBlob("count")) | 
|  |  | 
|  |  | 
|  | # Check if counter is done | 
|  | workspace.RunOperatorOnce(checkcounterdone_op) | 
|  | print("Initial 'done' value:", workspace.FetchBlob("done")) | 
|  |  | 
|  |  | 
|  | # Test CountUp operator | 
|  | print("\nTesting CountUp operator...") | 
|  | for i in range(5): | 
|  | workspace.RunOperatorOnce(countup_op) | 
|  | print("'previous_count' after CountUp:", workspace.FetchBlob("previous_count")) | 
|  |  | 
|  | workspace.RunOperatorOnce(retrievecount_op) | 
|  | print("'count' value after CountUp test:", workspace.FetchBlob("count")) | 
|  |  | 
|  |  | 
|  | # Test CountDown operator | 
|  | print("\nTesting CountDown operator...") | 
|  | for i in range(11): | 
|  | workspace.RunOperatorOnce(countdown_op) | 
|  | workspace.RunOperatorOnce(retrievecount_op) | 
|  | print("'count' value after CountDown: {}\t'done' value: {}".format(workspace.FetchBlob("count"), workspace.FetchBlob("done"))) | 
|  | ``` | 
|  |  | 
|  | **Result** | 
|  |  | 
|  | ``` | 
|  | 'counter' pointer: counter, a C++ native class of type std::__1::unique_ptr<caffe2::Counter<long long>, std::__1::default_delete<caffe2::Counter<long long> > >. | 
|  | Initial 'count': 5 | 
|  | Initial 'done' value: False | 
|  |  | 
|  | Testing CountUp operator... | 
|  | 'previous_count' after CountUp: 5 | 
|  | 'previous_count' after CountUp: 6 | 
|  | 'previous_count' after CountUp: 7 | 
|  | 'previous_count' after CountUp: 8 | 
|  | 'previous_count' after CountUp: 9 | 
|  | 'count' value after CountUp test: 10 | 
|  |  | 
|  | Testing CountDown operator... | 
|  | 'count' value after CountDown: 9	'done' value: False | 
|  | 'count' value after CountDown: 8	'done' value: False | 
|  | 'count' value after CountDown: 7	'done' value: False | 
|  | 'count' value after CountDown: 6	'done' value: False | 
|  | 'count' value after CountDown: 5	'done' value: False | 
|  | 'count' value after CountDown: 4	'done' value: False | 
|  | 'count' value after CountDown: 3	'done' value: False | 
|  | 'count' value after CountDown: 2	'done' value: False | 
|  | 'count' value after CountDown: 1	'done' value: False | 
|  | 'count' value after CountDown: 0	'done' value: False | 
|  | 'count' value after CountDown: -1	'done' value: True | 
|  | ``` | 
|  |  | 
|  | </details> | 
|  |  | 
|  | )DOC"; | 
|  |  | 
|  | namespace { | 
|  | /** | 
|  | *  @brief CounterSerializer is the serializer for Counter type. | 
|  | * | 
|  | * CounterSerializer takes in a blob that contains a Counter, and serializes | 
|  | * it into a BlobProto protocol buffer. At the moment only int64_t counters are | 
|  | * supported (since it's the only once that is really used). | 
|  | * | 
|  | */ | 
|  | class CounterSerializer : public BlobSerializerBase { | 
|  | public: | 
|  | CounterSerializer() {} | 
|  | ~CounterSerializer() {} | 
|  |  | 
|  | void Serialize( | 
|  | const Blob& blob, | 
|  | const string& name, | 
|  | SerializationAcceptor acceptor) override { | 
|  | CAFFE_ENFORCE(blob.IsType<std::unique_ptr<Counter<int64_t>>>()); | 
|  |  | 
|  | BlobProto blob_proto; | 
|  | blob_proto.set_name(name); | 
|  | blob_proto.set_type("std::unique_ptr<Counter<int64_t>>"); | 
|  | TensorProto& proto = *blob_proto.mutable_tensor(); | 
|  | proto.set_name(name); | 
|  | proto.set_data_type(TensorProto_DataType_INT64); | 
|  | proto.add_dims(1); | 
|  | proto.add_int64_data( | 
|  | blob.template Get<std::unique_ptr<Counter<int64_t>>>()->retrieve()); | 
|  | acceptor(name, blob_proto.SerializeAsString()); | 
|  | } | 
|  | }; | 
|  |  | 
|  | /** | 
|  | * @brief CounterDeserializer is the deserializer for Counters. | 
|  | * | 
|  | */ | 
|  | class CounterDeserializer : public BlobDeserializerBase { | 
|  | public: | 
|  | void Deserialize(const BlobProto& proto, Blob* blob) override { | 
|  | auto tensorProto = proto.tensor(); | 
|  | CAFFE_ENFORCE_EQ(tensorProto.dims_size(), 1, "Unexpected size of dims"); | 
|  | CAFFE_ENFORCE_EQ(tensorProto.dims(0), 1, "Unexpected value of dims"); | 
|  | CAFFE_ENFORCE_EQ( | 
|  | tensorProto.data_type(), | 
|  | TensorProto_DataType_INT64, | 
|  | "Only int64_t counters supported"); | 
|  | CAFFE_ENFORCE_EQ( | 
|  | tensorProto.int64_data_size(), 1, "Unexpected size of data"); | 
|  | *blob->GetMutable<std::unique_ptr<Counter<int64_t>>>() = | 
|  | caffe2::make_unique<Counter<int64_t>>(tensorProto.int64_data(0)); | 
|  | } | 
|  | }; | 
|  | } | 
|  |  | 
|  | // TODO(jiayq): deprecate these ops & consolidate them with | 
|  | // IterOp/AtomicIterOp | 
|  |  | 
|  | REGISTER_CPU_OPERATOR(CreateCounter, CreateCounterOp<int64_t, CPUContext>); | 
|  | REGISTER_CPU_OPERATOR(ResetCounter, ResetCounterOp<int64_t, CPUContext>); | 
|  | REGISTER_CPU_OPERATOR(CountDown, CountDownOp<int64_t, CPUContext>); | 
|  | REGISTER_CPU_OPERATOR( | 
|  | CheckCounterDone, | 
|  | CheckCounterDoneOp<int64_t, CPUContext>); | 
|  | REGISTER_CPU_OPERATOR(CountUp, CountUpOp<int64_t, CPUContext>); | 
|  | REGISTER_CPU_OPERATOR(RetrieveCount, RetrieveCountOp<int64_t, CPUContext>); | 
|  |  | 
|  | OPERATOR_SCHEMA(CreateCounter) | 
|  | .NumInputs(0) | 
|  | .NumOutputs(1) | 
|  | .SetDoc(R"DOC( | 
|  | Creates a count-down counter with initial value specified by the `init_count` | 
|  | argument. | 
|  |  | 
|  | )DOC" + (string) githubLinks + (string) kCountExample) | 
|  | .Output( | 
|  | 0, | 
|  | "counter", | 
|  | "*(type: Tensor`<ptr>`)* A blob pointing to an instance of a new counter.") | 
|  | .Arg( | 
|  | "init_count", | 
|  | "*(type: int; default: 0)* Initial count for the counter, must be >= 0."); | 
|  |  | 
|  | OPERATOR_SCHEMA(ResetCounter) | 
|  | .NumInputs(1) | 
|  | .NumOutputs(0, 1) | 
|  | .SetDoc(R"DOC( | 
|  | Resets a count-down counter with initial value specified by the `init_count` | 
|  | argument. | 
|  | )DOC" + (string) githubLinks + (string) kCountExample) | 
|  | .Input( | 
|  | 0, | 
|  | "counter", | 
|  | "*(type: Tensor`<ptr>`)* A blob pointing to an instance of a counter.") | 
|  | .Output( | 
|  | 0, | 
|  | "previous_value", | 
|  | "*(type: int)* [OPTIONAL] count value BEFORE this operation.") | 
|  | .Arg( | 
|  | "init_count", | 
|  | "*(type: int; default: 0)* Resets counter to this value, must be >= 0."); | 
|  |  | 
|  | OPERATOR_SCHEMA(CountDown) | 
|  | .NumInputs(1) | 
|  | .NumOutputs(1) | 
|  | .SetDoc(R"DOC( | 
|  | If the internal count value > 0, decreases count value by 1 and outputs False, | 
|  | otherwise outputs True. | 
|  | )DOC" + (string) githubLinks + (string) kCountExample) | 
|  | .Input( | 
|  | 0, | 
|  | "counter", | 
|  | "*(type: Tensor`<ptr>`)* A blob pointing to an instance of a counter.") | 
|  | .Output( | 
|  | 0, | 
|  | "done", | 
|  | "*(type: bool)* False unless the internal count is zero."); | 
|  |  | 
|  | OPERATOR_SCHEMA(CheckCounterDone) | 
|  | .NumInputs(1) | 
|  | .NumOutputs(1) | 
|  | .SetDoc(R"DOC( | 
|  | If the internal count value <= 0, outputs true, otherwise outputs false. | 
|  | )DOC" + (string) githubLinks + (string) kCountExample) | 
|  | .Input( | 
|  | 0, | 
|  | "counter", | 
|  | "*(type: Tensor`<ptr>`)* A blob pointing to an instance of a counter.") | 
|  | .Output( | 
|  | 0, | 
|  | "done", | 
|  | "*(type: bool)* True if the internal count is zero or negative, otherwise False."); | 
|  |  | 
|  | OPERATOR_SCHEMA(CountUp) | 
|  | .NumInputs(1) | 
|  | .NumOutputs(1) | 
|  | .SetDoc(R"DOC( | 
|  | Increases count value by 1 and outputs the previous value atomically. | 
|  | )DOC" + (string) githubLinks + (string) kCountExample) | 
|  | .Input( | 
|  | 0, | 
|  | "counter", | 
|  | "*(type: Tensor`<ptr>`)* A blob pointing to an instance of a counter.") | 
|  | .Output( | 
|  | 0, | 
|  | "previous_count", | 
|  | "*(type: int)* Count value BEFORE this operation."); | 
|  |  | 
|  | OPERATOR_SCHEMA(RetrieveCount) | 
|  | .NumInputs(1) | 
|  | .NumOutputs(1) | 
|  | .ScalarType(TensorProto::INT64) | 
|  | .SetDoc(R"DOC( | 
|  | Retrieve the current value from the counter as an integer. | 
|  | )DOC" + (string) githubLinks + (string) kCountExample) | 
|  | .Input( | 
|  | 0, | 
|  | "counter", | 
|  | "*(type: Tensor`<ptr>`)* A blob pointing to an instance of a counter.") | 
|  | .Output( | 
|  | 0, | 
|  | "count", | 
|  | "*(type: int)* Current count value."); | 
|  |  | 
|  | SHOULD_NOT_DO_GRADIENT(CreateCounter); | 
|  | SHOULD_NOT_DO_GRADIENT(ResetCounter); | 
|  | SHOULD_NOT_DO_GRADIENT(CountDown); | 
|  | SHOULD_NOT_DO_GRADIENT(CountUp); | 
|  | SHOULD_NOT_DO_GRADIENT(RetrieveCount); | 
|  |  | 
|  | CAFFE_KNOWN_TYPE(std::unique_ptr<Counter<int64_t>>); | 
|  | REGISTER_BLOB_SERIALIZER( | 
|  | (TypeMeta::Id<std::unique_ptr<Counter<int64_t>>>()), | 
|  | CounterSerializer); | 
|  | REGISTER_BLOB_DESERIALIZER( | 
|  | std::unique_ptr<Counter<int64_t>>, | 
|  | CounterDeserializer); | 
|  |  | 
|  | } // namespace caffe2 |