Skip to content

Commit

Permalink
Add deterministic strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
patel-zeel committed Oct 28, 2023
1 parent 616b3cd commit f0b6c0c
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 0 deletions.
6 changes: 6 additions & 0 deletions astra/torch/al/acquisitions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ def acquire_scores(self, *args, **kwargs):
raise NotImplementedError("This method must be implemented in a subclass")


class DeterministicAcquisition(Acquisition):
@abstractmethod
def acquire_scores(self, *args, **kwargs):
raise NotImplementedError("This method must be implemented in a subclass")


class RandomAcquisition(Acquisition):
def acquire_scores(self, logits: torch.Tensor) -> torch.Tensor:
# logits shape (n_mc_samples, pool_dim, n_classes)
Expand Down
80 changes: 80 additions & 0 deletions astra/torch/al/strategies/deterministic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from astra.torch.al.strategies.base import Strategy
from astra.torch.al.acquisitions.base import DeterministicAcquisition
from astra.torch.al.errors import AcquisitionMismatchError

from typing import Sequence, Dict, Union


class DeterministicStrategy(Strategy):
def __init__(
self,
acquisitions: Union[DeterministicAcquisition, Sequence[DeterministicAcquisition]],
inputs: torch.Tensor,
outputs: torch.Tensor,
):
"""Base class for query strategies
Args:
acquisitions: A sequence of acquisition functions.
inputs: A tensor of inputs.
outputs: A tensor of outputs.
"""
super().__init__(acquisitions, inputs, outputs)

for name, acquisition in self.acquisitions.items():
if not isinstance(acquisition, DeterministicAcquisition):
raise AcquisitionMismatchError(DeterministicAcquisition, acquisition)

def query(
self,
net: nn.Module,
pool_indices: torch.Tensor,
context_indices: torch.Tensor = None,
n_query_samples: int = 1,
n_mc_samples: int = 10,
batch_size: int = None,
) -> Dict[str, torch.Tensor]:
"""Monte Carlo query strategy
Args:
net: A neural network with dropout layers.
pool_indices: The indices of the pool set.
context_indices: This argument is ignored.
n_query_samples: Number of samples to query.
n_mc_samples: This argument is ignored.
batch_size: Batch size for the data loader.
Returns:
best_indices: A dictionary of acquisition names and the corresponding best indices.
"""
assert isinstance(pool_indices, torch.Tensor), f"pool_indices must be a torch.Tensor, got {type(pool_indices)}"

if batch_size is None:
batch_size = len(pool_indices)

data_loader = DataLoader(self.dataset[pool_indices])

# put the model on eval mode
net.eval()

with torch.no_grad():
logits_list = []
for x, _ in data_loader:
logits = net(x)
logits_list.append(logits)
logits = torch.cat(logits_list, dim=1) # (pool_dim, n_classes)

best_indices = {}
for acq_name, acquisition in self.acquisitions.items():
scores = acquisition.acquire_scores(logits)
index = torch.topk(scores, n_query_samples).indices
selected_indices = pool_indices[index]
best_indices[acq_name] = selected_indices

return best_indices
3 changes: 3 additions & 0 deletions astra/torch/al/strategies/diversity.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def query(

data_loader = DataLoader(self.dataset)

# put model on eval mode
net.eval()

with torch.no_grad():
# Get all features
features_list = []
Expand Down
4 changes: 4 additions & 0 deletions astra/torch/al/strategies/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ def query(

data_loader = DataLoader(self.dataset[pool_indices])

# Put model on eval mode
for model in net:
model.eval()

with torch.no_grad():
logits_list = []
for x, _ in data_loader:
Expand Down

0 comments on commit f0b6c0c

Please sign in to comment.