Skip to content

Commit

Permalink
Add vqa +more support, add task pages
Browse files Browse the repository at this point in the history
  • Loading branch information
listofbanned committed Nov 5, 2022
1 parent 13bb654 commit 7ae6aba
Showing 1 changed file with 97 additions and 10 deletions.
107 changes: 97 additions & 10 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,85 @@ def load_image_from_url(img_url, image_size, device):


@app.get('/')
def main():
content = '''
<body>
<h3>Available options</h3>
<ul>
<li>Image Captioning: <i>GET</i> <a href='/image_captioning' target='_blank' style='text-decoration: inherit'> `/img_captioning`</a></li>
<li>Visual Question Answering: <i>GET</i> <a href='/vqa' target='_blank' style='text-decoration: inherit'>`/vqa`</a></li>
<li>Feature Extraction: <i>GET</i> <a href='/feature_extraction' target='_blank' style='text-decoration: inherit'> `/feature_extraction`</a></li>
<li>Image Text Matching: <i>GET</i> <a href='/image_text_matching' target='_blank' style='text-decoration: inherit'> `/image_text_matching`</a></li>
</ul>
</body>
'''
return HTMLResponse(content=content)


@app.get('/image_captioning')
def main():
content = '''
<body>
<form action="/upload" enctype="multipart/form-data" method="post">
<input name="task" type="text" value="image_captioning" hidden>
<input name="file" type="file">
<br />
<br />
<input type="submit">
</form>
</body>
'''
return HTMLResponse(content=content)


@app.get('/vqa')
def main():
content = '''
<body>
<form action="/upload" enctype="multipart/form-data" method="post">
<input name="task" type="text" value="vqa" hidden>
<span>Question: </span><input name="question" type="text">
<br />
<br />
<input name="file" type="file">
<br />
<br />
<input type="submit">
</form>
</body>
'''
return HTMLResponse(content=content)


@app.get('/feature_extraction')
def main():
content = '''
<body>
<form action="/upload" enctype="multipart/form-data" method="post">
<span>Task: </span><input name="task" type="text">
<input name="task" type="text" value="feature_extraction" hidden>
<span>Question: </span><input name="question" type="text">
<br />
<br />
<input name="file" type="file">
<br />
<br />
<input type="submit">
</form>
</body>
'''
return HTMLResponse(content=content)


@app.get('/image_text_matching')
def main():
content = '''
<body>
<form action="/upload" enctype="multipart/form-data" method="post">
<input name="task" type="text" value="text_matching" hidden>
<span>Question: </span><input name="question" type="text">
<br />
<br />
<span>Mode (<i>itm or itc</i>): </span><input name="mode" type="text" value="itm">
<br />
<br />
<input name="file" type="file">
Expand All @@ -69,30 +143,43 @@ def main():


@app.post('/upload')
async def upload_image(task: str = Form(), file: UploadFile = Form()):
async def upload_image(
task: str = Form(),
question: str = Form(),
caption: str = Form(),
mode: str = Form(),
match_head: str = Form(),
file: UploadFile = Form()
):
try:
img = file.file
image = load_image(img, image_size, device)

if task == 'image_captioning':
with torch.no_grad():
# beam search
caption = image_captioning_model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5)
result = image_captioning_model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5)
# nucleus sampling
# caption = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5)
return {'Caption': caption[0]}
return {'Caption': result[0]}
if task == 'vqa':
with torch.no_grad():
caption = vqa_model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5)
return {'Caption': caption[0]}
answer = vqa_model(image, question, sample=False, num_beams=3, max_length=20, min_length=5)
return {'Answer': answer[0]}
if task == 'feature_extraction':
with torch.no_grad():
caption = feature_extraction_model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5)
return {'Caption': caption[0]}
result = feature_extraction_model(image, caption, mode) [0, 0]
return {'Result': result} # ?
if task == 'text_matching':
with torch.no_grad():
caption = image_text_matching_model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5)
return {'Caption': caption[0]}
if match_head == 'itm':
itm_output = image_text_matching_model(image, caption, match_head)
itm_score = torch.nn.functional.softmax(itm_output,dim=1)[:,1]
return {'The image and text is matched with a probability of %.4f'%itm_score}
elif match_head == 'itc':
itc_score = image_text_matching_model(image, caption, match_head)
return {'The image feature and text feature has a cosine similarity of %.4f'%itc_score}

except Exception as e:
return {'Error': e}

Expand Down

0 comments on commit 7ae6aba

Please sign in to comment.