Skip to content

Commit

Permalink
support static graph for kunlun
Browse files Browse the repository at this point in the history
  • Loading branch information
QingshuChen authored Dec 7, 2020
1 parent f7a385f commit 50252a1
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 10 deletions.
34 changes: 34 additions & 0 deletions configs/deeplabv3p_xception65_optic_kunlun.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# 数据集配置
DATASET:
DATA_DIR: "./dataset/optic_disc_seg/"
NUM_CLASSES: 2
TEST_FILE_LIST: "./dataset/optic_disc_seg/test_list.txt"
TRAIN_FILE_LIST: "./dataset/optic_disc_seg/train_list.txt"
VAL_FILE_LIST: "./dataset/optic_disc_seg/val_list.txt"
VIS_FILE_LIST: "./dataset/optic_disc_seg/test_list.txt"

# 预训练模型配置
MODEL:
MODEL_NAME: "deeplabv3p"
DEFAULT_NORM_TYPE: "bn"
DEEPLAB:
BACKBONE: "xception_65"

# 其他配置
TRAIN_CROP_SIZE: (512, 512)
EVAL_CROP_SIZE: (512, 512)
AUG:
AUG_METHOD: "unpadding"
FIX_RESIZE_SIZE: (512, 512)
BATCH_SIZE: 2
TRAIN:
PRETRAINED_MODEL_DIR: "./pretrained_model/deeplabv3p_xception65_bn_coco/"
MODEL_SAVE_DIR: "./saved_model/deeplabv3p_xception65_bn_optic/"
SNAPSHOT_EPOCH: 2
TEST:
TEST_MODEL: "./saved_model/deeplabv3p_xception65_bn_optic/final"
SOLVER:
NUM_EPOCHS: 20
LR: 0.001
LR_POLICY: "poly"
OPTIMIZER: "adam"
18 changes: 15 additions & 3 deletions legacy/pdseg/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,13 @@ def parse_args():
parser.add_argument(
'--use_gpu',
dest='use_gpu',
help='Use gpu or cpu',
help='Use xpu, gpu or cpu',
action='store_true',
default=False)
parser.add_argument(
'--use_xpu',
dest='use_xpu',
help='Use xpu, gpu or cpu',
action='store_true',
default=False)
parser.add_argument(
Expand All @@ -68,7 +74,7 @@ def parse_args():
return parser.parse_args()


def evaluate(cfg, ckpt_dir=None, use_gpu=False, use_mpio=False, **kwargs):
def evaluate(cfg, ckpt_dir=None, use_gpu=False, use_xpu=False, use_mpio=False, **kwargs):
np.set_printoptions(precision=5, suppress=True)

startup_prog = fluid.Program()
Expand Down Expand Up @@ -97,7 +103,13 @@ def data_generator():
data_generator, drop_last=False, batch_size=cfg.BATCH_SIZE)

# Get device environment
places = fluid.cuda_places() if use_gpu else fluid.cpu_places()
if use_gpu:
places = fluid.cuda_places()
elif use_xpu:
xpu_id = int(os.environ.get('FLAGS_selected_xpus', 0))
places = [fluid.XPUPlace(xpu_id)]
else:
fluid.cpu_places()
place = places[0]
dev_count = len(places)
print("#Device count: {}".format(dev_count))
Expand Down
32 changes: 25 additions & 7 deletions legacy/pdseg/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,13 @@ def parse_args():
parser.add_argument(
'--use_gpu',
dest='use_gpu',
help='Use gpu or cpu',
help='Use gpu, xpu or cpu',
action='store_true',
default=False)
parser.add_argument(
'--use_xpu',
dest='use_xpu',
help='Use xpu, gpu or cpu',
action='store_true',
default=False)
parser.add_argument(
Expand Down Expand Up @@ -219,8 +225,16 @@ def data_generator():

# Get device environment
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = fluid.CUDAPlace(gpu_id) if args.use_gpu else fluid.CPUPlace()
places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places()
xpu_id = int(os.environ.get('FLAGS_selected_xpus', 0))
if args.use_gpu:
place = fluid.CUDAPlace(gpu_id)
places = fluid.cuda_places()
elif args.use_xpu:
place = fluid.XPUPlace(xpu_id)
places = [place]
else:
place = fluid.CPUPlace()
places = fluid.cpu_places()

# Get number of GPU
dev_count = cfg.NUM_TRAINERS if cfg.NUM_TRAINERS > 1 else len(places)
Expand Down Expand Up @@ -263,10 +277,13 @@ def data_generator():
print_info(
"Sync BatchNorm strategy will not be effective if GPU device"
" count <= 1")
compiled_train_prog = fluid.CompiledProgram(train_prog).with_data_parallel(
loss_name=avg_loss.name,
exec_strategy=exec_strategy,
build_strategy=build_strategy)
if args.use_xpu:
compiled_train_prog = train_prog
else:
compiled_train_prog = fluid.CompiledProgram(train_prog).with_data_parallel(
loss_name=avg_loss.name,
exec_strategy=exec_strategy,
build_strategy=build_strategy)

# Resume training
begin_epoch = cfg.SOLVER.BEGIN_EPOCH
Expand Down Expand Up @@ -408,6 +425,7 @@ def data_generator():
cfg=cfg,
ckpt_dir=ckpt_dir,
use_gpu=args.use_gpu,
use_xpu=args.use_xpu,
use_mpio=args.use_mpio)
if args.use_vdl:
log_writer.add_scalar('Evaluate/mean_iou', mean_iou, step)
Expand Down

0 comments on commit 50252a1

Please sign in to comment.