diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 8811cb5a..a5303b0f 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -101,7 +101,8 @@ AudioEncoder::AudioEncoder( const torch::Tensor wf, int sampleRate, std::string_view fileName, - std::optional bitRate) + std::optional bitRate, + std::optional numChannels) : wf_(validateWf(wf)) { setFFmpegLogLevel(); AVFormatContext* avFormatContext = nullptr; @@ -125,7 +126,7 @@ AudioEncoder::AudioEncoder( ", make sure it's a valid path? ", getFFMPEGErrorStringFromErrorCode(status)); - initializeEncoder(sampleRate, bitRate); + initializeEncoder(sampleRate, bitRate, numChannels); } AudioEncoder::AudioEncoder( @@ -133,7 +134,8 @@ AudioEncoder::AudioEncoder( int sampleRate, std::string_view formatName, std::unique_ptr avioContextHolder, - std::optional bitRate) + std::optional bitRate, + std::optional numChannels) : wf_(validateWf(wf)), avioContextHolder_(std::move(avioContextHolder)) { setFFmpegLogLevel(); AVFormatContext* avFormatContext = nullptr; @@ -151,12 +153,13 @@ AudioEncoder::AudioEncoder( avFormatContext_->pb = avioContextHolder_->getAVIOContext(); - initializeEncoder(sampleRate, bitRate); + initializeEncoder(sampleRate, bitRate, numChannels); } void AudioEncoder::initializeEncoder( int sampleRate, - std::optional bitRate) { + std::optional bitRate, + std::optional numChannels) { // We use the AVFormatContext's default codec for that // specific format/container. const AVCodec* avCodec = @@ -174,6 +177,12 @@ void AudioEncoder::initializeEncoder( // well when "-b:a" isn't specified. avCodecContext_->bit_rate = bitRate.value_or(0); + desiredNumChannels_ = static_cast(numChannels.value_or(wf_.sizes()[0])); + validateNumChannels(*avCodec, desiredNumChannels_); + // The avCodecContext layout defines the layout of the encoded output, it's + // not related to the input sampes. + setDefaultChannelLayout(avCodecContext_, desiredNumChannels_); + validateSampleRate(*avCodec, sampleRate); avCodecContext_->sample_rate = sampleRate; @@ -182,8 +191,6 @@ void AudioEncoder::initializeEncoder( // what the `.sample_fmt` defines. avCodecContext_->sample_fmt = findBestOutputSampleFormat(*avCodec); - setDefaultChannelLayout(avCodecContext_, static_cast(wf_.sizes()[0])); - int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr); TORCH_CHECK( status == AVSUCCESS, @@ -228,7 +235,9 @@ void AudioEncoder::encode() { avFrame->format = AV_SAMPLE_FMT_FLTP; avFrame->sample_rate = avCodecContext_->sample_rate; avFrame->pts = 0; - setChannelLayout(avFrame, avCodecContext_); + // We set the channel layout of the frame to the default layout corresponding + // to the input samples' number of channels + setDefaultChannelLayout(avFrame, static_cast(wf_.sizes()[0])); auto status = av_frame_get_buffer(avFrame.get(), 0); TORCH_CHECK( @@ -293,8 +302,10 @@ void AudioEncoder::encodeInnerLoop( AutoAVPacket& autoAVPacket, const UniqueAVFrame& srcAVFrame) { bool mustConvert = - (avCodecContext_->sample_fmt != AV_SAMPLE_FMT_FLTP && - srcAVFrame != nullptr); + (srcAVFrame != nullptr && + (avCodecContext_->sample_fmt != AV_SAMPLE_FMT_FLTP || + getNumChannels(srcAVFrame) != desiredNumChannels_)); + UniqueAVFrame convertedAVFrame; if (mustConvert) { if (!swrContext_) { @@ -304,15 +315,14 @@ void AudioEncoder::encodeInnerLoop( srcAVFrame->sample_rate, // No sample rate conversion srcAVFrame->sample_rate, srcAVFrame, - getNumChannels(srcAVFrame) // No num_channel conversion - )); + desiredNumChannels_)); } convertedAVFrame = convertAudioAVFrameSamples( swrContext_, srcAVFrame, avCodecContext_->sample_fmt, srcAVFrame->sample_rate, // No sample rate conversion - getNumChannels(srcAVFrame)); // No num_channel conversion + desiredNumChannels_); TORCH_CHECK( convertedAVFrame->nb_samples == srcAVFrame->nb_samples, "convertedAVFrame->nb_samples=", diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index bf31c31b..afbc1d3f 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -13,6 +13,9 @@ class AudioEncoder { // like passing 0, which results in choosing the minimum supported bit rate. // Passing 44_100 could result in output being 44000 if only 44000 is // supported. + // + // TODO-ENCODING: bundle the optional params like bitRate, numChannels, etc. + // into an AudioStreamOptions struct, or similar. AudioEncoder( const torch::Tensor wf, // The *output* sample rate. We can't really decide for the user what it @@ -21,20 +24,23 @@ class AudioEncoder { // encoding will still work but audio will be distorted. int sampleRate, std::string_view fileName, - std::optional bitRate = std::nullopt); + std::optional bitRate = std::nullopt, + std::optional numChannels = std::nullopt); AudioEncoder( const torch::Tensor wf, int sampleRate, std::string_view formatName, std::unique_ptr avioContextHolder, - std::optional bitRate = std::nullopt); + std::optional bitRate = std::nullopt, + std::optional numChannels = std::nullopt); void encode(); torch::Tensor encodeToTensor(); private: void initializeEncoder( int sampleRate, - std::optional bitRate = std::nullopt); + std::optional bitRate = std::nullopt, + std::optional numChannels = std::nullopt); void encodeInnerLoop( AutoAVPacket& autoAVPacket, const UniqueAVFrame& srcAVFrame); @@ -44,6 +50,9 @@ class AudioEncoder { UniqueAVCodecContext avCodecContext_; int streamIndex_; UniqueSwrContext swrContext_; + // TODO-ENCODING: desiredNumChannels should just be part of an options struct, + // see other TODO above. + int desiredNumChannels_ = -1; const torch::Tensor wf_; diff --git a/src/torchcodec/_core/FFMPEGCommon.cpp b/src/torchcodec/_core/FFMPEGCommon.cpp index a8740b1f..b412517e 100644 --- a/src/torchcodec/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/_core/FFMPEGCommon.cpp @@ -88,21 +88,71 @@ void setDefaultChannelLayout( #endif } -void setChannelLayout( - UniqueAVFrame& dstAVFrame, - const UniqueAVCodecContext& avCodecContext) { +void setDefaultChannelLayout(UniqueAVFrame& avFrame, int numChannels) { #if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4 - auto status = av_channel_layout_copy( - &dstAVFrame->ch_layout, &avCodecContext->ch_layout); - TORCH_CHECK( - status == AVSUCCESS, - "Couldn't copy channel layout to avFrame: ", - getFFMPEGErrorStringFromErrorCode(status)); + AVChannelLayout channel_layout; + av_channel_layout_default(&channel_layout, numChannels); + avFrame->ch_layout = channel_layout; #else - dstAVFrame->channel_layout = avCodecContext->channel_layout; - dstAVFrame->channels = avCodecContext->channels; + uint64_t channel_layout = av_get_default_channel_layout(numChannels); + avFrame->channel_layout = channel_layout; + avFrame->channels = numChannels; +#endif +} +void validateNumChannels(const AVCodec& avCodec, int numChannels) { +#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4 + if (avCodec.ch_layouts == nullptr) { + // If we can't validate, we must assume it'll be fine. If not, FFmpeg will + // eventually raise. + return; + } + // FFmpeg doc indicate that the ch_layouts array is terminated by a zeroed + // layout, so checking for nb_channels == 0 should indicate its end. + for (auto i = 0; avCodec.ch_layouts[i].nb_channels != 0; ++i) { + if (numChannels == avCodec.ch_layouts[i].nb_channels) { + return; + } + } + // At this point it seems that the encoder doesn't support the requested + // number of channels, so we error out. + std::stringstream supportedNumChannels; + for (auto i = 0; avCodec.ch_layouts[i].nb_channels != 0; ++i) { + if (i > 0) { + supportedNumChannels << ", "; + } + supportedNumChannels << avCodec.ch_layouts[i].nb_channels; + } +#else + if (avCodec.channel_layouts == nullptr) { + // can't validate, same as above. + return; + } + for (auto i = 0; avCodec.channel_layouts[i] != 0; ++i) { + if (numChannels == + av_get_channel_layout_nb_channels(avCodec.channel_layouts[i])) { + return; + } + } + // At this point it seems that the encoder doesn't support the requested + // number of channels, so we error out. + std::stringstream supportedNumChannels; + for (auto i = 0; avCodec.channel_layouts[i] != 0; ++i) { + if (i > 0) { + supportedNumChannels << ", "; + } + supportedNumChannels << av_get_channel_layout_nb_channels( + avCodec.channel_layouts[i]); + } #endif + TORCH_CHECK( + false, + "Desired number of channels (", + numChannels, + ") is not supported by the ", + "encoder. Supported number of channels are: ", + supportedNumChannels.str(), + "."); } namespace { diff --git a/src/torchcodec/_core/FFMPEGCommon.h b/src/torchcodec/_core/FFMPEGCommon.h index d0d3a682..07b7443e 100644 --- a/src/torchcodec/_core/FFMPEGCommon.h +++ b/src/torchcodec/_core/FFMPEGCommon.h @@ -151,9 +151,9 @@ void setDefaultChannelLayout( UniqueAVCodecContext& avCodecContext, int numChannels); -void setChannelLayout( - UniqueAVFrame& dstAVFrame, - const UniqueAVCodecContext& avCodecContext); +void setDefaultChannelLayout(UniqueAVFrame& avFrame, int numChannels); + +void validateNumChannels(const AVCodec& avCodec, int numChannels); void setChannelLayout( UniqueAVFrame& dstAVFrame, diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 1355045a..c6e43d09 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -29,9 +29,9 @@ TORCH_LIBRARY(torchcodec_ns, m) { "torchcodec._core.ops", "//pytorch/torchcodec:torchcodec"); m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor"); m.def( - "encode_audio_to_file(Tensor wf, int sample_rate, str filename, int? bit_rate=None) -> ()"); + "encode_audio_to_file(Tensor wf, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None) -> ()"); m.def( - "encode_audio_to_tensor(Tensor wf, int sample_rate, str format, int? bit_rate=None) -> Tensor"); + "encode_audio_to_tensor(Tensor wf, int sample_rate, str format, int? bit_rate=None, int? num_channels=None) -> Tensor"); m.def( "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); m.def("_convert_to_tensor(int decoder_ptr) -> Tensor"); @@ -391,8 +391,10 @@ void encode_audio_to_file( const at::Tensor wf, int64_t sample_rate, std::string_view file_name, - std::optional bit_rate = std::nullopt) { - AudioEncoder(wf, validateSampleRate(sample_rate), file_name, bit_rate) + std::optional bit_rate = std::nullopt, + std::optional num_channels = std::nullopt) { + AudioEncoder( + wf, validateSampleRate(sample_rate), file_name, bit_rate, num_channels) .encode(); } @@ -400,14 +402,16 @@ at::Tensor encode_audio_to_tensor( const at::Tensor wf, int64_t sample_rate, std::string_view format, - std::optional bit_rate = std::nullopt) { + std::optional bit_rate = std::nullopt, + std::optional num_channels = std::nullopt) { auto avioContextHolder = std::make_unique(); return AudioEncoder( wf, validateSampleRate(sample_rate), format, std::move(avioContextHolder), - bit_rate) + bit_rate, + num_channels) .encodeToTensor(); } diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index 3507df44..11751e32 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -164,14 +164,22 @@ def create_from_file_abstract(filename: str, seek_mode: Optional[str]) -> torch. # TODO-ENCODING: rename wf to samples @register_fake("torchcodec_ns::encode_audio_to_file") def encode_audio_to_file_abstract( - wf: torch.Tensor, sample_rate: int, filename: str, bit_rate: Optional[int] = None + wf: torch.Tensor, + sample_rate: int, + filename: str, + bit_rate: Optional[int] = None, + num_channels: Optional[int] = None, ) -> None: return @register_fake("torchcodec_ns::encode_audio_to_tensor") def encode_audio_to_tensor_abstract( - wf: torch.Tensor, sample_rate: int, format: str, bit_rate: Optional[int] = None + wf: torch.Tensor, + sample_rate: int, + format: str, + bit_rate: Optional[int] = None, + num_channels: Optional[int] = None, ) -> torch.Tensor: return torch.empty([], dtype=torch.long) diff --git a/src/torchcodec/encoders/_audio_encoder.py b/src/torchcodec/encoders/_audio_encoder.py index bee05d0a..469fbefb 100644 --- a/src/torchcodec/encoders/_audio_encoder.py +++ b/src/torchcodec/encoders/_audio_encoder.py @@ -31,12 +31,14 @@ def to_file( dest: Union[str, Path], *, bit_rate: Optional[int] = None, + num_channels: Optional[int] = None, ) -> None: _core.encode_audio_to_file( wf=self._samples, sample_rate=self._sample_rate, filename=dest, bit_rate=bit_rate, + num_channels=num_channels, ) def to_tensor( @@ -44,10 +46,12 @@ def to_tensor( format: str, *, bit_rate: Optional[int] = None, + num_channels: Optional[int] = None, ) -> Tensor: return _core.encode_audio_to_tensor( wf=self._samples, sample_rate=self._sample_rate, format=format, bit_rate=bit_rate, + num_channels=num_channels, ) diff --git a/test/test_ops.py b/test/test_ops.py index 5fb4d350..789ef8a9 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -6,6 +6,7 @@ import io import os +import re from functools import partial os.environ["TORCH_LOGS"] = "output_code" @@ -1158,6 +1159,20 @@ def test_bad_input(self, tmp_path): wf=torch.rand(10, 20), sample_rate=10, filename="doesnt_matter" ) + for num_channels in (0, 3): + with pytest.raises( + RuntimeError, + match=re.escape( + f"Desired number of channels ({num_channels}) is not supported" + ), + ): + encode_audio_to_file( + wf=torch.rand(2, 10), + sample_rate=16_000, + filename="ok.mp3", + num_channels=num_channels, + ) + @pytest.mark.parametrize( "encode_method", (encode_audio_to_file, encode_audio_to_tensor) ) @@ -1194,8 +1209,9 @@ def test_round_trip(self, encode_method, output_format, tmp_path): @pytest.mark.skipif(in_fbcode(), reason="TODO: enable ffmpeg CLI") @pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32)) @pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999)) + @pytest.mark.parametrize("num_channels", (None, 1, 2)) @pytest.mark.parametrize("output_format", ("mp3", "wav", "flac")) - def test_against_cli(self, asset, bit_rate, output_format, tmp_path): + def test_against_cli(self, asset, bit_rate, num_channels, output_format, tmp_path): # Encodes samples with our encoder and with the FFmpeg CLI, and checks # that both decoded outputs are equal @@ -1206,6 +1222,7 @@ def test_against_cli(self, asset, bit_rate, output_format, tmp_path): subprocess.run( ["ffmpeg", "-i", str(asset.path)] + (["-b:a", f"{bit_rate}"] if bit_rate is not None else []) + + (["-ac", f"{num_channels}"] if num_channels is not None else []) + [ str(encoded_by_ffmpeg), ], @@ -1219,9 +1236,19 @@ def test_against_cli(self, asset, bit_rate, output_format, tmp_path): sample_rate=asset.sample_rate, filename=str(encoded_by_us), bit_rate=bit_rate, + num_channels=num_channels, ) - rtol, atol = (0, 1e-4) if output_format == "wav" else (None, None) + if output_format == "wav": + rtol, atol = 0, 1e-4 + elif output_format == "mp3" and asset is SINE_MONO_S32 and num_channels == 2: + # Not sure why, this one needs slightly higher tol. With default + # tolerances, the check fails on ~1% of the samples, so that's + # probably fine. It might be that the FFmpeg CLI doesn't rely on + # libswresample for converting channels? + rtol, atol = 0, 1e-3 + else: + rtol, atol = None, None torch.testing.assert_close( self.decode(encoded_by_ffmpeg), self.decode(encoded_by_us), @@ -1231,8 +1258,11 @@ def test_against_cli(self, asset, bit_rate, output_format, tmp_path): @pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32)) @pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999)) + @pytest.mark.parametrize("num_channels", (None, 1, 2)) @pytest.mark.parametrize("output_format", ("mp3", "wav", "flac")) - def test_tensor_against_file(self, asset, bit_rate, output_format, tmp_path): + def test_tensor_against_file( + self, asset, bit_rate, num_channels, output_format, tmp_path + ): if get_ffmpeg_major_version() == 4 and output_format == "wav": pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files") @@ -1242,6 +1272,7 @@ def test_tensor_against_file(self, asset, bit_rate, output_format, tmp_path): sample_rate=asset.sample_rate, filename=str(encoded_file), bit_rate=bit_rate, + num_channels=num_channels, ) encoded_tensor = encode_audio_to_tensor( @@ -1249,6 +1280,7 @@ def test_tensor_against_file(self, asset, bit_rate, output_format, tmp_path): sample_rate=asset.sample_rate, format=output_format, bit_rate=bit_rate, + num_channels=num_channels, ) torch.testing.assert_close( @@ -1305,6 +1337,42 @@ def test_contiguity(self): encoded_from_contiguous, encoded_from_non_contiguous, rtol=0, atol=0 ) + @pytest.mark.parametrize("num_channels_input", (1, 2)) + @pytest.mark.parametrize("num_channels_output", (1, 2, None)) + @pytest.mark.parametrize( + "encode_method", (encode_audio_to_file, encode_audio_to_tensor) + ) + def test_num_channels( + self, num_channels_input, num_channels_output, encode_method, tmp_path + ): + # We just check that the num_channels parameter is respected. + # Correctness is checked in other tests (like test_against_cli()) + + sample_rate = 16_000 + source_samples = torch.rand(num_channels_input, 1_000) + format = "mp3" + + if encode_method is encode_audio_to_file: + encoded_path = tmp_path / f"output.{format}" + encode_audio_to_file( + wf=source_samples, + sample_rate=sample_rate, + filename=str(encoded_path), + num_channels=num_channels_output, + ) + encoded_source = encoded_path + else: + encoded_source = encode_audio_to_tensor( + wf=source_samples, + sample_rate=sample_rate, + format=format, + num_channels=num_channels_output, + ) + + if num_channels_output is None: + num_channels_output = num_channels_input + assert self.decode(encoded_source).shape[0] == num_channels_output + if __name__ == "__main__": pytest.main()