Skip to content

Commit

Permalink
Added timeout for merge and hash bucket tasks as a symptomatic fix (r…
Browse files Browse the repository at this point in the history
…ay-project#306)

* Added timeout for merge and hash bucket tasks as a symptomatic fix for the daft resource leaks

* Add documentation

* address comments
  • Loading branch information
raghumdani authored May 28, 2024
1 parent 8d789da commit 09d0495
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 0 deletions.
12 changes: 12 additions & 0 deletions deltacat/compute/compactor_v2/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from deltacat.utils.common import env_integer

TOTAL_BYTES_IN_SHA1_HASH = 20

PK_DELIMITER = "L6kl7u5f"
Expand Down Expand Up @@ -41,6 +43,16 @@
# size in metadata to pyarrow table size.
PARQUET_TO_PYARROW_INFLATION = 4

# A merge task will fail after this timeout
# The default is currently double the observed maximum.
# This timeout depends on total data processed per task.
MERGE_TASK_TIMEOUT_IN_SECONDS = env_integer("MERGE_TASK_TIMEOUT_IN_SECONDS", 25 * 60)

# A hash bucket task will fail after this timeout
HASH_BUCKET_TASK_TIMEOUT_IN_SECONDS = env_integer(
"HASH_BUCKET_TASK_TIMEOUT_IN_SECONDS", 25 * 60
)

# Metric Names
# Time taken for a hash bucket task
HASH_BUCKET_TIME_IN_SECONDS = "hash_bucket_time"
Expand Down
6 changes: 6 additions & 0 deletions deltacat/compute/compactor_v2/steps/hash_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@
from deltacat.utils.resources import (
get_current_process_peak_memory_usage_in_bytes,
ProcessUtilizationOverTimeRange,
timeout,
)
from deltacat.constants import BYTES_PER_GIBIBYTE
from deltacat.compute.compactor_v2.constants import (
HASH_BUCKET_TIME_IN_SECONDS,
HASH_BUCKET_FAILURE_COUNT,
HASH_BUCKET_SUCCESS_COUNT,
HASH_BUCKET_TASK_TIMEOUT_IN_SECONDS,
)

if importlib.util.find_spec("memray"):
Expand Down Expand Up @@ -96,8 +98,12 @@ def _group_file_records_by_pk_hash_bucket(
return hb_to_delta_file_envelopes, total_record_count, total_size_bytes


# TODO: use timeout parameter in ray.remote
# https://github.com/ray-project/ray/issues/18916
# Note: order of decorators is important
@success_metric(name=HASH_BUCKET_SUCCESS_COUNT)
@failure_metric(name=HASH_BUCKET_FAILURE_COUNT)
@timeout(HASH_BUCKET_TASK_TIMEOUT_IN_SECONDS)
def _timed_hash_bucket(input: HashBucketInput):
task_id = get_current_ray_task_id()
worker_id = get_current_ray_worker_id()
Expand Down
6 changes: 6 additions & 0 deletions deltacat/compute/compactor_v2/steps/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from deltacat.utils.resources import (
get_current_process_peak_memory_usage_in_bytes,
ProcessUtilizationOverTimeRange,
timeout,
)
from deltacat.compute.compactor_v2.utils.primary_key_index import (
generate_pk_hash_column,
Expand All @@ -46,6 +47,7 @@
MERGE_TIME_IN_SECONDS,
MERGE_SUCCESS_COUNT,
MERGE_FAILURE_COUNT,
MERGE_TASK_TIMEOUT_IN_SECONDS,
)


Expand Down Expand Up @@ -484,8 +486,12 @@ def _copy_manifests_from_hash_bucketing(
return materialized_results


# TODO: use timeout parameter in ray.remote
# https://github.com/ray-project/ray/issues/18916
# Note: order of decorators is important
@success_metric(name=MERGE_SUCCESS_COUNT)
@failure_metric(name=MERGE_FAILURE_COUNT)
@timeout(MERGE_TASK_TIMEOUT_IN_SECONDS)
def _timed_merge(input: MergeInput) -> MergeResult:
task_id = get_current_ray_task_id()
worker_id = get_current_ray_worker_id()
Expand Down
28 changes: 28 additions & 0 deletions deltacat/tests/utils/test_resources.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import unittest
from unittest import mock
import time
from multiprocessing import Pool
import platform


class TestGetCurrentClusterUtilization(unittest.TestCase):
Expand Down Expand Up @@ -70,3 +72,29 @@ def test_callback():
nu.schedule_callback(test_callback, 1)
time.sleep(3)
self.assertTrue(nu.test_field_set)


class TestTimeoutDecorator(unittest.TestCase):
from deltacat.utils.resources import timeout

@staticmethod
@timeout(2)
def something_that_runs_xs(x, *args, **kwargs):
time.sleep(x)

def test_timeout(self):
if platform.system() != "Windows":
self.assertRaises(
TimeoutError, lambda: self.something_that_runs_xs(3, test=10)
)

def test_sanity_in_multiprocess(self):
if platform.system() != "Windows":
# An alarm works per process
# https://pubs.opengroup.org/onlinepubs/9699919799/functions/alarm.html
with Pool(3) as p:
p.map(self.something_that_runs_xs, [1, 1.1, 1.2])

def test_sanity(self):
if platform.system() != "Windows":
self.something_that_runs_xs(1, test=10)
45 changes: 45 additions & 0 deletions deltacat/utils/resources.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Allow classes to use self-referencing Type hints in Python 3.7.
from __future__ import annotations

import functools
import signal
from contextlib import AbstractContextManager
from types import TracebackType
import ray
Expand Down Expand Up @@ -230,3 +232,46 @@ def run(cls):
continuous_thread = ScheduleThread()
continuous_thread.start()
return cease_continuous_run


def timeout(value_in_seconds: int):
"""
A decorator that will raise a TimeoutError if the decorated function takes longer
than the specified timeout.
Note: The decorator does not work in a multithreading env or on Windows platform.
Hence, the default behavior is same as executing a method without timeout set.
Also note: it is still the responsibility of the caller to clean up any resource leaks
during the execution of the underlying function.
"""

def _decorate(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
current_platform = platform.system()

def handler(signum, frame):
raise TimeoutError(
f"Timeout occurred on method: {func.__name__},"
f" args={args}, kwargs={kwargs}"
)

if current_platform == "Windows":
return func(*args, **kwargs)

old_handler = signal.signal(signal.SIGALRM, handler)
# An alarm works per process.
# https://pubs.opengroup.org/onlinepubs/9699919799/functions/alarm.html
signal.alarm(value_in_seconds)
try:
return func(*args, **kwargs)
finally:
# reset the SIGALRM handler
signal.signal(signal.SIGALRM, old_handler)
# cancel the alarm
signal.alarm(0)

return wrapper

return _decorate

0 comments on commit 09d0495

Please sign in to comment.