| // Copyright 2016 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "mojo/edk/system/channel.h" |
| |
| #include <stdint.h> |
| #include <windows.h> |
| |
| #include <algorithm> |
| #include <deque> |
| #include <limits> |
| #include <memory> |
| |
| #include "base/bind.h" |
| #include "base/location.h" |
| #include "base/macros.h" |
| #include "base/memory/ref_counted.h" |
| #include "base/message_loop/message_loop.h" |
| #include "base/synchronization/lock.h" |
| #include "base/task_runner.h" |
| #include "base/win/win_util.h" |
| #include "mojo/edk/embedder/platform_handle_vector.h" |
| |
| namespace mojo { |
| namespace edk { |
| |
| namespace { |
| |
| // A view over a Channel::Message object. The write queue uses these since |
| // large messages may need to be sent in chunks. |
| class MessageView { |
| public: |
| // Owns |message|. |offset| indexes the first unsent byte in the message. |
| MessageView(Channel::MessagePtr message, size_t offset) |
| : message_(std::move(message)), |
| offset_(offset) { |
| DCHECK_GT(message_->data_num_bytes(), offset_); |
| } |
| |
| MessageView(MessageView&& other) { *this = std::move(other); } |
| |
| MessageView& operator=(MessageView&& other) { |
| message_ = std::move(other.message_); |
| offset_ = other.offset_; |
| return *this; |
| } |
| |
| ~MessageView() {} |
| |
| const void* data() const { |
| return static_cast<const char*>(message_->data()) + offset_; |
| } |
| |
| size_t data_num_bytes() const { return message_->data_num_bytes() - offset_; } |
| |
| size_t data_offset() const { return offset_; } |
| void advance_data_offset(size_t num_bytes) { |
| DCHECK_GE(message_->data_num_bytes(), offset_ + num_bytes); |
| offset_ += num_bytes; |
| } |
| |
| Channel::MessagePtr TakeChannelMessage() { return std::move(message_); } |
| |
| private: |
| Channel::MessagePtr message_; |
| size_t offset_; |
| |
| DISALLOW_COPY_AND_ASSIGN(MessageView); |
| }; |
| |
| class ChannelWin : public Channel, |
| public base::MessageLoop::DestructionObserver, |
| public base::MessageLoopForIO::IOHandler { |
| public: |
| ChannelWin(Delegate* delegate, |
| ScopedPlatformHandle handle, |
| scoped_refptr<base::TaskRunner> io_task_runner) |
| : Channel(delegate), |
| self_(this), |
| handle_(std::move(handle)), |
| io_task_runner_(io_task_runner) { |
| CHECK(handle_.is_valid()); |
| |
| wait_for_connect_ = handle_.get().needs_connection; |
| } |
| |
| void Start() override { |
| io_task_runner_->PostTask( |
| FROM_HERE, base::Bind(&ChannelWin::StartOnIOThread, this)); |
| } |
| |
| void ShutDownImpl() override { |
| // Always shut down asynchronously when called through the public interface. |
| io_task_runner_->PostTask( |
| FROM_HERE, base::Bind(&ChannelWin::ShutDownOnIOThread, this)); |
| } |
| |
| void Write(MessagePtr message) override { |
| bool write_error = false; |
| { |
| base::AutoLock lock(write_lock_); |
| if (reject_writes_) |
| return; |
| |
| bool write_now = !delay_writes_ && outgoing_messages_.empty(); |
| outgoing_messages_.emplace_back(std::move(message), 0); |
| |
| if (write_now && !WriteNoLock(outgoing_messages_.front())) |
| reject_writes_ = write_error = true; |
| } |
| if (write_error) { |
| // Do not synchronously invoke OnError(). Write() may have been called by |
| // the delegate and we don't want to re-enter it. |
| io_task_runner_->PostTask(FROM_HERE, |
| base::Bind(&ChannelWin::OnError, this)); |
| } |
| } |
| |
| void LeakHandle() override { |
| DCHECK(io_task_runner_->RunsTasksOnCurrentThread()); |
| leak_handle_ = true; |
| } |
| |
| bool GetReadPlatformHandles( |
| size_t num_handles, |
| const void* extra_header, |
| size_t extra_header_size, |
| ScopedPlatformHandleVectorPtr* handles) override { |
| if (num_handles > std::numeric_limits<uint16_t>::max()) |
| return false; |
| using HandleEntry = Channel::Message::HandleEntry; |
| size_t handles_size = sizeof(HandleEntry) * num_handles; |
| if (handles_size > extra_header_size) |
| return false; |
| DCHECK(extra_header); |
| handles->reset(new PlatformHandleVector(num_handles)); |
| const HandleEntry* extra_header_handles = |
| reinterpret_cast<const HandleEntry*>(extra_header); |
| for (size_t i = 0; i < num_handles; i++) { |
| (*handles)->at(i).handle = |
| base::win::Uint32ToHandle(extra_header_handles[i].handle); |
| } |
| return true; |
| } |
| |
| private: |
| // May run on any thread. |
| ~ChannelWin() override {} |
| |
| void StartOnIOThread() { |
| base::MessageLoop::current()->AddDestructionObserver(this); |
| base::MessageLoopForIO::current()->RegisterIOHandler( |
| handle_.get().handle, this); |
| |
| if (wait_for_connect_) { |
| BOOL ok = ConnectNamedPipe(handle_.get().handle, |
| &connect_context_.overlapped); |
| if (ok) { |
| PLOG(ERROR) << "Unexpected success while waiting for pipe connection"; |
| OnError(); |
| return; |
| } |
| |
| const DWORD err = GetLastError(); |
| switch (err) { |
| case ERROR_PIPE_CONNECTED: |
| wait_for_connect_ = false; |
| break; |
| case ERROR_IO_PENDING: |
| AddRef(); |
| return; |
| case ERROR_NO_DATA: |
| OnError(); |
| return; |
| } |
| } |
| |
| // Now that we have registered our IOHandler, we can start writing. |
| { |
| base::AutoLock lock(write_lock_); |
| if (delay_writes_) { |
| delay_writes_ = false; |
| WriteNextNoLock(); |
| } |
| } |
| |
| // Keep this alive in case we synchronously run shutdown. |
| scoped_refptr<ChannelWin> keep_alive(this); |
| ReadMore(0); |
| } |
| |
| void ShutDownOnIOThread() { |
| base::MessageLoop::current()->RemoveDestructionObserver(this); |
| |
| // BUG(crbug.com/583525): This function is expected to be called once, and |
| // |handle_| should be valid at this point. |
| CHECK(handle_.is_valid()); |
| CancelIo(handle_.get().handle); |
| if (leak_handle_) |
| ignore_result(handle_.release()); |
| handle_.reset(); |
| |
| // May destroy the |this| if it was the last reference. |
| self_ = nullptr; |
| } |
| |
| // base::MessageLoop::DestructionObserver: |
| void WillDestroyCurrentMessageLoop() override { |
| DCHECK(io_task_runner_->RunsTasksOnCurrentThread()); |
| if (self_) |
| ShutDownOnIOThread(); |
| } |
| |
| // base::MessageLoop::IOHandler: |
| void OnIOCompleted(base::MessageLoopForIO::IOContext* context, |
| DWORD bytes_transfered, |
| DWORD error) override { |
| if (error != ERROR_SUCCESS) { |
| OnError(); |
| } else if (context == &connect_context_) { |
| DCHECK(wait_for_connect_); |
| wait_for_connect_ = false; |
| ReadMore(0); |
| |
| base::AutoLock lock(write_lock_); |
| if (delay_writes_) { |
| delay_writes_ = false; |
| WriteNextNoLock(); |
| } |
| } else if (context == &read_context_) { |
| OnReadDone(static_cast<size_t>(bytes_transfered)); |
| } else { |
| CHECK(context == &write_context_); |
| OnWriteDone(static_cast<size_t>(bytes_transfered)); |
| } |
| Release(); // Balancing reference taken after ReadFile / WriteFile. |
| } |
| |
| void OnReadDone(size_t bytes_read) { |
| if (bytes_read > 0) { |
| size_t next_read_size = 0; |
| if (OnReadComplete(bytes_read, &next_read_size)) { |
| ReadMore(next_read_size); |
| } else { |
| OnError(); |
| } |
| } else if (bytes_read == 0) { |
| OnError(); |
| } |
| } |
| |
| void OnWriteDone(size_t bytes_written) { |
| if (bytes_written == 0) |
| return; |
| |
| bool write_error = false; |
| { |
| base::AutoLock lock(write_lock_); |
| |
| DCHECK(!outgoing_messages_.empty()); |
| |
| MessageView& message_view = outgoing_messages_.front(); |
| message_view.advance_data_offset(bytes_written); |
| if (message_view.data_num_bytes() == 0) { |
| Channel::MessagePtr message = message_view.TakeChannelMessage(); |
| outgoing_messages_.pop_front(); |
| |
| // Clear any handles so they don't get closed on destruction. |
| ScopedPlatformHandleVectorPtr handles = message->TakeHandles(); |
| if (handles) |
| handles->clear(); |
| } |
| |
| if (!WriteNextNoLock()) |
| reject_writes_ = write_error = true; |
| } |
| if (write_error) |
| OnError(); |
| } |
| |
| void ReadMore(size_t next_read_size_hint) { |
| size_t buffer_capacity = next_read_size_hint; |
| char* buffer = GetReadBuffer(&buffer_capacity); |
| DCHECK_GT(buffer_capacity, 0u); |
| |
| BOOL ok = ReadFile(handle_.get().handle, |
| buffer, |
| static_cast<DWORD>(buffer_capacity), |
| NULL, |
| &read_context_.overlapped); |
| |
| if (ok || GetLastError() == ERROR_IO_PENDING) { |
| AddRef(); // Will be balanced in OnIOCompleted |
| } else { |
| OnError(); |
| } |
| } |
| |
| // Attempts to write a message directly to the channel. If the full message |
| // cannot be written, it's queued and a wait is initiated to write the message |
| // ASAP on the I/O thread. |
| bool WriteNoLock(const MessageView& message_view) { |
| BOOL ok = WriteFile(handle_.get().handle, |
| message_view.data(), |
| static_cast<DWORD>(message_view.data_num_bytes()), |
| NULL, |
| &write_context_.overlapped); |
| |
| if (ok || GetLastError() == ERROR_IO_PENDING) { |
| AddRef(); // Will be balanced in OnIOCompleted. |
| return true; |
| } |
| return false; |
| } |
| |
| bool WriteNextNoLock() { |
| if (outgoing_messages_.empty()) |
| return true; |
| return WriteNoLock(outgoing_messages_.front()); |
| } |
| |
| // Keeps the Channel alive at least until explicit shutdown on the IO thread. |
| scoped_refptr<Channel> self_; |
| |
| ScopedPlatformHandle handle_; |
| scoped_refptr<base::TaskRunner> io_task_runner_; |
| |
| base::MessageLoopForIO::IOContext connect_context_; |
| base::MessageLoopForIO::IOContext read_context_; |
| base::MessageLoopForIO::IOContext write_context_; |
| |
| // Protects |reject_writes_| and |outgoing_messages_|. |
| base::Lock write_lock_; |
| |
| bool delay_writes_ = true; |
| |
| bool reject_writes_ = false; |
| std::deque<MessageView> outgoing_messages_; |
| |
| bool wait_for_connect_; |
| |
| bool leak_handle_ = false; |
| |
| DISALLOW_COPY_AND_ASSIGN(ChannelWin); |
| }; |
| |
| } // namespace |
| |
| // static |
| scoped_refptr<Channel> Channel::Create( |
| Delegate* delegate, |
| ConnectionParams connection_params, |
| scoped_refptr<base::TaskRunner> io_task_runner) { |
| return new ChannelWin(delegate, connection_params.TakeChannelHandle(), |
| io_task_runner); |
| } |
| |
| } // namespace edk |
| } // namespace mojo |