Skip to content

Commit

Permalink
Separating loss and gradient computation from model
Browse files Browse the repository at this point in the history
Summary:
This commit splits the computation of the loss and the subsequent gradient into different classes. Each Loss class implements its own logic and contains the underlying data needed for the computation.

There is a behavioural change :
- now, `NegativeSampling` also uses the sigmoid output instead of softmax output for the prediction. Before this commit, it used sigmoid for train, softmax for prediction.

We are passing many information to `Loss` classes. There are two things we should think next:
- `State` class
- `Model` classes instead of `Loss` classes

Reviewed By: EdouardGrave

Differential Revision: D13359871

fbshipit-source-id: 2f53eaafb800a9a2742817aa113af5f6bd7e282e
  • Loading branch information
Celebio authored and facebook-github-bot committed Feb 22, 2019
1 parent c35edc3 commit 9ddcabd
Show file tree
Hide file tree
Showing 13 changed files with 627 additions and 407 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ set(HEADER_FILES
src/densematrix.h
src/dictionary.h
src/fasttext.h
src/loss.h
src/matrix.h
src/meter.h
src/model.h
Expand All @@ -36,6 +37,7 @@ set(SOURCE_FILES
src/densematrix.cc
src/dictionary.cc
src/fasttext.cc
src/loss.cc
src/main.cc
src/matrix.cc
src/meter.cc
Expand Down
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

CXX = c++
CXXFLAGS = -pthread -std=c++0x -march=native
OBJS = args.o matrix.o dictionary.o productquantizer.o densematrix.o quantmatrix.o vector.o model.o utils.o meter.o fasttext.o
OBJS = args.o matrix.o dictionary.o loss.o productquantizer.o densematrix.o quantmatrix.o vector.o model.o utils.o meter.o fasttext.o
INCLUDES = -I.

opt: CXXFLAGS += -O3 -funroll-loops
Expand All @@ -29,6 +29,9 @@ matrix.o: src/matrix.cc src/matrix.h
dictionary.o: src/dictionary.cc src/dictionary.h src/args.h
$(CXX) $(CXXFLAGS) -c src/dictionary.cc

loss.o: src/loss.cc src/loss.h src/basematrix.h src/real.h
$(CXX) $(CXXFLAGS) -c src/loss.cc

productquantizer.o: src/productquantizer.cc src/productquantizer.h src/utils.h
$(CXX) $(CXXFLAGS) -c src/productquantizer.cc

Expand Down
43 changes: 32 additions & 11 deletions src/fasttext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
*/

#include "fasttext.h"
#include "loss.h"
#include "quantmatrix.h"

#include <algorithm>
Expand All @@ -28,6 +29,24 @@ bool comparePairs(
const std::pair<real, std::string>& l,
const std::pair<real, std::string>& r);

std::shared_ptr<Loss> FastText::createLoss(std::shared_ptr<Matrix>& output) {
loss_name lossName = args_->loss;
switch (lossName) {
case loss_name::hs:
return std::make_shared<HierarchicalSoftmaxLoss>(
output, getTargetCounts());
case loss_name::ns:
return std::make_shared<NegativeSamplingLoss>(
output, args_->neg, getTargetCounts());
case loss_name::softmax:
return std::make_shared<SoftmaxLoss>(output);
case loss_name::ova:
return std::make_shared<OneVsAllLoss>(output);
default:
throw std::runtime_error("Unknown loss");
}
}

FastText::FastText() : quant_(false), wordVectors_(nullptr) {}

void FastText::addInputVector(Vector& vec, int32_t ind) const {
Expand Down Expand Up @@ -237,8 +256,8 @@ void FastText::loadModel(std::istream& in) {
}
output_->load(in);

model_ =
std::make_shared<Model>(input_, output_, args_, getTargetCounts(), 0);
auto loss = createLoss(output_);
model_ = std::make_shared<Model>(input_, output_, args_, loss, 0);
}

void FastText::printInfo(real progress, real loss, std::ostream& log_stream) {
Expand Down Expand Up @@ -297,7 +316,6 @@ void FastText::quantize(const Args& qargs) {
std::dynamic_pointer_cast<DenseMatrix>(input_);
std::shared_ptr<DenseMatrix> output =
std::dynamic_pointer_cast<DenseMatrix>(output_);

if (qargs.cutoff > 0 && qargs.cutoff < input->size(0)) {
auto idx = selectEmbeddings(qargs.cutoff);
dict_->prune(idx);
Expand All @@ -314,6 +332,8 @@ void FastText::quantize(const Args& qargs) {
args_->lr = qargs.lr;
args_->thread = qargs.thread;
args_->verbose = qargs.verbose;
auto loss = createLoss(output_);
model_ = std::make_shared<Model>(input, output, args_, loss, 0);
startThreads();
}
}
Expand All @@ -327,8 +347,8 @@ void FastText::quantize(const Args& qargs) {
}

quant_ = true;
model_ =
std::make_shared<Model>(input_, output_, args_, getTargetCounts(), 0);
auto loss = createLoss(output_);
model_ = std::make_shared<Model>(input_, output_, args_, loss, 0);
}

void FastText::supervised(
Expand Down Expand Up @@ -393,7 +413,7 @@ void FastText::test(std::istream& in, int32_t k, real threshold, Meter& meter)
const {
std::vector<int32_t> line;
std::vector<int32_t> labels;
std::vector<std::pair<real, int32_t>> predictions;
Predictions predictions;

while (in.peek() != EOF) {
line.clear();
Expand All @@ -411,7 +431,7 @@ void FastText::test(std::istream& in, int32_t k, real threshold, Meter& meter)
void FastText::predict(
int32_t k,
const std::vector<int32_t>& words,
std::vector<std::pair<real, int32_t>>& predictions,
Predictions& predictions,
real threshold) const {
if (words.empty()) {
return;
Expand All @@ -433,7 +453,7 @@ bool FastText::predictLine(

std::vector<int32_t> words, labels;
dict_->getLine(in, words, labels);
std::vector<std::pair<real, int32_t>> linePredictions;
Predictions linePredictions;
predict(k, words, linePredictions, threshold);
for (const auto& p : linePredictions) {
predictions.push_back(
Expand Down Expand Up @@ -624,7 +644,8 @@ void FastText::trainThread(int32_t threadId) {
std::ifstream ifs(args_->input);
utils::seek(ifs, threadId * utils::size(ifs) / args_->thread);

Model model(input_, output_, args_, getTargetCounts(), threadId);
assert(model_);
Model model(*model_, threadId);

const int64_t ntokens = dict_->ntokens();
int64_t localTokenCount = 0;
Expand Down Expand Up @@ -742,9 +763,9 @@ void FastText::train(const Args& args) {
input_ = createRandomMatrix();
}
output_ = createTrainOutputMatrix();
auto loss = createLoss(output_);
model_ = std::make_shared<Model>(input_, output_, args_, loss, 0);
startThreads();
model_ =
std::make_shared<Model>(input_, output_, args_, getTargetCounts(), 0);
}

void FastText::startThreads() {
Expand Down
3 changes: 2 additions & 1 deletion src/fasttext.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class FastText {
std::shared_ptr<Matrix> createRandomMatrix() const;
std::shared_ptr<Matrix> createTrainOutputMatrix() const;
std::vector<int64_t> getTargetCounts() const;
std::shared_ptr<Loss> createLoss(std::shared_ptr<Matrix>& output);

bool quant_;
int32_t version;
Expand Down Expand Up @@ -111,7 +112,7 @@ class FastText {
void predict(
int32_t k,
const std::vector<int32_t>& words,
std::vector<std::pair<real, int32_t>>& predictions,
Predictions& predictions,
real threshold = 0.0) const;

bool predictLine(
Expand Down
Loading

0 comments on commit 9ddcabd

Please sign in to comment.