This repository has been archived by the owner on Mar 19, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 4.7k
/
Copy pathmodel.cc
93 lines (78 loc) · 2.12 KB
/
model.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "model.h"
#include "loss.h"
#include "utils.h"
#include <algorithm>
#include <stdexcept>
namespace fasttext {
Model::State::State(int32_t hiddenSize, int32_t outputSize, int32_t seed)
: lossValue_(0.0),
nexamples_(0),
hidden(hiddenSize),
output(outputSize),
grad(hiddenSize),
rng(seed) {}
real Model::State::getLoss() const {
return lossValue_ / nexamples_;
}
void Model::State::incrementNExamples(real loss) {
lossValue_ += loss;
nexamples_++;
}
Model::Model(
std::shared_ptr<Matrix> wi,
std::shared_ptr<Matrix> wo,
std::shared_ptr<Loss> loss,
bool normalizeGradient)
: wi_(wi), wo_(wo), loss_(loss), normalizeGradient_(normalizeGradient) {}
void Model::computeHidden(const std::vector<int32_t>& input, State& state)
const {
Vector& hidden = state.hidden;
wi_->averageRowsToVector(hidden, input);
}
void Model::predict(
const std::vector<int32_t>& input,
int32_t k,
real threshold,
Predictions& heap,
State& state) const {
if (k == Model::kUnlimitedPredictions) {
k = wo_->size(0); // output size
} else if (k <= 0) {
throw std::invalid_argument("k needs to be 1 or higher!");
}
heap.reserve(k + 1);
computeHidden(input, state);
loss_->predict(k, threshold, heap, state);
}
void Model::update(
const std::vector<int32_t>& input,
const std::vector<int32_t>& targets,
int32_t targetIndex,
real lr,
State& state) {
if (input.size() == 0) {
return;
}
computeHidden(input, state);
Vector& grad = state.grad;
grad.zero();
real lossValue = loss_->forward(targets, targetIndex, state, lr, true);
state.incrementNExamples(lossValue);
if (normalizeGradient_) {
grad.mul(1.0 / input.size());
}
for (auto it = input.cbegin(); it != input.cend(); ++it) {
wi_->addVectorToRow(grad, *it, 1.0);
}
}
real Model::std_log(real x) const {
return std::log(x + 1e-5);
}
} // namespace fasttext