Skip to content

Commit 7479b96

Browse files
committedMay 25, 2023
refactor metas
1 parent 188271c commit 7479b96

File tree

4 files changed

+104
-51
lines changed

4 files changed

+104
-51
lines changed
 

‎src/Part2/Rx/T-maze_Bethe.ipynb

+3-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@
6565
" z_k_min = z_0\n",
6666
" for k=1:2\n",
6767
" z[k] ~ Transition(z_k_min, u[k])\n",
68-
" c[k] ~ GoalObservation(z[k], A) where {meta=BetheMeta(), pipeline=BethePipeline()} # Goal-observation composite node\n",
68+
" c[k] ~ GoalObservation(z[k], A) where {\n",
69+
" meta=BetheMeta(), \n",
70+
" pipeline=BethePipeline()}\n",
6971
"\n",
7072
" z_k_min = z[k] # Reset for next slice\n",
7173
" end\n",

‎src/Part2/Rx/T-maze_Generalized.ipynb

+13-11
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 55,
5+
"execution_count": 1,
66
"metadata": {},
77
"outputs": [
88
{
@@ -21,7 +21,7 @@
2121
},
2222
{
2323
"cell_type": "code",
24-
"execution_count": 56,
24+
"execution_count": 2,
2525
"metadata": {},
2626
"outputs": [],
2727
"source": [
@@ -36,7 +36,7 @@
3636
},
3737
{
3838
"cell_type": "code",
39-
"execution_count": 57,
39+
"execution_count": 3,
4040
"metadata": {},
4141
"outputs": [
4242
{
@@ -45,7 +45,7 @@
4545
"structured (generic function with 1 method)"
4646
]
4747
},
48-
"execution_count": 57,
48+
"execution_count": 3,
4949
"metadata": {},
5050
"output_type": "execute_result"
5151
}
@@ -65,7 +65,9 @@
6565
" z_k_min = z_0\n",
6666
" for k=1:2\n",
6767
" z[k] ~ Transition(z_k_min, u[k])\n",
68-
" c[k] ~ GoalObservation(z[k], A) where {meta=GeneralizedMeta(), pipeline=GeneralizedPipeline(vague(Categorical,8))} # Goal-observation composite node\n",
68+
" c[k] ~ GoalObservation(z[k], A) where {\n",
69+
" meta=GeneralizedMeta(), \n",
70+
" pipeline=GeneralizedPipeline(vague(Categorical,8))}\n",
6971
"\n",
7072
" z_k_min = z[k] # Reset for next slice\n",
7173
" end\n",
@@ -80,7 +82,7 @@
8082
},
8183
{
8284
"cell_type": "code",
83-
"execution_count": 58,
85+
"execution_count": 4,
8486
"metadata": {},
8587
"outputs": [],
8688
"source": [
@@ -100,7 +102,7 @@
100102
},
101103
{
102104
"cell_type": "code",
103-
"execution_count": 59,
105+
"execution_count": 5,
104106
"metadata": {},
105107
"outputs": [
106108
{
@@ -109,7 +111,7 @@
109111
"infer (generic function with 1 method)"
110112
]
111113
},
112-
"execution_count": 59,
114+
"execution_count": 5,
113115
"metadata": {},
114116
"output_type": "execute_result"
115117
}
@@ -135,18 +137,18 @@
135137
},
136138
{
137139
"cell_type": "code",
138-
"execution_count": 60,
140+
"execution_count": 6,
139141
"metadata": {},
140142
"outputs": [
141143
{
142144
"data": {
143145
"text/plain": [
144146
"Inference results:\n",
145147
" Posteriors | available for (A, z_0, z)\n",
146-
" Free Energy: | Real[5.99812, 5.68334, 5.74714, 5.72778, 5.76155, 5.70259, 5.76313, 5.97126, 5.73973, 5.86003 … 5.84994, 5.7725, 5.94248, 5.83697, 5.93566, 5.82777, 6.05745, 5.80951, 5.78432, 5.75771]\n"
148+
" Free Energy: | Real[5.80212, 5.786, 5.87746, 5.62839, 6.06458, 5.87783, 6.14873, 5.786, 5.86232, 5.87551 … 5.8343, 5.75434, 5.56626, 5.85493, 5.78456, 6.02905, 6.09559, 6.02818, 5.89753, 5.70935]\n"
147149
]
148150
},
149-
"execution_count": 60,
151+
"execution_count": 6,
150152
"metadata": {},
151153
"output_type": "execute_result"
152154
}

‎src/Part2/Rx/T-maze_Observed.ipynb

+13-11
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 13,
5+
"execution_count": 1,
66
"metadata": {},
77
"outputs": [
88
{
@@ -21,7 +21,7 @@
2121
},
2222
{
2323
"cell_type": "code",
24-
"execution_count": 14,
24+
"execution_count": 2,
2525
"metadata": {},
2626
"outputs": [],
2727
"source": [
@@ -36,7 +36,7 @@
3636
},
3737
{
3838
"cell_type": "code",
39-
"execution_count": 15,
39+
"execution_count": 3,
4040
"metadata": {},
4141
"outputs": [
4242
{
@@ -45,7 +45,7 @@
4545
"structured (generic function with 1 method)"
4646
]
4747
},
48-
"execution_count": 15,
48+
"execution_count": 3,
4949
"metadata": {},
5050
"output_type": "execute_result"
5151
}
@@ -65,7 +65,9 @@
6565
" z_k_min = z_0\n",
6666
" for k=1:2\n",
6767
" z[k] ~ Transition(z_k_min, u[k])\n",
68-
" c[k] ~ GoalObservation(z[k], A) where {meta=ObservedMeta(ones(16)./16)} # Goal-observation composite node\n",
68+
" c[k] ~ GoalObservation(z[k], A) where {\n",
69+
" meta=GeneralizedMeta(ones(16)./16), \n",
70+
" pipeline=GeneralizedPipeline(vague(Categorical, 8))}\n",
6971
"\n",
7072
" z_k_min = z[k] # Reset for next slice\n",
7173
" end\n",
@@ -79,7 +81,7 @@
7981
},
8082
{
8183
"cell_type": "code",
82-
"execution_count": 16,
84+
"execution_count": 4,
8385
"metadata": {},
8486
"outputs": [],
8587
"source": [
@@ -99,7 +101,7 @@
99101
},
100102
{
101103
"cell_type": "code",
102-
"execution_count": 17,
104+
"execution_count": 5,
103105
"metadata": {},
104106
"outputs": [
105107
{
@@ -108,7 +110,7 @@
108110
"infer (generic function with 1 method)"
109111
]
110112
},
111-
"execution_count": 17,
113+
"execution_count": 5,
112114
"metadata": {},
113115
"output_type": "execute_result"
114116
}
@@ -134,18 +136,18 @@
134136
},
135137
{
136138
"cell_type": "code",
137-
"execution_count": 18,
139+
"execution_count": 10,
138140
"metadata": {},
139141
"outputs": [
140142
{
141143
"data": {
142144
"text/plain": [
143145
"Inference results:\n",
144146
" Posteriors | available for (A, z_0, z)\n",
145-
" Free Energy: | Real[28.4268, 27.4004, 27.3962, 27.3684, 27.2546, 26.9013, 26.5377, 26.4734, 26.4717, 26.4717 … 26.4717, 26.4717, 26.4717, 26.4717, 26.4717, 26.4717, 26.4717, 26.4717, 26.4717, 26.4717]\n"
147+
" Free Energy: | Real[28.4283, 27.3949, 27.3765, 27.2636, 26.9394, 26.5511, 26.474, 26.4717, 26.4717, 26.4717 … 26.4717, 26.4717, 26.4717, 26.4717, 26.4717, 26.4717, 26.4717, 26.4717, 26.4717, 26.4717]\n"
146148
]
147149
},
148-
"execution_count": 18,
150+
"execution_count": 10,
149151
"metadata": {},
150152
"output_type": "execute_result"
151153
}

‎src/Part2/Rx/goal_observation.jl

+75-28
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,17 @@ struct GoalObservation end
1616
#----------
1717

1818
# Metas
19-
struct BetheMeta end # Forces explicit constraint specification to prevent mixups
20-
struct ObservedMeta{I}
21-
x::I # Pointmass value for observation
19+
struct BetheMeta{P} # Meta parameterized by x type for rule overloading
20+
x::P # Pointmass value for observation
2221
end
23-
struct GeneralizedMeta
22+
BetheMeta() = BetheMeta(nothing) # Absent observation indicated by nothing
23+
24+
struct GeneralizedMeta{P}
25+
x::P # Pointmass value for observation
2426
newton_iterations::Int64
2527
end
26-
GeneralizedMeta() = GeneralizedMeta(20) # Default number of iterations
28+
GeneralizedMeta() = GeneralizedMeta(nothing, 20)
29+
GeneralizedMeta(point) = GeneralizedMeta(point, 20)
2730

2831
# Pipelines
2932
struct BethePipeline <: AbstractNodeFunctionalDependenciesPipeline end
@@ -44,7 +47,7 @@ end
4447
function message_dependencies(pipeline::GeneralizedPipeline, nodeinterfaces, nodelocalmarginals, varcluster, cindex, iindex)
4548
if iindex === 2 # Message towards state
4649
input = ReactiveMP.messagein(nodeinterfaces[iindex])
47-
ReactiveMP.setmessage!(input, pipeline.init_message)
50+
ReactiveMP.setmessage!(input, pipeline.init_message) # Predefine breaker message
4851
return (nodeinterfaces[iindex],) # Include inbound message on state
4952
else
5053
return ()
@@ -61,14 +64,14 @@ function marginal_dependencies(::GeneralizedPipeline, nodeinterfaces, nodelocalm
6164
end
6265

6366

64-
#-------------------
65-
# Bethe Update Rules
66-
#-------------------
67+
#------------------------------
68+
# Unobserved Bethe Update Rules
69+
#------------------------------
6770

6871
@rule GoalObservation(:c, Marginalisation) (q_c::Union{Dirichlet, PointMass},
6972
q_z::Categorical,
7073
q_A::Union{SampleList, MatrixDirichlet, PointMass},
71-
meta::BetheMeta) = begin
74+
meta::BetheMeta{Nothing}) = begin
7275
log_c = mean(log, q_c)
7376
z = probvec(q_z)
7477
log_A = mean(log, q_A)
@@ -82,7 +85,7 @@ end
8285
@rule GoalObservation(:z, Marginalisation) (q_c::Union{Dirichlet, PointMass},
8386
q_z::Categorical,
8487
q_A::Union{SampleList, MatrixDirichlet, PointMass},
85-
meta::BetheMeta) = begin
88+
meta::BetheMeta{Nothing}) = begin
8689
log_c = mean(log, q_c)
8790
z = probvec(q_z)
8891
log_A = mean(log, q_A)
@@ -96,7 +99,7 @@ end
9699
@rule GoalObservation(:A, Marginalisation) (q_c::Union{Dirichlet, PointMass},
97100
q_z::Categorical,
98101
q_A::Union{SampleList, MatrixDirichlet, PointMass},
99-
meta::BetheMeta) = begin
102+
meta::BetheMeta{Nothing}) = begin
100103
log_c = mean(log, q_c)
101104
z = probvec(q_z)
102105
log_A = mean(log, q_A)
@@ -110,7 +113,7 @@ end
110113
@average_energy GoalObservation (q_c::Union{Dirichlet, PointMass},
111114
q_z::Categorical,
112115
q_A::Union{SampleList, MatrixDirichlet, PointMass},
113-
meta::BetheMeta) = begin
116+
meta::BetheMeta{Nothing}) = begin
114117
log_c = mean(log, q_c)
115118
z = probvec(q_z)
116119
log_A = mean(log, q_A)
@@ -122,27 +125,30 @@ end
122125
end
123126

124127

125-
#----------------------
126-
# Observed Update Rules
127-
#----------------------
128+
#----------------------------
129+
# Observed Bethe Update Rules
130+
#----------------------------
128131

129-
@rule GoalObservation(:c, Marginalisation) (q_z::Categorical,
132+
@rule GoalObservation(:c, Marginalisation) (q_c::Union{Dirichlet, PointMass}, # Unused
133+
q_z::Categorical,
130134
q_A::Union{SampleList, MatrixDirichlet, PointMass},
131-
meta::ObservedMeta) = begin
135+
meta::BetheMeta{<:AbstractVector}) = begin
132136
return Dirichlet(meta.x .+ 1)
133137
end
134138

135139
@rule GoalObservation(:z, Marginalisation) (q_c::Union{Dirichlet, PointMass},
140+
q_z::Categorical, # Unused
136141
q_A::Union{SampleList, MatrixDirichlet, PointMass},
137-
meta::ObservedMeta) = begin
142+
meta::BetheMeta{<:AbstractVector}) = begin
138143
log_A = mean(log, q_A)
139144

140145
return Categorical(softmax(log_A'*meta.x))
141146
end
142147

143148
@rule GoalObservation(:A, Marginalisation) (q_c::Union{Dirichlet, PointMass},
144149
q_z::Categorical,
145-
meta::ObservedMeta) = begin
150+
q_A::Union{SampleList, MatrixDirichlet, PointMass}, # Unused
151+
meta::BetheMeta{<:AbstractVector}) = begin
146152
z = probvec(q_z)
147153

148154
return MatrixDirichlet(meta.x*z' .+ 1)
@@ -151,7 +157,7 @@ end
151157
@average_energy GoalObservation (q_c::Union{Dirichlet, PointMass},
152158
q_z::Categorical,
153159
q_A::Union{SampleList, MatrixDirichlet, PointMass},
154-
meta::ObservedMeta) = begin
160+
meta::BetheMeta{<:AbstractVector}) = begin
155161
log_c = mean(log, q_c)
156162
z = probvec(q_z)
157163
log_A = mean(log, q_A)
@@ -160,13 +166,13 @@ end
160166
end
161167

162168

163-
#-------------------------
164-
# Generalized Update Rules
165-
#-------------------------
169+
#------------------------------------
170+
# Unobserved Generalized Update Rules
171+
#------------------------------------
166172

167173
@rule GoalObservation(:c, Marginalisation) (q_z::Categorical,
168174
q_A::Union{SampleList, MatrixDirichlet, PointMass},
169-
meta::GeneralizedMeta) = begin
175+
meta::GeneralizedMeta{Nothing}) = begin
170176
z = probvec(q_z)
171177
A = mean(q_A)
172178

@@ -177,7 +183,7 @@ end
177183
q_c::Union{Dirichlet, PointMass},
178184
q_z::Categorical,
179185
q_A::Union{SampleList, MatrixDirichlet, PointMass},
180-
meta::GeneralizedMeta) = begin
186+
meta::GeneralizedMeta{Nothing}) = begin
181187
d = probvec(m_z)
182188
log_c = mean(log, q_c)
183189
z_0 = probvec(q_z)
@@ -200,7 +206,7 @@ end
200206
@rule GoalObservation(:A, Marginalisation) (q_c::Union{Dirichlet, PointMass},
201207
q_z::Categorical,
202208
q_A::Union{SampleList, MatrixDirichlet, PointMass},
203-
meta::GeneralizedMeta) = begin
209+
meta::GeneralizedMeta{Nothing}) = begin
204210
log_c = mean(log, q_c)
205211
z = probvec(q_z)
206212
A_bar = mean(q_A)
@@ -213,10 +219,51 @@ end
213219
@average_energy GoalObservation (q_c::Union{Dirichlet, PointMass},
214220
q_z::Categorical,
215221
q_A::Union{SampleList, MatrixDirichlet, PointMass},
216-
meta::GeneralizedMeta) = begin
222+
meta::GeneralizedMeta{Nothing}) = begin
217223
log_c = mean(log, q_c)
218224
z = probvec(q_z)
219225
(A, h_A) = mean_h(q_A)
220226

221227
return z'*h_A - (A*z)'*(log_c - safelog.(A*z))
228+
end
229+
230+
231+
#----------------------------------
232+
# Observed Generalized Update Rules
233+
#----------------------------------
234+
235+
@rule GoalObservation(:c, Marginalisation) (q_z::Categorical, # Unused
236+
q_A::Union{SampleList, MatrixDirichlet, PointMass}, # Unused
237+
meta::GeneralizedMeta{<:AbstractVector}) = begin
238+
return Dirichlet(meta.x .+ 1)
239+
end
240+
241+
@rule GoalObservation(:z, Marginalisation) (m_z::Categorical, # Unused
242+
q_c::Union{Dirichlet, PointMass}, # Unused
243+
q_z::Categorical, # Unused
244+
q_A::Union{SampleList, MatrixDirichlet, PointMass},
245+
meta::GeneralizedMeta{<:AbstractVector}) = begin
246+
log_A = mean(log, q_A)
247+
248+
return Categorical(softmax(log_A'*meta.x))
249+
end
250+
251+
@rule GoalObservation(:A, Marginalisation) (q_c::Union{Dirichlet, PointMass}, # Unused
252+
q_z::Categorical,
253+
q_A::Union{SampleList, MatrixDirichlet, PointMass}, # Unused
254+
meta::GeneralizedMeta{<:AbstractVector}) = begin
255+
z = probvec(q_z)
256+
257+
return MatrixDirichlet(meta.x*z' .+ 1)
258+
end
259+
260+
@average_energy GoalObservation (q_c::Union{Dirichlet, PointMass},
261+
q_z::Categorical,
262+
q_A::Union{SampleList, MatrixDirichlet, PointMass},
263+
meta::GeneralizedMeta{<:AbstractVector}) = begin
264+
log_c = mean(log, q_c)
265+
z = probvec(q_z)
266+
log_A = mean(log, q_A)
267+
268+
return -meta.x'*(log_A*z + log_c)
222269
end

0 commit comments

Comments
 (0)
Please sign in to comment.