blob: e636b083802a420958a0467701a05565be1bf06e [file] [log] [blame]
#include "counter_ops.h"
#include "caffe2/core/blob_serialization.h"
namespace caffe2 {
namespace {
/**
* @brief CounterSerializer is the serializer for Counter type.
*
* CounterSerializer takes in a blob that contains a Counter, and serializes
* it into a BlobProto protocol buffer. At the moment only int64_t counters are
* supported (since it's the only once that is really used).
*
*/
class CounterSerializer : public BlobSerializerBase {
public:
CounterSerializer() {}
~CounterSerializer() {}
void Serialize(
const Blob& blob,
const string& name,
SerializationAcceptor acceptor) override {
CAFFE_ENFORCE(blob.IsType<std::unique_ptr<Counter<int64_t>>>());
BlobProto blob_proto;
blob_proto.set_name(name);
blob_proto.set_type("std::unique_ptr<Counter<int64_t>>");
TensorProto& proto = *blob_proto.mutable_tensor();
proto.set_name(name);
proto.set_data_type(TensorProto_DataType_INT64);
proto.add_dims(1);
proto.add_int64_data(
blob.template Get<std::unique_ptr<Counter<int64_t>>>()->retrieve());
acceptor(name, blob_proto.SerializeAsString());
}
};
/**
* @brief CounterDeserializer is the deserializer for Counters.
*
*/
class CounterDeserializer : public BlobDeserializerBase {
public:
void Deserialize(const BlobProto& proto, Blob* blob) override {
auto tensorProto = proto.tensor();
CAFFE_ENFORCE_EQ(tensorProto.dims_size(), 1, "Unexpected size of dims");
CAFFE_ENFORCE_EQ(tensorProto.dims(0), 1, "Unexpected value of dims");
CAFFE_ENFORCE_EQ(
tensorProto.data_type(),
TensorProto_DataType_INT64,
"Only int64_t counters supported");
CAFFE_ENFORCE_EQ(
tensorProto.int64_data_size(), 1, "Unexpected size of data");
*blob->GetMutable<std::unique_ptr<Counter<int64_t>>>() =
caffe2::make_unique<Counter<int64_t>>(tensorProto.int64_data(0));
}
};
}
// TODO(jiayq): deprecate these ops & consolidate them with
// IterOp/AtomicIterOp
REGISTER_CPU_OPERATOR(CreateCounter, CreateCounterOp<int64_t, CPUContext>);
REGISTER_CPU_OPERATOR(ResetCounter, ResetCounterOp<int64_t, CPUContext>);
REGISTER_CPU_OPERATOR(CountDown, CountDownOp<int64_t, CPUContext>);
REGISTER_CPU_OPERATOR(
CheckCounterDone,
CheckCounterDoneOp<int64_t, CPUContext>);
REGISTER_CPU_OPERATOR(CountUp, CountUpOp<int64_t, CPUContext>);
REGISTER_CPU_OPERATOR(RetrieveCount, RetrieveCountOp<int64_t, CPUContext>);
OPERATOR_SCHEMA(CreateCounter)
.NumInputs(0)
.NumOutputs(1)
.SetDoc(R"DOC(
Creates a count-down counter with initial value specified by the 'init_count'
argument.
)DOC")
.Output(0, "counter", "A blob pointing to an instance of a new counter.")
.Arg("init_count", "Initial count for the counter, must be >= 0.");
OPERATOR_SCHEMA(ResetCounter)
.NumInputs(1)
.NumOutputs(0, 1)
.SetDoc(R"DOC(
Resets a count-down counter with initial value specified by the 'init_count'
argument.
)DOC")
.Input(0, "counter", "A blob pointing to an instance of a new counter.")
.Output(0, "previous_value", "(optional) Previous value of the counter.")
.Arg("init_count", "Resets counter to this value, must be >= 0.");
OPERATOR_SCHEMA(CountDown)
.NumInputs(1)
.NumOutputs(1)
.SetDoc(R"DOC(
If the internal count value > 0, decreases count value by 1 and outputs false,
otherwise outputs true.
)DOC")
.Input(0, "counter", "A blob pointing to an instance of a counter.")
.Output(0, "done", "false unless the internal count is zero.");
OPERATOR_SCHEMA(CheckCounterDone)
.NumInputs(1)
.NumOutputs(1)
.SetDoc(R"DOC(
If the internal count value <= 0, outputs true, otherwise outputs false,
)DOC")
.Input(0, "counter", "A blob pointing to an instance of a counter.")
.Output(0, "done", "true if the internal count is zero or negative.");
OPERATOR_SCHEMA(CountUp)
.NumInputs(1)
.NumOutputs(1)
.SetDoc(R"DOC(
Increases count value by 1 and outputs the previous value atomically
)DOC")
.Input(0, "counter", "A blob pointing to an instance of a counter.")
.Output(0, "previous_count", "count value BEFORE this operation");
OPERATOR_SCHEMA(RetrieveCount)
.NumInputs(1)
.NumOutputs(1)
.ScalarType(TensorProto::INT64)
.SetDoc(R"DOC(
Retrieve the current value from the counter.
)DOC")
.Input(0, "counter", "A blob pointing to an instance of a counter.")
.Output(0, "count", "current count value.");
SHOULD_NOT_DO_GRADIENT(CreateCounter);
SHOULD_NOT_DO_GRADIENT(ResetCounter);
SHOULD_NOT_DO_GRADIENT(CountDown);
SHOULD_NOT_DO_GRADIENT(CountUp);
SHOULD_NOT_DO_GRADIENT(RetrieveCount);
CAFFE_KNOWN_TYPE(std::unique_ptr<Counter<int64_t>>);
REGISTER_BLOB_SERIALIZER(
(TypeMeta::Id<std::unique_ptr<Counter<int64_t>>>()),
CounterSerializer);
REGISTER_BLOB_DESERIALIZER(
std::unique_ptr<Counter<int64_t>>,
CounterDeserializer);
} // namespace caffe2