Skip to content

Commit

Permalink
pagination improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Dec 22, 2021
1 parent 3ffe35c commit 9230133
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 55 deletions.
12 changes: 6 additions & 6 deletions api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
| `SECRET_KEY` | `top-secret!` | A secret key used when signing tokens. |
| `DATABASE_URL` | `sqlite:///db.sqlite` | The database URL, as defined by the [SQLAlchemy](https://docs.sqlalchemy.org/en/14/core/engines.html#database-urls) framework. |
| `SQL_ECHO` | not defined | Whether to echo SQL statements to the console for debugging purposes. |
| `DISABLE_AUTH` | not defined | Whether to disable authentication. When running with authentication disabled, the user is assumed to be logged as the first user in the database. |
| `DISABLE_AUTH` | not defined | Whether to disable authentication. When running with authentication disabled, the user is assumed to be logged as the user with `id=1`, which must exist in the database. |
| `ACCESS_TOKEN_EXPIRATION` | `60` (1 hour) | The number of minutes an access token is valid for. |
| `REFRESH_TOKEN_EXPIRATION` | `1440` (24 hours) | The number of minutes a refresh token is valid for. |
| `RESET_TOKEN_EXPIRATION` | `15` (15 minutes) | The number of minutes a reset token is valid for. |
Expand Down Expand Up @@ -80,10 +80,11 @@
### Password Resets
This API also supports a password reset flow, to help users who forget their
passwords. To issue a password request, the client must send a `POST` request
to `/api/tokens/reset`, passing the user's email in the body of the request.
The user will receive a password reset link by email, which is the request's
This API supports a password reset flow, to help users who forget their
passwords regain access to their accounts. To issue a password reset request,
the client must send a `POST` request to `/api/tokens/reset`, passing the
user's email in the body of the request. The user will receive a password reset
link by email, which is the request's
[Referrer](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Referer)
URL, with an added `token` query string argument set to an email reset token,
with a validity of 15 minutes.
Expand Down Expand Up @@ -115,7 +116,6 @@
"errors": [ <error details>, ... ]
}
```
""" # noqa: E501

from api.app import create_app, db, ma # noqa: F401
49 changes: 34 additions & 15 deletions api/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,63 @@
from apifairy import arguments, response
import sqlalchemy as sqla
from api.app import db
from api.schemas import PaginationRequestSchema, PaginatedCollection
from api.schemas import StringPaginationSchema, PaginatedCollection


def paginated_response(schema, max_limit=10, model_from_statement=None):
def paginated_response(schema, max_limit=25, model_from_statement=None,
order_by=None, order_direction='asc',
pagination_schema=StringPaginationSchema):
def inner(f):
@wraps(f)
def paginate(*args, **kwargs):
args = list(args)
pagination = args.pop(-1)
select_query = f(*args, **kwargs)
if order_by is not None:
o = order_by.desc() if order_direction == 'desc' else order_by
select_query = select_query.order_by(o)

count = db.session.scalar(sqla.select(
sqla.func.count()).select_from(select_query))

offset = pagination.get('offset', 0)
limit = pagination.get('limit', max_limit)
offset = pagination.get('offset')
after = pagination.get('after')
if limit > max_limit:
limit = max_limit

if offset < 0 or (count > 0 and offset >= count) or limit <= 0:
abort(400)

if model_from_statement:
data = db.session.scalars(
model_from_statement.select().from_statement(
select_query.limit(limit).offset(offset))).all()
if after is not None:
if offset is not None or order_by is None: # pragma: no cover
abort(400)
if order_direction != 'desc':
order_condition = order_by > after
offset_condition = order_by <= after
else:
order_condition = order_by < after
offset_condition = order_by >= after
query = select_query.limit(limit).filter(order_condition)
offset = db.session.scalar(sqla.select(
sqla.func.count()).select_from(select_query.filter(
offset_condition)))
else:
data = db.session.scalars(select_query.limit(limit).offset(
offset)).all()
if offset is None:
offset = 0
if offset < 0 or (count > 0 and offset >= count) or limit <= 0:
abort(400)

query = select_query.limit(limit).offset(offset)
if model_from_statement:
query = model_from_statement.select().from_statement(query)

data = db.session.scalars(query).all()
return {'data': data, 'pagination': {
'offset': offset,
'limit': limit,
'count': len(data),
'total': count,
}}

# wrap with APIFairy's arguments and response decorators
return arguments(PaginationRequestSchema)(response(PaginatedCollection(
schema))(paginate))
return arguments(pagination_schema)(response(PaginatedCollection(
schema, pagination_schema=pagination_schema))(paginate))

return inner
2 changes: 1 addition & 1 deletion api/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def posts(num): # pragma: no cover
for i in range(num):
user = random.choice(users)
post = Post(body=faker.paragraph(), author=user,
timestamp=faker.date_this_year())
timestamp=faker.date_time_this_year())
db.session.add(post)
db.session.commit()
print(num, 'posts added.')
13 changes: 10 additions & 3 deletions api/posts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from api.schemas import PostSchema
from api.auth import token_auth
from api.decorators import paginated_response
from api.schemas import DateTimePaginationSchema

posts = Blueprint('posts', __name__)
post_schema = PostSchema()
Expand Down Expand Up @@ -36,15 +37,19 @@ def get(id):


@posts.route('/posts', methods=['GET'])
@paginated_response(posts_schema)
@paginated_response(posts_schema, order_by=Post.timestamp,
order_direction='desc',
pagination_schema=DateTimePaginationSchema)
def all():
"""Retrieve all posts"""
return Post.select()


@posts.route('/users/<int:id>/posts', methods=['GET'])
@authenticate(token_auth)
@paginated_response(posts_schema)
@paginated_response(posts_schema, order_by=Post.timestamp,
order_direction='desc',
pagination_schema=DateTimePaginationSchema)
@other_responses({404: 'User not found'})
def user_all(id):
"""Retrieve all posts from a user"""
Expand Down Expand Up @@ -83,7 +88,9 @@ def delete(id):

@posts.route('/posts/timeline', methods=['GET'])
@authenticate(token_auth)
@paginated_response(posts_schema, model_from_statement=Post)
@paginated_response(posts_schema, order_by=Post.timestamp,
order_direction='desc', model_from_statement=Post,
pagination_schema=DateTimePaginationSchema)
def timeline():
"""Retrieve the user's post timeline"""
user = token_auth.current_user()['user']
Expand Down
28 changes: 21 additions & 7 deletions api/schemas.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from marshmallow import validates_schema, ValidationError
from api import ma
from api.models import User, Post

Expand All @@ -8,34 +9,47 @@ class EmptySchema(ma.Schema):
pass


class PaginationRequestSchema(ma.Schema):
class DateTimePaginationSchema(ma.Schema):
class Meta:
ordered = True

offset = ma.Integer()
limit = ma.Integer()
offset = ma.Integer()
after = ma.DateTime(load_only=True)
count = ma.Integer(dump_only=True)
total = ma.Integer(dump_only=True)

@validates_schema
def validate_schema(self, data, **kwargs):
if data.get('offset') is not None and data.get('after') is not None:
raise ValidationError('Cannot specify both offset and after')


class PaginationSchema(ma.Schema):
class StringPaginationSchema(ma.Schema):
class Meta:
ordered = True

limit = ma.Integer()
offset = ma.Integer()
count = ma.Integer()
after = ma.String(load_only=True)
count = ma.Integer(dump_only=True)
total = ma.Integer(dump_only=True)

@validates_schema
def validate_schema(self, data, **kwargs):
if data.get('offset') is not None and data.get('after') is not None:
raise ValidationError('Cannot specify both offset and after')


def PaginatedCollection(schema):
def PaginatedCollection(schema, pagination_schema=StringPaginationSchema):
if schema in paginated_schema_cache:
return paginated_schema_cache[schema]

class PaginatedSchema(ma.Schema):
class Meta:
ordered = True
name = 'Foo'

pagination = ma.Nested(PaginationSchema)
pagination = ma.Nested(pagination_schema)
data = ma.Nested(schema, many=True)

PaginatedSchema.__name__ = 'Paginated{}'.format(schema.__class__.__name__)
Expand Down
26 changes: 13 additions & 13 deletions api/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,6 @@ def put(data):
return user


@users.route('/users/<int:id>/following', methods=['GET'])
@authenticate(token_auth)
@paginated_response(users_schema)
@other_responses({404: 'User not found'})
def following(id):
"""Retrieve the users this user is following"""
user = db.session.get(User, id) or abort(404)
return user.following_select()


@users.route('/users/me/following/<int:id>', methods=['POST'])
@authenticate(token_auth)
@response(EmptySchema, status_code=204,
Expand Down Expand Up @@ -105,9 +95,19 @@ def unfollow(id):
return {}


@users.route('/users/<int:id>/following', methods=['GET'])
@authenticate(token_auth)
@paginated_response(users_schema, order_by=User.username)
@other_responses({404: 'User not found'})
def following(id):
"""Retrieve the users this user is following"""
user = db.session.get(User, id) or abort(404)
return user.following_select()


@users.route('/users/me/following', methods=['GET'])
@authenticate(token_auth)
@paginated_response(users_schema)
@paginated_response(users_schema, order_by=User.username)
def my_following():
"""Retrieve the users the logged in user is following"""
user = token_auth.current_user()['user']
Expand All @@ -116,7 +116,7 @@ def my_following():

@users.route('/users/<int:id>/followers', methods=['GET'])
@authenticate(token_auth)
@paginated_response(users_schema)
@paginated_response(users_schema, order_by=User.username)
@other_responses({404: 'User not found'})
def followers(id):
"""Retrieve the followers of the user"""
Expand All @@ -126,7 +126,7 @@ def followers(id):

@users.route('/users/me/followers', methods=['GET'])
@authenticate(token_auth)
@paginated_response(users_schema)
@paginated_response(users_schema, order_by=User.username)
def my_followers():
"""Retrieve the followers of the logged in user"""
user = token_auth.current_user()['user']
Expand Down
2 changes: 1 addition & 1 deletion tests/base_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class TestConfig(Config):
SERVER_NAME = 'localhost:5000'
TESTING = True
DISABLE_AUTH = True
SQLALCHEMY_DATABASE_URI = 'sqlite://'
ALCHEMICAL_DATABASE_URL = 'sqlite://'


class TestConfigWithAuth(TestConfig):
Expand Down
63 changes: 54 additions & 9 deletions tests/test_pagination.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import datetime, timedelta
from api.app import db
from api.models import User, Post
from tests.base_test_case import BaseTestCase
Expand All @@ -7,27 +8,36 @@ class PaginationTests(BaseTestCase):
def setUp(self):
super().setUp()
user = db.session.get(User, 1)
tm = datetime.utcnow()
for i in range(105):
post = Post(body=f'Post {i + 1}', author=user)
tm -= timedelta(minutes=1)
post = Post(body=f'Post {i + 1}', author=user, timestamp=tm)
db.session.add(post)
for i in range(26):
follower = User(username=chr(ord('a') + i),
email=f'{chr(ord("a") + i)}@example.com')
db.session.add(follower)
follower.follow(user)
db.session.commit()

def test_pagination_default(self):
rv = self.client.get('/api/posts')
assert rv.status_code == 200
assert rv.json['pagination']['total'] == 105
assert rv.json['pagination']['offset'] == 0
assert rv.json['pagination']['count'] == 10
assert len(rv.json['data']) == 10
assert rv.json['pagination']['count'] == 25
assert rv.json['pagination']['limit'] == 25
assert len(rv.json['data']) == 25
assert rv.json['data'][0]['body'] == 'Post 1'
assert rv.json['data'][9]['body'] == 'Post 10'
assert rv.json['data'][24]['body'] == 'Post 25'

def test_pagination_page(self):
rv = self.client.get('/api/posts?offset=30')
rv = self.client.get('/api/posts?offset=30&limit=10')
assert rv.status_code == 200
assert rv.json['pagination']['total'] == 105
assert rv.json['pagination']['offset'] == 30
assert rv.json['pagination']['count'] == 10
assert rv.json['pagination']['limit'] == 10
assert len(rv.json['data']) == 10
assert rv.json['data'][0]['body'] == 'Post 31'
assert rv.json['data'][9]['body'] == 'Post 40'
Expand All @@ -38,6 +48,7 @@ def test_pagination_last(self):
assert rv.json['pagination']['total'] == 105
assert rv.json['pagination']['offset'] == 99
assert rv.json['pagination']['count'] == 6
assert rv.json['pagination']['limit'] == 25
assert len(rv.json['data']) == 6
assert rv.json['data'][0]['body'] == 'Post 100'
assert rv.json['data'][5]['body'] == 'Post 105'
Expand All @@ -58,16 +69,50 @@ def test_pagination_custom_limit(self):
assert rv.json['pagination']['total'] == 105
assert rv.json['pagination']['offset'] == 16
assert rv.json['pagination']['count'] == 5
assert rv.json['pagination']['limit'] == 5
assert len(rv.json['data']) == 5
assert rv.json['data'][0]['body'] == 'Post 17'
assert rv.json['data'][4]['body'] == 'Post 21'

def test_pagination_large_per_page(self):
rv = self.client.get('/api/posts?offset=37&limit=25')
rv = self.client.get('/api/posts?offset=37&limit=50')
assert rv.status_code == 200
assert rv.json['pagination']['total'] == 105
assert rv.json['pagination']['offset'] == 37
assert rv.json['pagination']['count'] == 10
assert len(rv.json['data']) == 10
assert rv.json['pagination']['count'] == 25
assert rv.json['pagination']['limit'] == 25
assert len(rv.json['data']) == 25
assert rv.json['data'][0]['body'] == 'Post 38'
assert rv.json['data'][9]['body'] == 'Post 47'
assert rv.json['data'][24]['body'] == 'Post 62'

def test_pagination_offset_and_after(self):
rv = self.client.get('/api/posts?offset=37&after=2021-01-01T00:00:00')
assert rv.status_code == 400
rv = self.client.get('/api/users/1/following?offset=37&after=foo')
assert rv.status_code == 400

def test_pagination_after_desc(self):
rv = self.client.get('/api/posts')
assert rv.status_code == 200
tm = rv.json['data'][5]['timestamp']

rv = self.client.get(f'/api/posts?after={tm}')
assert rv.status_code == 200
assert rv.json['pagination']['total'] == 105
assert rv.json['pagination']['offset'] == 6
assert rv.json['pagination']['count'] == 25
assert rv.json['pagination']['limit'] == 25
assert len(rv.json['data']) == 25
assert rv.json['data'][0]['body'] == 'Post 7'
assert rv.json['data'][24]['body'] == 'Post 31'

def test_pagination_after_asc(self):
rv = self.client.get('/api/users/1/followers?after=g')
assert rv.status_code == 200
assert rv.json['pagination']['total'] == 26
assert rv.json['pagination']['offset'] == 7
assert rv.json['pagination']['count'] == 19
assert rv.json['pagination']['limit'] == 25
assert len(rv.json['data']) == 19
assert rv.json['data'][0]['username'] == 'h'
assert rv.json['data'][-1]['username'] == 'z'

0 comments on commit 9230133

Please sign in to comment.