-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathEM_sort.cpp
211 lines (166 loc) · 7.11 KB
/
EM_sort.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
#include <iostream>
#include <fstream>
#include <string>
#include <algorithm>
#include <vector>
#include <set>
#include <utility>
#include <tuple>
#include <cstring>
#include <cstdio>
#include <cassert>
#include "globals.hh"
#include "Block.hh"
#include "ParallelBoundedQueue.hh"
#include "generic_EM_classes.hh"
#include "EM_sort.hh"
using namespace std;
// Interprets the strings as integers (no leading zeros allowed) and returns:
// -1 if x < y
// 0 if x = y
// 1 if x > y
int compare_as_numbers(const char* x, const char* y){
int64_t nx = strlen(x);
int64_t ny = strlen(y);
if(nx < ny) return -1;
if(nx > ny) return 1;
return strcmp(x,y);
}
bool memcmp_variable_binary_records(const char* x, const char* y){
int64_t nx = parse_big_endian_int64_t(x);
int64_t ny = parse_big_endian_int64_t(y);
int64_t c = memcmp(x + 8, y + 8, min(nx-8,ny-8));
if(c < 0){
return true;
}
else if(c > 0){
return false;
}
else { // c == 0
return nx < ny;
}
}
template <typename record_reader_t, typename record_writer_t>
void merge_files_generic(const std::function<bool(const char* x, const char* y)>& cmp, int64_t& merge_count, record_reader_t& reader, record_writer_t& writer){
write_log("Doing merge number " + to_string(merge_count), LogLevel::MINOR);
vector<char*> input_buffers;
vector<int64_t> input_buffer_sizes;
for(int64_t i = 0; i < reader.get_num_files(); i++){
int64_t buf_size = 1024;
input_buffers.push_back((char*)malloc(buf_size)); // Freed at the end of this function
input_buffer_sizes.push_back(buf_size);
}
auto cmp_wrap = [&](pair<char*, int64_t> x, pair<char*, int64_t> y) {
return cmp(x.first, y.first);
};
multiset<pair<char*, int64_t>, decltype(cmp_wrap)> Q(cmp_wrap); // Priority queue: (record, file index).
// Must be a multiset because a regular set will not add an element if it is equal
// to another according to the comparison function.
// Initialize priority queue
for(int64_t i = 0; i < reader.get_num_files(); i++){
reader.read_record(i, &input_buffers[i], &input_buffer_sizes[i]);
Q.insert({input_buffers[i], i});
}
// Do the merge
while(!Q.empty()){
char* record; int64_t stream_idx;
std::tie(record, stream_idx) = *(Q.begin());
Q.erase(Q.begin()); // pop
// Write the current data
writer.write(record);
// Read next value from the file
if(reader.read_record(stream_idx, &input_buffers[stream_idx], &input_buffer_sizes[stream_idx]))
Q.insert({input_buffers[stream_idx], stream_idx});
}
writer.close_file();
for(int64_t i = 0; i < reader.get_num_files(); i++){
free(input_buffers[i]);
}
merge_count++;
}
template <typename record_reader_t, typename record_writer_t>
void EM_sort_generic(string infile, string outfile, const std::function<bool(const char* x, const char* y)>& cmp, int64_t RAM_bytes, Generic_Block_Producer* producer, vector<Generic_Block_Consumer*> consumers, record_reader_t& reader, record_writer_t& writer){
int64_t max_files = 512;
// Number of blocks in the memory at once:
// - 1 per consumer thread in processing
// - 1 in the queue
// - 1 with the producer loading (there is only one producer)
// So if block size is B, we have (n_threads + 2)*B blocks in memory at a time
// So we have the equation (n_threads + 2)*B = RAM_bytes. Solve for B:
int64_t B = RAM_bytes / (consumers.size() + 2);
B = min(B, (int64_t)(std::filesystem::file_size(infile) / consumers.size())); // Make sure all threads have work
vector<string> block_files;
ParallelBoundedQueue<Generic_Block*> Q(1); // 1 byte = basically only one block can be in the queue at a time
vector<std::thread> threads;
// Create producer
threads.push_back(std::thread([&Q,&infile,&B,&producer](){
producer->run(Q, B);
}));
// Create consumers
for(int64_t i = 0; i < consumers.size(); i++){
threads.push_back(std::thread([i, &Q, &block_files, &cmp, &consumers](){
write_log("Starting thread " + to_string(i), LogLevel::MINOR);
consumers[i]->run(Q,cmp);
write_log("Thread " + to_string(i) + ": done", LogLevel::MINOR);
}));
}
for(std::thread& t : threads) t.join();
for(auto& consumer : consumers){
for(string filename : consumer->get_outfilenames()){
block_files.push_back(filename);
}
}
// Merge blocks
int64_t merge_count = 0;
vector<string> cur_round = block_files;
while(cur_round.size() > 1){
vector<string> next_round;
for(int64_t i = 0; i < cur_round.size(); i += max_files){
// Merge
vector<string> to_merge(cur_round.begin() + i, cur_round.begin() + min(i + max_files, (int64_t)cur_round.size()));
string round_file = get_temp_file_manager().create_filename();
writer.open_file(round_file);
reader.open_files(to_merge);
merge_files_generic(cmp, merge_count, reader, writer);
next_round.push_back(round_file);
writer.close_file();
reader.close_files();
// Clear files
for(int64_t j = i; j < min(i+max_files, (int64_t)cur_round.size()); j++){
get_temp_file_manager().delete_file(cur_round[j].c_str());
}
}
cur_round = next_round;
}
// Move final merge file to outfile
if(cur_round.size() == 0) // Function was called with empty input file
std::filesystem::rename(infile, outfile);
else{
assert(cur_round.size() == 1);
std::filesystem::rename(cur_round[0], outfile);
get_temp_file_manager().delete_file(cur_round[0].c_str());
}
}
// Constant size records of record_size bytes each
void EM_sort_constant_binary(string infile, string outfile, const std::function<bool(const char* x, const char* y)>& cmp, int64_t RAM_bytes, int64_t record_size, int64_t n_threads){
Generic_Block_Producer* producer = new Constant_Block_Producer(infile, record_size);
vector<Generic_Block_Consumer*> consumers;
for(int64_t i = 0; i < n_threads; i++)
consumers.push_back(new Block_Consumer(i));
Constant_Record_Reader reader(record_size);
Constant_Record_Writer writer(record_size);
EM_sort_generic(infile, outfile, cmp, RAM_bytes, producer, consumers, reader, writer);
delete producer;
for(Generic_Block_Consumer* C : consumers) delete C;
}
void EM_sort_variable_length_records(string infile, string outfile, const std::function<bool(const char* x, const char* y)>& cmp, int64_t RAM_bytes, int64_t n_threads){
Generic_Block_Producer* producer = new Variable_Block_Producer(infile);
vector<Generic_Block_Consumer*> consumers;
for(int64_t i = 0; i < n_threads; i++)
consumers.push_back(new Block_Consumer(i));
Variable_Record_Reader reader;
Variable_Record_Writer writer;
EM_sort_generic(infile, outfile, cmp, RAM_bytes, producer, consumers, reader, writer);
delete producer;
for(Generic_Block_Consumer* C : consumers) delete C;;
}