Skip to content

Commit 53f5542

Browse files
Merge pull request project-codeflare#29 from project-codeflare/fix_yref
Fix yref assignment for pipeline PREDICT and SCORE
2 parents bbbed8f + d2aa799 commit 53f5542

File tree

5 files changed

+76
-84
lines changed

5 files changed

+76
-84
lines changed

codeflare/pipelines/Runtime.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class ExecutionType(Enum):
6262

6363

6464
@ray.remote
65-
def execute_or_node_remote(node: dm.EstimatorNode, mode: ExecutionType, xy_ref: dm.XYRef):
65+
def execute_or_node_remote(node: dm.EstimatorNode, mode: ExecutionType, xy_ref: dm.XYRef, is_outputNode: bool):
6666
"""
6767
Helper remote function that executes an OR node. As such, this is a remote task that runs the estimator
6868
in the provided mode with the data pointed to by XYRef. The key aspect to note here is the choice of input
@@ -107,9 +107,16 @@ def execute_or_node_remote(node: dm.EstimatorNode, mode: ExecutionType, xy_ref:
107107
elif mode == ExecutionType.SCORE:
108108
if base.is_classifier(estimator) or base.is_regressor(estimator):
109109
estimator = node.get_estimator()
110-
res_Xref = ray.put(estimator.score(X, y))
111-
result = dm.XYRef(res_Xref, xy_ref.get_yref(), prev_node_ptr, prev_node_ptr, [xy_ref])
112-
return result
110+
if is_outputNode:
111+
score_ref = ray.put(estimator.score(X, y))
112+
result = dm.XYRef(score_ref, score_ref, prev_node_ptr, prev_node_ptr, [xy_ref])
113+
return result
114+
else:
115+
res_xy = estimator.score(xy_list)
116+
res_xref = ray.put(res_xy.get_x())
117+
res_yref = ray.put(res_xy.get_y())
118+
result = dm.XYRef(res_xref, res_yref, prev_node_ptr, prev_node_ptr, Xyref_list)
119+
return result
113120
else:
114121
res_Xref = ray.put(estimator.transform(X))
115122
result = dm.XYRef(res_Xref, xy_ref.get_yref(), prev_node_ptr, prev_node_ptr, [xy_ref])
@@ -118,16 +125,24 @@ def execute_or_node_remote(node: dm.EstimatorNode, mode: ExecutionType, xy_ref:
118125
elif mode == ExecutionType.PREDICT:
119126
# Test mode does not clone as it is a simple predict or transform
120127
if base.is_classifier(estimator) or base.is_regressor(estimator):
121-
res_Xref = ray.put(estimator.predict(X))
122-
result = dm.XYRef(res_Xref, xy_ref.get_yref(), prev_node_ptr, prev_node_ptr, [xy_ref])
123-
return result
128+
if is_outputNode:
129+
predict_ref = ray.put(estimator.predict(X))
130+
result = dm.XYRef(predict_ref, predict_ref, prev_node_ptr, prev_node_ptr, [xy_ref])
131+
return result
132+
else:
133+
res_xy = estimator.predict(xy_list)
134+
res_xref = ray.put(res_xy.get_x())
135+
res_yref = ray.put(res_xy.get_y())
136+
137+
result = dm.XYRef(res_xref, res_yref, prev_node_ptr, prev_node_ptr, Xyref_list)
138+
return result
124139
else:
125140
res_Xref = ray.put(estimator.transform(X))
126141
result = dm.XYRef(res_Xref, xy_ref.get_yref(), prev_node_ptr, prev_node_ptr, [xy_ref])
127142
return result
128143

129144

130-
def execute_or_node(node, pre_edges, edge_args, post_edges, mode: ExecutionType):
145+
def execute_or_node(node, pre_edges, edge_args, post_edges, mode: ExecutionType, is_outputNode):
131146
"""
132147
Inner method that executes the estimator node parallelizing at the level of input objects. This defines the
133148
strategy of execution of the node, in this case, parallel for each object that is input. The function takes
@@ -147,7 +162,7 @@ def execute_or_node(node, pre_edges, edge_args, post_edges, mode: ExecutionType)
147162
exec_xyrefs = []
148163
for xy_ref_ptr in Xyref_ptrs:
149164
xy_ref = ray.get(xy_ref_ptr)
150-
inner_result = execute_or_node_remote.remote(node, mode, xy_ref)
165+
inner_result = execute_or_node_remote.remote(node, mode, xy_ref, is_outputNode)
151166
exec_xyrefs.append(inner_result)
152167

153168
for post_edge in post_edges:
@@ -337,7 +352,7 @@ def execute_pipeline(pipeline: dm.Pipeline, mode: ExecutionType, pipeline_input:
337352
pre_edges = pipeline.get_pre_edges(node)
338353
post_edges = pipeline.get_post_edges(node)
339354
if node.get_node_input_type() == dm.NodeInputType.OR:
340-
execute_or_node(node, pre_edges, edge_args, post_edges, mode)
355+
execute_or_node(node, pre_edges, edge_args, post_edges, mode, pipeline.is_output(node))
341356
elif node.get_node_input_type() == dm.NodeInputType.AND:
342357
execute_and_node(node, pre_edges, edge_args, post_edges, mode)
343358

@@ -662,4 +677,4 @@ def save(pipeline_output: dm.PipelineOutput, xy_ref: dm.XYRef, filehandle):
662677
:return: None
663678
"""
664679
pipeline = select_pipeline(pipeline_output, xy_ref)
665-
pipeline.save(filehandle)
680+
pipeline.save(filehandle)

codeflare/pipelines/tests/test_pipeline_predict.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ def test_pipeline_predict():
7777

7878
predict_clf_output = predict_output.get_xyrefs(node_clf)
7979

80-
#y_pred = ray.get(predict_clf_output[0].get_yref())
81-
y_pred = ray.get(predict_clf_output[0].get_Xref())
80+
y_pred = ray.get(predict_clf_output[0].get_yref())
81+
#y_pred = ray.get(predict_clf_output[0].get_Xref())
8282

8383

8484
report_codeflare = classification_report(y_test, y_pred)

notebooks/plot_nca_classification.ipynb

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
},
1111
{
1212
"cell_type": "code",
13-
"execution_count": 4,
13+
"execution_count": 1,
1414
"metadata": {},
1515
"outputs": [],
1616
"source": [
@@ -37,7 +37,7 @@
3737
},
3838
{
3939
"cell_type": "code",
40-
"execution_count": 36,
40+
"execution_count": 2,
4141
"metadata": {},
4242
"outputs": [
4343
{
@@ -150,14 +150,14 @@
150150
},
151151
{
152152
"cell_type": "code",
153-
"execution_count": 37,
153+
"execution_count": 3,
154154
"metadata": {},
155155
"outputs": [
156156
{
157157
"name": "stderr",
158158
"output_type": "stream",
159159
"text": [
160-
"2021-06-08 16:33:25,975\tINFO services.py:1267 -- View the Ray dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8266\u001b[39m\u001b[22m\n"
160+
"2021-07-22 17:14:51,530\tINFO services.py:1267 -- View the Ray dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8266\u001b[39m\u001b[22m\n"
161161
]
162162
},
163163
{
@@ -243,7 +243,7 @@
243243
"\n",
244244
"knn_pipeline = rt.select_pipeline(pipeline_fitted, pipeline_fitted.get_xyrefs(node_knn)[0])\n",
245245
"knn_score = ray.get(rt.execute_pipeline(knn_pipeline, ExecutionType.SCORE, test_input)\n",
246-
" .get_xyrefs(node_knn)[0].get_Xref())\n",
246+
" .get_xyrefs(node_knn)[0].get_yref())\n",
247247
"\n",
248248
"# Plot the decision boundary. For that, we will assign a color to each\n",
249249
"# point in the mesh [x_min, x_max]x[y_min, y_max].\n",
@@ -254,7 +254,7 @@
254254
"predict_input.add_xy_arg(node_scalar, dm.Xy(meshinput, meshlabel))\n",
255255
"\n",
256256
"Z = ray.get(rt.execute_pipeline(knn_pipeline, ExecutionType.PREDICT, predict_input)\n",
257-
" .get_xyrefs(node_knn)[0].get_Xref())\n",
257+
" .get_xyrefs(node_knn)[0].get_yref())\n",
258258
"\n",
259259
"# Put the result into a color plot\n",
260260
"Z = Z.reshape(xx.shape)\n",
@@ -273,10 +273,10 @@
273273
"name = names[1]\n",
274274
"nca_pipeline = rt.select_pipeline(pipeline_fitted, pipeline_fitted.get_xyrefs(node_knn_post_nca)[0])\n",
275275
"nca_score = ray.get(rt.execute_pipeline(nca_pipeline, ExecutionType.SCORE, test_input)\n",
276-
" .get_xyrefs(node_knn_post_nca)[0].get_Xref())\n",
276+
" .get_xyrefs(node_knn_post_nca)[0].get_yref())\n",
277277
"\n",
278278
"Z = ray.get(rt.execute_pipeline(nca_pipeline, ExecutionType.PREDICT, predict_input)\n",
279-
" .get_xyrefs(node_knn_post_nca)[0].get_Xref())\n",
279+
" .get_xyrefs(node_knn_post_nca)[0].get_yref())\n",
280280
"\n",
281281
"# Put the result into a color plot\n",
282282
"Z = Z.reshape(xx.shape)\n",

notebooks/plot_rbm_logistic_classification.ipynb

Lines changed: 29 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,16 @@
2727
"output_type": "stream",
2828
"text": [
2929
"Automatically created module for IPython interactive environment\n",
30-
"[BernoulliRBM] Iteration 1, pseudo-likelihood = -25.57, time = 0.11s\n",
30+
"[BernoulliRBM] Iteration 1, pseudo-likelihood = -25.57, time = 0.13s\n",
3131
"[BernoulliRBM] Iteration 2, pseudo-likelihood = -23.68, time = 0.18s\n",
3232
"[BernoulliRBM] Iteration 3, pseudo-likelihood = -22.74, time = 0.18s\n",
33-
"[BernoulliRBM] Iteration 4, pseudo-likelihood = -21.83, time = 0.16s\n",
34-
"[BernoulliRBM] Iteration 5, pseudo-likelihood = -21.62, time = 0.16s\n",
35-
"[BernoulliRBM] Iteration 6, pseudo-likelihood = -21.11, time = 0.15s\n",
36-
"[BernoulliRBM] Iteration 7, pseudo-likelihood = -20.88, time = 0.15s\n",
37-
"[BernoulliRBM] Iteration 8, pseudo-likelihood = -20.58, time = 0.13s\n",
38-
"[BernoulliRBM] Iteration 9, pseudo-likelihood = -20.32, time = 0.14s\n",
39-
"[BernoulliRBM] Iteration 10, pseudo-likelihood = -20.13, time = 0.14s\n",
33+
"[BernoulliRBM] Iteration 4, pseudo-likelihood = -21.83, time = 0.17s\n",
34+
"[BernoulliRBM] Iteration 5, pseudo-likelihood = -21.62, time = 0.17s\n",
35+
"[BernoulliRBM] Iteration 6, pseudo-likelihood = -21.11, time = 0.18s\n",
36+
"[BernoulliRBM] Iteration 7, pseudo-likelihood = -20.88, time = 0.17s\n",
37+
"[BernoulliRBM] Iteration 8, pseudo-likelihood = -20.58, time = 0.17s\n",
38+
"[BernoulliRBM] Iteration 9, pseudo-likelihood = -20.32, time = 0.17s\n",
39+
"[BernoulliRBM] Iteration 10, pseudo-likelihood = -20.13, time = 0.16s\n",
4040
"Logistic regression using RBM features:\n",
4141
" precision recall f1-score support\n",
4242
"\n",
@@ -207,60 +207,30 @@
207207
},
208208
{
209209
"cell_type": "code",
210-
"execution_count": 16,
210+
"execution_count": 3,
211211
"metadata": {},
212212
"outputs": [
213213
{
214214
"name": "stderr",
215215
"output_type": "stream",
216216
"text": [
217-
"2021-06-09 10:48:44,778\tINFO services.py:1267 -- View the Ray dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8266\u001b[39m\u001b[22m\n"
217+
"2021-07-22 17:16:19,742\tINFO services.py:1267 -- View the Ray dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8267\u001b[39m\u001b[22m\n"
218218
]
219219
},
220220
{
221221
"name": "stdout",
222222
"output_type": "stream",
223223
"text": [
224-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 1, pseudo-likelihood = -25.57, time = 0.11s\n",
225-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 1, pseudo-likelihood = -25.57, time = 0.11s\n",
226-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 1, pseudo-likelihood = -25.57, time = 0.11s\n",
227-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 1, pseudo-likelihood = -25.57, time = 0.11s\n",
228-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 2, pseudo-likelihood = -23.68, time = 0.15s\n",
229-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 2, pseudo-likelihood = -23.68, time = 0.15s\n",
230-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 2, pseudo-likelihood = -23.68, time = 0.15s\n",
231-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 2, pseudo-likelihood = -23.68, time = 0.15s\n",
232-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 3, pseudo-likelihood = -22.74, time = 0.15s\n",
233-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 3, pseudo-likelihood = -22.74, time = 0.15s\n",
234-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 3, pseudo-likelihood = -22.74, time = 0.15s\n",
235-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 3, pseudo-likelihood = -22.74, time = 0.15s\n",
236-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 4, pseudo-likelihood = -21.83, time = 0.14s\n",
237-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 4, pseudo-likelihood = -21.83, time = 0.14s\n",
238-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 4, pseudo-likelihood = -21.83, time = 0.14s\n",
239-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 4, pseudo-likelihood = -21.83, time = 0.14s\n",
240-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 5, pseudo-likelihood = -21.62, time = 0.15s\n",
241-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 5, pseudo-likelihood = -21.62, time = 0.15s\n",
242-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 5, pseudo-likelihood = -21.62, time = 0.15s\n",
243-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 5, pseudo-likelihood = -21.62, time = 0.15s\n",
244-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 6, pseudo-likelihood = -21.11, time = 0.14s\n",
245-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 6, pseudo-likelihood = -21.11, time = 0.14s\n",
246-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 6, pseudo-likelihood = -21.11, time = 0.14s\n",
247-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 6, pseudo-likelihood = -21.11, time = 0.14s\n",
248-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 7, pseudo-likelihood = -20.88, time = 0.15s\n",
249-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 7, pseudo-likelihood = -20.88, time = 0.15s\n",
250-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 7, pseudo-likelihood = -20.88, time = 0.15s\n",
251-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 7, pseudo-likelihood = -20.88, time = 0.15s\n",
252-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 8, pseudo-likelihood = -20.58, time = 0.15s\n",
253-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 8, pseudo-likelihood = -20.58, time = 0.15s\n",
254-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 8, pseudo-likelihood = -20.58, time = 0.15s\n",
255-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 8, pseudo-likelihood = -20.58, time = 0.15s\n",
256-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 9, pseudo-likelihood = -20.32, time = 0.13s\n",
257-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 9, pseudo-likelihood = -20.32, time = 0.13s\n",
258-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 9, pseudo-likelihood = -20.32, time = 0.13s\n",
259-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 9, pseudo-likelihood = -20.32, time = 0.13s\n",
260-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 10, pseudo-likelihood = -20.13, time = 0.15s\n",
261-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 10, pseudo-likelihood = -20.13, time = 0.15s\n",
262-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 10, pseudo-likelihood = -20.13, time = 0.15s\n",
263-
"\u001b[2m\u001b[36m(pid=4523)\u001b[0m [BernoulliRBM] Iteration 10, pseudo-likelihood = -20.13, time = 0.15s\n",
224+
"\u001b[2m\u001b[36m(pid=12180)\u001b[0m [BernoulliRBM] Iteration 1, pseudo-likelihood = -25.57, time = 0.16s\n",
225+
"\u001b[2m\u001b[36m(pid=12180)\u001b[0m [BernoulliRBM] Iteration 2, pseudo-likelihood = -23.68, time = 0.22s\n",
226+
"\u001b[2m\u001b[36m(pid=12180)\u001b[0m [BernoulliRBM] Iteration 3, pseudo-likelihood = -22.74, time = 0.22s\n",
227+
"\u001b[2m\u001b[36m(pid=12180)\u001b[0m [BernoulliRBM] Iteration 4, pseudo-likelihood = -21.83, time = 0.22s\n",
228+
"\u001b[2m\u001b[36m(pid=12180)\u001b[0m [BernoulliRBM] Iteration 5, pseudo-likelihood = -21.62, time = 0.22s\n",
229+
"\u001b[2m\u001b[36m(pid=12180)\u001b[0m [BernoulliRBM] Iteration 6, pseudo-likelihood = -21.11, time = 0.21s\n",
230+
"\u001b[2m\u001b[36m(pid=12180)\u001b[0m [BernoulliRBM] Iteration 7, pseudo-likelihood = -20.88, time = 0.21s\n",
231+
"\u001b[2m\u001b[36m(pid=12180)\u001b[0m [BernoulliRBM] Iteration 8, pseudo-likelihood = -20.58, time = 0.21s\n",
232+
"\u001b[2m\u001b[36m(pid=12180)\u001b[0m [BernoulliRBM] Iteration 9, pseudo-likelihood = -20.32, time = 0.21s\n",
233+
"\u001b[2m\u001b[36m(pid=12180)\u001b[0m [BernoulliRBM] Iteration 10, pseudo-likelihood = -20.13, time = 0.47s\n",
264234
"Logistic regression using RBM features:\n",
265235
" precision recall f1-score support\n",
266236
"\n",
@@ -411,14 +381,14 @@
411381
"\n",
412382
"logistic_pipeline = rt.select_pipeline(pipeline_fitted, pipeline_fitted.get_xyrefs(node_logistic)[0])\n",
413383
"Y_pred = ray.get(rt.execute_pipeline(logistic_pipeline, ExecutionType.PREDICT, predict_input)\n",
414-
" .get_xyrefs(node_logistic)[0].get_Xref())\n",
384+
" .get_xyrefs(node_logistic)[0].get_yref())\n",
415385
"\n",
416386
"print(\"Logistic regression using RBM features:\\n%s\\n\" % (\n",
417387
" metrics.classification_report(Y_test, Y_pred)))\n",
418388
"\n",
419389
"raw_pixel_pipeline = rt.select_pipeline(pipeline_fitted, pipeline_fitted.get_xyrefs(node_raw_pixel)[0])\n",
420390
"Y_pred = ray.get(rt.execute_pipeline(raw_pixel_pipeline, ExecutionType.PREDICT, predict_input)\n",
421-
" .get_xyrefs(node_raw_pixel)[0].get_Xref())\n",
391+
" .get_xyrefs(node_raw_pixel)[0].get_yref())\n",
422392
"\n",
423393
"print(\"Logistic regression using raw pixel features:\\n%s\\n\" % (\n",
424394
" metrics.classification_report(Y_test, Y_pred)))\n",
@@ -438,6 +408,13 @@
438408
"\n",
439409
"ray.shutdown()"
440410
]
411+
},
412+
{
413+
"cell_type": "code",
414+
"execution_count": null,
415+
"metadata": {},
416+
"outputs": [],
417+
"source": []
441418
}
442419
],
443420
"metadata": {

0 commit comments

Comments
 (0)