Skip to content

Commit

Permalink
Add tflite inference
Browse files Browse the repository at this point in the history
  • Loading branch information
psykana committed Sep 1, 2021
1 parent 023da36 commit 4af7f83
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ data/raw/*
logs/*
models/*
processing/__pycache__
build/*
!/**/.gitkeep
52 changes: 52 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
cmake_minimum_required(VERSION 3.16)
project(inference C CXX)

set(CMAKE_CXX_STANDARD 11)

set(TENSORFLOW_SOURCE_DIR "" CACHE PATH
"Directory that contains the TensorFlow project"
)

file(GLOB model_dirs "${CMAKE_CURRENT_LIST_DIR}/models/*")
list(LENGTH model_dirs list_length)
if(list_length LESS_EQUAL 1)
message( FATAL_ERROR "Models directory is empty" )
endif()
list(SORT model_dirs COMPARE NATURAL ORDER DESCENDING)
list(GET model_dirs 0 latest_model_dir)
if(NOT EXISTS "${latest_model_dir}/model.tflite")
message( FATAL_ERROR "Could not find a .tflite model in ${latest_model_dir}" )
endif()

execute_process(
WORKING_DIRECTORY ${latest_model_dir}
COMMAND xxd -i model.tflite ${CMAKE_BINARY_DIR}/model.cc
)

if(NOT TENSORFLOW_SOURCE_DIR)
get_filename_component(TENSORFLOW_SOURCE_DIR
"${CMAKE_CURRENT_LIST_DIR}/inference_src/tensorflow"
ABSOLUTE
)
endif()

add_subdirectory(
"${TENSORFLOW_SOURCE_DIR}/tensorflow/lite"
"${CMAKE_CURRENT_BINARY_DIR}/tensorflow-lite"
EXCLUDE_FROM_ALL
)

find_package(PkgConfig REQUIRED)
pkg_search_module(FFTW REQUIRED fftw3 IMPORTED_TARGET)
include_directories(PkgConfig::FFTW)

add_executable(inference
inference_src/inference.cpp
${CMAKE_BINARY_DIR}/model.cc
)

target_link_libraries(inference
tensorflow-lite
PkgConfig::FFTW
m
)
132 changes: 132 additions & 0 deletions inference_src/inference.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#define _CRT_SECURE_NO_WARNINGS
#define _USE_MATH_DEFINES
#include <math.h>
#include <iomanip>

#include "AudioFile/AudioFile.h"

#include <fftw3.h>

#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/error_reporter.h"
#include "tensorflow/lite/optional_debug_tools.h"

#define TFLITE_MINIMAL_CHECK(x) \
if (!(x)) { \
fprintf(stderr, "Error at %s:%d\n", __FILE__, __LINE__); \
exit(1); \
}

class FlatBufferModel {
// Build a model based on a file. Return a nullptr in case of failure.
static std::unique_ptr<FlatBufferModel> BuildFromFile(
const char* filename,
tflite::ErrorReporter* error_reporter);

// Build a model based on a pre-loaded flatbuffer. The caller retains
// ownership of the buffer and should keep it alive until the returned object
// is destroyed. Return a nullptr in case of failure.
static std::unique_ptr<FlatBufferModel> BuildFromBuffer(
const char* buffer,
size_t buffer_size,
tflite::ErrorReporter* error_reporter);
};

int main() {
const std::string path = "SA1.WAV";
const int N = 1024;
const int numBins = (N / 2) + 1;
const int frameSize = 512;
const int overlap = 256;
const int step = frameSize - overlap;
const int channel = 0;

double* in;
in = (double*)fftw_malloc(sizeof(double) * N);
memset(in, 0, N * sizeof(in[0]));

double* hann;
hann = (double*)malloc(sizeof(double) * frameSize);
for (int i = 0; i < frameSize; i++) {
hann[i] = 0.5 * (1 - cos(2 * M_PI * i / (frameSize - 1)));
}

double* bins;
bins = (double*)malloc(sizeof(double) * numBins);

fftw_complex* out;
out = (fftw_complex*)fftw_malloc(sizeof(fftw_complex) * numBins);

fftw_plan p;
p = fftw_plan_dft_r2c_1d(N, in, out, FFTW_ESTIMATE);

AudioFile<double> audioFile;
audioFile.load(path);
audioFile.printSummary();

int numSamples = audioFile.getNumSamplesPerChannel();

int frameNum = int(floor(numSamples / overlap)) - 1;

// Load the model
std::unique_ptr<tflite::FlatBufferModel> model = tflite::FlatBufferModel::BuildFromFile("model.tflite");
TFLITE_MINIMAL_CHECK(model != nullptr);

// Build the interpreter
tflite::ops::builtin::BuiltinOpResolver resolver;
tflite::InterpreterBuilder builder(*model, resolver);
std::unique_ptr<tflite::Interpreter> interpreter;
builder(&interpreter);
TFLITE_MINIMAL_CHECK(interpreter != nullptr);

TFLITE_MINIMAL_CHECK(interpreter->AllocateTensors() == kTfLiteOk);
//printf("=== Pre-invoke Interpreter State ===\n");
//tflite::PrintInterpreterState(interpreter.get());

float* input_data_ptr = interpreter->typed_input_tensor<float>(0);
float* output_data_ptr = interpreter->typed_output_tensor<float>(0);

std::ofstream txt;
txt.open("result.txt", std::ios::out | std::ios::trunc);

for (int i = 0; i < frameNum; i++) {
int offset = int(i * step);
memcpy(in, &audioFile.samples[channel][offset], frameSize * sizeof(audioFile.samples[channel][0]));

for (int j = 0; j < frameSize; j++) {
in[j] = in[j] * hann[j];
}

fftw_execute(p);

for (int k = 0; k < (N / 2) + 1; k++) {
double realVal = out[k][0];
double imagVal = out[k][1];
bins[k] = realVal * realVal + imagVal * imagVal;
//std::cout << k << ": " << bins[k] << std::endl;
}

// Fill `input`.
float* input_data_ptr = interpreter->typed_input_tensor<float>(0);
for (int i = 0; i < numBins; i++) {
*(input_data_ptr) = (float)bins[i];
//std::cout << i << ": " << *(input_data_ptr) << std::endl;
input_data_ptr++;
}

TFLITE_MINIMAL_CHECK(interpreter->Invoke() == kTfLiteOk);
//printf("\n\n=== Post-invoke Interpreter State ===\n");
//tflite::PrintInterpreterState(interpreter.get());

txt << *output_data_ptr << std::endl;
}

txt.close();
fftw_destroy_plan(p);
fftw_free(in);
fftw_free(out);
free(bins);
return 0;
}

0 comments on commit 4af7f83

Please sign in to comment.