Skip to content

Commit

Permalink
解析 fmt/data
Browse files Browse the repository at this point in the history
  • Loading branch information
tybian committed Nov 14, 2024
1 parent 6adcf63 commit 057914d
Showing 1 changed file with 103 additions and 24 deletions.
127 changes: 103 additions & 24 deletions test_parse_riff.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from enum import Enum
from typing import List

import numpy as np


class FourCCType(Enum):
RIFF = 'RIFF'
Expand Down Expand Up @@ -46,11 +48,6 @@ class CueTiming:

@classmethod
def from_bytes(cls, data: bytes) -> CueTiming:
"""
形如 hex 01 start "data" 00 00 end
:param data:
:return:
"""
if len(data) != 24:
raise ValueError(f"Invalid cue timing data: {data}")
id_, start, _, _, _, end = struct.unpack('<II4sIII', data)
Expand All @@ -68,7 +65,6 @@ def __post_init__(self):

@classmethod
def from_bytes(cls, data: bytes) -> Cue:

id_, = struct.unpack('<I', data[:4])
cue = cls(id=id_, size=len(data))

Expand All @@ -79,15 +75,79 @@ def from_bytes(cls, data: bytes) -> Cue:
return cue


@dataclass()
class Format:
"""
12 ~ 16 字节为:fmt 大小为 16 bytes
16 ~ 19 4字节为过滤字节: 10 00 00 00
20 ~ 21 2字节为表示编码格式 1 为 pcm 编码
22 ~ 23 2字节为声道数, 1 为单声道, 2 为立体声
24 ~ 27 4字节为采样率
28 ~ 31 4字节为每秒字节数,即比特率
32 ~ 33 2字节为数据块长度(通道数 x 位宽)
34 ~ 35 2字节为采样位数
"""
encoding_code: int
channel_count: int
sample_rate: int
byte_rate: int
chunk_bytes: int = 0
sample_bits: int = 0

@property
def sample_width(self):
return self.byte_rate // self.sample_rate // 2

@classmethod
def from_bytes(cls, data: bytes) -> Format:
if len(data) != 16:
raise ValueError(f"Invalid format data: {data}")
encoding_code, channel_count, sample_rate, byte_rate, chunk_bytes, sample_bits = struct.unpack('<HHIIHH', data)
return cls(encoding_code, channel_count, sample_rate, byte_rate, chunk_bytes, sample_bits)


@dataclass()
class WaveData:
samples: np.ndarray

@classmethod
def from_bytes(cls, data: bytes, format_: Format) -> WaveData:
print(f'format={format_}')
sample_width = format_.sample_width
if sample_width == 2:
dtype = np.int16
elif sample_width == 3:
dtype = np.dtype([('int24', 'i4')]) # 24-bit int stored in 32-bit container
elif sample_width == 4:
dtype = np.int32
else:
raise ValueError(f"Unsupported sample width: {sample_width}")

samples = np.frombuffer(data, dtype=dtype)
if sample_width == 3:
samples = samples['int24'] >> 8 # Adjust 24-bit int to correct range
samples = samples.reshape(-1, format_.channel_count)

# 归一化float
if dtype == np.int16:
samples = samples / 2 ** 15
elif dtype == np.int32:
samples = samples / 2 ** 31
return cls(samples)

@dataclass()
class Chunk:
type: FourCC = None
length: int = 0
code: str = None
data: bytes = None
data_start: int = 0
length: int = 0
file: io.BytesIO = None
format: Format = None
wave_data: WaveData = None
label: Label = None
cue: Cue = None
sub_chunks: List[Chunk] = None
parent: Chunk = None

def __post_init__(self):
self.sub_chunks = self.sub_chunks or []
Expand All @@ -98,31 +158,51 @@ def prob(self, stream: io.BytesIO):
return -1
type_bytes, self.length = struct.unpack('<4sI', header)
self.type = FourCC(type_bytes.decode('ascii'))
self.data = stream.read(self.length)
self.data_start = stream.tell()
self.file = stream

if self.type.code == 'labl':
id_ = struct.unpack('<I', self.data[:4])[0]
self.file.seek(self.data_start)
id_ = struct.unpack('<I', self.file.read(4))[0]
self.label = Label(size=self.length, id=id_)

self.label.read(io.BytesIO(self.data[4:]))
self.label.read(self.file)

if self.type.code == 'cue ':
self.cue = Cue.from_bytes(self.data)
self.file.seek(self.data_start)
data = self.file.read(self.length)
self.cue = Cue.from_bytes(data)

if self.type.code == 'fmt ':
self.file.seek(self.data_start)
data = self.file.read(self.length)
self.format = Format.from_bytes(data)

if self.type.code == 'data':
self.file.seek(self.data_start)
data = self.file.read(self.length)
self.wave_data = WaveData.from_bytes(data, self.parent.get_format())

self.file.seek(self.data_start + self.length + (self.length % 2))
return 0

def get_format(self):
if self.parent:
return None
for sub_chunk in self.sub_chunks:
if sub_chunk.type.code == 'fmt ':
return sub_chunk.format
return None

def parse_sub_chunks(self):
if FourCCType(self.type.code) in [FourCCType.RIFF, FourCCType.LIST]:
sub_file = io.BytesIO(self.data)
code = sub_file.read(4)
self.file.seek(self.data_start)
code = self.file.read(4)
if len(code) < 4:
return
self.code = code.decode('ascii')
print(f'code={self.code}')
while sub_file.tell() < self.length:
sub_chunk = Chunk()
if sub_chunk.prob(sub_file) == 0:
print(f'sub_chunk_type: {sub_chunk.type}')
while self.file.tell() < self.data_start + self.length:
sub_chunk = Chunk(parent=self)
if sub_chunk.prob(self.file) == 0:
sub_chunk.parse_sub_chunks()
self.sub_chunks.append(sub_chunk)

Expand All @@ -147,18 +227,17 @@ class Riff:
@classmethod
def from_file(cls, filename):
with open(filename, "rb") as file:
return cls.from_steam(io.BytesIO(file.read()))
return cls.from_stream(io.BytesIO(file.read()))

@classmethod
def from_steam(cls, stream: io.BytesIO) -> Riff:
def from_stream(cls, stream: io.BytesIO) -> Riff:
root_chunk = Chunk()
if root_chunk.prob(stream) == 0:
root_chunk.parse_sub_chunks()

return cls(root_chunk.type, root_chunk.length, root_chunk.sub_chunks)

def show_chunk_layers(self, chunk=None, depth=0):

result = " " * depth + str(chunk) + "\n"
for subchunk in chunk.sub_chunks:
result += self.show_chunk_layers(subchunk, depth + 1)
Expand All @@ -167,4 +246,4 @@ def show_chunk_layers(self, chunk=None, depth=0):

if __name__ == "__main__":
riff = Riff.from_file(r'F:\tmp\test_ocenaudio_annotation\sil_2s_marker.wav')
print(riff)
print(riff.show_chunk_layers(riff))

0 comments on commit 057914d

Please sign in to comment.