Skip to content

Commit

Permalink
Update dataset_stats() (ultralytics#3593)
Browse files Browse the repository at this point in the history
@kalenmike this is a PR to add image filenames and labels to our stats dictionary and to save the dictionary to JSON. Save location is next to the train labels.cache file. The single JSON contains all stats for entire dataset.

Usage example:
```python
from utils.datasets import *

dataset_stats('coco128.yaml', verbose=True)
```
  • Loading branch information
glenn-jocher authored Jun 12, 2021
1 parent 4984cf5 commit 7a565f1
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import glob
import hashlib
import json
import logging
import math
import os
Expand Down Expand Up @@ -1105,12 +1106,20 @@ def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False):
continue
x = []
dataset = LoadImagesAndLabels(data[split], augment=False, rect=True) # load dataset
if split == 'train':
cache_path = Path(dataset.label_files[0]).parent.with_suffix('.cache') # *.cache path
for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'):
x.append(np.bincount(label[:, 0].astype(int), minlength=nc))
x = np.array(x) # shape(128x80)
stats[split] = {'instances': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()},
'images': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()),
'per_class': (x > 0).sum(0).tolist()}}
stats[split] = {'instance_stats': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()},
'image_stats': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()),
'per_class': (x > 0).sum(0).tolist()},
'labels': {str(Path(k).name): v.tolist() for k, v in zip(dataset.img_files, dataset.labels)}}

# Save, print and return
with open(cache_path.with_suffix('.json'), 'w') as f:
json.dump(stats, f) # save stats *.json
if verbose:
print(yaml.dump([stats], sort_keys=False, default_flow_style=False))
# print(json.dumps(stats, indent=2, sort_keys=False))
return stats

0 comments on commit 7a565f1

Please sign in to comment.