Skip to content

Commit

Permalink
Merge pull request alibaba#169 from alibaba/artist_script
Browse files Browse the repository at this point in the history
update painter example script
  • Loading branch information
chywang authored Jul 29, 2022
2 parents 120c733 + 105e2e1 commit b023b4d
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 113 deletions.
69 changes: 19 additions & 50 deletions examples/text2image_generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@

1. 下载数据
```shell
if [ ! -f ./tmp/T2I_train.txt ]; then
wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/artist_text2image/T2I_train.txt
wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/artist_text2image/T2I_val.txt
wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/artist_text2image/T2I_test.txt
mkdir tmp/
mv *.txt tmp/
if [ ! -f ./tmp/MUGE_train_text_imgbase64.tsv ]; then
wget -P ./tmp https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/painter_text2image/MUGE_train_text_imgbase64.tsv
wget -P ./tmp https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/painter_text2image/MUGE_val_text_imgbase64.tsv
wget -P ./tmp https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/painter_text2image/MUGE_test.text.tsv
fi
```

Expand Down Expand Up @@ -38,57 +36,27 @@ img = Image.open(BytesIO(base64.urlsafe_b64decode(image_base64)))
```

## 模型训练
1. 预训练

1. 模型微调
```shell
if [ ! -f ./tmp/vqgan_f16_16384.bin ]; then
wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/easynlp_modelzoo/alibaba-pai/vqgan_f16_16384.bin
mv vqgan_f16_16384.bin tmp/
fi

python -m torch.distributed.launch $DISTRIBUTED_ARGS examples/appzoo_tutorials/text2image_generation/main.py \
--mode=train \
--tables=./tmp/T2I_train.txt,./tmp/T2I_val.txt \
--input_schema=idx:str:1,text:str:1,imgbase64:str:1 \
--first_sequence=text \
--second_sequence=imgbase64 \
--checkpoint_dir=./tmp/artist_model_pretrain \
--learning_rate=4e-5 \
--epoch_num=1 \
--random_seed=42 \
--logging_steps=100 \
--save_checkpoint_steps=200 \
--sequence_length=288 \
--micro_batch_size=8 \
--app_name=text2image_generation \
--user_defined_parameters='
vqgan_ckpt_path=./tmp/vqgan_f16_16384.bin
size=256
text_len=32
img_len=256
img_vocab_size=16384
'
```

2. 模型微调
```shell
python -m torch.distributed.launch $DISTRIBUTED_ARGS examples/appzoo_tutorials/text2image_generation/main.py \
python -m torch.distributed.launch $DISTRIBUTED_ARGS examples/text2image_generation/main.py \
--mode=train \
--tables=./tmp/T2I_train.txt,./tmp/T2I_val.txt \
--worker_gpu=1 \
--tables=./tmp/MUGE_train_text_imgbase64.tsv,./tmp/MUGE_val_text_imgbase64.tsv \
--input_schema=idx:str:1,text:str:1,imgbase64:str:1 \
--first_sequence=text \
--second_sequence=imgbase64 \
--checkpoint_dir=./tmp/artist_model_finetune \
--checkpoint_dir=./tmp/finetune_model \
--learning_rate=4e-5 \
--epoch_num=1 \
--epoch_num=40 \
--random_seed=42 \
--logging_steps=100 \
--save_checkpoint_steps=200 \
--save_checkpoint_steps=1000 \
--sequence_length=288 \
--micro_batch_size=8 \
--micro_batch_size=16 \
--app_name=text2image_generation \
--user_defined_parameters='
pretrain_model_name_or_path=alibaba-pai/pai-shenbi-base-zh
pretrain_model_name_or_path=alibaba-pai/pai-painter-base-zh
size=256
text_len=32
img_len=256
Expand All @@ -98,16 +66,17 @@ if [ ! -f ./tmp/vqgan_f16_16384.bin ]; then

3. 模型预测
```shell
python -m torch.distributed.launch $DISTRIBUTED_ARGS examples/appzoo_tutorials/text2image_generation/main.py \
python -m torch.distributed.launch $DISTRIBUTED_ARGS examples/text2image_generation/main.py \
--mode=predict \
--tables=./tmp/T2I_test.txt \
--worker_gpu=1 \
--tables=./tmp/MUGE_test.text.tsv \
--input_schema=idx:str:1,text:str:1 \
--first_sequence=text \
--outputs=./tmp/T2I_outputs.txt \
--outputs=./tmp/T2I_outputs.tsv \
--output_schema=idx,text,gen_imgbase64 \
--checkpoint_dir=./tmp/artist_model_finetune \
--checkpoint_dir=./tmp/finetune_model \
--sequence_length=288 \
--micro_batch_size=8 \
--micro_batch_size=16 \
--app_name=text2image_generation \
--user_defined_parameters='
size=256
Expand Down
55 changes: 20 additions & 35 deletions examples/text2image_generation/run_appzoo_cli_local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,88 +5,73 @@ mode=$2
cur_path=$PWD/../../
cd ${cur_path}

# Download whl
if [ ! -f ./tmp/easynlp-0.0.5-py3-none-any.whl ]; then
wget -P ./tmp https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/geely_app/easynlp-0.0.5-py3-none-any.whl
fi

# Download data
if [ ! -f ./tmp/T2I_train.tsv ]; then
wget -P ./tmp https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/artist_text2image/T2I_train.tsv
wget -P ./tmp https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/artist_text2image/T2I_val.tsv
wget -P ./tmp https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/artist_text2image/T2I_test.tsv
fi

# Download artist-large ckpt
if [ ! -f ./tmp/artist-large-zh.tgz ]; then
wget -P ./tmp https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/geely_app/artist-large-zh.tgz
tar zxvf ./tmp/artist-large-zh.tgz -C ./tmp
if [ ! -f ./tmp/MUGE_train_text_imgbase64.tsv ]; then
wget -P ./tmp https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/painter_text2image/MUGE_train_text_imgbase64.tsv
wget -P ./tmp https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/painter_text2image/MUGE_val_text_imgbase64.tsv
wget -P ./tmp https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/painter_text2image/MUGE_test.text.tsv
fi

pip install ./tmp/easynlp-0.0.5-py3-none-any.whl

if [ "$mode" = "pretrain" ]; then
easynlp \
easynlp \
--mode=train \
--worker_gpu=1 \
--tables=./tmp/T2I_train.tsv,./tmp/T2I_val.tsv \
--tables=./tmp/MUGE_train_text_imgbase64.tsv,./tmp/MUGE_val_text_imgbase64.tsv \
--input_schema=idx:str:1,text:str:1,imgbase64:str:1 \
--first_sequence=text \
--second_sequence=imgbase64 \
--checkpoint_dir=./tmp/artist_model_pretrain \
--checkpoint_dir=./tmp/continue_pretrain_model/ \
--learning_rate=4e-5 \
--epoch_num=1 \
--epoch_num=40 \
--random_seed=42 \
--logging_steps=100 \
--save_checkpoint_steps=200 \
--save_checkpoint_steps=1000 \
--sequence_length=288 \
--micro_batch_size=8 \
--micro_batch_size=16 \
--app_name=text2image_generation \
--user_defined_parameters='
pretrain_model_name_or_path=./tmp/artist-large-zh
pretrain_model_name_or_path=alibaba-pai/pai-painter-base-zh
size=256
text_len=32
img_len=256
img_vocab_size=16384
'

'

elif [ "$mode" = "finetune" ]; then
easynlp \
--mode=train \
--worker_gpu=1 \
--tables=./tmp/T2I_train.tsv,./tmp/T2I_val.tsv \
--tables=./tmp/MUGE_train_text_imgbase64.tsv,./tmp/MUGE_val_text_imgbase64.tsv \
--input_schema=idx:str:1,text:str:1,imgbase64:str:1 \
--first_sequence=text \
--second_sequence=imgbase64 \
--checkpoint_dir=./tmp/artist_model_finetune \
--checkpoint_dir=./tmp/finetune_model \
--learning_rate=4e-5 \
--epoch_num=1 \
--epoch_num=40 \
--random_seed=42 \
--logging_steps=100 \
--save_checkpoint_steps=200 \
--save_checkpoint_steps=1000 \
--sequence_length=288 \
--micro_batch_size=8 \
--app_name=text2image_generation \
--user_defined_parameters='
pretrain_model_name_or_path=./tmp/artist_model_pretrain
pretrain_model_name_or_path=./tmp/continue_pretrain_model/
size=256
text_len=32
img_len=256
img_vocab_size=16384
'
'


elif [ "$mode" = "predict" ]; then
easynlp \
--mode=predict \
--worker_gpu=1 \
--tables=./tmp/T2I_test.tsv \
--tables=./tmp/MUGE_test.text.tsv \
--input_schema=idx:str:1,text:str:1 \
--first_sequence=text \
--outputs=./tmp/T2I_outputs.tsv \
--output_schema=idx,text,gen_imgbase64 \
--checkpoint_dir=./tmp/artist_model_finetune\
--checkpoint_dir=./tmp/finetune_model \
--sequence_length=288 \
--micro_batch_size=8 \
--app_name=text2image_generation \
Expand Down
48 changes: 20 additions & 28 deletions examples/text2image_generation/run_user_defined_local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,45 +11,37 @@ NNODES=1
NODE_RANK=0

# Download data
if [ ! -f ./tmp/T2I_train.tsv ]; then
wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/artist_text2image/T2I_train.tsv
wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/artist_text2image/T2I_val.tsv
wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/artist_text2image/T2I_test.tsv
mkdir tmp/
mv *.tsv tmp/
if [ ! -f ./tmp/MUGE_train_text_imgbase64.tsv ]; then
wget -P ./tmp https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/painter_text2image/MUGE_train_text_imgbase64.tsv
wget -P ./tmp https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/painter_text2image/MUGE_val_text_imgbase64.tsv
wget -P ./tmp https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/painter_text2image/MUGE_test.text.tsv
fi

# Download artist-large ckpt
if [ ! -f ./tmp/artist-large-zh.tgz ]; then
wget -P ./tmp https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/geely_app/artist-large-zh.tgz
tar zxvf ./tmp/artist-large-zh.tgz -C ./tmp
fi

DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
mode=$2


if [ "$mode" = "pretrain" ]; then
if [ ! -f ./tmp/vqgan_f16_16384.bin ]; then
wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/easynlp_modelzoo/alibaba-pai/vqgan_f16_16384.bin
mv vqgan_f16_16384.bin tmp/
wget -P ./tmp https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/easynlp_modelzoo/alibaba-pai/vqgan_f16_16384.bin
fi

python -m torch.distributed.launch $DISTRIBUTED_ARGS examples/text2image_generation/main.py \
--mode=train \
--worker_gpu=1 \
--tables=./tmp/T2I_train.tsv,./tmp/T2I_val.tsv \
--tables=./tmp/MUGE_train_text_imgbase64.tsv,./tmp/MUGE_val_text_imgbase64.tsv \
--input_schema=idx:str:1,text:str:1,imgbase64:str:1 \
--first_sequence=text \
--second_sequence=imgbase64 \
--checkpoint_dir=./tmp/artist_model_pretrain \
--checkpoint_dir=./tmp/pretrain_model \
--learning_rate=4e-5 \
--epoch_num=1 \
--epoch_num=40 \
--random_seed=42 \
--logging_steps=100 \
--save_checkpoint_steps=200 \
--save_checkpoint_steps=1000 \
--sequence_length=288 \
--micro_batch_size=8 \
--micro_batch_size=16 \
--app_name=text2image_generation \
--user_defined_parameters='
vqgan_ckpt_path=./tmp/vqgan_f16_16384.bin
Expand All @@ -68,21 +60,21 @@ elif [ "$mode" = "finetune" ]; then
python -m torch.distributed.launch $DISTRIBUTED_ARGS examples/text2image_generation/main.py \
--mode=train \
--worker_gpu=1 \
--tables=./tmp/T2I_train.tsv,./tmp/T2I_val.tsv \
--tables=./tmp/MUGE_train_text_imgbase64.tsv,./tmp/MUGE_val_text_imgbase64.tsv \
--input_schema=idx:str:1,text:str:1,imgbase64:str:1 \
--first_sequence=text \
--second_sequence=imgbase64 \
--checkpoint_dir=./tmp/artist_model_finetune \
--checkpoint_dir=./tmp/finetune_model \
--learning_rate=4e-5 \
--epoch_num=1 \
--epoch_num=40 \
--random_seed=42 \
--logging_steps=100 \
--save_checkpoint_steps=200 \
--save_checkpoint_steps=1000 \
--sequence_length=288 \
--micro_batch_size=8 \
--micro_batch_size=16 \
--app_name=text2image_generation \
--user_defined_parameters='
pretrain_model_name_or_path=artist-base-zh
pretrain_model_name_or_path=alibaba-pai/pai-painter-base-zh
size=256
text_len=32
img_len=256
Expand All @@ -94,20 +86,20 @@ elif [ "$mode" = "predict" ]; then
python -m torch.distributed.launch $DISTRIBUTED_ARGS examples/text2image_generation/main.py \
--mode=predict \
--worker_gpu=1 \
--tables=./tmp/T2I_test.tsv \
--tables=./tmp/MUGE_test.text.tsv \
--input_schema=idx:str:1,text:str:1 \
--first_sequence=text \
--outputs=./tmp/T2I_outputs.tsv \
--output_schema=idx,text,gen_imgbase64 \
--checkpoint_dir=./tmp/artist_model_finetune \
--checkpoint_dir=./tmp/finetune_model \
--sequence_length=288 \
--micro_batch_size=8 \
--micro_batch_size=16 \
--app_name=text2image_generation \
--user_defined_parameters='
size=256
text_len=32
img_len=256
img_vocab_size=16384
max_generated_num=4
max_generated_num=1
'
fi

0 comments on commit b023b4d

Please sign in to comment.