-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrnn_lstm_2.py
39 lines (30 loc) · 1.29 KB
/
rnn_lstm_2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# 2.构建网络
import torch
from torch import nn
import torch.nn.functional as F
class LSTMTagger(nn.Module):
def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size):
super(LSTMTagger, self).__init__()
self.hidden_dim = hidden_dim
self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_dim)
self.hidden2tag = nn.Linear(hidden_dim, tagset_size)
self.hidden = self.init_hidden()
# 初始化隐含状态State及C
def init_hidden(self):
return (torch.zeros(1, 1, self.hidden_dim),
torch.zeros(1, 1, self.hidden_dim))
def forward(self, sentence):
# 获得词嵌入矩阵embeds
embeds = self.word_embeddings(sentence)
# 按lstm格式,修改embeds的形状
lstm_out, self.hidden = self.lstm(embeds.view(len(sentence), 1, -1), self.hidden)
# 修改隐含状态的形状,作为全连接层的输入
tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))
# 计算每个单词属于各词性的概率
tag_scores = F.log_softmax(tag_space, dim=1)
return tag_scores
def prepare_sequence(seq, to_ix):
idxs = [to_ix[w] for w in seq]
tensor = torch.LongTensor(idxs)
return tensor