| Intermittent patch to TFRT to submit a TF/TFRT cross-cutting change. |
| This patch will be applied only until TF's TFRT commit is automatically bumped. |
| |
| --- |
| |
| diff --git a/backends/gpu/include/tfrt/gpu/gpu_types.h b/backends/gpu/include/tfrt/gpu/gpu_types.h |
| index 3d311c3..a216716 100644 |
| --- a/backends/gpu/include/tfrt/gpu/gpu_types.h |
| +++ b/backends/gpu/include/tfrt/gpu/gpu_types.h |
| @@ -295,11 +295,7 @@ |
| wrapper::CurrentContext current, wrapper::Stream stream, |
| wrapper::CclComm comm)>; |
| |
| - explicit GpuCclHandle(AsyncValueRef<GpuContext> context, |
| - wrapper::OwningCclComm comm, int num_ranks); |
| - // TODO(hanbinyoon): Remove after transitioning to the above constructor. |
| - explicit GpuCclHandle(AsyncValueRef<GpuContext> context, |
| - wrapper::OwningCclComm comm); |
| + GpuCclHandle(AsyncValueRef<GpuContext> context, wrapper::OwningCclComm comm); |
| ~GpuCclHandle(); |
| |
| GpuCclHandle(GpuCclHandle&&) = default; |
| @@ -311,8 +307,6 @@ |
| llvm::Error ExecuteCallbacks(wrapper::CurrentContext current, |
| wrapper::Stream stream); |
| |
| - int num_ranks() const { return num_ranks_; } |
| - |
| const wrapper::OwningCclComm& operator->() const { return comm_; } |
| wrapper::CclComm get() const { return comm_.get(); } |
| wrapper::CclComm release(); |
| @@ -322,7 +316,6 @@ |
| private: |
| AsyncValueRef<GpuContext> context_; |
| wrapper::OwningCclComm comm_; |
| - int num_ranks_; |
| std::vector<Callback> callbacks_; |
| }; |
| |
| diff --git a/backends/gpu/lib/gpu_types.cc b/backends/gpu/lib/gpu_types.cc |
| index 38529bc..01e3dba 100644 |
| --- a/backends/gpu/lib/gpu_types.cc |
| +++ b/backends/gpu/lib/gpu_types.cc |
| @@ -214,15 +214,8 @@ |
| GpuBlasHandle::~GpuBlasHandle() = default; |
| |
| GpuCclHandle::GpuCclHandle(AsyncValueRef<GpuContext> context, |
| - wrapper::OwningCclComm comm, int num_ranks) |
| - : context_(std::move(context)), |
| - comm_(std::move(comm)), |
| - num_ranks_(num_ranks) {} |
| - |
| -// TODO(hanbinyoon): Remove after transitioning to the above constructor. |
| -GpuCclHandle::GpuCclHandle(AsyncValueRef<GpuContext> context, |
| wrapper::OwningCclComm comm) |
| - : context_(std::move(context)), comm_(std::move(comm)), num_ranks_(0) {} |
| + : context_(std::move(context)), comm_(std::move(comm)) {} |
| |
| GpuCclHandle::~GpuCclHandle() = default; |
| |
| diff --git a/backends/gpu/lib/kernels/ccl_kernels.cc b/backends/gpu/lib/kernels/ccl_kernels.cc |
| index 52ce820..9cfc1de 100644 |
| --- a/backends/gpu/lib/kernels/ccl_kernels.cc |
| +++ b/backends/gpu/lib/kernels/ccl_kernels.cc |
| @@ -107,8 +107,6 @@ |
| auto width = ToWidthInBytes(type); |
| if (!width) return width.takeError(); |
| assert(*width != 0); |
| - if (input->size() != output->size() * handle->num_ranks()) |
| - return MakeStringError("Input size must be output size times ranks."); |
| |
| handle->AddCallback([input = input.ValueRef(), output = output.ValueRef(), |
| recvcount = output->size() / *width, type, |
| @@ -116,6 +114,10 @@ |
| wrapper::CurrentContext current, |
| wrapper::Stream stream, |
| wrapper::CclComm comm) -> llvm::Error { |
| + auto count = wrapper::CclCommCount(comm); |
| + if (!count) return count.takeError(); |
| + if (input->size() != output->size() * *count) |
| + return MakeStringError("Input size must be output size times ranks."); |
| return wrapper::CclReduceScatter(current, input->pointer(), |
| output->pointer(), recvcount, type, op, |
| comm, stream); |