From eacda1d189f5feef8298dde147a7e3aab7542e73 Mon Sep 17 00:00:00 2001 From: tzm <89303510+tzm224@users.noreply.github.com> Date: Sun, 10 Nov 2024 12:44:10 +0800 Subject: [PATCH] Update model2.py --- src/Exp5/model2.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/Exp5/model2.py b/src/Exp5/model2.py index 62d315f..bb4976f 100644 --- a/src/Exp5/model2.py +++ b/src/Exp5/model2.py @@ -7,13 +7,11 @@ def __init__(self, embedding_size): super(Mix_Pooling, self).__init__() def forward(self, seq, mask): - mask = (~mask).unsqueeze(-1).float() # 反转mask并扩展最后一维 - cls = seq[:, 1, :] # 获取第一个位置的CLS token + mask = (~mask).unsqueeze(-1).float() + cls = seq[:, 1, :] - # 将mask位置置为 -1e9 seq = seq * mask + (1.0 - mask) * -1e9 - # 在维度1 (max_len) 上进行 max pooling,保持输入的形状 seq, max_indices = torch.max(seq, dim=1) if self.training: @@ -82,9 +80,9 @@ def forward(self, seq, length, mz): final_state = torch.cat([seq, length, mz], dim=1) return self.mlp(final_state) -class Seq2CCS(nn.Module): +class PEP2CCS(nn.Module): def __init__(self, num_layers, embedding_size, num_heads, dropout, p, max_len=64): - super(Seq2CCS, self).__init__() + super(PEP2CCS, self).__init__() self.max_len = max_len self.Embedding = Embedding(embedding_size, max_len) self.Encoder = EncoderLayer(num_layers, embedding_size, num_heads, dropout)