diff --git a/README.md b/README.md index bdc88e241..6328e1fda 100644 --- a/README.md +++ b/README.md @@ -72,19 +72,6 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): masks, _, _ = predictor.predict() ``` -or from Hugging Face, as follows: - -```python -import torch -from sam2.sam2_image_predictor import SAM2ImagePredictor - -predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large") - -with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): - predictor.set_image() - masks, _, _ = predictor.predict() -``` - Please refer to the examples in [image_predictor_example.ipynb](./notebooks/image_predictor_example.ipynb) for static image use cases. SAM 2 also supports automatic mask generation on images just like SAM. Please see [automatic_mask_generator_example.ipynb](./notebooks/automatic_mask_generator_example.ipynb) for automatic mask generation in images. @@ -110,7 +97,26 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): ... ``` -or from Hugging Face, as follows: +Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) for details on how to add prompts, make refinements, and track multiple objects in videos. + +## Load from Hugging Face + +Alternatively, models can also be loaded from Hugging Face using the `from_pretrained` method: + +For image prediction: + +```python +import torch +from sam2.sam2_image_predictor import SAM2ImagePredictor + +predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large") + +with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): + predictor.set_image() + masks, _, _ = predictor.predict() +``` + +For video prediction: ```python import torch @@ -123,8 +129,6 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): masks, _, _ = predictor.predict() ``` -Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) for details on how to add prompts, make refinements, and track multiple objects in videos. - ## Model Description | **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** |