Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 360425904
  • Loading branch information
roark-google authored and copybara-github committed Mar 2, 2021
1 parent 76e9fdf commit fd7b47a
Show file tree
Hide file tree
Showing 12 changed files with 69 additions and 60 deletions.
41 changes: 19 additions & 22 deletions mozolm/grpc/mozolm_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,34 +32,32 @@

namespace mozolm {
namespace grpc {

namespace {

// Uses random number to choose position according to returned distribution.
int GetRandomPosition(
int64 normalization,
const std::vector<std::pair<int64, int32>>& count_idx_pair_vector) {
const std::vector<std::pair<double, int32>>& prob_idx_pair_vector) {
absl::BitGen gen;
const double thresh = absl::Uniform(gen, 0, 100);
double total_prob = 0.0;
const double norm = normalization / 100.0;
int pos = 0;
while (total_prob < thresh &&
pos < static_cast<int64>(count_idx_pair_vector.size())) {
pos < static_cast<int64>(prob_idx_pair_vector.size())) {
total_prob +=
static_cast<double>(count_idx_pair_vector[pos++].first) / norm;
static_cast<double>(prob_idx_pair_vector[pos++].first);
}
if (pos > 0) --pos;
return pos;
}
} // namespace

bool MozoLMClient::GetLMScores(
const std::string& context_string, int initial_state, int64* normalization,
std::vector<std::pair<int64, int32>>* count_idx_pair_vector) {
const std::string& context_string, int initial_state, double* normalization,
std::vector<std::pair<double, int32>>* prob_idx_pair_vector) {
GOOGLE_CHECK_NE(completion_client_, nullptr);
GOOGLE_CHECK(completion_client_->GetLMScore(context_string, initial_state,
timeout_, normalization,
count_idx_pair_vector));
prob_idx_pair_vector));
return *normalization > 0;
}

Expand All @@ -85,15 +83,15 @@ bool MozoLMClient::RandGen(const std::string& context_string,
int chosen = -1; // Initialize to non-zero to enter loop.
while (success && chosen != 0 &&
static_cast<int>(result->length()) < max_length) {
std::vector<std::pair<int64, int32>> count_idx_pair_vector;
int64 normalization;
std::vector<std::pair<double, int32>> prob_idx_pair_vector;
double normalization;
bool success = GetLMScores(/*context_string=*/"", state, &normalization,
&count_idx_pair_vector);
&prob_idx_pair_vector);
if (success) {
const int pos = GetRandomPosition(normalization, count_idx_pair_vector);
const int pos = GetRandomPosition(prob_idx_pair_vector);
GOOGLE_CHECK_GE(pos, 0);
GOOGLE_CHECK_LT(pos, count_idx_pair_vector.size());
chosen = count_idx_pair_vector[pos].second;
GOOGLE_CHECK_LT(pos, prob_idx_pair_vector.size());
chosen = prob_idx_pair_vector[pos].second;
if (chosen > 0) {
// Only updates if not end-of-string.
const std::string next_sym = utf8::EncodeUnicodeChar(chosen);
Expand All @@ -112,18 +110,17 @@ bool MozoLMClient::RandGen(const std::string& context_string,

bool MozoLMClient::OneKbestSample(int k_best, const std::string& context_string,
std::string* result) {
std::vector<std::pair<int64, int32>> count_idx_pair_vector;
int64 normalization;
std::vector<std::pair<double, int32>> prob_idx_pair_vector;
double normalization;
const bool success = GetLMScores(context_string, /*initial_state=*/-1,
&normalization, &count_idx_pair_vector);
&normalization, &prob_idx_pair_vector);
if (success) {
*result = std::to_string(k_best) + "-best prob continuations:";
double norm = normalization / 100.0;
for (int i = 0; i < k_best; i++) {
// TODO: fix for general utf8 symbols.
*result = absl::StrFormat(
"%s %c(%5.2f)", *result, count_idx_pair_vector[i].second,
static_cast<double>(count_idx_pair_vector[i].first) / norm);
*result = absl::StrFormat("%s %c(%5.2f)", *result,
prob_idx_pair_vector[i].second,
prob_idx_pair_vector[i].first);
}
}
return success;
Expand Down
6 changes: 3 additions & 3 deletions mozolm/grpc/mozolm_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ class MozoLMClient {
bool RandGen(const std::string& context_string, std::string* result);

private:
// Requests LMScores from model, populates vector of count/index pairs and
// Requests LMScores from model, populates vector of prob/index pairs and
// updates normalization count, returning true if successful.
bool GetLMScores(const std::string& context_string, int initial_state,
int64* normalization,
std::vector<std::pair<int64, int32>>* count_idx_pair_vector);
double* normalization,
std::vector<std::pair<double, int32>>* prob_idx_pair_vector);

// Requests next state from model and returns result.
int64 GetNextState(const std::string& context_string, int initial_state);
Expand Down
24 changes: 12 additions & 12 deletions mozolm/grpc/mozolm_client_async_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,18 @@ namespace mozolm {
namespace grpc {
namespace {

// Retrieves reverse frequency sorted vector of count/utf8-symbol pairs and
// Retrieves reverse probability sorted vector of prob/utf8-symbol pairs and
// normalization.
void RetrieveLMScores(
LMScores response, int64* normalization,
std::vector<std::pair<int64, int32>>* count_idx_pair_vector) {
count_idx_pair_vector->reserve(response.counts_size());
for (int i = 0; i < response.counts_size(); i++) {
count_idx_pair_vector->push_back(
std::make_pair(response.counts(i), response.utf8_syms(i)));
LMScores response, double* normalization,
std::vector<std::pair<double, int32>>* prob_idx_pair_vector) {
prob_idx_pair_vector->reserve(response.probabilities_size());
for (int i = 0; i < response.probabilities_size(); i++) {
prob_idx_pair_vector->push_back(
std::make_pair(response.probabilities(i), response.utf8_syms(i)));
}
std::sort(count_idx_pair_vector->begin(), count_idx_pair_vector->end());
std::reverse(count_idx_pair_vector->begin(), count_idx_pair_vector->end());
std::sort(prob_idx_pair_vector->begin(), prob_idx_pair_vector->end());
std::reverse(prob_idx_pair_vector->begin(), prob_idx_pair_vector->end());
*normalization = response.normalization();
}

Expand All @@ -59,8 +59,8 @@ MozoLMClientAsyncImpl::MozoLMClientAsyncImpl(

bool MozoLMClientAsyncImpl::GetLMScore(
const std::string& context_str, int initial_state, double timeout,
int64* normalization,
std::vector<std::pair<int64, int32>>* count_idx_pair_vector) {
double* normalization,
std::vector<std::pair<double, int32>>* prob_idx_pair_vector) {
// Sets up ClientContext, request and response.
::grpc::ClientContext context;
context.set_deadline(gpr_time_add(
Expand All @@ -81,7 +81,7 @@ bool MozoLMClientAsyncImpl::GetLMScore(
GOOGLE_LOG(ERROR) << status.error_message();
} else {
// Retrieves information from response if RPC call was successful.
RetrieveLMScores(response, normalization, count_idx_pair_vector);
RetrieveLMScores(response, normalization, prob_idx_pair_vector);
}
return status.ok();
}
Expand Down
4 changes: 2 additions & 2 deletions mozolm/grpc/mozolm_client_async_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ class MozoLMClientAsyncImpl {
// Seeks the language models scores given the initial state and context
// string. Any errors are logged.
bool GetLMScore(const std::string& context_str, int initial_state,
double timeout, int64* normalization,
std::vector<std::pair<int64, int32>>* count_idx_pair_vector);
double timeout, double* normalization,
std::vector<std::pair<double, int32>>* prob_idx_pair_vector);

// Seeks the next model state given the initial state and context string. Any
// errors are logged.
Expand Down
2 changes: 1 addition & 1 deletion mozolm/grpc/mozolm_server_async_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class MozoLMServerAsyncImpl : public MozoLMServer::AsyncService {
const GetContextRequest* request,
NextState* response);

// Updates the counts/norm by count and advances state, returning counts at
// Updates the counts/norm by count and advances state, returning probs at
// new state.
::grpc::Status HandleRequest(::grpc::ServerContext* context,
const UpdateLMScoresRequest* request,
Expand Down
24 changes: 16 additions & 8 deletions mozolm/grpc/mozolm_server_async_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ namespace mozolm {
namespace grpc {
namespace {

constexpr float kFloatDelta = 0.00001; // Delta for float comparisons.

using ::grpc::Status;
using ::grpc::ServerContext;

Expand All @@ -53,7 +55,7 @@ void CheckGetLMScoresError(int state, ::grpc::StatusCode error_code) {
}

// Check that a call to GetLMScores succeeds and returns the expected
// counts and normalization.
// probabilities and normalization.
void CheckGetLMScores(int state) {
MozoLMServerAsyncImplMock server;
ServerContext context;
Expand All @@ -63,10 +65,11 @@ void CheckGetLMScores(int state) {
const GetContextRequest* request_ptr(&request);
Status status = server.HandleRequest(&context, request_ptr, &response);
ASSERT_TRUE(status.ok());
ASSERT_EQ(response.counts_size(), 28);
ASSERT_EQ(response.normalization(), 28);
ASSERT_EQ(response.probabilities_size(), 28);
ASSERT_NEAR(response.normalization(), 28.0, kFloatDelta);
double uniform_value = static_cast<double>(1.0) / static_cast<double>(28);
for (int i = 0; i < 28; i++) {
ASSERT_EQ(response.counts(i), 1);
ASSERT_NEAR(response.probabilities(i), uniform_value, kFloatDelta);
}
}

Expand Down Expand Up @@ -99,13 +102,18 @@ void CheckUpdateLMScoresContent(int state, int count) {
Status status = server.HandleRequest(&context, request_ptr, &response);
ASSERT_TRUE(status.ok());
// Check if the response matches the request
ASSERT_EQ(response.counts_size(), 28);
ASSERT_EQ(response.normalization(), 28 + count);
ASSERT_EQ(response.probabilities_size(), 28);
ASSERT_NEAR(response.normalization(), 28.0 + count, kFloatDelta);
double rest_value =
static_cast<double>(1.0) / static_cast<double>(28 + count);
for (int i = 0; i < 28; i++) {
if (i == state) {
ASSERT_EQ(response.counts(i), 1 + count);
ASSERT_NEAR(
response.probabilities(i),
static_cast<double>(1 + count) / static_cast<double>(28 + count),
kFloatDelta);
} else {
ASSERT_EQ(response.counts(i), 1);
ASSERT_NEAR(response.probabilities(i), rest_value, kFloatDelta);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion mozolm/grpc/service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ message UpdateLMScoresRequest {
}

service MozoLMServer {
// Returns the counts and normalization for given state.
// Returns the probs and normalization for given state.
rpc GetLMScores(GetContextRequest) returns (LMScores) {
// errors: invalid state;
}
Expand Down
2 changes: 1 addition & 1 deletion mozolm/models/language_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class LanguageModel {
// than zero, the model will start at the start state of the model.
int ContextState(const std::string &context = "", int init_state = -1);

// Copies the counts and normalization from the given state into the response.
// Copies the probs and normalization from the given state into the response.
virtual bool ExtractLMScores(int state, LMScores* response) {
return false; // Requires a derived class to complete.
}
Expand Down
10 changes: 6 additions & 4 deletions mozolm/models/lm_scores.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ message LMScores {
// Individual symbols for which counts are returned.
repeated int32 utf8_syms = 1;

// Counts for each of the symbols.
repeated int64 counts = 2;
// Probabilities for each of the symbols.
repeated double probabilities = 2;

// Sum of counts for normalizations.
int64 normalization = 3;
// Normalization to recreate (smoothed) counts from probabilities. This may
// have some value when mixing models, for methods that, e.g., take into
// account the number of observations when calculating mixing values.
double normalization = 3;
}
2 changes: 1 addition & 1 deletion mozolm/models/opengrm_ngram_char_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class OpenGrmNGramCharModel : public LanguageModel {
// Provides the state reached from state following utf8_sym.
int NextState(int state, int utf8_sym) override;

// Copies the counts and normalization from the given state into the response.
// Copies the probs and normalization from the given state into the response.
bool ExtractLMScores(int state, LMScores* response) override;

// Updates the count for the utf8_sym and normalization at the current state.
Expand Down
6 changes: 4 additions & 2 deletions mozolm/models/simple_bigram_char_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ absl::StatusOr<std::vector<int32>> ReadVocabSymbols(
}

absl::Status ReadCountMatrix(const std::string& in_counts, int rows,
std::vector<int64>* utf8_normalizer,
std::vector<double>* utf8_normalizer,
std::vector<std::vector<int64>>* bigram_matrix) {
int idx = 0;
std::ifstream infile(in_counts);
Expand Down Expand Up @@ -158,7 +158,9 @@ bool SimpleBigramCharModel::ExtractLMScores(int state, LMScores* response) {
response->set_normalization(utf8_normalizer_[state]);
for (size_t i = 0; i < bigram_counts_[state].size(); i++) {
response->add_utf8_syms(utf8_indices_[i]);
response->add_counts(bigram_counts_[state][i]);
response->add_probabilities(
static_cast<double>(bigram_counts_[state][i]) /
utf8_normalizer_[state]);
}
}
return valid_state;
Expand Down
6 changes: 3 additions & 3 deletions mozolm/models/simple_bigram_char_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,19 @@ class SimpleBigramCharModel : public LanguageModel {
// Provides the state reached from state following utf8_sym.
int NextState(int state, int utf8_sym) override;

// Copies the counts and normalization from the given state into the response.
// Copies the probs and normalization from the given state into the response.
bool ExtractLMScores(int state, LMScores* response)
ABSL_LOCKS_EXCLUDED(normalizer_lock_, counts_lock_) override;

// Updates the count for the utf8_sym and normalization at the current state.
// Updates the counts for the utf8_sym and normalization at the current state.
bool UpdateLMCounts(int32 state, int32 utf8_sym, int64 count)
ABSL_LOCKS_EXCLUDED(normalizer_lock_, counts_lock_) override;

private:
std::vector<int32> utf8_indices_; // utf8 symbols in vocabulary.
std::vector<int32> vocab_indices_; // dimension is utf8 symbol, stores index.
// stores normalization constant for each item in vocabulary.
std::vector<int64> utf8_normalizer_ ABSL_GUARDED_BY(normalizer_lock_);
std::vector<double> utf8_normalizer_ ABSL_GUARDED_BY(normalizer_lock_);
absl::Mutex normalizer_lock_; // protects normalizer information.
// Stores counts for each bigram in dense square matrix.
std::vector<std::vector<int64>> bigram_counts_ ABSL_GUARDED_BY(counts_lock_);
Expand Down

0 comments on commit fd7b47a

Please sign in to comment.