| #pragma once |
| |
| #include <torch/csrc/distributed/c10d/Backend.hpp> |
| |
| namespace c10d { |
| |
| class FakeWork : public Work { |
| public: |
| bool wait(std::chrono::milliseconds timeout) override { |
| return true; |
| } |
| |
| c10::intrusive_ptr<c10::ivalue::Future> getFuture() override { |
| auto fut = c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get()); |
| fut->markCompleted(); |
| return fut; |
| } |
| }; |
| |
| class FakeProcessGroup : public Backend { |
| public: |
| FakeProcessGroup(int rank, int size) : Backend(rank, size) {} |
| |
| c10::intrusive_ptr<Work> broadcast( |
| std::vector<at::Tensor>& /* tensors */, |
| const BroadcastOptions& /* opts */ = BroadcastOptions()) override { |
| return c10::make_intrusive<FakeWork>(); |
| } |
| |
| c10::intrusive_ptr<Work> allreduce( |
| std::vector<at::Tensor>& /* tensors */, |
| const AllreduceOptions& /* opts */ = AllreduceOptions()) override { |
| return c10::make_intrusive<FakeWork>(); |
| } |
| |
| c10::intrusive_ptr<Work> allreduce_sparse( |
| std::vector<at::Tensor>& /* tensors */, |
| const AllreduceOptions& /* opts */ = AllreduceOptions()) override { |
| return c10::make_intrusive<FakeWork>(); |
| } |
| |
| c10::intrusive_ptr<Work> allreduce_coalesced( |
| std::vector<at::Tensor>& /* tensors */, |
| const AllreduceCoalescedOptions& /* opts */ = |
| AllreduceCoalescedOptions()) override { |
| return c10::make_intrusive<FakeWork>(); |
| } |
| |
| c10::intrusive_ptr<Work> reduce( |
| std::vector<at::Tensor>& /* tensors */, |
| const ReduceOptions& /* opts */ = ReduceOptions()) override { |
| return c10::make_intrusive<FakeWork>(); |
| } |
| |
| // NOTE [allgather on FakeProcessGroup] |
| // Assume each rank have the same input tensor so we just copy to the results |
| // since it's not a real allgather, we simply make this copying logic to let |
| // some simple validation works (i.e. calling allgather to see if each rank |
| // have the same tensor or not). |
| // |
| // NOTE: in general it's not good form to try to make FakeProcessGroup work |
| // with real data, but the reasoning here is that we want FakeProcessGroup to |
| // work with DeviceMesh's init code that have the data validation, which |
| // makes it worth the tradeoff. |
| c10::intrusive_ptr<Work> allgather( |
| std::vector<std::vector<at::Tensor>>& outputTensors, |
| std::vector<at::Tensor>& inputTensors, |
| const AllgatherOptions& /* opts */ = AllgatherOptions()) override { |
| for (auto& tensor : outputTensors[0]) { |
| tensor.copy_(inputTensors[0]); |
| } |
| return c10::make_intrusive<FakeWork>(); |
| } |
| |
| c10::intrusive_ptr<Work> _allgather_base( |
| at::Tensor& outputBuffer, |
| at::Tensor& inputBuffer, |
| const AllgatherOptions& /* opts */ = AllgatherOptions()) override { |
| auto chunks = outputBuffer.chunk(size_); |
| for (auto& tensor : chunks) { |
| tensor.copy_(inputBuffer); |
| } |
| return c10::make_intrusive<FakeWork>(); |
| } |
| |
| c10::intrusive_ptr<Work> allgather_coalesced( |
| std::vector<std::vector<at::Tensor>>& /* outputTensorLists */, |
| std::vector<at::Tensor>& /* inputTensors */, |
| const AllgatherOptions& /* opts */ = AllgatherOptions()) override { |
| return c10::make_intrusive<FakeWork>(); |
| } |
| |
| c10::intrusive_ptr<Work> allgather_into_tensor_coalesced( |
| std::vector<at::Tensor>& outputs, |
| std::vector<at::Tensor>& inputs, |
| const AllgatherOptions& /* opts */ = AllgatherOptions()) override { |
| for (size_t i = 0; i < outputs.size(); ++i) { |
| auto chunks = outputs[i].chunk(size_); |
| for (auto& chunk : chunks) { |
| chunk.copy_(inputs[i]); |
| } |
| } |
| return c10::make_intrusive<FakeWork>(); |
| } |
| |
| c10::intrusive_ptr<Work> gather( |
| std::vector<std::vector<at::Tensor>>& /* outputTensors */, |
| std::vector<at::Tensor>& /* inputTensors */, |
| const GatherOptions& /* opts */ = GatherOptions()) override { |
| return c10::make_intrusive<FakeWork>(); |
| } |
| |
| c10::intrusive_ptr<Work> scatter( |
| std::vector<at::Tensor>& /* outputTensors */, |
| std::vector<std::vector<at::Tensor>>& /* inputTensors */, |
| const ScatterOptions& /* opts */ = ScatterOptions()) override { |
| return c10::make_intrusive<FakeWork>(); |
| } |
| |
| c10::intrusive_ptr<Work> reduce_scatter( |
| std::vector<at::Tensor>& /* outputTensors */, |
| std::vector<std::vector<at::Tensor>>& /* inputTensors */, |
| const ReduceScatterOptions& /* opts */ = |
| ReduceScatterOptions()) override { |
| return c10::make_intrusive<FakeWork>(); |
| } |
| |
| c10::intrusive_ptr<Work> _reduce_scatter_base( |
| at::Tensor& /* outputBuffer */, |
| at::Tensor& /* inputBuffer */, |
| const ReduceScatterOptions& /* opts */ = |
| ReduceScatterOptions()) override { |
| return c10::make_intrusive<FakeWork>(); |
| } |
| |
| c10::intrusive_ptr<Work> reduce_scatter_tensor_coalesced( |
| std::vector<at::Tensor>& /* outputs */, |
| std::vector<at::Tensor>& /* inputs */, |
| const ReduceScatterOptions& /* opts */ = |
| ReduceScatterOptions()) override { |
| return c10::make_intrusive<FakeWork>(); |
| } |
| |
| c10::intrusive_ptr<Work> alltoall_base( |
| at::Tensor& /* outputBuffer */, |
| at::Tensor& /* inputBuffer */, |
| std::vector<int64_t>& /* outputSplitSizes */, |
| std::vector<int64_t>& /* inputSplitSizes */, |
| const AllToAllOptions& /* opts */ = AllToAllOptions()) override { |
| return c10::make_intrusive<FakeWork>(); |
| } |
| |
| c10::intrusive_ptr<Work> alltoall( |
| std::vector<at::Tensor>& /* outputTensors */, |
| std::vector<at::Tensor>& /* inputTensors */, |
| const AllToAllOptions& opts = AllToAllOptions()) override { |
| return c10::make_intrusive<FakeWork>(); |
| } |
| |
| c10::intrusive_ptr<Work> send( |
| std::vector<at::Tensor>& /* tensors */, |
| int /* dstRank */, |
| int /* tag */) override { |
| return c10::make_intrusive<FakeWork>(); |
| } |
| |
| c10::intrusive_ptr<Work> recv( |
| std::vector<at::Tensor>& /* tensors */, |
| int /* srcRank */, |
| int /* tag */) override { |
| return c10::make_intrusive<FakeWork>(); |
| } |
| |
| c10::intrusive_ptr<Work> recvAnysource( |
| std::vector<at::Tensor>& /* tensors */, |
| int /* tag */) override { |
| return c10::make_intrusive<FakeWork>(); |
| } |
| |
| c10::intrusive_ptr<Work> barrier( |
| const BarrierOptions& /* opts */ = BarrierOptions()) override { |
| return c10::make_intrusive<FakeWork>(); |
| } |
| }; |
| |
| } // namespace c10d |