| import time |
| import types |
| |
| from torch.utils.data import communication, MapDataPipe |
| |
| DEFAULT_NON_BLOCKING_SLEEP = 0.001 |
| |
| __all__ = [ |
| "DataPipeBehindQueues", |
| "EnsureNonBlockingMapDataPipe", |
| "NonBlockingMap", |
| "NotAvailable", |
| "QueueWrapperForMap", |
| "default_not_available_hook", |
| ] |
| |
| |
| def default_not_available_hook(): |
| time.sleep(DEFAULT_NON_BLOCKING_SLEEP) |
| |
| |
| class NotAvailable(Exception): |
| pass |
| |
| |
| class NonBlockingMap(MapDataPipe): |
| not_available_hook = default_not_available_hook |
| |
| def __getitem__(self, index): |
| while True: |
| try: |
| return self.nonblocking_getitem(index) |
| except NotAvailable: |
| if NonBlockingMap.not_available_hook is not None: |
| NonBlockingMap.not_available_hook() |
| |
| def __len__(self): |
| try: |
| return self.nonblocking_len() |
| except NotAvailable: |
| if NonBlockingMap.not_available_hook is not None: |
| NonBlockingMap.not_available_hook() |
| |
| def nonblocking_len(self): |
| raise NotImplementedError( |
| "nonblocking_len is not implemented for %s" % self.__class__) |
| |
| def nonblocking_getitem(self, index): |
| raise NotImplementedError( |
| "nonblocking_getitem is not implemented for %s" % self.__class__) |
| |
| @staticmethod |
| def register_not_available_hook(hook_function): |
| NonBlockingMap.not_available_hook = hook_function |
| |
| |
| def EnsureNonBlockingMapDataPipe(validated_datapipe): |
| if not isinstance(validated_datapipe, MapDataPipe): |
| raise Exception(f'Not Map DataPipe - got {validated_datapipe.__class__}') |
| if isinstance(validated_datapipe, NonBlockingMap): |
| return validated_datapipe |
| if not hasattr(validated_datapipe, 'nonblocking_len'): |
| def nonblocking_len(self): |
| return self.__len__() |
| validated_datapipe.nonblocking_len = types.MethodType( # type: ignore[attr-defined] |
| nonblocking_len, validated_datapipe) |
| if not hasattr(validated_datapipe, 'nonblocking_getitem'): |
| def nonblocking_getitem(self, index): |
| return self.__getitem__(index) |
| validated_datapipe.nonblocking_getitem = types.MethodType( # type: ignore[attr-defined] |
| nonblocking_getitem, validated_datapipe) |
| return validated_datapipe |
| |
| |
| def DataPipeBehindQueues(source_datapipe, protocol, full_stop=False, blocking_request_get=False): |
| """ |
| Indefinitely iterates over req_queue and passing values from source_datapipe to res_queue |
| If raise_stop is true, raises exception when StopIteration received from the source_datapipe |
| """ |
| if not isinstance(protocol, communication.protocol.MapDataPipeQueueProtocolServer): |
| raise Exception('Expecting MapDataPipeQueueProtocolServer, got', protocol) |
| source_datapipe = EnsureNonBlockingMapDataPipe(source_datapipe) |
| forever = True |
| while forever: |
| try: |
| # Non-blocking call is Extremely slow here for python.mp, need to figure out a good workaround |
| request = protocol.get_new_request(block=blocking_request_get) |
| except communication.protocol.EmptyQueue: |
| yield True |
| continue |
| |
| if isinstance(request, communication.messages.TerminateRequest): |
| forever = False |
| protocol.response_terminate() |
| |
| elif isinstance(request, communication.messages.LenRequest): |
| size = source_datapipe.nonblocking_len() |
| protocol.response_len(size) |
| |
| elif isinstance(request, communication.messages.GetItemRequest): |
| while forever: |
| try: |
| value = source_datapipe.nonblocking_getitem(request.key) |
| except NotAvailable: |
| yield True |
| continue |
| except IndexError as e: |
| # Alternatively, we can just allow the underlying DataPipe to throw an exception? |
| protocol.response_index_out_of_bound() |
| if full_stop: |
| forever = False |
| else: |
| yield True |
| break |
| protocol.response_item(request.key, value) |
| yield True # Returns control |
| break |
| else: |
| raise Exception('Unrecognized type of request received', request) |
| |
| |
| class QueueWrapperForMap(NonBlockingMap): |
| """ |
| Creates map.DataPipe which reads data from the DataLoader.Queue |
| """ |
| def __init__(self, protocol, response_wait_time=0.00001): |
| if not isinstance(protocol, communication.protocol.MapDataPipeQueueProtocolClient): |
| raise Exception('Got', protocol) |
| self.protocol = protocol |
| self.counter = 0 |
| self._stop_iteration = False |
| self._response_wait_time = response_wait_time |
| |
| def nonblocking_getitem(self, index): |
| if self._stop_iteration: |
| raise Exception( |
| '`getitem` or `nonblocking_getitem` called after receiving StopIteration') |
| if self.protocol.can_take_request(): |
| self.protocol.request_item(index) |
| try: |
| response = self.protocol.get_response_item(block=True, timeout=self._response_wait_time) |
| except communication.protocol.EmptyQueue: |
| raise NotAvailable |
| if isinstance(response, communication.messages.StopIterationResponse): |
| self._stop_iteration = True |
| raise IndexError(f"Index {index} is out of bound.") |
| return response.key, response.value |
| |
| def nonblocking_len(self): |
| if self._stop_iteration: |
| raise Exception( |
| '`len` or `nonblocking_len` called after receiving StopIteration') |
| if self.protocol.can_take_request(): |
| self.protocol.request_len() |
| try: |
| response = self.protocol.get_response_len(block=True, timeout=self._response_wait_time) |
| except communication.protocol.EmptyQueue: |
| raise NotAvailable |
| return response.len |