Skip to content

Commit

Permalink
Support subtype checks (#42)
Browse files Browse the repository at this point in the history
* More performance tests (prevent inlining).

* Support subtype checks.

* Bump version.
  • Loading branch information
MrVPlusOne authored Feb 6, 2022
1 parent d3b2579 commit 197e615
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ArgCheck"
uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197"
license = "MIT"
version = "2.2.0"
version = "2.3.0"

[compat]
julia = "1"
Expand Down
29 changes: 28 additions & 1 deletion src/checks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ struct CheckFlavor <: AbstractCheckFlavor end
abstract type AbstractCodeFlavor end
struct CallFlavor <: AbstractCodeFlavor end
struct ComparisonFlavor <: AbstractCodeFlavor end
struct SubtypeFlavor <: AbstractCodeFlavor end
struct FallbackFlavor <: AbstractCodeFlavor end

struct Checker
Expand All @@ -49,6 +50,13 @@ struct ComparisonErrorInfo <: AbstractErrorInfo
argument_values::Vector
options::Tuple
end
struct SubtypeErrorInfo <: AbstractErrorInfo
code
checkflavor::AbstractCheckFlavor
argument_expressions::Vector
argument_values::Vector
options::Tuple
end
struct FallbackErrorInfo <: AbstractErrorInfo
code
checkflavor::AbstractCheckFlavor
Expand Down Expand Up @@ -97,6 +105,8 @@ function check(ex, checkflavor, options...)
ComparisonFlavor()
elseif isexpr(ex, :call)
CallFlavor()
elseif isexpr(ex, :(<:))
SubtypeFlavor()
else
FallbackFlavor()
end
Expand Down Expand Up @@ -180,7 +190,7 @@ function build_call(ana)
return Expr(:call, f, callargs...)
end

function check(c, ::CallFlavor)
function check(c::Checker, ::CallFlavor)
ana = analyze_call(c.code)
variables = []
argument_expressions = []
Expand Down Expand Up @@ -236,6 +246,22 @@ function check(c::Checker, ::ComparisonFlavor)
Expr(:block, ret...)
end

function check(c::Checker, ::SubtypeFlavor)
lhs, rhs = c.code.args
vlhs, vrhs = gensym(:lhs), gensym(:rhs)

condition = Expr(:(<:), vlhs, vrhs)
assignments = (Expr(:(=), vlhs, esc(lhs)), Expr(:(=), vrhs, esc(rhs)))
info = Expr(:call, :SubtypeErrorInfo,
QuoteNode(c.code),
c.checkflavor,
QuoteNode([lhs, rhs]),
Expr(:vect, vlhs, vrhs),
Expr(:tuple, esc.(c.options)...))

expr_error_block(info, condition, assignments...)
end

function expr_error_block(info, condition, preamble...)
quote
$(preamble...)
Expand Down Expand Up @@ -277,6 +303,7 @@ end
error_message(info::FallbackErrorInfo) = "$(info.code) must hold."
error_message(info::CallErrorInfo) = fancy_error_message(info)
error_message(info::ComparisonErrorInfo) = fancy_error_message(info)
error_message(info::SubtypeErrorInfo) = fancy_error_message(info)

function pretty_string(data)
io = IOBuffer()
Expand Down
11 changes: 10 additions & 1 deletion test/checks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ end
a = 1.0; b = 1.2; atol = 0.1; nvalue = false; rtol = 0.05;
@check isapprox(a, b, atol=atol, nans=nvalue; rtol)
end
locations = map(["a", "b", "atol", "nvalue", "rtol"]) do name
locations = map(["a =>", "b =>", "atol =>", "nvalue =>", "rtol =>"]) do name
findfirst(name, err.msg)
end
@test all(x -> x isa UnitRange, locations)
Expand All @@ -206,6 +206,15 @@ end
@test occursin(string(x), err.msg)
err = @catch_exception_object @argcheck !isfinite(x)
@test_broken occursin(string(x), err.msg)


t1 = Int32
t2 = Integer
err = @catch_exception_object @argcheck t2 <: t1
@test occursin(string(t1), err.msg)
@test occursin(string(t2), err.msg)
@test occursin("t1 =>", err.msg)
@test occursin("t2 =>", err.msg)
end

@testset "complicated calls" begin
Expand Down
28 changes: 18 additions & 10 deletions test/perf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,46 @@ module Perf
using BenchmarkTools
using ArgCheck

truthy(x) = true
@noinline truthy(x) = x == x
truthy2(x) = true

function fallback_argcheck(x)
@noinline function fallback_argcheck(x)
@argcheck x
end
function comparison_argcheck(x)
@noinline function comparison_argcheck(x)
@argcheck x == x
end
function call_argcheck(x)
@noinline function call_argcheck(x)
@argcheck truthy(x)
end
function fallback_assert(x)
function call_argcheck2(x)
@argcheck truthy2(x)
end
@noinline function fallback_assert(x)
@assert x
end
function comparison_assert(x)
@noinline function comparison_assert(x)
@assert x == x
end
function call_assert(x)
@noinline function call_assert(x)
@assert truthy(x)
end
function call_assert2(x)
@assert truthy2(x)
end

benchmarks =[
(fallback_assert, fallback_argcheck, true),
(call_assert, call_argcheck, 42),
(call_assert2, call_argcheck2, 42),
(comparison_assert, comparison_argcheck, 42),
]

for (f_argcheck, f_assert, arg) in benchmarks
println(f_argcheck)
@btime ($f_argcheck)($arg)
for (f_assert, f_argcheck, arg) in benchmarks
println(f_assert)
@btime ($f_assert)($arg)
println(f_argcheck)
@btime ($f_argcheck)($arg)
end

end#module

2 comments on commit 197e615

@jw3126
Copy link
Owner

@jw3126 jw3126 commented on 197e615 Feb 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/54038

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v2.3.0 -m "<description of version>" 197e6155704ff40633517c9cc480c6a690fb1c90
git push origin v2.3.0

Please sign in to comment.