Skip to content

Commit

Permalink
Update imagenet prediction decoding utilities
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Sep 24, 2016
1 parent 4c01c0c commit d5f1250
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
7 changes: 5 additions & 2 deletions docs/templates/applications.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ All of these architectures are compatible with both TensorFlow and Theano, and u
from keras.applications.resnet50 import ResNet50
from keras.preprocessing import image
from keras.applications.resnet50 import preprocess_input, decode_predictions
import numpy as np

model = ResNet50(weights='imagenet')

Expand All @@ -36,8 +37,10 @@ x = np.expand_dims(x, axis=0)
x = preprocess_input(x)

preds = model.predict(x)
print('Predicted:', decode_predictions(preds))
# print: [[u'n02504458', u'African_elephant']]
# decode the results into a list of tuples (class, description, probability)
# (one such list for each sample in the batch)
print('Predicted:', decode_predictions(preds, top=3)[0])
# Predicted: [(u'n02504013', u'Indian_elephant', 0.82658225), (u'n01871265', u'tusker', 0.1122357), (u'n02504458', u'African_elephant', 0.061040461)]
```

### Extract features with VGG16
Expand Down
15 changes: 10 additions & 5 deletions keras/applications/imagenet_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,21 @@ def preprocess_input(x, dim_ordering='default'):
return x


def decode_predictions(preds):
def decode_predictions(preds, top=5):
global CLASS_INDEX
assert len(preds.shape) == 2 and preds.shape[1] == 1000
if len(preds.shape) != 2 or preds.shape[1] != 1000:
raise ValueError('`decode_predictions` expects '
'a batch of predictions '
'(i.e. a 2D array of shape (samples, 1000)). '
'Found array with shape: ' + str(preds.shape))
if CLASS_INDEX is None:
fpath = get_file('imagenet_class_index.json',
CLASS_INDEX_PATH,
cache_subdir='models')
CLASS_INDEX = json.load(open(fpath))
indices = np.argmax(preds, axis=-1)
results = []
for i in indices:
results.append(CLASS_INDEX[str(i)])
for pred in preds:
top_indices = np.argpartition(pred, -top)[-top:][::-1]
result = [tuple(CLASS_INDEX[str(i)]) + (pred[i],) for i in top_indices]
results.append(result)
return results

0 comments on commit d5f1250

Please sign in to comment.