Skip to content

Commit

Permalink
Updates server
Browse files Browse the repository at this point in the history
Use Bing v7.0
Properly cache Sessions
Log in a way more compatable with Sanic, and log to a file
Dismiss errors when a new search is done
  • Loading branch information
chrisc36 committed Nov 22, 2017
1 parent 7f13a4e commit d3feaba
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 101 deletions.
72 changes: 45 additions & 27 deletions docqa/server/qa_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ def __init__(self, text: List[List[str]], original_text: str, token_spans: np.nd
class QaSystem(object):
"""
End-to-end QA system, uses web-requests to get relevant documents and a model
to scores candidate answer spans.
to score candidate answer spans.
"""
# TODO fix logging level

_split_regex = re.compile("\s*\n\s*") # split includes whitespace to avoid empty paragraphs

Expand All @@ -48,27 +49,37 @@ def __init__(self,
model: Union[ParagraphQuestionModel, ModelDir],
loader: ResourceLoader=ResourceLoader(),
bing_api_key=None,
bing_version="v5.0",
tagme_api_key=None,
blacklist_trivia_sites: bool=False,
bing_api_7=False,
n_dl_threads: int=5,
span_bound:int=8,
span_bound: int=8,
tagme_threshold: Optional[float]=0.2,
download_timeout: int=None,
n_web_docs=10):
n_web_docs=10,
loop=None):
self.log = logging.getLogger('qa_system')
self.tagme_threshold = tagme_threshold
self.n_web_docs = n_web_docs
self.blacklist_trivia_sites = blacklist_trivia_sites
self.tagme_api_key = tagme_api_key

self.client_sess = ClientSession(loop=loop)

if bing_api_key is not None:
self.searcher = AsyncWebSearcher(bing_api_key, bing_api_7)
if bing_version is None:
raise ValueError("Must specify a Bing version if using a bing_api key")
self.searcher = AsyncWebSearcher(bing_api_key, bing_version, loop=loop)
self.text_extractor = AsyncBoilerpipeCliExtractor(n_dl_threads, download_timeout)
else:
self.text_extractor = None
self.searcher = None
self.wiki_corpus = WikiCorpus(wiki_cache, keep_inverse_mapping=True)

if self.tagme_threshold is not None:
self.wiki_corpus = WikiCorpus(wiki_cache, keep_inverse_mapping=True, loop=loop)
else:
self.wiki_corpus = None

self.paragraph_splitter = paragraph_splitter
self.paragraph_selector = paragraph_selector
self.model_dir = model
Expand All @@ -92,7 +103,7 @@ def __init__(self,

self.model.set_input_spec(ParagraphAndQuestionSpec(None), voc, loader)

self.sess = tf.Session()
self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
with self.sess.as_default():
pred = self.model.get_prediction()

Expand Down Expand Up @@ -124,34 +135,35 @@ def _preprocess(self, paragraphs: List[WebParagraph]) -> List[WebParagraph]:
else:
return paragraphs

async def answer_question(self, question: str) -> Tuple[np.ndarray, List[WebParagraph]]:
"""
Answer a question using web search, return the paragraphs and per-span confidence scores
in the form of a (batch, max_num_context_tokens, max_num_context_tokens) array
"""

context = await self.get_question_context(question)
question = self.tokenizer.tokenize_paragraph_flat(question)
t0 = time.perf_counter()
out = self._get_span_scores(question, context)
self.log.info("Computing answer spans took %.5f seconds" % (time.perf_counter() - t0))
return out

async def answer_question_spans(self, question: str) -> Tuple[np.ndarray, np.ndarray, List[WebParagraph]]:
"""
Answer a question using web search, return the top spans and confidence scores for each paragraph
"""

paragraphs = await self.get_question_context(question)
paragraphs = self._preprocess(paragraphs)
question = self.tokenizer.tokenize_paragraph_flat(question)
paragraphs = self._preprocess(paragraphs)
t0 = time.perf_counter()
qa_pairs = [ParagraphAndQuestion(c.get_context(), question, None, "") for c in paragraphs]
encoded = self.model.encode(qa_pairs, False)
spans, scores = self.sess.run([self.span, self.score], encoded)
self.log.info("Computing answer spans took %.5f seconds" % (time.perf_counter() - t0))
return spans, scores, paragraphs

async def answer_question(self, question: str) -> Tuple[np.ndarray, List[WebParagraph]]:
"""
Answer a question using web search, return the paragraphs and per-span confidence scores
in the form of a (batch, max_num_context_tokens, max_num_context_tokens) array
"""

self.log.info("Answering question \"%s\" with web search" % question)
context = await self.get_question_context(question)
question = self.tokenizer.tokenize_paragraph_flat(question)
t0 = time.perf_counter()
out = self._get_span_scores(question, context)
self.log.info("Computing answer spans took %.5f seconds" % (time.perf_counter() - t0))
return out

def answer_with_doc(self, question: str, doc: str) -> Tuple[np.ndarray, List[WebParagraph]]:
""" Answer a question using the given text as a document """

Expand Down Expand Up @@ -194,15 +206,13 @@ def _split_document(self, para: List[ParagraphWithInverse], source_name: str,
on_token += n_tokens
return tokenized_paragraphs

async def _tagme(self, question):
async def _tagme(self, question: str):
payload = {"text": question,
"long_text": 3,
"lang": "en",
"gcube-token": self.tagme_api_key}
async with ClientSession() as sess:
async with sess.get(url=TAGME_API, params=payload) as resp:
data = await resp.json()

async with self.client_sess.get(url=TAGME_API, params=payload) as resp:
data = await resp.json()
return [ann_json for ann_json in data["annotations"] if "title" in ann_json]

async def get_question_context(self, question: str) -> List[WebParagraph]:
Expand Down Expand Up @@ -262,4 +272,12 @@ async def get_question_context(self, question: str) -> List[WebParagraph]:
if len(tokenized_paragraphs) == 0:
return []
question = self.tokenizer.tokenize_sentence(question)
return self.paragraph_selector.prune(question, tokenized_paragraphs)
return self.paragraph_selector.prune(question, tokenized_paragraphs)

def close(self):
if self.wiki_corpus is not None:
self.wiki_corpus.close()
if self.searcher is not None:
self.searcher.close()
self.sess.close()
self.client_sess.close()
134 changes: 87 additions & 47 deletions docqa/server/server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import logging
import sys
import time
import ujson
from os import environ
Expand All @@ -8,6 +9,7 @@
import numpy as np
import tensorflow as tf
from sanic import Sanic, response
from sanic.config import LOGGING
from sanic.exceptions import ServerError
from sanic.response import json

Expand All @@ -21,17 +23,13 @@
from docqa.text_preprocessor import WithIndicators
from docqa.utils import ResourceLoader, LoadFromPath


"""
Server for the demo. The server uses the async/await framework so the API calls and
downloads can be done asynchronously, allowing the server can answer other queries in the
web downloads can be done asynchronously, allowing the server to answer other queries in the
meantime.
"""


log = logging.getLogger('server')


class AnswerSpan(object):
def __init__(self, conf: float, start: int, end: int):
self.conf = conf
Expand Down Expand Up @@ -139,16 +137,19 @@ def main():
parser.add_argument('-t', '--tokens', type=int, default=400,
help='Number of tokens to use per paragraph')
parser.add_argument('--vec_dir', help='Location to find word vectors')
parser.add_argument('--n_paragraphs', type=int, default=12,
parser.add_argument('--n_paragraphs', type=int, default=15,
help="Number of paragraphs to run the model on")
parser.add_argument('--paragraphs_to_return', type=int, default=10,
help="Number of paragraphs return to the frontend")
parser.add_argument('--span_bound', type=int, default=8,
help="Max span size to return as an answer")

parser.add_argument('--tagme_api_key', help="Key to use for TAGME (tagme.d4science.org/tagme)")
parser.add_argument('--bing_api_key', help="Key to use for bing searches")
parser.add_argument('--bing_version', choices=["5", "7"], help="Bing search version")
parser.add_argument('--tagme_thresh', default=0.2, type=float)
parser.add_argument('--no_wiki', action="store_true", help="Dont use TAGME")
parser.add_argument('--bing_version', choices=["v5.0", "v7.0"], default="v5.0",
help='Version of Bing API to use (must be compatible with the API key)')
parser.add_argument('--tagme_thresh', default=0.2, type=float,
help="TAGME threshold for when to use the identified docs")
parser.add_argument('--n_web', type=int, default=10, help='Number of web docs to fetch')
parser.add_argument('--blacklist_trivia_sites', action="store_true",
help="Don't use trivia websites")
Expand All @@ -157,7 +158,8 @@ def main():
parser.add_argument('--n_dl_threads', type=int, default=5,
help="Number of threads to download documents with")
parser.add_argument('--request_timeout', type=int, default=60)
parser.add_argument('--download_timeout', type=int, default=25)
parser.add_argument('--download_timeout', type=int, default=25,
help="Who long to wait before timing out downloads")
parser.add_argument('--workers', type=int, default=1,
help="Number of server workers")
parser.add_argument('--debug', default=None, choices=["random_model", "dummy_qa"])
Expand Down Expand Up @@ -187,53 +189,88 @@ def main():
else:
loader = ResourceLoader()

if args.debug == "dummy_qa":
qa = DummyQa()
else:
qa = QaSystem(
args.wiki_cache,
MergeParagraphs(args.tokens),
ShallowOpenWebRanker(args.n_paragraphs),
args.voc,
model,
loader,
bing_api_key,
tagme_api_key=tagme_api_key,
bing_api_7=args.bing_version == "7",
n_dl_threads=args.n_dl_threads,
blacklist_trivia_sites=args.blacklist_trivia_sites,
download_timeout=args.download_timeout,
span_bound=span_bound,
tagme_threshold=None if args.no_wiki else args.tagme_thresh,
n_web_docs=args.n_web
)

logging.propagate = False
formatter = logging.Formatter("%(asctime)s: %(levelname)s: %(message)s")
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logging.root.addHandler(handler)
logging.root.setLevel(logging.DEBUG)
# Update Sanic's logging to register our class's loggers
log_config = LOGGING
formatter = "%(asctime)s: %(levelname)s: %(message)s"
log_config["formatters"]['my_formatter'] = {
'format': formatter,
'datefmt': '%Y-%m-%d %H:%M:%S',
}
log_config['handlers']['stream_handler'] = {
'class': "logging.StreamHandler",
'formatter': 'my_formatter',
'stream': sys.stderr
}
log_config['handlers']['file_handler'] = {
'class': "logging.FileHandler",
'formatter': 'my_formatter',
'filename': 'logging.log'
}

# It looks like we have to go and name every logger our own code might
# use in order to register it with Sanic
log_config["loggers"]['qa_system'] = {
'level': 'INFO',
'handlers': ['stream_handler', 'file_handler'],
}
log_config["loggers"]['downloader'] = {
'level': 'INFO',
'handlers': ['stream_handler', 'file_handler'],
}
log_config["loggers"]['server'] = {
'level': 'INFO',
'handlers': ['stream_handler', 'file_handler'],
}

app = Sanic()
app.config.REQUEST_TIMEOUT = args.request_timeout
log = logging.getLogger('server')

@app.listener('before_server_start')
async def setup_qa(app, loop):
# To play nice with iohttp's async ClientSession objects, we need to construct the QaSystem
# inside the event loop.
if args.debug == "dummy_qa":
qa = DummyQa()
else:
qa = QaSystem(
args.wiki_cache,
MergeParagraphs(args.tokens),
ShallowOpenWebRanker(args.n_paragraphs),
args.voc,
model,
loader,
bing_api_key,
bing_version=args.bing_version,
tagme_api_key=tagme_api_key,
n_dl_threads=args.n_dl_threads,
blacklist_trivia_sites=args.blacklist_trivia_sites,
download_timeout=args.download_timeout,
span_bound=span_bound,
tagme_threshold=None if (tagme_api_key is None) else args.tagme_thresh,
n_web_docs=args.n_web,
)
app.qa = qa

@app.listener('after_server_stop')
async def setup_qa(app, loop):
app.qa.close()

@app.route("/answer")
async def answer(request):
try:
question = request.args["question"][0]
if question == "":
return response.json(
{'message': 'No question given'},
status=400
)
spans, paras = await qa.answer_question(question)
return response.json({'message': 'No question given'}, status=400)
spans, paras = await app.qa.answer_question(question)
answers = select_answers(paras, spans, span_bound, 10)
answers = answers[:args.paragraphs_to_return]
best_span = max(answers[0].answers, key=lambda x: x.conf)
log.info("Answered \"%s\" (with web search): \"%s\"", question, answers[0].original_text[best_span.start:best_span.end])
return json([x.to_json() for x in answers])
except Exception as e:
log.info("Error: " + str(e))

raise ServerError("Server Error", status_code=500)
raise ServerError(e, status_code=500)

@app.route('/answer-from', methods=['POST'])
async def answer_from(request):
Expand All @@ -247,16 +284,19 @@ async def answer_from(request):
doc = args["document"]
if len(doc) > 500000:
raise ServerError("Document too large", status_code=400)
spans, paras = qa.answer_with_doc(question, doc)
spans, paras = app.qa.answer_with_doc(question, doc)
answers = select_answers(paras, spans, span_bound, 10)
answers = answers[:args.paragraphs_to_return]
best_span = max(answers[0].answers, key=lambda x: x.conf)
log.info("Answered \"%s\" (with user doc): \"%s\"", question, answers[0].original_text[best_span.start:best_span.end])
return json([x.to_json() for x in answers])
except Exception as e:
log.info("Error: " + str(e))
raise ServerError("Server Error", status_code=500)
raise ServerError(e, status_code=500)

app.static('/', './docqa//server/static/index.html')
app.static('/about.html', './docqa//service/static/about.html')
app.run(host="0.0.0.0", port=8000, workers=args.workers, debug=False)
app.run(host="0.0.0.0", port=8000, workers=args.workers, debug=False, log_config=LOGGING)


if __name__ == "__main__":
Expand Down
9 changes: 7 additions & 2 deletions docqa/server/static/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
return;
}

this.props.onExecute()
this.setState({working: true});

if (this.state.mode === "doc") {
Expand All @@ -145,7 +146,7 @@
var fileReader = new FileReader()
var self = this
fileReader.onload = function (fileLoadedEvent) {
if (console.log(fileLoadedEvent.error !== undefined)) {
if (fileLoadedEvent.error !== undefined) {
self.setState({working: false});
self.props.onError("Error uploading file")
return
Expand Down Expand Up @@ -255,6 +256,10 @@
this.setState({error: null});
}

onStartExecutingQuery = () => {
this.setState({error: null});
}

onError = (msg) => {
this.setState({error: msg});
}
Expand Down Expand Up @@ -333,7 +338,7 @@ <h2>Question Answering by Reading Documents Demo</h2>
return (
<div style={{"maxWidth":1000, "margin": "auto"}}>
{introDisplay}
<SearchBar onError={this.onError} onAnswer={this.onAnswer}></SearchBar>
<SearchBar onError={this.onError} onAnswer={this.onAnswer} onExecute={this.onStartExecutingQuery}></SearchBar>
{alert}
{answerDisplay}
</div>
Expand Down
Loading

0 comments on commit d3feaba

Please sign in to comment.