Skip to content

Commit

Permalink
Fix HF image predictor
Browse files Browse the repository at this point in the history
  • Loading branch information
haithamkhedr committed Aug 12, 2024
1 parent dce7b54 commit 1191677
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
2 changes: 2 additions & 0 deletions sam2/build_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def build_sam2(
mode="eval",
hydra_overrides_extra=[],
apply_postprocessing=True,
**kwargs,
):

if apply_postprocessing:
Expand Down Expand Up @@ -47,6 +48,7 @@ def build_sam2_video_predictor(
mode="eval",
hydra_overrides_extra=[],
apply_postprocessing=True,
**kwargs,
):
hydra_overrides = [
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
Expand Down
9 changes: 6 additions & 3 deletions sam2/sam2_image_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
mask_threshold=0.0,
max_hole_area=0.0,
max_sprinkle_area=0.0,
**kwargs,
) -> None:
"""
Uses SAM-2 to calculate the image embedding for an image, and then
Expand All @@ -33,8 +34,10 @@ def __init__(
sam_model (Sam-2): The model to use for mask prediction.
mask_threshold (float): The threshold to use when converting mask logits
to binary masks. Masks are thresholded at 0 by default.
fill_hole_area (int): If fill_hole_area > 0, we fill small holes in up to
the maximum area of fill_hole_area in low_res_masks.
max_hole_area (int): If max_hole_area > 0, we fill small holes in up to
the maximum area of max_hole_area in low_res_masks.
max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to
the maximum area of max_sprinkle_area in low_res_masks.
"""
super().__init__()
self.model = sam_model
Expand Down Expand Up @@ -77,7 +80,7 @@ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor":
from sam2.build_sam import build_sam2_hf

sam_model = build_sam2_hf(model_id, **kwargs)
return cls(sam_model)
return cls(sam_model, **kwargs)

@torch.no_grad()
def set_image(
Expand Down
2 changes: 1 addition & 1 deletion sam2/sam2_video_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor":
from sam2.build_sam import build_sam2_video_predictor_hf

sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
return cls(sam_model)
return sam_model

def _obj_id_to_idx(self, inference_state, obj_id):
"""Map client-side object id to model-side object index."""
Expand Down

0 comments on commit 1191677

Please sign in to comment.