Skip to content

Commit

Permalink
fix(etm): fixes an error generated by feeding the model with a single…
Browse files Browse the repository at this point in the history
…-word document

Fixes #37
  • Loading branch information
Luiz Matos committed Oct 31, 2021
1 parent 47c342a commit e5c4446
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 15 deletions.
6 changes: 3 additions & 3 deletions octis/models/ETM.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _train_epoch(self, epoch):
self.optimizer.zero_grad()
self.model.zero_grad()
data_batch = data.get_batch(self.train_tokens, self.train_counts, ind, len(self.vocab.keys()),
self.hyperparameters['embedding_size'], self.device)
self.device)
sums = data_batch.sum(1).unsqueeze(1)
if self.hyperparameters['bow_norm']:
normalized_data_batch = data_batch / sums
Expand Down Expand Up @@ -179,7 +179,7 @@ def _train_epoch(self, epoch):
self.model.zero_grad()
val_data_batch = data.get_batch(self.valid_tokens, self.valid_counts,
ind, len(self.vocab.keys()),
self.hyperparameters['embedding_size'], self.device)
self.device)
sums = val_data_batch.sum(1).unsqueeze(1)
if self.hyperparameters['bow_norm']:
val_normalized_data_batch = val_data_batch / sums
Expand Down Expand Up @@ -245,7 +245,7 @@ def inference(self):
for idx, ind in enumerate(indices):
data_batch = data.get_batch(self.test_tokens, self.test_counts,
ind, len(self.vocab.keys()),
self.hyperparameters['embedding_size'], self.device)
self.device)
sums = data_batch.sum(1).unsqueeze(1)
if self.hyperparameters['bow_norm']:
normalized_data_batch = data_batch / sums
Expand Down
13 changes: 1 addition & 12 deletions octis/models/ETM_model/data.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,13 @@
import os
import random
import pickle
import numpy as np
import torch
import scipy.io

def get_batch(tokens, counts, ind, vocab_size, emsize, device):
def get_batch(tokens, counts, ind, vocab_size, device):
"""fetch input data by batch."""
batch_size = len(ind)
data_batch = np.zeros((batch_size, vocab_size))
for i, doc_id in enumerate(ind):
doc = tokens[doc_id]
count = counts[doc_id]
#L = count.shape[1]
if len(doc) == 1:
doc = [doc.squeeze()]
count = [count.squeeze()]
else:
doc = doc#.squeeze()
count = count#.squeeze()
if doc_id != -1:
for j, word in enumerate(doc):
data_batch[i, word] = count[j]
Expand Down

0 comments on commit e5c4446

Please sign in to comment.