Skip to content

Commit

Permalink
Replaced joinpath with / operator
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastianruder committed May 29, 2018
1 parent 0916bc3 commit 3fce720
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
14 changes: 7 additions & 7 deletions courses/dl2/imdb_scripts/create_toks.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,21 @@ def create_toks(dir_path, chunksize=24000, n_lbls=1, lang='en'):
sys.exit(1)
dir_path = Path(dir_path)
assert dir_path.exists(), f'Error: {dir_path} does not exist.'
df_trn = pd.read_csv(dir_path.joinpath('train.csv'), header=None, chunksize=chunksize)
df_val = pd.read_csv(dir_path.joinpath('val.csv'), header=None, chunksize=chunksize)
df_trn = pd.read_csv(dir_path / 'train.csv', header=None, chunksize=chunksize)
df_val = pd.read_csv(dir_path / 'val.csv', header=None, chunksize=chunksize)

tmp_path = dir_path.joinpath('tmp')
tmp_path.mkdir(exist_ok=True)
tok_trn, trn_labels = get_all(df_trn, n_lbls, lang='en')
tok_val, val_labels = get_all(df_val, n_lbls, lang='en')

np.save(tmp_path.joinpath('tok_trn.npy'), tok_trn)
np.save(tmp_path.joinpath('tok_val.npy'), tok_val)
np.save(tmp_path.joinpath('lbl_trn.npy'), trn_labels)
np.save(tmp_path.joinpath('lbl_val.npy'), val_labels)
np.save(tmp_path / 'tok_trn.npy', tok_trn)
np.save(tmp_path / 'tok_val.npy', tok_val)
np.save(tmp_path / 'lbl_trn.npy', trn_labels)
np.save(tmp_path / 'lbl_val.npy', val_labels)

trn_joined = [' '.join(o) for o in tok_trn]
open(tmp_path.joinpath('joined.txt'), 'w', encoding='utf-8').writelines(trn_joined)
open(tmp_path / 'joined.txt', 'w', encoding='utf-8').writelines(trn_joined)


if __name__ == '__main__': fire.Fire(create_toks)
12 changes: 6 additions & 6 deletions courses/dl2/imdb_scripts/tok2id.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ def tok2id(dir_path, max_vocab=30000, min_freq=1):
print(f'dir_path {dir_path} max_vocab {max_vocab} min_freq {min_freq}')
p = Path(dir_path)
assert p.exists(), f'Error: {p} does not exist.'
tmp_path = p.joinpath('tmp')
tmp_path = p / 'tmp'
assert tmp_path.exists(), f'Error: {tmp_path} does not exist.'

trn_tok = np.load(tmp_path.joinpath('tok_trn.npy'))
val_tok = np.load(tmp_path.joinpath('tok_val.npy'))
trn_tok = np.load(tmp_path / 'tok_trn.npy')
val_tok = np.load(tmp_path / 'tok_val.npy')

freq = Counter(p for o in trn_tok for p in o)
print(freq.most_common(25))
Expand All @@ -23,8 +23,8 @@ def tok2id(dir_path, max_vocab=30000, min_freq=1):
trn_lm = np.array([[stoi[o] for o in p] for p in trn_tok])
val_lm = np.array([[stoi[o] for o in p] for p in val_tok])

np.save(tmp_path.joinpath('trn_ids.npy'), trn_lm)
np.save(tmp_path.joinpath('val_ids.npy'), val_lm)
pickle.dump(itos, open(tmp_path.joinpath('itos.pkl'), 'wb'))
np.save(tmp_path / 'trn_ids.npy', trn_lm)
np.save(tmp_path / 'val_ids.npy', val_lm)
pickle.dump(itos, open(tmp_path / 'itos.pkl', 'wb'))

if __name__ == '__main__': fire.Fire(tok2id)

0 comments on commit 3fce720

Please sign in to comment.