Skip to content

Commit

Permalink
Fix get_emb_size bug when feeding num_buckets as an int to Categorify (
Browse files Browse the repository at this point in the history
  • Loading branch information
rnyak authored Feb 1, 2021
1 parent e4731ab commit 4606b0c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
2 changes: 2 additions & 0 deletions nvtabular/ops/categorify.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,8 @@ def _get_embeddings_dask(paths, cat_names, buckets=None, freq_limit=0):
embeddings = {}
if isinstance(freq_limit, int):
freq_limit = {name: freq_limit for name in cat_names}
if isinstance(buckets, int):
buckets = {name: buckets for name in cat_names}
for col in _get_embedding_order(cat_names):
path = paths.get(col)
num_rows = cudf.io.read_parquet_metadata(path)[0] if path else 0
Expand Down
24 changes: 24 additions & 0 deletions tests/unit/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,30 @@ def test_categorify_freq_limit(tmpdir, freq_limit, buckets, search_sort):
)


def test_categorify_hash_bucket():
df = cudf.DataFrame(
{
"Authors": ["User_A", "User_A", "User_E", "User_B", "User_C"],
"Engaging_User": ["User_B", "User_B", "User_A", "User_D", "User_D"],
"Post": [1, 2, 3, 4, 5],
}
)
cat_names = ["Authors", "Engaging_User"]
buckets = 10
dataset = nvt.Dataset(df)
hash_features = cat_names >> ops.Categorify(num_buckets=buckets)
processor = nvt.Workflow(hash_features)
processor.fit(dataset)
new_gdf = processor.transform(dataset).to_ddf().compute()

# check hashed values
assert new_gdf["Authors"].max() <= (buckets - 1)
assert new_gdf["Engaging_User"].max() <= (buckets - 1)
# check embedding size is equal to the num_buckets after hashing
assert nvt.ops.get_embedding_sizes(processor)["Authors"][0] == buckets
assert nvt.ops.get_embedding_sizes(processor)["Engaging_User"][0] == buckets


@pytest.mark.parametrize("groups", [[["Author", "Engaging-User"]], "Author"])
def test_joingroupby_multi(tmpdir, groups):

Expand Down

0 comments on commit 4606b0c

Please sign in to comment.