forked from VowpalWabbit/vowpal_wabbit
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlibsearch.h
120 lines (99 loc) · 4.45 KB
/
libsearch.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#/*
COpyright (c) by respective owners including Yahoo!, Microsoft, and
individual contributors. All rights reserved. Released under a BSD
license as described in the file LICENSE.
*/
#ifndef LIBSEARCH_HOOKTASK_H
#define LIBSEARCH_HOOKTASK_H
#include "../vowpalwabbit/parser.h"
#include "../vowpalwabbit/parse_example.h"
#include "../vowpalwabbit/vw.h"
#include "../vowpalwabbit/search.h"
#include "../vowpalwabbit/search_hooktask.h"
using namespace std;
template<class INPUT, class OUTPUT> class SearchTask
{
public:
SearchTask(vw& vw_obj) : vw_obj(vw_obj), sch(*(Search::search*)vw_obj.searchstr)
{ bogus_example = VW::alloc_examples(vw_obj.p->lp.label_size, 1);
VW::read_line(vw_obj, bogus_example, (char*)"1 | x");
VW::parse_atomic_example(vw_obj, bogus_example, false);
VW::setup_example(vw_obj, bogus_example);
blank_line = VW::alloc_examples(vw_obj.p->lp.label_size, 1);
VW::read_line(vw_obj, blank_line, (char*)"");
VW::parse_atomic_example(vw_obj, blank_line, false);
VW::setup_example(vw_obj, blank_line);
HookTask::task_data* d = sch.get_task_data<HookTask::task_data>();
d->run_f = _search_run_fn;
d->run_setup_f = _search_setup_fn;
d->run_takedown_f = _search_takedown_fn;
d->run_object = this;
d->extra_data = NULL;
d->extra_data2 = NULL;
}
~SearchTask()
{ VW::dealloc_example(vw_obj.p->lp.delete_label, *bogus_example); free(bogus_example);
VW::dealloc_example(vw_obj.p->lp.delete_label, *blank_line); free(blank_line);
}
virtual void _run(Search::search&sch, INPUT& input_example, OUTPUT& output) {} // YOU MUST DEFINE THIS FUNCTION!
void _setup(Search::search&sch, INPUT& input_example, OUTPUT& output) {} // OPTIONAL
void _takedown(Search::search&sch, INPUT& input_example, OUTPUT& output) {} // OPTIONAL
void learn(INPUT& input_example, OUTPUT& output) { bogus_example->test_only = false; call_vw(input_example, output); }
void predict(INPUT& input_example, OUTPUT& output) { bogus_example->test_only = true; call_vw(input_example, output); }
protected:
vw& vw_obj;
Search::search& sch;
private:
example* bogus_example, *blank_line;
void call_vw(INPUT& input_example, OUTPUT& output)
{ HookTask::task_data* d = sch.template get_task_data<HookTask::task_data> (); // ugly calling convention :(
d->extra_data = (void*)&input_example;
d->extra_data2 = (void*)&output;
vw_obj.learn(bogus_example);
vw_obj.learn(blank_line); // this will cause our search_run_fn hook to get called
}
static void _search_run_fn(Search::search&sch)
{ HookTask::task_data* d = sch.get_task_data<HookTask::task_data>();
if ((d->run_object == NULL) || (d->extra_data == NULL) || (d->extra_data2 == NULL))
{ cerr << "error: calling _search_run_fn without setting run object" << endl;
throw exception();
}
((SearchTask*)d->run_object)->_run(sch, *(INPUT*)d->extra_data, *(OUTPUT*)d->extra_data2);
}
static void _search_setup_fn(Search::search&sch)
{ HookTask::task_data* d = sch.get_task_data<HookTask::task_data>();
if ((d->run_object == NULL) || (d->extra_data == NULL) || (d->extra_data2 == NULL))
{ cerr << "error: calling _search_setup_fn without setting run object" << endl;
throw exception();
}
((SearchTask*)d->run_object)->_setup(sch, *(INPUT*)d->extra_data, *(OUTPUT*)d->extra_data2);
}
static void _search_takedown_fn(Search::search&sch)
{ HookTask::task_data* d = sch.get_task_data<HookTask::task_data>();
if ((d->run_object == NULL) || (d->extra_data == NULL) || (d->extra_data2 == NULL))
{ cerr << "error: calling _search_takedown_fn without setting run object" << endl;
throw exception();
}
((SearchTask*)d->run_object)->_takedown(sch, *(INPUT*)d->extra_data, *(OUTPUT*)d->extra_data2);
}
};
class BuiltInTask : public SearchTask< vector<example*>, vector<uint32_t> >
{
public:
BuiltInTask(vw& vw_obj, Search::search_task* task)
: SearchTask< vector<example*>, vector<uint32_t> >(vw_obj)
{ HookTask::task_data* d = sch.get_task_data<HookTask::task_data>();
size_t num_actions = d->num_actions;
my_task = task;
if (my_task->initialize)
my_task->initialize(sch, num_actions, *d->var_map);
}
~BuiltInTask() { if (my_task->finish) my_task->finish(sch); }
void _run(Search::search& sch, vector<example*> & input_example, vector<uint32_t> & output)
{ my_task->run(sch, input_example);
sch.get_test_action_sequence(output);
}
protected:
Search::search_task* my_task;
};
#endif