Skip to content

Commit

Permalink
Update plotly function
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkMnDragon committed Dec 21, 2024
1 parent c7a476e commit 1341a09
Showing 1 changed file with 41 additions and 3 deletions.
44 changes: 41 additions & 3 deletions rectified_flow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,13 @@ def visualize_2d_trajectories_plotly(
# Create figure
fig = go.Figure()

# Collect all x and y values
all_x, all_y = [], []

# Plot ground truth samples
if D1_gt_samples is not None:
all_x.extend(D1_gt_samples[:, dim0])
all_y.extend(D1_gt_samples[:, dim1])
fig.add_trace(
go.Scatter(
x=D1_gt_samples[:, dim0],
Expand All @@ -295,6 +300,9 @@ def visualize_2d_trajectories_plotly(
current_trace_index = len(fig.data)

for trajectory_name, xtraj in trajectory_data.items():
all_x.extend(xtraj[:, :, dim0].ravel())
all_y.extend(xtraj[:, :, dim1].ravel())

particle_color = colors[trajectory_name]["particle_color"]
trajectory_color = colors[trajectory_name]["trajectory_color"]
marker_symbol = colors[trajectory_name]["marker"]
Expand Down Expand Up @@ -398,7 +406,7 @@ def visualize_2d_trajectories_plotly(
),
opacity=alpha_particles,
name=f"{trajectory_name} x_t",
showlegend=True,
showlegend=False,
)
)
frame_trace_indices.append(trace_index)
Expand Down Expand Up @@ -434,6 +442,37 @@ def visualize_2d_trajectories_plotly(
)
]

min_x, max_x = np.min(all_x), np.max(all_x)
min_y, max_y = np.min(all_y), np.max(all_y)
delta_x = 0.02 * (max_x - min_x)
delta_y = 0.02 * (max_y - min_y)

fig.update_xaxes(
range=[min_x - delta_x, max_x + delta_x],
showgrid=True,
gridcolor="white",
gridwidth=1,
griddash="dot",
showticklabels=False,
showline=False,
zeroline=False,
mirror=False,
dtick=2.0,
)

fig.update_yaxes(
range=[min_y - delta_y, max_y + delta_y],
showgrid=True,
gridcolor="white",
gridwidth=1,
griddash="dot",
showticklabels=False,
showline=False,
zeroline=False,
mirror=False,
dtick=2.0,
)

# Update figure layout
fig.update_layout(
sliders=sliders,
Expand All @@ -458,12 +497,11 @@ def visualize_2d_trajectories_plotly(
],
}
],
xaxis_title=f"Dimension {dim0}",
yaxis_title=f"Dimension {dim1}",
title=title,
showlegend=show_legend,
height=600,
width=900,
# autosize=True,
)

# Add frames
Expand Down

0 comments on commit 1341a09

Please sign in to comment.