Skip to content

Commit

Permalink
add regularization parameter NMF scikit
Browse files Browse the repository at this point in the history
  • Loading branch information
silviatti committed May 20, 2021
1 parent dada89c commit cac4230
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
8 changes: 5 additions & 3 deletions octis/models/NMF_scikit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

class NMF_scikit(AbstractModel):

def __init__(self, num_topics=100, init=None, alpha=0, l1_ratio=0, use_partitions=True):
def __init__(self, num_topics=100, init=None, alpha=0, l1_ratio=0, regularization='both',
use_partitions=True):
"""
Initialize NMF model
Expand Down Expand Up @@ -48,6 +49,7 @@ def __init__(self, num_topics=100, init=None, alpha=0, l1_ratio=0, use_partition
self.hyperparameters["init"] = init
self.hyperparameters["alpha"] = alpha
self.hyperparameters["l1_ratio"] = l1_ratio
self.hyperparameters['regularization'] = regularization
self.use_partitions = use_partitions

self.id2word = None
Expand Down Expand Up @@ -126,9 +128,9 @@ def train_model(self, dataset, hyperparameters=None, topics=10):
#hyperparameters["corpus"] = self.id_corpus
#hyperparameters["id2word"] = self.id2word
self.hyperparameters.update(hyperparameters)

model = NMF(n_components=self.hyperparameters["num_topics"], init=self.hyperparameters["init"],
alpha=self.hyperparameters["alpha"], l1_ratio=self.hyperparameters["l1_ratio"])
alpha=self.hyperparameters["alpha"], l1_ratio=self.hyperparameters["l1_ratio"],
regularization=self.hyperparameters['regularization'])

W = model.fit_transform(self.id_corpus)
#W = W / W.sum(axis=1, keepdims=True)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ gensim<=3.8.3
nltk
pandas
spacy
scikit-learn
scikit-learn>=0.24.2
scikit-optimize>=0.8.1
matplotlib
torch
Expand Down

0 comments on commit cac4230

Please sign in to comment.