Skip to content

Commit

Permalink
feat: openai baselines plot util
Browse files Browse the repository at this point in the history
  • Loading branch information
songlei00 committed Dec 5, 2024
1 parent 49e644b commit 29e4035
Showing 1 changed file with 93 additions and 48 deletions.
141 changes: 93 additions & 48 deletions bbo/benchmarks/analyzers/plot_utils.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,54 @@
# https://github.com/openai/baselines/blob/master/baselines/common/plot_util.py

import logging
import os
from typing import Sequence, Dict, List
from typing import Dict, List, Callable
from attrs import define, field, validators
from collections import defaultdict, namedtuple

import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt

from bbo.utils.problem_statement import ProblemStatement
from bbo.utils.metric_config import ObjectiveMetricGoal
from bbo.utils.trial import Trial
from bbo.benchmarks.analyzers.utils import trials2df

logger = logging.getLogger(__name__)


Result = namedtuple('Result', ['name', 'data'])

def default_xy_fn(r: Result):
return r.data['x'], r.data['y']

def default_split_fn(r: Result):
return ''

def default_group_fn(r: Result):
return r.name


@define
class PlotUtil:
_problem_statement: ProblemStatement = field(
validator=validators.instance_of(ProblemStatement)
)
_save_dir: str | None = field(
validator=validators.optional(validators.instance_of(str)), default=None
)
_xy_fn: Callable = field()
_split_fn: Callable = field(default=default_split_fn, kw_only=True)
_group_fn: Callable = field(default=default_group_fn, kw_only=True)
_xlabel: str | None = field(
validator=validators.optional(validators.instance_of(str)), default=None, kw_only=True
)
_ylabel: str | None = field(
validator=validators.optional(validators.instance_of(str)), default=None, kw_only=True
)
_title: str | None = field(
validator=validators.optional(validators.instance_of(str)), default=None, kw_only=True
_save_path: str | None = field(
validator=validators.optional(validators.instance_of(str)),
default='log/plot_result.pdf', kw_only=True
)
_legend_show: bool = field(default=True, validator=validators.instance_of(bool), kw_only=True)
_ncols: int | None = field(
default=None, validator=validators.optional(validators.instance_of(int)), kw_only=True
)
_shaded_err: bool = field(default=True, validator=validators.instance_of(bool), kw_only=True)
_shaded_std: bool = field(default=True, validator=validators.instance_of(bool), kw_only=True)

_label2trials: Dict[str, Sequence[Trial]] = field(factory=dict, init=False)
_allresults: List[Result] = field(factory=list, init=False)
_colors: List[str] = field(factory=lambda: [
'blue', 'green', 'red', 'cyan', 'magenta', 'yellow', 'black', 'purple', 'pink',
'brown', 'orange', 'teal', 'lightblue', 'lime', 'lavender', 'turquoise',
Expand All @@ -46,38 +63,66 @@ class PlotUtil:
'ytick.labelsize': 20,
}, init=False)

def add_trials(self, trials: Sequence[Trial], label: str):
self._label2trials[label] = trials

def del_trials(self, label: str):
self._label2trials.pop(label, None)

def plot(self, besty=False):
for m in self._problem_statement.objective.metrics:
with plt.rc_context(self._params):
_, ax = plt.subplots(dpi=300)
for i, (label, trials) in enumerate(self._label2trials.items()):
df = trials2df(trials)
if besty:
if m.goal == ObjectiveMetricGoal.MAXIMIZE:
y = df[m.name].cummax()
else:
y = df[m.name].cummin()
else:
y = df[m.name]
ax.plot(df.index, y, label=label, color=self._colors[i%len(self._colors)])
def add_result(self, name: str, df: pd.DataFrame):
self._allresults.append(Result(name, df))

def plot(self):
rc_params = matplotlib.rcParams
matplotlib.rcParams.update(self._params)

assert len(self._allresults) > 0
sk2r = defaultdict(list)
for r in self._allresults:
sk2r[self._split_fn(r)].append(r)

ncols = int(self._ncols or np.ceil(np.sqrt(len(sk2r))))
nrows = int(np.ceil(len(sk2r) / ncols))
_, axarr = plt.subplots(nrows, ncols, squeeze=False, dpi=300, figsize=(10*ncols, 6*nrows))

groups = list(set(self._group_fn(result) for result in self._allresults))
groups.sort()
g2l = dict()
for i, sk in enumerate(sorted(sk2r.keys())):
idx_row, idx_col = i // ncols, i % ncols
ax = axarr[idx_row][idx_col]
ax.set_title(sk)
sresults = sk2r[sk]
gresults = defaultdict(list)
for r in sresults:
group = self._group_fn(r)
x, y = self._xy_fn(r)
gresults[group].append((x, y))

for j, group in enumerate(groups):
xys = gresults[group]
if not any(xys):
continue
color = self._colors[j % len(self._colors)]
x = xys[0][0]
ys = [xy[1] for xy in xys]
ymean = np.mean(ys, axis=0)
ystd = np.std(ys, axis=0)
ystderr = ystd / np.sqrt(len(ys))
l, = ax.plot(x, ymean, color=color, label=group)
g2l[group] = l
if self._shaded_err:
ax.fill_between(x, ymean-ystderr, ymean+ystderr, color=color, alpha=.4)
if self._shaded_std:
ax.fill_between(x, ymean-ystd, ymean+ystd, color=color, alpha=.2)

if self._legend_show and any(g2l.keys()):
axarr[0][-1].legend(g2l.values(), g2l.keys(), loc=2, bbox_to_anchor=(1, 1))
if self._xlabel is not None:
for ax in axarr[-1]:
ax.set_xlabel(self._xlabel)
if self._ylabel is not None:
for ax in axarr[:, 0]:
ax.set_ylabel(self._ylabel)
ax.set_title(self._title)
ax.legend()
if self._save_dir is not None:
if not os.path.exists(self._save_dir):
os.makedirs(self._save_dir)
if besty:
file_name = os.path.join(self._save_dir, m.name + '_best.pdf')
else:
file_name = os.path.join(self._save_dir, m.name + '.pdf')
plt.savefig(file_name, bbox_inches='tight')
logger.debug('Plot to {}'.format(file_name))
else:
plt.show()
plt.tight_layout()
matplotlib.rcParams.update(rc_params)

if self._save_path is not None:
save_dir = os.path.dirname(self._save_path)
if any(save_dir) and not os.path.exists(save_dir):
os.makedirs(save_dir)
plt.savefig(self._save_path)

0 comments on commit 29e4035

Please sign in to comment.