Skip to content

Commit

Permalink
add argument for lr update rate / remove verbose
Browse files Browse the repository at this point in the history
  • Loading branch information
ajoulin committed Aug 8, 2016
1 parent cd5726e commit 7867de2
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 30 deletions.
44 changes: 22 additions & 22 deletions src/args.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Args::Args() {
minn = 3;
maxn = 6;
thread = 12;
verbose = 10000;
lrUpdateRate = 100;
t = 1e-4;
label = "__label__";
}
Expand Down Expand Up @@ -60,6 +60,8 @@ void Args::parseArgs(int argc, char** argv) {
output = std::string(argv[ai + 1]);
} else if (strcmp(argv[ai], "-lr") == 0) {
lr = atof(argv[ai + 1]);
} else if (strcmp(argv[ai], "-lrUpdateRate") == 0) {
lrUpdateRate = atoi(argv[ai + 1]);
} else if (strcmp(argv[ai], "-dim") == 0) {
dim = atoi(argv[ai + 1]);
} else if (strcmp(argv[ai], "-ws") == 0) {
Expand Down Expand Up @@ -92,8 +94,6 @@ void Args::parseArgs(int argc, char** argv) {
maxn = atoi(argv[ai + 1]);
} else if (strcmp(argv[ai], "-thread") == 0) {
thread = atoi(argv[ai + 1]);
} else if (strcmp(argv[ai], "-verbose") == 0) {
verbose = atoi(argv[ai + 1]);
} else if (strcmp(argv[ai], "-t") == 0) {
t = atof(argv[ai + 1]);
} else if (strcmp(argv[ai], "-label") == 0) {
Expand All @@ -116,24 +116,24 @@ void Args::printHelp() {
std::cout
<< "\n"
<< "The following arguments are mandatory:\n"
<< " -input training file path\n"
<< " -output output file path\n\n"
<< " -input training file path\n"
<< " -output output file path\n\n"
<< "The following arguments are optional:\n"
<< " -lr learning rate [" << lr << "]\n"
<< " -dim size of word vectors [" << dim << "]\n"
<< " -ws size of the context window [" << ws << "]\n"
<< " -epoch number of epochs [" << epoch << "]\n"
<< " -minCount minimal number of word occurences [" << minCount << "]\n"
<< " -neg number of negatives sampled [" << neg << "]\n"
<< " -wordNgrams max length of word ngram [" << wordNgrams << "]\n"
<< " -loss loss function {ns, hs, softmax} [ns]\n"
<< " -bucket number of buckets [" << bucket << "]\n"
<< " -minn min length of char ngram [" << minn << "]\n"
<< " -maxn max length of char ngram [" << maxn << "]\n"
<< " -thread number of threads [" << thread << "]\n"
<< " -verbose how often to print to stdout [" << verbose << "]\n"
<< " -t sampling threshold [" << t << "]\n"
<< " -label labels prefix [" << label << "]\n"
<< " -lr learning rate [" << lr << "]\n"
<< " -lrUpdateRate change the rate of updates for the learning rate [" << lrUpdateRate << "]\n"
<< " -dim size of word vectors [" << dim << "]\n"
<< " -ws size of the context window [" << ws << "]\n"
<< " -epoch number of epochs [" << epoch << "]\n"
<< " -minCount minimal number of word occurences [" << minCount << "]\n"
<< " -neg number of negatives sampled [" << neg << "]\n"
<< " -wordNgrams max length of word ngram [" << wordNgrams << "]\n"
<< " -loss loss function {ns, hs, softmax} [ns]\n"
<< " -bucket number of buckets [" << bucket << "]\n"
<< " -minn min length of char ngram [" << minn << "]\n"
<< " -maxn max length of char ngram [" << maxn << "]\n"
<< " -thread number of threads [" << thread << "]\n"
<< " -t sampling threshold [" << t << "]\n"
<< " -label labels prefix [" << label << "]\n"
<< std::endl;
}

Expand All @@ -150,7 +150,7 @@ void Args::save(std::ofstream& ofs) {
ofs.write((char*) &(bucket), sizeof(int));
ofs.write((char*) &(minn), sizeof(int));
ofs.write((char*) &(maxn), sizeof(int));
ofs.write((char*) &(verbose), sizeof(int));
ofs.write((char*) &(lrUpdateRate), sizeof(int));
ofs.write((char*) &(t), sizeof(double));
}
}
Expand All @@ -168,7 +168,7 @@ void Args::load(std::ifstream& ifs) {
ifs.read((char*) &(bucket), sizeof(int));
ifs.read((char*) &(minn), sizeof(int));
ifs.read((char*) &(maxn), sizeof(int));
ifs.read((char*) &(verbose), sizeof(int));
ifs.read((char*) &(lrUpdateRate), sizeof(int));
ifs.read((char*) &(t), sizeof(double));
}
}
2 changes: 1 addition & 1 deletion src/args.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class Args {
std::string test;
std::string output;
double lr;
int lrUpdateRate;
int dim;
int ws;
int epoch;
Expand All @@ -34,7 +35,6 @@ class Args {
int minn;
int maxn;
int thread;
int verbose;
double t;
std::string label;

Expand Down
16 changes: 10 additions & 6 deletions src/fasttext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,16 @@ void trainThread(Dictionary& dict, Matrix& input, Matrix& output,
model.setTargetCounts(dict.getCounts(entry_type::word));
}

real progress;
const int64_t ntokens = dict.ntokens();
int64_t tokenCount = 0;
int64_t tokenCount = 0, printCount = 0, deltaCount = 0;
double loss = 0.0;
int32_t nexamples = 0;
std::vector<int32_t> line, labels;
while (info::allWords < args.epoch * ntokens) {
tokenCount += dict.getLine(ifs, line, labels, model.rng);
deltaCount = dict.getLine(ifs, line, labels, model.rng);
tokenCount += deltaCount;
printCount += deltaCount;
if (args.model == model_name::sup) {
dict.addNgrams(line, args.wordNgrams);
supervised(model, line, labels, loss, nexamples);
Expand All @@ -232,17 +235,18 @@ void trainThread(Dictionary& dict, Matrix& input, Matrix& output,
} else if (args.model == model_name::sg) {
skipgram(dict, model, line, loss, nexamples);
}

if (tokenCount > args.verbose) {
if (tokenCount > args.lrUpdateRate) {
info::allWords += tokenCount;
info::allLoss += loss;
info::allN += nexamples;
tokenCount = 0;
loss = 0.0;
nexamples = 0;
real progress = real(info::allWords) / (args.epoch * ntokens);
progress = real(info::allWords) / (args.epoch * ntokens);
model.setLearningRate(args.lr * (1.0 - progress));
if (threadId == 0) printInfo(model, progress);
if (threadId == 0) {
printInfo(model, progress);
}
}
}
if (threadId == 0) {
Expand Down
2 changes: 1 addition & 1 deletion word-vector-example.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ make

./fasttext skipgram -input "${DATADIR}"/text9 -output "${RESULTDIR}"/text9 -lr 0.025 -dim 100 \
-ws 5 -epoch 1 -minCount 5 -neg 5 -loss ns -bucket 2000000 \
-minn 3 -maxn 6 -thread 4 -verbose 1000 -t 1e-4
-minn 3 -maxn 6 -thread 4 -t 1e-4 -lrUpdateRate 100

cut -f 1,2 "${DATADIR}"/rw/rw.txt | awk '{print tolower($0)}' | tr '\t' '\n' > "${DATADIR}"/queries.txt

Expand Down

0 comments on commit 7867de2

Please sign in to comment.