forked from castorini/anserini
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Addressing castorini#121 - Using Anserini from Python (castorini#139)
Added py4j gateway to user Anserini from Python - initial implementation.
- Loading branch information
1 parent
ab354d9
commit e9d4053
Showing
6 changed files
with
178 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,3 +7,4 @@ eval/trec_eval.9.0/ | |
bin/ | ||
*.iml | ||
.idea | ||
resources/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
### How to use the py4j Gateway | ||
|
||
#### Steps | ||
- Build the Maven package and assemble the app. | ||
- Start the GatewayEntryPoint from Java to open up a socket for communication. | ||
``` | ||
mvn clean package appassembler:assemble | ||
sh target/appassembler/bin/Pyserini | ||
``` | ||
- Make sure you have py4j installed for Python or else, issue this command: sudo pip install py4j | ||
- Python tries to connect to a JVM with a gateway (localhost on port 25333). | ||
- Python program can now initialize a Java Gateway object. | ||
- The Python program, search_web_collection, in src/main/python contains a search method that takes a | ||
query string, number of hits and returns a list of document IDs. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
package io.anserini.py4j; | ||
|
||
import io.anserini.index.IndexUtils; | ||
import io.anserini.rerank.RerankerContext; | ||
import io.anserini.rerank.ScoredDocuments; | ||
import io.anserini.util.AnalyzerUtils; | ||
import org.apache.lucene.analysis.CharArraySet; | ||
import org.apache.lucene.analysis.en.EnglishAnalyzer; | ||
import org.apache.lucene.index.DirectoryReader; | ||
import org.apache.lucene.index.IndexReader; | ||
import org.apache.lucene.queryparser.classic.QueryParser; | ||
import org.apache.lucene.search.IndexSearcher; | ||
import org.apache.lucene.search.Query; | ||
import org.apache.lucene.search.ScoreDoc; | ||
import org.apache.lucene.search.TopDocs; | ||
import org.apache.lucene.store.FSDirectory; | ||
import java.io.IOException; | ||
import java.nio.file.Files; | ||
import java.nio.file.Path; | ||
import java.nio.file.Paths; | ||
import org.apache.lucene.queryparser.classic.ParseException; | ||
import io.anserini.rerank.IdentityReranker; | ||
import io.anserini.rerank.RerankerCascade; | ||
import org.apache.lucene.search.similarities.BM25Similarity; | ||
import org.apache.lucene.search.similarities.Similarity; | ||
import java.util.List; | ||
import java.util.ArrayList; | ||
import java.util.TreeMap; | ||
import java.util.Map; | ||
import java.util.SortedMap; | ||
import static io.anserini.index.generator.LuceneDocumentGenerator.FIELD_ID; | ||
import static io.anserini.index.generator.LuceneDocumentGenerator.FIELD_BODY; | ||
|
||
import py4j.GatewayServer; | ||
|
||
/** | ||
* @author s43moham on 06/03/17. | ||
* @project anserini | ||
*/ | ||
public class Pyserini { | ||
|
||
private IndexReader reader = null; | ||
private IndexUtils indexUtils = null; | ||
|
||
public Pyserini() {} | ||
|
||
public void initializeWithIndex(String indexDir) throws Exception { | ||
Path indexPath = Paths.get(indexDir); | ||
|
||
if (!Files.exists(indexPath) || !Files.isDirectory(indexPath) || !Files.isReadable(indexPath)) { | ||
throw new IllegalArgumentException(indexDir + " does not exist or is not a directory."); | ||
} | ||
|
||
this.reader = DirectoryReader.open(FSDirectory.open(indexPath)); | ||
this.indexUtils = new IndexUtils(indexDir); | ||
} | ||
|
||
/** | ||
* Prints TREC submission file to the standard output stream. | ||
* | ||
* @param topics queries | ||
* @param similarity similarity | ||
* @throws IOException | ||
* @throws ParseException | ||
*/ | ||
|
||
public List<String> search(SortedMap<Integer, String> topics, Similarity similarity, int numHits, RerankerCascade cascade, | ||
boolean useQueryParser, boolean keepstopwords) throws IOException, ParseException { | ||
|
||
List<String> docids = new ArrayList<String>(); | ||
IndexSearcher searcher = new IndexSearcher(reader); | ||
searcher.setSimilarity(similarity); | ||
|
||
EnglishAnalyzer ea = keepstopwords ? new EnglishAnalyzer(CharArraySet.EMPTY_SET) : new EnglishAnalyzer(); | ||
QueryParser queryParser = new QueryParser(FIELD_BODY, ea); | ||
queryParser.setDefaultOperator(QueryParser.Operator.OR); | ||
|
||
for (Map.Entry<Integer, String> entry : topics.entrySet()) { | ||
|
||
int qID = entry.getKey(); | ||
String queryString = entry.getValue(); | ||
Query query = useQueryParser ? queryParser.parse(queryString) : | ||
AnalyzerUtils.buildBagOfWordsQuery(FIELD_BODY, ea, queryString); | ||
|
||
TopDocs rs = searcher.search(query, numHits); | ||
ScoreDoc[] hits = rs.scoreDocs; | ||
List<String> queryTokens = AnalyzerUtils.tokenize(ea, queryString); | ||
RerankerContext context = new RerankerContext(searcher, query, String.valueOf(qID), queryString, | ||
queryTokens, FIELD_BODY, null); | ||
ScoredDocuments docs = cascade.run(ScoredDocuments.fromTopDocs(rs, searcher), context); | ||
for (int i = 0; i < docs.documents.length; i++) { | ||
String docid = docs.documents[i].getField(FIELD_ID).stringValue(); | ||
docids.add(docid); | ||
} | ||
} | ||
|
||
return docids; | ||
} | ||
|
||
public List<String> search(String query, int numHits) throws IOException, ParseException { | ||
|
||
// for now, using BM25 similarity - not branching on args.bm25 or args.ql | ||
float k1 = 0.9f; | ||
float b = 0.4f; | ||
Similarity similarity = new BM25Similarity(k1, b); | ||
|
||
// for now, creating Topics map and appending query and setting id=1 | ||
SortedMap<Integer, String> topics = new TreeMap<>(); | ||
int id = 1; | ||
topics.put(id, query); | ||
|
||
// for now, using IdentityReranker - not branching on args.rm3 | ||
RerankerCascade cascade = new RerankerCascade(); | ||
cascade.add(new IdentityReranker()); | ||
|
||
List<String> docids = search(topics, similarity, numHits, cascade, false, false); | ||
return docids; | ||
} | ||
|
||
public String getRawDocument(String docid) throws Exception { | ||
return indexUtils.getRawDocument(docid); | ||
} | ||
|
||
public static void main(String[] argv) throws Exception { | ||
System.out.println("starting Gateway Server..."); | ||
GatewayServer gatewayServer = new GatewayServer(new Pyserini()); | ||
gatewayServer.start(); | ||
System.out.println("started!"); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from py4j.java_gateway import JavaGateway | ||
|
||
gateway = JavaGateway() | ||
index = gateway.jvm.java.lang.String("/home/s43moham/indexes/lucene-index.TrecQA.pos+docvectors+rawdocs/") | ||
pyserini = gateway.jvm.io.anserini.py4j.Pyserini() | ||
pyserini.initializeWithIndex(index) | ||
|
||
# query = "Airbus Subsidies" | ||
# hits = 30 | ||
# gateway.help(pyserini) | ||
def search(query_string, num_hits): | ||
docids = pyserini.search(query_string, num_hits) | ||
return docids | ||
|
||
# docid = "FT943-5123" | ||
def raw_doc(docid): | ||
doc_text = pyserini.getRawDocument(docid) | ||
return doc_text |