-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add app module to support object detection visualization with opencv
- Loading branch information
1 parent
6efc26a
commit b1e5ba1
Showing
10 changed files
with
432 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright 2021 Huawei Technologies Co., Ltd | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
|
||
from . import object_detection | ||
|
||
__all__ = [] | ||
__all__.extend(object_detection.__all__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# Copyright 2021 Huawei Technologies Co., Ltd | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
""" | ||
This module is to support object detection visualization with opencv, | ||
which can help developers use object detection models to predict | ||
and show the detection image fast. | ||
""" | ||
from . import object_detector | ||
from . import utils | ||
|
||
__all__ = [] | ||
__all__.extend(object_detector.__all__) | ||
__all__.extend(object_detector.__all__) |
12 changes: 12 additions & 0 deletions
12
tinyms/app/object_detection/configs/tinyms/0.3/ssd300_shanshui.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
{ | ||
"model_net": "ssd300", | ||
"class_num": 11, | ||
"dataset": "shanshui", | ||
"download_from_hub": true, | ||
"checkpoint_url": "https://tinyms-hub.obs.cn-north-4.myhuaweicloud.com/tinyms/0.3/ssd300_v3_shanshui2021/ssd300.ckpt", | ||
"sha256": "66c4b4878ea7f7d20f5cff3b5de2d325698dd73becf79e79882fe835e0b8bf26", | ||
"checkpoint_path": "/etc/tinyms/object_detection/ssd300_shanshui", | ||
"checkpoint_name": "ssd300.ckpt", | ||
"description": "This object detection hosts a ssd300 model predicting for shanshui protected animal dataset." | ||
} | ||
|
12 changes: 12 additions & 0 deletions
12
tinyms/app/object_detection/configs/tinyms/0.3/ssd300_voc.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
{ | ||
"model_net": "ssd300", | ||
"class_num": 21, | ||
"dataset": "voc", | ||
"download_from_hub": true, | ||
"checkpoint_url": "https://tinyms-hub.obs.cn-north-4.myhuaweicloud.com/tinyms/0.3/ssd300_v1_voc2007/ssd300.ckpt", | ||
"sha256": "29ada5f9a903267b424c10e543d1419538905e958dcd0a2d6c5ad563c2b31ce1", | ||
"checkpoint_path": "/etc/tinyms/object_detection/ssd300_voc", | ||
"checkpoint_name": "ssd300.ckpt", | ||
"description": "This object detection hosts a ssd300 model predicting for voc dataset." | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
# Copyright 2021 Huawei Technologies Co., Ltd | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
import os | ||
import cv2 | ||
import numpy as np | ||
|
||
import tinyms as ts | ||
from tinyms import model | ||
from tinyms.vision import voc_transform, shanshui_tranform | ||
|
||
__all__ = ['ObjectDetector', 'object_detection_predict'] | ||
|
||
model_checker = { | ||
"ssd300": model.ssd300_mobilenetv2 | ||
} | ||
|
||
transform_checker = { | ||
'voc': voc_transform, | ||
'shanshui': shanshui_tranform | ||
} | ||
|
||
|
||
class ObjectDetector(): | ||
r''' | ||
ObjectDetector is a high-level class defined for building model,preproceing the input image, | ||
predicting and postprocessing the prediction output data. | ||
Args: | ||
config (dict): model config parsed from the json file under the app/object_detection/configs dir. | ||
''' | ||
def __init__(self, config=None): | ||
self.config = config | ||
|
||
def data_preprocess(self, input): | ||
r''' | ||
Preprocess the input image. | ||
Args: | ||
input (numpy.ndarray): the input image. | ||
Returns: | ||
list, the preprocess image shape. | ||
numpy.ndarray, the preprocess image result. | ||
''' | ||
if not isinstance(input, np.ndarray): | ||
err_msg = 'The input type should be numpy.ndarray, got {}.'.format(type(input)) | ||
raise TypeError(err_msg) | ||
image_height, image_width, _ = input.shape | ||
image_shape = (image_height, image_width) | ||
cvt_input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB) | ||
|
||
transform_func = transform_checker.get(self.config.get('dataset')) | ||
if not transform_func: | ||
err_msg = 'Currently dataset only supports {} transform!'.format(str(list(transform_checker.keys()))) | ||
raise KeyError(err_msg) | ||
transform_input = transform_func(cvt_input) | ||
return image_shape, transform_input | ||
|
||
def convert2tensor(self, transform_input): | ||
r''' | ||
Convert the numpy data to the tensor format. | ||
Args: | ||
transform_input (numpy.ndarray): the preprocessed image. | ||
Returns: | ||
Tensor, the converted image. | ||
''' | ||
if not isinstance(transform_input, np.ndarray): | ||
err_msg = 'The transform_input type should be numpy.ndarray, got {}.'.format(type(transform_input)) | ||
raise TypeError(err_msg) | ||
input_tensor = ts.expand_dims(ts.array(list(transform_input)), 0) | ||
return input_tensor | ||
|
||
def model_build(self, is_training=False): | ||
r''' | ||
Build the object detection model to predict the image. | ||
Args: | ||
is_training (bool): default: False. | ||
Returns: | ||
model.Model, generated object detection model. | ||
''' | ||
model_net = model_checker.get(self.config.get('model_net')) | ||
if not model_net: | ||
err_msg = 'Currently model_net only supports {}!'.format(str(list(model_checker.keys()))) | ||
raise KeyError(err_msg) | ||
|
||
class_num = self.config.get('class_num') | ||
if class_num <= 0: | ||
err_msg = 'The class_num should be an integer greater than 0, got {}.'.format(class_num) | ||
raise ValueError(err_msg) | ||
|
||
net = model_net(class_num=class_num, is_training=is_training) | ||
serve_model = model.Model(net) | ||
return serve_model | ||
|
||
def model_load_and_predict(self, serve_model, input_tensor): | ||
r''' | ||
Load the object detection model to predict the image. | ||
Args: | ||
serve_model (model.Model): object detection model. | ||
input_tensor(Tensor): the converted input image | ||
Returns: | ||
model.Model, object detection model loaded the checkpoint file. | ||
list, predictions output result. | ||
''' | ||
ckpt_path = self.config.get('checkpoint_path') | ||
if not ckpt_path: | ||
err_msg = 'The ckpt_path {} can not be none.'.format(ckpt_path) | ||
raise TypeError(err_msg) | ||
|
||
ckpt_name = self.config.get('checkpoint_name') | ||
if not ckpt_name.endswith('.ckpt'): | ||
err_msg = 'Currently model only supports `ckpt` format, got {}.'.format(ckpt_name) | ||
raise TypeError(err_msg) | ||
|
||
ckpt_file = os.path.join(ckpt_path, ckpt_name) | ||
if not os.path.isfile(ckpt_file): | ||
raise FileNotFoundError("The model checkpoint file path {} does not exist!".format(ckpt_file)) | ||
serve_model.load_checkpoint(ckpt_file) | ||
|
||
predictions_output = serve_model.predict(input_tensor) | ||
return serve_model, predictions_output | ||
|
||
def data_postprocess(self, predictions_output, image_shape): | ||
r''' | ||
Postprocessing the predictions output data. | ||
Args: | ||
predictions_output (list): predictions output data. | ||
image_shape(list): the shapr of the input image. | ||
Returns: | ||
dict, the postprocess result. | ||
''' | ||
output_np = (ts.concatenate((predictions_output[0], predictions_output[1]), axis=-1).asnumpy()) | ||
transform_func = transform_checker.get(self.config.get('dataset')) | ||
if not transform_func: | ||
raise KeyError("Currently dataset only supports {} transform!".format(str(list(transform_checker.keys())))) | ||
bbox_data = transform_func.postprocess(output_np, image_shape) | ||
return bbox_data | ||
|
||
|
||
def object_detection_predict(input, object_detector, is_training=False): | ||
r''' | ||
An easy object detection model predicting method for beginning developers to use. | ||
Args: | ||
input (numpy.ndarray): the input image. | ||
object_detector (ObjectDetector): the instance of the ObjectDetector class | ||
is_training (bool): default: False. | ||
Returns: | ||
dict, the postprocess result. | ||
''' | ||
if not isinstance(object_detector, ObjectDetector): | ||
err_msg = 'The object_detector is not the instance of ObjectDetector' | ||
raise TypeError(err_msg) | ||
image_shape, transform_input = object_detector.data_preprocess(input) | ||
input_tensor = object_detector.convert2tensor(transform_input) | ||
serve_model = object_detector.model_build(is_training=is_training) | ||
_, predictions_output = object_detector.model_load_and_predict(serve_model, input_tensor) | ||
detection_bbox_data = object_detector.data_postprocess(predictions_output, image_shape) | ||
return detection_bbox_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright 2021 Huawei Technologies Co., Ltd | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
from . import config_util, view_util | ||
|
||
__all__ = [] | ||
__all__.extend(config_util.__all__) | ||
__all__.extend(view_util.__all__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Copyright 2021 Huawei Technologies Co., Ltd | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
import os | ||
import json | ||
|
||
from tinyms.hub.utils.download import download_file_from_url | ||
|
||
__all__ = ['load_and parse_config'] | ||
|
||
|
||
def _download_ckeckpoint(checkpoint_url, sha256, checkpoint_path): | ||
if not checkpoint_url: | ||
err_msg = 'When set download_from_hub to true, the checkpoint_url can not be empty.' | ||
raise ValueError(err_msg) | ||
|
||
if not checkpoint_path: | ||
err_msg = 'When set download_from_hub to true, the checkpoint_path can not be empty.' | ||
raise ValueError(err_msg) | ||
|
||
if not sha256: | ||
err_msg = 'When set download_from_hub to true, the sha256 can not be empty.' | ||
raise ValueError(err_msg) | ||
|
||
download_file_from_url(checkpoint_url, hash_sha256=sha256, save_path=checkpoint_path) | ||
|
||
|
||
def load_and_parse_config(config_path): | ||
r''' | ||
Load and parse the json file the object detection model. | ||
Args: | ||
config_path (numpy.ndarray): the config json file path. | ||
Returns: | ||
dict, the model configuration. | ||
''' | ||
# Check if config_path existed | ||
if not os.path.exists(config_path): | ||
raise FileNotFoundError("The config file path {} does not exist!".format(config_path)) | ||
|
||
with open(config_path, 'r') as f: | ||
configs = json.load(f) | ||
if configs.get('download_from_hub'): | ||
_download_ckeckpoint(configs.get('checkpoint_url'), | ||
configs.get('sha256'), | ||
configs.get('checkpoint_path')) | ||
return configs |
Oops, something went wrong.