Skip to content

Commit

Permalink
Add function to save primitive parameters as dictionary
Browse files Browse the repository at this point in the history
  • Loading branch information
paschalidoud committed Aug 4, 2019
1 parent 2e068ca commit 2c38d0e
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 1 deletion.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ backports.functools_lru_cache

# Visualization
matplotlib
seaborn
#mayavi
#PyQt5
17 changes: 16 additions & 1 deletion scripts/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
add_dataset_parameters, add_gaussian_noise_layer_parameters, \
voxelizer_shape, add_tsdf_fusion_parameters, \
add_loss_options_parameters, add_loss_parameters
from utils import get_colors
from utils import get_colors, store_primitive_parameters

from learnable_primitives.common.dataset import get_dataset_type,\
compose_transformations
Expand Down Expand Up @@ -311,6 +311,8 @@ def main(argv):
R = quaternions_to_rotation_matrices(
y_hat[2].view(-1, 4)
).to("cpu").detach()
# get also the raw quaternions
quats = y_hat[2].view(-1, 4).to("cpu").detach().numpy()
translations = y_hat[1].to("cpu").view(args.n_primitives, 3)
translations = translations.detach().numpy()

Expand Down Expand Up @@ -350,6 +352,19 @@ def main(argv):
taperings[i, 0],
taperings[i, 1]
)
store_primitive_parameters(
size=tuple(shapes[i]),
shape=tuple(epsilons[i]),
rotation=tuple(quats[i]),
location=tuple(translations[i]),
tapering=tuple(taperings[i]),
probability=(probs[0, i],),
color=(colors[i % len(colors)]) + (1.0,),
filepath=os.path.join(
args.output_directory,
"primitive_%d.p" %(i,)
)
)
if probs[0, i] >= args.prob_threshold:
on_prims += 1
mlab.mesh(
Expand Down
26 changes: 26 additions & 0 deletions scripts/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pandas as pd
import pickle

import matplotlib
matplotlib.use("agg")
Expand Down Expand Up @@ -35,3 +36,28 @@ def parse_train_test_splits(train_test_splits_file, model_tags):

def get_colors(M):
return sns.color_palette("Paired")


def store_primitive_parameters(
size,
shape,
rotation,
location,
tapering,
probability,
color,
filepath
):
primitive_params = dict(
size=size,
shape=shape,
rotation=rotation,
location=location,
tapering=tapering,
probability=probability,
color=color
)
pickle.dump(
primitive_params,
open(filepath, "wb")
)

0 comments on commit 2c38d0e

Please sign in to comment.