Skip to content

Commit

Permalink
update and fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
1989Ryan committed Feb 7, 2023
1 parent f793063 commit 313d88c
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 26 deletions.
3 changes: 2 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ RUN sudo apt-get install build-essential

RUN sudo pip3 install git+https://github.com/openai/CLIP.git
RUN sudo pip3 install \
gdown==3.13.1 \
gdown \
absl-py>=0.7.0 \
gym==0.17.3 \
pybullet>=3.0.4 \
Expand All @@ -71,6 +71,7 @@ RUN sudo pip3 install \
regex \
timm==0.5.4\
ffmpeg \

opencv-python==4.1.2.30
RUN sudo pip3 install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-1.10.2+cu113.html
RUN sudo pip3 install torch-geometric
Expand Down
46 changes: 23 additions & 23 deletions object_detector/run.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
from object_detector import Space, maskrcnn
from object_detector import maskrcnn
from object_detector.space.eval.ap import convert_to_boxes
from object_detector.maskrcnn.mask_rcnn import get_model_instance_segmentation
import torch

detector = {
'space': Space,
'maskrcnn': maskrcnn,
}
# detector = {
# 'space': Space,
# 'maskrcnn': maskrcnn,
# }

class objDetector():
def __init__(self, model_name, device) -> None:
self.model = detector[model_name]
self.model.to(device)
self.device = device
# class objDetector():
# def __init__(self, model_name, device) -> None:
# self.model = detector[model_name]
# self.model.to(device)
# self.device = device

@torch.no_grad()
def run(self, img):
boxes_pred = []
self.model.eval()
img.to(self.device)
loss, log = self.model(img, global_step=10000000000)
z_where, z_pres_prob = log['z_shere'], log['z_pres_prob']
z_where = z_where.detach().cpu()
z_pres_prob = z_pres_prob.detach().cpu().squeeze()
z_pres = z_pres_prob > 0.5
boxes_batch = convert_to_boxes(z_where, z_pres, z_pres_prob)
boxes_pred.extend(boxes_batch)
return boxes_pred
# @torch.no_grad()
# def run(self, img):
# boxes_pred = []
# self.model.eval()
# img.to(self.device)
# loss, log = self.model(img, global_step=10000000000)
# z_where, z_pres_prob = log['z_shere'], log['z_pres_prob']
# z_where = z_where.detach().cpu()
# z_pres_prob = z_pres_prob.detach().cpu().squeeze()
# z_pres = z_pres_prob > 0.5
# boxes_batch = convert_to_boxes(z_where, z_pres, z_pres_prob)
# boxes_pred.extend(boxes_batch)
# return boxes_pred

class maskrcnn_obj_detecotr():
'''mask rcnn object detector'''
Expand Down
3 changes: 1 addition & 2 deletions paragon/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ def main(cfg):
total_num = 0
succ_num = 0
repeat = 1
obj_detect = maskrcnn_obj_detecotr('/home/zirui/paraground/trained_model/0_9_mask_rcnn.pt')

obj_detect = maskrcnn_obj_detecotr('./object_detector/maskrcnn/0_9_mask_rcnn.pt')
if cfg['dataset']['comp']:
print('testing tasks: 10 objects with compositional instructions')
else:
Expand Down

0 comments on commit 313d88c

Please sign in to comment.