@@ -16,14 +16,17 @@ struct GoalObservation end
16
16
# ----------
17
17
18
18
# Metas
19
- struct BetheMeta end # Forces explicit constraint specification to prevent mixups
20
- struct ObservedMeta{I}
21
- x:: I # Pointmass value for observation
19
+ struct BetheMeta{P} # Meta parameterized by x type for rule overloading
20
+ x:: P # Pointmass value for observation
22
21
end
23
- struct GeneralizedMeta
22
+ BetheMeta () = BetheMeta (nothing ) # Absent observation indicated by nothing
23
+
24
+ struct GeneralizedMeta{P}
25
+ x:: P # Pointmass value for observation
24
26
newton_iterations:: Int64
25
27
end
26
- GeneralizedMeta () = GeneralizedMeta (20 ) # Default number of iterations
28
+ GeneralizedMeta () = GeneralizedMeta (nothing , 20 )
29
+ GeneralizedMeta (point) = GeneralizedMeta (point, 20 )
27
30
28
31
# Pipelines
29
32
struct BethePipeline <: AbstractNodeFunctionalDependenciesPipeline end
44
47
function message_dependencies (pipeline:: GeneralizedPipeline , nodeinterfaces, nodelocalmarginals, varcluster, cindex, iindex)
45
48
if iindex === 2 # Message towards state
46
49
input = ReactiveMP. messagein (nodeinterfaces[iindex])
47
- ReactiveMP. setmessage! (input, pipeline. init_message)
50
+ ReactiveMP. setmessage! (input, pipeline. init_message) # Predefine breaker message
48
51
return (nodeinterfaces[iindex],) # Include inbound message on state
49
52
else
50
53
return ()
@@ -61,14 +64,14 @@ function marginal_dependencies(::GeneralizedPipeline, nodeinterfaces, nodelocalm
61
64
end
62
65
63
66
64
- # -------------------
65
- # Bethe Update Rules
66
- # -------------------
67
+ # ------------------------------
68
+ # Unobserved Bethe Update Rules
69
+ # ------------------------------
67
70
68
71
@rule GoalObservation (:c , Marginalisation) (q_c:: Union{Dirichlet, PointMass} ,
69
72
q_z:: Categorical ,
70
73
q_A:: Union{SampleList, MatrixDirichlet, PointMass} ,
71
- meta:: BetheMeta ) = begin
74
+ meta:: BetheMeta{Nothing} ) = begin
72
75
log_c = mean (log, q_c)
73
76
z = probvec (q_z)
74
77
log_A = mean (log, q_A)
82
85
@rule GoalObservation (:z , Marginalisation) (q_c:: Union{Dirichlet, PointMass} ,
83
86
q_z:: Categorical ,
84
87
q_A:: Union{SampleList, MatrixDirichlet, PointMass} ,
85
- meta:: BetheMeta ) = begin
88
+ meta:: BetheMeta{Nothing} ) = begin
86
89
log_c = mean (log, q_c)
87
90
z = probvec (q_z)
88
91
log_A = mean (log, q_A)
96
99
@rule GoalObservation (:A , Marginalisation) (q_c:: Union{Dirichlet, PointMass} ,
97
100
q_z:: Categorical ,
98
101
q_A:: Union{SampleList, MatrixDirichlet, PointMass} ,
99
- meta:: BetheMeta ) = begin
102
+ meta:: BetheMeta{Nothing} ) = begin
100
103
log_c = mean (log, q_c)
101
104
z = probvec (q_z)
102
105
log_A = mean (log, q_A)
110
113
@average_energy GoalObservation (q_c:: Union{Dirichlet, PointMass} ,
111
114
q_z:: Categorical ,
112
115
q_A:: Union{SampleList, MatrixDirichlet, PointMass} ,
113
- meta:: BetheMeta ) = begin
116
+ meta:: BetheMeta{Nothing} ) = begin
114
117
log_c = mean (log, q_c)
115
118
z = probvec (q_z)
116
119
log_A = mean (log, q_A)
@@ -122,27 +125,30 @@ end
122
125
end
123
126
124
127
125
- # ----------------------
126
- # Observed Update Rules
127
- # ----------------------
128
+ # ----------------------------
129
+ # Observed Bethe Update Rules
130
+ # ----------------------------
128
131
129
- @rule GoalObservation (:c , Marginalisation) (q_z:: Categorical ,
132
+ @rule GoalObservation (:c , Marginalisation) (q_c:: Union{Dirichlet, PointMass} , # Unused
133
+ q_z:: Categorical ,
130
134
q_A:: Union{SampleList, MatrixDirichlet, PointMass} ,
131
- meta:: ObservedMeta ) = begin
135
+ meta:: BetheMeta{<:AbstractVector} ) = begin
132
136
return Dirichlet (meta. x .+ 1 )
133
137
end
134
138
135
139
@rule GoalObservation (:z , Marginalisation) (q_c:: Union{Dirichlet, PointMass} ,
140
+ q_z:: Categorical , # Unused
136
141
q_A:: Union{SampleList, MatrixDirichlet, PointMass} ,
137
- meta:: ObservedMeta ) = begin
142
+ meta:: BetheMeta{<:AbstractVector} ) = begin
138
143
log_A = mean (log, q_A)
139
144
140
145
return Categorical (softmax (log_A' * meta. x))
141
146
end
142
147
143
148
@rule GoalObservation (:A , Marginalisation) (q_c:: Union{Dirichlet, PointMass} ,
144
149
q_z:: Categorical ,
145
- meta:: ObservedMeta ) = begin
150
+ q_A:: Union{SampleList, MatrixDirichlet, PointMass} , # Unused
151
+ meta:: BetheMeta{<:AbstractVector} ) = begin
146
152
z = probvec (q_z)
147
153
148
154
return MatrixDirichlet (meta. x* z' .+ 1 )
151
157
@average_energy GoalObservation (q_c:: Union{Dirichlet, PointMass} ,
152
158
q_z:: Categorical ,
153
159
q_A:: Union{SampleList, MatrixDirichlet, PointMass} ,
154
- meta:: ObservedMeta ) = begin
160
+ meta:: BetheMeta{<:AbstractVector} ) = begin
155
161
log_c = mean (log, q_c)
156
162
z = probvec (q_z)
157
163
log_A = mean (log, q_A)
@@ -160,13 +166,13 @@ end
160
166
end
161
167
162
168
163
- # -------------------------
164
- # Generalized Update Rules
165
- # -------------------------
169
+ # ------------------------------------
170
+ # Unobserved Generalized Update Rules
171
+ # ------------------------------------
166
172
167
173
@rule GoalObservation (:c , Marginalisation) (q_z:: Categorical ,
168
174
q_A:: Union{SampleList, MatrixDirichlet, PointMass} ,
169
- meta:: GeneralizedMeta ) = begin
175
+ meta:: GeneralizedMeta{Nothing} ) = begin
170
176
z = probvec (q_z)
171
177
A = mean (q_A)
172
178
177
183
q_c:: Union{Dirichlet, PointMass} ,
178
184
q_z:: Categorical ,
179
185
q_A:: Union{SampleList, MatrixDirichlet, PointMass} ,
180
- meta:: GeneralizedMeta ) = begin
186
+ meta:: GeneralizedMeta{Nothing} ) = begin
181
187
d = probvec (m_z)
182
188
log_c = mean (log, q_c)
183
189
z_0 = probvec (q_z)
200
206
@rule GoalObservation (:A , Marginalisation) (q_c:: Union{Dirichlet, PointMass} ,
201
207
q_z:: Categorical ,
202
208
q_A:: Union{SampleList, MatrixDirichlet, PointMass} ,
203
- meta:: GeneralizedMeta ) = begin
209
+ meta:: GeneralizedMeta{Nothing} ) = begin
204
210
log_c = mean (log, q_c)
205
211
z = probvec (q_z)
206
212
A_bar = mean (q_A)
@@ -213,10 +219,51 @@ end
213
219
@average_energy GoalObservation (q_c:: Union{Dirichlet, PointMass} ,
214
220
q_z:: Categorical ,
215
221
q_A:: Union{SampleList, MatrixDirichlet, PointMass} ,
216
- meta:: GeneralizedMeta ) = begin
222
+ meta:: GeneralizedMeta{Nothing} ) = begin
217
223
log_c = mean (log, q_c)
218
224
z = probvec (q_z)
219
225
(A, h_A) = mean_h (q_A)
220
226
221
227
return z' * h_A - (A* z)' * (log_c - safelog .(A* z))
228
+ end
229
+
230
+
231
+ # ----------------------------------
232
+ # Observed Generalized Update Rules
233
+ # ----------------------------------
234
+
235
+ @rule GoalObservation (:c , Marginalisation) (q_z:: Categorical , # Unused
236
+ q_A:: Union{SampleList, MatrixDirichlet, PointMass} , # Unused
237
+ meta:: GeneralizedMeta{<:AbstractVector} ) = begin
238
+ return Dirichlet (meta. x .+ 1 )
239
+ end
240
+
241
+ @rule GoalObservation (:z , Marginalisation) (m_z:: Categorical , # Unused
242
+ q_c:: Union{Dirichlet, PointMass} , # Unused
243
+ q_z:: Categorical , # Unused
244
+ q_A:: Union{SampleList, MatrixDirichlet, PointMass} ,
245
+ meta:: GeneralizedMeta{<:AbstractVector} ) = begin
246
+ log_A = mean (log, q_A)
247
+
248
+ return Categorical (softmax (log_A' * meta. x))
249
+ end
250
+
251
+ @rule GoalObservation (:A , Marginalisation) (q_c:: Union{Dirichlet, PointMass} , # Unused
252
+ q_z:: Categorical ,
253
+ q_A:: Union{SampleList, MatrixDirichlet, PointMass} , # Unused
254
+ meta:: GeneralizedMeta{<:AbstractVector} ) = begin
255
+ z = probvec (q_z)
256
+
257
+ return MatrixDirichlet (meta. x* z' .+ 1 )
258
+ end
259
+
260
+ @average_energy GoalObservation (q_c:: Union{Dirichlet, PointMass} ,
261
+ q_z:: Categorical ,
262
+ q_A:: Union{SampleList, MatrixDirichlet, PointMass} ,
263
+ meta:: GeneralizedMeta{<:AbstractVector} ) = begin
264
+ log_c = mean (log, q_c)
265
+ z = probvec (q_z)
266
+ log_A = mean (log, q_A)
267
+
268
+ return - meta. x' * (log_A* z + log_c)
222
269
end
0 commit comments