Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
mishgon committed Dec 5, 2018
2 parents 3639e6b + b557e2d commit a245aa6
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
17 changes: 10 additions & 7 deletions unsupervised_mt/batch_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,26 @@ def __init__(self, dataset):
self.languages = dataset.languages
self.load_sentence = dataset.load_sentence
self.train_ids = {l: np.arange(len(dataset.train[l])) for l in self.languages}
self.test_ids = np.arange(len(dataset.test))
self.pad_index = {l: dataset.vocabs[l].get_pad(l) for l in self.languages}

def load_raw_monolingual_batch(self, batch_size, language, random_state=None):
def load_raw_monolingual_batch(self, batch_size, language, random_state=None, test=False, ids=None):
if random_state is not None:
np.random.seed(random_state)

random_ids = np.random.choice(self.train_ids[language], size=batch_size)
return [self.load_sentence(language, idx) for idx in random_ids]
if ids is None:
ids = np.random.choice(self.test_ids if test else self.train_ids[language], size=batch_size)

def load_monolingual_batch(self, batch_size, language, random_state=None):
return [self.load_sentence(language, idx, test=test) for idx in ids]

def load_monolingual_batch(self, batch_size, language, random_state=None, test=False, ids=None):
return torch.tensor(pad_monolingual_batch(
self.load_raw_monolingual_batch(batch_size, language, random_state),
self.load_raw_monolingual_batch(batch_size, language, random_state, test=test, ids=ids),
self.pad_index[language]
), dtype=torch.long).transpose(0, 1)

def load_batch(self, batch_size, random_state=None):
return {l: self.load_monolingual_batch(batch_size, l, random_state) for l in self.languages}
def load_batch(self, batch_size, random_state=None, test=False, ids=None):
return {l: self.load_monolingual_batch(batch_size, l, random_state, test=test, ids=ids) for l in self.languages}



Expand Down
8 changes: 3 additions & 5 deletions unsupervised_mt/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,9 @@ def __init__(self, corp_paths, emb_paths, pairs_paths, max_length=10, test_size=
for l in self.languages
}

def load_sentence(self, language, idx, pad=0):
return self.vocabs[language].get_indices(self.train[language][idx], language=language, pad=pad)

def load_len(self, language, idx):
return len(self.load_sentence(language, idx))
def load_sentence(self, language, idx, pad=0, test=False):
sentence = self.test[idx][0 if language == 'src' else 1] if test else self.train[language][idx]
return self.vocabs[language].get_indices(sentence, language=language, pad=pad)

def get_sos_index(self, language):
return self.vocabs[language].get_sos(language)
Expand Down

0 comments on commit a245aa6

Please sign in to comment.