Skip to content

Commit

Permalink
fix pred
Browse files Browse the repository at this point in the history
  • Loading branch information
akirasosa committed Mar 4, 2021
1 parent e1995ed commit 20f972b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
4 changes: 3 additions & 1 deletion src/003.pred-nrms.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,16 @@ def make_main_sub(logits: np.ndarray):
slices = np.concatenate(([0], np.cumsum(cand_sizes.values)))
slices = [slice(a, b) for a, b in zip(slices, slices[1:])]

assert len(df_b['b_id'].values) == slices

sub_rows = []
for b_id, s in tqdm(zip(df_b['b_id'].values, slices), total=len(df_b)):
rank = (logits[s] * -1).argsort().argsort() + 1
rank = ','.join(rank.astype(str))
sub_rows.append(f'{b_id} [{rank}]')

return pd.DataFrame(
index=df_b.index,
index=df_b['b_id'],
data=sub_rows,
columns=['preds'],
)
Expand Down
5 changes: 4 additions & 1 deletion src/004.pred-popularity.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,23 @@ def make_popularity_sub(logits: np.ndarray):
'../data/mind-large',
drop_no_hist=False,
)
df_b = df_b[df_b['histories'].apply(len) == 0]
df_b = df_b[df_b['split'] == 'test']

cand_sizes = df_b['candidates'].apply(len)
slices = np.concatenate(([0], np.cumsum(cand_sizes.values)))
slices = [slice(a, b) for a, b in zip(slices, slices[1:])]

assert len(df_b['b_id'].values) == slices

sub_rows = []
for b_id, s in tqdm(zip(df_b['b_id'].values, slices), total=len(df_b)):
rank = (logits[s] * -1).argsort().argsort() + 1
rank = ','.join(rank.astype(str))
sub_rows.append(f'{b_id} [{rank}]')

return pd.DataFrame(
index=df_b.index,
index=df_b['b_id'],
data=sub_rows,
columns=['preds'],
)
Expand Down

0 comments on commit 20f972b

Please sign in to comment.