Skip to content

Commit

Permalink
show_samples for variable-horizon prediction in colab
Browse files Browse the repository at this point in the history
  • Loading branch information
jannerm committed May 23, 2022
1 parent c4bdf45 commit f77b0b0
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
2 changes: 1 addition & 1 deletion diffuser/datasets/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def unnormalize(self, x, eps=1e-4):
x : [ -1, 1 ]
'''
if x.max() > 1 + eps or x.min() < -1 - eps:
print(f'[ datasets/mujoco ] Warning: sample out of range | ({x.min():.4f}, {x.max():.4f})')
# print(f'[ datasets/mujoco ] Warning: sample out of range | ({x.min():.4f}, {x.max():.4f})')
x = np.clip(x, -1, 1)

## [ -1, 1 ] --> [ 0, 1 ]
Expand Down
22 changes: 20 additions & 2 deletions diffuser/utils/colab.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import numpy as np
import einops
import matplotlib.pyplot as plt
from tqdm import tqdm

try:
Expand Down Expand Up @@ -51,7 +52,7 @@ def run_diffusion(model, dataset, obs, n_samples=1, device='cuda:0', **diffusion

def show_diffusion(renderer, observations, n_repeat=100, substep=1, filename='diffusion.mp4', savebase='/content/videos'):
'''
observations : [ n_diffusion_steps x batch_size x horizon x observation_dim ]
observations : [ n_diffusion_steps x batch_size x horizon x observation_dim ]
'''
mkdir(savebase)
savepath = os.path.join(savebase, filename)
Expand All @@ -78,7 +79,7 @@ def show_diffusion(renderer, observations, n_repeat=100, substep=1, filename='di

def show_sample(renderer, observations, filename='sample.mp4', savebase='/content/videos'):
'''
observations : [ batch_size x horizon x observation_dim ]
observations : [ batch_size x horizon x observation_dim ]
'''

mkdir(savebase)
Expand All @@ -97,6 +98,23 @@ def show_sample(renderer, observations, filename='sample.mp4', savebase='/conten
show_video(savepath, height=200)


def show_samples(renderer, observations_l, figsize=12):
'''
observations_l : [ [ n_diffusion_steps x batch_size x horizon x observation_dim ], ... ]
'''

images = []
for observations in observations_l:
path = observations[-1]
img = renderer.composite(None, path)
images.append(img)
images = np.concatenate(images, axis=0)

plt.imshow(images)
plt.axis('off')
plt.gcf().set_size_inches(figsize, figsize)


def show_video(path, height=400):
video = io.open(path, 'r+b').read()
encoded = base64.b64encode(video)
Expand Down

0 comments on commit f77b0b0

Please sign in to comment.