Skip to content

Commit

Permalink
Pretty good with darts
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Feb 11, 2015
1 parent 03c99ed commit 892aa0e
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 72 deletions.
145 changes: 75 additions & 70 deletions bin/darts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,84 +188,89 @@ int main(int argc, char** argv)
truth[i] = truth[i] - 1; // 1~1000 to 0~999
}
iftruth.close();
const double accuracy_guarantee = 0.99;
const double epsilon = 1 - accuracy_guarantee;
const double confidence = 0.95;
double min_lambda = 0;
double max_lambda = ((1 - epsilon) * max_rewards - min_rewards) / epsilon;
printf("Binary search between %lf and %lf\n", min_lambda, max_lambda);
// Top-sort for used
for (i = 0; i < 1000; i++)
{
ccv_darts_tree_t* node = used[wnids[i]];
node->inverse_high = 0;
for (map<string, ccv_darts_tree_t*>::iterator it = node->super.begin(); it != node->super.end(); ++it)
recompute_inverse_high(it->second);
}
vector<ccv_darts_tree_t*> sort;
for (i = 0; i > -(int)used.size(); i--) // Arbitrary large negative numbers
{
int found = 0;
for (map<string, ccv_darts_tree_t*>::iterator it = used.begin(); it != used.end(); ++it)
if (it->second->inverse_high == i)
sort.push_back(it->second), found = 1;
if (!found)
break;
}
int max_high = abs(i);
double* accuracy_at_high = (double*)calloc(max_high, sizeof(double));
assert(used.size() == sort.size());
for (i = 0; i < 25; i++)
const double accuracy_guarantees[4] = {0.85, 0.9, 0.95, 0.99};
int t;
for (t = 0; t < 3; t++)
{
double current_lambda = (min_lambda + max_lambda) / 2.0;
double correct = 0;
for (j = 0; j < max_high; j++)
accuracy_at_high[j] = 0;
for (j = 0; j < 50000; j++)
const double accuracy_guarantee = accuracy_guarantees[t];
const double epsilon = 1 - accuracy_guarantee;
const double confidence = 0.95;
double min_lambda = 0;
double max_lambda = ((1 - epsilon) * max_rewards - min_rewards) / epsilon;
printf("Binary search between %lf and %lf\n", min_lambda, max_lambda);
// Top-sort for used
for (i = 0; i < 1000; i++)
{
if (j % 291 == 0 || j == 49999)
FLUSH("At %d / %d, going over %d / %d", i + 1, 25, j + 1, 50000);
for (map<string, ccv_darts_tree_t*>::iterator it = used.begin(); it != used.end(); ++it)
it->second->probs = 0;
for (k = 0; k < 1000; k++)
{
ccv_darts_tree_t* node = used[wnids[k]];
node->probs = probs[j * 1000 + k];
}
for (vector<ccv_darts_tree_t*>::iterator it = sort.begin(); it != sort.end(); ++it)
{
if ((*it)->inverse_high < 0)
recompute_probs(*it);
(*it)->correct = 0;
}
ccv_darts_tree_t* truth_node = used[wnids[truth[j]]];
recompute_correctness(1, truth_node);
string max_wnid = "";
double max_rewards = 0;
ccv_darts_tree_t* node = used[wnids[i]];
node->inverse_high = 0;
for (map<string, ccv_darts_tree_t*>::iterator it = node->super.begin(); it != node->super.end(); ++it)
recompute_inverse_high(it->second);
}
vector<ccv_darts_tree_t*> sort;
for (i = 0; i > -(int)used.size(); i--) // Arbitrary large negative numbers
{
int found = 0;
for (map<string, ccv_darts_tree_t*>::iterator it = used.begin(); it != used.end(); ++it)
if (it->second->inverse_high == i)
sort.push_back(it->second), found = 1;
if (!found)
break;
}
int max_high = abs(i);
double* accuracy_at_high = (double*)calloc(max_high, sizeof(double));
assert(used.size() == sort.size());
for (i = 0; i < 25; i++)
{
double current_lambda = (min_lambda + max_lambda) / 2.0;
double correct = 0;
for (j = 0; j < max_high; j++)
accuracy_at_high[j] = 0;
for (j = 0; j < 50000; j++)
{
double rewards = (it->second->info_gain + current_lambda) * it->second->probs;
if (rewards > max_rewards)
if (j % 291 == 0 || j == 49999)
FLUSH("At %d / %d, going over %d / %d", i + 1, 25, j + 1, 50000);
for (map<string, ccv_darts_tree_t*>::iterator it = used.begin(); it != used.end(); ++it)
it->second->probs = 0;
for (k = 0; k < 1000; k++)
{
ccv_darts_tree_t* node = used[wnids[k]];
node->probs = probs[j * 1000 + k];
}
for (vector<ccv_darts_tree_t*>::iterator it = sort.begin(); it != sort.end(); ++it)
{
if ((*it)->inverse_high < 0)
recompute_probs(*it);
(*it)->correct = 0;
}
ccv_darts_tree_t* truth_node = used[wnids[truth[j]]];
recompute_correctness(1, truth_node);
string max_wnid = "";
double max_rewards = 0;
for (map<string, ccv_darts_tree_t*>::iterator it = used.begin(); it != used.end(); ++it)
{
max_wnid = it->first;
max_rewards = rewards;
double rewards = (it->second->info_gain + current_lambda) * it->second->probs;
if (rewards > max_rewards)
{
max_wnid = it->first;
max_rewards = rewards;
}
}
assert(max_wnid.size() > 0);
correct += used[max_wnid]->correct;
accuracy_at_high[abs(used[max_wnid]->inverse_high)] += 1;
}
assert(max_wnid.size() > 0);
correct += used[max_wnid]->correct;
accuracy_at_high[abs(used[max_wnid]->inverse_high)] += 1;
double accuracy = correct / 50000.0;
double accuracy_lower_bound = binofit(accuracy, 50000, confidence);
if (accuracy_lower_bound > accuracy_guarantee)
max_lambda = current_lambda;
else
min_lambda = current_lambda;
FLUSH("At %d / %d, lambda %lf, at accuracy %.3lf%%, accuracy lower bound %.3lf%%\n", i + 1, 25, current_lambda, accuracy * 100, accuracy_lower_bound * 100);
printf("accuracy at: (%d, %.3lf%%)", 0, accuracy_at_high[0] * 100 / 50000.0);
for (j = 1; j < max_high; j++)
printf(", (%d, %.3lf%%)", j, accuracy_at_high[j] * 100 / 50000.0);
printf("\n");
}
double accuracy = correct / 50000.0;
double accuracy_lower_bound = binofit(accuracy, 50000, confidence);
if (accuracy_lower_bound > accuracy_guarantee)
max_lambda = current_lambda;
else
min_lambda = current_lambda;
printf("\nAt %d / %d, lambda %lf, at accuracy %.3lf%%, accuracy lower bound %.3lf%%\n", i + 1, 25, current_lambda, accuracy * 100, accuracy_lower_bound * 100);
printf("accuracy at: (%d, %.3lf%%)", 0, accuracy_at_high[0] * 100 / 50000.0);
for (j = 1; j < max_high; j++)
printf(", (%d, %.3lf%%)", j, accuracy_at_high[j] * 100 / 50000.0);
printf("\n");
}
free(probs);
return 0;
Expand Down
4 changes: 2 additions & 2 deletions bin/image-net.c
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,10 @@ int main(int argc, char** argv)
for (i = 0; i < depth; i++)
{
layer_params[i].w.decay = 0.0005;
layer_params[i].w.learn_rate = 0.0001;
layer_params[i].w.learn_rate = 0.01;
layer_params[i].w.momentum = 0.9;
layer_params[i].bias.decay = 0;
layer_params[i].bias.learn_rate = 0.0001;
layer_params[i].bias.learn_rate = 0.01;
layer_params[i].bias.momentum = 0.9;
}
// set the two full connect layers to last with dropout rate at 0.5
Expand Down

0 comments on commit 892aa0e

Please sign in to comment.