Skip to content

Commit

Permalink
Update model2.py
Browse files Browse the repository at this point in the history
  • Loading branch information
tzm224 authored Nov 10, 2024
1 parent c307a14 commit eacda1d
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions src/Exp5/model2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@ def __init__(self, embedding_size):
super(Mix_Pooling, self).__init__()

def forward(self, seq, mask):
mask = (~mask).unsqueeze(-1).float() # 反转mask并扩展最后一维
cls = seq[:, 1, :] # 获取第一个位置的CLS token
mask = (~mask).unsqueeze(-1).float()
cls = seq[:, 1, :]

# 将mask位置置为 -1e9
seq = seq * mask + (1.0 - mask) * -1e9

# 在维度1 (max_len) 上进行 max pooling,保持输入的形状
seq, max_indices = torch.max(seq, dim=1)

if self.training:
Expand Down Expand Up @@ -82,9 +80,9 @@ def forward(self, seq, length, mz):
final_state = torch.cat([seq, length, mz], dim=1)
return self.mlp(final_state)

class Seq2CCS(nn.Module):
class PEP2CCS(nn.Module):
def __init__(self, num_layers, embedding_size, num_heads, dropout, p, max_len=64):
super(Seq2CCS, self).__init__()
super(PEP2CCS, self).__init__()
self.max_len = max_len
self.Embedding = Embedding(embedding_size, max_len)
self.Encoder = EncoderLayer(num_layers, embedding_size, num_heads, dropout)
Expand Down

0 comments on commit eacda1d

Please sign in to comment.