Skip to content

Commit

Permalink
Preparing for public model release
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Apr 22, 2020
1 parent a79b0c1 commit 35aeca9
Showing 1 changed file with 54 additions and 56 deletions.
110 changes: 54 additions & 56 deletions classify_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
# classify_images.py
#
# This is a test driver for running our species classifier. It classifies
# one or more hard-coded image files. Think of this more like a demo notebook
# than like a script you'd actually use to run a zillion files through our
# classifier.
# one or more hard-coded image files, writing the top N results for each
# to a .csv file.
#
# You should set the options in the 'Options' cell before running.
#
Expand All @@ -19,11 +18,14 @@
# By default, both point to URLs and will be downloaded to a temp directory
# automatically.
#
# This code has been tested against pytorch 1.2, so dependencies can be installed as:
# Dependencies (may work against later versions, but this was the test environment):
#
# conda install pytorch==1.2.0 torchvision==0.4.0 cudatoolkit=10.0 -c pytorch
# pip install pretrainedmodels
# pip install "pillow<7"
# pip install pretrainedmodels==0.7.4
# pip install pillow==6.1.0
# pip install progressbar2==3.51.0
# pip install cupy-cuda100==7.3.0
# pip install torchnet==0.0.4
#
#######

Expand All @@ -48,7 +50,7 @@
# little path management. This also implicitly defers PyTorch imports.

# Directory to which you sync'd this repo.
api_root = r'c:\git\speciesclassification'
api_root = r'/home/coyote/git/speciesclassification'

# If not None, pre-pended to filenames. Most useful when filenames are coming from
# a .csv file.
Expand All @@ -63,13 +65,13 @@
# a single image file
#
# a directory, which is recursively enumerated
images_to_classify = r'g:\temp\elephant.jpg'
# images_to_classify = [r'g:\temp\elephant.jpg']
# images_to_classify = r'/data/species-classification/elephant.jpg'
# images_to_classify = [r'/data/species-classification/elephant.jpg']
# images_to_classify = 'image_list.csv'
# images_to_classify = r'g:\temp\animal_images'
images_to_classify = r'/data/species-classification/images/sample_images.2019.12.28'

# Classification results will be written here
classification_output_file = 'g:\temp\classification_output.csv'
classification_output_file = '/data/species-classification/classification_output.csv'

# Path to taxa.csv, for latin --> common mapping
#
Expand All @@ -89,7 +91,7 @@

# This must be True if detection is enabled. Classification can be run
# on the CPU or GPU.
use_gpu = False
use_gpu = True


#%% Constants
Expand Down Expand Up @@ -127,42 +129,6 @@
import api as speciesapi


#%% Build Latin --> common mapping

latin_to_common = {}

if taxonomy_path is not None:

print('Reading taxonomy file')

# Read taxonomy file; takes ~1 minute
df = pd.read_csv(taxonomy_path)
df = df.fillna('')

# Columns are:
#
# taxonID,scientificName,parentNameUsageID,taxonRank,vernacularName,wikipedia_url

# Create dictionary by ID

nRows = df.shape[0]

for index, row in df.iterrows():

latin_name = row['scientificName']
latin_name = latin_name.strip()
if len(latin_name)==0:
print('Warning: invalid scientific name at {}'.format(index))
latin_name = 'unknown'
common_name = row['vernacularName']
common_name = common_name.strip()
latin_name = latin_name.lower()
common_name = common_name.lower()
latin_to_common[latin_name] = common_name

print('Finished reading taxonomy file')


#%% Support functions

class DownloadProgressBar():
Expand Down Expand Up @@ -255,8 +221,41 @@ def do_latin_to_common(latin_name):
assert os.path.isfile(detection_model_path)


#%% Create the model(s)
#%% Build Latin --> common mapping

latin_to_common = {}

if taxonomy_path is not None:

print('Reading taxonomy file from {}'.format(taxonomy_path))

# Read taxonomy file; takes ~1 minute
df = pd.read_csv(taxonomy_path)
df = df.fillna('')

# Columns are:
#
# taxonID,scientificName,parentNameUsageID,taxonRank,vernacularName,wikipedia_url

nRows = df.shape[0]

for index, row in df.iterrows():

latin_name = row['scientificName']
latin_name = latin_name.strip()
if len(latin_name)==0:
print('Warning: invalid scientific name at {}'.format(index))
latin_name = 'unknown'
common_name = row['vernacularName']
common_name = common_name.strip()
latin_name = latin_name.lower()
common_name = common_name.lower()
latin_to_common[latin_name] = common_name

print('Finished reading taxonomy file')


#%% Create the model(s)

print('Loading model')
model = speciesapi.DetectionClassificationAPI(classification_model_path,
Expand All @@ -270,10 +269,6 @@ def do_latin_to_common(latin_name):

queries = None

# If we specified a single image
if isinstance(images_to_classify,str) and (not images_to_classify.endswith('.csv')):
images_to_classify = [images_to_classify]

# If we specified a folder
if isinstance(images_to_classify,str) and os.path.isdir(images_to_classify):

Expand All @@ -294,13 +289,16 @@ def do_latin_to_common(latin_name):
queries = list(df_images.query_string)
assert(len(queries) == len(images))

# If we specified a list
elif isinstance(images_to_classify,list):
# If we specified a list or a single file
else:

if isinstance(images_to_classify,str):
images_to_classify = [images_to_classify]

assert isinstance(images_to_classify,list)
images = images_to_classify
queries = None
print('Processing list of {} images'.format(len(images)))
print('Processing a list of {} images'.format(len(images)))


#%% Classify images
Expand Down

0 comments on commit 35aeca9

Please sign in to comment.