Skip to content

Commit

Permalink
accept kwargs in auto_mask_generator
Browse files Browse the repository at this point in the history
  • Loading branch information
haithamkhedr committed Aug 13, 2024
1 parent 1191677 commit fd5125b
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions sam2/automatic_mask_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
output_mode: str = "binary_mask",
use_m2m: bool = False,
multimask_output: bool = True,
**kwargs,
) -> None:
"""
Using a SAM 2 model, generates masks for the entire image.
Expand Down Expand Up @@ -148,6 +149,23 @@ def __init__(
self.use_m2m = use_m2m
self.multimask_output = multimask_output

@classmethod
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator":
"""
Load a pretrained model from the Hugging Face hub.
Arguments:
model_id (str): The Hugging Face repository ID.
**kwargs: Additional arguments to pass to the model constructor.
Returns:
(SAM2AutomaticMaskGenerator): The loaded model.
"""
from sam2.build_sam import build_sam2_hf

sam_model = build_sam2_hf(model_id, **kwargs)
return cls(sam_model, **kwargs)

@torch.no_grad()
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
"""
Expand Down

0 comments on commit fd5125b

Please sign in to comment.