From fd9451cbb94aa0f202a43636fb48fe35de200906 Mon Sep 17 00:00:00 2001 From: Julien Hoareau Date: Wed, 22 Feb 2023 15:59:50 +0100 Subject: [PATCH] Add state communication to iter and map datapipes, add snapshot to MPRS and add corresponding test --- test/dataloader2/test_mprs.py | 40 +++++++++++++++++-- torchdata/dataloader2/communication/iter.py | 20 ++++++++++ torchdata/dataloader2/communication/map.py | 9 +++++ .../dataloader2/communication/messages.py | 11 +++++ .../dataloader2/communication/protocol.py | 35 ++++++++++++++++ torchdata/dataloader2/reading_service.py | 16 ++++++++ 6 files changed, 128 insertions(+), 3 deletions(-) diff --git a/test/dataloader2/test_mprs.py b/test/dataloader2/test_mprs.py index 634094e49..bc90e6e24 100644 --- a/test/dataloader2/test_mprs.py +++ b/test/dataloader2/test_mprs.py @@ -276,10 +276,44 @@ def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_pr # cumulative_res.extend(res) # self.assertEqual(list(range(n_elements)), sorted(cumulative_res)) + @mp_ctx_parametrize + @dp_parametrize + @parametrize("n_workers,worker_prefetch_cnt,main_prefetch_cnt", [(2, 1, 1), (4, 1, 0), (4, 0, 0)]) + def test_reading_service_snapshot(self, ctx, dp, n_workers, worker_prefetch_cnt, main_prefetch_cnt) -> None: + # Functional Test: Confirms that `snapshot` does capture the state of the underlying DataPipes properly + rs = MultiProcessingReadingService( + num_workers=n_workers, + worker_prefetch_cnt=worker_prefetch_cnt, + main_prefetch_cnt=main_prefetch_cnt, + multiprocessing_context=ctx, + ) + dl: DataLoader2 = DataLoader2(dp, reading_service=rs) + res = [] + stop_index = 3 + for i, x in enumerate(dl): + res.append(x) + if i == stop_index: + snapshot = dl.reading_service.snapshot() + break + self.assertEqual( + n_workers, + len(snapshot), + msg=f"The test is failing with '{ctx}', num_workers = {rs.num_workers}, " + f"worker_prefetch_cnt = {rs.worker_prefetch_cnt}, main_prefetch_cnt = {rs.main_prefetch_cnt}", + ) + + if worker_prefetch_cnt == 0 and main_prefetch_cnt == 0 and dp == dp1: + for snapshot_worker in snapshot: + self.assertAlmostEqual( + stop_index + worker_prefetch_cnt, + snapshot_worker["_number_of_samples_yielded"], + delta=2, + msg=f"The test is failing with '{ctx}', num_workers = {rs.num_workers}, " + f"worker_prefetch_cnt = {rs.worker_prefetch_cnt}, main_prefetch_cnt = {rs.main_prefetch_cnt}", + ) + dl.shutdown() + # TODO: Implemented in an upcoming PR - # def test_reading_service_snapshot(self) -> None: - # pass - # # def test_dataloader2_snapshot(self) -> None: # pass diff --git a/torchdata/dataloader2/communication/iter.py b/torchdata/dataloader2/communication/iter.py index ee0816879..a44ec4341 100644 --- a/torchdata/dataloader2/communication/iter.py +++ b/torchdata/dataloader2/communication/iter.py @@ -193,6 +193,15 @@ def DataPipeBehindQueues( forever = False protocol.response_terminate() + elif isinstance(request, communication.messages.GetStateRequest): + datapipe_state = source_datapipe.__getstate__() + # Remove pickle-incompatible keys from the state + datapipe_state = { + k: v for k, v in datapipe_state.items() if not callable(v) and not isinstance(v, types.GeneratorType) + } + protocol.response_state(datapipe_state) + yield True # Return control + elif isinstance(request, communication.messages.GetNextRequest): while forever: if protocol.is_paused(): @@ -273,6 +282,14 @@ def resume(self): if NonBlocking.not_available_hook is not None: NonBlocking.not_available_hook() + def state_dict(self): + self.protocol.request_state() + try: + response = self.protocol.get_response_state(block=True, timeout=self._response_wait_time) + except communication.protocol.EmptyQueue: + raise NotAvailable + return response.value + def nonblocking_next(self): if self._stop_iteration: raise Exception("`next` or `nonblocking_next` called after receiving StopIteration") @@ -374,3 +391,6 @@ def request_pause(self): def request_resume(self): for dp in self.datapipes: dp.resume() + + def state_dict(self): + return [dp.state_dict() for dp in self.datapipes] diff --git a/torchdata/dataloader2/communication/map.py b/torchdata/dataloader2/communication/map.py index 3dee2e419..3b106470b 100644 --- a/torchdata/dataloader2/communication/map.py +++ b/torchdata/dataloader2/communication/map.py @@ -183,3 +183,12 @@ def nonblocking_len(self): except communication.protocol.EmptyQueue: raise NotAvailable return response.len + + def state_dict(self): + if self.protocol.can_take_request(): + self.protocol.request_state() + try: + response = self.protocol.get_response_state(block=True, timeout=self._response_wait_time) + except communication.protocol.EmptyQueue: + raise NotAvailable + return response.value diff --git a/torchdata/dataloader2/communication/messages.py b/torchdata/dataloader2/communication/messages.py index 70cadc596..a284074f9 100644 --- a/torchdata/dataloader2/communication/messages.py +++ b/torchdata/dataloader2/communication/messages.py @@ -103,6 +103,17 @@ class StopIterationResponse(Response): pass +class GetStateRequest(Request): + pass + + +class GetStateResponse(Request): + __slots__ = "value" + + def __init__(self, value): + self.value = value + + class InvalidStateResponse(Response): """ Returned by DataPipe when it is expecting to get reset request, diff --git a/torchdata/dataloader2/communication/protocol.py b/torchdata/dataloader2/communication/protocol.py index f1b86d6e0..6c3d05910 100644 --- a/torchdata/dataloader2/communication/protocol.py +++ b/torchdata/dataloader2/communication/protocol.py @@ -64,6 +64,13 @@ def request_resume(self): self.request_queue.put(request) self.request_sent(request) + def request_state(self): + if not self.can_take_request(): + raise Exception("Can not request state while we are still waiting response for previous request") + request = communication.messages.GetStateRequest() + self.request_queue.put(request) + self.request_sent(request) + class ProtocolServer(Protocol): """ @@ -132,6 +139,12 @@ def response_resume(self): self.response_queue.put(communication.messages.ResumeResponse()) self._req_received = None + def response_state(self, value): + if not self.have_pending_request(): + raise Exception("Attempting to reply with pending request") + self.response_queue.put(communication.messages.GetStateResponse(value)) + self._req_received = None + def response_worker_exception(self, exception): if not self.have_pending_request(): raise Exception("Attempting to reply with pending request") @@ -205,6 +218,17 @@ def get_response_item(self, block=False, timeout=None): # raise Exception('Invalid response received') return response + def get_response_state(self, block=False, timeout=None): + if not self.waiting_for_response(): + raise Exception("Can not expect any response without submitted request") + try: + response = self.response_queue.get(block=block, timeout=timeout) + except EmptyException: + raise EmptyQueue("queue is empty") + self.request_served(response) + + return response + class EmptyQueue(Exception): pass @@ -311,3 +335,14 @@ def get_response_next(self, block=False, timeout=None): # TODO(629): Add possible response types validation here return response + + def get_response_state(self, block=False, timeout=None): + if not self.waiting_for_response(): + raise Exception("Can not expect any response without submitted request") + try: + response = self.response_queue.get(block=block, timeout=timeout) + except EmptyException: + raise EmptyQueue("queue is empty") + self.request_served(response) + + return response diff --git a/torchdata/dataloader2/reading_service.py b/torchdata/dataloader2/reading_service.py index 934b468c5..0a3353bdb 100644 --- a/torchdata/dataloader2/reading_service.py +++ b/torchdata/dataloader2/reading_service.py @@ -376,6 +376,22 @@ def clean_me(process, req_queue, res_queue): self._worker_processes = [] self._dispatch_process = None + def snapshot(self): + """ + Captures the state_dict of the underlying worker datapipes via the consumer datapipe. + We only capture the worker datapipes's states and not the prefetching datapipe. + This is a PoC for now so there is no corresponding restoring action to make it properly checkpointable. + """ + if self.num_workers == 0: + raise RuntimeError( + "If you would like to use `snapshot` with `MultiProcessingReadingService`, please use more than 0 workers." + ) + + self._pause() + result = self._worker_consumer_datapipe.state_dict() + self._resume() + return result + def _pause(self): """ Pauses DataPipes' activities such as prefetching, in order to collect state.