forked from riffusion/riffusion-hobby
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sample_clips_test.py
88 lines (69 loc) · 2.35 KB
/
sample_clips_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import typing as T
import pydub
from riffusion.cli import sample_clips
from .test_case import TestCase
class SampleClipsTest(TestCase):
"""
Test riffusion.cli sample-clips
"""
@staticmethod
def default_params() -> T.Dict:
return dict(
num_clips=3,
duration_ms=5678,
mono=False,
extension="wav",
seed=42,
)
def test_sample_clips(self) -> None:
"""
Test sample-clips with default params.
"""
params = self.default_params()
self.helper_test_with_params(params)
def test_mono(self) -> None:
"""
Test sample-clips with mono=True.
"""
params = self.default_params()
params["mono"] = True
params["num_clips"] = 1
self.helper_test_with_params(params)
def test_mp3(self) -> None:
"""
Test sample-clips with extension=mp3.
"""
if pydub.AudioSegment.converter is None:
self.skipTest("skipping, ffmpeg not found")
params = self.default_params()
params["extension"] = "mp3"
params["num_clips"] = 1
self.helper_test_with_params(params)
def helper_test_with_params(self, params: T.Dict) -> None:
"""
Test sample-clips with the given params.
"""
audio_path = self.TEST_DATA_PATH / "tired_traveler" / "tired_traveler.mp3"
output_dir = self.get_tmp_dir("sample_clips_")
sample_clips(
audio=str(audio_path),
output_dir=str(output_dir),
**params,
)
# For each file in output dir
counter = 0
for clip_path in output_dir.iterdir():
# Check that it has the right extension
self.assertEqual(clip_path.suffix, f".{params['extension']}")
# Check that it has the right duration
segment = pydub.AudioSegment.from_file(clip_path)
self.assertEqual(round(segment.duration_seconds * 1000), params["duration_ms"])
# Check that it has the right number of channels
if params["mono"]:
self.assertEqual(segment.channels, 1)
else:
self.assertEqual(segment.channels, 2)
counter += 1
self.assertEqual(counter, params["num_clips"])
if __name__ == "__main__":
TestCase.main()