Skip to content

Commit

Permalink
[MedicalSeg] Fix infer problem (PaddlePaddle#2013)
Browse files Browse the repository at this point in the history
  • Loading branch information
shiyutang authored Apr 15, 2022
1 parent 49e03b8 commit 6c6dcf9
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 28 deletions.
6 changes: 4 additions & 2 deletions contrib/MedicalSeg/README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,16 @@ MedicalSeg 是一个简单易使用的全流程 3D 医学图像分割工具包

#### **COVID-19 CT scans 上的分割结果**

| 主干网络 | 分辨率 | 学习率 | 训练轮数 | mDice | 链接 |

| 骨干网络 | 分辨率 | 学习率 | 训练轮数 | mDice | 链接 |
|:-:|:-:|:-:|:-:|:-:|:-:|
|-|128x128x128|0.001|15000|97.04%|[model](https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_1e-3/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_1e-3/train.log) \| [vdl](https://paddlepaddle.org.cn/paddle/visualdl/service/app?id=9db5c1e11ebc82f9a470f01a9114bd3c)|
|-|128x128x128|0.0003|15000|92.70%|[model](https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_3e-4/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_3e-4/train.log) \| [vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/scalar?id=0fb90ee5a6ea8821c0d61a6857ba4614)|

#### **MRISpineSeg 上的分割结果**

| 主干网络 | 分辨率 | 学习率 | 训练轮数 | mDice(20 classes) | Dice(16 classes) | 链接 |

| 骨干网络 | 分辨率 | 学习率 | 训练轮数 | mDice(20 classes) | Dice(16 classes) | 链接 |
|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
|-|512x512x12|0.1|15000|74.41%| 88.17% |[model](https://bj.bcebos.com/paddleseg/paddleseg3d/mri_spine_seg/vnet_mri_spine_seg_512_512_12_15k_1e-1/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/paddleseg3d/mri_spine_seg/vnet_mri_spine_seg_512_512_12_15k_1e-1/train.log) \| [vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/scalar?id=36504064c740e28506f991815bd21cc7)|
|-|512x512x12|0.5|15000|74.69%| 89.14% |[model](https://bj.bcebos.com/paddleseg/paddleseg3d/mri_spine_seg/vnet_mri_spine_seg_512_512_12_15k_5e-1/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/paddleseg3d/mri_spine_seg/vnet_mri_spine_seg_512_512_12_15k_5e-1/train.log) \| [vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/index?id=08b0f9f62ebb255cdfc93fd6bd8f2c06)|
Expand Down
40 changes: 22 additions & 18 deletions contrib/MedicalSeg/deploy/python/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import codecs
import os
import sys
import codecs
import warnings
import argparse

LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(LOCAL_PATH, '..', '..'))

import yaml
import numpy as np
import functools
import numpy as np

from paddle.inference import create_predictor, PrecisionType
from paddle.inference import Config as PredictConfig
Expand Down Expand Up @@ -383,29 +384,32 @@ def _preprocess(self, img):
"""
if not "npy" in img:
image_files = get_image_list(img, None, None)
warnings.warn("The image path is {}, please make sure this is the images you want to infer".format(image_files))
savepath = os.path.dirname(img)
pre = [
HUnorm,
functools.partial(
resample, # TODO: config preprocess in deply.yaml to set params
resample, # TODO: config preprocess in deply.yaml(export) to set params
new_shape=[128, 128, 128],
order=1)
]

for f in tqdm(image_files, total=len(image_files)):
f_np = Prep.load_medical_data(f)
if pre is not None:
for op in pre:
f_np = op(f_np)

# Set image to a uniform format before save.
f_np = f_np.astype("float32")

np.save(
os.path.join(
savepath, f.split("/")[-1].split(
".", maxsplit=1)[0]),
f_np)
for f in image_files:
f_nps = Prep.load_medical_data(f)
for f_np in f_nps:
if pre is not None:
for op in pre:
f_np = op(f_np)

# Set image to a uniform format before save.
if isinstance(f_np, tuple):
f_np = f_np[0]
f_np = f_np.astype("float32")

np.save(
os.path.join(
savepath, f.split("/")[-1].split(
".", maxsplit=1)[0]), f_np)

img = img.split(".", maxsplit=1)[0] + ".npy"

Expand Down
1 change: 0 additions & 1 deletion contrib/MedicalSeg/medicalseg/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def __init__(self,
savepath=seg_env.DATA_HOME,
extrapath=seg_env.DATA_HOME)
elif not os.path.exists(self.dataset_root):
print()
raise ValueError(
"The `dataset_root` don't exist please specify the correct path to data."
)
Expand Down
11 changes: 6 additions & 5 deletions contrib/MedicalSeg/tools/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,10 @@ def load_medical_data(f):
images = [sitk.DICOMOrient(img, 'LPS') for img in images]
f_nps = [sitk.GetArrayFromImage(img) for img in images]

# previous line already swap to xyz
# f_nps = [np.transpose(f_np, [1, 2, 0]) for f_np in f_nps] # swap to xyz
# if previous line not swap to xyz
if f_nps[0].shape[0] != f_nps[0].shape[1]:
f_nps = [np.transpose(f_np, [1, 2, 0]) for f_np in f_nps] # swap to xyz

elif filename.endswith(
(".mha", ".mhd", "nrrd"
)): # validate mhd on lung and mri with correct spacing_resample
Expand Down Expand Up @@ -233,14 +235,13 @@ def load_save(self):
["images", "labels", "images_test"][i])):

# load data will transpose the image from "zyx" to "xyz"
spacing = dataset_json_dict["training"][
osp.basename(f).split(".")[0]]["spacing"] if i == 0 else None
f_nps = Prep.load_medical_data(f)

for volume_idx, f_np in enumerate(f_nps):
for op in pre:
if op.__name__ == "resample":
spacing = dataset_json_dict["training"][
osp.basename(f).split(".")[0]][
"spacing"] if i == 0 else None
f_np, new_spacing = op(
f_np,
spacing=spacing) # (960, 15, 960) if transpose
Expand Down
3 changes: 1 addition & 2 deletions contrib/MedicalSeg/tools/prepare_mri_spine_seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,7 @@ def generate_txt(self, train_split=1.0):

if __name__ == "__main__":
prep = Prep_mri_spine()
if not os.path.isfile(prep.dataset_json_path):
prep.generate_dataset_json(
prep.generate_dataset_json(
modalities=('MRI-T2', ),
labels={
0: "Background",
Expand Down

0 comments on commit 6c6dcf9

Please sign in to comment.