Skip to content

Commit

Permalink
Update plot style "steps" to use where="post"
Browse files Browse the repository at this point in the history
  • Loading branch information
quantgirluk committed Dec 1, 2024
1 parent 44b8210 commit 02c73dd
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions aleatory/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def plot_paths(times, paths, style="seaborn-v0_8-whitegrid", title=None, plot_st
if plot_style == 'points':
ax.scatter(times, p, s=7)
elif plot_style == 'steps':
ax.step(times, p)
ax.step(times, p, where='post')
elif plot_style == 'linear':
ax.plot(times, p)
else:
Expand All @@ -77,9 +77,13 @@ def plot_paths_random_walk(*args, times, paths, style="seaborn-v0_8-whitegrid",
if plot_style == 'points':
ax.scatter(times, p, s=7)
elif plot_style == 'steps':
ax.step(times, p)
ax.step(times, p, where='post')
elif plot_style == 'linear':
ax.plot(times, p, *args)
elif plot_style == 'points+steps':
ax.step(times, p, where='post')
color = plt.gca().lines[-1].get_color()
ax.plot(times, p, 'o', color=color)
else:
raise ValueError("plot_style must be 'points', 'steps', or 'linear'.")
ax.set_title(title)
Expand Down Expand Up @@ -149,7 +153,7 @@ def draw_paths_horizontal(times, paths, N, expectations=None, title=None, KDE=Fa
if draw_style == 'points':
ax1.scatter(times, paths[i], s=7, color=cm(colors[i]))
elif draw_style == 'steps':
ax1.step(times, paths[i], color=cm(colors[i]))
ax1.step(times, paths[i], color=cm(colors[i]), where='post')
elif draw_style == 'linear':
ax1.plot(times, paths[i], '-', lw=1.0, color=cm(colors[i]))
else:
Expand All @@ -159,7 +163,7 @@ def draw_paths_horizontal(times, paths, N, expectations=None, title=None, KDE=Fa
ax1.plot(times, expectations, '--', lw=1.75, label='$E[X_t]$')
ax1.legend()
if envelope:
ax1.fill_between(times, upper, lower, alpha=0.25, color='grey')
ax1.fill_between(times, upper, lower, alpha=0.25, color='silver')
plt.subplots_adjust(wspace=0.025, hspace=0.025)

else:
Expand Down Expand Up @@ -231,7 +235,7 @@ def draw_paths_vertical(times, paths, N, expectations, title=None, KDE=False, ma
if draw_style == 'points':
ax1.scatter(times, paths[i], s=7, color=cm(colors[i]))
elif draw_style == 'steps':
ax1.step(times, paths[i], color=cm(colors[i]))
ax1.step(times, paths[i], color=cm(colors[i]), where='post')
elif draw_style == 'linear':
ax1.plot(times, paths[i], '-', lw=1.0, color=cm(colors[i]))
else:
Expand Down

0 comments on commit 02c73dd

Please sign in to comment.