Skip to content

Commit

Permalink
Add unit tests to check if Convs and SeaNets are causal and streamable.
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryli27 committed Oct 9, 2024
1 parent 35a9f43 commit c2a26be
Show file tree
Hide file tree
Showing 5 changed files with 348 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# VsCode
.vscode/

Cargo.lock
*~
*.safetensors
Expand Down
156 changes: 156 additions & 0 deletions moshi/moshi/modules/conv_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import functools
import torch
import torch.nn as nn
import pytest

from .conv import StreamingConv1d, StreamingConvTranspose1d


CONV1D_DATA = [
# batch_size,in_channels,out_channels,seq_len,kernel_size
pytest.param(
3, 4, 5, 10, 6,
id='small conv1d test 1',
),
pytest.param(
4, 5, 6, 10, 7,
id='small conv1d test 2',
),
pytest.param(
5, 6, 7, 10, 2,
id='small conv1d test 3',
),
pytest.param(
1, 512, 512, 256, 7,
id='large conv1d test 1',
),
]

CONV1D_TRANSPOSE_DATA = [
# batch_size,in_channels,out_channels,seq_len,kernel_size,stride
pytest.param(
3, 4, 5, 10, 6, 1,
id='small conv1d transpose test 1',
),
pytest.param(
4, 5, 6, 10, 7, 2,
id='small conv1d transpose test 2',
),
pytest.param(
5, 6, 7, 10, 4, 3,
id='small conv1d transpose test 3',
),
pytest.param(
1, 512, 512, 256, 7, 2,
id='large conv1d transpose test 1',
),
]

def _init_weights(module, generator=None):
for name, param in module.named_parameters():
if "weight" in name:
nn.init.xavier_uniform_(param, generator=generator)
elif "bias" in name:
nn.init.constant_(param, 0.0)
else:
nn.init.xavier_uniform_(param, generator=generator)



@pytest.mark.parametrize("batch_size,in_channels,out_channels,seq_len,kernel_size", CONV1D_DATA)
def test_conv1d(batch_size,in_channels,out_channels,seq_len,kernel_size):
"""Test that StreamingConv1d() calls are causal. Having new inputs does not change the previous output."""
assert seq_len > kernel_size

layer = StreamingConv1d(in_channels, out_channels, kernel_size, causal=True, norm="none", pad_mode="constant")


generator = torch.Generator()
generator = generator.manual_seed(41)
layer.apply(functools.partial(_init_weights, generator=generator))

shape = (batch_size, in_channels, seq_len,)
input_hidden_states = torch.rand(shape)

expected_output = layer(input_hidden_states)

for end_index in range(kernel_size, seq_len+1):
actual_output = layer(input_hidden_states[..., :end_index])
torch.testing.assert_close(actual_output, expected_output[..., :actual_output.shape[-1]],
msg=lambda original_msg: f"Failed at end_index={end_index}: \n{original_msg}")


@pytest.mark.parametrize("batch_size,in_channels,out_channels,seq_len,kernel_size", CONV1D_DATA)
def test_conv1d_streaming(batch_size,in_channels,out_channels,seq_len,kernel_size):
"""Test that StreamingConv1d() streaming works as expected."""
assert seq_len > kernel_size

layer = StreamingConv1d(in_channels, out_channels, kernel_size, causal=True, norm="none", pad_mode="constant")


generator = torch.Generator()
generator = generator.manual_seed(41)
layer.apply(functools.partial(_init_weights, generator=generator))

shape = (batch_size, in_channels, seq_len,)
input_hidden_states = torch.rand(shape)
expected_output = layer(input_hidden_states)

start_index = 0
actual_outputs = []
with layer.streaming(batch_size=batch_size):
for end_index in range(kernel_size, seq_len+1):
actual_output = layer(input_hidden_states[..., start_index:end_index])
start_index = end_index
actual_outputs.append(actual_output)
actual_outputs = torch.concat(actual_outputs, axis=-1)

torch.testing.assert_close(actual_outputs, expected_output)


@pytest.mark.parametrize("batch_size,in_channels,out_channels,seq_len,kernel_size,stride", CONV1D_TRANSPOSE_DATA)
def test_conv1d_transpose(batch_size,in_channels,out_channels,seq_len,kernel_size,stride):
"""Test that StreamingConvTranspose1d() calls are causal. Having new inputs does not change the previous output."""
assert seq_len > kernel_size

layer = StreamingConvTranspose1d(in_channels, out_channels, kernel_size,stride, causal=True, norm="none")

generator = torch.Generator()
generator = generator.manual_seed(41)
layer.apply(functools.partial(_init_weights, generator=generator))

shape = (batch_size, in_channels, seq_len,)
input_hidden_states = torch.rand(shape)
expected_output = layer(input_hidden_states)

for end_index in range(kernel_size, seq_len+1):
actual_output = layer(input_hidden_states[..., :end_index])
torch.testing.assert_close(actual_output, expected_output[..., :actual_output.shape[-1]],
msg=lambda original_msg: f"Failed at end_index={end_index}: \n{original_msg}")


@pytest.mark.parametrize("batch_size,in_channels,out_channels,seq_len,kernel_size,stride", CONV1D_TRANSPOSE_DATA)
def test_conv1d_transpose_streaming(batch_size,in_channels,out_channels,seq_len,kernel_size,stride):
"""Test that StreamingConvTranspose1d() streaming works as expected."""
assert seq_len > kernel_size

layer = StreamingConvTranspose1d(in_channels, out_channels, kernel_size,stride, causal=True, norm="none")

generator = torch.Generator()
generator = generator.manual_seed(41)
layer.apply(functools.partial(_init_weights, generator=generator))

shape = (batch_size, in_channels, seq_len,)
input_hidden_states = torch.rand(shape)
expected_output = layer(input_hidden_states)

start_index = 0
actual_outputs = []
with layer.streaming(batch_size=batch_size):
for end_index in range(kernel_size, seq_len+1):
actual_output = layer(input_hidden_states[..., start_index:end_index])
start_index = end_index
actual_outputs.append(actual_output)
actual_outputs = torch.concat(actual_outputs, axis=-1)

torch.testing.assert_close(actual_outputs, expected_output)
187 changes: 187 additions & 0 deletions moshi/moshi/modules/seanet_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import functools
import torch
import torch.nn as nn
import pytest

from .seanet import SEANetResnetBlock, SEANetDecoder


SEANET_RESNET_DATA = [
# batch_size,dim,res_layer_index,seq_len,kernel_size
pytest.param(
3, 4, 1, 10, 6,
id='small resnet test 1',
),
pytest.param(
4, 5, 2, 10, 7,
id='small resnet test 2',
),
pytest.param(
5, 6, 4, 10, 2,
id='small resnet test 3',
),
pytest.param(
1, 512, 2, 256, 7,
id='large resnet test 1',
),
]
NUM_TIMESTEPS_DATA = [
pytest.param(
1,
id='length 1',
),
pytest.param(
2,
id='length 2',
),
pytest.param(
10,
id='length 10',
),
pytest.param(
100,
id='length 100',
),
]

SEANET_KWARGS_DATA = [
pytest.param(
{
"channels": 1,
"dimension": 8,
"causal": True,
"n_filters": 2,
"n_residual_layers": 1,
"activation": "ELU",
"compress": 2,
"dilation_base": 2,
"disable_norm_outer_blocks": 0,
"kernel_size": 7,
"residual_kernel_size": 3,
"last_kernel_size": 3,
# We train using weight_norm but then the weights are pre-processed for inference so
# that we can use a normal convolution.
"norm": "none",
"pad_mode": "constant",
"ratios": [5],
"true_skip": True,
},
id='Tiny SEANet',
),

pytest.param(
{
"channels": 1,
"dimension": 512,
"causal": True,
"n_filters": 64,
"n_residual_layers": 1,
"activation": "ELU",
"compress": 2,
"dilation_base": 2,
"disable_norm_outer_blocks": 0,
"kernel_size": 7,
"residual_kernel_size": 3,
"last_kernel_size": 3,
# We train using weight_norm but then the weights are pre-processed for inference so
# that we can use a normal convolution.
"norm": "none",
"pad_mode": "constant",
"ratios": [8, 6, 5, 4],
"true_skip": True,
},
id='Large SEANet',
),
]


def _init_weights(module, generator=None):
for name, param in module.named_parameters():
if "weight" in name:
nn.init.xavier_uniform_(param, generator=generator)
elif "bias" in name:
nn.init.constant_(param, 0.0)
else:
nn.init.xavier_uniform_(param, generator=generator)




@pytest.mark.parametrize("batch_size,dim,res_layer_index,seq_len,kernel_size", SEANET_RESNET_DATA)
def test_resnet(batch_size,dim,res_layer_index,seq_len,kernel_size):
"""Test that SEANetResnetBlock() calls are causal. Having new inputs does not change the previous output."""
assert seq_len > kernel_size

dilation_base = 2
layer = SEANetResnetBlock(dim=dim, dilations=[dilation_base**res_layer_index, 1], pad_mode="constant", causal=True)


generator = torch.Generator()
generator = generator.manual_seed(41)
layer.apply(functools.partial(_init_weights, generator=generator))

shape = (batch_size, dim, seq_len,)
input_hidden_states = torch.rand(shape)

expected_output = layer(input_hidden_states)

for end_index in range(kernel_size, seq_len+1):
actual_output = layer(input_hidden_states[..., :end_index])
torch.testing.assert_close(actual_output, expected_output[..., :actual_output.shape[-1]], msg=lambda original_msg: f"Failed at end_index={end_index}: \n{original_msg}")



@pytest.mark.parametrize("batch_size,dim,res_layer_index,seq_len,kernel_size", SEANET_RESNET_DATA)
def test_resnet_streaming(batch_size,dim,res_layer_index,seq_len,kernel_size):
"""Test that SEANetResnetBlock() streaming works as expected."""
assert seq_len > kernel_size

dilation_base = 2
layer = SEANetResnetBlock(dim=dim, dilations=[dilation_base**res_layer_index, 1], pad_mode="constant", causal=True)


generator = torch.Generator()
generator = generator.manual_seed(41)
layer.apply(functools.partial(_init_weights, generator=generator))

shape = (batch_size, dim, seq_len,)
input_hidden_states = torch.rand(shape)

expected_output = layer(input_hidden_states)

start_index = 0
actual_outputs = []
with layer.streaming(batch_size=batch_size):
for end_index in range(kernel_size, seq_len+1):
actual_output = layer(input_hidden_states[..., start_index:end_index])
start_index = end_index
actual_outputs.append(actual_output)
actual_outputs = torch.concat(actual_outputs, axis=-1)

torch.testing.assert_close(actual_outputs, expected_output)


@pytest.mark.parametrize("num_timesteps", NUM_TIMESTEPS_DATA)
@pytest.mark.parametrize("seanet_kwargs", SEANET_KWARGS_DATA)
def test_nonstreaming_causal_decode(num_timesteps, seanet_kwargs):
"""Test that the SEANetDecoder does not depend on future inputs."""

device = 'cuda' if torch.cuda.is_available() else 'cpu'
decoder = SEANetDecoder(**seanet_kwargs).to(device=device)

generator = torch.Generator(device=device)
generator = generator.manual_seed(41)
decoder.apply(functools.partial(_init_weights, generator=generator))

rand_generator = torch.Generator(device=device)
rand_generator.manual_seed(2147483647)
with torch.no_grad():
codes = torch.randn(1, seanet_kwargs['dimension'], num_timesteps, generator=rand_generator, device=device) # [B, K = 8, T]
expected_decoded = decoder(codes)

num_timesteps = codes.shape[-1]
for t in range(num_timesteps):
current_codes = codes[..., :t+1]
actual_decoded = decoder(current_codes)
torch.testing.assert_close(expected_decoded[..., :actual_decoded.shape[-1]], actual_decoded,
msg=lambda original_msg: f"Failed at t={t}: \n{original_msg}")
1 change: 1 addition & 0 deletions moshi/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dependencies = [
"sphn >= 0.1.4",
"torch >= 2.2.0, < 2.5",
"aiohttp>=3.10.5, <3.11",
"pytest >= 8.3.3",
]
authors = [{name="Laurent Mazaré", email="[email protected]"}]
maintainers = [{name="Laurent Mazaré", email="[email protected]"}]
Expand Down
1 change: 1 addition & 0 deletions moshi/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ torch==2.2.0
numpy==1.26.4
aiohttp>=3.10.5, <3.11
huggingface-hub==0.24.6
pytest==8.3.3

0 comments on commit c2a26be

Please sign in to comment.