pFad - Phone/Frame/Anonymizer/Declutterfier! Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

URL: http://github.com/speechbrain/speechbrain/pull/2675/files

/assets/primer-primitives-26e89bb5a0c37ae9.css" /> Multi-Window Multi-Head Attention implementation for ASR transformer by NikolaiKyhne · Pull Request #2675 · speechbrain/speechbrain · GitHub
Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9d765e4
Create multiwindow_attention.py
NikolaiKyhne Sep 9, 2024
4f63396
added mwmha option
NikolaiKyhne Sep 9, 2024
c7a8fd3
Update TransformerASR.py with mwmha option
NikolaiKyhne Sep 9, 2024
c24435c
added medium recipe for mwmha transformer
NikolaiKyhne Sep 9, 2024
e04f3db
added small mwmha transformer recipe
NikolaiKyhne Sep 9, 2024
7599854
added section about mwmha
NikolaiKyhne Sep 9, 2024
740552d
Merge branch 'develop' into mwmha_final
SarthakYadav Sep 10, 2024
dc472e7
pre-commit fixes
SarthakYadav Sep 10, 2024
bf6c747
Added large mwmha transformer recipe for CommomVoice english
NikolaiKyhne Sep 10, 2024
fc16b25
Update output folder name to match other recipes
NikolaiKyhne Sep 10, 2024
0d9e0df
Added about MWMWA section
NikolaiKyhne Sep 10, 2024
96d1e25
Added model links for MWMHA recipes
NikolaiKyhne Sep 10, 2024
9741a32
Added model link for MWMHA recipe
NikolaiKyhne Sep 10, 2024
1537e15
Added mwmha recipes to the list
NikolaiKyhne Sep 10, 2024
3e76bee
Added mwmha transformer recipe to the list
NikolaiKyhne Sep 10, 2024
55fa0b9
updated about sections for MW-MHA
SarthakYadav Sep 10, 2024
512ad12
updated docstrings
SarthakYadav Sep 10, 2024
588a66a
refactoring for pre-commit fixes
NikolaiKyhne Sep 10, 2024
3bf2799
refactoring for pre-commit fixes
NikolaiKyhne Sep 10, 2024
df4f255
pre-commit fixes (getting rid of unnecessary stuff)
NikolaiKyhne Sep 10, 2024
427ebed
pre-commit fixes (getting rid of unnecessary stuff)
NikolaiKyhne Sep 10, 2024
401420b
fix examples
SarthakYadav Sep 11, 2024
e03be4a
Merge branch 'develop' into mwmha_final
SarthakYadav Sep 11, 2024
eec182b
fix flake8 violations
SarthakYadav Sep 11, 2024
6f298a6
fix MWMHA example
SarthakYadav Sep 11, 2024
bf80f66
Merge branch 'speechbrain:develop' into mwmha_final
SarthakYadav Sep 11, 2024
8fe5202
update trunc_normal_ example
SarthakYadav Sep 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions recipes/CommonVoice/ASR/transformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,31 @@ For Whisper-large-v2 and medium finetuning, here is list of the different langua
## Transformer
| Language | CV version | hyperparams file | LM | Val. CER | Val. WER | Test CER | Test WER | Hugging Face link | Model link | GPUs |
| ------------- |:-------------:|:---------------------------:| -----:| -----:| -----:| -----:| -----:|:-----------:| :-----------:| :-----------:|
| English | 16.1 | mwmha_transformer_large.yaml | No | 4.72 | 10.97 | 6.68 | 13.69 | - | [model](https://1drv.ms/f/c/039f8ffe91e06416/Et7KEbSlWNdJhkjLIi7_vGQBMVhGwRRBzCSljh6aA4sJSw?e=dXeuiY) | 1xL40 48GB |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the Val WER so high? I think you swapped CER and WER right ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No that's right I just double checked, it is the same for Conformer English on CV 16.1 :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Adel-Moumen Val WER for MWMHA (10.97) follows the same trend and is quite close to that of the Conformer model (10.48) and is reported correctly, CER and WER are not swapped.

| English | 16.1 | conformer_large.yaml | No | 4.48 | 10.48 | 6.42 | 13.39 | - | [model](https://www.dropbox.com/scl/fo/3w24pxln0fjyofl6xbfv1/AJJqzWfCtGFFTRLwM3DeZG8?rlkey=wpzzhizreedptts64d2m9jq4u&st=xu5g9an8&dl=0) | 4xA40 46GB |
| Italian | 14.0 | conformer_large.yaml | No | 2.91 | 9.79 | 2.68 | 9.27 | - | [model](https://www.dropbox.com/scl/fo/tf44itp8f4icf2z5qlxpm/AIOYS_CMov5ss5Q9AonFEno?rlkey=xek5ikbhqoovcao31iniqimrr&dl=0) | 2xV100 32GB |
| French | 14.0 | conformer_large.yaml | No | 2.64 | 7.62 | 3.55 | 9.48 | - | [model](https://www.dropbox.com/scl/fo/y862nl95zoe4sj3347095/ACxmT3_uw1ScLoYs0DSbGRM?rlkey=q66dk13w5nu1lkphtdinnnigm&dl=0) | 2xV100 32GB |

### **About MW-MHA Transformer**
Multi-Window Multi-Head Attention (MW-MHA) is a new Multi-Head attention module where the constituent individual attention heads operate on different local sizes of the input sequence, capturing local-global dependencies more effectively. The method was proposed in the paper "Masked Autoencoders with Multi-Window Local-Global Attention Are Better Audio Learners" by Yadav et al. (2024), where it was shown to capture better local-global dependencies when learning general-purpose audio representations.

Here, we simply replaced the standard MHA in the transformer encoder with MW-MHA, achieving performance quite close to that of a Conformer model with no additional parameters or modifications. You can learn more about MW-MHA through the following links:

- Paper: https://openreview.net/forum?id=Q53QLftNkA
- Code: https://github.com/SarthakYadav/mwmae-jax-official

If you use MW-MHA in your work, please cite the following paper:

```bibtex
@inproceedings{
yadav2024masked,
title={Masked Autoencoders with Multi-Window Local-Global Attention Are Better Audio Learners},
author={Sarthak Yadav and Sergios Theodoridis and Lars Kai Hansen and Zheng-Hua Tan},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=Q53QLftNkA}
}
```

## Whisper Finetuning
Following table contains whisper-finetuning results for 1 epoch using Whisper model, freezing encoder and finetuning decoder.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
# ############################################################################
# Model: E2E ASR with Transformer
# Encoder: Transformer Encoder
# Decoder: Transformer Decoder + (CTC/ATT joint) beamsearch
# Tokens: unigram
# losses: CTC + KLdiv (Label Smoothing loss)
# Authors: Nikolai Lund Kühne, Sarthak Yadav
# ############################################################################
# Seed needs to be set at top of yaml, before objects with parameters are made

seed: 3407
__set_seed: !apply:torch.manual_seed [!ref <seed>]
output_folder: !ref results/mwmha_transformer_large/<seed>
output_wer_folder: !ref <output_folder>/
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt

# Data files
data_folder: !PLACEHOLDER # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
train_tsv_file: !ref <data_folder>/train.tsv # Standard CommonVoice .tsv files
dev_tsv_file: !ref <data_folder>/dev.tsv # Standard CommonVoice .tsv files
test_tsv_file: !ref <data_folder>/test.tsv # Standard CommonVoice .tsv files
accented_letters: False
language: en # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
train_csv: !ref <output_folder>/train.csv
valid_csv: !ref <output_folder>/dev.csv
test_csv: !ref <output_folder>/test.csv
skip_prep: False # Skip data preparation
convert_to_wav: False # Switch this to True to convert all mp3 files to wav.

# We remove utterance slonger than 10s in the train/dev/test sets as
# longer sentences certainly correspond to "open microphones".
avoid_if_longer_than: 10.0

# THIS IS TERRIBLE BUT WE HAVE NO CHOICE.
# Some version of the CV dataset may contain one or two files of more than
# 2 min in the validation and or test. This is an error by design of the dataset
# as these files contain 90% of silence. We exclude them.
avoid_if_longer_than_val_test: 90.0

ckpt_interval_minutes: 15 # save checkpoint every N min

####################### Training Parameters ####################################
number_of_epochs: 78
optimizer_step_limit: 250000
batch_size: 32 # This works with a 32GB GPU ! (bs * nb_gpu * accum) > 128 !
ctc_weight: 0.3
grad_accumulation_factor: 4
loss_reduction: 'batchmean'
sorting: random
num_workers: 4
precision: fp32 # bf16, fp16 or fp32

# stages related parameters
lr_adam: 0.0008
weight_decay: 0.01
warmup_steps: 25000
augment_warmup: 8000

# BPE parameters
token_type: unigram # ["unigram", "bpe", "char"]
character_coverage: 1.0

# Feature parameters
sample_rate: 16000
n_fft: 400
n_mels: 80

# This setup works well for A100 80GB GPU, adapts it to your needs.
# Or turn it off (but training speed will decrease)
dynamic_batching: True
max_batch_length_train: 500
max_batch_length_val: 100 # we reduce it as the beam is much wider (VRAM)
num_bucket: 200
shuffle: True # if true re-creates batches at each epoch shuffling examples.
batch_ordering: random
max_batch_ex: 256

dynamic_batch_sampler_train:
max_batch_length: !ref <max_batch_length_train>
num_buckets: !ref <num_bucket>
shuffle: !ref <shuffle>
batch_ordering: !ref <batch_ordering>
max_batch_ex: !ref <max_batch_ex>

dynamic_batch_sampler_valid:
max_batch_length: !ref <max_batch_length_val>
num_buckets: !ref <num_bucket>
shuffle: !ref <shuffle>
batch_ordering: !ref <batch_ordering>
max_batch_ex: !ref <max_batch_ex>

# Dataloader options
train_dataloader_opts:
batch_size: !ref <batch_size>
shuffle: True
num_workers: !ref <num_workers>
collate_fn: !name:speechbrain.dataio.batch.PaddedBatch
padding_kwargs:
value: !ref <pad_index>

valid_dataloader_opts:
batch_size: 1
collate_fn: !name:speechbrain.dataio.batch.PaddedBatch
padding_kwargs:
value: !ref <pad_index>

test_dataloader_opts:
batch_size: 1
collate_fn: !name:speechbrain.dataio.batch.PaddedBatch
padding_kwargs:
value: !ref <pad_index>

####################### Model Parameters ###########################
# Transformer
d_model: 512
nhead: 8
num_encoder_layers: 24
num_decoder_layers: 6
d_ffn: 2048
transformer_dropout: 0.1
activation: !name:torch.nn.GELU
output_neurons: 5120

# Outputs
blank_index: 0
label_smoothing: 0.1
pad_index: 0
bos_index: 1
eos_index: 2

# Decoding parameters
min_decode_ratio: 0.0
max_decode_ratio: 1.0
valid_search_interval: 10
valid_beam_size: 1 # We do greedy here so it's faster to decode ...
test_beam_size: 80
ctc_weight_decode: 0.3
scorer_beam_scale: 0.3

############################## models ################################

CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd
input_shape: (8, 10, 80)
num_blocks: 2
num_layers_per_block: 1
out_channels: (64, 32)
kernel_sizes: (3, 3)
strides: (2, 2)
residuals: (False, False)

Transformer: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR # yamllint disable-line rule:line-length
input_size: 640
tgt_vocab: !ref <output_neurons>
d_model: !ref <d_model>
nhead: !ref <nhead>
num_encoder_layers: !ref <num_encoder_layers>
num_decoder_layers: !ref <num_decoder_layers>
d_ffn: !ref <d_ffn>
dropout: !ref <transformer_dropout>
conformer_activation: !ref <activation>
activation: !ref <activation>
encoder_module: transformer
attention_type: MWMHA
normalize_before: True
causal: False
mwmha_windows: [3, 4, 5, 8, 8, 128, 0, 0]

ctc_lin: !new:speechbrain.nnet.linear.Linear
input_size: !ref <d_model>
n_neurons: !ref <output_neurons>

seq_lin: !new:speechbrain.nnet.linear.Linear
input_size: !ref <d_model>
n_neurons: !ref <output_neurons>

modules:
CNN: !ref <CNN>
Transformer: !ref <Transformer>
seq_lin: !ref <seq_lin>
ctc_lin: !ref <ctc_lin>

model: !new:torch.nn.ModuleList
- [!ref <CNN>, !ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]

# We define two optimizers as we have two stages (training + finetuning)
Adam: !name:torch.optim.AdamW
lr: !ref <lr_adam>
weight_decay: !ref <weight_decay>

# Scorer
ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer
eos_index: !ref <eos_index>
blank_index: !ref <blank_index>
ctc_fc: !ref <ctc_lin>

scorer: !new:speechbrain.decoders.scorer.ScorerBuilder
full_scorers: [!ref <ctc_scorer>]
weights:
ctc: !ref <ctc_weight_decode>
scorer_beam_scale: !ref <scorer_beam_scale>

valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
modules: [!ref <Transformer>, !ref <seq_lin>]
bos_index: !ref <bos_index>
eos_index: !ref <eos_index>
min_decode_ratio: !ref <min_decode_ratio>
max_decode_ratio: !ref <max_decode_ratio>
beam_size: !ref <valid_beam_size>
using_eos_threshold: False
length_normalization: True
scorer: !ref <scorer>

test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher
modules: [!ref <Transformer>, !ref <seq_lin>]
bos_index: !ref <bos_index>
eos_index: !ref <eos_index>
min_decode_ratio: !ref <min_decode_ratio>
max_decode_ratio: !ref <max_decode_ratio>
beam_size: !ref <test_beam_size>
temperature: 1.15
using_eos_threshold: True
scorer: !ref <scorer>

log_softmax: !new:torch.nn.LogSoftmax
dim: -1

ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
blank_index: !ref <blank_index>
reduction: !ref <loss_reduction>

seq_cost: !name:speechbrain.nnet.losses.kldiv_loss
label_smoothing: !ref <label_smoothing>
reduction: !ref <loss_reduction>

noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler
lr_initial: !ref <lr_adam>
n_warmup_steps: !ref <warmup_steps>

checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
checkpoints_dir: !ref <save_folder>
recoverables:
model: !ref <model>
noam_scheduler: !ref <noam_annealing>
normalizer: !ref <normalize>
counter: !ref <epoch_counter>

epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
limit: !ref <number_of_epochs>

normalize: !new:speechbrain.processing.features.InputNormalization
norm_type: global
update_until_epoch: 4

############################## Augmentations ###################################

# Time Drop
time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
drop_length_low: 15
drop_length_high: 25
drop_count_low: 3
drop_count_high: 3
replace: "zeros"
dim: 1

# Frequency Drop
freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop
drop_length_low: 25
drop_length_high: 35
drop_count_low: 2
drop_count_high: 2
replace: "zeros"
dim: 2

# Time warp
time_warp: !new:speechbrain.augment.freq_domain.Warping

fea_augment: !new:speechbrain.augment.augmenter.Augmenter
min_augmentations: 3
max_augmentations: 3
augment_prob: 1.0
augmentations: [
!ref <time_drop>,
!ref <freq_drop>,
!ref <time_warp>]

compute_features: !new:speechbrain.lobes.features.Fbank
sample_rate: !ref <sample_rate>
n_fft: !ref <n_fft>
n_mels: !ref <n_mels>

train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
save_file: !ref <train_log>

error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
acc_computer: !name:speechbrain.utils.Accuracy.AccuracyStats
cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
split_tokens: True
24 changes: 24 additions & 0 deletions recipes/LibriSpeech/ASR/transformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ Following table contains whisper-finetuning results for 1 epoch using Whisper mo
| 03-09-23 | hyperbranchformer_13M.yaml | NA | 2.54 | 6.58 | Not Avail. | Not Avail. | 1xP40 24GB
| 03-09-23 | hyperbranchformer_25M.yaml | NA | 2.36 | 5.89 | Not Avail. | Not Avail. | 1xP40 24GB
| 05-01-24 | bayesspeech.yaml | 4.28 | 2.84 | 6.27 | Not Avail. | [DropBox](https://www.dropbox.com/scl/fo/cdken4jqfj96ev1v84jxm/h?rlkey=25eu1ytgm5ac51zqj8p65zwxd&dl=0) | 1xV100 32GB |
| 07-06-24 | mwmha_transformer_small.yaml | 4.60 | 2.66 | 6.50 (**only 12.7M parameters**) | NA | [OneDrive](https://1drv.ms/f/c/039f8ffe91e06416/EkeVvQiD9lREpHflL8DkpuYBeMFJkTxyzCh0DOEoGMNLgw?e=3vuycE) | 1xA40 48GB |
| 07-06-24 | mwmha_transformer_medium.yaml | 3.55 | 2.26 | 5.66 (**only 39.9M parameters**) | NA | [OneDrive](https://1drv.ms/f/c/039f8ffe91e06416/EvX2HV5lfY9GvbB8iKK18-kBMMtdNVDYtCPvcMk4aB4jfw?e=2gV9x0) | 1xA40 48GB |


# **About HyperConformer**
HyperConformer is a new architecture, which replaces the self-attention mechanism of Conformer with the linear-time token mixing architecture HyperMixer.
Expand All @@ -67,6 +70,27 @@ Please cite HyperConformer if you use it for your research or business.
}
```

# **About MW-MHA Transformer**
Multi-Window Multi-Head Attention (MW-MHA) is a new Multi-Head attention module where the constituent individual attention heads operate on different local sizes of the input sequence, capturing local-global dependencies more effectively. The method was proposed in the paper "Masked Autoencoders with Multi-Window Local-Global Attention Are Better Audio Learners" by Yadav et al. (2024), where it was shown to capture better local-global dependencies when learning general-purpose audio representations.

Here, we simply replaced the standard MHA in the transformer encoder with MW-MHA, achieving performance quite close to that of a Conformer model with no additional parameters or modifications. You can learn more about MW-MHA through the following links:

- Paper: https://openreview.net/forum?id=Q53QLftNkA
- Code: https://github.com/SarthakYadav/mwmae-jax-official

If you use MW-MHA in your work, please cite the following paper:

```bibtex
@inproceedings{
yadav2024masked,
title={Masked Autoencoders with Multi-Window Local-Global Attention Are Better Audio Learners},
author={Sarthak Yadav and Sergios Theodoridis and Lars Kai Hansen and Zheng-Hua Tan},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=Q53QLftNkA}
}
```

# **About SpeechBrain**
- Website: https://speechbrain.github.io/
- Code: https://github.com/speechbrain/speechbrain/
Expand Down
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad © 2024 Your Company Name. All rights reserved.





Check this box to remove all script contents from the fetched content.



Check this box to remove all images from the fetched content.


Check this box to remove all CSS styles from the fetched content.


Check this box to keep images inefficiently compressed and original size.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy