Skip to content

Commit

Permalink
Update random_sampled_dict
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Aug 31, 2023
1 parent 23c08c9 commit a74389d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
9 changes: 3 additions & 6 deletions gpax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,24 +83,21 @@ def split_dict(data: Dict[str, jnp.ndarray], chunk_size: int

def random_sample_dict(data: Dict[str, jnp.ndarray],
num_samples: int,
seed: int = 42) -> Dict[str, jnp.ndarray]:
rng_key: jnp.ndarray) -> Dict[str, jnp.ndarray]:
"""Returns a dictionary with a smaller number of consistent random samples for each array.
Args:
data: Dictionary containing numpy arrays.
num_samples: Number of random samples required.
seed: Seed for the random number generator.
rng_key: Random number generator key
Returns:
Dictionary with the consistently sampled arrays.
"""

# Create a random key
key = jax.random.PRNGKey(seed)

# Generate unique random indices
num_data_points = len(next(iter(data.values())))
indices = jax.random.permutation(key, num_data_points)[:num_samples]
indices = jax.random.permutation(rng_key, num_data_points)[:num_samples]

return {key: value[indices] for key, value in data.items()}

Expand Down
27 changes: 22 additions & 5 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import sys
import numpy as onp
import jax.numpy as jnp
import jax.random as jra
from numpy.testing import assert_equal, assert_, assert_array_equal

sys.path.insert(0, "../gpax/")

from gpax.utils import preprocess_sparse_image, split_dict, random_sample_dict
from gpax.utils import preprocess_sparse_image, split_dict, random_sample_dict, get_keys


def test_sparse_img_processing():
Expand Down Expand Up @@ -56,7 +57,8 @@ def test_random_sample_size():
'c': jnp.array([10, 20, 30, 40, 50])
}
num_samples = 3
sampled_data = random_sample_dict(data, num_samples)
rng_key = jra.PRNGKey(123)
sampled_data = random_sample_dict(data, num_samples, rng_key)
for value in sampled_data.values():
assert_(len(value) == num_samples)

Expand All @@ -68,9 +70,24 @@ def test_random_sample_consistency():
'c': jnp.array([10, 20, 30, 40, 50])
}
num_samples = 3
seed = 123
sampled_data1 = random_sample_dict(data, num_samples, seed)
sampled_data2 = random_sample_dict(data, num_samples, seed)
rng_key = jra.PRNGKey(123)
sampled_data1 = random_sample_dict(data, num_samples, rng_key)
sampled_data2 = random_sample_dict(data, num_samples, rng_key)

for key in sampled_data1:
assert_(jnp.array_equal(sampled_data1[key], sampled_data2[key]))


def test_random_sample_difference():
data = {
'a': jnp.array([1, 2, 3, 4, 5]),
'b': jnp.array([5, 4, 3, 2, 1]),
'c': jnp.array([10, 20, 30, 40, 50])
}
num_samples = 3
rng_key1, rng_key2 = get_keys()
sampled_data1 = random_sample_dict(data, num_samples, rng_key1)
sampled_data2 = random_sample_dict(data, num_samples, rng_key2)

for key in sampled_data1:
assert_(jnp.array_equal(sampled_data1[key], sampled_data2[key]))

0 comments on commit a74389d

Please sign in to comment.