-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLabel Propagation Algorithm.cpp
118 lines (106 loc) · 3.95 KB
/
Label Propagation Algorithm.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
// 目前的最终版本-提交
int n=queryUtil->getVertNum();
map<int,int> node_label;
map<int,set<int>> node_ins;
map<int,map<int,int>> label_counts;
for(int i=0;i<n;i++){
node_label[i]=i;
}
if(!directed){
for(int kk=0;kk<k;kk++){
int flag=0;
for(int i=0;i<n;i++){
map<int,int> label_counts2;
for(auto& pred:pred_set){
// map<int,int> label_counts;
set<int> used;
used.insert(i);
int out_size=queryUtil->getOutSize(i,pred);
int in_size=queryUtil->getInSize(i,pred);
for(int in_count=0;in_count<in_size;in_count++){
int node=queryUtil->getInVertID(i,pred,in_count);
if(used.find(node)!=used.end())continue;
used.insert(node);
if(label_counts2.find(node_label[node])==label_counts2.end())label_counts2[node_label[node]]=1;
else label_counts2[node_label[node]]+=1;
}
for(int out_count=0;out_count<out_size;out_count++){
int node=queryUtil->getOutVertID(i,pred,out_count);
if(used.find(node)!=used.end())continue;
used.insert(node);
if(label_counts2.find(node_label[node])==label_counts2.end())label_counts2[node_label[node]]=1;
else label_counts2[node_label[node]]+=1;
}
}
int max_label=-1;
int max_count=0;
for(auto& label_count:label_counts2){
if(label_count.second>max_count){
max_count=label_count.second;
max_label=label_count.first;
}
}
if(max_label!=node_label[i]&&max_label!=-1){
node_label[i]=max_label;
flag=1;
}
}
if(flag==0)break;
}
}else{
for(int i=0;i<n;i++){
set<int> used;
used.insert(i);
for(auto& pred:pred_set){
int in_size=queryUtil->getInSize(i,pred);
for(int in_count=0;in_count<in_size;in_count++){
int node=queryUtil->getInVertID(i,pred,in_count);
if(used.find(node)!=used.end())continue;
used.insert(node);
if(label_counts[i].find(node_label[node])==label_counts[i].end())label_counts[i][node_label[node]]=1;
else label_counts[i][node_label[node]]+=1;
node_ins[node].insert(i);
}
}
used.clear();
}
for(int kk=0;kk<k;kk++){
int flag=0;
for(int i=0;i<n;i++){
int max_label=-1;
int max_count=0;
for(auto& label_count:label_counts[i]){
if(label_count.second>max_count){
max_count=label_count.second;
max_label=label_count.first;
}
}
if(max_label!=node_label[i]&&max_label!=-1){
flag=1;
for(auto& node_in:node_ins[i]){
label_counts[node_in][node_label[i]]-=1;
if(label_counts[node_in].find(max_label)==label_counts[node_in].end())label_counts[node_in][max_label]=1;
else label_counts[node_in][max_label]++;
}
node_label[i]=max_label;
}
}
if(flag==0)break;
}
}
map<int,vector<int>> res;
for(auto& node_labelx:node_label){
res[node_labelx.second].push_back(node_labelx.first);
}
stringstream ss;
int res_len=res.size();
ss<<"[";
int res_index=1;
for(auto& re:res){
ss<<"[";
ss<<queryUtil->getPathString(re.second);
ss<<"]";
if(res_index++!=res_len)ss<<",";
}
ss<<"]";
return ss.str();