Skip to content

Commit

Permalink
improved dask logic.
Browse files Browse the repository at this point in the history
  • Loading branch information
vtimmel committed Apr 11, 2023
1 parent 86bcfab commit 32a744b
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 41 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,11 @@ karabo/test/data/*
result/*
karabo/result/*
karabo/examples/result/*
karabo/examples/karabo/test/data/results/*

# Vscode
.vscode/*
/.vs

# Slurm logs
**.out

3 changes: 3 additions & 0 deletions karabo/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@ class KaraboError(Exception):
"""
Base Exception thrown by the Karabo Pipeline
"""

class NodeTermination(Exception):
pass
2 changes: 1 addition & 1 deletion karabo/examples/sbatch_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK
export CRAY_CUDA_MPS=1
conda activate karabo_dev_env
srun python3 time_karabo_slurm.py
srun python3 test_long_observation.py
105 changes: 66 additions & 39 deletions karabo/util/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import sys
import time
from datetime import datetime
from subprocess import call
from subprocess import PIPE, Popen, call

import psutil
from distributed import Client, LocalCluster

from karabo.error import NodeTermination


def get_global_client(
min_ram_gb_per_worker: int = 2, threads_per_worker: int = 1
Expand Down Expand Up @@ -40,59 +42,73 @@ def get_local_dask_client(min_ram_gb_per_worker, threads_per_worker) -> Client:
return client


def dask_cleanup(client: Client):
# Create the stop_workers file
with open("stop_workers", "w") as f:
pass

# Give some time for the workers to exit before closing the client
time.sleep(10)

# Remove the stop_workers file
if os.path.exists("stop_workers"):
os.remove("stop_workers")
client.close()

def setup_dask_for_slurm(number_of_workers_on_scheduler_node: int = 1):
# Detect if we are on a slurm cluster
if "SLURM_JOB_ID" not in os.environ or os.getenv("SLURM_JOB_NUM_NODES") == "1":
if not is_on_slurm_cluster or os.getenv("SLURM_JOB_NUM_NODES") == "1":
print("Not on a SLURM cluster or only 1 node. Not setting up dask.")
return None

else:
if is_first_node():
# Remove old scheduler file
try:
os.remove("scheduler.txt")
except FileNotFoundError:
pass
if os.path.exists("scheduler.json"):
os.remove("scheduler.json")

# Create client and scheduler
print(create_node_list_except_first())
cluster = LocalCluster(
ip=get_lowest_node_name(), n_workers=number_of_workers_on_scheduler_node
)
client = Client(cluster)

# Write the scheduler address to a file
with open("scheduler.txt", "w") as f:
# Write scheduler file
with open("scheduler.json", "w") as f:
f.write(cluster.scheduler_address)

print(
f'Main Node. Name = {os.getenv("SLURMD_NODENAME")}. Client = {client}'
)

while len(client.scheduler_info()["workers"]) != int(
os.getenv("SLURM_JOB_NUM_NODES")
):
time.sleep(3)
# Wait until all workers are connected
n_workers_requested = get_number_of_nodes() - 1 + number_of_workers_on_scheduler_node
while len(client.scheduler_info()["workers"]) < n_workers_requested:
print(f"Waiting for all workers to connect. Currently {len(client.scheduler_info()['workers'])} workers connected of {n_workers_requested} requested.")
time.sleep(5)

# Print the number of workers
print(f'Number of workers: {len(client.scheduler_info()["workers"])}')
return client

else:
# Sleep first to make sure no old scheduler file is read
# Wait some time to make sure the scheduler file is created
time.sleep(5)

# Read the scheduler address from the file
scheduler_address = None
timeout_time = datetime.now().timestamp() + 60
while (
scheduler_address is None and datetime.now().timestamp() < timeout_time
):
try:
with open("scheduler.txt", "r") as f:
scheduler_address = f.read()
except FileNotFoundError:
time.sleep(1)
call(["dask", "worker", scheduler_address])
sys.exit(1)
# Wait until scheduler file is created
while not os.path.exists("scheduler.json"):
print("Waiting for scheduler file to be created.")
time.sleep(5)

# Read scheduler file
with open("scheduler.json", "r") as f:
scheduler_address = f.read()

# Create client
call(['dask', 'worker', scheduler_address])

# Run until stop_workers file is created
while True:
if os.path.exists("stop_workers"):
print("Stop workers file detected. Exiting.")
sys.exit(0)
time.sleep(5)



def get_min_max_of_node_id():
Expand All @@ -111,17 +127,25 @@ def get_min_max_of_node_id():
def get_lowest_node_id():
return get_min_max_of_node_id()[0]

def get_base_string_node_list():
return os.getenv("SLURM_JOB_NODELIST").split("[")[0]

def get_lowest_node_name():
return os.getenv("SLURM_JOB_NODELIST").split("[")[0] + str(get_lowest_node_id())

return get_base_string_node_list() + str(get_lowest_node_id())

def create_list_of_node_names():
return [
os.getenv("SLURM_JOB_NODELIST").split("[")[0] + str(i)
for i in range(get_min_max_of_node_id()[0], get_min_max_of_node_id()[1] + 1)
]
def get_number_of_nodes():
return get_min_max_of_node_id()[1] - get_min_max_of_node_id()[0] + 1

def create_node_list_except_first():
"""
Returns a list of all nodes except the first one to pass to SLURM
Example: node[2-4] if there are 4 nodes or node[2] if there are 2 nodes
"""
min_node, max_node = get_min_max_of_node_id()
if get_number_of_nodes() == 2:
return get_base_string_node_list() + "[" + str(min_node + 1) + "]"

return get_base_string_node_list() + "[" + str(min_node + 1) + "-" + str(max_node) + "]"

def get_node_id():
len_id = len(str(get_lowest_node_id()))
Expand All @@ -134,3 +158,6 @@ def is_first_node():

def get_current_time():
return time.strftime("%H:%M:%S", time.localtime())

def is_on_slurm_cluster():
return "SLURM_JOB_ID" in os.environ

0 comments on commit 32a744b

Please sign in to comment.