Skip to content

Commit

Permalink
Renaming plot_style and draw_style to be consistent with mode behaviour.
Browse files Browse the repository at this point in the history
  • Loading branch information
quantgirluk committed Dec 3, 2024
1 parent a88edc9 commit 1ff6518
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
12 changes: 6 additions & 6 deletions aleatory/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,18 @@ def times_to_increments(times):
return check_increments(times)


def plot_paths(times, paths, style="seaborn-v0_8-whitegrid", title=None, plot_style="linear", **fig_kw):
def plot_paths(times, paths, style="seaborn-v0_8-whitegrid", title=None, mode="linear", **fig_kw):
with plt.style.context(style):
fig, ax = plt.subplots(**fig_kw)
for p in paths:
if plot_style == 'points':
if mode == 'points':
ax.scatter(times, p, s=7)
elif plot_style == 'steps':
elif mode == 'steps':
ax.step(times, p, where='post')
elif plot_style == 'linear':
elif mode == 'linear':
ax.plot(times, p)
else:
raise ValueError("plot_style must be 'points', 'steps', or 'linear'.")
raise ValueError("mode must be 'points', 'steps', or 'linear'.")
ax.set_title(title)
ax.set_xlabel('$t$')
ax.set_ylabel('$X(t)$')
Expand All @@ -85,7 +85,7 @@ def plot_paths_random_walk(*args, times, paths, style="seaborn-v0_8-whitegrid",
color = plt.gca().lines[-1].get_color()
ax.plot(times, p, 'o', color=color)
else:
raise ValueError("plot_style must be 'points', 'steps', or 'points+steps'.")
raise ValueError("mode must be 'points', 'steps', or 'points+steps'.")
ax.set_title(title)
ax.set_xlabel('$t$')
ax.set_ylabel('$X(t)$')
Expand Down
6 changes: 3 additions & 3 deletions tests/test_charts.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,11 @@ def test_charts_orientation():
process.draw(n=100, N=200, envelope=False, orientation='vertical', figsize=(12, 6))


def test_charts_styles():
def test_charts_modes():
for sty in ["steps", "points", "linear"]:
for process in [Vasicek()]:
process.plot(n=100, N=200, figsize=(12, 6), plot_style=sty)
process.draw(n=100, N=200, envelope=False, figsize=(12, 6), draw_style=sty)
process.plot(n=100, N=200, figsize=(12, 6), mode=sty)
process.draw(n=100, N=200, envelope=False, figsize=(12, 6), mode=sty)


def test_poisson():
Expand Down

0 comments on commit 1ff6518

Please sign in to comment.