Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mini sbibm #1335

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Prev Previous commit
Next Next commit
extended to something reasonable
  • Loading branch information
manuelgloeckler committed Dec 19, 2024
commit 6fa9b96ae52b00d9757053d3fa5c36867062a34e
50 changes: 50 additions & 0 deletions test_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from datetime import datetime
import pytest

import shutil


def pytest_addoption(parser):
parser.addoption(
"--print-harvest",
action="store_true",
default=False,
help="Print the harvest results at the end of the test session",
)


def pytest_terminal_summary(terminalreporter, exitstatus, config):
# Only print the harvest results if the --print-harvest flag is used
if config.getoption("--print-harvest"):
# Dynamically center the summary title in the terminal width
terminal_width = shutil.get_terminal_size().columns
summary_text = " short test summary info "
centered_line = summary_text.center(terminal_width, '=')
terminalreporter.write_line(centered_line)


@pytest.mark.parametrize('p', ['world', 'self'], ids=str)
def test_foo(p, results_bag):
"""
A dummy test, parametrized so that it is executed twice
"""

# Let's store some things in the results bag
results_bag.nb_letters = len(p)
results_bag.current_time = datetime.now().isoformat()


def test_synthesis(fixture_store):
"""
In this test we inspect the contents of the fixture store so far, and
check that the 'results_bag' entry contains a dict <test_id>: <results_bag>
"""
# print the keys in the store
results = fixture_store["results_bag"]

# print what is available for the 'results_bag' entry
print("\n--- Harvested Test Results ---")
for k, v in results.items():
print(k)
for kk, vv in v.items():
print(kk, vv)
146 changes: 108 additions & 38 deletions tests/bm_test.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,27 @@
import pytest
import torch

from sbi.inference import NPE, NRE
from sbi.inference import FMPE, NLE, NPE, NPSE, NRE
from sbi.utils.metrics import c2st

from .mini_sbibm import get_task

# The probably should be some user control on this
SEED = 0
TASKS = ["two_moons", "linear_mvg_2d", "gaussian_linear", "slcp"]
NUM_SIMULATIONS = 2000
EVALUATION_POINTS = 4 # Currently only 3 observation tested for speed

@pytest.mark.benchmark
@pytest.mark.parametrize('task_name', ['two_moons'], ids=str)
@pytest.mark.parametrize('density_estimator', ["maf", "nsf"], ids=str)
def test_benchmark_npe_methods(
task_name, density_estimator, results_bag, method=None, num_simulations=1000, seed=0
):
torch.manual_seed(seed)
task = get_task(task_name)
thetas, xs = task.get_data(num_simulations)
assert thetas.shape[0] == num_simulations
assert xs.shape[0] == num_simulations
TRAIN_KWARGS = {
# "training_batch_size": 200, # To speed up training
}

inference = NPE(density_estimator=density_estimator)
_ = inference.append_simulations(thetas, xs).train()
# Amortized benchmarking

posterior = inference.build_posterior()

def standard_eval_c2st_loop(posterior, task):
metrics = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should better be called c2st_scores or something, not general metrics. unless we are calculating more metrics here.

for i in range(1, 2): # Currently only one observation tested for speed
for i in range(1, EVALUATION_POINTS):
x_o = task.get_observation(i)
posterior_samples = task.get_reference_posterior_samples(i)
approx_posterior_samples = posterior.sample((1000,), x=x_o)
@@ -37,43 +33,117 @@ def test_benchmark_npe_methods(
mean_c2st = sum(metrics) / len(metrics)
# Convert to float rounded to 3 decimal places
mean_c2st = float(f"{mean_c2st:.3f}")
return mean_c2st


DENSITY_estimators = ["mdn", "made", "maf", "nsf", "maf_rqs"] # "Kinda exhaustive"
DENSITY_estimators = ["maf", "nsf"] # Fast


@pytest.mark.benchmark
@pytest.mark.parametrize('task_name', TASKS, ids=str)
@pytest.mark.parametrize('density_estimator', DENSITY_estimators, ids=str)
def test_benchmark_npe_methods(task_name, density_estimator, results_bag):
torch.manual_seed(SEED)
task = get_task(task_name)
thetas, xs = task.get_data(NUM_SIMULATIONS)
prior = task.get_prior()

print(thetas.shape, xs.shape)

inference = NPE(prior, density_estimator=density_estimator)
_ = inference.append_simulations(thetas, xs).train(**TRAIN_KWARGS)

posterior = inference.build_posterior()

mean_c2st = standard_eval_c2st_loop(posterior, task)

# Cache results
results_bag.metric = mean_c2st
results_bag.num_simulations = num_simulations
results_bag.num_simulations = NUM_SIMULATIONS
results_bag.task_name = task_name
results_bag.method = "NPE_" + density_estimator


@pytest.mark.benchmark
@pytest.mark.parametrize('task_name', ['two_moons'], ids=str)
def test_benchmark_nre_methods(task_name, results_bag, num_simulations=1000, seed=0):
torch.manual_seed(seed)
@pytest.mark.parametrize('task_name', TASKS, ids=str)
def test_benchmark_nre_methods(task_name, results_bag):
torch.manual_seed(SEED)
task = get_task(task_name)
thetas, xs = task.get_data(num_simulations)
thetas, xs = task.get_data(NUM_SIMULATIONS)
prior = task.get_prior()
assert thetas.shape[0] == num_simulations
assert xs.shape[0] == num_simulations

inference = NRE(prior)
_ = inference.append_simulations(thetas, xs).train()
_ = inference.append_simulations(thetas, xs).train(**TRAIN_KWARGS)

posterior = inference.build_posterior()

metrics = []
for i in range(1, 2):
x_o = task.get_observation(i)
posterior_samples = task.get_reference_posterior_samples(i)
approx_posterior_samples = posterior.sample((1000,), x=x_o)
if isinstance(approx_posterior_samples, tuple):
approx_posterior_samples = approx_posterior_samples[0]
c2st_val = c2st(posterior_samples[:1000], approx_posterior_samples)
metrics.append(c2st_val)

mean_c2st = sum(metrics) / len(metrics)
# Convert to float rounded to 3 decimal places
mean_c2st = float(f"{mean_c2st:.3f}")
mean_c2st = standard_eval_c2st_loop(posterior, task)

results_bag.metric = mean_c2st
results_bag.num_simulations = num_simulations
results_bag.num_simulations = NUM_SIMULATIONS
results_bag.task_name = task_name
results_bag.method = "NRE"


@pytest.mark.benchmark
@pytest.mark.parametrize('task_name', TASKS, ids=str)
def test_benchmark_nle_methods(task_name, results_bag):
torch.manual_seed(SEED)
task = get_task(task_name)
thetas, xs = task.get_data(NUM_SIMULATIONS)
prior = task.get_prior()

inference = NLE(prior)
_ = inference.append_simulations(thetas, xs).train(**TRAIN_KWARGS)

posterior = inference.build_posterior()

mean_c2st = standard_eval_c2st_loop(posterior, task)

results_bag.metric = mean_c2st
results_bag.num_simulations = NUM_SIMULATIONS
results_bag.task_name = task_name
results_bag.method = "NLE"


@pytest.mark.benchmark
@pytest.mark.parametrize('task_name', TASKS, ids=str)
def test_benchmark_fmpe_methods(task_name, results_bag):
torch.manual_seed(SEED)
task = get_task(task_name)
thetas, xs = task.get_data(NUM_SIMULATIONS)
prior = task.get_prior()

inference = FMPE(prior)
_ = inference.append_simulations(thetas, xs).train(**TRAIN_KWARGS)

posterior = inference.build_posterior()

mean_c2st = standard_eval_c2st_loop(posterior, task)

results_bag.metric = mean_c2st
results_bag.num_simulations = NUM_SIMULATIONS
results_bag.task_name = task_name
results_bag.method = "FMPE"


@pytest.mark.benchmark
@pytest.mark.parametrize('task_name', TASKS, ids=str)
def test_benchmark_npse_methods(task_name, results_bag):
torch.manual_seed(SEED)
task = get_task(task_name)
thetas, xs = task.get_data(NUM_SIMULATIONS)
prior = task.get_prior()

inference = NPSE(prior)
_ = inference.append_simulations(thetas, xs).train(**TRAIN_KWARGS)

posterior = inference.build_posterior()

mean_c2st = standard_eval_c2st_loop(posterior, task)

results_bag.metric = mean_c2st
results_bag.num_simulations = NUM_SIMULATIONS
results_bag.task_name = task_name
results_bag.method = "NPSE"
53 changes: 51 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
import pickle
import shutil
from logging import warning
from pathlib import Path
from shutil import rmtree

import pytest
import torch
@@ -79,7 +83,7 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
terminalreporter.write_line(colored_line)

if harvested_fixture_data is not None:
terminalreporter.write_line("Harvested Fixture Data:")
terminalreporter.write_line("Amortized inference:")

results = harvested_fixture_data["results_bag"]

@@ -131,7 +135,7 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
val = data.get((m, t), "N/A")
# Convert metric to string with formatting if needed
# e.g. format(val, ".3f") if val is a float
val_str = str(val)
val_str = format(val, ".3f")
row += val_str.center(task_col_widths[t] + 2)
terminalreporter.write_line(row)

@@ -149,3 +153,48 @@ def mcmc_params_accurate() -> dict:
def mcmc_params_fast() -> dict:
"""Fixture for MCMC parameters for fast tests."""
return dict(num_chains=1, thin=1, warmup_steps=1)


# Pytest harvest xdist support - not sure if we need it (for me xdist is always slower).


# Define the folder in which temporary worker's results will be stored
RESULTS_PATH = Path('./.xdist_results/')
RESULTS_PATH.mkdir(exist_ok=True)


def pytest_harvest_xdist_init():
# reset the recipient folder
if RESULTS_PATH.exists():
rmtree(RESULTS_PATH)
RESULTS_PATH.mkdir(exist_ok=False)
return True


def pytest_harvest_xdist_worker_dump(worker_id, session_items, fixture_store):
# persist session_items and fixture_store in the file system
with open(RESULTS_PATH / ('%s.pkl' % worker_id), 'wb') as f:
try:
pickle.dump((session_items, fixture_store), f)
except Exception as e:
warning(
"Error while pickling worker %s's harvested results: " "[%s] %s",
(worker_id, e.__class__, e),
)
return True


def pytest_harvest_xdist_load():
# restore the saved objects from file system
workers_saved_material = dict()
for pkl_file in RESULTS_PATH.glob('*.pkl'):
wid = pkl_file.stem
with pkl_file.open('rb') as f:
workers_saved_material[wid] = pickle.load(f)
return workers_saved_material


def pytest_harvest_xdist_cleanup():
# delete all temporary pickle files
rmtree(RESULTS_PATH)
return True
9 changes: 9 additions & 0 deletions tests/mini_sbibm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from .gaussian_linear import GaussianLinear
from .linear_mvg import LinearMVG2d
from .slcp import Slcp
from .two_moons import TwoMoons


def get_task(name: str):
if name == "two_moons":
return TwoMoons()
elif name == "linear_mvg_2d":
return LinearMVG2d()
elif name == "gaussian_linear":
return GaussianLinear()
elif name == "slcp":
return Slcp()
else:
raise ValueError(f"Unknown task {name}")
Binary file added tests/mini_sbibm/files/slcp/samples_1.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/samples_10.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/samples_2.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/samples_3.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/samples_4.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/samples_5.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/samples_6.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/samples_7.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/samples_8.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/samples_9.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/theta_o_1.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/theta_o_10.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/theta_o_2.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/theta_o_3.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/theta_o_4.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/theta_o_5.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/theta_o_6.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/theta_o_7.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/theta_o_8.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/theta_o_9.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/x_o_1.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/x_o_10.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/x_o_2.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/x_o_3.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/x_o_4.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/x_o_5.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/x_o_6.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/x_o_7.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/x_o_8.pt
Binary file not shown.
Binary file added tests/mini_sbibm/files/slcp/x_o_9.pt
Binary file not shown.
55 changes: 55 additions & 0 deletions tests/mini_sbibm/gaussian_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from functools import partial
from typing import Callable

import torch
from torch.distributions import Distribution, MultivariateNormal

from sbi.simulators.linear_gaussian import (
diagonal_linear_gaussian,
true_posterior_linear_gaussian_mvn_prior,
)

from .base_task import Task


class GaussianLinear(Task):
def __init__(self):
self.simulator_scale = 0.1
self.dim = 5
super().__init__("gaussian_linear")

def theta_dim(self) -> int:
return self.dim

def x_dim(self) -> int:
return self.dim

def get_reference_posterior_samples(self, idx: int) -> torch.Tensor:
x_o = self.get_observation(idx)
posterior = true_posterior_linear_gaussian_mvn_prior(
x_o,
torch.zeros(self.dim),
self.simulator_scale * torch.eye(self.dim),
torch.zeros(self.dim),
torch.eye(self.dim),
)

return posterior.sample((10_000,))

def get_true_parameters(self, idx: int) -> torch.Tensor:
torch.manual_seed(idx)
return self.get_prior().sample()

def get_observation(self, idx: int) -> torch.Tensor:
theta_o = self.get_true_parameters(idx)
x_o = self.get_simulator()(theta_o[None, :])[0]
return x_o

def get_prior(self) -> Distribution:
return MultivariateNormal(torch.zeros(self.dim), torch.eye(self.dim))

def get_simulator(self) -> Callable:
return partial(
diagonal_linear_gaussian,
std=self.simulator_scale,
)
Loading
Loading