Skip to content

Commit

Permalink
Fix Keras compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
d4nst committed Nov 17, 2017
1 parent e09a2d4 commit d6741aa
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,14 +233,11 @@ def __init__(self, input, input_shape=None, color_mode='rgb', batch_size=64,

super(RotNetDataGenerator, self).__init__(N, batch_size, shuffle, seed)

def next(self):
with self.lock:
# get input data index and size of the current batch
index_array, _, current_batch_size = next(self.index_generator)
def _get_batches_of_transformed_samples(self, index_array):
# create array to hold the images
batch_x = np.zeros((current_batch_size,) + self.input_shape, dtype='float32')
batch_x = np.zeros((len(index_array),) + self.input_shape, dtype='float32')
# create array to hold the labels
batch_y = np.zeros(current_batch_size, dtype='float32')
batch_y = np.zeros(len(index_array), dtype='float32')

# iterate through the current batch
for i, j in enumerate(index_array):
Expand Down Expand Up @@ -287,6 +284,13 @@ def next(self):

return batch_x, batch_y

def next(self):
with self.lock:
# get input data index and size of the current batch
index_array, _, current_batch_size = next(self.index_generator)
# create array to hold the images
return self._get_batches_of_transformed_samples(index_array)


def display_examples(model, input, num_images=5, size=None, crop_center=False,
crop_largest_rect=False, preprocess_func=None, save_path=None):
Expand Down

0 comments on commit d6741aa

Please sign in to comment.