Skip to content

Commit

Permalink
add beam search arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
theyorubayesian committed Jun 19, 2023
1 parent 081a8ab commit 64bae7d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
15 changes: 10 additions & 5 deletions scripts/lafand_mt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,27 @@
DATASET_DIR=data/lafand
TRAIN_BATCH_SIZE=16
EVAL_BATCH_SIZE=32
INFER_BATCH_SIZE=32
INFER_BATCH_SIZE=4
CHECKPOINT="gs://awarawa/T5_1_1_large/checkpoint_300000"
CHECKPOINT_PERIOD=auto
MODEL_SIZE="large"
EVAL_PERIOD=auto
BEAM_SEARCH_ALPHA=0.6
BEAM_SEARCH_WIDTH=5
# Please pass FEATURE_LENGTHS as string dictionary.
FEATURE_LENGTHS="{'inputs': 512, 'targets': 200}"
# We pretrained for 524288 steps if you use the final checkpoints.
# If you use any other checkpoint, take note of its pre-trained steps.
PRETRAINED_STEPS=300000
FT_NUM_EPOCHS=5
OUTPUT_DIR="arawat5_large_lafand_ibo_yor_zul"
OUTPUT_DIR="arawat5_large_lafand_hau_pcm_swa"
mkdir -p logs/$OUTPUT_DIR
REMOVE_CHECKPOINTS=true
# ---------------------------------------------

# LANGUAGES=("hau" "pcm" "swa" "ibo" "yor" "zul")
# LANGUAGES=("hau" "pcm" "swa")
LANGUAGES=("ibo" "yor" "zul")
LANGUAGES=("pcm" "swa")
# LANGUAGES=("ibo" "yor" "zul")
for language in ${LANGUAGES[@]}
do
# TODO: You can check the task name format in src/teva/tasks.py
Expand All @@ -47,7 +49,7 @@
[[ $EVAL_PERIOD == "auto" ]] && _EVAL_PERIOD=$num_steps_per_epoch || _EVAL_PERIOD=$EVAL_PERIOD
[[ $CHECKPOINT_PERIOD == "auto" ]] && _CHECKPOINT_PERIOD=$num_steps_per_epoch || _CHECKPOINT_PERIOD=$CHECKPOINT_PERIOD

for seed in 1 2 3
for seed in 1
do
seed_output_dir=runs/$OUTPUT_DIR/${task}_${seed}

Expand All @@ -66,6 +68,9 @@
--output_dir $seed_output_dir \
--cuda_12 \
--gin.infer_eval/utils.DatasetConfig.batch_size=$INFER_BATCH_SIZE \
--gin.models.EncoderDecoderModel.predict_batch_with_aux.num_decodes=$BEAM_SEARCH_WIDTH \
[email protected]_search \
--gin.decode.beam_search.alpha=$BEAM_SEARCH_ALPHA \
>& logs/$OUTPUT_DIR/${task}_${seed}_ft.log \
&& finetuned=true

Expand Down
8 changes: 8 additions & 0 deletions scripts/xlsum.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
CHECKPOINT_PERIOD=auto # If auto, we save checkpoint after every epoch. Otherwise set to value.
MODEL_SIZE="base"
EVAL_PERIOD=auto # If auto, we run evaluations after every epoch. Otherwise set to value.
BEAM_SEARCH_ALPHA=0.6
BEAM_SEARCH_WIDTH=5
# Please pass FEATURE_LENGTHS as string dictionary.
FEATURE_LENGTHS="{'inputs': 512, 'targets': 64}" # TODO: Change based on your task
# We pretrained for 524288 steps if you use the final checkpoints.
Expand Down Expand Up @@ -73,6 +75,9 @@
--model_size $MODEL_SIZE \
--output_dir $seed_output_dir \
--gin.infer_eval/utils.DatasetConfig.batch_size=$INFER_BATCH_SIZE \
--gin.models.EncoderDecoderModel.predict_batch_with_aux.num_decodes=$BEAM_SEARCH_WIDTH \
[email protected]_search \
--gin.decode.beam_search.alpha=$BEAM_SEARCH_ALPHA \
>& logs/$OUTPUT_DIR/${task}_${seed}_ft.log \
&& finetuned=true

Expand All @@ -92,6 +97,9 @@
# --batch_size $EVAL_BATCH_SIZE \
# --output_dir $seed_output_dir/eval_${checkpoint_steps} \
# --cuda_12 \
# --gin.models.EncoderDecoderModel.predict_batch_with_aux.num_decodes=4 \
# [email protected]_search \
# --gin.decode.beam_search.alpha=0.6 \
# >& logs/$OUTPUT_DIR/${task}_${seed}_eval_${checkpoint_steps}.log
# done

Expand Down

0 comments on commit 64bae7d

Please sign in to comment.