Skip to content

Commit

Permalink
Addressing castorini#121 - Using Anserini from Python (castorini#139)
Browse files Browse the repository at this point in the history
Added py4j gateway to user Anserini from Python - initial implementation.
  • Loading branch information
salman1993 authored and lintool committed Mar 8, 2017
1 parent ab354d9 commit e9d4053
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ eval/trec_eval.9.0/
bin/
*.iml
.idea
resources/
14 changes: 14 additions & 0 deletions docs/py4j-gateway.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
### How to use the py4j Gateway

#### Steps
- Build the Maven package and assemble the app.
- Start the GatewayEntryPoint from Java to open up a socket for communication.
```
mvn clean package appassembler:assemble
sh target/appassembler/bin/Pyserini
```
- Make sure you have py4j installed for Python or else, issue this command: sudo pip install py4j
- Python tries to connect to a JVM with a gateway (localhost on port 25333).
- Python program can now initialize a Java Gateway object.
- The Python program, search_web_collection, in src/main/python contains a search method that takes a
query string, number of hits and returns a list of document IDs.
11 changes: 10 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@
<mainClass>io.anserini.search.SearchWebCollection</mainClass>
<name>SearchWebCollection</name>
</program>
<program>
<mainClass>io.anserini.py4j.Pyserini</mainClass>
<name>Pyserini</name>
</program>
<program>
<mainClass>io.anserini.eval.Eval</mainClass>
<name>Eval</name>
Expand Down Expand Up @@ -390,12 +394,17 @@
<artifactId>stanford-corenlp</artifactId>
<version>3.7.0</version>
</dependency>

<dependency>
<groupId>edu.stanford.nlp</groupId>
<artifactId>stanford-corenlp</artifactId>
<version>3.7.0</version>
<classifier>models</classifier>
</dependency>

<dependency>
<groupId>net.sf.py4j</groupId>
<artifactId>py4j</artifactId>
<version>0.10.4</version>
</dependency>
</dependencies>
</project>
130 changes: 130 additions & 0 deletions src/main/java/io/anserini/py4j/Pyserini.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package io.anserini.py4j;

import io.anserini.index.IndexUtils;
import io.anserini.rerank.RerankerContext;
import io.anserini.rerank.ScoredDocuments;
import io.anserini.util.AnalyzerUtils;
import org.apache.lucene.analysis.CharArraySet;
import org.apache.lucene.analysis.en.EnglishAnalyzer;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.queryparser.classic.QueryParser;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.FSDirectory;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import org.apache.lucene.queryparser.classic.ParseException;
import io.anserini.rerank.IdentityReranker;
import io.anserini.rerank.RerankerCascade;
import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.search.similarities.Similarity;
import java.util.List;
import java.util.ArrayList;
import java.util.TreeMap;
import java.util.Map;
import java.util.SortedMap;
import static io.anserini.index.generator.LuceneDocumentGenerator.FIELD_ID;
import static io.anserini.index.generator.LuceneDocumentGenerator.FIELD_BODY;

import py4j.GatewayServer;

/**
* @author s43moham on 06/03/17.
* @project anserini
*/
public class Pyserini {

private IndexReader reader = null;
private IndexUtils indexUtils = null;

public Pyserini() {}

public void initializeWithIndex(String indexDir) throws Exception {
Path indexPath = Paths.get(indexDir);

if (!Files.exists(indexPath) || !Files.isDirectory(indexPath) || !Files.isReadable(indexPath)) {
throw new IllegalArgumentException(indexDir + " does not exist or is not a directory.");
}

this.reader = DirectoryReader.open(FSDirectory.open(indexPath));
this.indexUtils = new IndexUtils(indexDir);
}

/**
* Prints TREC submission file to the standard output stream.
*
* @param topics queries
* @param similarity similarity
* @throws IOException
* @throws ParseException
*/

public List<String> search(SortedMap<Integer, String> topics, Similarity similarity, int numHits, RerankerCascade cascade,
boolean useQueryParser, boolean keepstopwords) throws IOException, ParseException {

List<String> docids = new ArrayList<String>();
IndexSearcher searcher = new IndexSearcher(reader);
searcher.setSimilarity(similarity);

EnglishAnalyzer ea = keepstopwords ? new EnglishAnalyzer(CharArraySet.EMPTY_SET) : new EnglishAnalyzer();
QueryParser queryParser = new QueryParser(FIELD_BODY, ea);
queryParser.setDefaultOperator(QueryParser.Operator.OR);

for (Map.Entry<Integer, String> entry : topics.entrySet()) {

int qID = entry.getKey();
String queryString = entry.getValue();
Query query = useQueryParser ? queryParser.parse(queryString) :
AnalyzerUtils.buildBagOfWordsQuery(FIELD_BODY, ea, queryString);

TopDocs rs = searcher.search(query, numHits);
ScoreDoc[] hits = rs.scoreDocs;
List<String> queryTokens = AnalyzerUtils.tokenize(ea, queryString);
RerankerContext context = new RerankerContext(searcher, query, String.valueOf(qID), queryString,
queryTokens, FIELD_BODY, null);
ScoredDocuments docs = cascade.run(ScoredDocuments.fromTopDocs(rs, searcher), context);
for (int i = 0; i < docs.documents.length; i++) {
String docid = docs.documents[i].getField(FIELD_ID).stringValue();
docids.add(docid);
}
}

return docids;
}

public List<String> search(String query, int numHits) throws IOException, ParseException {

// for now, using BM25 similarity - not branching on args.bm25 or args.ql
float k1 = 0.9f;
float b = 0.4f;
Similarity similarity = new BM25Similarity(k1, b);

// for now, creating Topics map and appending query and setting id=1
SortedMap<Integer, String> topics = new TreeMap<>();
int id = 1;
topics.put(id, query);

// for now, using IdentityReranker - not branching on args.rm3
RerankerCascade cascade = new RerankerCascade();
cascade.add(new IdentityReranker());

List<String> docids = search(topics, similarity, numHits, cascade, false, false);
return docids;
}

public String getRawDocument(String docid) throws Exception {
return indexUtils.getRawDocument(docid);
}

public static void main(String[] argv) throws Exception {
System.out.println("starting Gateway Server...");
GatewayServer gatewayServer = new GatewayServer(new Pyserini());
gatewayServer.start();
System.out.println("started!");
}
}
10 changes: 5 additions & 5 deletions src/main/java/io/anserini/search/SearchWebCollection.java
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public void close() throws IOException {
*/

public void search(SortedMap<Integer, String> topics, String submissionFile, Similarity similarity, int numHits, RerankerCascade cascade,
boolean useQueryParser, boolean keepstopwords) throws IOException, ParseException {
boolean useQueryParser, boolean keepstopwords) throws IOException, ParseException {


IndexSearcher searcher = new IndexSearcher(reader);
Expand All @@ -123,8 +123,8 @@ public void search(SortedMap<Integer, String> topics, String submissionFile, Sim

int qID = entry.getKey();
String queryString = entry.getValue();
Query query = useQueryParser? queryParser.parse(queryString) :
AnalyzerUtils.buildBagOfWordsQuery(FIELD_BODY, ea, queryString);
Query query = useQueryParser? queryParser.parse(queryString) :
AnalyzerUtils.buildBagOfWordsQuery(FIELD_BODY, ea, queryString);

/**
* For Web Tracks 2010,2011,and 2012; an experimental run consists of the top 10,000 documents for each topic query.
Expand Down Expand Up @@ -154,7 +154,7 @@ public void search(SortedMap<Integer, String> topics, String submissionFile, Sim
}

public void search(SortedMap<Integer, String> topics, String submissionFile, Similarity similarity, int numHits, RerankerCascade cascade)
throws IOException, ParseException {
throws IOException, ParseException {
search(topics, submissionFile, similarity, numHits, cascade, false, false);
}

Expand Down Expand Up @@ -232,4 +232,4 @@ public static void main(String[] args) throws Exception {
final long durationMillis = TimeUnit.MILLISECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS);
LOG.info("Total " + topics.size() + " topics searched in " + DurationFormatUtils.formatDuration(durationMillis, "HH:mm:ss"));
}
}
}
18 changes: 18 additions & 0 deletions src/main/python/search_web_collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from py4j.java_gateway import JavaGateway

gateway = JavaGateway()
index = gateway.jvm.java.lang.String("/home/s43moham/indexes/lucene-index.TrecQA.pos+docvectors+rawdocs/")
pyserini = gateway.jvm.io.anserini.py4j.Pyserini()
pyserini.initializeWithIndex(index)

# query = "Airbus Subsidies"
# hits = 30
# gateway.help(pyserini)
def search(query_string, num_hits):
docids = pyserini.search(query_string, num_hits)
return docids

# docid = "FT943-5123"
def raw_doc(docid):
doc_text = pyserini.getRawDocument(docid)
return doc_text

0 comments on commit e9d4053

Please sign in to comment.