forked from ntu-adl-ta/ADL21-HW1
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
44 lines (33 loc) · 1.29 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from typing import Iterable, List
class Vocab:
PAD = "[PAD]"
UNK = "[UNK]"
def __init__(self, vocab: Iterable[str]) -> None:
self.token2idx = {
Vocab.PAD: 0,
Vocab.UNK: 1,
**{token: i for i, token in enumerate(vocab, 2)},
}
@property
def pad_id(self) -> int:
return self.token2idx[Vocab.PAD]
@property
def unk_id(self) -> int:
return self.token2idx[Vocab.UNK]
@property
def tokens(self) -> List[str]:
return list(self.token2idx.keys())
def token_to_id(self, token: str) -> int:
return self.token2idx.get(token, self.unk_id)
def encode(self, tokens: List[str]) -> List[int]:
return [self.token_to_id(token) for token in tokens]
def encode_batch(
self, batch_tokens: List[List[str]], to_len: int = None
) -> List[List[int]]:
batch_ids = [self.encode(tokens) for tokens in batch_tokens]
to_len = max(len(ids) for ids in batch_ids) if to_len is None else to_len
padded_ids = pad_to_len(batch_ids, to_len, self.pad_id)
return padded_ids
def pad_to_len(seqs: List[List[int]], to_len: int, padding: int) -> List[List[int]]:
paddeds = [seq[:to_len] + [padding] * max(0, to_len - len(seq)) for seq in seqs]
return paddeds