Skip to content

Commit

Permalink
Add API support to get runs by id (#2157)
Browse files Browse the repository at this point in the history
* Support getting runs by id

* Test getting runs

* Add deleted field to run API response
  • Loading branch information
r4victor authored Dec 30, 2024
1 parent 4ece690 commit b315d52
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 5 deletions.
1 change: 1 addition & 0 deletions src/dstack/_internal/core/models/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ class Run(CoreModel):
service: Optional[ServiceSpec] = None
# TODO: make error a computed field after migrating to pydanticV2
error: Optional[str] = None
deleted: Optional[bool] = None

@root_validator
def _error(cls, values) -> Dict:
Expand Down
5 changes: 4 additions & 1 deletion src/dstack/_internal/server/routers/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,16 @@ async def get_run(
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
) -> Run:
"""
Returns a run given a run name.
Returns a run given `run_name` or `id`.
If given `run_name`, does not return deleted runs.
If given `id`, returns deleted runs.
"""
_, project = user_project
run = await runs.get_run(
session=session,
project=project,
run_name=body.run_name,
run_id=body.id,
)
if run is None:
raise ResourceNotExistsError("Run not found")
Expand Down
3 changes: 2 additions & 1 deletion src/dstack/_internal/server/schemas/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ class ListRunsRequest(CoreModel):


class GetRunRequest(CoreModel):
run_name: str
run_name: Optional[str] = None
id: Optional[UUID] = None


class GetRunPlanRequest(CoreModel):
Expand Down
47 changes: 44 additions & 3 deletions src/dstack/_internal/server/services/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,27 @@ async def list_projects_run_models(


async def get_run(
session: AsyncSession,
project: ProjectModel,
run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None,
) -> Optional[Run]:
if run_id is not None:
return await get_run_by_id(
session=session,
project=project,
run_id=run_id,
)
elif run_name is not None:
return await get_run_by_name(
session=session,
project=project,
run_name=run_name,
)
raise ServerClientError("run_name or id must be specified")


async def get_run_by_name(
session: AsyncSession,
project: ProjectModel,
run_name: str,
Expand All @@ -230,6 +251,25 @@ async def get_run(
return run_model_to_run(run_model, return_in_api=True)


async def get_run_by_id(
session: AsyncSession,
project: ProjectModel,
run_id: uuid.UUID,
) -> Optional[Run]:
res = await session.execute(
select(RunModel)
.where(
RunModel.project_id == project.id,
RunModel.id == run_id,
)
.options(joinedload(RunModel.user))
)
run_model = res.scalar()
if run_model is None:
return None
return run_model_to_run(run_model, return_in_api=True)


async def get_plan(
session: AsyncSession,
project: ProjectModel,
Expand All @@ -244,7 +284,7 @@ async def get_plan(
current_resource = None
action = ApplyAction.CREATE
if run_spec.run_name is not None:
current_resource = await get_run(
current_resource = await get_run_by_name(
session=session,
project=project,
run_name=run_spec.run_name,
Expand Down Expand Up @@ -333,7 +373,7 @@ async def apply_plan(
project=project,
run_spec=plan.run_spec,
)
current_resource = await get_run(
current_resource = await get_run_by_name(
session=session,
project=project,
run_name=plan.run_spec.run_name,
Expand Down Expand Up @@ -366,7 +406,7 @@ async def apply_plan(
.where(RunModel.id == current_resource.id)
.values(run_spec=plan.run_spec.json())
)
run = await get_run(
run = await get_run_by_name(
session=session,
project=project,
run_name=plan.run_spec.run_name,
Expand Down Expand Up @@ -621,6 +661,7 @@ def run_model_to_run(
jobs=jobs,
latest_job_submission=latest_job_submission,
service=service_spec,
deleted=run_model.deleted,
)
run.cost = _get_run_cost(run)
return run
Expand Down
83 changes: 83 additions & 0 deletions src/tests/_internal/server/routers/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def get_dev_env_run_dict(
last_processed_at: str = "2023-01-02T03:04:00+00:00",
finished_at: Optional[str] = "2023-01-02T03:04:00+00:00",
privileged: bool = False,
deleted: bool = False,
) -> Dict:
return {
"id": run_id,
Expand Down Expand Up @@ -369,6 +370,7 @@ def get_dev_env_run_dict(
"service": None,
"termination_reason": None,
"error": "",
"deleted": deleted,
}


Expand Down Expand Up @@ -492,6 +494,7 @@ async def test_lists_runs(self, test_db, session: AsyncSession, client: AsyncCli
"service": None,
"termination_reason": None,
"error": "",
"deleted": False,
},
{
"id": str(run2.id),
Expand All @@ -507,6 +510,7 @@ async def test_lists_runs(self, test_db, session: AsyncSession, client: AsyncCli
"service": None,
"termination_reason": None,
"error": "",
"deleted": False,
},
]

Expand Down Expand Up @@ -572,6 +576,85 @@ async def test_lists_runs_pagination(
assert response2_json[0]["id"] == str(run2.id)


class TestGetRun:
@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_returns_403_if_not_project_member(
self, test_db, session: AsyncSession, client: AsyncClient
):
user = await create_user(session=session, global_role=GlobalRole.USER)
project = await create_project(session=session, owner=user)
response = await client.post(
f"/api/project/{project.name}/runs/get",
headers=get_auth_headers(user.token),
json={"run_name": "myrun"},
)
assert response.status_code == 403

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_returns_run_given_name(
self, test_db, session: AsyncSession, client: AsyncClient
):
user = await create_user(session=session, global_role=GlobalRole.USER)
project = await create_project(session=session, owner=user)
await add_project_member(
session=session, project=project, user=user, project_role=ProjectRole.USER
)
repo = await create_repo(
session=session,
project_id=project.id,
)
run = await create_run(
session=session,
project=project,
repo=repo,
user=user,
)
response = await client.post(
f"/api/project/{project.name}/runs/get",
headers=get_auth_headers(user.token),
json={"run_name": "nonexistent_run_name"},
)
assert response.status_code == 400
response = await client.post(
f"/api/project/{project.name}/runs/get",
headers=get_auth_headers(user.token),
json={"run_name": run.run_name},
)
assert response.status_code == 200, response.json()
assert response.json()["id"] == str(run.id)

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_returns_deleted_run_given_id(
self, test_db, session: AsyncSession, client: AsyncClient
):
user = await create_user(session=session, global_role=GlobalRole.USER)
project = await create_project(session=session, owner=user)
await add_project_member(
session=session, project=project, user=user, project_role=ProjectRole.USER
)
repo = await create_repo(
session=session,
project_id=project.id,
)
run = await create_run(
session=session,
project=project,
repo=repo,
user=user,
deleted=True,
)
response = await client.post(
f"/api/project/{project.name}/runs/get",
headers=get_auth_headers(user.token),
json={"id": str(run.id)},
)
assert response.status_code == 200, response.json()
assert response.json()["id"] == str(run.id)


class TestGetRunPlan:
@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
Expand Down

0 comments on commit b315d52

Please sign in to comment.