Skip to content

Commit

Permalink
Remove unnecessary utf8 encoding causing problems under Windows (cast…
Browse files Browse the repository at this point in the history
  • Loading branch information
stekiri authored Oct 4, 2021
1 parent 371f7d7 commit 36ad550
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 19 deletions.
2 changes: 1 addition & 1 deletion pyserini/analysis/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def analyze(self, text: str) -> List[str]:
List[str]
List of tokens corresponding to the output of the analyzer.
"""
results = JAnalyzerUtils.analyze(self.analyzer, JString(text.encode('utf-8')))
results = JAnalyzerUtils.analyze(self.analyzer, JString(text))
tokens = []
for token in results.toArray():
tokens.append(token)
Expand Down
20 changes: 10 additions & 10 deletions pyserini/index/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,9 @@ def analyze(self, text: str, analyzer=None) -> List[str]:
List of tokens corresponding to the output of the analyzer.
"""
if analyzer is None:
results = JAnalyzerUtils.analyze(JString(text.encode('utf-8')))
results = JAnalyzerUtils.analyze(JString(text))
else:
results = JAnalyzerUtils.analyze(analyzer, JString(text.encode('utf-8')))
results = JAnalyzerUtils.analyze(analyzer, JString(text))
tokens = []
for token in results.toArray():
tokens.append(token)
Expand Down Expand Up @@ -256,7 +256,7 @@ def get_term_counts(self, term: str, analyzer: Optional[JAnalyzer] = get_lucene_
if analyzer is None:
analyzer = get_lucene_analyzer(stemming=False, stopwords=False)

term_map = self.object.getTermCountsWithAnalyzer(self.reader, JString(term.encode('utf-8')), analyzer)
term_map = self.object.getTermCountsWithAnalyzer(self.reader, JString(term), analyzer)

return term_map.get(JString('docFreq')), term_map.get(JString('collectionFreq'))

Expand All @@ -276,9 +276,9 @@ def get_postings_list(self, term: str, analyzer=get_lucene_analyzer()) -> List[P
List of :class:`Posting` objects corresponding to the postings list for the term.
"""
if analyzer is None:
postings_list = self.object.getPostingsListForAnalyzedTerm(self.reader, JString(term.encode('utf-8')))
postings_list = self.object.getPostingsListForAnalyzedTerm(self.reader, JString(term))
else:
postings_list = self.object.getPostingsListWithAnalyzer(self.reader, JString(term.encode('utf-8')),
postings_list = self.object.getPostingsListWithAnalyzer(self.reader, JString(term),
analyzer)

if postings_list is None:
Expand Down Expand Up @@ -309,7 +309,7 @@ def get_document_vector(self, docid: str) -> Optional[Dict[str, int]]:
return None
doc_vector_dict = {}
for term in doc_vector_map.keySet().toArray():
doc_vector_dict[term] = doc_vector_map.get(JString(term.encode('utf-8')))
doc_vector_dict[term] = doc_vector_map.get(JString(term))
return doc_vector_dict

def get_term_positions(self, docid: str) -> Optional[Dict[str, int]]:
Expand All @@ -333,7 +333,7 @@ def get_term_positions(self, docid: str) -> Optional[Dict[str, int]]:
return None
term_position_map = {}
for term in java_term_position_map.keySet().toArray():
term_position_map[term] = java_term_position_map.get(JString(term.encode('utf-8'))).toArray()
term_position_map[term] = java_term_position_map.get(JString(term)).toArray()
return term_position_map

def doc(self, docid: str) -> Optional[Document]:
Expand Down Expand Up @@ -430,11 +430,11 @@ def compute_bm25_term_weight(self, docid: str, term: str, analyzer=get_lucene_an
"""
if analyzer is None:
return self.object.getBM25AnalyzedTermWeightWithParameters(self.reader, JString(docid),
JString(term.encode('utf-8')),
JString(term),
float(k1), float(b))
else:
return self.object.getBM25UnanalyzedTermWeightWithParameters(self.reader, JString(docid),
JString(term.encode('utf-8')), analyzer,
JString(term), analyzer,
float(k1), float(b))

def compute_query_document_score(self, docid: str, query: str, similarity=None):
Expand Down Expand Up @@ -492,6 +492,6 @@ def stats(self) -> Dict[str, int]:

index_stats_dict = {}
for term in index_stats_map.keySet().toArray():
index_stats_dict[term] = index_stats_map.get(JString(term.encode('utf-8')))
index_stats_dict[term] = index_stats_map.get(JString(term))

return index_stats_dict
4 changes: 2 additions & 2 deletions pyserini/search/_impact_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def search(self, q: str, k: int = 10, fields=dict()) -> List[JImpactSearcherResu
jquery = JHashMap()
for (token, weight) in encoded_query.items():
if token in self.idf and self.idf[token] > self.min_idf:
jquery.put(JString(token.encode('utf8')), JFloat(weight))
jquery.put(JString(token), JFloat(weight))

if not fields:
hits = self.object.search(jquery, k)
Expand Down Expand Up @@ -154,7 +154,7 @@ def batch_search(self, queries: List[str], qids: List[str],
jquery = JHashMap()
for (token, weight) in encoded_query.items():
if token in self.idf and self.idf[token] > self.min_idf:
jquery.put(JString(token.encode('utf8')), JFloat(weight))
jquery.put(JString(token), JFloat(weight))
query_lst.add(jquery)

for qid in qids:
Expand Down
10 changes: 5 additions & 5 deletions pyserini/search/_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ def search(self, q: Union[str, JQuery], k: int = 10, query_generator: JQueryGene
hits = None
if query_generator:
if not fields:
hits = self.object.search(query_generator, JString(q.encode('utf8')), k)
hits = self.object.search(query_generator, JString(q), k)
else:
hits = self.object.searchFields(query_generator, JString(q.encode('utf8')), jfields, k)
hits = self.object.searchFields(query_generator, JString(q), jfields, k)
elif isinstance(q, JQuery):
# Note that RM3 requires the notion of a query (string) to estimate the appropriate models. If we're just
# given a Lucene query, it's unclear what the "query" is for this estimation. One possibility is to extract
Expand All @@ -127,9 +127,9 @@ def search(self, q: Union[str, JQuery], k: int = 10, query_generator: JQueryGene
hits = self.object.search(q, k)
else:
if not fields:
hits = self.object.search(JString(q.encode('utf8')), k)
hits = self.object.search(JString(q), k)
else:
hits = self.object.searchFields(JString(q.encode('utf8')), jfields, k)
hits = self.object.searchFields(JString(q), jfields, k)

docids = set()
filtered_hits = []
Expand Down Expand Up @@ -176,7 +176,7 @@ def batch_search(self, queries: List[str], qids: List[str], k: int = 10, threads
query_strings = JArrayList()
qid_strings = JArrayList()
for query in queries:
jq = JString(query.encode('utf8'))
jq = JString(query)
query_strings.add(jq)

for qid in qids:
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Cython>=0.29.21
numpy>=1.18.1
pandas>=1.1.5
pyjnius>=1.2.1
pyjnius>=1.3.0
scikit-learn>=0.22.1
scipy>=1.4.1
tqdm
Expand Down
7 changes: 7 additions & 0 deletions tests/test_index_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from sklearn.naive_bayes import MultinomialNB

from pyserini import analysis, index, search
from pyserini.pyclass import JString
from pyserini.vectorizer import BM25Vectorizer, TfidfVectorizer


Expand Down Expand Up @@ -346,6 +347,12 @@ def test_index_stats(self):
self.assertEqual(3204, self.index_reader.stats()['documents'])
self.assertEqual(14363, self.index_reader.stats()['unique_terms'])

def test_jstring_encoding(self):
# When using pyjnius in a version prior 1.3.0, creating a JString with non-ASCII characters resulted in a
# failure. This test simply ensures that a compatible version of pyjnius is used. More details can be found in
# the discussion here: https://github.com/castorini/pyserini/issues/770
JString('zoölogy')

def tearDown(self):
os.remove(self.tarball_name)
shutil.rmtree(self.index_dir)
Expand Down

0 comments on commit 36ad550

Please sign in to comment.