Skip to content

ulzee/raptor-mmnist

Folders and files

NameName
Last commit message
Last commit date

Latest commit

7279ea4 · Jan 27, 2025

History

8 Commits
Jan 27, 2025
Jan 22, 2025
Jan 27, 2025
Jan 27, 2025
Jan 27, 2025
Jan 27, 2025
Jan 27, 2025
Jan 22, 2025
Jan 22, 2025
Jan 27, 2025
Jan 22, 2025
Jan 22, 2025
Jan 27, 2025

Repository files navigation

Raptor benchmarks

Run Raptor and related methods on MedMNIST formatted datasets.

MedMNIST data format

MedMNIST datasets are in npz format with keys 'train_images', 'train_labels', etc... for train, val, and test. The 'train_images' entry should contain volumes in N x H x W x D. The labels are assumed to be classification with mutually exclusive classes in the shape N x 1 where each entry ranges from 0...C .

The datasets are expected to be in ../medmnist. If it is at a different path, a symlink can be created at this location for convenience.

Extended MedMNIST data format

We also store multi-label classification and regression tasks using a similar format as described. For multi-label, we define a one-hot indicator for the '*_labels' entry in the shape N x C. For regression, we store the quantitative labels (can be multi dimensional) also in the shape N x C.

We allow volumes to be stored outside the npz due to performance/storage issues. If this is the case, partial paths to the volumes (string) should be stored in '*_images' instead of the actual volumes. The training scripts should handle some root_path option such that the volumes can be located at different drives on different machines.

Raptor scripts overview

The scripts relevant to raptor are:

  • embed.py: Uses Raptor to embed volumes.
  • fit.py: Logistic regression on precomputed embeddings (Raptor, VoCo)

Training scripts overview

  • fit.py: Logistic regression on precomputed embeddings (Raptor, VoCo)
  • fit_e2e.py: End to end training for pretrained models (PCTNet, SuPreM, ResNet, Merlin)
  • fit_mae.py: MAE training and classification training for standard ViT (MAE)

Adding new methods to test in our pipeline:

Generally it is recommended to add new methods to compare with at the top of fit_mae.py so that we can conveniently use one script to run all experiments. For example, PCT-Net was added in the following way:

Import the model definition from a local path and define default options for this model:

elif args.model == 'pctnet':
    input_size = (64, 128, 128)

    sys.path.append('../etc/MIS-FM/')
    from net.pct_net import PCTNet
    from pymic.util.parse_config import parse_config
    from pymic.util.general import mixup, tensor_shape_match

    class defaults:
        stage = 'train'
        cfg = '../etc/MIS-FM/demo/pctnet_pretrain.cfg'
        num_classes = 100 # infer a sufficient number of outplanes for most tasks

    config = parse_config(defaults)
    config['network']['class_num'] = defaults.num_classes

Wrap the model to do classification:

    class PCTNetClassifier(PCTNet):
        def __init__(self, params):
            super().__init__(params)

            self.fc = nn.Linear(128+defaults.num_classes*2, num_classes)

        def forward(self, x):
            x0, x1, x2 = self.embeddings[self.resolution_mode](x)
            x2  = self.pyramid_ct(x2)
            pooled = [F.adaptive_avg_pool3d(s, (1, 1, 1)).view(len(s), -1) for s in x2]
            pooled = torch.cat(pooled, -1)
            return self.fc(pooled)

    model = PCTNetClassifier(config['network']).to(device)

Load pretrained weights shared by the authors of the model:

    def load_pretrained_weights(network, pretrained_dict, device_ids):
        if(len(device_ids) > 1):
            if(hasattr(network.module, "get_parameters_to_load")):
                model_dict = network.module.get_parameters_to_load()
            else:
                model_dict = network.module.state_dict()
        else:
            if(hasattr(network, "get_parameters_to_load")):
                model_dict = network.get_parameters_to_load()
            else:
                model_dict = network.state_dict()
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if \
            k in model_dict and tensor_shape_match(pretrained_dict[k], model_dict[k])}
        if (len(device_ids) > 1):
            network.module.load_state_dict(pretrained_dict, strict = False)
        else:
            network.load_state_dict(pretrained_dict, strict = False)
    w = torch.load('../etc/MIS-FM/weights/pctnet_ct10k_volf.pt')['model_state_dict']
    load_pretrained_weights(model, w, [None])

About

run raptor on medmnist formatted data

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published