forked from songyouwei/ABSA-PyTorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbert_spc.py
26 lines (23 loc) · 1.06 KB
/
bert_spc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# -*- coding: utf-8 -*-
# file: BERT_SPC.py
# author: songyouwei <[email protected]>
# Copyright (C) 2019. All Rights Reserved.
import torch
import torch.nn as nn
from layers.squeeze_embedding import SqueezeEmbedding
class BERT_SPC(nn.Module):
def __init__(self, bert, opt):
super(BERT_SPC, self).__init__()
# self.squeeze_embedding = SqueezeEmbedding()
self.bert = bert
self.dropout = nn.Dropout(opt.dropout)
self.dense = nn.Linear(opt.bert_dim, opt.polarities_dim)
def forward(self, inputs):
text_bert_indices, bert_segments_ids = inputs[0], inputs[1]
# text_bert_len = torch.sum(text_bert_indices != 0, dim=-1)
# text_bert_indices = self.squeeze_embedding(text_bert_indices, text_bert_len)
# bert_segments_ids = self.squeeze_embedding(bert_segments_ids, text_bert_len)
_, pooled_output = self.bert(text_bert_indices, bert_segments_ids, output_all_encoded_layers=False)
pooled_output = self.dropout(pooled_output)
logits = self.dense(pooled_output)
return logits