Skip to content

Commit

Permalink
Refactor MetricMixin.get_top_by_count for testability
Browse files Browse the repository at this point in the history
  • Loading branch information
sloria committed Oct 18, 2018
1 parent 767d4a9 commit 2e38abf
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 17 deletions.
47 changes: 30 additions & 17 deletions osf/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,29 @@

class MetricMixin(object):

@classmethod
def _get_id_to_count(cls, size, metric_field, count_field, after=None):
"""Performs the elasticsearch aggregation for get_top_by_count. Return a
dict mapping ids to summed counts. If there's no data in the ES index, return None.
"""
search = cls.search()
if after:
search = search.filter('range', timestamp={'gte': after})
search.aggs.\
bucket('by_id', 'terms', field=metric_field, order={'sum_count': 'desc'}).\
metric('sum_count', 'sum', field=count_field)
# Optimization: set size to 0 so that hits aren't returned (we only care about the aggregation)
response = search.extra(size=0).execute()
# No indexed data
if not hasattr(response.aggregations, 'by_id'):
return None
buckets = response.aggregations.by_id.buckets
# Map _id => count
return {
bucket.key: int(bucket.sum_count.value)
for bucket in buckets
}

@classmethod
def get_top_by_count(cls, qs, model_field, metric_field,
size=None, order_by=None,
Expand Down Expand Up @@ -39,24 +62,14 @@ def get_top_by_count(cls, qs, model_field, metric_field,
:param str count_field: Name of the field where count values are stored.
:param str annotation: Name of the annotation.
"""
search = cls.search()
size = size or qs.count()
if after:
search = search.filter('range', timestamp={'gte': after})
search.aggs.\
bucket('by_id', 'terms', field=metric_field, order={'sum_count': 'desc'}).\
metric('sum_count', 'sum', field=count_field)
# Optimization: set size to 0 so that hits aren't returns (we only care about the aggregation)
response = search.extra(size=0).execute()
# No indexed data
if not hasattr(response.aggregations, 'by_id'):
id_to_count = cls._get_id_to_count(
size=size or qs.count(),
metric_field=metric_field,
count_field=count_field,
after=after,
)
if id_to_count is None:
return qs.annotate(**{annotation: models.Value(0, models.IntegerField())})
buckets = response.aggregations.by_id.buckets
# Map _id => count
id_to_count = {
bucket.key: int(bucket.sum_count.value)
for bucket in buckets
}
# Annotate the queryset with the counts for each id
# https://stackoverflow.com/a/48187723/1157536
whens = [
Expand Down
34 changes: 34 additions & 0 deletions osf_tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import mock
import pytest
from elasticsearch_metrics import metrics

from osf.metrics import MetricMixin
from osf.models import OSFUser
from osf_tests.factories import UserFactory

class DummyMetric(MetricMixin, metrics.Metric):
count = metrics.Integer(doc_values=True, index=True, required=True)
user_id = metrics.Keyword(index=True, doc_values=True, required=False)

class Meta:
app_label = 'osf'

@pytest.mark.django_db
@mock.patch.object(DummyMetric, '_get_id_to_count')
def test_get_top_by_count(mock_get_id_to_count):
user1, user2 = UserFactory(), UserFactory()
mock_get_id_to_count.return_value = {
user1._id: 41,
user2._id: 42,
}

metric_qs = DummyMetric.get_top_by_count(
qs=OSFUser.objects.all(),
model_field='guids___id',
metric_field='user_id',
annotation='dummies',
)

annotated_user = metric_qs.first()
assert annotated_user._id == user2._id
assert annotated_user.dummies == 42

0 comments on commit 2e38abf

Please sign in to comment.