Skip to content

Commit

Permalink
Nx test matrix speedup via multiprocessing, N = CPU count (semgrep#2224)
Browse files Browse the repository at this point in the history
  • Loading branch information
mschwager authored Dec 10, 2020
1 parent c140887 commit 55a918d
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 55 deletions.
101 changes: 61 additions & 40 deletions scripts/generate_test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import glob
import io
import json
import multiprocessing
import os
import subprocess
import sys
Expand Down Expand Up @@ -193,52 +194,72 @@ def run_semgrep_on_example(lang: str, config_arg_str: str, code_path: str) -> st
# sys.exit(1)


def invoke_semgrep_multi(semgrep_path, code_path, lang, category, subcategory):
result = run_semgrep_on_example(lang, semgrep_path, code_path)
return (
semgrep_path,
code_path,
lang,
category,
subcategory,
result,
)


def paths_exist(*paths):
return all(os.path.exists(path) for path in paths)


def generate_cheatsheet(root_dir: str, html: bool):
# output : {'dots': {'arguments': ['foo(...)', 'foo(1)'], } }
output = collections.defaultdict(
lambda: collections.defaultdict(lambda: collections.defaultdict(list))
)
langs = get_language_directories(root_dir)
for lang in langs:
for category, subcategories in CHEATSHEET_ENTRIES.items():
for subcategory in subcategories:

sgrep_path = find_path(root_dir, lang, category, subcategory, "sgrep")
code_path = find_path(
root_dir, lang, category, subcategory, lang_dir_to_ext(lang)
)

highlights = []
if os.path.exists(sgrep_path) and os.path.exists(code_path):
ranges = run_semgrep_on_example(lang, sgrep_path, code_path)
if ranges:
j = json.loads(ranges)
for entry in j["results"]:
highlights.append(
{"start": entry["start"], "end": entry["end"]}
)

entry = {
"pattern": read_if_exists(sgrep_path),
"pattern_path": os.path.relpath(sgrep_path, root_dir),
"code": read_if_exists(code_path),
"code_path": os.path.relpath(code_path, root_dir),
"highlights": highlights,
}

if html:
entry["pattern_path"] = os.path.relpath(sgrep_path)
entry["code_path"] = os.path.relpath(code_path)

feature_name = VERBOSE_FEATURE_NAME.get(category, category)
subcategory_name = VERBOSE_SUBCATEGORY_NAME.get(
subcategory, subcategory
)
language_exception = feature_name in LANGUAGE_EXCEPTIONS.get(
lang, []
) or subcategory in LANGUAGE_EXCEPTIONS.get(lang, [])
if not language_exception:
output[lang][feature_name][subcategory_name].append(entry)
semgrep_multi_args = [
(
find_path(root_dir, lang, category, subcategory, "sgrep"),
find_path(root_dir, lang, category, subcategory, lang_dir_to_ext(lang)),
lang,
category,
subcategory,
)
for lang in langs
for category, subcategories in CHEATSHEET_ENTRIES.items()
for subcategory in subcategories
if paths_exist(
find_path(root_dir, lang, category, subcategory, "sgrep"),
find_path(root_dir, lang, category, subcategory, lang_dir_to_ext(lang)),
)
]
with multiprocessing.Pool(multiprocessing.cpu_count()) as pool:
results = pool.starmap(invoke_semgrep_multi, semgrep_multi_args)

for semgrep_path, code_path, lang, category, subcategory, result in results:
highlights = []
if result:
j = json.loads(result)
for entry in j["results"]:
highlights.append({"start": entry["start"], "end": entry["end"]})

entry = {
"pattern": read_if_exists(semgrep_path),
"pattern_path": os.path.relpath(semgrep_path, root_dir),
"code": read_if_exists(code_path),
"code_path": os.path.relpath(code_path, root_dir),
"highlights": highlights,
}

if html:
entry["pattern_path"] = os.path.relpath(semgrep_path)
entry["code_path"] = os.path.relpath(code_path)

feature_name = VERBOSE_FEATURE_NAME.get(category, category)
subcategory_name = VERBOSE_SUBCATEGORY_NAME.get(subcategory, subcategory)
feature_exception = feature_name in LANGUAGE_EXCEPTIONS.get(lang, [])
subcategory_exception = subcategory in LANGUAGE_EXCEPTIONS.get(lang, [])
if not feature_exception and not subcategory_exception:
output[lang][feature_name][subcategory_name].append(entry)

return output

Expand Down
16 changes: 4 additions & 12 deletions semgrep/semgrep/core_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import functools
import json
import logging
import multiprocessing
import re
import subprocess
import tempfile
from datetime import datetime
from multiprocessing import pool
from pathlib import Path
from typing import Any
from typing import cast
Expand Down Expand Up @@ -139,14 +139,12 @@ def __init__(
timeout: int,
max_memory: int,
timeout_threshold: int,
testing: bool = False,
):
self._allow_exec = allow_exec
self._jobs = jobs
self._timeout = timeout
self._max_memory = max_memory
self._timeout_threshold = timeout_threshold
self._testing = testing

def _flatten_rule_patterns(self, rules: List[Rule]) -> Iterator[Pattern]:
"""
Expand Down Expand Up @@ -463,15 +461,9 @@ def handle_regex_patterns(
except re.error as err:
raise SemgrepError(f"invalid regular expression specified: {err}")

if self._testing:
# Testing functionality runs in a multiprocessing.Pool. We cannot run
# a Pool inside a Pool, so we have to avoid multiprocessing when testing.
# https://stackoverflow.com/questions/6974695/python-process-pool-non-daemonic
matches = [get_re_matches(patterns_re, target) for target in targets]
else:
re_fn = functools.partial(get_re_matches, patterns_re)
with multiprocessing.Pool(self._jobs) as pool:
matches = pool.map(re_fn, targets)
re_fn = functools.partial(get_re_matches, patterns_re)
with pool.ThreadPool(self._jobs) as tpool:
matches = tpool.map(re_fn, targets)

outputs.extend(
single_match for file_matches in matches for single_match in file_matches
Expand Down
2 changes: 0 additions & 2 deletions semgrep/semgrep/semgrep_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ def main(
max_memory: int = 0,
timeout_threshold: int = 0,
skip_unknown_extensions: bool = False,
testing: bool = False,
severity: Optional[List[str]] = None,
) -> None:
if include is None:
Expand Down Expand Up @@ -246,7 +245,6 @@ def main(
timeout=timeout,
max_memory=max_memory,
timeout_threshold=timeout_threshold,
testing=testing,
).invoke_semgrep(
target_manager, filtered_rules
)
Expand Down
1 change: 0 additions & 1 deletion semgrep/semgrep/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,6 @@ def generate_file_pairs(
no_rewrite_rule_ids=True,
strict=strict,
dangerously_allow_arbitrary_code_execution_from_rules=unsafe,
testing=True,
)
with multiprocessing.Pool(multiprocessing.cpu_count()) as pool:
results = pool.starmap(invoke_semgrep_fn, config_with_tests)
Expand Down

0 comments on commit 55a918d

Please sign in to comment.