Skip to content

Fix progressbar with nested compound step samplers #7776

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 9 additions & 14 deletions pymc/step_methods/compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,11 @@ def _progressbar_config(n_chains=1):
return columns, stats

@staticmethod
def _make_update_stats_function():
def update_stats(stats, step_stats, chain_idx):
return stats
def _make_update_stats_functions():
def update_stats(step_stats, chain_idx):
return step_stats

return update_stats
return (update_stats,)

# Hack for creating the class correctly when unpickling.
def __getnewargs_ex__(self):
Expand Down Expand Up @@ -332,16 +332,11 @@ def _progressbar_config(self, n_chains=1):

return columns, stats

def _make_update_stats_function(self):
update_fns = [method._make_update_stats_function() for method in self.methods]

def update_stats(stats, step_stats, chain_idx):
for step_stat, update_fn in zip(step_stats, update_fns):
stats = update_fn(stats, step_stat, chain_idx)

return stats

return update_stats
def _make_update_stats_functions(self):
update_functions = []
for method in self.methods:
update_functions.extend(method._make_update_stats_functions())
return update_functions


def flatten_steps(step: BlockedStep | CompoundStep) -> list[BlockedStep]:
Expand Down
16 changes: 4 additions & 12 deletions pymc/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,19 +248,11 @@ def _progressbar_config(n_chains=1):
return columns, stats

@staticmethod
def _make_update_stats_function():
def update_stats(stats, step_stats, chain_idx):
if isinstance(step_stats, list):
step_stats = step_stats[0]
def _make_update_stats_functions():
def update_stats(stats):
return {key: stats[key] for key in ("diverging", "step_size", "tree_size")}

if not step_stats["tune"]:
stats["divergences"][chain_idx] += step_stats["diverging"]

stats["step_size"][chain_idx] = step_stats["step_size"]
stats["tree_size"][chain_idx] = step_stats["tree_size"]
return stats

return update_stats
return (update_stats,)


# A proposal for the next position
Expand Down
20 changes: 8 additions & 12 deletions pymc/step_methods/metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,18 +346,14 @@ def _progressbar_config(n_chains=1):
return columns, stats

@staticmethod
def _make_update_stats_function():
def update_stats(stats, step_stats, chain_idx):
if isinstance(step_stats, list):
step_stats = step_stats[0]

stats["tune"][chain_idx] = step_stats["tune"]
stats["accept_rate"][chain_idx] = step_stats["accept"]
stats["scaling"][chain_idx] = step_stats["scaling"]

return stats

return update_stats
def _make_update_stats_functions():
def update_stats(step_stats):
return {
"accept_rate" if key == "accept" else key: step_stats[key]
for key in ("tune", "accept", "scaling")
}

return (update_stats,)


def tune(scale, acc_rate):
Expand Down
15 changes: 4 additions & 11 deletions pymc/step_methods/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,15 +212,8 @@ def _progressbar_config(n_chains=1):
return columns, stats

@staticmethod
def _make_update_stats_function():
def update_stats(stats, step_stats, chain_idx):
if isinstance(step_stats, list):
step_stats = step_stats[0]
def _make_update_stats_functions():
def update_stats(step_stats):
return {key: step_stats[key] for key in {"tune", "nstep_out", "nstep_in"}}

stats["tune"][chain_idx] = step_stats["tune"]
stats["nstep_out"][chain_idx] = step_stats["nstep_out"]
stats["nstep_in"][chain_idx] = step_stats["nstep_in"]

return stats

return update_stats
return (update_stats,)
41 changes: 31 additions & 10 deletions pymc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,9 +806,8 @@ def __init__(
progressbar=progressbar,
progressbar_theme=progressbar_theme,
)

self.progress_stats = progress_stats
self.update_stats = step_method._make_update_stats_function()
self.update_stats_functions = step_method._make_update_stats_functions()

self._show_progress = show_progress
self.divergences = 0
Expand Down Expand Up @@ -883,27 +882,49 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
if not tuning and stats and stats[0].get("diverging"):
self.divergences += 1

self.progress_stats = self.update_stats(self.progress_stats, stats, chain_idx)
more_updates = (
{stat: value[chain_idx] for stat, value in self.progress_stats.items()}
if self.full_stats
else {}
)
if self.full_stats:
# TODO: Index by chain already?
chain_progress_stats = [
update_states_fn(step_stats)
for update_states_fn, step_stats in zip(
self.update_stats_functions, stats, strict=True
)
]
all_step_stats = {}
for step_stats in chain_progress_stats:
for key, val in step_stats.items():
if key in all_step_stats:
continue
Copy link
Preview

Copilot AI May 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 'continue' statement immediately after checking if the key exists prevents the subsequent duplicate-key handling code (which calculates a new suffix) from ever executing. Consider removing 'continue' so that duplicate keys are properly renamed.

Suggested change
continue

Copilot uses AI. Check for mistakes.

count = (
sum(step_key.startswith(f"{key}_") for step_key in all_step_stats) + 1
)
all_step_stats[f"{key}_{count}"] = val
else:
all_step_stats[key] = val

else:
all_step_stats = {}

# more_updates = (
# {stat: value[chain_idx] for stat, value in progress_stats.items()}
# if self.full_stats
# else {}
# )

self._progress.update(
self.tasks[chain_idx],
completed=draw,
draws=draw,
sampling_speed=speed,
speed_unit=unit,
**more_updates,
**all_step_stats,
)

if is_last:
self._progress.update(
self.tasks[chain_idx],
draws=draw + 1 if not self.combined_progress else draw,
**more_updates,
**all_step_stats,
refresh=True,
)

Expand Down
Loading