Skip to content

Audio encoding: support custom num_channels #693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions src/torchcodec/_core/Encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ AudioEncoder::AudioEncoder(
const torch::Tensor wf,
int sampleRate,
std::string_view fileName,
std::optional<int64_t> bitRate)
std::optional<int64_t> bitRate,
std::optional<int64_t> numChannels)
: wf_(validateWf(wf)) {
setFFmpegLogLevel();
AVFormatContext* avFormatContext = nullptr;
Expand All @@ -125,15 +126,16 @@ AudioEncoder::AudioEncoder(
", make sure it's a valid path? ",
getFFMPEGErrorStringFromErrorCode(status));

initializeEncoder(sampleRate, bitRate);
initializeEncoder(sampleRate, bitRate, numChannels);
}

AudioEncoder::AudioEncoder(
const torch::Tensor wf,
int sampleRate,
std::string_view formatName,
std::unique_ptr<AVIOToTensorContext> avioContextHolder,
std::optional<int64_t> bitRate)
std::optional<int64_t> bitRate,
std::optional<int64_t> numChannels)
: wf_(validateWf(wf)), avioContextHolder_(std::move(avioContextHolder)) {
setFFmpegLogLevel();
AVFormatContext* avFormatContext = nullptr;
Expand All @@ -151,12 +153,13 @@ AudioEncoder::AudioEncoder(

avFormatContext_->pb = avioContextHolder_->getAVIOContext();

initializeEncoder(sampleRate, bitRate);
initializeEncoder(sampleRate, bitRate, numChannels);
}

void AudioEncoder::initializeEncoder(
int sampleRate,
std::optional<int64_t> bitRate) {
std::optional<int64_t> bitRate,
std::optional<int64_t> numChannels) {
// We use the AVFormatContext's default codec for that
// specific format/container.
const AVCodec* avCodec =
Expand All @@ -174,6 +177,12 @@ void AudioEncoder::initializeEncoder(
// well when "-b:a" isn't specified.
avCodecContext_->bit_rate = bitRate.value_or(0);

desiredNumChannels_ = static_cast<int>(numChannels.value_or(wf_.sizes()[0]));
validateNumChannels(*avCodec, desiredNumChannels_);
// The avCodecContext layout defines the layout of the encoded output, it's
// not related to the input sampes.
setDefaultChannelLayout(avCodecContext_, desiredNumChannels_);

validateSampleRate(*avCodec, sampleRate);
avCodecContext_->sample_rate = sampleRate;

Expand All @@ -182,8 +191,6 @@ void AudioEncoder::initializeEncoder(
// what the `.sample_fmt` defines.
avCodecContext_->sample_fmt = findBestOutputSampleFormat(*avCodec);

setDefaultChannelLayout(avCodecContext_, static_cast<int>(wf_.sizes()[0]));

int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
TORCH_CHECK(
status == AVSUCCESS,
Expand Down Expand Up @@ -228,7 +235,9 @@ void AudioEncoder::encode() {
avFrame->format = AV_SAMPLE_FMT_FLTP;
avFrame->sample_rate = avCodecContext_->sample_rate;
avFrame->pts = 0;
setChannelLayout(avFrame, avCodecContext_);
// We set the channel layout of the frame to the default layout corresponding
// to the input samples' number of channels
setDefaultChannelLayout(avFrame, static_cast<int>(wf_.sizes()[0]));
Copy link
Contributor

@scotts scotts May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to think about this for a few moments to convince myself it's correct, so it may be worth putting in a comment: the default channel layout should be the channel layout of the provided waveform. The desired channel layout only comes in if we need to do any conversions in the encoding inner loop.


auto status = av_frame_get_buffer(avFrame.get(), 0);
TORCH_CHECK(
Expand Down Expand Up @@ -293,8 +302,10 @@ void AudioEncoder::encodeInnerLoop(
AutoAVPacket& autoAVPacket,
const UniqueAVFrame& srcAVFrame) {
bool mustConvert =
(avCodecContext_->sample_fmt != AV_SAMPLE_FMT_FLTP &&
srcAVFrame != nullptr);
(srcAVFrame != nullptr &&
(avCodecContext_->sample_fmt != AV_SAMPLE_FMT_FLTP ||
getNumChannels(srcAVFrame) != desiredNumChannels_));

UniqueAVFrame convertedAVFrame;
if (mustConvert) {
if (!swrContext_) {
Expand All @@ -304,15 +315,14 @@ void AudioEncoder::encodeInnerLoop(
srcAVFrame->sample_rate, // No sample rate conversion
srcAVFrame->sample_rate,
srcAVFrame,
getNumChannels(srcAVFrame) // No num_channel conversion
));
desiredNumChannels_));
}
convertedAVFrame = convertAudioAVFrameSamples(
swrContext_,
srcAVFrame,
avCodecContext_->sample_fmt,
srcAVFrame->sample_rate, // No sample rate conversion
getNumChannels(srcAVFrame)); // No num_channel conversion
desiredNumChannels_);
TORCH_CHECK(
convertedAVFrame->nb_samples == srcAVFrame->nb_samples,
"convertedAVFrame->nb_samples=",
Expand Down
15 changes: 12 additions & 3 deletions src/torchcodec/_core/Encoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ class AudioEncoder {
// like passing 0, which results in choosing the minimum supported bit rate.
// Passing 44_100 could result in output being 44000 if only 44000 is
// supported.
//
// TODO-ENCODING: bundle the optional params like bitRate, numChannels, etc.
// into an AudioStreamOptions struct, or similar.
AudioEncoder(
const torch::Tensor wf,
// The *output* sample rate. We can't really decide for the user what it
Expand All @@ -21,20 +24,23 @@ class AudioEncoder {
// encoding will still work but audio will be distorted.
int sampleRate,
std::string_view fileName,
std::optional<int64_t> bitRate = std::nullopt);
std::optional<int64_t> bitRate = std::nullopt,
std::optional<int64_t> numChannels = std::nullopt);
AudioEncoder(
const torch::Tensor wf,
int sampleRate,
std::string_view formatName,
std::unique_ptr<AVIOToTensorContext> avioContextHolder,
std::optional<int64_t> bitRate = std::nullopt);
std::optional<int64_t> bitRate = std::nullopt,
std::optional<int64_t> numChannels = std::nullopt);
void encode();
torch::Tensor encodeToTensor();

private:
void initializeEncoder(
int sampleRate,
std::optional<int64_t> bitRate = std::nullopt);
std::optional<int64_t> bitRate = std::nullopt,
std::optional<int64_t> numChannels = std::nullopt);
void encodeInnerLoop(
AutoAVPacket& autoAVPacket,
const UniqueAVFrame& srcAVFrame);
Expand All @@ -44,6 +50,9 @@ class AudioEncoder {
UniqueAVCodecContext avCodecContext_;
int streamIndex_;
UniqueSwrContext swrContext_;
// TODO-ENCODING: desiredNumChannels should just be part of an options struct,
// see other TODO above.
int desiredNumChannels_ = -1;

const torch::Tensor wf_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a comment here that wf stands for "wave form", and it's the original audio data passed to us by the user would be helpful. I know this is not directly related to the changes in this PR, but I keep having to remind myself of this fact.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a TODO somewhere to rename wf to samples, like in our Python API. That should make it more obvious


Expand Down
72 changes: 61 additions & 11 deletions src/torchcodec/_core/FFMPEGCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,21 +88,71 @@ void setDefaultChannelLayout(
#endif
}

void setChannelLayout(
UniqueAVFrame& dstAVFrame,
const UniqueAVCodecContext& avCodecContext) {
void setDefaultChannelLayout(UniqueAVFrame& avFrame, int numChannels) {
#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4
auto status = av_channel_layout_copy(
&dstAVFrame->ch_layout, &avCodecContext->ch_layout);
TORCH_CHECK(
status == AVSUCCESS,
"Couldn't copy channel layout to avFrame: ",
getFFMPEGErrorStringFromErrorCode(status));
AVChannelLayout channel_layout;
av_channel_layout_default(&channel_layout, numChannels);
avFrame->ch_layout = channel_layout;
#else
dstAVFrame->channel_layout = avCodecContext->channel_layout;
dstAVFrame->channels = avCodecContext->channels;
uint64_t channel_layout = av_get_default_channel_layout(numChannels);
avFrame->channel_layout = channel_layout;
avFrame->channels = numChannels;
#endif
}

void validateNumChannels(const AVCodec& avCodec, int numChannels) {
#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4
if (avCodec.ch_layouts == nullptr) {
// If we can't validate, we must assume it'll be fine. If not, FFmpeg will
// eventually raise.
return;
}
// FFmpeg doc indicate that the ch_layouts array is terminated by a zeroed
// layout, so checking for nb_channels == 0 should indicate its end.
for (auto i = 0; avCodec.ch_layouts[i].nb_channels != 0; ++i) {
if (numChannels == avCodec.ch_layouts[i].nb_channels) {
return;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment here saying that we've now entered the error path might be helpful - I think this is less obvious because we're in an #if block.

// At this point it seems that the encoder doesn't support the requested
// number of channels, so we error out.
std::stringstream supportedNumChannels;
for (auto i = 0; avCodec.ch_layouts[i].nb_channels != 0; ++i) {
if (i > 0) {
supportedNumChannels << ", ";
}
supportedNumChannels << avCodec.ch_layouts[i].nb_channels;
}
#else
if (avCodec.channel_layouts == nullptr) {
// can't validate, same as above.
return;
}
for (auto i = 0; avCodec.channel_layouts[i] != 0; ++i) {
if (numChannels ==
av_get_channel_layout_nb_channels(avCodec.channel_layouts[i])) {
return;
}
}
// At this point it seems that the encoder doesn't support the requested
// number of channels, so we error out.
std::stringstream supportedNumChannels;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto about error path, partially so both arms have the same structure.

for (auto i = 0; avCodec.channel_layouts[i] != 0; ++i) {
if (i > 0) {
supportedNumChannels << ", ";
}
supportedNumChannels << av_get_channel_layout_nb_channels(
avCodec.channel_layouts[i]);
}
#endif
TORCH_CHECK(
false,
"Desired number of channels (",
numChannels,
") is not supported by the ",
"encoder. Supported number of channels are: ",
supportedNumChannels.str(),
".");
}

namespace {
Expand Down
6 changes: 3 additions & 3 deletions src/torchcodec/_core/FFMPEGCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ void setDefaultChannelLayout(
UniqueAVCodecContext& avCodecContext,
int numChannels);

void setChannelLayout(
UniqueAVFrame& dstAVFrame,
const UniqueAVCodecContext& avCodecContext);
void setDefaultChannelLayout(UniqueAVFrame& avFrame, int numChannels);

void validateNumChannels(const AVCodec& avCodec, int numChannels);

void setChannelLayout(
UniqueAVFrame& dstAVFrame,
Expand Down
16 changes: 10 additions & 6 deletions src/torchcodec/_core/custom_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ TORCH_LIBRARY(torchcodec_ns, m) {
"torchcodec._core.ops", "//pytorch/torchcodec:torchcodec");
m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor");
m.def(
"encode_audio_to_file(Tensor wf, int sample_rate, str filename, int? bit_rate=None) -> ()");
"encode_audio_to_file(Tensor wf, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None) -> ()");
m.def(
"encode_audio_to_tensor(Tensor wf, int sample_rate, str format, int? bit_rate=None) -> Tensor");
"encode_audio_to_tensor(Tensor wf, int sample_rate, str format, int? bit_rate=None, int? num_channels=None) -> Tensor");
m.def(
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
m.def("_convert_to_tensor(int decoder_ptr) -> Tensor");
Expand Down Expand Up @@ -391,23 +391,27 @@ void encode_audio_to_file(
const at::Tensor wf,
int64_t sample_rate,
std::string_view file_name,
std::optional<int64_t> bit_rate = std::nullopt) {
AudioEncoder(wf, validateSampleRate(sample_rate), file_name, bit_rate)
std::optional<int64_t> bit_rate = std::nullopt,
std::optional<int64_t> num_channels = std::nullopt) {
AudioEncoder(
wf, validateSampleRate(sample_rate), file_name, bit_rate, num_channels)
.encode();
}

at::Tensor encode_audio_to_tensor(
const at::Tensor wf,
int64_t sample_rate,
std::string_view format,
std::optional<int64_t> bit_rate = std::nullopt) {
std::optional<int64_t> bit_rate = std::nullopt,
std::optional<int64_t> num_channels = std::nullopt) {
auto avioContextHolder = std::make_unique<AVIOToTensorContext>();
return AudioEncoder(
wf,
validateSampleRate(sample_rate),
format,
std::move(avioContextHolder),
bit_rate)
bit_rate,
num_channels)
.encodeToTensor();
}

Expand Down
12 changes: 10 additions & 2 deletions src/torchcodec/_core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,22 @@ def create_from_file_abstract(filename: str, seek_mode: Optional[str]) -> torch.
# TODO-ENCODING: rename wf to samples
@register_fake("torchcodec_ns::encode_audio_to_file")
def encode_audio_to_file_abstract(
wf: torch.Tensor, sample_rate: int, filename: str, bit_rate: Optional[int] = None
wf: torch.Tensor,
sample_rate: int,
filename: str,
bit_rate: Optional[int] = None,
num_channels: Optional[int] = None,
) -> None:
return


@register_fake("torchcodec_ns::encode_audio_to_tensor")
def encode_audio_to_tensor_abstract(
wf: torch.Tensor, sample_rate: int, format: str, bit_rate: Optional[int] = None
wf: torch.Tensor,
sample_rate: int,
format: str,
bit_rate: Optional[int] = None,
num_channels: Optional[int] = None,
) -> torch.Tensor:
return torch.empty([], dtype=torch.long)

Expand Down
4 changes: 4 additions & 0 deletions src/torchcodec/encoders/_audio_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,27 @@ def to_file(
dest: Union[str, Path],
*,
bit_rate: Optional[int] = None,
num_channels: Optional[int] = None,
) -> None:
_core.encode_audio_to_file(
wf=self._samples,
sample_rate=self._sample_rate,
filename=dest,
bit_rate=bit_rate,
num_channels=num_channels,
)

def to_tensor(
self,
format: str,
*,
bit_rate: Optional[int] = None,
num_channels: Optional[int] = None,
) -> Tensor:
return _core.encode_audio_to_tensor(
wf=self._samples,
sample_rate=self._sample_rate,
format=format,
bit_rate=bit_rate,
num_channels=num_channels,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no tests for the public API right now. I will soon migrate most of the existing encoder ops tests into testing the public Python APIs.

)
Loading
Loading