blob: 508939a900c48d32e4fdb32b394b3aad02aa454f [file] [log] [blame]
#include "broadcast_ops.h"
#include "gloo/broadcast_one_to_all.h"
namespace caffe2 {
namespace gloo {
template <class Context>
void BroadcastOp<Context>::initializeAlgorithm() {
if (init_.template IsType<float>()) {
algorithm_.reset(new ::gloo::BroadcastOneToAll<float>(
init_.context, init_.template getOutputs<float>(), init_.size, root_));
} else if (init_.template IsType<long>()) {
algorithm_.reset(new ::gloo::BroadcastOneToAll<long>(
init_.context, init_.template getOutputs<long>(), init_.size, root_));
} else {
CAFFE_ENFORCE(false, "Unhandled type: ", init_.meta.name());
}
}
namespace {
REGISTER_CPU_OPERATOR_WITH_ENGINE(Broadcast, GLOO, BroadcastOp<CPUContext>);
} // namespace
} // namespace gloo
} // namespace caffe2