Skip to content

Commit

Permalink
add preprocess hook to load_dataset
Browse files Browse the repository at this point in the history
Signed-off-by: Jie Pu <[email protected]>
  • Loading branch information
jaypume committed Mar 29, 2021
1 parent ec6966e commit c29a638
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,25 @@
from tensorflow import keras

import sedna
from sedna.ml_model import save_model
from network import GlobalModelInspectionCNN
from sedna.ml_model import save_model


def image_process(line):
import keras.preprocessing.image as img_preprocessing
file_path, label = line.split(',')
img = img_preprocessing.load_img(file_path).resize((128, 128))
data = img_preprocessing.img_to_array(img) / 255.0
label = [0, 1] if int(label) == 0 else [1, 0]
data = np.array(data)
label = np.array(label)
return [data, label]


def main():
# load dataset.
train_data = sedna.load_train_dataset(data_format="txt", with_image=True)
train_data = sedna.load_train_dataset(data_format="txt",
preprocess_fun=image_process)

x = np.array([tup[0] for tup in train_data])
y = np.array([tup[1] for tup in train_data])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ def main():
class_names = sedna.context.get_parameters("class_names")

# load dataset.
train_data = sedna.load_train_dataset(data_format='txt',
with_image=False)
train_data = sedna.load_train_dataset(data_format='txt')

# read parameters from deployment config.
obj_threshold = sedna.context.get_parameters("obj_threshold")
Expand Down Expand Up @@ -88,13 +87,13 @@ def main():
model = Interface()

sedna.incremental_learning.train(model=model,
train_data=train_data,
epochs=epochs,
batch_size=batch_size,
class_names=class_names,
input_shape=input_shape,
obj_threshold=obj_threshold,
nms_threshold=nms_threshold)
train_data=train_data,
epochs=epochs,
batch_size=batch_size,
class_names=class_names,
input_shape=input_shape,
obj_threshold=obj_threshold,
nms_threshold=nms_threshold)


if __name__ == '__main__':
Expand Down
42 changes: 12 additions & 30 deletions lib/sedna/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,47 +18,47 @@
choice 2: a high level Dataset object not compatible with tensorflow,
but it's unified in our framework.
"""

import fileinput
import logging
import os

import numpy as np

from sedna.common.config import BaseConfig

LOG = logging.getLogger(__name__)


def _load_dataset(dataset_url, format, **kwargs):
def _load_dataset(dataset_url, format, preprocess_fun=None, **kwargs):
if dataset_url is None:
LOG.warning(f'dataset_url is None, please check the url.')
return None
if format == 'txt':
LOG.info(
f"dataset format is txt, now loading txt from [{dataset_url}]")
if kwargs.get('with_image'):
return _load_txt_dataset_with_image(dataset_url)
samples = _load_txt_dataset(dataset_url)
if preprocess_fun:
new_samples = [preprocess_fun(s) for s in samples]
else:
return _load_txt_dataset(dataset_url)
new_samples = samples
return new_samples


def load_train_dataset(data_format, **kwargs):
def load_train_dataset(data_format, preprocess_fun=None, **kwargs):
"""
:param data_format: txt
:param kwargs:
:return: Dataset
"""
return _load_dataset(BaseConfig.train_dataset_url, data_format, **kwargs)
return _load_dataset(BaseConfig.train_dataset_url, data_format,
preprocess_fun, **kwargs)


def load_test_dataset(data_format, **kwargs):
def load_test_dataset(data_format, preprocess_fun=None, **kwargs):
"""
:param data_format: txt
:param kwargs:
:return: Dataset
"""
return _load_dataset(BaseConfig.test_dataset_url, data_format, **kwargs)
return _load_dataset(BaseConfig.test_dataset_url, data_format,
preprocess_fun, **kwargs)


def _load_txt_dataset(dataset_url):
Expand All @@ -68,21 +68,3 @@ def _load_txt_dataset(dataset_url):
lines = f.readlines()
new_lines = [root_path + os.path.sep + l for l in lines]
return new_lines


def _load_txt_dataset_with_image(dataset_url):
import keras.preprocessing.image as img_preprocessing
root_path = os.path.dirname(dataset_url)
img_data = []
img_label = []
for line in fileinput.input(dataset_url):
file_path, label = line.split(',')
file_path = (file_path.replace("\\", os.path.sep)
.replace("/", os.path.sep))
file_path = os.path.join(root_path, file_path)
img = img_preprocessing.load_img(file_path).resize((128, 128))
img_data.append(img_preprocessing.img_to_array(img) / 255.0)
img_label += [(0, 1)] if int(label) == 0 else [(1, 0)]
data_set = [(np.array(line[0]), np.array(line[1]))
for line in zip(img_data, img_label)]
return data_set

0 comments on commit c29a638

Please sign in to comment.