-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathpredict_toy_beam.sh
30 lines (27 loc) · 998 Bytes
/
predict_toy_beam.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
export MODEL_DIR=${TMPDIR:-/tmp}/nmt_tutorial
export VOCAB_SOURCE=${HOME}/nmt_data/toy_reverse/train/vocab.sources.txt
export VOCAB_TARGET=${HOME}/nmt_data/toy_reverse/train/vocab.targets.txt
export TRAIN_SOURCES=${HOME}/nmt_data/toy_reverse/train/sources.txt
export TRAIN_TARGETS=${HOME}/nmt_data/toy_reverse/train/targets.txt
export DEV_SOURCES=${HOME}/nmt_data/toy_reverse/dev/sources.txt
export DEV_TARGETS=${HOME}/nmt_data/toy_reverse/dev/targets.txt
export DEV_TARGETS_REF=${HOME}/nmt_data/toy_reverse/dev/targets.txt
export TRAIN_STEPS=1000
mkdir -p $MODEL_DIR
export PRED_DIR=${MODEL_DIR}/pred
mkdir -p ${PRED_DIR}
python3 -m bin.infer \
--tasks "
- class: DecodeText
- class: DumpBeams
params:
file: ${PRED_DIR}/beams.npz" \
--model_dir $MODEL_DIR \
--model_params "
inference.beam_search.beam_width: 5" \
--input_pipeline "
class: ParallelTextInputPipeline
params:
source_files:
- $DEV_SOURCES" \
> ${PRED_DIR}/predictions.txt