Skip to content

Commit

Permalink
add pretrained model support, fix main
Browse files Browse the repository at this point in the history
  • Loading branch information
listofbanned committed Oct 26, 2022
1 parent 07f14ac commit 9d9e398
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
import uvicorn


device = 'cuda:0'
app = FastAPI()

device = 'cuda:0'
image_size = 384


def load_image_from_url(img_url, image_size, device):
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
def load_image_from_url(img, image_size, device):
raw_image = Image.open(requests.get(img, stream=True).raw).convert('RGB')

w,h = raw_image.size
# display(raw_image.resize((w//5,h//5)))
Expand All @@ -33,17 +35,12 @@ def main():
return {'response': 'ok'}


@app.get('/image_captioning/{img_url}')
async def exec_image_captioning(img_url: str):
@app.get('/image_captioning')
def exec_image_captioning(img: str):
# img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
print(img)
try:
image_size = 384
image = load_image_from_url(img_url, image_size, device)

model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth'

model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base')
model.eval()
model = model.to(device)
image = load_image_from_url(img, image_size, device)

with torch.no_grad():
# beam search
Expand All @@ -56,4 +53,8 @@ async def exec_image_captioning(img_url: str):


if __name__ == '__main__':
model_url = 'models/model_base_capfilt_large.pth'
model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base')
model.eval()
model = model.to(device)
uvicorn.run(app, host='0.0.0.0', port='8080')

0 comments on commit 9d9e398

Please sign in to comment.