diff --git a/tf_pose/train.py b/tf_pose/train.py index df1e9ad4..5203b093 100644 --- a/tf_pose/train.py +++ b/tf_pose/train.py @@ -31,10 +31,10 @@ parser.add_argument('--model', default='mobilenet_v2_1.4', help='model name') parser.add_argument('--datapath', type=str, default='/data/public/rw/coco/annotations') parser.add_argument('--imgpath', type=str, default='/data/public/rw/coco/') - parser.add_argument('--batchsize', type=int, default=96) - parser.add_argument('--gpus', type=int, default=1) - parser.add_argument('--max-epoch', type=int, default=300) - parser.add_argument('--lr', type=str, default='0.01') + parser.add_argument('--batchsize', type=int, default=64) + parser.add_argument('--gpus', type=int, default=4) + parser.add_argument('--max-epoch', type=int, default=600) + parser.add_argument('--lr', type=str, default='0.001') parser.add_argument('--tag', type=str, default='test') parser.add_argument('--checkpoint', type=str, default='')