Skip to content

Commit

Permalink
Add ability to use BirdNET generated custom classifiers with Analyzer…
Browse files Browse the repository at this point in the history
… class (#55)

* Add ability to use custom classifier (w tests)

* Add ability to use custom classifier in multiprocessing helper

* Rename label path to labels

* Add docs for using custom classifiers

* Update BirdNET-Analyzer submodule

* Remove lat/lon/date from custom classifier docs

* Fix example for custom classifier
  • Loading branch information
joeweiss authored May 11, 2023
1 parent 0dc1886 commit 2a87862
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 5 deletions.
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,32 @@ recording.analyze()
print(recording.detections)
```

### Using a custom classifier with BirdNET-Analyzer

To use the a [model trained with BirdNET-Analyzer](https://github.com/kahst/BirdNET-Analyzer#training), use the `Analyzer` class.

```python
from birdnetlib import Recording
from birdnetlib.analyzer import Analyzer

# Load and initialize BirdNET-Analyzer with your own model/labels.

custom_model_path = "custom_classifiers/trogoniformes.tflite"
custom_labels_path = "custom_classifiers/trogoniformes.txt"

analyzer = Analyzer(
classifier_labels_path=custom_labels_path, classifier_model_path=custom_model_path
)

recording = Recording(
analyzer,
"sample.mp3",
min_conf=0.25,
)
recording.analyze()
print(recording.detections)
```

### Using BirdNET-Lite

To use the BirdNET-Lite model, use the `LiteAnalyzer` class.
Expand Down
19 changes: 17 additions & 2 deletions src/birdnetlib/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@ def as_dict(self):


class Analyzer:
def __init__(self, custom_species_list_path=None):
def __init__(
self,
custom_species_list_path=None,
classifier_model_path=None,
classifier_labels_path=None,
):
self.name = "Analyzer"
self.model_name = "BirdNET-Analyzer"
self.interpreter = None
Expand All @@ -58,6 +63,10 @@ def __init__(self, custom_species_list_path=None):
self.labels = []
self.results = []
self.custom_species_list = []

self.classifier_model_path = classifier_model_path
self.classifier_labels_path = classifier_labels_path

self.load_model()
self.load_labels()

Expand Down Expand Up @@ -164,7 +173,7 @@ def analyze_recording(self, recording):
)

# If recording has lon/lat, load cached list or predict a new species list.
if recording.lon and recording.lat:
if recording.lon and recording.lat and self.classifier_model_path == None:
print("recording has lon/lat")
self.set_predicted_species_list_from_position(recording)

Expand Down Expand Up @@ -197,6 +206,9 @@ def load_model(self):
print("load model")
# Load TFLite model and allocate tensors.
model_path = MODEL_PATH
if self.classifier_model_path:
print("loading custom classifier model")
model_path = self.classifier_model_path
num_threads = 1 # Default from BN-A config
self.interpreter = tflite.Interpreter(
model_path=model_path, num_threads=num_threads
Expand All @@ -216,6 +228,9 @@ def load_model(self):

def load_labels(self):
labels_file_path = LABEL_PATH
if self.classifier_labels_path:
print("loading custom classifier labels")
labels_file_path = self.classifier_labels_path
labels = []
with open(labels_file_path, "r") as lfile:
for line in lfile.readlines():
Expand Down
12 changes: 11 additions & 1 deletion src/birdnetlib/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,11 @@ def process_from_queue(shared_queue, results=[], analyzers=None):
from birdnetlib.analyzer import Analyzer

analyzers.append(
Analyzer(custom_species_list_path=i["custom_species_list_path"])
Analyzer(
custom_species_list_path=i["custom_species_list_path"],
classifier_model_path=i["classifier_model_path"],
classifier_labels_path=i["classifier_labels_path"],
)
)

recordings = []
Expand Down Expand Up @@ -231,6 +235,12 @@ def process(self):
analyzer_args = [
{
"model_name": i.model_name,
"classifier_labels_path": i.classifier_labels_path
if hasattr(i, "classifier_labels_path")
else None,
"classifier_model_path": i.classifier_model_path
if hasattr(i, "classifier_model_path")
else None,
"custom_species_list_path": i.custom_species_list_path,
}
for i in self.analyzers
Expand Down
2 changes: 1 addition & 1 deletion tests/BirdNET-Analyzer
Submodule BirdNET-Analyzer updated 44 files
+1 −0 .gitignore
+93 −0 BirdNET-Analyzer-full.spec
+1 −1 Dockerfile
+108 −16 README.md
+10 −18 analyze.py
+1 −1 audio.py
+ checkpoints/V2.3/BirdNET_GLOBAL_3K_V2.3_Model/saved_model.pb
+ checkpoints/V2.3/BirdNET_GLOBAL_3K_V2.3_Model/variables/variables.data-00000-of-00001
+ checkpoints/V2.3/BirdNET_GLOBAL_3K_V2.3_Model/variables/variables.index
+2 −0 config.py
+18 −2 extra-hooks/hook-librosa.py
+285 −52 gui.py
+0 −0 labels/V2.3/BirdNET_GLOBAL_3K_V2.3_Labels_af.txt
+0 −0 labels/V2.3/BirdNET_GLOBAL_3K_V2.3_Labels_ar.txt
+0 −0 labels/V2.3/BirdNET_GLOBAL_3K_V2.3_Labels_cs.txt
+0 −0 labels/V2.3/BirdNET_GLOBAL_3K_V2.3_Labels_da.txt
+0 −0 labels/V2.3/BirdNET_GLOBAL_3K_V2.3_Labels_de.txt
+0 −0 labels/V2.3/BirdNET_GLOBAL_3K_V2.3_Labels_es.txt
+0 −0 labels/V2.3/BirdNET_GLOBAL_3K_V2.3_Labels_fi.txt
+0 −0 labels/V2.3/BirdNET_GLOBAL_3K_V2.3_Labels_fr.txt
+0 −0 labels/V2.3/BirdNET_GLOBAL_3K_V2.3_Labels_hu.txt
+0 −0 labels/V2.3/BirdNET_GLOBAL_3K_V2.3_Labels_it.txt
+0 −0 labels/V2.3/BirdNET_GLOBAL_3K_V2.3_Labels_ja.txt
+0 −0 labels/V2.3/BirdNET_GLOBAL_3K_V2.3_Labels_ko.txt
+0 −0 labels/V2.3/BirdNET_GLOBAL_3K_V2.3_Labels_nl.txt
+0 −0 labels/V2.3/BirdNET_GLOBAL_3K_V2.3_Labels_no.txt
+0 −0 labels/V2.3/BirdNET_GLOBAL_3K_V2.3_Labels_pl.txt
+0 −0 labels/V2.3/BirdNET_GLOBAL_3K_V2.3_Labels_pt.txt
+0 −0 labels/V2.3/BirdNET_GLOBAL_3K_V2.3_Labels_ro.txt
+0 −0 labels/V2.3/BirdNET_GLOBAL_3K_V2.3_Labels_ru.txt
+0 −0 labels/V2.3/BirdNET_GLOBAL_3K_V2.3_Labels_sk.txt
+0 −0 labels/V2.3/BirdNET_GLOBAL_3K_V2.3_Labels_sl.txt
+0 −0 labels/V2.3/BirdNET_GLOBAL_3K_V2.3_Labels_sv.txt
+0 −0 labels/V2.3/BirdNET_GLOBAL_3K_V2.3_Labels_th.txt
+0 −0 labels/V2.3/BirdNET_GLOBAL_3K_V2.3_Labels_tr.txt
+0 −0 labels/V2.3/BirdNET_GLOBAL_3K_V2.3_Labels_uk.txt
+0 −0 labels/V2.3/BirdNET_GLOBAL_3K_V2.3_Labels_zh.txt
+20 −9 model.py
+22 −0 pyinstaller_analyze.py
+33 −0 pyinstaller_full.py
+22 −0 pyinstaller_gui.py
+3 −0 segments.py
+15 −7 train.py
+4 −0 utils.py
55 changes: 54 additions & 1 deletion tests/test_analyzer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from birdnetlib import Recording
from birdnetlib.analyzer import Analyzer
from birdnetlib.analyzer import Analyzer, MODEL_PATH, LABEL_PATH

from pprint import pprint
import pytest
Expand Down Expand Up @@ -156,6 +156,59 @@ def test_with_species_list():
recording.analyze()


def test_with_custom_classifier():

# Process file with command line utility, then process with python library and ensure equal commandline_results.

input_path = os.path.join(os.path.dirname(__file__), "test_files/soundscape.wav")
min_conf = 0.25
lon = -120.7463
lat = 35.4244
week_48 = 18

# Note, we're using the BirdNET_GLOBAL_3K_V2.3 as the "custom" classifier.

custom_model_path = MODEL_PATH
custom_labels_path = LABEL_PATH

analyzer = Analyzer(
classifier_labels_path=custom_labels_path,
classifier_model_path=custom_model_path,
)

recording = Recording(
analyzer,
input_path,
lon=lon,
lat=lat,
week_48=week_48,
min_conf=min_conf,
)
recording.analyze()

# Ensure that there is no BirdNET generated species list (from lon/lat or week)
# Custom classifiers do not use BirdNET generated species list as labels will likely not match.
assert analyzer.custom_species_list == []

assert len(recording.detections) == 12

detected_birds = [i["common_name"] for i in recording.detections]
assert detected_birds == [
"Black-capped Chickadee",
"Black-capped Chickadee",
"House Finch",
"Blue Jay",
"Blue Jay",
"Dark-eyed Junco",
"Dark-eyed Junco",
"Dark-eyed Junco",
"Dark-eyed Junco",
"Black-capped Chickadee",
"House Finch",
"House Finch",
]


def test_species_list_calls():

lon = -120.7463
Expand Down
41 changes: 41 additions & 0 deletions tests/test_multiprocessor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from birdnetlib.batch import DirectoryMultiProcessingAnalyzer
from birdnetlib.analyzer_lite import LiteAnalyzer
from birdnetlib.analyzer import Analyzer, MODEL_PATH, LABEL_PATH
import tempfile
import shutil
import os
Expand Down Expand Up @@ -94,6 +95,46 @@ def test_process_defined_batch():
print("test_process_defined_batch completed in", time.time() - start)


def test_batch_with_custom_classifiers():
test_files = "tests/test_files"

start = time.time()

defined_date = datetime(year=2022, month=5, day=10)

# Note, we're using the BirdNET_GLOBAL_3K_V2.3 as the "custom" classifier.
custom_model_path = MODEL_PATH
custom_labels_path = LABEL_PATH
analyzer = Analyzer(
classifier_labels_path=custom_labels_path,
classifier_model_path=custom_model_path,
)

with tempfile.TemporaryDirectory() as input_dir:
# Copy test files to temp directory.
copytree(test_files, input_dir)
assert len(os.listdir(input_dir)) == 7
batch = DirectoryMultiProcessingAnalyzer(
input_dir, date=defined_date, min_conf=0.4, analyzers=[analyzer]
)
batch.process()
assert len(batch.directory_recordings) == 5
# Ensure date was used
assert batch.directory_recordings[0].config["date"] == defined_date
assert batch.directory_recordings[0].config["minimum_confidence"] == 0.4
# Ensure the default is BirdNET-Analyzer
assert batch.directory_recordings[0].config["model_name"] == "BirdNET-Analyzer"
test_result_with_detections = [
i
for i in batch.directory_recordings
if i.path.endswith("XC563936 - Soundscape.mp3")
][0]

assert len(test_result_with_detections.detections) == 13

print("test_batch completed in", time.time() - start)


def test_batch_error():
analyzer = LiteAnalyzer()
test_files = "tests/test_files"
Expand Down

0 comments on commit 2a87862

Please sign in to comment.