-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Multi-Window Multi-Head Attention implementation for ASR transformer #2675
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
NikolaiKyhne
wants to merge
27
commits into
speechbrain:develop
Choose a base branch
from
NikolaiKyhne:mwmha_final
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
9d765e4
Create multiwindow_attention.py
NikolaiKyhne 4f63396
added mwmha option
NikolaiKyhne c7a8fd3
Update TransformerASR.py with mwmha option
NikolaiKyhne c24435c
added medium recipe for mwmha transformer
NikolaiKyhne e04f3db
added small mwmha transformer recipe
NikolaiKyhne 7599854
added section about mwmha
NikolaiKyhne 740552d
Merge branch 'develop' into mwmha_final
SarthakYadav dc472e7
pre-commit fixes
SarthakYadav bf6c747
Added large mwmha transformer recipe for CommomVoice english
NikolaiKyhne fc16b25
Update output folder name to match other recipes
NikolaiKyhne 0d9e0df
Added about MWMWA section
NikolaiKyhne 96d1e25
Added model links for MWMHA recipes
NikolaiKyhne 9741a32
Added model link for MWMHA recipe
NikolaiKyhne 1537e15
Added mwmha recipes to the list
NikolaiKyhne 3e76bee
Added mwmha transformer recipe to the list
NikolaiKyhne 55fa0b9
updated about sections for MW-MHA
SarthakYadav 512ad12
updated docstrings
SarthakYadav 588a66a
refactoring for pre-commit fixes
NikolaiKyhne 3bf2799
refactoring for pre-commit fixes
NikolaiKyhne df4f255
pre-commit fixes (getting rid of unnecessary stuff)
NikolaiKyhne 427ebed
pre-commit fixes (getting rid of unnecessary stuff)
NikolaiKyhne 401420b
fix examples
SarthakYadav e03be4a
Merge branch 'develop' into mwmha_final
SarthakYadav eec182b
fix flake8 violations
SarthakYadav 6f298a6
fix MWMHA example
SarthakYadav bf80f66
Merge branch 'speechbrain:develop' into mwmha_final
SarthakYadav 8fe5202
update trunc_normal_ example
SarthakYadav File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
298 changes: 298 additions & 0 deletions
298
recipes/CommonVoice/ASR/transformer/hparams/mwmha_transformer_large.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,298 @@ | ||
| # ############################################################################ | ||
| # Model: E2E ASR with Transformer | ||
| # Encoder: Transformer Encoder | ||
| # Decoder: Transformer Decoder + (CTC/ATT joint) beamsearch | ||
| # Tokens: unigram | ||
| # losses: CTC + KLdiv (Label Smoothing loss) | ||
| # Authors: Nikolai Lund Kühne, Sarthak Yadav | ||
| # ############################################################################ | ||
| # Seed needs to be set at top of yaml, before objects with parameters are made | ||
|
|
||
| seed: 3407 | ||
| __set_seed: !apply:torch.manual_seed [!ref <seed>] | ||
| output_folder: !ref results/mwmha_transformer_large/<seed> | ||
| output_wer_folder: !ref <output_folder>/ | ||
| save_folder: !ref <output_folder>/save | ||
| train_log: !ref <output_folder>/train_log.txt | ||
|
|
||
| # Data files | ||
| data_folder: !PLACEHOLDER # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr | ||
| train_tsv_file: !ref <data_folder>/train.tsv # Standard CommonVoice .tsv files | ||
| dev_tsv_file: !ref <data_folder>/dev.tsv # Standard CommonVoice .tsv files | ||
| test_tsv_file: !ref <data_folder>/test.tsv # Standard CommonVoice .tsv files | ||
| accented_letters: False | ||
| language: en # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english | ||
| train_csv: !ref <output_folder>/train.csv | ||
| valid_csv: !ref <output_folder>/dev.csv | ||
| test_csv: !ref <output_folder>/test.csv | ||
| skip_prep: False # Skip data preparation | ||
| convert_to_wav: False # Switch this to True to convert all mp3 files to wav. | ||
|
|
||
| # We remove utterance slonger than 10s in the train/dev/test sets as | ||
| # longer sentences certainly correspond to "open microphones". | ||
| avoid_if_longer_than: 10.0 | ||
|
|
||
| # THIS IS TERRIBLE BUT WE HAVE NO CHOICE. | ||
| # Some version of the CV dataset may contain one or two files of more than | ||
| # 2 min in the validation and or test. This is an error by design of the dataset | ||
| # as these files contain 90% of silence. We exclude them. | ||
| avoid_if_longer_than_val_test: 90.0 | ||
|
|
||
| ckpt_interval_minutes: 15 # save checkpoint every N min | ||
|
|
||
| ####################### Training Parameters #################################### | ||
| number_of_epochs: 78 | ||
| optimizer_step_limit: 250000 | ||
| batch_size: 32 # This works with a 32GB GPU ! (bs * nb_gpu * accum) > 128 ! | ||
| ctc_weight: 0.3 | ||
| grad_accumulation_factor: 4 | ||
| loss_reduction: 'batchmean' | ||
| sorting: random | ||
| num_workers: 4 | ||
| precision: fp32 # bf16, fp16 or fp32 | ||
|
|
||
| # stages related parameters | ||
| lr_adam: 0.0008 | ||
| weight_decay: 0.01 | ||
| warmup_steps: 25000 | ||
| augment_warmup: 8000 | ||
|
|
||
| # BPE parameters | ||
| token_type: unigram # ["unigram", "bpe", "char"] | ||
| character_coverage: 1.0 | ||
|
|
||
| # Feature parameters | ||
| sample_rate: 16000 | ||
| n_fft: 400 | ||
| n_mels: 80 | ||
|
|
||
| # This setup works well for A100 80GB GPU, adapts it to your needs. | ||
| # Or turn it off (but training speed will decrease) | ||
| dynamic_batching: True | ||
| max_batch_length_train: 500 | ||
| max_batch_length_val: 100 # we reduce it as the beam is much wider (VRAM) | ||
| num_bucket: 200 | ||
| shuffle: True # if true re-creates batches at each epoch shuffling examples. | ||
| batch_ordering: random | ||
| max_batch_ex: 256 | ||
|
|
||
| dynamic_batch_sampler_train: | ||
| max_batch_length: !ref <max_batch_length_train> | ||
| num_buckets: !ref <num_bucket> | ||
| shuffle: !ref <shuffle> | ||
| batch_ordering: !ref <batch_ordering> | ||
| max_batch_ex: !ref <max_batch_ex> | ||
|
|
||
| dynamic_batch_sampler_valid: | ||
| max_batch_length: !ref <max_batch_length_val> | ||
| num_buckets: !ref <num_bucket> | ||
| shuffle: !ref <shuffle> | ||
| batch_ordering: !ref <batch_ordering> | ||
| max_batch_ex: !ref <max_batch_ex> | ||
|
|
||
| # Dataloader options | ||
| train_dataloader_opts: | ||
| batch_size: !ref <batch_size> | ||
| shuffle: True | ||
| num_workers: !ref <num_workers> | ||
| collate_fn: !name:speechbrain.dataio.batch.PaddedBatch | ||
| padding_kwargs: | ||
| value: !ref <pad_index> | ||
|
|
||
| valid_dataloader_opts: | ||
| batch_size: 1 | ||
| collate_fn: !name:speechbrain.dataio.batch.PaddedBatch | ||
| padding_kwargs: | ||
| value: !ref <pad_index> | ||
|
|
||
| test_dataloader_opts: | ||
| batch_size: 1 | ||
| collate_fn: !name:speechbrain.dataio.batch.PaddedBatch | ||
| padding_kwargs: | ||
| value: !ref <pad_index> | ||
|
|
||
| ####################### Model Parameters ########################### | ||
| # Transformer | ||
| d_model: 512 | ||
| nhead: 8 | ||
| num_encoder_layers: 24 | ||
| num_decoder_layers: 6 | ||
| d_ffn: 2048 | ||
| transformer_dropout: 0.1 | ||
| activation: !name:torch.nn.GELU | ||
| output_neurons: 5120 | ||
|
|
||
| # Outputs | ||
| blank_index: 0 | ||
| label_smoothing: 0.1 | ||
| pad_index: 0 | ||
| bos_index: 1 | ||
| eos_index: 2 | ||
|
|
||
| # Decoding parameters | ||
| min_decode_ratio: 0.0 | ||
| max_decode_ratio: 1.0 | ||
| valid_search_interval: 10 | ||
| valid_beam_size: 1 # We do greedy here so it's faster to decode ... | ||
| test_beam_size: 80 | ||
| ctc_weight_decode: 0.3 | ||
| scorer_beam_scale: 0.3 | ||
|
|
||
| ############################## models ################################ | ||
|
|
||
| CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd | ||
| input_shape: (8, 10, 80) | ||
| num_blocks: 2 | ||
| num_layers_per_block: 1 | ||
| out_channels: (64, 32) | ||
| kernel_sizes: (3, 3) | ||
| strides: (2, 2) | ||
| residuals: (False, False) | ||
|
|
||
| Transformer: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR # yamllint disable-line rule:line-length | ||
| input_size: 640 | ||
| tgt_vocab: !ref <output_neurons> | ||
| d_model: !ref <d_model> | ||
| nhead: !ref <nhead> | ||
| num_encoder_layers: !ref <num_encoder_layers> | ||
| num_decoder_layers: !ref <num_decoder_layers> | ||
| d_ffn: !ref <d_ffn> | ||
| dropout: !ref <transformer_dropout> | ||
| conformer_activation: !ref <activation> | ||
| activation: !ref <activation> | ||
| encoder_module: transformer | ||
| attention_type: MWMHA | ||
| normalize_before: True | ||
| causal: False | ||
| mwmha_windows: [3, 4, 5, 8, 8, 128, 0, 0] | ||
|
|
||
| ctc_lin: !new:speechbrain.nnet.linear.Linear | ||
| input_size: !ref <d_model> | ||
| n_neurons: !ref <output_neurons> | ||
|
|
||
| seq_lin: !new:speechbrain.nnet.linear.Linear | ||
| input_size: !ref <d_model> | ||
| n_neurons: !ref <output_neurons> | ||
|
|
||
| modules: | ||
| CNN: !ref <CNN> | ||
| Transformer: !ref <Transformer> | ||
| seq_lin: !ref <seq_lin> | ||
| ctc_lin: !ref <ctc_lin> | ||
|
|
||
| model: !new:torch.nn.ModuleList | ||
| - [!ref <CNN>, !ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>] | ||
|
|
||
| # We define two optimizers as we have two stages (training + finetuning) | ||
| Adam: !name:torch.optim.AdamW | ||
| lr: !ref <lr_adam> | ||
| weight_decay: !ref <weight_decay> | ||
|
|
||
| # Scorer | ||
| ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer | ||
| eos_index: !ref <eos_index> | ||
| blank_index: !ref <blank_index> | ||
| ctc_fc: !ref <ctc_lin> | ||
|
|
||
| scorer: !new:speechbrain.decoders.scorer.ScorerBuilder | ||
| full_scorers: [!ref <ctc_scorer>] | ||
| weights: | ||
| ctc: !ref <ctc_weight_decode> | ||
| scorer_beam_scale: !ref <scorer_beam_scale> | ||
|
|
||
| valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher | ||
| modules: [!ref <Transformer>, !ref <seq_lin>] | ||
| bos_index: !ref <bos_index> | ||
| eos_index: !ref <eos_index> | ||
| min_decode_ratio: !ref <min_decode_ratio> | ||
| max_decode_ratio: !ref <max_decode_ratio> | ||
| beam_size: !ref <valid_beam_size> | ||
| using_eos_threshold: False | ||
| length_normalization: True | ||
| scorer: !ref <scorer> | ||
|
|
||
| test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher | ||
| modules: [!ref <Transformer>, !ref <seq_lin>] | ||
| bos_index: !ref <bos_index> | ||
| eos_index: !ref <eos_index> | ||
| min_decode_ratio: !ref <min_decode_ratio> | ||
| max_decode_ratio: !ref <max_decode_ratio> | ||
| beam_size: !ref <test_beam_size> | ||
| temperature: 1.15 | ||
| using_eos_threshold: True | ||
| scorer: !ref <scorer> | ||
|
|
||
| log_softmax: !new:torch.nn.LogSoftmax | ||
| dim: -1 | ||
|
|
||
| ctc_cost: !name:speechbrain.nnet.losses.ctc_loss | ||
| blank_index: !ref <blank_index> | ||
| reduction: !ref <loss_reduction> | ||
|
|
||
| seq_cost: !name:speechbrain.nnet.losses.kldiv_loss | ||
| label_smoothing: !ref <label_smoothing> | ||
| reduction: !ref <loss_reduction> | ||
|
|
||
| noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler | ||
| lr_initial: !ref <lr_adam> | ||
| n_warmup_steps: !ref <warmup_steps> | ||
|
|
||
| checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer | ||
| checkpoints_dir: !ref <save_folder> | ||
| recoverables: | ||
| model: !ref <model> | ||
| noam_scheduler: !ref <noam_annealing> | ||
| normalizer: !ref <normalize> | ||
| counter: !ref <epoch_counter> | ||
|
|
||
| epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter | ||
| limit: !ref <number_of_epochs> | ||
|
|
||
| normalize: !new:speechbrain.processing.features.InputNormalization | ||
| norm_type: global | ||
| update_until_epoch: 4 | ||
|
|
||
| ############################## Augmentations ################################### | ||
|
|
||
| # Time Drop | ||
| time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop | ||
| drop_length_low: 15 | ||
| drop_length_high: 25 | ||
| drop_count_low: 3 | ||
| drop_count_high: 3 | ||
| replace: "zeros" | ||
| dim: 1 | ||
|
|
||
| # Frequency Drop | ||
| freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop | ||
| drop_length_low: 25 | ||
| drop_length_high: 35 | ||
| drop_count_low: 2 | ||
| drop_count_high: 2 | ||
| replace: "zeros" | ||
| dim: 2 | ||
|
|
||
| # Time warp | ||
| time_warp: !new:speechbrain.augment.freq_domain.Warping | ||
|
|
||
| fea_augment: !new:speechbrain.augment.augmenter.Augmenter | ||
| min_augmentations: 3 | ||
| max_augmentations: 3 | ||
| augment_prob: 1.0 | ||
| augmentations: [ | ||
| !ref <time_drop>, | ||
| !ref <freq_drop>, | ||
| !ref <time_warp>] | ||
|
|
||
| compute_features: !new:speechbrain.lobes.features.Fbank | ||
| sample_rate: !ref <sample_rate> | ||
| n_fft: !ref <n_fft> | ||
| n_mels: !ref <n_mels> | ||
|
|
||
| train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger | ||
| save_file: !ref <train_log> | ||
|
|
||
| error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats | ||
| acc_computer: !name:speechbrain.utils.Accuracy.AccuracyStats | ||
| cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats | ||
| split_tokens: True |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the Val WER so high? I think you swapped CER and WER right ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No that's right I just double checked, it is the same for Conformer English on CV 16.1 :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Adel-Moumen Val WER for MWMHA (10.97) follows the same trend and is quite close to that of the Conformer model (10.48) and is reported correctly, CER and WER are not swapped.