Skip to content

Commit

Permalink
fix filter pagination. it won't work for sqlite
Browse files Browse the repository at this point in the history
  • Loading branch information
llazzaro committed Aug 27, 2020
1 parent 0086f07 commit 6d5c661
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 44 deletions.
69 changes: 42 additions & 27 deletions faraday/server/api/modules/vulns.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,15 +730,14 @@ def _hostname_filters(self, filters):

return res_filters, hostname_filters

def _filter_vulns(self, vulnerability_class, filters, hostname_filters, workspace, marshmallow_params, is_web, limit=None, offset=None):
def _filter_vulns(self, vulnerability_class, filters, hostname_filters, workspace, marshmallow_params, is_web):
hosts_os_filter = [host_os_filter for host_os_filter in filters.get('filters', []) if host_os_filter.get('name') == 'host__os']

if hosts_os_filter:
# remove host__os filters from filters due to a bug
hosts_os_filter = hosts_os_filter[0]
filters['filters'] = [host_os_filter for host_os_filter in filters.get('filters', []) if host_os_filter.get('name') != 'host__os']

# SQLAlchemy query can't be extended with filters after applying limits/offsets
vulns = search(db.session,
vulnerability_class,
filters)
Expand Down Expand Up @@ -767,29 +766,13 @@ def _filter_vulns(self, vulnerability_class, filters, hostname_filters, workspac

else:
_type = 'Vulnerability'
if limit:
vulns = vulns.limit(limit)
if offset:
vulns = vulns.offset(offset)
if 'group_by' not in filters:
vulns = vulns.options(
joinedload(VulnerabilityGeneric.tags),
joinedload(Vulnerability.host),
joinedload(Vulnerability.service),
joinedload(VulnerabilityWeb.service),
)

vulns = self.schema_class_dict[_type](**marshmallow_params).dumps(
vulns.all())
vulns_data = json.loads(vulns)
else:
column_names = ['count'] + [field['field'] for field in filters.get('group_by',[])]
rows = [list(zip(column_names, row)) for row in vulns.all()]
vulns_data = []
for row in rows:
vulns_data.append({field[0]:field[1] for field in row})

return vulns_data
return vulns

def _filter(self, filters, workspace_name, confirmed=False):
try:
Expand Down Expand Up @@ -824,21 +807,53 @@ def _filter(self, filters, workspace_name, confirmed=False):
hostname_filters,
workspace,
marshmallow_params,
is_web=False,
limit=limit,
offset=offset)
is_web=False)

web_vulns_data = self._filter_vulns(
VulnerabilityWeb,
filters,
hostname_filters,
workspace,
marshmallow_params,
is_web=True,
limit=limit,
offset=offset)
return normal_vulns_data + web_vulns_data
is_web=True)

if db.engine.dialect.name == 'postgresql':
vulns = normal_vulns_data.union(web_vulns_data)
# postgresql pagination with offset and limit need to order by a field
# to guarentee that all pages returns all objects
# without order by postgresql could repeat rows
vulns = vulns.order_by(VulnerabilityGeneric.id)

if limit:
vulns = vulns.limit(limit)
if offset:
vulns = vulns.offset(offset)
vulns = self.schema_class_dict['VulnerabilityWeb'](**marshmallow_params).dumps(
vulns.all())
return json.loads(vulns)
else:

normal_vulns = self.schema_class_dict['VulnerabilityWeb'](**marshmallow_params).dumps(
normal_vulns_data.all())

web_vulns = self.schema_class_dict['VulnerabilityWeb'](**marshmallow_params).dumps(
web_vulns_data.all())
return json.loads(normal_vulns) + json.loads(web_vulns)
else:
vulns_data = self._filter_vulns(VulnerabilityGeneric, filters, hostname_filters, workspace, marshmallow_params, False)
vulns = self._filter_vulns(
VulnerabilityGeneric,
filters,
hostname_filters,
workspace,
marshmallow_params,
False
)
column_names = ['count'] + [field['field'] for field in filters.get('group_by',[])]
rows = [list(zip(column_names, row)) for row in vulns.all()]
vulns_data = []
for row in rows:
vulns_data.append({field[0]:field[1] for field in row})

return vulns_data

@route('/<int:vuln_id>/attachment/<attachment_filename>/', methods=['GET'])
Expand Down
19 changes: 10 additions & 9 deletions faraday/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,16 @@ class VulnerabilityGeneric(VulnerabilityABC):
association_date = Column(DateTime, nullable=True)
disassociated_manually = Column(Boolean, nullable=False, default=False)
tool = BlankColumn(Text, nullable=False)
method = BlankColumn(Text)
parameters = BlankColumn(Text)
parameter_name = BlankColumn(Text)
path = BlankColumn(Text)
query_string = BlankColumn(Text)
request = BlankColumn(Text)
response = BlankColumn(Text)
website = BlankColumn(Text)
status_code = Column(Integer, nullable=True)


vulnerability_duplicate_id = Column(
Integer,
Expand Down Expand Up @@ -1118,15 +1128,6 @@ def parent(self):

class VulnerabilityWeb(VulnerabilityGeneric):
__tablename__ = None
method = BlankColumn(Text)
parameters = BlankColumn(Text)
parameter_name = BlankColumn(Text)
path = BlankColumn(Text)
query_string = BlankColumn(Text)
request = BlankColumn(Text)
response = BlankColumn(Text)
website = BlankColumn(Text)
status_code = Column(Integer, nullable=True)

@declared_attr
def service_id(cls):
Expand Down
21 changes: 13 additions & 8 deletions tests/test_api_vulnerability.py
Original file line number Diff line number Diff line change
Expand Up @@ -2921,6 +2921,7 @@ def test_search_by_hostname_vulns_with_service(self, test_client, session):
assert res.json['vulnerabilities'][0]['id'] == vuln.id

@pytest.mark.skip_sql_dialect('sqlite')
@pytest.mark.usefixtures('ignore_nplusone')
def test_search_hostname_web_vulns(self, test_client, session):
workspace = WorkspaceFactory.create()
host = HostFactory.create(workspace=workspace)
Expand Down Expand Up @@ -2977,8 +2978,8 @@ def test_search_by_hostname_multiple_logic(self, test_client, session):
assert res.json['count'] == 1
assert res.json['vulnerabilities'][0]['id'] == vuln.id

@pytest.mark.skip(reason="Refactor to remove VulnerabilityWeb")
@pytest.mark.skip_sql_dialect('sqlite')
@pytest.mark.usefixtures('ignore_nplusone')
def test_search_filter_offset_and_limit_mixed_vulns_type_bug(self, test_client, session):
workspace = WorkspaceFactory.create()
host = HostFactory.create(workspace=workspace)
Expand All @@ -2987,31 +2988,32 @@ def test_search_filter_offset_and_limit_mixed_vulns_type_bug(self, test_client,
severity='high'
)
session.add_all(vulns)
vulns = VulnerabilityWebFactory.create_batch(10,
web_vulns = VulnerabilityWebFactory.create_batch(10,
workspace=workspace,
severity='high'
)
session.add_all(vulns)
session.add_all(web_vulns)
session.add(host)
session.commit()
paginated_vulns = set()
expected_vulns = set([vuln.id for vuln in vulns])
for offset in range(0, 10):
expected_vulns = set([vuln.id for vuln in vulns] + [vuln.id for vuln in web_vulns])
for offset in range(0, 2):
query_filter = {
"filters":[{"name":"severity","op":"eq","val":"high"}],
"limit": 10,
"offset": 10 * offset,
"offset": offset * 10,
}
res = test_client.get(
'/v2/ws/{}/vulns/filter?q={}'.format(workspace.name, json.dumps(query_filter)))
assert res.status_code == 200
assert res.json['count'] == 10 # Before the refactor this return 20, where it should return 10
print(res.json['vulnerabilities'][0]['id'])
assert res.json['count'] == 10, query_filter # Before the refactor this return 20, where it should return 10
for vuln in res.json['vulnerabilities']:
print(vuln['id'])
paginated_vulns.add(vuln['id'])
assert expected_vulns == paginated_vulns

@pytest.mark.skip_sql_dialect('sqlite')
@pytest.mark.usefixtures('ignore_nplusone')
def test_search_filter_offset_and_limit_page_size_10(self, test_client, session):
workspace = WorkspaceFactory.create()
host = HostFactory.create(workspace=workspace)
Expand Down Expand Up @@ -3040,6 +3042,7 @@ def test_search_filter_offset_and_limit_page_size_10(self, test_client, session)
assert expected_vulns == paginated_vulns

@pytest.mark.skip_sql_dialect('sqlite')
@pytest.mark.usefixtures('ignore_nplusone')
def test_search_filter_offset_and_limit(self, test_client, session):
workspace = WorkspaceFactory.create()
host = HostFactory.create(workspace=workspace)
Expand Down Expand Up @@ -3077,6 +3080,7 @@ def test_search_filter_offset_and_limit(self, test_client, session):
assert expected_vulns == paginated_vulns

@pytest.mark.skip_sql_dialect('sqlite')
@pytest.mark.usefixtures('ignore_nplusone')
def test_search_by_host_os_with_vulnerability_web_bug(self, test_client, session):
"""
When searching by the host os an error was raised when a vuln web exists in the ws
Expand Down Expand Up @@ -3112,6 +3116,7 @@ def test_search_by_host_os_with_vulnerability_web_bug(self, test_client, session
assert res.json['vulnerabilities'][0]['id'] == vuln.id

@pytest.mark.skip_sql_dialect('sqlite')
@pytest.mark.usefixtures('ignore_nplusone')
def test_search_by_date_equals(self, test_client, session):
"""
When searching by the host os an error was raised when a vuln web exists in the ws
Expand Down

0 comments on commit 6d5c661

Please sign in to comment.