Skip to content

Commit

Permalink
precomputed document topic distributions input feature (bmabey#85)
Browse files Browse the repository at this point in the history
* Added feature to create pyLDAvis with precomputed doc_topic_dist (doc topic distributions).  This feature saves a substantial amount of compute time when used

* Added feature to create pyLDAvis with precomputed doc_topic_dist (doc topic distributions).  This feature saves a substantial amount of compute time when used
  • Loading branch information
Alex Loosley authored and bmabey committed Feb 24, 2017
1 parent e57cd1d commit 7169fab
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions pyLDAvis/gensim.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import funcy as fp
import numpy as np
import pandas as pd
from scipy.sparse import issparse
from past.builtins import xrange
from . import prepare as vis_prepare

Expand All @@ -34,18 +35,24 @@ def _extract_data(topic_model, corpus, dictionary, doc_topic_dists=None):
assert term_freqs.shape[0] == len(dictionary), 'Term frequencies and dictionary have different shape {} != {}'.format(term_freqs.shape[0], len(dictionary))
assert doc_lengths.shape[0] == len(corpus), 'Document lengths and corpus have different sizes {} != {}'.format(doc_lengths.shape[0], len(corpus))

if hasattr(topic_model, 'lda_alpha'):
num_topics = len(topic_model.lda_alpha)
else:
num_topics = topic_model.num_topics

if doc_topic_dists is None:
# If its an HDP model.
if hasattr(topic_model, 'lda_beta'):
gamma = topic_model.inference(corpus)
else:
gamma, _ = topic_model.inference(corpus)
doc_topic_dists = gamma / gamma.sum(axis=1)[:, None]

if hasattr(topic_model, 'lda_alpha'):
num_topics = len(topic_model.lda_alpha)
else:
num_topics = topic_model.num_topics
if isinstance(doc_topic_dists, list):
doc_topic_dists = gensim.matutils.corpus2dense(doc_topic_dists, num_topics).T
elif issparse(doc_topic_dists):
doc_topic_dists = doc_topic_dists.T.todense()
doc_topic_dists = doc_topic_dists / doc_topic_dists.sum(axis=1)

assert doc_topic_dists.shape[1] == num_topics, 'Document topics and number of topics do not match {} != {}'.format(doc_topic_dists.shape[1], num_topics)

Expand Down

0 comments on commit 7169fab

Please sign in to comment.