Skip to content

Commit

Permalink
ENH: pad or crop volume
Browse files Browse the repository at this point in the history
  • Loading branch information
AntoineTheb committed Jul 31, 2024
1 parent 460dac8 commit 856586b
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 0 deletions.
75 changes: 75 additions & 0 deletions scilpy/image/volume_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,81 @@ def resample_volume(img, ref_img=None, volume_shape=None, iso_min=False,
return nib.Nifti1Image(data2.astype(data.dtype), affine2)


def reshape_volume(
img, volume_shape, mode='constant', cval=0
):
""" Reshape a volume to a specified shape by padding or cropping. The
new volume is centered wrt the old volume in world space.
Parameters
----------
img : nib.Nifti1Image
The input image.
volume_shape : tuple of 3 ints
The desired shape of the volume.
mode : str, optional
Padding mode. See np.pad for more information. Default is 'constant'.
cval: float, optional
Value to use for padding when mode is 'constant'. Default is 0.
Returns
-------
reshaped_img : nib.Nifti1Image
The reshaped image.
"""

data = img.get_fdata(dtype=np.float32)
affine = img.affine

# Compute the difference between the desired shape and the current shape
diff = (np.array(volume_shape) - np.array(data.shape[:3])) // 2

# Compute the offset to center the data
offset = (np.array(volume_shape) - np.array(data.shape[:3])) % 2

# Compute the padding values (before and after) for all axes
pad_width = np.zeros((len(data.shape), 2), dtype=int)
for i in range(3):
pad_width[i, 0] = int(max(0, diff[i] + offset[i]))
pad_width[i, 1] = int(max(0, diff[i]))

# If dealing with 4D data, do not pad the last dimension
if len(data.shape) == 4:
pad_width[3, :] = [0, 0]

# Pad the data
kwargs = {
'mode': mode,
}
# Add constant_values only if mode is 'constant'
# Otherwise, it will raise an error
if mode == 'constant':
kwargs['constant_values'] = cval
padded_data = np.pad(data, pad_width, **kwargs)

# Compute the cropping values (before and after) for all axes
crop_width = np.zeros((len(data.shape), 2))
for i in range(3):
crop_width[i, 0] = -diff[i] - offset[i]
crop_width[i, 1] = np.ceil(padded_data.shape[i] + diff[i])

# If dealing with 4D data, do not crop the last dimension
if len(data.shape) == 4:
crop_width[3, :] = [0, data.shape[3]]

# Crop the data
cropped_data = crop(
padded_data, np.maximum(0, crop_width[:, 0]).astype(int),
crop_width[:, 1].astype(int))

# Compute the new affine
translation = voxel_to_world(crop_width[:, 0], affine)
new_affine = np.copy(affine)
new_affine[0:3, 3] = translation[0:3]

return nib.Nifti1Image(cropped_data, new_affine)


def mask_data_with_default_cube(data):
"""Masks data outside a default cube (Cube: data.shape/3 centered)
Expand Down
94 changes: 94 additions & 0 deletions scripts/scil_volume_reshape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Script to reshape a volume to match the resolution of another
reference volume or to the resolution specified as in argument. The resulting
volume will be centered in world space wrt. the reference volume or the
specified resolution.
This script will pad or crop the volume to match the desired shape.
To interpolate, use scil_volume_resample.py.
"""

import argparse
import logging

import nibabel as nib

from scilpy.io.utils import (add_verbose_arg, add_overwrite_arg,
assert_inputs_exist, assert_outputs_exist)
from scilpy.image.volume_operations import reshape_volume


def _build_arg_parser():
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawTextHelpFormatter)

p.add_argument('in_image',
help='Path of the input volume.')
p.add_argument('out_image',
help='Path of the resampled volume.')

res_group = p.add_mutually_exclusive_group(required=True)
res_group.add_argument(
'--ref',
help='Reference volume to resample to.')
res_group.add_argument(
'--volume_size', nargs='+', type=int,
help='Sets the size for the volume. If the value is set to is Y, '
'it will resample to a shape of Y x Y x Y.')

p.add_argument(
'--mode', default='constant',
choices=['constant', 'edge', 'wrap', 'reflect'],
help="Padding mode.\nconstant: pads with a constant value.\n"
"edge: repeats the edge value.\nDefaults to [%(default)s].")
p.add_argument('--constant_value', type=float, default=0,
help='Value to use for padding when mode is constant.')

add_verbose_arg(p)
add_overwrite_arg(p)

return p


def main():
parser = _build_arg_parser()
args = parser.parse_args()
logging.getLogger().setLevel(logging.getLevelName(args.verbose))

# Checking args
assert_inputs_exist(parser, args.in_image, args.ref)
assert_outputs_exist(parser, args, args.out_image)

if args.volume_size and (not len(args.volume_size) == 1 and
not len(args.volume_size) == 3):
parser.error('--volume_size takes in either 1 or 3 arguments.')

logging.info('Loading raw data from %s', args.in_image)

img = nib.load(args.in_image)

ref_img = None
if args.ref:
ref_img = nib.load(args.ref)
volume_shape = ref_img.shape[:3]
else:
if len(args.volume_size) == 1:
volume_shape = [args.volume_size[0]] * 3
else:
volume_shape = args.volume_size

# Resampling volume
reshaped_img = reshape_volume(img, volume_shape,
mode=args.mode,
cval=args.constant_value)

# Saving results
logging.info('Saving reshaped data to %s', args.out_image)
nib.save(reshaped_img, args.out_image)


if __name__ == '__main__':
main()

0 comments on commit 856586b

Please sign in to comment.