| //===- RawByteChannel.h -----------------------------------------*- C++ -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_RAWBYTECHANNEL_H |
| #define LLVM_EXECUTIONENGINE_ORC_SHARED_RAWBYTECHANNEL_H |
| |
| #include "llvm/ADT/StringRef.h" |
| #include "llvm/ExecutionEngine/Orc/Shared/Serialization.h" |
| #include "llvm/Support/Endian.h" |
| #include "llvm/Support/Error.h" |
| #include <cstdint> |
| #include <mutex> |
| #include <string> |
| #include <type_traits> |
| |
| namespace llvm { |
| namespace orc { |
| namespace shared { |
| |
| /// Interface for byte-streams to be used with ORC Serialization. |
| class RawByteChannel { |
| public: |
| virtual ~RawByteChannel() = default; |
| |
| /// Read Size bytes from the stream into *Dst. |
| virtual Error readBytes(char *Dst, unsigned Size) = 0; |
| |
| /// Read size bytes from *Src and append them to the stream. |
| virtual Error appendBytes(const char *Src, unsigned Size) = 0; |
| |
| /// Flush the stream if possible. |
| virtual Error send() = 0; |
| |
| /// Notify the channel that we're starting a message send. |
| /// Locks the channel for writing. |
| template <typename FunctionIdT, typename SequenceIdT> |
| Error startSendMessage(const FunctionIdT &FnId, const SequenceIdT &SeqNo) { |
| writeLock.lock(); |
| if (auto Err = serializeSeq(*this, FnId, SeqNo)) { |
| writeLock.unlock(); |
| return Err; |
| } |
| return Error::success(); |
| } |
| |
| /// Notify the channel that we're ending a message send. |
| /// Unlocks the channel for writing. |
| Error endSendMessage() { |
| writeLock.unlock(); |
| return Error::success(); |
| } |
| |
| /// Notify the channel that we're starting a message receive. |
| /// Locks the channel for reading. |
| template <typename FunctionIdT, typename SequenceNumberT> |
| Error startReceiveMessage(FunctionIdT &FnId, SequenceNumberT &SeqNo) { |
| readLock.lock(); |
| if (auto Err = deserializeSeq(*this, FnId, SeqNo)) { |
| readLock.unlock(); |
| return Err; |
| } |
| return Error::success(); |
| } |
| |
| /// Notify the channel that we're ending a message receive. |
| /// Unlocks the channel for reading. |
| Error endReceiveMessage() { |
| readLock.unlock(); |
| return Error::success(); |
| } |
| |
| /// Get the lock for stream reading. |
| std::mutex &getReadLock() { return readLock; } |
| |
| /// Get the lock for stream writing. |
| std::mutex &getWriteLock() { return writeLock; } |
| |
| private: |
| std::mutex readLock, writeLock; |
| }; |
| |
| template <typename ChannelT, typename T> |
| class SerializationTraits< |
| ChannelT, T, T, |
| std::enable_if_t< |
| std::is_base_of<RawByteChannel, ChannelT>::value && |
| (std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value || |
| std::is_same<T, uint16_t>::value || std::is_same<T, int16_t>::value || |
| std::is_same<T, uint32_t>::value || std::is_same<T, int32_t>::value || |
| std::is_same<T, uint64_t>::value || std::is_same<T, int64_t>::value || |
| std::is_same<T, char>::value)>> { |
| public: |
| static Error serialize(ChannelT &C, T V) { |
| support::endian::byte_swap<T, support::big>(V); |
| return C.appendBytes(reinterpret_cast<const char *>(&V), sizeof(T)); |
| }; |
| |
| static Error deserialize(ChannelT &C, T &V) { |
| if (auto Err = C.readBytes(reinterpret_cast<char *>(&V), sizeof(T))) |
| return Err; |
| support::endian::byte_swap<T, support::big>(V); |
| return Error::success(); |
| }; |
| }; |
| |
| template <typename ChannelT> |
| class SerializationTraits< |
| ChannelT, bool, bool, |
| std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value>> { |
| public: |
| static Error serialize(ChannelT &C, bool V) { |
| uint8_t Tmp = V ? 1 : 0; |
| if (auto Err = C.appendBytes(reinterpret_cast<const char *>(&Tmp), 1)) |
| return Err; |
| return Error::success(); |
| } |
| |
| static Error deserialize(ChannelT &C, bool &V) { |
| uint8_t Tmp = 0; |
| if (auto Err = C.readBytes(reinterpret_cast<char *>(&Tmp), 1)) |
| return Err; |
| V = Tmp != 0; |
| return Error::success(); |
| } |
| }; |
| |
| template <typename ChannelT> |
| class SerializationTraits< |
| ChannelT, std::string, StringRef, |
| std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value>> { |
| public: |
| /// Serialization channel serialization for std::strings. |
| static Error serialize(RawByteChannel &C, StringRef S) { |
| if (auto Err = serializeSeq(C, static_cast<uint64_t>(S.size()))) |
| return Err; |
| return C.appendBytes((const char *)S.data(), S.size()); |
| } |
| }; |
| |
| template <typename ChannelT, typename T> |
| class SerializationTraits< |
| ChannelT, std::string, T, |
| std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value && |
| (std::is_same<T, const char *>::value || |
| std::is_same<T, char *>::value)>> { |
| public: |
| static Error serialize(RawByteChannel &C, const char *S) { |
| return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C, |
| S); |
| } |
| }; |
| |
| template <typename ChannelT> |
| class SerializationTraits< |
| ChannelT, std::string, std::string, |
| std::enable_if_t<std::is_base_of<RawByteChannel, ChannelT>::value>> { |
| public: |
| /// Serialization channel serialization for std::strings. |
| static Error serialize(RawByteChannel &C, const std::string &S) { |
| return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C, |
| S); |
| } |
| |
| /// Serialization channel deserialization for std::strings. |
| static Error deserialize(RawByteChannel &C, std::string &S) { |
| uint64_t Count = 0; |
| if (auto Err = deserializeSeq(C, Count)) |
| return Err; |
| S.resize(Count); |
| return C.readBytes(&S[0], Count); |
| } |
| }; |
| |
| } // end namespace shared |
| } // end namespace orc |
| } // end namespace llvm |
| |
| #endif // LLVM_EXECUTIONENGINE_ORC_SHARED_RAWBYTECHANNEL_H |