blob: bd45be9192dcad078b5c7248be3b853dd7f907eb [file] [log] [blame]
#ifndef CAFFE2_UTILS_ZMQ_HELPER_H_
#define CAFFE2_UTILS_ZMQ_HELPER_H_
#include <zmq.h>
#include "caffe2/core/logging.h"
namespace caffe2 {
class ZmqContext {
public:
explicit ZmqContext(int io_threads) : ptr_(zmq_ctx_new()) {
CAFFE_ENFORCE(ptr_ != nullptr, "Failed to create zmq context.");
int rc = zmq_ctx_set(ptr_, ZMQ_IO_THREADS, io_threads);
CAFFE_ENFORCE_EQ(rc, 0);
rc = zmq_ctx_set(ptr_, ZMQ_MAX_SOCKETS, ZMQ_MAX_SOCKETS_DFLT);
CAFFE_ENFORCE_EQ(rc, 0);
}
~ZmqContext() {
int rc = zmq_ctx_destroy(ptr_);
CAFFE_ENFORCE_EQ(rc, 0);
}
void* ptr() { return ptr_; }
private:
void* ptr_;
C10_DISABLE_COPY_AND_ASSIGN(ZmqContext);
};
class ZmqMessage {
public:
ZmqMessage() {
int rc = zmq_msg_init(&msg_);
CAFFE_ENFORCE_EQ(rc, 0);
}
~ZmqMessage() {
int rc = zmq_msg_close(&msg_);
CAFFE_ENFORCE_EQ(rc, 0);
}
zmq_msg_t* msg() { return &msg_; }
void* data() { return zmq_msg_data(&msg_); }
size_t size() { return zmq_msg_size(&msg_); }
private:
zmq_msg_t msg_;
C10_DISABLE_COPY_AND_ASSIGN(ZmqMessage);
};
class ZmqSocket {
public:
explicit ZmqSocket(int type)
: context_(1), ptr_(zmq_socket(context_.ptr(), type)) {
CAFFE_ENFORCE(ptr_ != nullptr, "Faild to create zmq socket.");
}
~ZmqSocket() {
int rc = zmq_close(ptr_);
CAFFE_ENFORCE_EQ(rc, 0);
}
void Bind(const string& addr) {
int rc = zmq_bind(ptr_, addr.c_str());
CAFFE_ENFORCE_EQ(rc, 0);
}
void Unbind(const string& addr) {
int rc = zmq_unbind(ptr_, addr.c_str());
CAFFE_ENFORCE_EQ(rc, 0);
}
void Connect(const string& addr) {
int rc = zmq_connect(ptr_, addr.c_str());
CAFFE_ENFORCE_EQ(rc, 0);
}
void Disconnect(const string& addr) {
int rc = zmq_disconnect(ptr_, addr.c_str());
CAFFE_ENFORCE_EQ(rc, 0);
}
int Send(const string& msg, int flags) {
int nbytes = zmq_send(ptr_, msg.c_str(), msg.size(), flags);
if (nbytes) {
return nbytes;
} else if (zmq_errno() == EAGAIN) {
return 0;
} else {
LOG(FATAL) << "Cannot send zmq message. Error number: "
<< zmq_errno();
return 0;
}
}
int SendTillSuccess(const string& msg, int flags) {
CAFFE_ENFORCE(msg.size(), "You cannot send an empty message.");
int nbytes = 0;
do {
nbytes = Send(msg, flags);
} while (nbytes == 0);
return nbytes;
}
int Recv(ZmqMessage* msg) {
int nbytes = zmq_msg_recv(msg->msg(), ptr_, 0);
if (nbytes >= 0) {
return nbytes;
} else if (zmq_errno() == EAGAIN || zmq_errno() == EINTR) {
return 0;
} else {
LOG(FATAL) << "Cannot receive zmq message. Error number: "
<< zmq_errno();
return 0;
}
}
int RecvTillSuccess(ZmqMessage* msg) {
int nbytes = 0;
do {
nbytes = Recv(msg);
} while (nbytes == 0);
return nbytes;
}
private:
ZmqContext context_;
void* ptr_;
};
} // namespace caffe2
#endif // CAFFE2_UTILS_ZMQ_HELPER_H_