Skip to content

Commit

Permalink
feat: Polling for dataset registration and query execution to avoid t…
Browse files Browse the repository at this point in the history
…imeouts (#178)

This switches dataset registration, query execution, and checkpoint
restoration to busy waiting instead of blocking calls.

Fixes #92.
  • Loading branch information
MaxiBoether authored Feb 19, 2025
1 parent c985ad5 commit 053d8ef
Show file tree
Hide file tree
Showing 16 changed files with 327 additions and 140 deletions.
6 changes: 4 additions & 2 deletions examples/client_local_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,14 @@ def setup_local_client(directory: Path):
client.register_metadata_parser("TEST_PARSER", TestMetadataParser)

# Registering the dataset with the client.
client.register_dataset(
if not client.register_dataset(
"local_integrationtest_dataset",
directory / "testd.jsonl",
JSONLDataset,
parsing_func,
"TEST_PARSER",
)
):
raise RuntimeError("Error while registering dataset!")

return client

Expand All @@ -120,6 +121,7 @@ def run_query(client: MixteraClient, chunk_size: int):
mixture = ArbitraryMixture(chunk_size=chunk_size)
qea = QueryExecutionArgs(mixture=mixture)
client.execute_query(query, qea)
client.wait_for_execution(job_id)

rsa = ResultStreamingArgs(job_id=job_id)
result_samples = list(client.stream_results(rsa))
Expand Down
6 changes: 4 additions & 2 deletions examples/client_server_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def run_query(client: MixteraClient, chunk_size: int, tunnel: bool):
mixture = ArbitraryMixture(chunk_size=chunk_size)
qea = QueryExecutionArgs(mixture=mixture)
client.execute_query(query, qea)
client.wait_for_execution(job_id)

rsa = ResultStreamingArgs(job_id=job_id)
result_samples = list(client.stream_results(rsa))
Expand Down Expand Up @@ -147,13 +148,14 @@ def main(server_host: str, server_port: int):
client.register_metadata_parser("TEST_PARSER", TestMetadataParser)

# Registering the dataset with the client.
client.register_dataset(
if not client.register_dataset(
"server_integrationtest_dataset",
server_dir / "testd.jsonl",
JSONLDataset,
parsing_func,
"TEST_PARSER",
)
):
raise RuntimeError("Error while registering dataset.")

# Run queries on server
chunk_size = 42
Expand Down
2 changes: 2 additions & 0 deletions integrationtests/checkpointing/test_local_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def run_test_arbitrarymixture(client: MixteraClient):
num_workers=0,
)
client.execute_query(query, query_execution_args)
client.wait_for_execution(job_id)
result_streaming_args = ResultStreamingArgs(job_id=job_id)
logger.info("Executed query.")
# Get one chunk for each worker
Expand Down Expand Up @@ -64,6 +65,7 @@ def run_test_arbitrarymixture(client: MixteraClient):
logger.info(f"Got all chunks.")

client.restore_checkpoint(job_id, checkpoint_id)
client.wait_for_execution(job_id)

logger.info("Restored checkpoint.")

Expand Down
2 changes: 2 additions & 0 deletions integrationtests/checkpointing/test_server_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def run_test_arbitrarymixture_server(client: ServerStub, dp_groups, nodes_per_gr
num_workers=num_workers,
)
client.execute_query(query, query_execution_args)
client.wait_for_execution(job_id)
logger.info(
f"Executed query for job {job_id} with dp_groups={dp_groups}, nodes_per_group={nodes_per_group}, num_workers={num_workers}"
)
Expand Down Expand Up @@ -110,6 +111,7 @@ def run_test_arbitrarymixture_server(client: ServerStub, dp_groups, nodes_per_gr

# Restore from checkpoint
client.restore_checkpoint(job_id, checkpoint_id)
client.wait_for_execution(job_id)
logger.info("Restored from checkpoint.")

# Obtain chunks after restoring from checkpoint
Expand Down
11 changes: 11 additions & 0 deletions integrationtests/local/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def test_filter_javascript(
)
query = Query.for_job(result_streaming_args.job_id).select(("language", "==", "JavaScript"))
client.execute_query(query, query_exec_args)
client.wait_for_execution(result_streaming_args.job_id)
result_samples = []
for sample in client.stream_results(result_streaming_args):
result_samples.append(sample)
Expand All @@ -65,6 +66,7 @@ def test_filter_html(
)
query = Query.for_job(result_streaming_args.job_id).select(("language", "==", "HTML"))
client.execute_query(query, query_exec_args)
client.wait_for_execution(result_streaming_args.job_id)
result_samples = []

for sample in client.stream_results(result_streaming_args):
Expand All @@ -91,6 +93,7 @@ def test_filter_both(
.select(("language", "==", "JavaScript"))
)
client.execute_query(query, query_exec_args)
client.wait_for_execution(result_streaming_args.job_id)
result_samples = []

for sample in client.stream_results(result_streaming_args):
Expand All @@ -113,6 +116,7 @@ def test_filter_license(
)
query = Query.for_job(result_streaming_args.job_id).select(("license", "==", "CC"))
client.execute_query(query, query_exec_args)
client.wait_for_execution(result_streaming_args.job_id)
result_samples = []

for sample in client.stream_results(result_streaming_args):
Expand All @@ -135,6 +139,7 @@ def test_filter_unknown_license(
)
query = Query.for_job(result_streaming_args.job_id).select(("license", "==", "All rights reserved."))
client.execute_query(query, query_exec_args)
client.wait_for_execution(result_streaming_args.job_id)
assert len(list(client.stream_results(result_streaming_args))) == 0, "Got results back for expected empty results."


Expand All @@ -150,6 +155,7 @@ def test_filter_license_and_html(
Query.for_job(result_streaming_args.job_id).select(("language", "==", "HTML")).select(("license", "==", "CC"))
)
client.execute_query(query, query_exec_args)
client.wait_for_execution(result_streaming_args.job_id)
result_samples = []

for sample in client.stream_results(result_streaming_args):
Expand Down Expand Up @@ -186,6 +192,7 @@ def test_reproducibility(
)
query_exec_args.mixture = mixture
client.execute_query(query, query_exec_args)
client.wait_for_execution(result_streaming_args.job_id)
result_samples = []

for sample in client.stream_results(result_streaming_args):
Expand Down Expand Up @@ -216,6 +223,8 @@ def test_mixture_schedule(client: MixteraClient):
query_execution_args = QueryExecutionArgs(mixture=mixture_schedule)
result_streaming_args = ResultStreamingArgs(job_id)
assert client.execute_query(query, query_execution_args)
assert client.wait_for_execution(job_id)

logger.info(f"Executed query for job {job_id} for mixture schedule.")

result_samples = []
Expand Down Expand Up @@ -269,6 +278,8 @@ def test_dynamic_mixture(client: MixteraClient):
result_streaming_args = ResultStreamingArgs(job_id)

assert client.execute_query(query, query_execution_args)
assert client.wait_for_execution(job_id)

logger.info(f"Executed query for job {job_id} for dynamic mixture.")

result_iter = client.stream_results(result_streaming_args)
Expand Down
12 changes: 10 additions & 2 deletions integrationtests/server/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def test_filter_javascript(
)
query = Query.for_job(result_streaming_args.job_id).select(("language", "==", "JavaScript"))
assert client.execute_query(query, query_exec_args)
assert client.wait_for_execution(result_streaming_args.job_id)
result_samples = []

for sample in client.stream_results(result_streaming_args):
Expand All @@ -73,6 +74,7 @@ def test_filter_html(
)
query = Query.for_job(result_streaming_args.job_id).select(("language", "==", "HTML"))
assert client.execute_query(query, query_exec_args)
assert client.wait_for_execution(result_streaming_args.job_id)
result_samples = []

for sample in client.stream_results(result_streaming_args):
Expand Down Expand Up @@ -100,6 +102,7 @@ def test_filter_both(
.select(("language", "==", "JavaScript"))
)
assert client.execute_query(query, query_exec_args)
assert client.wait_for_execution(result_streaming_args.job_id)
result_samples = []

for sample in client.stream_results(result_streaming_args):
Expand Down Expand Up @@ -148,6 +151,7 @@ def test_filter_unknown_license(
)
query = Query.for_job(result_streaming_args.job_id).select(("license", "==", "All rights reserved."))
assert client.execute_query(query, query_exec_args)
assert client.wait_for_execution(result_streaming_args.job_id)
assert len(list(client.stream_results(result_streaming_args))) == 0, "Got results back for expected empty results."


Expand All @@ -164,6 +168,7 @@ def test_filter_license_and_html(
Query.for_job(result_streaming_args.job_id).select(("language", "==", "HTML")).select(("license", "==", "CC"))
)
assert client.execute_query(query, query_exec_args)
assert client.wait_for_execution(result_streaming_args.job_id)
result_samples = []

for sample in client.stream_results(result_streaming_args):
Expand All @@ -190,15 +195,16 @@ def test_reproducibility(
f"6_{query_exec_args.mixture.chunk_size}_{query_exec_args.dp_groups}"
+ f"_{query_exec_args.nodes_per_group}_{query_exec_args.num_workers}_{result_streaming_args.chunk_reading_degree_of_parallelism}"
+ f"_{result_streaming_args.chunk_reading_window_size}_{result_streaming_args.chunk_reading_mixture_type}"
+ f"_reproducibility_{i}"
+ f"_{result_streaming_args.tunnel_via_server}_reproducibility_{i}"
)
query = (
Query.for_job(result_streaming_args.job_id)
.select(("language", "==", "HTML"))
.select(("language", "==", "JavaScript"))
)
query_exec_args.mixture = mixture
client.execute_query(query, query_exec_args)
assert client.execute_query(query, query_exec_args)
assert client.wait_for_execution(result_streaming_args.job_id)
result_samples = []

for sample in client.stream_results(result_streaming_args):
Expand Down Expand Up @@ -229,6 +235,7 @@ def test_mixture_schedule(client: ServerStub):
query_execution_args = QueryExecutionArgs(mixture=mixture_schedule)
result_streaming_args = ResultStreamingArgs(job_id)
assert client.execute_query(query, query_execution_args)
assert client.wait_for_execution(job_id)
logger.info(f"Executed query for job {job_id} for mixture schedule.")

result_samples = []
Expand Down Expand Up @@ -282,6 +289,7 @@ def test_dynamic_mixture(client: MixteraClient):
result_streaming_args = ResultStreamingArgs(job_id)

assert client.execute_query(query, query_execution_args)
assert client.wait_for_execution(job_id)
logger.info(f"Executed query for job {job_id} for dynamic mixture.")

result_iter = client.stream_results(result_streaming_args)
Expand Down
5 changes: 4 additions & 1 deletion mixtera/core/client/local/local_stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def register_dataset(
) -> bool:
if isinstance(loc, Path):
loc = str(loc)

return self._mdc.register_dataset(identifier, loc, dtype, parsing_func, metadata_parser_identifier)

def register_metadata_parser(
Expand Down Expand Up @@ -95,6 +94,10 @@ def execute_query(self, query: Query, args: QueryExecutionArgs) -> bool:
query, args.mixture, args.dp_groups, args.nodes_per_group, args.num_workers, cache_path
)

def wait_for_execution(self, job_id: str) -> bool:
logger.info(f"Waiting for execution of {job_id}.")
return wait_for_key_in_dict(self._training_query_map, job_id, 30)

def is_remote(self) -> bool:
return False

Expand Down
18 changes: 16 additions & 2 deletions mixtera/core/client/mixtera_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,20 @@ def execute_query(self, query: Query, args: QueryExecutionArgs) -> bool:

raise NotImplementedError()

@abstractmethod
def wait_for_execution(self, job_id: str) -> bool:
"""
Waits until the query has finished executing.
Args:
job_id (str): The job id of the query
Returns:
bool indicating success
"""

raise NotImplementedError()

def stream_results(self, args: ResultStreamingArgs) -> Generator[tuple[int, int, Sample], None, None]:
"""
Given a job ID, returns the QueryResult object from which the result chunks can be obtained.
Expand All @@ -265,7 +279,7 @@ def stream_results(self, args: ResultStreamingArgs) -> Generator[tuple[int, int,
with self.current_mixture_id_val.get_lock():
new_id = max(result_chunk.mixture_id, self.current_mixture_id_val.get_obj().value)
self.current_mixture_id_val.get_obj().value = new_id
logger.debug(f"Set current mixture ID to {new_id}")
# logger.debug(f"Set current mixture ID to {new_id}")

result_chunk.configure_result_streaming(
client=self,
Expand All @@ -275,7 +289,7 @@ def stream_results(self, args: ResultStreamingArgs) -> Generator[tuple[int, int,

with self.current_mixture_id_val.get_lock():
self.current_mixture_id_val.get_obj().value = -1
logger.debug("Reset current mixture ID to -1.")
# logger.debug("Reset current mixture ID to -1.")

@abstractmethod
def _stream_result_chunks(
Expand Down
21 changes: 20 additions & 1 deletion mixtera/core/client/server/server_stub.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from pathlib import Path
from typing import Any, Callable, Generator, Type

Expand Down Expand Up @@ -69,7 +70,7 @@ def execute_query(self, query: Query, args: QueryExecutionArgs) -> bool:
logger.error("Could not register query at server!")
return False

logger.info(f"Registered query for job {query.job_id} at server!")
logger.info(f"Started query registration for job {query.job_id} at server!")

return True

Expand Down Expand Up @@ -132,3 +133,21 @@ def checkpoint_completed(self, job_id: str, chkpnt_id: str, on_disk: bool) -> bo

def restore_checkpoint(self, job_id: str, chkpnt_id: str) -> None:
return self.server_connection.restore_checkpoint(job_id, chkpnt_id)

def wait_for_execution(self, job_id: str) -> bool:
logger.info("Waiting for query execution at server to finish.")
status = self.server_connection.check_query_exec_status(job_id)

timeout_minutes = 30
curr_time = 0
while status == 0 and curr_time <= timeout_minutes * 60:
time.sleep(1)
status = self.server_connection.check_query_exec_status(job_id)
curr_time += 1

if status != 1:
logger.error(f"Query execution failed with status {status}.")
return False

logger.info("Query execution finished.")
return True
Loading

0 comments on commit 053d8ef

Please sign in to comment.