Skip to content

Commit

Permalink
[perf] improve runs by tag lookups (dagster-io#9420)
Browse files Browse the repository at this point in the history
  • Loading branch information
alangenfeld authored Aug 18, 2022
1 parent 4ee3505 commit 2ad016f
Showing 1 changed file with 44 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import defaultdict
from datetime import datetime
from enum import Enum
from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Set, Tuple, Union

import pendulum
import sqlalchemy as db
Expand Down Expand Up @@ -232,26 +232,6 @@ def _add_filters_to_query(self, query, filters: RunsFilter):
RunsTable.c.status.in_([status.value for status in filters.statuses])
)

if filters.tags:
query = query.where(
db.or_(
*(
db.and_(
RunTagsTable.c.key == key,
(
RunTagsTable.c.value == value
if isinstance(value, str)
else RunTagsTable.c.value.in_(value)
),
)
for key, value in filters.tags.items()
)
)
).group_by(RunsTable.c.run_body, RunsTable.c.id)

if len(filters.tags) > 0:
query = query.having(db.func.count(RunsTable.c.run_id) == len(filters.tags))

if filters.snapshot_id:
query = query.where(RunsTable.c.snapshot_id == filters.snapshot_id)

Expand All @@ -263,6 +243,28 @@ def _add_filters_to_query(self, query, filters: RunsFilter):

return query

def _apply_tags_table_joins(
self,
table: db.Table,
tags: Mapping[str, Union[str, Sequence[str]]],
):
multi_join = len(tags) > 1
for key, value in tags.items():
tags_table = RunTagsTable.alias() if multi_join else RunTagsTable
table = table.join(
tags_table,
db.and_(
RunsTable.c.run_id == tags_table.c.run_id,
tags_table.c.key == key,
(
tags_table.c.value == value
if isinstance(value, str)
else tags_table.c.value.in_(value)
),
),
)
return table

def _runs_query(
self,
filters: Optional[RunsFilter] = None,
Expand All @@ -288,14 +290,14 @@ def _runs_query(
check.failed("cannot specify bucket_by and limit/cursor at the same time")
return self._bucketed_runs_query(bucket_by, filters, columns, order_by, ascending)

query_columns = [getattr(RunsTable.c, column) for column in columns]
if filters.tags:
base_query = db.select(query_columns).select_from(
RunsTable.join(RunTagsTable, RunsTable.c.run_id == RunTagsTable.c.run_id)
)
table = self._apply_tags_table_joins(RunsTable, filters.tags)
else:
base_query = db.select(query_columns).select_from(RunsTable)
table = RunsTable

base_query = db.select([getattr(RunsTable.c, column) for column in columns]).select_from(
table
)
base_query = self._add_filters_to_query(base_query, filters)
return self._add_cursor_limit_to_query(base_query, cursor, limit, order_by, ascending)

Expand Down Expand Up @@ -328,47 +330,37 @@ def _bucketed_runs_query(

if isinstance(bucket_by, JobBucket):
# bucketing by job
base_query = (
db.select(query_columns)
.select_from(
RunsTable.join(RunTagsTable, RunsTable.c.run_id == RunTagsTable.c.run_id)
if filters.tags
else RunsTable
)
.where(RunsTable.c.pipeline_name.in_(bucket_by.job_names))
)
if filters.tags:
table = self._apply_tags_table_joins(RunsTable, filters.tags)
else:
table = RunsTable

base_query = db.select(query_columns).select_from(table)
base_query = base_query.where(RunsTable.c.pipeline_name.in_(bucket_by.job_names))
base_query = self._add_filters_to_query(base_query, filters)

elif not filters.tags:
# bucketing by tag, no tag filters
base_query = (
db.select(query_columns)
.select_from(
RunsTable.join(RunTagsTable, RunsTable.c.run_id == RunTagsTable.c.run_id)
base_query = db.select(query_columns).select_from(
self._apply_tags_table_joins(
RunsTable,
{bucket_by.tag_key: bucket_by.tag_values},
)
.where(RunTagsTable.c.key == bucket_by.tag_key)
.where(RunTagsTable.c.value.in_(bucket_by.tag_values))
)
base_query = self._add_filters_to_query(base_query, filters)

else:
# there are tag filters as well as tag buckets, so we have to apply the tag filters in
# a separate join
filtered_query = db.select([RunsTable.c.run_id]).select_from(
RunsTable.join(RunTagsTable, RunsTable.c.run_id == RunTagsTable.c.run_id)
self._apply_tags_table_joins(RunsTable, filters.tags)
)
filtered_query = self._add_filters_to_query(filtered_query, filters)
filtered_query = filtered_query.alias("filtered_query")

base_query = (
db.select(query_columns)
.select_from(
RunsTable.join(RunTagsTable, RunsTable.c.run_id == RunTagsTable.c.run_id).join(
filtered_query, RunsTable.c.run_id == filtered_query.c.run_id
)
)
.where(RunTagsTable.c.key == bucket_by.tag_key)
.where(RunTagsTable.c.value.in_(bucket_by.tag_values))
base_query = db.select(query_columns).select_from(
self._apply_tags_table_joins(
RunsTable, {bucket_by.tag_key: bucket_by.tag_values}
).join(filtered_query, RunsTable.c.run_id == filtered_query.c.run_id)
)

subquery = base_query.alias("subquery")
Expand Down

0 comments on commit 2ad016f

Please sign in to comment.