Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix run_in_subprocess #299

Merged
merged 1 commit into from
Dec 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions src/spdl/dataloader/_iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
# Message from worker to the parent
_MSG_GENERATOR_FAILED = "GENERATOR_FAILED_TO_INITIALIZE"
_MSG_ITERATION_FINISHED = "ITERATION_FINISHED"
_MSG_DATA_QUEUE_FULL = "DATA_QUEUE_FULL"
_MSG_DATA_QUEUE_FAILED = "DATA_QUEUE_FAILED"


def _execute_iterator(
Expand All @@ -63,7 +63,7 @@ def _execute_iterator(
else:
if msg == _MSG_PARENT_REQUEST_STOP:
return
raise ValueError(f"Unexpected message redeived: {msg}")
raise ValueError(f"[INTERNAL ERROR] Unexpected message received: {msg}")

try:
item = next(gen)
Expand All @@ -76,8 +76,8 @@ def _execute_iterator(

try:
data_queue.put(item)
except queue.Full:
msg_queue.put(_MSG_DATA_QUEUE_FULL)
except Exception:
msg_queue.put(_MSG_DATA_QUEUE_FAILED)
return


Expand Down Expand Up @@ -111,7 +111,11 @@ def run_in_subprocess(
"""
ctx = mp.get_context(mp_context)
msg_q = ctx.Queue()
data_q = ctx.Queue(maxsize=queue_size)
data_q: mp.Queue = ctx.Queue(maxsize=queue_size)

def _drain() -> Iterator[T]:
while not data_q.empty():
yield data_q.get_nowait()

process = ctx.Process(
target=_execute_iterator,
Expand All @@ -127,18 +131,21 @@ def run_in_subprocess(
except queue.Empty:
pass
else:
# When a message is found, the child process stopped putting data.
yield from _drain()

if msg == _MSG_ITERATION_FINISHED:
return
if msg == _MSG_GENERATOR_FAILED:
raise RuntimeError(
"The worker process quit because the generator failed."
)
if msg == _MSG_DATA_QUEUE_FULL:
if msg == _MSG_DATA_QUEUE_FAILED:
raise RuntimeError(
"The worker process quit because the data queue is full for too long."
"The worker process quit because it failed at passing the data."
)

raise ValueError(f"Unexpected message received: {msg}")
raise ValueError(f"[INTERNAL ERROR] Unexpected message received: {msg}")

try:
yield data_q.get(timeout=1)
Expand All @@ -153,12 +160,11 @@ def run_in_subprocess(
f"The worker process did not produce any data for {elapsed:.2f} seconds."
)

except Exception:
except (Exception, KeyboardInterrupt):
msg_q.put(_MSG_PARENT_REQUEST_STOP)
raise
finally:
while not data_q.empty():
data_q.get_nowait()
yield from _drain()
process.join(3)

if process.exitcode is None:
Expand Down
Loading