-
Notifications
You must be signed in to change notification settings - Fork 48
/
Copy pathsubmit_slurm_jobs.py
220 lines (177 loc) · 8.95 KB
/
submit_slurm_jobs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
from enum import Enum
import os
from jinja2 import Template
import subprocess
import json
from typing import List
class Status(Enum):
# INIT -> PENDING -> [RUNNING | FAIL | TIMEOUT OOM] -> COMPLETED
INIT = "init" # Job is created
PENDING = "pending" # Job is waiting for ressources
RUNNING = "running" # Job is running
FAIL = "fail" # Job failed
OOM = "oom" # Job failed due to out of memory (expected behavior)
TIMEOUT = "timeout" # Job failed due to timeout
COMPLETED = "completed" # Job is completed
class Job:
def __init__(self, root_path: str, qos: str) -> None:
self.root_path = root_path
self.name = os.path.basename(root_path)
self.config = os.path.join(root_path, "config.json")
self.qos = qos
# Check if the status.txt file exists
status_file_path = os.path.join(self.root_path, "status.txt")
if not os.path.exists(status_file_path):
# Create the status.txt file with INIT status
with open(status_file_path, 'w') as f:
f.write(Status.INIT.value)
self.status = self.get_status()
def get_status(self) -> Status:
"""
Read the status of the job from `status.txt` and return it
"""
is_existing = lambda value_to_check: any(value.value == value_to_check for value in Status.__members__.values())
status_file_path = os.path.join(self.root_path, "status.txt")
with open(status_file_path, 'r') as f:
status = f.read()
if not is_existing(status):
raise ValueError("Invalid status")
return Status(status)
def set_status(self, status: Status) -> Status:
"""
Update the status of the job in `status.txt` and return the new status
"""
status_file_path = os.path.join(self.root_path, "status.txt")
with open(status_file_path, 'w') as f:
f.write(status.value)
return status
class Scheduler:
def __init__(self, inp_dir: str, qos: str) -> None:
jobs_directory_paths = [os.path.abspath(root) for root, dirs, _ in os.walk(inp_dir) if not dirs]
jobs_directory_paths = [job_path.replace("/profiler", "") if "profiler" in job_path else job_path for job_path in jobs_directory_paths]
self.job_lists = [Job(job_path, qos) for job_path in jobs_directory_paths]
def keep_only_jobs(self, status: Status):
return [job for job in self.job_lists if job.status == status]
def filter_out_jobs(self, status: Status):
return [job for job in self.job_lists if job.status != status]
def create_slurm_script(self, job: Job):
# Submit job to the cluster (edit jinja)
# load yaml config.yaml
with open(job.config, 'r') as file:
config = json.load(file)
max_gpu_per_node = 8
# Pick the right number of nodes and n_proc_per_node
world_size = config["distributed"]["tp_size"] * config["distributed"]["cp_size"] * config["distributed"]["pp_size"] * config["distributed"]["dp_size"]
assert world_size <= max_gpu_per_node or world_size % max_gpu_per_node == 0
nodes = max(1, world_size // max_gpu_per_node)
n_proc_per_node = min(max_gpu_per_node, world_size // nodes)
assert nodes * n_proc_per_node == world_size
context_bench = {
'nodes': nodes,
'n_proc_per_node': n_proc_per_node,
'root_path': job.root_path,
"config": job.config,
"qos": job.qos,
}
base_path = os.path.join(os.getcwd(), "template/base_job.slurm")
with open(base_path, 'r') as file:
base_job_file = file.read()
base_job_template = Template(base_job_file)
# Write the rendered script to a new file located at the job root_path
output_file_path = os.path.join(job.root_path, "job.slurm")
with open(output_file_path, 'w') as file:
file.write(base_job_template.render(context_bench))
print(f"Slurm script created at {output_file_path}")
def launch_dependency(self, job_array: List[Job], env_vars):
prev_job_id = None
for job in job_array:
if prev_job_id is None:
result = subprocess.run(["sbatch", '--parsable', os.path.join(job.root_path, "job.slurm")], env=env_vars, capture_output=True, text=True)
else:
result = subprocess.run(["sbatch", '--parsable', '--dependency=afterany:'+prev_job_id, os.path.join(job.root_path, "job.slurm")], env=env_vars, capture_output=True, text=True)
job.set_status(Status.PENDING)
prev_job_id = result.stdout.strip()
def check_status(self):
# find all status files using self.jobs_directory_paths
status_files = [os.path.join(job.root_path, "status.txt") for job in self.job_lists]
status_counts = {
"init": 0,
"pending": 0,
"running": 0,
"fail": 0,
"oom": 0,
"timeout": 0,
"completed": 0
}
for status_file in status_files:
with open(status_file, 'r') as f:
status = f.read().strip()
if status in status_counts:
status_counts[status] += 1
else:
raise ValueError(f"Invalid status: {status}")
total = sum(status_counts.values())
# Print the status counts in a formatted table
print(f"{'Status':<10} | {'Count':<6}")
print(f"{'-'*10}-|-{'-'*6}")
for status, count in status_counts.items():
print(f"{status.capitalize():<10} | {count:<6}")
print(f"{'-'*10}-|-{'-'*6}")
print(f"{'Total':<10} | {total:<6}")
def submit_jobs(inp_dir, qos, hf_token, nb_slurm_array, only: str = None):
scheduler = Scheduler(inp_dir, qos)
#TODO: batch into job arrays
env_vars = os.environ.copy()
env_vars["HUGGINGFACE_TOKEN"] = hf_token
total_jobs = len(scheduler.job_lists)
if only == "fail":
scheduler.job_lists = scheduler.keep_only_jobs(Status.FAIL)
elif only == "pending":
scheduler.job_lists = scheduler.keep_only_jobs(Status.PENDING)
elif only == "timeout":
scheduler.job_lists = scheduler.keep_only_jobs(Status.TIMEOUT)
elif only == "running":
scheduler.job_lists = scheduler.keep_only_jobs(Status.RUNNING)
if only is not None:
filtered_jobs = len(scheduler.job_lists)
if filtered_jobs == 0:
print(f"No '{only}' jobs to resubmit")
return
print(f"Only {filtered_jobs}/{total_jobs} jobs with status '{only}' will be resubmitted")
scheduler.job_lists = scheduler.filter_out_jobs(Status.COMPLETED)
if nb_slurm_array > 0:
# Use job dependecies
# Distribute the jobs into the arrays
base_jobs_per_array = len(scheduler.job_lists) // nb_slurm_array
extra_jobs = len(scheduler.job_lists) % nb_slurm_array
distribution = [base_jobs_per_array] * nb_slurm_array
for i in range(extra_jobs):
distribution[i] += 1
start = 0
for i, nb_jobs in enumerate(distribution):
previous_job_id = None
end = start + nb_jobs
job_array = scheduler.job_lists[start:end]
print(f"Launching job Dependency array {i+1} with {nb_jobs} jobs")
for job in job_array:
scheduler.create_slurm_script(job)
scheduler.launch_dependency(job_array, env_vars)
start = end
else:
# Don't use job dependecies
for job in scheduler.job_lists:
scheduler.create_slurm_script(job)
print(os.path.join(job.root_path, "job.slurm"))
subprocess.run(["sbatch", os.path.join(job.root_path, "job.slurm")], env=env_vars)
job.set_status(Status.PENDING)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Submit jobs to the cluster')
parser.add_argument('--inp_dir', type=str, help='Input directory containing the jobs')
parser.add_argument('--qos', type=str, help='QOS of the jobs')
parser.add_argument('--nb_slurm_array', type=int, default=0, help='Number of slurm arrays')
parser.add_argument('--only', type=str, default=None, help='Filter the jobs to submit')
parser.add_argument('--hf_token', type=str, required=True, help='Huggingface token')
args = parser.parse_args()
#TODO: add more option like "python slurm.py submit_jobs --...." or "python slurm.py update_jobs --...." or "python slurm.py cancel_jobs --...." or "python slurm.py check_status --...."
submit_jobs(args.inp_dir, args.qos, args.hf_token, args.nb_slurm_array, only=args.only)