Skip to content

Commit

Permalink
Fix multiprocess preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
adaminsky committed Feb 21, 2023
1 parent 0272b69 commit 0672d42
Showing 1 changed file with 24 additions and 25 deletions.
49 changes: 24 additions & 25 deletions octis/preprocessing/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from tqdm.contrib.concurrent import process_map # or thread_map
from tqdm import tqdm
from pathlib import Path
from octis.dataset.dataset import Dataset
from collections import Counter
Expand Down Expand Up @@ -158,9 +159,10 @@ def preprocess_dataset(self, documents_path, labels_path=None, multilabel=False)
if self.num_processes is not None:
# with Pool(self.num_processes) as p:
# docs = p.map(self.simple_preprocessing_steps, docs)
docs = process_map(self.simple_preprocessing_steps, docs, max_workers=self.num_processes, chunksize=1)
chunksize = max(1, len(docs) // self.num_processes)
docs_list = process_map(self.simple_preprocessing_steps, docs, max_workers=self.num_processes, chunksize=chunksize)
else:
docs = self.simple_preprocessing_steps(docs)
docs = list(map(self.simple_preprocessing_steps, tqdm(docs)))
if self.lowercase:
self.preprocessing_steps.append("lowercase")
if self.remove_punctuation:
Expand Down Expand Up @@ -310,27 +312,24 @@ def _foo(self, docs, vocabulary, labels_path):
return final_docs, []
'''

def simple_preprocessing_steps(self, docs):
tmp_docs = []
for d in docs:
new_d = d
new_d = new_d.replace('\n', '')
new_d = new_d.replace('\t', '')
if self.lowercase:
new_d = new_d.lower()
if self.lemmatize:
if self.remove_stopwords_spacy:
new_d = ' '.join([token.lemma_ for token in self.spacy_model(new_d) if not token.is_stop])
elif self.stopwords:
new_d = ' '.join(
[token.lemma_ for token in self.spacy_model(new_d) if token.lemma_ not in set(self.stopwords)])
else:
new_d = ' '.join([token.lemma_ for token in self.spacy_model(new_d)])
def simple_preprocessing_steps(self, doc):
new_d = doc
new_d = new_d.replace('\n', '')
new_d = new_d.replace('\t', '')
if self.lowercase:
new_d = new_d.lower()
if self.lemmatize:
if self.remove_stopwords_spacy:
new_d = ' '.join([token.lemma_ for token in self.spacy_model(new_d) if not token.is_stop])
elif self.stopwords:
new_d = ' '.join(
[token.lemma_ for token in self.spacy_model(new_d) if token.lemma_ not in set(self.stopwords)])
else:
new_d = ' '.join([token.lemma_ for token in self.spacy_model(new_d)])

if self.remove_punctuation:
new_d = new_d.translate(str.maketrans(self.punctuation, ' ' * len(self.punctuation)))
if self.remove_numbers:
new_d = new_d.translate(str.maketrans("0123456789", ' ' * len("0123456789")))
new_d = " ".join(new_d.split())
tmp_docs.append(new_d)
return tmp_docs
if self.remove_punctuation:
new_d = new_d.translate(str.maketrans(self.punctuation, ' ' * len(self.punctuation)))
if self.remove_numbers:
new_d = new_d.translate(str.maketrans("0123456789", ' ' * len("0123456789")))
new_d = " ".join(new_d.split())
return new_d

0 comments on commit 0672d42

Please sign in to comment.