Skip to content

Commit

Permalink
Add file upload
Browse files Browse the repository at this point in the history
  • Loading branch information
listofbanned committed Oct 26, 2022
1 parent 9d9e398 commit 4918512
Showing 1 changed file with 31 additions and 19 deletions.
50 changes: 31 additions & 19 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from torchvision.transforms.functional import InterpolationMode
from torchvision import transforms
from PIL import Image
import requests
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode

from fastapi import FastAPI, UploadFile, Form
from fastapi.responses import HTMLResponse
from models.blip import blip_decoder
from fastapi import FastAPI
import uvicorn


Expand All @@ -16,7 +17,7 @@


def load_image_from_url(img, image_size, device):
raw_image = Image.open(requests.get(img, stream=True).raw).convert('RGB')
raw_image = Image.open(img).convert('RGB')

w,h = raw_image.size
# display(raw_image.resize((w//5,h//5)))
Expand All @@ -32,29 +33,40 @@ def load_image_from_url(img, image_size, device):

@app.get('/')
def main():
return {'response': 'ok'}
content = '''
<body>
<form action="/upload" enctype="multipart/form-data" method="post">
<input name="task" type="text">
<br />
<input name="file" type="file">
<br />
<input type="submit">
</form>
</body>
'''
return HTMLResponse(content=content)


@app.get('/image_captioning')
def exec_image_captioning(img: str):
# img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
print(img)
@app.post('/upload')
async def upload_image(task: str = Form(), file: UploadFile = Form()):
try:
image = load_image_from_url(img, image_size, device)

with torch.no_grad():
# beam search
caption = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5)
# nucleus sampling
# caption = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5)
return {'caption': caption[0]}
img = file.file
if task == 'image_captioning':
image = load_image_from_url(img, image_size, device)

with torch.no_grad():
# beam search
caption = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5)
# nucleus sampling
# caption = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5)
return {'Caption': caption[0]}
except Exception as e:
return {'Error': e}


if __name__ == '__main__':
model_url = 'models/model_base_capfilt_large.pth'
model_url = './models/model_base_capfilt_large.pth'
model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base')
model.eval()
model = model.to(device)
uvicorn.run(app, host='0.0.0.0', port='8080')
uvicorn.run(app, host='0.0.0.0', port=8080)

0 comments on commit 4918512

Please sign in to comment.