diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index f177c19b..453ae0e0 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -8,16 +8,19 @@ namespace facebook::torchcodec { namespace { -torch::Tensor validateWf(torch::Tensor wf) { +torch::Tensor validateSamples(torch::Tensor samples) { TORCH_CHECK( - wf.dtype() == torch::kFloat32, - "waveform must have float32 dtype, got ", - wf.dtype()); - TORCH_CHECK(wf.dim() == 2, "waveform must have 2 dimensions, got ", wf.dim()); + samples.dtype() == torch::kFloat32, + "samples must have float32 dtype, got ", + samples.dtype()); + TORCH_CHECK( + samples.dim() == 2, + "samples must have 2 dimensions, got ", + samples.dim()); // We enforce this, but if we get user reports we should investigate whether // that's actually needed. - int numChannels = static_cast(wf.sizes()[0]); + int numChannels = static_cast(samples.sizes()[0]); TORCH_CHECK( numChannels <= AV_NUM_DATA_POINTERS, "Trying to encode ", @@ -26,7 +29,7 @@ torch::Tensor validateWf(torch::Tensor wf) { AV_NUM_DATA_POINTERS, " channels per frame."); - return wf.contiguous(); + return samples.contiguous(); } void validateSampleRate(const AVCodec& avCodec, int sampleRate) { @@ -71,7 +74,7 @@ static const std::vector preferredFormatsOrder = { AVSampleFormat findBestOutputSampleFormat(const AVCodec& avCodec) { // Find a sample format that the encoder supports. We prefer using FLT[P], - // since this is the format of the input waveform. If FLTP isn't supported + // since this is the format of the input samples. If FLTP isn't supported // then we'll need to convert the AVFrame's format. Our heuristic is to encode // into the format with the highest resolution. if (avCodec.sample_fmts == nullptr) { @@ -98,11 +101,11 @@ AVSampleFormat findBestOutputSampleFormat(const AVCodec& avCodec) { AudioEncoder::~AudioEncoder() {} AudioEncoder::AudioEncoder( - const torch::Tensor wf, + const torch::Tensor samples, int sampleRate, std::string_view fileName, const AudioStreamOptions& audioStreamOptions) - : wf_(validateWf(wf)) { + : samples_(validateSamples(samples)) { setFFmpegLogLevel(); AVFormatContext* avFormatContext = nullptr; int status = avformat_alloc_output_context2( @@ -129,12 +132,13 @@ AudioEncoder::AudioEncoder( } AudioEncoder::AudioEncoder( - const torch::Tensor wf, + const torch::Tensor samples, int sampleRate, std::string_view formatName, std::unique_ptr avioContextHolder, const AudioStreamOptions& audioStreamOptions) - : wf_(validateWf(wf)), avioContextHolder_(std::move(avioContextHolder)) { + : samples_(validateSamples(samples)), + avioContextHolder_(std::move(avioContextHolder)) { setFFmpegLogLevel(); AVFormatContext* avFormatContext = nullptr; int status = avformat_alloc_output_context2( @@ -176,8 +180,8 @@ void AudioEncoder::initializeEncoder( // well when "-b:a" isn't specified. avCodecContext_->bit_rate = desiredBitRate.value_or(0); - outNumChannels_ = - static_cast(audioStreamOptions.numChannels.value_or(wf_.sizes()[0])); + outNumChannels_ = static_cast( + audioStreamOptions.numChannels.value_or(samples_.sizes()[0])); validateNumChannels(*avCodec, outNumChannels_); // The avCodecContext layout defines the layout of the encoded output, it's // not related to the input sampes. @@ -186,9 +190,9 @@ void AudioEncoder::initializeEncoder( validateSampleRate(*avCodec, sampleRate); avCodecContext_->sample_rate = sampleRate; - // Input waveform is expected to be FLTP. Not all encoders support FLTP, so we - // may need to convert the wf into a supported output sample format, which is - // what the `.sample_fmt` defines. + // Input samples are expected to be FLTP. Not all encoders support FLTP, so we + // may need to convert the samples into a supported output sample format, + // which is what the `.sample_fmt` defines. avCodecContext_->sample_fmt = findBestOutputSampleFormat(*avCodec); int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr); @@ -237,7 +241,7 @@ void AudioEncoder::encode() { avFrame->pts = 0; // We set the channel layout of the frame to the default layout corresponding // to the input samples' number of channels - setDefaultChannelLayout(avFrame, static_cast(wf_.sizes()[0])); + setDefaultChannelLayout(avFrame, static_cast(samples_.sizes()[0])); auto status = av_frame_get_buffer(avFrame.get(), 0); TORCH_CHECK( @@ -247,10 +251,10 @@ void AudioEncoder::encode() { AutoAVPacket autoAVPacket; - uint8_t* pwf = static_cast(wf_.data_ptr()); - int numSamples = static_cast(wf_.sizes()[1]); // per channel + uint8_t* psamples = static_cast(samples_.data_ptr()); + int numSamples = static_cast(samples_.sizes()[1]); // per channel int numEncodedSamples = 0; // per channel - int numBytesPerSample = static_cast(wf_.element_size()); + int numBytesPerSample = static_cast(samples_.element_size()); int numBytesPerChannel = numSamples * numBytesPerSample; status = avformat_write_header(avFormatContext_.get(), nullptr); @@ -270,11 +274,13 @@ void AudioEncoder::encode() { std::min(numSamplesAllocatedPerFrame, numSamples - numEncodedSamples); int numBytesToEncode = numSamplesToEncode * numBytesPerSample; - for (int ch = 0; ch < wf_.sizes()[0]; ch++) { + for (int ch = 0; ch < samples_.sizes()[0]; ch++) { std::memcpy( - avFrame->data[ch], pwf + ch * numBytesPerChannel, numBytesToEncode); + avFrame->data[ch], + psamples + ch * numBytesPerChannel, + numBytesToEncode); } - pwf += numBytesToEncode; + psamples += numBytesToEncode; // Above, we set the AVFrame's .nb_samples to AVCodecContext.frame_size so // that the frame buffers are allocated to a big enough size. Here, we reset diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index 08558b6b..bb746d04 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -15,18 +15,18 @@ class AudioEncoder { // Passing 44_100 could result in output being 44000 if only 44000 is // supported. AudioEncoder( - const torch::Tensor wf, + const torch::Tensor samples, // TODO-ENCODING: update this comment when we support an output sample // rate. This will become the input sample rate. // The *output* sample rate. We can't really decide for the user what it - // should be. Particularly, the sample rate of the input waveform should + // should be. Particularly, the sample rate of the input samples should // match this, and that's up to the user. If sample rates don't match, // encoding will still work but audio will be distorted. int sampleRate, std::string_view fileName, const AudioStreamOptions& audioStreamOptions); AudioEncoder( - const torch::Tensor wf, + const torch::Tensor samples, int sampleRate, std::string_view formatName, std::unique_ptr avioContextHolder, @@ -51,7 +51,7 @@ class AudioEncoder { int outNumChannels_ = -1; - const torch::Tensor wf_; + const torch::Tensor samples_; // Stores the AVIOContext for the output tensor buffer. std::unique_ptr avioContextHolder_; diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index b25a84e3..4a1c414b 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -29,9 +29,9 @@ TORCH_LIBRARY(torchcodec_ns, m) { "torchcodec._core.ops", "//pytorch/torchcodec:torchcodec"); m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor"); m.def( - "encode_audio_to_file(Tensor wf, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None) -> ()"); + "encode_audio_to_file(Tensor samples, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None) -> ()"); m.def( - "encode_audio_to_tensor(Tensor wf, int sample_rate, str format, int? bit_rate=None, int? num_channels=None) -> Tensor"); + "encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None) -> Tensor"); m.def( "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); m.def("_convert_to_tensor(int decoder_ptr) -> Tensor"); @@ -388,7 +388,7 @@ OpsAudioFramesOutput get_frames_by_pts_in_range_audio( } void encode_audio_to_file( - const at::Tensor wf, + const at::Tensor samples, int64_t sample_rate, std::string_view file_name, std::optional bit_rate = std::nullopt, @@ -399,12 +399,12 @@ void encode_audio_to_file( audioStreamOptions.bitRate = bit_rate; audioStreamOptions.numChannels = num_channels; AudioEncoder( - wf, validateSampleRate(sample_rate), file_name, audioStreamOptions) + samples, validateSampleRate(sample_rate), file_name, audioStreamOptions) .encode(); } at::Tensor encode_audio_to_tensor( - const at::Tensor wf, + const at::Tensor samples, int64_t sample_rate, std::string_view format, std::optional bit_rate = std::nullopt, @@ -416,7 +416,7 @@ at::Tensor encode_audio_to_tensor( audioStreamOptions.bitRate = bit_rate; audioStreamOptions.numChannels = num_channels; return AudioEncoder( - wf, + samples, validateSampleRate(sample_rate), format, std::move(avioContextHolder), diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index 11751e32..a68b51e2 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -161,10 +161,9 @@ def create_from_file_abstract(filename: str, seek_mode: Optional[str]) -> torch. return torch.empty([], dtype=torch.long) -# TODO-ENCODING: rename wf to samples @register_fake("torchcodec_ns::encode_audio_to_file") def encode_audio_to_file_abstract( - wf: torch.Tensor, + samples: torch.Tensor, sample_rate: int, filename: str, bit_rate: Optional[int] = None, @@ -175,7 +174,7 @@ def encode_audio_to_file_abstract( @register_fake("torchcodec_ns::encode_audio_to_tensor") def encode_audio_to_tensor_abstract( - wf: torch.Tensor, + samples: torch.Tensor, sample_rate: int, format: str, bit_rate: Optional[int] = None, diff --git a/src/torchcodec/encoders/_audio_encoder.py b/src/torchcodec/encoders/_audio_encoder.py index 469fbefb..3ad03912 100644 --- a/src/torchcodec/encoders/_audio_encoder.py +++ b/src/torchcodec/encoders/_audio_encoder.py @@ -34,7 +34,7 @@ def to_file( num_channels: Optional[int] = None, ) -> None: _core.encode_audio_to_file( - wf=self._samples, + samples=self._samples, sample_rate=self._sample_rate, filename=dest, bit_rate=bit_rate, @@ -49,7 +49,7 @@ def to_tensor( num_channels: Optional[int] = None, ) -> Tensor: return _core.encode_audio_to_tensor( - wf=self._samples, + samples=self._samples, sample_rate=self._sample_rate, format=format, bit_rate=bit_rate, diff --git a/test/test_ops.py b/test/test_ops.py index 5a5fe675..77be702b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1101,22 +1101,24 @@ def test_bad_input(self, tmp_path): with pytest.raises(RuntimeError, match="must have float32 dtype, got int"): encode_audio_to_file( - wf=torch.arange(10, dtype=torch.int), + samples=torch.arange(10, dtype=torch.int), sample_rate=10, filename=valid_output_file, ) with pytest.raises(RuntimeError, match="must have 2 dimensions, got 1"): encode_audio_to_file( - wf=torch.rand(3), sample_rate=10, filename=valid_output_file + samples=torch.rand(3), sample_rate=10, filename=valid_output_file ) with pytest.raises(RuntimeError, match="No such file or directory"): encode_audio_to_file( - wf=torch.rand(2, 10), sample_rate=10, filename="./bad/path.mp3" + samples=torch.rand(2, 10), sample_rate=10, filename="./bad/path.mp3" ) with pytest.raises(RuntimeError, match="check the desired extension"): encode_audio_to_file( - wf=torch.rand(2, 10), sample_rate=10, filename="./file.bad_extension" + samples=torch.rand(2, 10), + sample_rate=10, + filename="./file.bad_extension", )