Skip to content

Commit

Permalink
feat: improve test function classification (foundry-rs#8235)
Browse files Browse the repository at this point in the history
  • Loading branch information
DaniPopes authored Jun 23, 2024
1 parent 7074d20 commit ba9fa20
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 126 deletions.
1 change: 0 additions & 1 deletion crates/cli/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ impl<T: AsRef<Path>> FoundryPathExt for T {
}

/// Initializes a tracing Subscriber for logging
#[allow(dead_code)]
pub fn subscriber() {
tracing_subscriber::Registry::default()
.with(tracing_subscriber::EnvFilter::from_default_env())
Expand Down
20 changes: 10 additions & 10 deletions crates/common/src/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,16 +237,16 @@ impl ProjectCompiler {
for (name, artifact) in artifacts {
let size = deployed_contract_size(artifact).unwrap_or_default();

let dev_functions =
artifact.abi.as_ref().map(|abi| abi.functions()).into_iter().flatten().filter(
|func| {
func.name.is_test() ||
func.name.eq("IS_TEST") ||
func.name.eq("IS_SCRIPT")
},
);

let is_dev_contract = dev_functions.count() > 0;
let is_dev_contract = artifact
.abi
.as_ref()
.map(|abi| {
abi.functions().any(|f| {
f.test_function_kind().is_known() ||
matches!(f.name.as_str(), "IS_TEST" | "IS_SCRIPT")
})
})
.unwrap_or(false);
size_report.contracts.insert(name, ContractInfo { size, is_dev_contract });
}

Expand Down
218 changes: 155 additions & 63 deletions crates/common/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use alloy_json_abi::Function;
use alloy_primitives::Bytes;
use alloy_sol_types::SolError;
use std::path::Path;
use std::{fmt, path::Path};

/// Test filter.
pub trait TestFilter: Send + Sync {
Expand All @@ -19,116 +19,208 @@ pub trait TestFilter: Send + Sync {

/// Extension trait for `Function`.
pub trait TestFunctionExt {
/// Returns whether this function should be executed as invariant test.
fn is_invariant_test(&self) -> bool;

/// Returns whether this function should be executed as fuzz test.
fn is_fuzz_test(&self) -> bool;
/// Returns the kind of test function.
fn test_function_kind(&self) -> TestFunctionKind {
TestFunctionKind::classify(self.tfe_as_str(), self.tfe_has_inputs())
}

/// Returns whether this function is a test.
fn is_test(&self) -> bool;
/// Returns `true` if this function is a `setUp` function.
fn is_setup(&self) -> bool {
self.test_function_kind().is_setup()
}

/// Returns whether this function is a test that should fail.
fn is_test_fail(&self) -> bool;
/// Returns `true` if this function is a unit, fuzz, or invariant test.
fn is_any_test(&self) -> bool {
self.test_function_kind().is_any_test()
}

/// Returns whether this function is a `setUp` function.
fn is_setup(&self) -> bool;
/// Returns `true` if this function is a test that should fail.
fn is_any_test_fail(&self) -> bool {
self.test_function_kind().is_any_test_fail()
}

/// Returns whether this function is `afterInvariant` function.
fn is_after_invariant(&self) -> bool;
/// Returns `true` if this function is a unit test.
fn is_unit_test(&self) -> bool {
matches!(self.test_function_kind(), TestFunctionKind::UnitTest { .. })
}

/// Returns whether this function is a fixture function.
fn is_fixture(&self) -> bool;
}
/// Returns `true` if this function is a fuzz test.
fn is_fuzz_test(&self) -> bool {
self.test_function_kind().is_fuzz_test()
}

impl TestFunctionExt for Function {
/// Returns `true` if this function is an invariant test.
fn is_invariant_test(&self) -> bool {
self.name.is_invariant_test()
self.test_function_kind().is_invariant_test()
}

fn is_fuzz_test(&self) -> bool {
// test functions that have inputs are considered fuzz tests as those inputs will be fuzzed
!self.inputs.is_empty()
/// Returns `true` if this function is an `afterInvariant` function.
fn is_after_invariant(&self) -> bool {
self.test_function_kind().is_after_invariant()
}

fn is_test(&self) -> bool {
self.name.is_test()
/// Returns `true` if this function is a `fixture` function.
fn is_fixture(&self) -> bool {
self.test_function_kind().is_fixture()
}

fn is_test_fail(&self) -> bool {
self.name.is_test_fail()
#[doc(hidden)]
fn tfe_as_str(&self) -> &str;
#[doc(hidden)]
fn tfe_has_inputs(&self) -> bool;
}

impl TestFunctionExt for Function {
fn tfe_as_str(&self) -> &str {
self.name.as_str()
}

fn is_setup(&self) -> bool {
self.name.is_setup()
fn tfe_has_inputs(&self) -> bool {
!self.inputs.is_empty()
}
}

fn is_after_invariant(&self) -> bool {
self.name.is_after_invariant()
impl TestFunctionExt for String {
fn tfe_as_str(&self) -> &str {
self
}

fn is_fixture(&self) -> bool {
self.name.is_fixture()
fn tfe_has_inputs(&self) -> bool {
false
}
}

impl TestFunctionExt for String {
fn is_invariant_test(&self) -> bool {
self.as_str().is_invariant_test()
impl TestFunctionExt for str {
fn tfe_as_str(&self) -> &str {
self
}

fn is_fuzz_test(&self) -> bool {
self.as_str().is_fuzz_test()
fn tfe_has_inputs(&self) -> bool {
false
}
}

/// Test function kind.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum TestFunctionKind {
/// `setUp`.
Setup,
/// `test*`. `should_fail` is `true` for `testFail*`.
UnitTest { should_fail: bool },
/// `test*`, with arguments. `should_fail` is `true` for `testFail*`.
FuzzTest { should_fail: bool },
/// `invariant*` or `statefulFuzz*`.
InvariantTest,
/// `afterInvariant`.
AfterInvariant,
/// `fixture*`.
Fixture,
/// Unknown kind.
Unknown,
}

fn is_test(&self) -> bool {
self.as_str().is_test()
impl TestFunctionKind {
/// Classify a function.
#[inline]
pub fn classify(name: &str, has_inputs: bool) -> Self {
match () {
_ if name.starts_with("test") => {
let should_fail = name.starts_with("testFail");
if has_inputs {
Self::FuzzTest { should_fail }
} else {
Self::UnitTest { should_fail }
}
}
_ if name.starts_with("invariant") || name.starts_with("statefulFuzz") => {
Self::InvariantTest
}
_ if name.eq_ignore_ascii_case("setup") => Self::Setup,
_ if name.eq_ignore_ascii_case("afterinvariant") => Self::AfterInvariant,
_ if name.starts_with("fixture") => Self::Fixture,
_ => Self::Unknown,
}
}

fn is_test_fail(&self) -> bool {
self.as_str().is_test_fail()
/// Returns the name of the function kind.
pub const fn name(&self) -> &'static str {
match self {
Self::Setup => "setUp",
Self::UnitTest { should_fail: false } => "test",
Self::UnitTest { should_fail: true } => "testFail",
Self::FuzzTest { should_fail: false } => "fuzz",
Self::FuzzTest { should_fail: true } => "fuzz fail",
Self::InvariantTest => "invariant",
Self::AfterInvariant => "afterInvariant",
Self::Fixture => "fixture",
Self::Unknown => "unknown",
}
}

fn is_setup(&self) -> bool {
self.as_str().is_setup()
/// Returns `true` if this function is a `setUp` function.
#[inline]
pub const fn is_setup(&self) -> bool {
matches!(self, Self::Setup)
}

fn is_after_invariant(&self) -> bool {
self.as_str().is_after_invariant()
/// Returns `true` if this function is a unit, fuzz, or invariant test.
#[inline]
pub const fn is_any_test(&self) -> bool {
matches!(self, Self::UnitTest { .. } | Self::FuzzTest { .. } | Self::InvariantTest)
}

fn is_fixture(&self) -> bool {
self.as_str().is_fixture()
/// Returns `true` if this function is a test that should fail.
#[inline]
pub const fn is_any_test_fail(&self) -> bool {
matches!(self, Self::UnitTest { should_fail: true } | Self::FuzzTest { should_fail: true })
}
}

impl TestFunctionExt for str {
fn is_invariant_test(&self) -> bool {
self.starts_with("invariant") || self.starts_with("statefulFuzz")
/// Returns `true` if this function is a unit test.
#[inline]
pub fn is_unit_test(&self) -> bool {
matches!(self, Self::UnitTest { .. })
}

fn is_fuzz_test(&self) -> bool {
unimplemented!("no naming convention for fuzz tests")
/// Returns `true` if this function is a fuzz test.
#[inline]
pub const fn is_fuzz_test(&self) -> bool {
matches!(self, Self::FuzzTest { .. })
}

fn is_test(&self) -> bool {
self.starts_with("test")
/// Returns `true` if this function is an invariant test.
#[inline]
pub const fn is_invariant_test(&self) -> bool {
matches!(self, Self::InvariantTest)
}

fn is_test_fail(&self) -> bool {
self.starts_with("testFail")
/// Returns `true` if this function is an `afterInvariant` function.
#[inline]
pub const fn is_after_invariant(&self) -> bool {
matches!(self, Self::AfterInvariant)
}

fn is_setup(&self) -> bool {
self.eq_ignore_ascii_case("setup")
/// Returns `true` if this function is a `fixture` function.
#[inline]
pub const fn is_fixture(&self) -> bool {
matches!(self, Self::Fixture)
}

fn is_after_invariant(&self) -> bool {
self.eq_ignore_ascii_case("afterinvariant")
/// Returns `true` if this function kind is known.
#[inline]
pub const fn is_known(&self) -> bool {
!matches!(self, Self::Unknown)
}

fn is_fixture(&self) -> bool {
self.starts_with("fixture")
/// Returns `true` if this function kind is unknown.
#[inline]
pub const fn is_unknown(&self) -> bool {
matches!(self, Self::Unknown)
}
}

impl fmt::Display for TestFunctionKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.name().fmt(f)
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/evm/coverage/src/analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ impl<'a> SourceAnalyzer<'a> {

let is_test = items.iter().any(|item| {
if let CoverageItemKind::Function { name } = &item.kind {
name.is_test()
name.is_any_test()
} else {
false
}
Expand Down
3 changes: 1 addition & 2 deletions crates/forge/src/gas_report.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ impl GasReport {
} else if let Some(DecodedCallData { signature, .. }) = decoded.func {
let name = signature.split('(').next().unwrap();
// ignore any test/setup functions
let should_include = !(name.is_test() || name.is_invariant_test() || name.is_setup());
if should_include {
if !name.test_function_kind().is_known() {
trace!(contract_name, signature, "adding gas info");
let gas_info = contract_info
.functions
Expand Down
6 changes: 3 additions & 3 deletions crates/forge/src/multi_runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ impl MultiContractRunner {
.iter()
.filter(|(id, _)| filter.matches_path(&id.source) && filter.matches_contract(&id.name))
.flat_map(|(_, TestContract { abi, .. })| abi.functions())
.filter(|func| func.is_test() || func.is_invariant_test())
.filter(|func| func.is_any_test())
}

/// Returns all matching tests grouped by contract grouped by file (file -> (contract -> tests))
Expand Down Expand Up @@ -392,7 +392,7 @@ impl MultiContractRunnerBuilder {

// if it's a test, link it and add to deployable contracts
if abi.constructor.as_ref().map(|c| c.inputs.is_empty()).unwrap_or(true) &&
abi.functions().any(|func| func.name.is_test() || func.name.is_invariant_test())
abi.functions().any(|func| func.name.is_any_test())
{
let Some(bytecode) =
contract.get_bytecode_bytes().map(|b| b.into_owned()).filter(|b| !b.is_empty())
Expand Down Expand Up @@ -434,5 +434,5 @@ pub fn matches_contract(id: &ArtifactId, abi: &JsonAbi, filter: &dyn TestFilter)

/// Returns `true` if the function is a test function that matches the given filter.
pub(crate) fn is_matching_test(func: &Function, filter: &dyn TestFilter) -> bool {
(func.is_test() || func.is_invariant_test()) && filter.matches_test(&func.signature())
func.is_any_test() && filter.matches_test(&func.signature())
}
Loading

0 comments on commit ba9fa20

Please sign in to comment.