Skip to content

Commit

Permalink
Merge branch 'facebookresearch:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
phlong3105 authored Aug 13, 2024
2 parents 7cd5423 + 0db838b commit 1eaf093
Show file tree
Hide file tree
Showing 14 changed files with 323 additions and 420 deletions.
22 changes: 20 additions & 2 deletions INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
- Linux with Python ≥ 3.10, PyTorch ≥ 2.3.1 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. Install them together at https://pytorch.org to ensure this.
* Note older versions of Python or PyTorch may also work. However, the versions above are strongly recommended to provide all features such as `torch.compile`.
- [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. This should typically be CUDA 12.1 if you follow the default installation command.
- If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu.

Then, install SAM 2 from the root of this repository via
```bash
Expand All @@ -22,11 +23,13 @@ This would also skip the post-processing step at runtime (removing small holes a

By default, we allow the installation to proceed even if the SAM 2 CUDA extension fails to build. (In this case, the build errors are hidden unless using `-v` for verbose output in `pip install`.)

If you see a message like `Skipping the post-processing step due to the error above` at runtime or `Failed to build the SAM 2 CUDA extension due to the error above` during installation, it indicates that the SAM 2 CUDA extension failed to build in your environment. In this case, you can still use SAM 2 for both image and video applications, but the post-processing step (removing small holes and sprinkles in the output masks) will be skipped. This shouldn't affect the results in most cases.
If you see a message like `Skipping the post-processing step due to the error above` at runtime or `Failed to build the SAM 2 CUDA extension due to the error above` during installation, it indicates that the SAM 2 CUDA extension failed to build in your environment. In this case, **you can still use SAM 2 for both image and video applications**. The post-processing step (removing small holes and sprinkles in the output masks) will be skipped, but this shouldn't affect the results in most cases.

If you would like to enable this post-processing step, you can reinstall SAM 2 on a GPU machine with environment variable `SAM2_BUILD_ALLOW_ERRORS=0` to force building the CUDA extension (and raise errors if it fails to build), as follows
```bash
pip uninstall -y SAM-2; rm -f sam2/*.so; SAM2_BUILD_ALLOW_ERRORS=0 pip install -v -e ".[demo]"
pip uninstall -y SAM-2 && \
rm -f ./sam2/*.so && \
SAM2_BUILD_ALLOW_ERRORS=0 pip install -v -e ".[demo]"
```

Note that PyTorch needs to be installed first before building the SAM 2 CUDA extension. It's also necessary to install [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. (This should typically be CUDA 12.1 if you follow the default installation command.) After installing the CUDA toolkits, you can check its version via `nvcc --version`.
Expand Down Expand Up @@ -101,6 +104,21 @@ In particular, if you have a lower PyTorch version than 2.3.1, it's recommended
We have been building SAM 2 against PyTorch 2.3.1 internally. However, a few user comments (e.g. https://github.com/facebookresearch/segment-anything-2/issues/22, https://github.com/facebookresearch/segment-anything-2/issues/14) suggested that downgrading to PyTorch 2.1.0 might resolve this problem. In case the error persists, you may try changing the restriction from `torch>=2.3.1` to `torch>=2.1.0` in both [`pyproject.toml`](pyproject.toml) and [`setup.py`](setup.py) to allow PyTorch 2.1.0.
</details>

<details>
<summary>
I got `CUDA error: no kernel image is available for execution on the device`
</summary>
<br/>

A possible cause could be that the CUDA kernel is somehow not compiled towards your GPU's CUDA [capability](https://developer.nvidia.com/cuda-gpus). This could happen if the installation is done in an environment different from the runtime (e.g. in a slurm system).

You can try pulling the latest code from the SAM 2 repo and running the following
```
export TORCH_CUDA_ARCH_LIST=9.0 8.0 8.6 8.9 7.0 7.2 7.5 6.0`
```
to manually specify the CUDA capability in the compilation target that matches your GPU.
</details>

<details>
<summary>
I got `RuntimeError: No available kernel. Aborting execution.` (or similar errors)
Expand Down
11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ SAM 2 needs to be installed first before use. The code requires `python>=3.10`,
```bash
git clone https://github.com/facebookresearch/segment-anything-2.git

cd segment-anything-2; pip install -e .
cd segment-anything-2 & pip install -e .
```
If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu.

To use the SAM 2 predictor and run the example notebooks, `jupyter` and `matplotlib` are required and can be installed by:

Expand All @@ -29,8 +30,9 @@ pip install -e ".[demo]"
```

Note:
1. It's recommended to create a new Python environment for this installation and install PyTorch 2.3.1 (or higher) via `pip` following https://pytorch.org/. If you have a PyTorch version lower than 2.3.1 in your current environment, the installation command above will try to upgrade it to the latest PyTorch version using `pip`.
1. It's recommended to create a new Python environment via [Anaconda](https://www.anaconda.com/) for this installation and install PyTorch 2.3.1 (or higher) via `pip` following https://pytorch.org/. If you have a PyTorch version lower than 2.3.1 in your current environment, the installation command above will try to upgrade it to the latest PyTorch version using `pip`.
2. The step above requires compiling a custom CUDA kernel with the `nvcc` compiler. If it isn't already available on your machine, please install the [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) with a version that matches your PyTorch CUDA version.
3. If you see a message like `Failed to build the SAM 2 CUDA extension` during installation, you can ignore it and still use SAM 2 (some post-processing functionality may be limited, but it doesn't affect the results in most cases).

Please see [`INSTALL.md`](./INSTALL.md) for FAQs on potential issues and solutions.

Expand All @@ -41,8 +43,9 @@ Please see [`INSTALL.md`](./INSTALL.md) for FAQs on potential issues and solutio
First, we need to download a model checkpoint. All the model checkpoints can be downloaded by running:

```bash
cd checkpoints
./download_ckpts.sh
cd checkpoints && \
./download_ckpts.sh && \
cd ..
```

or individually from:
Expand Down
78 changes: 57 additions & 21 deletions notebooks/automatic_mask_generator_example.ipynb

Large diffs are not rendered by default.

61 changes: 45 additions & 16 deletions notebooks/image_predictor_example.ipynb

Large diffs are not rendered by default.

395 changes: 70 additions & 325 deletions notebooks/video_predictor_example.ipynb

Large diffs are not rendered by default.

22 changes: 21 additions & 1 deletion sam2/automatic_mask_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
output_mode: str = "binary_mask",
use_m2m: bool = False,
multimask_output: bool = True,
**kwargs,
) -> None:
"""
Using a SAM 2 model, generates masks for the entire image.
Expand Down Expand Up @@ -148,6 +149,23 @@ def __init__(
self.use_m2m = use_m2m
self.multimask_output = multimask_output

@classmethod
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator":
"""
Load a pretrained model from the Hugging Face hub.
Arguments:
model_id (str): The Hugging Face repository ID.
**kwargs: Additional arguments to pass to the model constructor.
Returns:
(SAM2AutomaticMaskGenerator): The loaded model.
"""
from sam2.build_sam import build_sam2_hf

sam_model = build_sam2_hf(model_id, **kwargs)
return cls(sam_model, **kwargs)

@torch.no_grad()
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
"""
Expand Down Expand Up @@ -284,7 +302,9 @@ def _process_batch(
orig_h, orig_w = orig_size

# Run model on this batch
points = torch.as_tensor(points, device=self.predictor.device)
points = torch.as_tensor(
points, dtype=torch.float32, device=self.predictor.device
)
in_points = self.predictor._transforms.transform_coords(
points, normalize=normalize, orig_hw=im_size
)
Expand Down
2 changes: 2 additions & 0 deletions sam2/build_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def build_sam2(
mode="eval",
hydra_overrides_extra=[],
apply_postprocessing=True,
**kwargs,
):

if apply_postprocessing:
Expand Down Expand Up @@ -47,6 +48,7 @@ def build_sam2_video_predictor(
mode="eval",
hydra_overrides_extra=[],
apply_postprocessing=True,
**kwargs,
):
hydra_overrides = [
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
Expand Down
7 changes: 6 additions & 1 deletion sam2/modeling/position_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,11 @@ def apply_rotary_enc(
# repeat freqs along seq_len dim to match k seq_len
if repeat_freqs_k:
r = xk_.shape[-2] // xq_.shape[-2]
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
if freqs_cis.is_cuda:
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
else:
# torch.repeat on complex numbers may not be supported on non-CUDA devices
# (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
4 changes: 2 additions & 2 deletions sam2/modeling/sam2_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,10 +567,10 @@ def _prepare_memory_conditioned_features(
continue # skip padding frames
# "maskmem_features" might have been offloaded to CPU in demo use cases,
# so we load it back to GPU (it's a no-op if it's already on GPU).
feats = prev["maskmem_features"].cuda(non_blocking=True)
feats = prev["maskmem_features"].to(device, non_blocking=True)
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
# Spatial positional encoding (it might have been offloaded to CPU in eval)
maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
# Temporal positional encoding
maskmem_enc = (
Expand Down
9 changes: 6 additions & 3 deletions sam2/sam2_image_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
mask_threshold=0.0,
max_hole_area=0.0,
max_sprinkle_area=0.0,
**kwargs,
) -> None:
"""
Uses SAM-2 to calculate the image embedding for an image, and then
Expand All @@ -33,8 +34,10 @@ def __init__(
sam_model (Sam-2): The model to use for mask prediction.
mask_threshold (float): The threshold to use when converting mask logits
to binary masks. Masks are thresholded at 0 by default.
fill_hole_area (int): If fill_hole_area > 0, we fill small holes in up to
the maximum area of fill_hole_area in low_res_masks.
max_hole_area (int): If max_hole_area > 0, we fill small holes in up to
the maximum area of max_hole_area in low_res_masks.
max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to
the maximum area of max_sprinkle_area in low_res_masks.
"""
super().__init__()
self.model = sam_model
Expand Down Expand Up @@ -77,7 +80,7 @@ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor":
from sam2.build_sam import build_sam2_hf

sam_model = build_sam2_hf(model_id, **kwargs)
return cls(sam_model)
return cls(sam_model, **kwargs)

@torch.no_grad()
def set_image(
Expand Down
14 changes: 9 additions & 5 deletions sam2/sam2_video_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@ def init_state(
async_loading_frames=False,
):
"""Initialize a inference state."""
compute_device = self.device # device of the model
images, video_height, video_width = load_video_frames(
video_path=video_path,
image_size=self.image_size,
offload_video_to_cpu=offload_video_to_cpu,
async_loading_frames=async_loading_frames,
compute_device=compute_device,
)
inference_state = {}
inference_state["images"] = images
Expand All @@ -65,11 +67,11 @@ def init_state(
# the original video height and width, used for resizing final output scores
inference_state["video_height"] = video_height
inference_state["video_width"] = video_width
inference_state["device"] = torch.device("cuda")
inference_state["device"] = compute_device
if offload_state_to_cpu:
inference_state["storage_device"] = torch.device("cpu")
else:
inference_state["storage_device"] = torch.device("cuda")
inference_state["storage_device"] = compute_device
# inputs on each frame
inference_state["point_inputs_per_obj"] = {}
inference_state["mask_inputs_per_obj"] = {}
Expand Down Expand Up @@ -119,7 +121,7 @@ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor":
from sam2.build_sam import build_sam2_video_predictor_hf

sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
return cls(sam_model)
return sam_model

def _obj_id_to_idx(self, inference_state, obj_id):
"""Map client-side object id to model-side object index."""
Expand Down Expand Up @@ -270,7 +272,8 @@ def add_new_points_or_box(
prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)

if prev_out is not None and prev_out["pred_masks"] is not None:
prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True)
device = inference_state["device"]
prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
current_out, _ = self._run_single_frame_inference(
Expand Down Expand Up @@ -793,7 +796,8 @@ def _get_image_feature(self, inference_state, frame_idx, batch_size):
)
if backbone_out is None:
# Cache miss -- we will run inference on a single image
image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0)
device = inference_state["device"]
image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
backbone_out = self.forward_image(image)
# Cache the most recent frame's feature (for repeated interactions with
# a frame; we can use an LRU cache for more frames in the future).
Expand Down
42 changes: 33 additions & 9 deletions sam2/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,15 @@ class AsyncVideoFrameLoader:
A list of video frames to be load asynchronously without blocking session start.
"""

def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std):
def __init__(
self,
img_paths,
image_size,
offload_video_to_cpu,
img_mean,
img_std,
compute_device,
):
self.img_paths = img_paths
self.image_size = image_size
self.offload_video_to_cpu = offload_video_to_cpu
Expand All @@ -119,6 +127,7 @@ def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_st
# video_height and video_width be filled when loading the first image
self.video_height = None
self.video_width = None
self.compute_device = compute_device

# load the first frame to fill video_height and video_width and also
# to cache it (since it's most likely where the user will click)
Expand Down Expand Up @@ -152,7 +161,7 @@ def __getitem__(self, index):
img -= self.img_mean
img /= self.img_std
if not self.offload_video_to_cpu:
img = img.cuda(non_blocking=True)
img = img.to(self.compute_device, non_blocking=True)
self.images[index] = img
return img

Expand All @@ -167,6 +176,7 @@ def load_video_frames(
img_mean=(0.485, 0.456, 0.406),
img_std=(0.229, 0.224, 0.225),
async_loading_frames=False,
compute_device=torch.device("cuda"),
):
"""
Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
Expand All @@ -179,7 +189,15 @@ def load_video_frames(
if isinstance(video_path, str) and os.path.isdir(video_path):
jpg_folder = video_path
else:
raise NotImplementedError("Only JPEG frames are supported at this moment")
raise NotImplementedError(
"Only JPEG frames are supported at this moment. For video files, you may use "
"ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n"
"```\n"
"ffmpeg -i <your_video>.mp4 -q:v 2 -start_number 0 <output_dir>/'%05d.jpg'\n"
"```\n"
"where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks "
"ffmpeg to start the JPEG file from 00000.jpg."
)

frame_names = [
p
Expand All @@ -196,17 +214,22 @@ def load_video_frames(

if async_loading_frames:
lazy_images = AsyncVideoFrameLoader(
img_paths, image_size, offload_video_to_cpu, img_mean, img_std
img_paths,
image_size,
offload_video_to_cpu,
img_mean,
img_std,
compute_device,
)
return lazy_images, lazy_images.video_height, lazy_images.video_width

images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
if not offload_video_to_cpu:
images = images.cuda()
img_mean = img_mean.cuda()
img_std = img_std.cuda()
images = images.to(compute_device)
img_mean = img_mean.to(compute_device)
img_std = img_std.to(compute_device)
# normalize by mean and std
images -= img_mean
images /= img_std
Expand All @@ -230,8 +253,9 @@ def fill_holes_in_mask_scores(mask, max_area):
except Exception as e:
# Skip the post-processing step on removing small holes if the CUDA kernel fails
warnings.warn(
f"{e}\n\nSkipping the post-processing step due to the error above. "
"Consider building SAM 2 with CUDA extension to enable post-processing (see "
f"{e}\n\nSkipping the post-processing step due to the error above. You can "
"still use SAM 2 and it's OK to ignore the error above, although some post-processing "
"functionality may be limited (which doesn't affect the results in most cases; see "
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
category=UserWarning,
stacklevel=2,
Expand Down
5 changes: 3 additions & 2 deletions sam2/utils/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
except Exception as e:
# Skip the post-processing step if the CUDA kernel fails
warnings.warn(
f"{e}\n\nSkipping the post-processing step due to the error above. "
"Consider building SAM 2 with CUDA extension to enable post-processing (see "
f"{e}\n\nSkipping the post-processing step due to the error above. You can "
"still use SAM 2 and it's OK to ignore the error above, although some post-processing "
"functionality may be limited (which doesn't affect the results in most cases; see "
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
category=UserWarning,
stacklevel=2,
Expand Down
Loading

0 comments on commit 1eaf093

Please sign in to comment.