Skip to content

Commit

Permalink
Added autodetection of image extension in NDDS directory
Browse files Browse the repository at this point in the history
  • Loading branch information
tabula-rosa committed Jun 15, 2020
1 parent 6ccdd19 commit 4a74c81
Showing 1 changed file with 27 additions and 15 deletions.
42 changes: 27 additions & 15 deletions dream/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,7 @@ def is_ndds_dataset(input_dir, data_extension="json"):


def find_ndds_data_in_dir(
input_dir,
data_extension="json",
image_extension="png",
predictive=False,
requested_image_types="all",
input_dir, data_extension="json", image_extension=None, requested_image_types="all",
):

# Input argument handling
Expand All @@ -78,28 +74,44 @@ def find_ndds_data_in_dir(
assert os.path.exists(
input_dir
), 'Expected path "{}" to exist, but it does not.'.format(input_dir)
dirlist = os.listdir(input_dir)

assert isinstance(
data_extension, str
), 'Expected "data_extension" to be a string, but it is "{}".'.format(
type(data_extension)
)
assert isinstance(
image_extension, str
), 'Expected "image_extension" to be a string, but it is "{}".'.format(
type(image_extension)
)
data_full_ext = "." + data_extension

if image_extension is None:
# Auto detect based on list of image extensions to try
# In case there is a tie, prefer the first in this list
image_exts_to_try = ["png", "jpg"]
num_image_exts = []
for image_ext in image_exts_to_try:
num_image_exts.append(len([f for f in dirlist if f.endswith(image_ext)]))
max_num_image_exts = np.max(num_image_exts)
idx_max = np.where(num_image_exts == max_num_image_exts)[0]
if len(idx_max) == 1:
# Max exists once, so use the corresponding image extension
image_extension = image_exts_to_try[idx_max[0]]
else:
# Multiple cases of the same image extension
image_extension = image_exts_to_try[0]
else:
assert isinstance(
image_extension, str
), 'If specified, expected "image_extension" to be a string, but it is "{}".'.format(
type(image_extension)
)
image_full_ext = "." + image_extension

assert (
requested_image_types is None
or requested_image_types == "all"
or isinstance(requested_image_types, list)
), "Expected \"requested_image_types\" to be None, 'all', or a list of requested_image_types."

data_full_ext = "." + data_extension
image_full_ext = "." + image_extension

dirlist = os.listdir(input_dir)

# Read in json files
data_filenames = [f for f in dirlist if f.endswith(data_full_ext)]

Expand Down

0 comments on commit 4a74c81

Please sign in to comment.