diff --git a/examples/aishell/NST/conf/train_conformer.yaml b/examples/aishell/NST/conf/train_conformer.yaml index 8499de2e9..221c3635b 100644 --- a/examples/aishell/NST/conf/train_conformer.yaml +++ b/examples/aishell/NST/conf/train_conformer.yaml @@ -28,12 +28,37 @@ decoder_conf: self_attention_dropout_rate: 0.0 src_attention_dropout_rate: 0.0 +tokenizer: char +tokenizer_conf: + symbol_table_path: 'data/dict/lang_char.txt' + split_with_space: false + bpe_path: null + non_lang_syms_path: null + is_multilingual: false + num_languages: 1 + special_tokens: + : 0 + : 1 + : 2 + : 2 + +ctc: ctc +ctc_conf: + ctc_blank_id: 0 + +cmvn: global_cmvn +cmvn_conf: + cmvn_file: 'data/train/global_cmvn' + is_json_cmvn: true + # hybrid CTC/attention +model: asr_model model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option length_normalized_loss: false +dataset: asr dataset_conf: filter_conf: max_length: 1200 diff --git a/examples/aishell/NST/run_nst.sh b/examples/aishell/NST/run_nst.sh index 6a4eaf0e5..11af48624 100644 --- a/examples/aishell/NST/run_nst.sh +++ b/examples/aishell/NST/run_nst.sh @@ -72,7 +72,6 @@ data_type=shard num_utts_per_shard=1000 train_set=train train_config=conf/train_conformer.yaml -cmvn=true average_checkpoint=true target_pt=80 decode_checkpoint=$dir/$target_pt.pt @@ -113,9 +112,6 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then dist_backend="nccl" # the global_cmvn file need to be calculated by combining both supervised/unsupervised datasets, # and it should be positioned at data/${train_set}/global_cmvn . - cmvn_opts= - $cmvn && cp data/${train_set}/global_cmvn $dir/global_cmvn - $cmvn && cmvn_opts="--cmvn ${dir}/global_cmvn" # train.py rewrite $train_config to $dir/train.yaml with model input # and output dimension, and $dir/train.yaml will be used for inference @@ -133,14 +129,12 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then --train_engine ${train_engine} \ --config $train_config \ --data_type $data_type \ - --symbol_table $dict \ --train_data data/$train_set/$data_list \ --cv_data data/dev/data.list \ ${checkpoint:+--checkpoint $checkpoint} \ --model_dir $dir \ --ddp.dist_backend $dist_backend \ --num_workers 1 \ - $cmvn_opts \ --pin_memory \ --deepspeed_config ${deepspeed_config} \ --deepspeed.save_states ${deepspeed_save_states} @@ -190,7 +184,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then --beam_size 10 \ --batch_size 1 \ --penalty 0.0 \ - --dict $dict \ --ctc_weight $ctc_weight \ --reverse_weight $reverse_weight \ --result_file $test_dir/text \ @@ -216,7 +209,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then --beam_size 10 \ --batch_size 1 \ --penalty 0.0 \ - --dict $dict \ --ctc_weight $ctc_weight \ --reverse_weight $reverse_weight \ --result_file $dev_dir/text \ @@ -275,7 +267,6 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then --beam_size 10 \ --batch_size 1 \ --penalty 0.0 \ - --dict $dict \ --ctc_weight $ctc_weight \ --reverse_weight $reverse_weight \ --result_file data/train/${dir_split}data_sublist${job_num}/${hypo_name} \ diff --git a/examples/aishell/rnnt/conf/conformer_rnnt.yaml b/examples/aishell/rnnt/conf/conformer_rnnt.yaml index 690743760..a70284870 100644 --- a/examples/aishell/rnnt/conf/conformer_rnnt.yaml +++ b/examples/aishell/rnnt/conf/conformer_rnnt.yaml @@ -49,6 +49,29 @@ decoder_conf: self_attention_dropout_rate: 0.1 src_attention_dropout_rate: 0.1 +tokenizer: char +tokenizer_conf: + symbol_table_path: 'data/dict/lang_char.txt' + split_with_space: false + bpe_path: null + non_lang_syms_path: null + is_multilingual: false + num_languages: 1 + special_tokens: + : 0 + : 1 + : 2 + : 2 + +ctc: ctc +ctc_conf: + ctc_blank_id: 0 + +cmvn: global_cmvn +cmvn_conf: + cmvn_file: 'data/train/global_cmvn' + is_json_cmvn: true + # hybrid transducer+ctc+attention model: transducer model_conf: @@ -59,6 +82,7 @@ model_conf: length_normalized_loss: false reverse_weight: 0.3 +dataset: asr dataset_conf: filter_conf: max_length: 40960 diff --git a/examples/aishell/rnnt/conf/conformer_u2pp_rnnt.yaml b/examples/aishell/rnnt/conf/conformer_u2pp_rnnt.yaml index 5079ef988..b7ffb08d9 100644 --- a/examples/aishell/rnnt/conf/conformer_u2pp_rnnt.yaml +++ b/examples/aishell/rnnt/conf/conformer_u2pp_rnnt.yaml @@ -53,6 +53,29 @@ decoder_conf: self_attention_dropout_rate: 0.1 src_attention_dropout_rate: 0.1 +tokenizer: char +tokenizer_conf: + symbol_table_path: 'data/dict/lang_char.txt' + split_with_space: false + bpe_path: null + non_lang_syms_path: null + is_multilingual: false + num_languages: 1 + special_tokens: + : 0 + : 1 + : 2 + : 2 + +ctc: ctc +ctc_conf: + ctc_blank_id: 0 + +cmvn: global_cmvn +cmvn_conf: + cmvn_file: 'data/train/global_cmvn' + is_json_cmvn: true + # hybrid transducer+ctc+attention model: transducer model_conf: @@ -63,6 +86,7 @@ model_conf: length_normalized_loss: false reverse_weight: 0.3 +dataset: asr dataset_conf: filter_conf: max_length: 40960 diff --git a/examples/aishell/rnnt/conf/example_embedding_predictor.yaml b/examples/aishell/rnnt/conf/example_embedding_predictor.yaml index 9a1b4ecac..b6abc3be4 100644 --- a/examples/aishell/rnnt/conf/example_embedding_predictor.yaml +++ b/examples/aishell/rnnt/conf/example_embedding_predictor.yaml @@ -45,6 +45,29 @@ decoder_conf: self_attention_dropout_rate: 0.1 src_attention_dropout_rate: 0.1 +tokenizer: char +tokenizer_conf: + symbol_table_path: 'data/dict/lang_char.txt' + split_with_space: false + bpe_path: null + non_lang_syms_path: null + is_multilingual: false + num_languages: 1 + special_tokens: + : 0 + : 1 + : 2 + : 2 + +ctc: ctc +ctc_conf: + ctc_blank_id: 0 + +cmvn: global_cmvn +cmvn_conf: + cmvn_file: 'data/train/global_cmvn' + is_json_cmvn: true + # hybrid transducer+ctc+attention model: transducer model_conf: @@ -55,6 +78,7 @@ model_conf: length_normalized_loss: false reverse_weight: 0.3 +dataset: asr dataset_conf: filter_conf: max_length: 40960 diff --git a/examples/aishell/rnnt/run.sh b/examples/aishell/rnnt/run.sh index 9b9e7ff6a..c04718adc 100644 --- a/examples/aishell/rnnt/run.sh +++ b/examples/aishell/rnnt/run.sh @@ -42,7 +42,6 @@ num_utts_per_shard=1000 train_set=train train_config=conf/conformer_u2pp_rnnt.yaml -cmvn=true dir=exp/conformer_rnnt checkpoint= @@ -92,11 +91,10 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then mkdir -p $(dirname $dict) echo " 0" > ${dict} # 0 is for "blank" in CTC echo " 1" >> ${dict} # must be 1 + echo " 2" >> $dict tools/text2token.py -s 1 -n 1 data/train/text | cut -f 2- -d" " \ | tr " " "\n" | sort | uniq | grep -a -v -e '^\s*$' | \ - awk '{print $0 " " NR+1}' >> ${dict} - num_token=$(cat $dict | wc -l) - echo " $num_token" >> $dict + awk '{print $0 " " NR+2}' >> ${dict} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then @@ -118,9 +116,6 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') # Use "nccl" if it works, otherwise use "gloo" dist_backend="nccl" - cmvn_opts= - $cmvn && cp data/${train_set}/global_cmvn $dir - $cmvn && cmvn_opts="--cmvn ${dir}/global_cmvn" # train.py rewrite $train_config to $dir/train.yaml with model input # and output dimension, and $dir/train.yaml will be used for inference @@ -137,14 +132,12 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then --train_engine ${train_engine} \ --config $train_config \ --data_type $data_type \ - --symbol_table $dict \ --train_data data/$train_set/data.list \ --cv_data data/dev/data.list \ ${checkpoint:+--checkpoint $checkpoint} \ --model_dir $dir \ --ddp.dist_backend $dist_backend \ --num_workers 1 \ - $cmvn_opts \ --pin_memory \ --deepspeed_config ${deepspeed_config} \ --deepspeed.save_states ${deepspeed_save_states} @@ -183,7 +176,6 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then --beam_size 10 \ --batch_size 32 \ --penalty 0.0 \ - --dict $dict \ --ctc_weight $rescore_ctc_weight \ --transducer_weight $rescore_transducer_weight \ --attn_weight $rescore_attn_weight \ diff --git a/examples/aishell/s0/conf/train_conformer.yaml b/examples/aishell/s0/conf/train_conformer.yaml index b8ce511cd..ef0694cb0 100644 --- a/examples/aishell/s0/conf/train_conformer.yaml +++ b/examples/aishell/s0/conf/train_conformer.yaml @@ -28,12 +28,37 @@ decoder_conf: self_attention_dropout_rate: 0.0 src_attention_dropout_rate: 0.0 +tokenizer: char +tokenizer_conf: + symbol_table_path: 'data/dict/lang_char.txt' + split_with_space: false + bpe_path: null + non_lang_syms_path: null + is_multilingual: false + num_languages: 1 + special_tokens: + : 0 + : 1 + : 2 + : 2 + +ctc: ctc +ctc_conf: + ctc_blank_id: 0 + +cmvn: global_cmvn +cmvn_conf: + cmvn_file: 'data/train/global_cmvn' + is_json_cmvn: true + # hybrid CTC/attention +model: asr_model model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option length_normalized_loss: false +dataset: asr dataset_conf: filter_conf: max_length: 40960 diff --git a/examples/aishell/s0/conf/train_conformer_no_pos.yaml b/examples/aishell/s0/conf/train_conformer_no_pos.yaml index a2d5d03f5..c90a83904 100644 --- a/examples/aishell/s0/conf/train_conformer_no_pos.yaml +++ b/examples/aishell/s0/conf/train_conformer_no_pos.yaml @@ -28,12 +28,37 @@ decoder_conf: self_attention_dropout_rate: 0.0 src_attention_dropout_rate: 0.0 +tokenizer: char +tokenizer_conf: + symbol_table_path: 'data/dict/lang_char.txt' + split_with_space: false + bpe_path: null + non_lang_syms_path: null + is_multilingual: false + num_languages: 1 + special_tokens: + : 0 + : 1 + : 2 + : 2 + +ctc: ctc +ctc_conf: + ctc_blank_id: 0 + +cmvn: global_cmvn +cmvn_conf: + cmvn_file: 'data/train/global_cmvn' + is_json_cmvn: true + # hybrid CTC/attention +model: asr_model model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option length_normalized_loss: false +dataset: asr dataset_conf: filter_conf: max_length: 40960 diff --git a/examples/aishell/s0/conf/train_ebranchformer.yaml b/examples/aishell/s0/conf/train_ebranchformer.yaml index edc952295..5136f1ad9 100644 --- a/examples/aishell/s0/conf/train_ebranchformer.yaml +++ b/examples/aishell/s0/conf/train_ebranchformer.yaml @@ -31,12 +31,37 @@ decoder_conf: self_attention_dropout_rate: 0.1 src_attention_dropout_rate: 0.1 +tokenizer: char +tokenizer_conf: + symbol_table_path: 'data/dict/lang_char.txt' + split_with_space: false + bpe_path: null + non_lang_syms_path: null + is_multilingual: false + num_languages: 1 + special_tokens: + : 0 + : 1 + : 2 + : 2 + +ctc: ctc +ctc_conf: + ctc_blank_id: 0 + +cmvn: global_cmvn +cmvn_conf: + cmvn_file: 'data/train/global_cmvn' + is_json_cmvn: true + # hybrid CTC/attention +model: asr_model model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option length_normalized_loss: false +dataset: asr dataset_conf: filter_conf: max_length: 40960 diff --git a/examples/aishell/s0/conf/train_transformer.yaml b/examples/aishell/s0/conf/train_transformer.yaml index b7d7eee83..ef88c7420 100644 --- a/examples/aishell/s0/conf/train_transformer.yaml +++ b/examples/aishell/s0/conf/train_transformer.yaml @@ -23,12 +23,37 @@ decoder_conf: self_attention_dropout_rate: 0.0 src_attention_dropout_rate: 0.0 +tokenizer: char +tokenizer_conf: + symbol_table_path: 'data/dict/lang_char.txt' + split_with_space: false + bpe_path: null + non_lang_syms_path: null + is_multilingual: false + num_languages: 1 + special_tokens: + : 0 + : 1 + : 2 + : 2 + +ctc: ctc +ctc_conf: + ctc_blank_id: 0 + +cmvn: global_cmvn +cmvn_conf: + cmvn_file: 'data/train/global_cmvn' + is_json_cmvn: true + # hybrid CTC/attention +model: asr_model model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option length_normalized_loss: false +dataset: asr dataset_conf: filter_conf: max_length: 40960 diff --git a/examples/aishell/s0/conf/train_u2++_branchformer.yaml b/examples/aishell/s0/conf/train_u2++_branchformer.yaml index ef12c13a4..8702fbeb4 100644 --- a/examples/aishell/s0/conf/train_u2++_branchformer.yaml +++ b/examples/aishell/s0/conf/train_u2++_branchformer.yaml @@ -37,13 +37,38 @@ decoder_conf: self_attention_dropout_rate: 0.1 src_attention_dropout_rate: 0.1 +tokenizer: char +tokenizer_conf: + symbol_table_path: 'data/dict/lang_char.txt' + split_with_space: false + bpe_path: null + non_lang_syms_path: null + is_multilingual: false + num_languages: 1 + special_tokens: + : 0 + : 1 + : 2 + : 2 + +ctc: ctc +ctc_conf: + ctc_blank_id: 0 + +cmvn: global_cmvn +cmvn_conf: + cmvn_file: 'data/train/global_cmvn' + is_json_cmvn: true + # hybrid CTC/attention +model: asr_model model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option length_normalized_loss: false reverse_weight: 0.3 +dataset: asr dataset_conf: filter_conf: max_length: 40960 diff --git a/examples/aishell/s0/conf/train_u2++_conformer.yaml b/examples/aishell/s0/conf/train_u2++_conformer.yaml index b4587bce3..fadc1c451 100644 --- a/examples/aishell/s0/conf/train_u2++_conformer.yaml +++ b/examples/aishell/s0/conf/train_u2++_conformer.yaml @@ -33,13 +33,38 @@ decoder_conf: self_attention_dropout_rate: 0.1 src_attention_dropout_rate: 0.1 +tokenizer: char +tokenizer_conf: + symbol_table_path: 'data/dict/lang_char.txt' + split_with_space: false + bpe_path: null + non_lang_syms_path: null + is_multilingual: false + num_languages: 1 + special_tokens: + : 0 + : 1 + : 2 + : 2 + +ctc: ctc +ctc_conf: + ctc_blank_id: 0 + +cmvn: global_cmvn +cmvn_conf: + cmvn_file: 'data/train/global_cmvn' + is_json_cmvn: true + # hybrid CTC/attention +model: asr_model model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option length_normalized_loss: false reverse_weight: 0.3 +dataset: asr dataset_conf: filter_conf: max_length: 40960 diff --git a/examples/aishell/s0/conf/train_u2++_conformer_1.8B.yaml b/examples/aishell/s0/conf/train_u2++_conformer_1.8B.yaml index c13b4b295..d4de4c440 100644 --- a/examples/aishell/s0/conf/train_u2++_conformer_1.8B.yaml +++ b/examples/aishell/s0/conf/train_u2++_conformer_1.8B.yaml @@ -33,13 +33,38 @@ decoder_conf: self_attention_dropout_rate: 0.1 src_attention_dropout_rate: 0.1 +tokenizer: char +tokenizer_conf: + symbol_table_path: 'data/dict/lang_char.txt' + split_with_space: false + bpe_path: null + non_lang_syms_path: null + is_multilingual: false + num_languages: 1 + special_tokens: + : 0 + : 1 + : 2 + : 2 + +ctc: ctc +ctc_conf: + ctc_blank_id: 0 + +cmvn: global_cmvn +cmvn_conf: + cmvn_file: 'data/train/global_cmvn' + is_json_cmvn: true + # hybrid CTC/attention +model: asr_model model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option length_normalized_loss: false reverse_weight: 0.3 +dataset: asr dataset_conf: filter_conf: max_length: 40960 diff --git a/examples/aishell/s0/conf/train_u2++_efficonformer_v1.yaml b/examples/aishell/s0/conf/train_u2++_efficonformer_v1.yaml index 3d0de82db..928227565 100644 --- a/examples/aishell/s0/conf/train_u2++_efficonformer_v1.yaml +++ b/examples/aishell/s0/conf/train_u2++_efficonformer_v1.yaml @@ -38,7 +38,31 @@ decoder_conf: self_attention_dropout_rate: 0.1 src_attention_dropout_rate: 0.1 +tokenizer: char +tokenizer_conf: + symbol_table_path: 'data/dict/lang_char.txt' + split_with_space: false + bpe_path: null + non_lang_syms_path: null + is_multilingual: false + num_languages: 1 + special_tokens: + : 0 + : 1 + : 2 + : 2 + +ctc: ctc +ctc_conf: + ctc_blank_id: 0 + +cmvn: global_cmvn +cmvn_conf: + cmvn_file: 'data/train/global_cmvn' + is_json_cmvn: true + # hybrid CTC/attention +model: asr_model model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option @@ -46,6 +70,7 @@ model_conf: reverse_weight: 0.3 # dataset related +dataset: asr dataset_conf: batch_conf: batch_size: 16 diff --git a/examples/aishell/s0/conf/train_u2++_efficonformer_v1_stream.yaml b/examples/aishell/s0/conf/train_u2++_efficonformer_v1_stream.yaml index 3b5a99a86..65af5845d 100644 --- a/examples/aishell/s0/conf/train_u2++_efficonformer_v1_stream.yaml +++ b/examples/aishell/s0/conf/train_u2++_efficonformer_v1_stream.yaml @@ -38,7 +38,31 @@ decoder_conf: self_attention_dropout_rate: 0.1 src_attention_dropout_rate: 0.1 +tokenizer: char +tokenizer_conf: + symbol_table_path: 'data/dict/lang_char.txt' + split_with_space: false + bpe_path: null + non_lang_syms_path: null + is_multilingual: false + num_languages: 1 + special_tokens: + : 0 + : 1 + : 2 + : 2 + +ctc: ctc +ctc_conf: + ctc_blank_id: 0 + +cmvn: global_cmvn +cmvn_conf: + cmvn_file: 'data/train/global_cmvn' + is_json_cmvn: true + # hybrid CTC/attention +model: asr_model model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option @@ -46,6 +70,7 @@ model_conf: reverse_weight: 0.3 # dataset related +dataset: asr dataset_conf: batch_conf: batch_size: 16 diff --git a/examples/aishell/s0/conf/train_u2++_efficonformer_v2.yaml b/examples/aishell/s0/conf/train_u2++_efficonformer_v2.yaml index c23e1b64d..e14f3ba07 100644 --- a/examples/aishell/s0/conf/train_u2++_efficonformer_v2.yaml +++ b/examples/aishell/s0/conf/train_u2++_efficonformer_v2.yaml @@ -38,7 +38,31 @@ decoder_conf: self_attention_dropout_rate: 0.1 src_attention_dropout_rate: 0.1 +tokenizer: char +tokenizer_conf: + symbol_table_path: 'data/dict/lang_char.txt' + split_with_space: false + bpe_path: null + non_lang_syms_path: null + is_multilingual: false + num_languages: 1 + special_tokens: + : 0 + : 1 + : 2 + : 2 + +ctc: ctc +ctc_conf: + ctc_blank_id: 0 + +cmvn: global_cmvn +cmvn_conf: + cmvn_file: 'data/train/global_cmvn' + is_json_cmvn: true + # hybrid CTC/attention +model: asr_model model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option @@ -46,6 +70,7 @@ model_conf: reverse_weight: 0.3 # dataset related +dataset: asr dataset_conf: batch_conf: batch_size: 16 diff --git a/examples/aishell/s0/conf/train_u2++_lite_conformer.yaml b/examples/aishell/s0/conf/train_u2++_lite_conformer.yaml index 1eb280de2..0ffc064f2 100644 --- a/examples/aishell/s0/conf/train_u2++_lite_conformer.yaml +++ b/examples/aishell/s0/conf/train_u2++_lite_conformer.yaml @@ -33,7 +33,31 @@ decoder_conf: self_attention_dropout_rate: 0.1 src_attention_dropout_rate: 0.1 +tokenizer: char +tokenizer_conf: + symbol_table_path: 'data/dict/lang_char.txt' + split_with_space: false + bpe_path: null + non_lang_syms_path: null + is_multilingual: false + num_languages: 1 + special_tokens: + : 0 + : 1 + : 2 + : 2 + +ctc: ctc +ctc_conf: + ctc_blank_id: 0 + +cmvn: global_cmvn +cmvn_conf: + cmvn_file: 'data/train/global_cmvn' + is_json_cmvn: true + # hybrid CTC/attention +model: asr_model model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option @@ -41,6 +65,7 @@ model_conf: reverse_weight: 0.3 apply_non_blank_embedding: true # warning: had better use a well trained model as init model +dataset: asr dataset_conf: filter_conf: max_length: 40960 diff --git a/examples/aishell/s0/conf/train_u2++_transformer.yaml b/examples/aishell/s0/conf/train_u2++_transformer.yaml index 44b4d4be7..47c822637 100644 --- a/examples/aishell/s0/conf/train_u2++_transformer.yaml +++ b/examples/aishell/s0/conf/train_u2++_transformer.yaml @@ -26,13 +26,38 @@ decoder_conf: self_attention_dropout_rate: 0.0 src_attention_dropout_rate: 0.0 +tokenizer: char +tokenizer_conf: + symbol_table_path: 'data/dict/lang_char.txt' + split_with_space: false + bpe_path: null + non_lang_syms_path: null + is_multilingual: false + num_languages: 1 + special_tokens: + : 0 + : 1 + : 2 + : 2 + +ctc: ctc +ctc_conf: + ctc_blank_id: 0 + +cmvn: global_cmvn +cmvn_conf: + cmvn_file: 'data/train/global_cmvn' + is_json_cmvn: true + # hybrid CTC/attention +model: asr_model model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option length_normalized_loss: false reverse_weight: 0.3 +dataset: asr dataset_conf: filter_conf: max_length: 40960 diff --git a/examples/aishell/s0/conf/train_unified_conformer.yaml b/examples/aishell/s0/conf/train_unified_conformer.yaml index 978d3d91c..27a060e97 100644 --- a/examples/aishell/s0/conf/train_unified_conformer.yaml +++ b/examples/aishell/s0/conf/train_unified_conformer.yaml @@ -32,12 +32,37 @@ decoder_conf: self_attention_dropout_rate: 0.0 src_attention_dropout_rate: 0.0 +tokenizer: char +tokenizer_conf: + symbol_table_path: 'data/dict/lang_char.txt' + split_with_space: false + bpe_path: null + non_lang_syms_path: null + is_multilingual: false + num_languages: 1 + special_tokens: + : 0 + : 1 + : 2 + : 2 + +ctc: ctc +ctc_conf: + ctc_blank_id: 0 + +cmvn: global_cmvn +cmvn_conf: + cmvn_file: 'data/train/global_cmvn' + is_json_cmvn: true + # hybrid CTC/attention +model: asr_model model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option length_normalized_loss: false +dataset: asr dataset_conf: filter_conf: max_length: 40960 diff --git a/examples/aishell/s0/conf/train_unified_conformer_ctl.yaml b/examples/aishell/s0/conf/train_unified_conformer_ctl.yaml index 8cf2b726d..ad49c28c5 100644 --- a/examples/aishell/s0/conf/train_unified_conformer_ctl.yaml +++ b/examples/aishell/s0/conf/train_unified_conformer_ctl.yaml @@ -32,6 +32,29 @@ decoder_conf: self_attention_dropout_rate: 0.0 src_attention_dropout_rate: 0.0 +tokenizer: char +tokenizer_conf: + symbol_table_path: 'data/dict/lang_char.txt' + split_with_space: false + bpe_path: null + non_lang_syms_path: null + is_multilingual: false + num_languages: 1 + special_tokens: + : 0 + : 1 + : 2 + : 2 + +ctc: ctc +ctc_conf: + ctc_blank_id: 0 + +cmvn: global_cmvn +cmvn_conf: + cmvn_file: 'data/train/global_cmvn' + is_json_cmvn: true + # hybrid CTC/attention model: ctl_model model_conf: @@ -42,6 +65,7 @@ model_conf: logit_temp: 0.4 n_negatives: 100 +dataset: asr dataset_conf: filter_conf: max_length: 40960 diff --git a/examples/aishell/s0/conf/train_unified_transformer.yaml b/examples/aishell/s0/conf/train_unified_transformer.yaml index 9d7a38687..58506e50f 100644 --- a/examples/aishell/s0/conf/train_unified_transformer.yaml +++ b/examples/aishell/s0/conf/train_unified_transformer.yaml @@ -25,12 +25,37 @@ decoder_conf: self_attention_dropout_rate: 0.0 src_attention_dropout_rate: 0.0 +tokenizer: char +tokenizer_conf: + symbol_table_path: 'data/dict/lang_char.txt' + split_with_space: false + bpe_path: null + non_lang_syms_path: null + is_multilingual: false + num_languages: 1 + special_tokens: + : 0 + : 1 + : 2 + : 2 + +ctc: ctc +ctc_conf: + ctc_blank_id: 0 + +cmvn: global_cmvn +cmvn_conf: + cmvn_file: 'data/train/global_cmvn' + is_json_cmvn: true + # hybrid CTC/attention +model: asr_model model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option length_normalized_loss: false +dataset: asr dataset_conf: filter_conf: max_length: 40960 diff --git a/examples/aishell/s0/run.sh b/examples/aishell/s0/run.sh index a94ea7b27..44decb368 100644 --- a/examples/aishell/s0/run.sh +++ b/examples/aishell/s0/run.sh @@ -51,7 +51,6 @@ train_set=train # trained model, and freeze encoder module, otherwise there will be a # autograd error train_config=conf/train_conformer.yaml -cmvn=true dir=exp/conformer tensorboard_dir=tensorboard checkpoint= @@ -104,11 +103,10 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then mkdir -p $(dirname $dict) echo " 0" > ${dict} # 0 is for "blank" in CTC echo " 1" >> ${dict} # must be 1 + echo " 2" >> $dict tools/text2token.py -s 1 -n 1 data/train/text | cut -f 2- -d" " \ | tr " " "\n" | sort | uniq | grep -a -v -e '^\s*$' | \ - awk '{print $0 " " NR+1}' >> ${dict} - num_token=$(cat $dict | wc -l) - echo " $num_token" >> $dict + awk '{print $0 " " NR+2}' >> ${dict} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then @@ -132,9 +130,6 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # NOTE(xcsong): deepspeed fails with gloo, see # https://github.com/microsoft/DeepSpeed/issues/2818 dist_backend="nccl" - cmvn_opts= - $cmvn && cp data/${train_set}/global_cmvn $dir - $cmvn && cmvn_opts="--cmvn ${dir}/global_cmvn" # train.py rewrite $train_config to $dir/train.yaml with model input # and output dimension, and $dir/train.yaml will be used for inference @@ -168,7 +163,6 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then --train_engine ${train_engine} \ --config $train_config \ --data_type $data_type \ - --symbol_table ${dict} \ --train_data data/$train_set/data.list \ --cv_data data/dev/data.list \ ${checkpoint:+--checkpoint $checkpoint} \ @@ -177,7 +171,6 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then --ddp.dist_backend $dist_backend \ --num_workers ${num_workers} \ --prefetch ${prefetch} \ - $cmvn_opts \ --pin_memory \ --deepspeed_config ${deepspeed_config} \ --deepspeed.save_states ${deepspeed_save_states} @@ -209,7 +202,6 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then --beam_size 10 \ --batch_size 32 \ --penalty 0.0 \ - --dict $dict \ --ctc_weight $ctc_weight \ --reverse_weight $reverse_weight \ --result_dir $dir \ diff --git a/test/resources/global_cmvn b/test/resources/global_cmvn new file mode 100644 index 000000000..bf739e005 --- /dev/null +++ b/test/resources/global_cmvn @@ -0,0 +1 @@ +{"mean_stat": [533748832.0, 537379392.0, 553560256.0, 587163840.0, 631869248.0, 662597952.0, 684377536.0, 695391808.0, 692469312.0, 679435968.0, 666124288.0, 656322880.0, 665750272.0, 678694016.0, 681919872.0, 679622592.0, 669891392.0, 656596352.0, 653837696.0, 637679360.0, 628411776.0, 644834944.0, 638840384.0, 646180416.0, 639723136.0, 642757056.0, 637471488.0, 642369216.0, 643415296.0, 647383872.0, 649347712.0, 649295168.0, 650232704.0, 654485632.0, 660473088.0, 667416512.0, 673157760.0, 675672960.0, 675124608.0, 668018432.0, 670061312.0, 662625152.0, 663145152.0, 662503872.0, 666412736.0, 672262144.0, 678483264.0, 685387712.0, 692571648.0, 699067008.0, 700786048.0, 701201856.0, 702666816.0, 705442496.0, 706070272.0, 705989312.0, 702844352.0, 699318592.0, 696090432.0, 687558528.0, 675279680.0, 663676480.0, 662962688.0, 664300608.0, 666095936.0, 671681664.0, 676652800.0, 680098048.0, 683809344.0, 688702016.0, 692081536.0, 695787328.0, 701086080.0, 706389504.0, 711492544.0, 717638656.0, 719691584.0, 715811904.0, 696363712.0, 604650304.0], "var_stat": [5413303296.0, 5559859712.0, 6150998016.0, 6921247232.0, 7999772672.0, 8789866496.0, 9405788160.0, 9768055808.0, 9759781888.0, 9430648832.0, 9090540544.0, 8873153536.0, 9155926016.0, 9542536192.0, 9653545984.0, 9593427968.0, 9316633600.0, 8959290368.0, 8863544320.0, 8450601472.0, 8211594240.0, 8587095040.0, 8432593920.0, 8583958528.0, 8401722368.0, 8439351296.0, 8293776896.0, 8401498112.0, 8427502592.0, 8525175296.0, 8577086464.0, 8575106560.0, 8594980864.0, 8701697024.0, 8854953984.0, 9029480448.0, 9168768000.0, 9221465088.0, 9194510336.0, 8997099520.0, 9024593920.0, 8819391488.0, 8807892992.0, 8777240576.0, 8869681152.0, 9017400320.0, 9173416960.0, 9345583104.0, 9530641408.0, 9701219328.0, 9749004288.0, 9762750464.0, 9802030080.0, 9874440192.0, 9883308032.0, 9873497088.0, 9780661248.0, 9672590336.0, 9569466368.0, 9321841664.0, 8968135680.0, 8646355968.0, 8616979456.0, 8648616960.0, 8702088192.0, 8859219968.0, 8999392256.0, 9105938432.0, 9220412416.0, 9358609408.0, 9451415552.0, 9552719872.0, 9695440896.0, 9836685312.0, 9970956288.0, 10135898112.0, 10189419520.0, 10070478848.0, 9532954624.0, 7261243904.0], "frame_num": 54068199} \ No newline at end of file diff --git a/test/wenet/dataset/test_processor.py b/test/wenet/dataset/test_processor.py index abe6adeba..c48a17729 100644 --- a/test/wenet/dataset/test_processor.py +++ b/test/wenet/dataset/test_processor.py @@ -1,6 +1,6 @@ import pytest -import wenet.dataset.processor as processor +from wenet.dataset import processor from wenet.utils.init_tokenizer import init_tokenizer @@ -29,6 +29,11 @@ def test_tokenize(symbol_table_path): }, { "txt": "It's okay" }] + configs = {} + configs['tokenizer_conf'] = {} + configs['tokenizer_conf']['symbol_table_path'] = symbol_table_path + configs['tokenizer_conf']['non_lang_syms_path'] = None + configs['tokenizer_conf']['split_with_space'] = False if symbol_table_path == "test/resources/librispeech.words.txt": bpe_model = "test/resources/librispeech.train_960_unigram5000.bpemodel" refs = [{ @@ -67,6 +72,8 @@ def test_tokenize(symbol_table_path): "tokens": ['▁IT', "'", 'S', '▁O', 'KA', 'Y'], "label": [2344, 2, 3790, 3010, 2418, 4979] }] + configs['tokenizer'] = 'bpe' + configs['tokenizer_conf']['bpe_path'] = bpe_model else: bpe_model = None refs = [{ @@ -138,9 +145,9 @@ def test_tokenize(symbol_table_path): "tokens": ['I', 't', "'", 's', '▁', 'o', 'k', 'a', 'y'], "label": [24, 46, 2, 43, 1, 35, 27, 7, 56] }] + configs['tokenizer'] = 'char' - configs = {'split_with_space': False} - tokenizer = init_tokenizer(configs, symbol_table_path, bpe_model) + tokenizer = init_tokenizer(configs) outs = processor.tokenize(txts, tokenizer) for (hyp, ref) in zip(outs, refs): assert (len(hyp["tokens"]) == len(ref["tokens"])) diff --git a/test/wenet/text/test_paraformer_tokenizer.py b/test/wenet/text/test_paraformer_tokenizer.py index 5101fe693..b581955be 100644 --- a/test/wenet/text/test_paraformer_tokenizer.py +++ b/test/wenet/text/test_paraformer_tokenizer.py @@ -28,7 +28,7 @@ def paraformer_tokenizer(request): configs['tokenizer_conf']['symbol_table_path'] = wenet_units configs['tokenizer_conf']['seg_dict_path'] = os.path.join( download_root, seg_dict) - return init_tokenizer(configs, wenet_units) + return init_tokenizer(configs) def test_tokenize(paraformer_tokenizer): diff --git a/test/wenet/utils/test_init_model.py b/test/wenet/utils/test_init_model.py index f73e181ec..e16ec7455 100644 --- a/test/wenet/utils/test_init_model.py +++ b/test/wenet/utils/test_init_model.py @@ -30,5 +30,30 @@ def test_init_model(): config['input_dim'] = input_dim # TODO(xcsong): fix vocab_size config['output_dim'] = 3000 + if config.get('cmvn', None) == "global_cmvn": + config['cmvn_conf']['cmvn_file'] = "test/resources/global_cmvn" + if 'tokenizer' in config: + if config['tokenizer'] == "char": + config['tokenizer_conf'][ + 'symbol_table_path'] = "test/resources/aishell2.words.txt" + elif config['tokenizer'] == "bpe": + config['tokenizer_conf']['bpe_path'] = \ + "test/resources/librispeech.train_960_unigram5000.bpemodel" + config['tokenizer_conf']['symbol_table_path'] = \ + "test/resources/librispeech.words.txt" + config['tokenizer_conf']['non_lang_syms_path'] = \ + "test/resources/non-linguistic-symbols.invalid" + elif config['tokenizer'] == "whisper": + config['tokenizer_conf']['is_multilingual'] = True + config['tokenizer_conf']['num_languages'] = 100 + else: + raise NotImplementedError + else: + config['tokenizer'] = "char" + config['tokenizer_conf'] = {} + config['tokenizer_conf']['symbol_table_path'] = \ + "test/resources/aishell2.words.txt" + config['tokenizer_conf']['non_lang_syms_path'] = \ + "test/resources/non-linguistic-symbols.invalid" print("checking {} {}".format(c, config)) init_model(args, config) diff --git a/test/wenet/utils/test_init_tokenizer.py b/test/wenet/utils/test_init_tokenizer.py index 1b5b4af94..feef99802 100644 --- a/test/wenet/utils/test_init_tokenizer.py +++ b/test/wenet/utils/test_init_tokenizer.py @@ -6,12 +6,13 @@ def test_init_whisper_tokenizer(): # TODO(Mddct): add configs generator configs = {} - configs['whisper'] = True - configs['whisper_conf'] = {} - configs['whisper_conf']['is_multilingual'] = False - configs['whisper_conf']['num_languages'] = 99 + configs['tokenizer'] = 'whisper' + configs['tokenizer_conf'] = {} + configs['tokenizer_conf']['symbol_table_path'] = None + configs['tokenizer_conf']['is_multilingual'] = False + configs['tokenizer_conf']['num_languages'] = 99 - tokenizer = init_tokenizer(configs, None) + tokenizer = init_tokenizer(configs) text = "whisper powered by wenet, it's great" assert text == tokenizer.tokens2text(tokenizer.text2tokens(text)) @@ -22,7 +23,11 @@ def test_init_whisper_tokenizer(): ]) def test_init_char_tokenizer(symbol_table_path): configs = {} - tokenizer = init_tokenizer(configs, symbol_table_path) + configs['tokenizer'] = 'char' + configs['tokenizer_conf'] = {} + configs['tokenizer_conf']['symbol_table_path'] = symbol_table_path + configs['tokenizer_conf']['non_lang_syms_path'] = None + tokenizer = init_tokenizer(configs) text = "大家都好帅" assert text == tokenizer.tokens2text(tokenizer.text2tokens(text)) @@ -35,7 +40,12 @@ def test_init_char_tokenizer(symbol_table_path): def test_init_bpe_tokenizer(symbol_table_path, bpe_model): configs = {} - tokenizer = init_tokenizer(configs, symbol_table_path, bpe_model) + configs['tokenizer'] = 'bpe' + configs['tokenizer_conf'] = {} + configs['tokenizer_conf']['bpe_path'] = bpe_model + configs['tokenizer_conf']['symbol_table_path'] = symbol_table_path + configs['tokenizer_conf']['non_lang_syms_path'] = None + tokenizer = init_tokenizer(configs) text = "WENET IT'S GREAT" assert text == tokenizer.tokens2text(tokenizer.text2tokens(text)) diff --git a/tools/onnx2horizonbin.py b/tools/onnx2horizonbin.py index 3f6474572..96bc4061c 100755 --- a/tools/onnx2horizonbin.py +++ b/tools/onnx2horizonbin.py @@ -49,7 +49,6 @@ from wenet.utils.common import remove_duplicates_and_blank from wenet.dataset.dataset import Dataset from wenet.utils.checkpoint import load_checkpoint -from wenet.utils.file_utils import read_symbol_table from wenet.utils.init_model import init_model from wenet.utils.init_tokenizer import init_tokenizer from wenet.bin.export_onnx_cpu import to_numpy @@ -76,10 +75,9 @@ def save_data(tensor, dirs, prefix): data.tofile(dirs + "/" + prefix + ".bin") -def make_calibration_data(enc, args, conf): +def make_calibration_data(enc, args, conf, tokenizer): conf['shuffle'] = True logger.info(conf) - tokenizer = init_tokenizer(ali_conf, args.symbol_table, args.bpe_model) dataset = Dataset("shard", args.cali_datalist, tokenizer, @@ -151,16 +149,15 @@ def make_calibration_data(enc, args, conf): prefix + "." + str(i)) -def check_wer(enc, ctc, args, conf): +def check_wer(enc, ctc, args, conf, tokenizer): conf['shuffle'] = False - tokenizer = init_tokenizer(ali_conf, args.symbol_table, args.bpe_model) dataset = Dataset("shard", args.wer_datalist, tokenizer, conf, partition=False) dataloader = DataLoader(dataset, batch_size=None, num_workers=0) - char_dict = {v: k for k, v in args.symbol_table.items()} + char_dict = {v: k for k, v in tokenizer.symbol_table.items()} eos = len(char_dict) - 1 enc_session = HB_ONNXRuntime( @@ -375,7 +372,6 @@ def get_args(): default=0.5, type=float, help='reverse_weight in attention_rescoing') - parser.add_argument('--dict', type=str, required=True, help='dict file') parser.add_argument('--max_samples', type=int, required=True, @@ -389,10 +385,6 @@ def get_args(): default=None, help='check wer') parser.add_argument('--wer_text', type=str, default=None, help='check wer') - parser.add_argument('--bpe_model', - default=None, - type=str, - help='bpe model for english part') parser.add_argument('--ln_run_on_bpu', action='store_true', help='layernorm running on bpu') @@ -418,15 +410,14 @@ def get_args(): os.environ['CUDA_VISIBLE_DEVICES'] = '-1' with open(args.config, 'r') as fin: - conf = yaml.load(fin, Loader=yaml.FullLoader) + configs = yaml.load(fin, Loader=yaml.FullLoader) - model = init_model(conf) + model = init_model(args, configs) load_checkpoint(model, args.checkpoint) + tokenizer = init_tokenizer(configs) model.eval() - symbol_table = read_symbol_table(args.dict) - args.symbol_table = symbol_table - args.feature_size = conf['input_dim'] + args.feature_size = configs['input_dim'] args.output_size = model.encoder.output_size() args.decoding_window = (args.chunk_size - 1) * \ model.encoder.embed.subsampling_rate + \ @@ -436,7 +427,7 @@ def get_args(): enc, enc_session = export_encoder(model, args) ctc, ctc_session = export_ctc(model, args) - conf = copy.deepcopy(conf['dataset_conf']) + conf = copy.deepcopy(configs['dataset_conf']) conf['filter_conf']['max_length'] = 102400 conf['filter_conf']['min_length'] = 0 conf['filter_conf']['token_max_length'] = 102400 @@ -483,7 +474,7 @@ def get_args(): generate_config(enc_session, ctc_session, args) logger.info("Stage-3: Make calibration data") - make_calibration_data(enc, args, conf) + make_calibration_data(enc, args, conf, tokenizer) output_dir = os.path.realpath(args.output_dir) logger.info("Stage-4: Make ctc.bin") @@ -502,7 +493,7 @@ def get_args(): logger.info( "Stage-6: Check wer between torch model and quantized onnx") assert args.wer_text is not None - check_wer(enc, ctc, args, conf) + check_wer(enc, ctc, args, conf, tokenizer) os.system( "python3 tools/compute-wer.py --char=1 --v=1 {} {} > {}".format( args.wer_text, args.output_dir + "/torch_text", diff --git a/wenet/bin/alignment.py b/wenet/bin/alignment.py index 1d510a9d8..2bfda7bcb 100644 --- a/wenet/bin/alignment.py +++ b/wenet/bin/alignment.py @@ -201,8 +201,7 @@ def get_labformat(timestamp, subsample): ali_conf['batch_conf']['batch_type'] = "static" ali_conf['batch_conf']['batch_size'] = args.batch_size - tokenizer = init_tokenizer(ali_conf, args.dict, args.bpe_model, - args.non_lang_syms) + tokenizer = init_tokenizer(configs) ali_dataset = Dataset(args.data_type, args.input_file, tokenizer, diff --git a/wenet/bin/export_onnx_gpu.py b/wenet/bin/export_onnx_gpu.py index 9832519ee..9b4620802 100644 --- a/wenet/bin/export_onnx_gpu.py +++ b/wenet/bin/export_onnx_gpu.py @@ -1195,7 +1195,13 @@ def export_rescoring_decoder(model, configs, args, logger, decoder_onnx_path, with open(args.config, "r") as fin: configs = yaml.load(fin, Loader=yaml.FullLoader) if args.cmvn_file and os.path.exists(args.cmvn_file): - configs["cmvn_file"] = args.cmvn_file + if 'cmvn' not in configs: + configs['cmvn'] = "global_cmvn" + configs['cmvn_conf'] = {} + else: + assert configs['cmvn'] == "global_cmvn" + assert configs['cmvn']['cmvn_conf'] is not None + configs['cmvn_conf']["cmvn_file"] = args.cmvn_file if (args.reverse_weight != -1.0 and "reverse_weight" in configs["model_conf"]): configs["model_conf"]["reverse_weight"] = args.reverse_weight diff --git a/wenet/bin/recognize.py b/wenet/bin/recognize.py index 699ecd1df..7a7d800fa 100644 --- a/wenet/bin/recognize.py +++ b/wenet/bin/recognize.py @@ -44,10 +44,6 @@ def get_args(): default=-1, help='gpu id for this rank, -1 for cpu') parser.add_argument('--checkpoint', required=True, help='checkpoint model') - parser.add_argument('--dict', required=True, help='dict file') - parser.add_argument( - "--non_lang_syms", - help="non-linguistic symbol file. One symbol per line.") parser.add_argument('--beam_size', type=int, default=10, @@ -121,18 +117,10 @@ def get_args(): default=0.0, help='''right to left weight for attention rescoring decode mode''') - parser.add_argument('--bpe_model', - default=None, - type=str, - help='bpe model for english part') parser.add_argument('--override_config', action='append', default=[], help="override yaml config") - parser.add_argument('--connect_symbol', - default='', - type=str, - help='used to connect the output characters') parser.add_argument('--word', default='', @@ -207,8 +195,7 @@ def main(): test_conf['batch_conf']['batch_type'] = "static" test_conf['batch_conf']['batch_size'] = args.batch_size - tokenizer = init_tokenizer(configs, args.dict, args.bpe_model, - args.non_lang_syms) + tokenizer = init_tokenizer(configs) test_dataset = Dataset(args.data_type, args.test_data, tokenizer, @@ -229,7 +216,8 @@ def main(): context_graph = None if 'decoding-graph' in args.context_bias_mode: context_graph = ContextGraph(args.context_list_path, - tokenizer.symbol_table, args.bpe_model, + tokenizer.symbol_table, + configs['tokenizer_conf']['bpe_path'], args.context_graph_score) _, blank_id = get_blank_id(configs, tokenizer.symbol_table) diff --git a/wenet/bin/recognize_onnx_gpu.py b/wenet/bin/recognize_onnx_gpu.py index 5a5cea66c..3fb0d8bbb 100644 --- a/wenet/bin/recognize_onnx_gpu.py +++ b/wenet/bin/recognize_onnx_gpu.py @@ -139,7 +139,7 @@ def main(): test_conf['batch_conf']['batch_type'] = "static" test_conf['batch_conf']['batch_size'] = args.batch_size - tokenizer = init_tokenizer(test_conf, args.dict, args.bpe_model) + tokenizer = init_tokenizer(configs) test_dataset = Dataset(args.data_type, args.test_data, tokenizer, diff --git a/wenet/bin/train.py b/wenet/bin/train.py index 2cf70a8b8..f6c3e5c4a 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -73,8 +73,7 @@ def main(): configs = override_config(configs, args.override_config) # init tokenizer - tokenizer = init_tokenizer(configs, args.symbol_table, args.bpe_model, - args.non_lang_syms) + tokenizer = init_tokenizer(configs) # Init env for ddp OR deepspeed world_size, local_rank, rank = init_distributed(args) diff --git a/wenet/paraformer/convert_paraformer_to_wenet_config_and_ckpt.py b/wenet/paraformer/convert_paraformer_to_wenet_config_and_ckpt.py index 587a3c077..9379d4690 100644 --- a/wenet/paraformer/convert_paraformer_to_wenet_config_and_ckpt.py +++ b/wenet/paraformer/convert_paraformer_to_wenet_config_and_ckpt.py @@ -269,9 +269,10 @@ def main(): configs, seg_dict) configs['output_dim'] = vocab_size configs['model'] = 'paraformer' - configs['is_json_cmvn'] = True - configs['cmvn_file'] = json_cmvn_path - # configs['input_dim'] = 80 + configs['cmvn'] = "global_cmvn" + configs['cmvn_conf'] = {} + configs['cmvn_conf']['is_json_cmvn'] = True + configs['cmvn_conf']['cmvn_file'] = json_cmvn_path fields_to_keep = [ 'model', 'encoder_conf', 'decoder_conf', 'predictor_conf', 'input_dim', 'output_dim', 'cmvn_file', 'is_json_cmvn', 'model_conf', 'paraformer', diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index 5e6680fc1..4b116bb57 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -82,8 +82,10 @@ def init_model(args, configs): - if configs.get('cmvn_file', None) is not None: - mean, istd = load_cmvn(configs['cmvn_file'], configs['is_json_cmvn']) + # TODO(xcsong): Forcefully read the 'cmvn' attribute. + if configs.get('cmvn', None) == 'global_cmvn': + mean, istd = load_cmvn(configs['cmvn_conf']['cmvn_file'], + configs['cmvn_conf']['is_json_cmvn']) global_cmvn = GlobalCMVN( torch.from_numpy(mean).float(), torch.from_numpy(istd).float()) diff --git a/wenet/utils/init_tokenizer.py b/wenet/utils/init_tokenizer.py index f16ae0493..2ce8fdd91 100644 --- a/wenet/utils/init_tokenizer.py +++ b/wenet/utils/init_tokenizer.py @@ -1,3 +1,20 @@ +# Copyright (c) 2023 Wenet Community. (authors: Dinghao Zhou) +# (authors: Xingchen Song) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + from wenet.text.base_tokenizer import BaseTokenizer from wenet.text.bpe_tokenizer import BpeTokenizer from wenet.text.char_tokenizer import CharTokenizer @@ -5,30 +22,32 @@ from wenet.text.whisper_tokenizer import WhisperTokenizer -def init_tokenizer(configs, - symbol_table, - bpe_model=None, - non_lang_syms=None) -> BaseTokenizer: - if configs.get("whisper", False): +def init_tokenizer(configs) -> BaseTokenizer: + # TODO(xcsong): Forcefully read the 'tokenizer' attribute. + tokenizer_type = configs.get("tokenizer", "char") + if tokenizer_type == "whisper": tokenizer = WhisperTokenizer( - multilingual=configs['whisper_conf']['is_multilingual'], - num_languages=configs['whisper_conf']['num_languages']) - elif configs.get("tokenizer", "char") == 'paraformer': - assert 'tokenizer' in configs - assert 'tokenizer_conf' in configs - assert symbol_table == configs['tokenizer_conf']['symbol_table_path'] - return ParaformerTokenizer( + multilingual=configs['tokenizer_conf']['is_multilingual'], + num_languages=configs['tokenizer_conf']['num_languages']) + elif tokenizer_type == "char": + tokenizer = CharTokenizer( + configs['tokenizer_conf']['symbol_table_path'], + configs['tokenizer_conf']['non_lang_syms_path'], + split_with_space=configs['tokenizer_conf'].get( + 'split_with_space', False)) + elif tokenizer_type == "bpe": + tokenizer = BpeTokenizer( + configs['tokenizer_conf']['bpe_path'], + configs['tokenizer_conf']['symbol_table_path'], + configs['tokenizer_conf']['non_lang_syms_path'], + split_with_space=configs['tokenizer_conf'].get( + 'split_with_space', False)) + elif tokenizer_type == 'paraformer': + tokenizer = ParaformerTokenizer( symbol_table=configs['tokenizer_conf']['symbol_table_path'], seg_dict=configs['tokenizer_conf']['seg_dict_path']) - elif bpe_model is None: - tokenizer = CharTokenizer(symbol_table, - non_lang_syms, - split_with_space=configs.get( - 'split_with_space', False)) else: - tokenizer = BpeTokenizer(bpe_model, - symbol_table, - split_with_space=configs.get( - 'split_with_space', False)) + raise NotImplementedError + logging.info("use {} tokenizer".format(configs["tokenizer"])) return tokenizer diff --git a/wenet/utils/train_utils.py b/wenet/utils/train_utils.py index e8ab61197..465eb9f28 100644 --- a/wenet/utils/train_utils.py +++ b/wenet/utils/train_utils.py @@ -47,17 +47,6 @@ def add_model_args(parser): parser.add_argument('--tensorboard_dir', default='tensorboard', help='tensorboard log dir') - parser.add_argument('--cmvn', default=None, help='global cmvn file') - parser.add_argument('--symbol_table', - required=True, - help='model unit symbol table for training') - parser.add_argument( - "--non_lang_syms", - help="non-linguistic symbol file. One symbol per line.") - parser.add_argument('--bpe_model', - default=None, - type=str, - help='bpe model for english part') parser.add_argument('--override_config', action='append', default=[], @@ -78,10 +67,6 @@ def add_model_args(parser): type=lambda s: [str(mod) for mod in s.split(",") if s != ""], help='free module names', ) - parser.add_argument('--lfmmi_dir', - default='', - required=False, - help='LF-MMI dir') return parser @@ -229,9 +214,6 @@ def check_modify_and_save_config(args, configs, symbol_table): configs['input_dim'] = input_dim configs['output_dim'] = configs['vocab_size'] - configs['cmvn_file'] = args.cmvn - configs['is_json_cmvn'] = True - configs['lfmmi_dir'] = args.lfmmi_dir configs['train_engine'] = args.train_engine configs['use_amp'] = args.use_amp diff --git a/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py b/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py index 4a51fb9b5..932a0d63a 100644 --- a/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py +++ b/wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py @@ -114,12 +114,18 @@ def convert_to_wenet_yaml(tokenizer, dims, wenet_yaml_path: str): configs['ctc_conf'] = {} configs['ctc_conf']['ctc_blank_id'] = tokenizer.no_speech + configs['cmvn'] = None + configs['cmvn_conf'] = {} + configs['cmvn_conf']['cmvn_file'] = None + configs['cmvn_conf']['is_json_cmvn'] = None + configs['model'] = "whisper" configs['model_conf'] = {} configs['model_conf']['ctc_weight'] = 0.3 configs['model_conf']['lsm_weight'] = 0.1 configs['model_conf']['length_normalized_loss'] = False + configs['dataset'] = "asr" configs['dataset_conf'] = {} configs['dataset_conf']['filter_conf'] = {} configs['dataset_conf']['filter_conf'][ @@ -138,7 +144,10 @@ def convert_to_wenet_yaml(tokenizer, dims, wenet_yaml_path: str): configs['dataset_conf']['spec_aug_conf']['num_f_mask'] = 2 configs['dataset_conf']['spec_aug_conf']['max_t'] = 50 configs['dataset_conf']['spec_aug_conf']['max_f'] = 10 - configs['dataset_conf']['spec_sub'] = False + configs['dataset_conf']['spec_sub'] = True + configs['dataset_conf']['spec_sub_conf'] = {} + configs['dataset_conf']['spec_sub_conf']['num_t_sub'] = 3 + configs['dataset_conf']['spec_sub_conf']['max_t'] = 30 configs['dataset_conf']['spec_trim'] = False configs['dataset_conf']['shuffle'] = True configs['dataset_conf']['shuffle_conf'] = {} @@ -159,16 +168,16 @@ def convert_to_wenet_yaml(tokenizer, dims, wenet_yaml_path: str): configs['dataset_conf']['batch_conf']['max_frames_in_batch'] = 12000 configs['grad_clip'] = 5 - configs['accum_grad'] = 1 + configs['accum_grad'] = 4 configs['max_epoch'] = 100 configs['log_interval'] = 100 configs['optim'] = "adam" configs['optim_conf'] = {} - configs['optim_conf']['lr'] = 0.002 + configs['optim_conf']['lr'] = 0.0005 configs['scheduler'] = "warmuplr" configs['scheduler_conf'] = {} - configs['scheduler_conf']['warmup_steps'] = 25000 + configs['scheduler_conf']['warmup_steps'] = 12000 with open(wenet_yaml_path, '+w') as f: f.write(yaml.dump(configs))