forked from QEDjl-project/QEDevents.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Proposal] multi particle distribution (QEDjl-project#24)
* added single particle dist interface, started implementing unit tests * reorganized; refined and improved tests * reorganization again * switched eltype to object based * changed output of rand to ParticleStateful * added dev-dependency on QEDprocesses and adjusted .gitlab-ci.toml -> REMOVE BEFORE MERGING * bugfix .gitlab-ci * made eltype more flat * implemented weight functionality * added ParticleSampleable, SingleParticleDistribution, test_implementation; removed sampler_interface * refac of the interface * added dev-dependency on QEDprocesses to docs building -> REMOVE BEFORE MERGING * bugfix github workflow * Apply suggestions from code review Co-authored-by: Anton Reinhard <[email protected]> * formatting * cleanup * added groundtruth to reproduction test * added tests for weights * Apply suggestions from code review Thanks for the suggestions, I added them. Co-authored-by: Anton Reinhard <[email protected]> * getting started * wip * added implementation for multi particle interface and test implementation * finalize implementation and tests, fixed type instability for rand * cleanup * opt out randmom and made it an interface function instead of rand * cleanup * rebased on dev * Apply suggestions from code review Co-authored-by: Anton Reinhard <[email protected]> * fixed cases in weigth tests * Update src/interfaces/multi_particle_distribution.jl Co-authored-by: Anton Reinhard <[email protected]> --------- Co-authored-by: Uwe Hernandez Acosta <[email protected]> Co-authored-by: Anton Reinhard <[email protected]>
- Loading branch information
1 parent
350f633
commit 7b496fd
Showing
9 changed files
with
336 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
|
||
""" | ||
MultiParticleDistribution | ||
Base type for sample drawing from multiple particle distributions. The following interface functions | ||
should be implemented: | ||
* [`QEDevents._particles(d::MultiParticleDistribution)`](@ref) | ||
* [`QEDevents._particle_directions(d::MultiParticleDistribution)`](@ref) | ||
* [`QEDevents.randmom(rng::AbstractRNG,d::MultiParticleDistribution)`](@ref) | ||
""" | ||
const MultiParticleDistribution = ParticleSampleable{MultiParticleVariate} | ||
|
||
Broadcast.broadcastable(d::MultiParticleDistribution) = Ref(d) | ||
|
||
Base.length(d::MultiParticleDistribution) = length(_particles(d)) | ||
Base.size(d::MultiParticleDistribution) = (length(d),) | ||
|
||
""" | ||
_particle(dist::MultiParticleDistribution) | ||
Return tuple of particles associated with the `dist`. | ||
!!! note | ||
Interface function to be implemented for multi-particle distributions. | ||
""" | ||
function _particles end | ||
|
||
""" | ||
_particle_direction(dist::MultiParticleDistribution) | ||
Return tuple of particle-directions for all particles associated with `dist`. | ||
!!! note | ||
Interface function to be implemented for multi-particle distributions. | ||
""" | ||
function _particle_directions end | ||
#default | ||
function _particle_direction(d::MultiParticleDistribution) | ||
return Tuple(fill(UnknownDirection(), length(d))) | ||
end | ||
|
||
""" | ||
randmom(rng::AbstractRNG,d::MultiParticleDistribution) | ||
Return an iterable container (e.g. vector or tuple) of momenta according to the distribution `d`. | ||
""" | ||
function randmom end | ||
|
||
# recursion termination: success | ||
@inline _recursive_type_check(::Tuple{}, ::Tuple{}, ::Tuple{}) = nothing | ||
|
||
# recursion termination: overload for unequal number of particles | ||
@inline function _recursive_type_check( | ||
::Tuple{Vararg{ParticleStateful,N}}, | ||
::Tuple{Vararg{AbstractParticleType,M}}, | ||
::Tuple{Vararg{ParticleDirection,M}}, | ||
) where {N,M} | ||
throw(InvalidInputError("expected $(M) particles but got $(N)")) | ||
return nothing | ||
end | ||
|
||
# recursion termination: overload for invalid types | ||
@inline function _recursive_type_check( | ||
::Tuple{ParticleStateful{DIR_IN_T,SPECIES_IN_T},Vararg{ParticleStateful,N}}, | ||
::Tuple{SPECIES_T,Vararg{AbstractParticleType,N}}, | ||
::Tuple{DIR_T,Vararg{ParticleDirection,N}}, | ||
) where { | ||
N, | ||
DIR_IN_T<:ParticleDirection, | ||
DIR_T<:ParticleDirection, | ||
SPECIES_IN_T<:AbstractParticleType, | ||
SPECIES_T<:AbstractParticleType, | ||
} | ||
throw( | ||
InvalidInputError( | ||
"expected $(DIR_T()) $(SPECIES_T()) but got $(DIR_IN_T()) $(SPECIES_IN_T())" | ||
), | ||
) | ||
return nothing | ||
end | ||
|
||
@inline function _recursive_type_check( | ||
t::Tuple{ParticleStateful{DIR_T,SPECIES_T},Vararg{ParticleStateful,N}}, | ||
p::Tuple{SPECIES_T,Vararg{AbstractParticleType,N}}, | ||
dir::Tuple{DIR_T,Vararg{ParticleDirection,N}}, | ||
) where {N,DIR_T<:ParticleDirection,SPECIES_T<:AbstractParticleType} | ||
return _recursive_type_check(t[2:end], p[2:end], dir[2:end]) | ||
end | ||
|
||
""" | ||
Interface function, which asserts that the given `input` is valid. | ||
""" | ||
function _assert_valid_input_type( | ||
d::MultiParticleDistribution, x::PS | ||
) where {PS<:Tuple{Vararg{ParticleStateful}}} | ||
# TODO: implement correct type check | ||
_recursive_type_check(x, _particles(d), _particle_directions(d)) | ||
return nothing | ||
end | ||
|
||
# recursion termination: base case | ||
@inline _assemble_tuple_types(::Tuple{}, ::Tuple{}, ::Type) = () | ||
|
||
@inline function _assemble_tuple_types( | ||
particle_types::Tuple{SPECIES_T,Vararg{AbstractParticleType}}, | ||
dir::Tuple{DIR_T,Vararg{ParticleDirection}}, | ||
ELTYPE::Type, | ||
) where {SPECIES_T<:AbstractParticleType,DIR_T<:ParticleDirection} | ||
return ( | ||
ParticleStateful{DIR_T,SPECIES_T,ELTYPE}, | ||
_assemble_tuple_types(particle_types[2:end], dir[2:end], ELTYPE)..., | ||
) | ||
end | ||
|
||
# used for pre-allocation of vectors of particle-stateful | ||
function Base.eltype(d::MultiParticleDistribution) | ||
return Tuple{ | ||
_assemble_tuple_types(_particles(d), _particle_directions(d), _momentum_type(d))... | ||
} | ||
end | ||
|
||
function Distributions.rand(rng::AbstractRNG, d::MultiParticleDistribution) | ||
n = length(d) | ||
moms = randmom(rng, d) | ||
dirs = _particle_directions(d) | ||
parts = _particles(d) | ||
|
||
return ntuple(i -> ParticleStateful(dirs[i], parts[i], moms[i]), Val(n)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
|
||
using QEDprocesses | ||
using QEDevents | ||
using QEDbase | ||
using Random: Random | ||
import Random: AbstractRNG, MersenneTwister, default_rng | ||
|
||
# only imported because we want to test if QEDevents works without this | ||
# (especially the Base.rand, which is exported by Distributions) | ||
using Distributions: Distributions | ||
|
||
include("../test_implementation/TestImpl.jl") | ||
|
||
RNG = MersenneTwister(137137137) | ||
|
||
ATOL = 0.0 | ||
RTOL = sqrt(eps()) | ||
|
||
struct WrongParticle <: AbstractParticleType end # for type checking in weight | ||
struct WrongDirection <: ParticleDirection end # for type checking in weight | ||
|
||
DIRECTIONS = (Incoming(), Outgoing(), QEDevents.UnknownDirection()) | ||
RND_SEED = ceil(Int, 1e6 * rand(RNG)) # for comparison | ||
@testset "N=$N" for N in (1, rand(RNG, 2:10)) | ||
@testset "default properties" begin | ||
test_dist_plain = TestImpl.TestMultiParticleDistPlain(N) | ||
@test all( | ||
QEDevents._particle_direction(test_dist_plain) .== QEDevents.UnknownDirection() | ||
) | ||
@test QEDevents._momentum_type(test_dist_plain) == SFourMomentum | ||
end | ||
|
||
test_particles = Tuple(rand(RNG, TestImpl.PARTICLE_SET, N)) | ||
test_directions = Tuple(rand(RNG, DIRECTIONS, N)) | ||
test_dist = TestImpl.TestMultiParticleDist(test_directions, test_particles) | ||
|
||
@testset "static properties" begin | ||
@test @inferred QEDevents._particles(test_dist) == test_particles | ||
@test @inferred QEDevents._particle_directions(test_dist) == test_directions | ||
@test @inferred length(test_dist) == N | ||
@test @inferred size(test_dist) == (N,) | ||
|
||
# todo: consider to move _assemble_tuple_types to the test implementation | ||
# (groundtruths must not rely on package internals) | ||
# See https://github.com/QEDjl-project/QEDprocesses.jl/issues/75 | ||
@test @inferred eltype(test_dist) == Tuple{ | ||
QEDevents._assemble_tuple_types( | ||
QEDevents._particles(test_dist), | ||
QEDevents._particle_directions(test_dist), | ||
QEDevents._momentum_type(test_dist), | ||
)..., | ||
} | ||
end | ||
|
||
@testset "single sample" begin | ||
Random.seed!(RND_SEED) | ||
rng = default_rng() | ||
moms_groundtruth = TestImpl._groundtruth_multi_randmom(rng, test_dist) | ||
psf_groundtruth = Tuple( | ||
ParticleStateful(test_directions[i], test_particles[i], moms_groundtruth[i]) for | ||
i in 1:N | ||
) | ||
|
||
Random.seed!(RND_SEED) | ||
rng = default_rng() | ||
psf_rng = @inferred rand(rng, test_dist) | ||
|
||
Random.seed!(RND_SEED) | ||
psf_default = @inferred rand(test_dist) | ||
|
||
@test psf_groundtruth == psf_rng | ||
@test psf_rng == psf_default | ||
end | ||
@testset "multiple samples" begin | ||
@testset "$dim" for dim in (1, 2, 3) | ||
checked_lengths = (1, 2, rand(RNG, 3:10)) | ||
shapes = Iterators.product(fill(checked_lengths, dim)...) | ||
|
||
@testset "$shape" for shape in shapes | ||
Random.seed!(RND_SEED) | ||
rng = default_rng() | ||
tuple_psf_rng = @inferred rand(rng, test_dist, shape...) | ||
|
||
Random.seed!(RND_SEED) | ||
tuple_psf_default = @inferred rand(test_dist, shape...) | ||
|
||
Random.seed!(RND_SEED) | ||
rng = default_rng() | ||
res_type = eltype(test_dist) | ||
tuple_psf_prealloc_rng = Array{res_type}(undef, shape...) | ||
@inferred Random.rand!(rng, test_dist, tuple_psf_prealloc_rng) | ||
|
||
Random.seed!(RND_SEED) | ||
res_type = eltype(test_dist) | ||
tuple_psf_prealloc_default = Array{res_type}(undef, shape...) | ||
@inferred Random.rand!(test_dist, tuple_psf_prealloc_default) | ||
|
||
@test all(tuple_psf_rng == tuple_psf_default) | ||
@test all(tuple_psf_rng == tuple_psf_prealloc_rng) | ||
@test all(tuple_psf_rng == tuple_psf_prealloc_default) | ||
end | ||
end | ||
end | ||
@testset "weights" begin | ||
@testset "evaluation" begin | ||
test_input = rand(RNG, test_dist) | ||
@test weight(test_dist, test_input) == | ||
TestImpl._groundtruth_multi_weight(test_dist, test_input) | ||
end | ||
|
||
@testset "fails" begin | ||
correct_input = rand(RNG, test_dist) | ||
|
||
# failing inputs with either wrong particle, wrong direction or both | ||
psf_wrong_particle = ParticleStateful( | ||
test_directions[1], WrongParticle(), rand(RNG, SFourMomentum) | ||
) | ||
input_wrong_particle = TestImpl.tuple_setindex( | ||
correct_input, 1, psf_wrong_particle | ||
) | ||
|
||
psf_wrong_direction = ParticleStateful( | ||
WrongDirection(), test_particles[1], rand(RNG, SFourMomentum) | ||
) | ||
input_wrong_direction = TestImpl.tuple_setindex( | ||
correct_input, 1, psf_wrong_direction | ||
) | ||
|
||
psf_wrong = ParticleStateful( | ||
WrongDirection(), WrongParticle(), rand(RNG, SFourMomentum) | ||
) | ||
input_wrong = TestImpl.tuple_setindex(correct_input, 1, psf_wrong) | ||
|
||
# failing input with wrong length | ||
input_wrong_length = (psf_wrong, correct_input...) | ||
|
||
@test_throws InvalidInputError weight(test_dist, input_wrong_particle) | ||
@test_throws InvalidInputError weight(test_dist, input_wrong_direction) | ||
@test_throws InvalidInputError weight(test_dist, input_wrong) | ||
@test_throws InvalidInputError weight(test_dist, input_wrong_length) | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
function _groundtruth_multi_randmom(rng, d) | ||
return rand(rng, SFourMomentum, length(d)) | ||
end | ||
|
||
function _groundtruth_multi_weight(dist, psfs) | ||
@. getE(momentum(psfs)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
|
||
PARTICLE_DIRECTIONS = (Incoming(), Outgoing(), QEDevents.UnknownDirection()) | ||
|
||
struct TestMultiParticleDist{DT<:Tuple,PT<:Tuple,RT} <: MultiParticleDistribution | ||
dirs::DT | ||
parts::PT | ||
function TestMultiParticleDist(dirs::DT, parts::PT) where {DT,PT} | ||
res_type = Tuple{QEDevents._assemble_tuple_types(parts, dirs, SFourMomentum)...} | ||
return new{DT,PT,res_type}(dirs, parts) | ||
end | ||
end | ||
|
||
function QEDevents._particles(d::TestMultiParticleDist) | ||
return d.parts | ||
end | ||
|
||
QEDevents._particle_directions(d::TestMultiParticleDist) = d.dirs | ||
|
||
function QEDevents.randmom(rng::AbstractRNG, d::TestMultiParticleDist) | ||
return _groundtruth_multi_randmom(rng, d) | ||
end | ||
|
||
function QEDevents._weight(d::TestMultiParticleDist, x::Tuple{Vararg{ParticleStateful}}) | ||
return _groundtruth_multi_weight(d, x) | ||
end | ||
|
||
# plain multi particle distribution | ||
# for testing of default implementations | ||
struct TestMultiParticleDistPlain <: MultiParticleDistribution | ||
n::Int | ||
end | ||
|
||
QEDevents._particles(d::TestMultiParticleDistPlain) = Tuple(fill(TestParticle(), d.n)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
""" | ||
replace i-th entry of t with val | ||
""" | ||
function tuple_setindex(t::Tuple, i, val) | ||
return ntuple(j -> j == i ? val : t[j], length(t)) | ||
end |