From a25e67e31271b8648a08666e488ab19a0b072eca Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 31 Oct 2024 14:19:19 -0700 Subject: [PATCH] add value residual learning, given people are interested in this again due to notebooklm --- README.md | 9 ++++++++ setup.py | 2 +- soundstorm_pytorch/soundstorm.py | 38 ++++++++++++++++++++++++++------ 3 files changed, 41 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 44d064e..845eae6 100644 --- a/README.md +++ b/README.md @@ -230,3 +230,12 @@ generated_speech = model.generate( primaryClass = {cs.CL} } ``` + +```bibtex +@inproceedings{Zhou2024ValueRL, + title = {Value Residual Learning For Alleviating Attention Concentration In Transformers}, + author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan}, + year = {2024}, + url = {https://api.semanticscholar.org/CorpusID:273532030} +} +``` diff --git a/setup.py b/setup.py index 0093338..a576eca 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'soundstorm-pytorch', packages = find_packages(exclude=[]), - version = '0.4.11', + version = '0.5.0', license='MIT', description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch', author = 'Phil Wang', diff --git a/soundstorm_pytorch/soundstorm.py b/soundstorm_pytorch/soundstorm.py index 68064b2..6686206 100644 --- a/soundstorm_pytorch/soundstorm.py +++ b/soundstorm_pytorch/soundstorm.py @@ -307,7 +307,9 @@ def forward( context = None, mask = None, rotary_emb = None, - attn_bias = None + attn_bias = None, + return_values = False, + value_residual = None ): n, device, h, has_context = x.shape[-2], x.device, self.heads, exists(context) context = default(context, x) @@ -315,6 +317,9 @@ def forward( q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) + if exists(value_residual): + v = 0.5 * (v + value_residual) + if exists(rotary_emb): q = apply_rotary_pos_emb(rotary_emb, q) k = apply_rotary_pos_emb(rotary_emb, k) @@ -322,7 +327,12 @@ def forward( out = self.attend(q, k, v, mask = mask, attn_bias = attn_bias) out = rearrange(out, 'b h n d -> b n (h d)') - return self.to_out(out) + out = self.to_out(out) + + if not return_values: + return out + + return out, v class FeedForward(Module): def __init__( @@ -418,18 +428,26 @@ def forward( x, mask = None, rotary_emb = None, - attn_bias = None + attn_bias = None, + attn_value_residual = None, + return_values = False ): x = self.ff1(x) + x if exists(self.gateloop): x = self.gateloop(x) + x - x = self.attn(x, mask = mask, rotary_emb = rotary_emb, attn_bias = attn_bias) + x + attn_out, attn_values = self.attn(x, mask = mask, rotary_emb = rotary_emb, attn_bias = attn_bias, value_residual = attn_value_residual, return_values = True) + x = attn_out + x + x = self.conv(x, mask = mask) + x x = self.ff2(x) + x x = self.post_norm(x) - return x + + if not return_values: + return x + + return x, attn_values # Conformer @@ -484,14 +502,20 @@ def forward(self, x, mask = None): rotary_emb = self.rotary_emb(seq_len) if exists(self.rotary_emb) else None attn_bias = self.rel_pos_bias(seq_len) if exists(self.rel_pos_bias) else None + attn_value_residual = None + for block in self.layers: - x = block( + x, attn_values = block( x, mask = mask, rotary_emb = rotary_emb, - attn_bias = attn_bias + attn_bias = attn_bias, + attn_value_residual = attn_value_residual, + return_values = True ) + attn_value_residual = default(attn_value_residual, attn_values) + return x # conformer with sum reduction across quantized tokens at the beginning, along with heads