Skip to content

Commit

Permalink
[Proposal] multi particle distribution (QEDjl-project#24)
Browse files Browse the repository at this point in the history
* added single particle dist interface, started implementing unit tests

* reorganized; refined and improved tests

* reorganization again

* switched eltype to object based

* changed output of rand to ParticleStateful

* added dev-dependency on QEDprocesses and adjusted .gitlab-ci.toml -> REMOVE BEFORE MERGING

* bugfix .gitlab-ci

* made eltype more flat

* implemented weight functionality

* added ParticleSampleable, SingleParticleDistribution, test_implementation; removed sampler_interface

* refac of the interface

* added dev-dependency on QEDprocesses to docs building -> REMOVE BEFORE MERGING

* bugfix github workflow

* Apply suggestions from code review

Co-authored-by: Anton Reinhard <[email protected]>

* formatting

* cleanup

* added groundtruth to reproduction test

* added tests for weights

* Apply suggestions from code review

Thanks for the suggestions, I added them.

Co-authored-by: Anton Reinhard <[email protected]>

* getting started

* wip

* added implementation for multi particle interface and test implementation

* finalize implementation and tests, fixed type instability for rand

* cleanup

* opt out randmom and made it an interface function instead of rand

* cleanup

* rebased on dev

* Apply suggestions from code review

Co-authored-by: Anton Reinhard <[email protected]>

* fixed cases in weigth tests

* Update src/interfaces/multi_particle_distribution.jl

Co-authored-by: Anton Reinhard <[email protected]>

---------

Co-authored-by: Uwe Hernandez Acosta <[email protected]>
Co-authored-by: Anton Reinhard <[email protected]>
  • Loading branch information
3 people authored Jun 19, 2024
1 parent 350f633 commit 7b496fd
Show file tree
Hide file tree
Showing 9 changed files with 336 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/QEDevents.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module QEDevents
export weight

export SingleParticleDistribution
export MultiParticleDistribution, randmom

import Random: AbstractRNG
import Distributions: rand, rand!, _rand!
Expand All @@ -15,6 +16,7 @@ using DocStringExtensions

include("interfaces/particle_distribution.jl")
include("interfaces/single_particle_distribution.jl")
include("interfaces/multi_particle_distribution.jl")

include("patch_QEDbase.jl")

Expand Down
137 changes: 137 additions & 0 deletions src/interfaces/multi_particle_distribution.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@

"""
MultiParticleDistribution
Base type for sample drawing from multiple particle distributions. The following interface functions
should be implemented:
* [`QEDevents._particles(d::MultiParticleDistribution)`](@ref)
* [`QEDevents._particle_directions(d::MultiParticleDistribution)`](@ref)
* [`QEDevents.randmom(rng::AbstractRNG,d::MultiParticleDistribution)`](@ref)
"""
const MultiParticleDistribution = ParticleSampleable{MultiParticleVariate}

Broadcast.broadcastable(d::MultiParticleDistribution) = Ref(d)

Base.length(d::MultiParticleDistribution) = length(_particles(d))
Base.size(d::MultiParticleDistribution) = (length(d),)

"""
_particle(dist::MultiParticleDistribution)
Return tuple of particles associated with the `dist`.
!!! note
Interface function to be implemented for multi-particle distributions.
"""
function _particles end

"""
_particle_direction(dist::MultiParticleDistribution)
Return tuple of particle-directions for all particles associated with `dist`.
!!! note
Interface function to be implemented for multi-particle distributions.
"""
function _particle_directions end
#default
function _particle_direction(d::MultiParticleDistribution)
return Tuple(fill(UnknownDirection(), length(d)))
end

"""
randmom(rng::AbstractRNG,d::MultiParticleDistribution)
Return an iterable container (e.g. vector or tuple) of momenta according to the distribution `d`.
"""
function randmom end

# recursion termination: success
@inline _recursive_type_check(::Tuple{}, ::Tuple{}, ::Tuple{}) = nothing

# recursion termination: overload for unequal number of particles
@inline function _recursive_type_check(
::Tuple{Vararg{ParticleStateful,N}},
::Tuple{Vararg{AbstractParticleType,M}},
::Tuple{Vararg{ParticleDirection,M}},
) where {N,M}
throw(InvalidInputError("expected $(M) particles but got $(N)"))
return nothing
end

# recursion termination: overload for invalid types
@inline function _recursive_type_check(
::Tuple{ParticleStateful{DIR_IN_T,SPECIES_IN_T},Vararg{ParticleStateful,N}},
::Tuple{SPECIES_T,Vararg{AbstractParticleType,N}},
::Tuple{DIR_T,Vararg{ParticleDirection,N}},
) where {
N,
DIR_IN_T<:ParticleDirection,
DIR_T<:ParticleDirection,
SPECIES_IN_T<:AbstractParticleType,
SPECIES_T<:AbstractParticleType,
}
throw(
InvalidInputError(
"expected $(DIR_T()) $(SPECIES_T()) but got $(DIR_IN_T()) $(SPECIES_IN_T())"
),
)
return nothing
end

@inline function _recursive_type_check(
t::Tuple{ParticleStateful{DIR_T,SPECIES_T},Vararg{ParticleStateful,N}},
p::Tuple{SPECIES_T,Vararg{AbstractParticleType,N}},
dir::Tuple{DIR_T,Vararg{ParticleDirection,N}},
) where {N,DIR_T<:ParticleDirection,SPECIES_T<:AbstractParticleType}
return _recursive_type_check(t[2:end], p[2:end], dir[2:end])
end

"""
Interface function, which asserts that the given `input` is valid.
"""
function _assert_valid_input_type(
d::MultiParticleDistribution, x::PS
) where {PS<:Tuple{Vararg{ParticleStateful}}}
# TODO: implement correct type check
_recursive_type_check(x, _particles(d), _particle_directions(d))
return nothing
end

# recursion termination: base case
@inline _assemble_tuple_types(::Tuple{}, ::Tuple{}, ::Type) = ()

@inline function _assemble_tuple_types(
particle_types::Tuple{SPECIES_T,Vararg{AbstractParticleType}},
dir::Tuple{DIR_T,Vararg{ParticleDirection}},
ELTYPE::Type,
) where {SPECIES_T<:AbstractParticleType,DIR_T<:ParticleDirection}
return (
ParticleStateful{DIR_T,SPECIES_T,ELTYPE},
_assemble_tuple_types(particle_types[2:end], dir[2:end], ELTYPE)...,
)
end

# used for pre-allocation of vectors of particle-stateful
function Base.eltype(d::MultiParticleDistribution)
return Tuple{
_assemble_tuple_types(_particles(d), _particle_directions(d), _momentum_type(d))...
}
end

function Distributions.rand(rng::AbstractRNG, d::MultiParticleDistribution)
n = length(d)
moms = randmom(rng, d)
dirs = _particle_directions(d)
parts = _particles(d)

return ntuple(i -> ParticleStateful(dirs[i], parts[i], moms[i]), Val(n))
end
143 changes: 143 additions & 0 deletions test/interfaces/multi_particle_distribution.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@

using QEDprocesses
using QEDevents
using QEDbase
using Random: Random
import Random: AbstractRNG, MersenneTwister, default_rng

# only imported because we want to test if QEDevents works without this
# (especially the Base.rand, which is exported by Distributions)
using Distributions: Distributions

include("../test_implementation/TestImpl.jl")

RNG = MersenneTwister(137137137)

ATOL = 0.0
RTOL = sqrt(eps())

struct WrongParticle <: AbstractParticleType end # for type checking in weight
struct WrongDirection <: ParticleDirection end # for type checking in weight

DIRECTIONS = (Incoming(), Outgoing(), QEDevents.UnknownDirection())
RND_SEED = ceil(Int, 1e6 * rand(RNG)) # for comparison
@testset "N=$N" for N in (1, rand(RNG, 2:10))
@testset "default properties" begin
test_dist_plain = TestImpl.TestMultiParticleDistPlain(N)
@test all(
QEDevents._particle_direction(test_dist_plain) .== QEDevents.UnknownDirection()
)
@test QEDevents._momentum_type(test_dist_plain) == SFourMomentum
end

test_particles = Tuple(rand(RNG, TestImpl.PARTICLE_SET, N))
test_directions = Tuple(rand(RNG, DIRECTIONS, N))
test_dist = TestImpl.TestMultiParticleDist(test_directions, test_particles)

@testset "static properties" begin
@test @inferred QEDevents._particles(test_dist) == test_particles
@test @inferred QEDevents._particle_directions(test_dist) == test_directions
@test @inferred length(test_dist) == N
@test @inferred size(test_dist) == (N,)

# todo: consider to move _assemble_tuple_types to the test implementation
# (groundtruths must not rely on package internals)
# See https://github.com/QEDjl-project/QEDprocesses.jl/issues/75
@test @inferred eltype(test_dist) == Tuple{
QEDevents._assemble_tuple_types(
QEDevents._particles(test_dist),
QEDevents._particle_directions(test_dist),
QEDevents._momentum_type(test_dist),
)...,
}
end

@testset "single sample" begin
Random.seed!(RND_SEED)
rng = default_rng()
moms_groundtruth = TestImpl._groundtruth_multi_randmom(rng, test_dist)
psf_groundtruth = Tuple(
ParticleStateful(test_directions[i], test_particles[i], moms_groundtruth[i]) for
i in 1:N
)

Random.seed!(RND_SEED)
rng = default_rng()
psf_rng = @inferred rand(rng, test_dist)

Random.seed!(RND_SEED)
psf_default = @inferred rand(test_dist)

@test psf_groundtruth == psf_rng
@test psf_rng == psf_default
end
@testset "multiple samples" begin
@testset "$dim" for dim in (1, 2, 3)
checked_lengths = (1, 2, rand(RNG, 3:10))
shapes = Iterators.product(fill(checked_lengths, dim)...)

@testset "$shape" for shape in shapes
Random.seed!(RND_SEED)
rng = default_rng()
tuple_psf_rng = @inferred rand(rng, test_dist, shape...)

Random.seed!(RND_SEED)
tuple_psf_default = @inferred rand(test_dist, shape...)

Random.seed!(RND_SEED)
rng = default_rng()
res_type = eltype(test_dist)
tuple_psf_prealloc_rng = Array{res_type}(undef, shape...)
@inferred Random.rand!(rng, test_dist, tuple_psf_prealloc_rng)

Random.seed!(RND_SEED)
res_type = eltype(test_dist)
tuple_psf_prealloc_default = Array{res_type}(undef, shape...)
@inferred Random.rand!(test_dist, tuple_psf_prealloc_default)

@test all(tuple_psf_rng == tuple_psf_default)
@test all(tuple_psf_rng == tuple_psf_prealloc_rng)
@test all(tuple_psf_rng == tuple_psf_prealloc_default)
end
end
end
@testset "weights" begin
@testset "evaluation" begin
test_input = rand(RNG, test_dist)
@test weight(test_dist, test_input) ==
TestImpl._groundtruth_multi_weight(test_dist, test_input)
end

@testset "fails" begin
correct_input = rand(RNG, test_dist)

# failing inputs with either wrong particle, wrong direction or both
psf_wrong_particle = ParticleStateful(
test_directions[1], WrongParticle(), rand(RNG, SFourMomentum)
)
input_wrong_particle = TestImpl.tuple_setindex(
correct_input, 1, psf_wrong_particle
)

psf_wrong_direction = ParticleStateful(
WrongDirection(), test_particles[1], rand(RNG, SFourMomentum)
)
input_wrong_direction = TestImpl.tuple_setindex(
correct_input, 1, psf_wrong_direction
)

psf_wrong = ParticleStateful(
WrongDirection(), WrongParticle(), rand(RNG, SFourMomentum)
)
input_wrong = TestImpl.tuple_setindex(correct_input, 1, psf_wrong)

# failing input with wrong length
input_wrong_length = (psf_wrong, correct_input...)

@test_throws InvalidInputError weight(test_dist, input_wrong_particle)
@test_throws InvalidInputError weight(test_dist, input_wrong_direction)
@test_throws InvalidInputError weight(test_dist, input_wrong)
@test_throws InvalidInputError weight(test_dist, input_wrong_length)
end
end
end
5 changes: 4 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ using QEDevents
using Test
using SafeTestsets

@testset "QEDevents.jl" begin
begin
@time @safetestset "single particle distribution" begin
include("interfaces/single_particle_distribution.jl")
end
@time @safetestset "multi particle distribution" begin
include("interfaces/multi_particle_distribution.jl")
end
end
3 changes: 3 additions & 0 deletions test/test_implementation/TestImpl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@ using Distributions
import Random: AbstractRNG

include("groundtruths/single_particle.jl")
include("groundtruths/multi_particle.jl")

include("test_particles.jl")
include("test_model.jl")
include("test_process.jl")
include("single_particle_dist.jl")
include("multi_particle_dist.jl")
include("utils.jl")

end
7 changes: 7 additions & 0 deletions test/test_implementation/groundtruths/multi_particle.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
function _groundtruth_multi_randmom(rng, d)
return rand(rng, SFourMomentum, length(d))
end

function _groundtruth_multi_weight(dist, psfs)
@. getE(momentum(psfs))
end
33 changes: 33 additions & 0 deletions test/test_implementation/multi_particle_dist.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@

PARTICLE_DIRECTIONS = (Incoming(), Outgoing(), QEDevents.UnknownDirection())

struct TestMultiParticleDist{DT<:Tuple,PT<:Tuple,RT} <: MultiParticleDistribution
dirs::DT
parts::PT
function TestMultiParticleDist(dirs::DT, parts::PT) where {DT,PT}
res_type = Tuple{QEDevents._assemble_tuple_types(parts, dirs, SFourMomentum)...}
return new{DT,PT,res_type}(dirs, parts)
end
end

function QEDevents._particles(d::TestMultiParticleDist)
return d.parts
end

QEDevents._particle_directions(d::TestMultiParticleDist) = d.dirs

function QEDevents.randmom(rng::AbstractRNG, d::TestMultiParticleDist)
return _groundtruth_multi_randmom(rng, d)
end

function QEDevents._weight(d::TestMultiParticleDist, x::Tuple{Vararg{ParticleStateful}})
return _groundtruth_multi_weight(d, x)
end

# plain multi particle distribution
# for testing of default implementations
struct TestMultiParticleDistPlain <: MultiParticleDistribution
n::Int
end

QEDevents._particles(d::TestMultiParticleDistPlain) = Tuple(fill(TestParticle(), d.n))
1 change: 1 addition & 0 deletions test/test_implementation/test_particles.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

# dummy particles
struct TestParticle <: AbstractParticleType end # generic particle
struct TestParticleFermion <: FermionLike end
struct TestParticleBoson <: BosonLike end

Expand Down
6 changes: 6 additions & 0 deletions test/test_implementation/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""
replace i-th entry of t with val
"""
function tuple_setindex(t::Tuple, i, val)
return ntuple(j -> j == i ? val : t[j], length(t))
end

0 comments on commit 7b496fd

Please sign in to comment.