Skip to content

Rename 'wf' to 'samples' in AudioEncoder #701

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 30 additions & 24 deletions src/torchcodec/_core/Encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,19 @@ namespace facebook::torchcodec {

namespace {

torch::Tensor validateWf(torch::Tensor wf) {
torch::Tensor validateSamples(torch::Tensor samples) {
TORCH_CHECK(
wf.dtype() == torch::kFloat32,
"waveform must have float32 dtype, got ",
wf.dtype());
TORCH_CHECK(wf.dim() == 2, "waveform must have 2 dimensions, got ", wf.dim());
samples.dtype() == torch::kFloat32,
"samples must have float32 dtype, got ",
samples.dtype());
TORCH_CHECK(
samples.dim() == 2,
"samples must have 2 dimensions, got ",
samples.dim());

// We enforce this, but if we get user reports we should investigate whether
// that's actually needed.
int numChannels = static_cast<int>(wf.sizes()[0]);
int numChannels = static_cast<int>(samples.sizes()[0]);
TORCH_CHECK(
numChannels <= AV_NUM_DATA_POINTERS,
"Trying to encode ",
Expand All @@ -26,7 +29,7 @@ torch::Tensor validateWf(torch::Tensor wf) {
AV_NUM_DATA_POINTERS,
" channels per frame.");

return wf.contiguous();
return samples.contiguous();
}

void validateSampleRate(const AVCodec& avCodec, int sampleRate) {
Expand Down Expand Up @@ -71,7 +74,7 @@ static const std::vector<AVSampleFormat> preferredFormatsOrder = {

AVSampleFormat findBestOutputSampleFormat(const AVCodec& avCodec) {
// Find a sample format that the encoder supports. We prefer using FLT[P],
// since this is the format of the input waveform. If FLTP isn't supported
// since this is the format of the input samples. If FLTP isn't supported
// then we'll need to convert the AVFrame's format. Our heuristic is to encode
// into the format with the highest resolution.
if (avCodec.sample_fmts == nullptr) {
Expand All @@ -98,11 +101,11 @@ AVSampleFormat findBestOutputSampleFormat(const AVCodec& avCodec) {
AudioEncoder::~AudioEncoder() {}

AudioEncoder::AudioEncoder(
const torch::Tensor wf,
const torch::Tensor samples,
int sampleRate,
std::string_view fileName,
const AudioStreamOptions& audioStreamOptions)
: wf_(validateWf(wf)) {
: samples_(validateSamples(samples)) {
setFFmpegLogLevel();
AVFormatContext* avFormatContext = nullptr;
int status = avformat_alloc_output_context2(
Expand All @@ -129,12 +132,13 @@ AudioEncoder::AudioEncoder(
}

AudioEncoder::AudioEncoder(
const torch::Tensor wf,
const torch::Tensor samples,
int sampleRate,
std::string_view formatName,
std::unique_ptr<AVIOToTensorContext> avioContextHolder,
const AudioStreamOptions& audioStreamOptions)
: wf_(validateWf(wf)), avioContextHolder_(std::move(avioContextHolder)) {
: samples_(validateSamples(samples)),
avioContextHolder_(std::move(avioContextHolder)) {
setFFmpegLogLevel();
AVFormatContext* avFormatContext = nullptr;
int status = avformat_alloc_output_context2(
Expand Down Expand Up @@ -176,8 +180,8 @@ void AudioEncoder::initializeEncoder(
// well when "-b:a" isn't specified.
avCodecContext_->bit_rate = desiredBitRate.value_or(0);

outNumChannels_ =
static_cast<int>(audioStreamOptions.numChannels.value_or(wf_.sizes()[0]));
outNumChannels_ = static_cast<int>(
audioStreamOptions.numChannels.value_or(samples_.sizes()[0]));
validateNumChannels(*avCodec, outNumChannels_);
// The avCodecContext layout defines the layout of the encoded output, it's
// not related to the input sampes.
Expand All @@ -186,9 +190,9 @@ void AudioEncoder::initializeEncoder(
validateSampleRate(*avCodec, sampleRate);
avCodecContext_->sample_rate = sampleRate;

// Input waveform is expected to be FLTP. Not all encoders support FLTP, so we
// may need to convert the wf into a supported output sample format, which is
// what the `.sample_fmt` defines.
// Input samples are expected to be FLTP. Not all encoders support FLTP, so we
// may need to convert the samples into a supported output sample format,
// which is what the `.sample_fmt` defines.
avCodecContext_->sample_fmt = findBestOutputSampleFormat(*avCodec);

int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
Expand Down Expand Up @@ -237,7 +241,7 @@ void AudioEncoder::encode() {
avFrame->pts = 0;
// We set the channel layout of the frame to the default layout corresponding
// to the input samples' number of channels
setDefaultChannelLayout(avFrame, static_cast<int>(wf_.sizes()[0]));
setDefaultChannelLayout(avFrame, static_cast<int>(samples_.sizes()[0]));

auto status = av_frame_get_buffer(avFrame.get(), 0);
TORCH_CHECK(
Expand All @@ -247,10 +251,10 @@ void AudioEncoder::encode() {

AutoAVPacket autoAVPacket;

uint8_t* pwf = static_cast<uint8_t*>(wf_.data_ptr());
int numSamples = static_cast<int>(wf_.sizes()[1]); // per channel
uint8_t* psamples = static_cast<uint8_t*>(samples_.data_ptr());
int numSamples = static_cast<int>(samples_.sizes()[1]); // per channel
int numEncodedSamples = 0; // per channel
int numBytesPerSample = static_cast<int>(wf_.element_size());
int numBytesPerSample = static_cast<int>(samples_.element_size());
int numBytesPerChannel = numSamples * numBytesPerSample;

status = avformat_write_header(avFormatContext_.get(), nullptr);
Expand All @@ -270,11 +274,13 @@ void AudioEncoder::encode() {
std::min(numSamplesAllocatedPerFrame, numSamples - numEncodedSamples);
int numBytesToEncode = numSamplesToEncode * numBytesPerSample;

for (int ch = 0; ch < wf_.sizes()[0]; ch++) {
for (int ch = 0; ch < samples_.sizes()[0]; ch++) {
std::memcpy(
avFrame->data[ch], pwf + ch * numBytesPerChannel, numBytesToEncode);
avFrame->data[ch],
psamples + ch * numBytesPerChannel,
numBytesToEncode);
}
pwf += numBytesToEncode;
psamples += numBytesToEncode;

// Above, we set the AVFrame's .nb_samples to AVCodecContext.frame_size so
// that the frame buffers are allocated to a big enough size. Here, we reset
Expand Down
8 changes: 4 additions & 4 deletions src/torchcodec/_core/Encoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@ class AudioEncoder {
// Passing 44_100 could result in output being 44000 if only 44000 is
// supported.
AudioEncoder(
const torch::Tensor wf,
const torch::Tensor samples,
// TODO-ENCODING: update this comment when we support an output sample
// rate. This will become the input sample rate.
// The *output* sample rate. We can't really decide for the user what it
// should be. Particularly, the sample rate of the input waveform should
// should be. Particularly, the sample rate of the input samples should
// match this, and that's up to the user. If sample rates don't match,
// encoding will still work but audio will be distorted.
int sampleRate,
std::string_view fileName,
const AudioStreamOptions& audioStreamOptions);
AudioEncoder(
const torch::Tensor wf,
const torch::Tensor samples,
int sampleRate,
std::string_view formatName,
std::unique_ptr<AVIOToTensorContext> avioContextHolder,
Expand All @@ -51,7 +51,7 @@ class AudioEncoder {

int outNumChannels_ = -1;

const torch::Tensor wf_;
const torch::Tensor samples_;

// Stores the AVIOContext for the output tensor buffer.
std::unique_ptr<AVIOToTensorContext> avioContextHolder_;
Expand Down
12 changes: 6 additions & 6 deletions src/torchcodec/_core/custom_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ TORCH_LIBRARY(torchcodec_ns, m) {
"torchcodec._core.ops", "//pytorch/torchcodec:torchcodec");
m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor");
m.def(
"encode_audio_to_file(Tensor wf, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None) -> ()");
"encode_audio_to_file(Tensor samples, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None) -> ()");
m.def(
"encode_audio_to_tensor(Tensor wf, int sample_rate, str format, int? bit_rate=None, int? num_channels=None) -> Tensor");
"encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None) -> Tensor");
m.def(
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
m.def("_convert_to_tensor(int decoder_ptr) -> Tensor");
Expand Down Expand Up @@ -388,7 +388,7 @@ OpsAudioFramesOutput get_frames_by_pts_in_range_audio(
}

void encode_audio_to_file(
const at::Tensor wf,
const at::Tensor samples,
int64_t sample_rate,
std::string_view file_name,
std::optional<int64_t> bit_rate = std::nullopt,
Expand All @@ -399,12 +399,12 @@ void encode_audio_to_file(
audioStreamOptions.bitRate = bit_rate;
audioStreamOptions.numChannels = num_channels;
AudioEncoder(
wf, validateSampleRate(sample_rate), file_name, audioStreamOptions)
samples, validateSampleRate(sample_rate), file_name, audioStreamOptions)
.encode();
}

at::Tensor encode_audio_to_tensor(
const at::Tensor wf,
const at::Tensor samples,
int64_t sample_rate,
std::string_view format,
std::optional<int64_t> bit_rate = std::nullopt,
Expand All @@ -416,7 +416,7 @@ at::Tensor encode_audio_to_tensor(
audioStreamOptions.bitRate = bit_rate;
audioStreamOptions.numChannels = num_channels;
return AudioEncoder(
wf,
samples,
validateSampleRate(sample_rate),
format,
std::move(avioContextHolder),
Expand Down
5 changes: 2 additions & 3 deletions src/torchcodec/_core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,9 @@ def create_from_file_abstract(filename: str, seek_mode: Optional[str]) -> torch.
return torch.empty([], dtype=torch.long)


# TODO-ENCODING: rename wf to samples
@register_fake("torchcodec_ns::encode_audio_to_file")
def encode_audio_to_file_abstract(
wf: torch.Tensor,
samples: torch.Tensor,
sample_rate: int,
filename: str,
bit_rate: Optional[int] = None,
Expand All @@ -175,7 +174,7 @@ def encode_audio_to_file_abstract(

@register_fake("torchcodec_ns::encode_audio_to_tensor")
def encode_audio_to_tensor_abstract(
wf: torch.Tensor,
samples: torch.Tensor,
sample_rate: int,
format: str,
bit_rate: Optional[int] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/torchcodec/encoders/_audio_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def to_file(
num_channels: Optional[int] = None,
) -> None:
_core.encode_audio_to_file(
wf=self._samples,
samples=self._samples,
sample_rate=self._sample_rate,
filename=dest,
bit_rate=bit_rate,
Expand All @@ -49,7 +49,7 @@ def to_tensor(
num_channels: Optional[int] = None,
) -> Tensor:
return _core.encode_audio_to_tensor(
wf=self._samples,
samples=self._samples,
sample_rate=self._sample_rate,
format=format,
bit_rate=bit_rate,
Expand Down
10 changes: 6 additions & 4 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,22 +1101,24 @@ def test_bad_input(self, tmp_path):

with pytest.raises(RuntimeError, match="must have float32 dtype, got int"):
encode_audio_to_file(
wf=torch.arange(10, dtype=torch.int),
samples=torch.arange(10, dtype=torch.int),
sample_rate=10,
filename=valid_output_file,
)
with pytest.raises(RuntimeError, match="must have 2 dimensions, got 1"):
encode_audio_to_file(
wf=torch.rand(3), sample_rate=10, filename=valid_output_file
samples=torch.rand(3), sample_rate=10, filename=valid_output_file
)

with pytest.raises(RuntimeError, match="No such file or directory"):
encode_audio_to_file(
wf=torch.rand(2, 10), sample_rate=10, filename="./bad/path.mp3"
samples=torch.rand(2, 10), sample_rate=10, filename="./bad/path.mp3"
)
with pytest.raises(RuntimeError, match="check the desired extension"):
encode_audio_to_file(
wf=torch.rand(2, 10), sample_rate=10, filename="./file.bad_extension"
samples=torch.rand(2, 10),
sample_rate=10,
filename="./file.bad_extension",
)


Expand Down
Loading