Layer
| Channel * Input Dimension
| Kernel/Filter Width
| Strides
| Channel*outputDimension
|
---|
Total Stride = 5 x 2 x 2 x 2 x 2 x 2 x 2 = 320
Time per Sample = [Tex]\frac{1}{16000} = 0.0625[/Tex] milliseconds (ms) per sample
Duration per sample output = 320 x 0.0625ms = 20 ms
|
---|
1
| 1 x 16000
|
10
|
5
| 512 x 3199
|
---|
2
| 512 x 3199
|
3
|
2
| 512 x 1599
|
---|
3
| 512 x 1599
|
3
|
2
| 512 x 799
|
---|
4
| 512 x 799
|
3
|
2
| 512 x 399
|
---|
5
| 512 x 399
|
3
|
2
| 512 x 199
|
---|
6
| 512 x 199
|
2
|
2
| 512 x 99
|
---|
7
| 512 x 99
|
2
|
2
| 512 x 49
|
---|

Feature Encoder of Wav2Vec2
Contextualized representations with Transformers
The core of wav2vec 2.0 is its Transformer encoder, which takes as input the latent feature vectors obtained from the feature encoder and processes it through transformer blocks. The input sequence undergoes an initial transformation by passing through a feature projection layer, which increases the dimension from 512 (the feature encoder output) to 768 for the BASE variant or 1,024 for the LARGE variant thereby aligning with the inner dimension requirements of the Transformer encoder.
BASE contains 12 transformer blocks, model dimension 768, inner dimension (FFN) 3,072 and 8 attention heads. The LARGE model is made up of 24 transformer blocks with model dimensions of 1,024, inner dimensions of 4,096 and 16 attention heads.
One difference with respect to BERT architecture is how positional information is incorporated. Instead of fixed positional embeddings which encode absolute positional information, the wav2vec model instead uses a new grouped convolution layer to learn relative positional embeddings by itself.
The output of transformer is a context vector. The transformer builds context representations over continuous speech representations which are compared with respect to the output of quantization module . The output of quantization module (quantized vector) represent the discrete targets to be learnt by the transformer encoder. Here both the quantized vector and context vector are jointly learn using contrastive loss . More details about this in the training section.
Quantization module
The quantization module of Wav2Vec2 is adopted from vq-wav2vec architecture. Below diagram shows the overall quantization process.

Wav2Vec2 Quantization Process
The output of the feature encoder (rather than the context transformer) is discretized in parallel using a product quantization-based quantization module. Quantization is the process of mapping infinite values to discrete ones. A codebook in product quantization is like a set of representative points that help us to discretize . This representative values can be thought of as speech units. Here are the steps of quantization
- For time duration of 1 sec we get 512*49 dimension vector from the feature encoder. Thus we get 49 latent features each of size 512 .
- A linear layer projects each of the feature from 512 to 640(V) logits. Here the 640 logits is divided into two groups (G=2). This 320 logits represent codebook of 320 discrete vectors. The codebook is randomly initialized. The codebook representation is learnt during training using contrastive loss. Since we have mapped our feature vector into two groups we get a total possible combination of 320 * 320 =102400 speech units
- Using Gumbel-Softmax a one hot vector is produced for each group G. Thus we get two one hot vector . Each of the one hot vector corresponds to one of the 320 discrete vectors in the codebook.
- Gumbel Softmax is a popular technique for sampling from discrete space. The method involves introducing stochasticity (using Gumbel distribution) into the discrete decision-making process by using a differentiable approximation(softmax) to the argmax operation. It enables to backpropagate through random samples of discrete variables. Gumbel-Max Trick is very similar to the Reparameterization track whereby we are combining the deterministic part (the model logits) with the stochastic part (Gumbel noise ). During forward pass or inference the largest index is picked and the vector corresponding to it from the codebook is used. During backward pass the logits calculated is used for backpropagation.
- Each of the vector in code book is of size d/2 . We obtain two code book vectors(e1 and e2) for each latent feature vector (Z). This vector e1 and e2 are concatenated to get a ‘d’ dimension vector. Then it is passed through a linear transformation Rd→ Rf to obtain quantized vector q ∈ Rf. This transformation is done to match the dimension of transformer output.
Training Process
First let us understand what is a contrastive score and contrastive loss in order to understand the training procedure of wav2vec model
Contrastive Score typically involves computing a similarity metric between pairs of samples. Commonly used similarity metrics include cosine similarity or dot product. The idea is to compare the representations of two instances in the embedding space. For positive pairs (examples that should be similar), the contrastive score should be high, indicating high similarity. For negative pairs (examples that should be dissimilar), the contrastive score should be low, indicating low similarity.
Contrastive Loss is often used as part of a loss function during training. One popular loss function in contrastive learning is the contrastive loss, which encourages the model to bring positive pairs closer together in the embedding space while pushing negative pairs apart.
[Tex]L(i,j) = -log(\frac{e^{(sim(z_i,z_j/\tau)}}{\Sigma e^{(sim(z_i,z_k/\tau)}})
[/Tex]
where
- [Tex]L(i,j)
[/Tex] is the contrastive loss for samples i and j
- [Tex]sim(z_i , z_j)
[/Tex] is the similarity score between samples i and j
- the sum is over all samples in the batch
Here one important thing to note is that the positive pair are moved/changed to make them more similar and negative pairs are moved/changed in a way that make them more dissimilar.
In the context of Wav2Vec2
- The output from feature encoder is passed through quantization module to get a quantized representation from the codebook. This is the positive sample .
- The same output form feature encoder is passed through transformer encoder . Before passing a proportion of the feature is masked (~50%). The objective is to learn the representation of discrete speech audio at the masked position by comparing it with true quantized latent speech representation. For each masked position, 100 negative distractors(negative sample) are uniformly sampled from other positions in the same sentence. This 100 negative distractors are from codebook of 320 representations excluding the positive vector.
- The model compares the similarity using the conservative loss equation as shown above.
- The loss is then backpropagated through the transformer as well as the quantization module to make the output of transformer encoder and the codebook positive sample similar as well as codebook negative sample more dissimilar.
Diversity Loss is used to encourage the equal use of all the entries in codebooks to represent both positive and negative samples during training a diversity loss is added . This works by maximizing the entropy of the averaged-Softmax distribution, preventing the model to always choose from a small sub-group of all available codebook entries.
Wav2Vec2 Model Implementation
Install Libraries
Install the below libraries if not available in your environment. These are required to run the subsequent code.
!pip install datasets
!pip install transformers
!pip install torch
!pip install evaluate
!pip install transformers[torch]
Import Libraries
And then import the libraries into your notebook. Required libraries include numpy, transformers and pytorch.
Python
# Imports required
import numpy as np
from datasets import load_dataset, Audio
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import torch
import evaluate
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
from transformers import TrainingArguments, Trainer
Loading Dataset and Preprocessing
Loading Minds 14 dataset and split the dataset in 80:20 ratio.
Python
# Load the PolyAI dataset.
dataset = load_dataset("PolyAI/minds14", name="en-US", split="train[:80]")
# Remove unnecessary columns
dataset = dataset.remove_columns(['path','english_transcription','intent_class'])
# Split the dataset into train and test
dataset = dataset.train_test_split(test_size = 0.2, shuffle=False)
Resampling data
We need to resample the data to 16khz as the Wav2Vec2 model is trained in 16khz and the dataset is in 8khz. For this we will use Audio library.
Python
# Declare device variable
device = 'cuda' if torch.cuda.is_available() else'cpu'
# Resample the dataset to 16 Khz as MCTCT model is trained on 16khz
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
model.to(device)
Drawing Inferences
We format an input and use the base model to infer its transcription. The model produces output in logits, and we decode it by selecting the maximum value among the logits. The use of ‘torch.no_grad()’ ensures that these operations do not contribute to gradient computation, which is particularly helpful when there’s no need to update the model weights.
Python
# Lets process the first example of train dataset
inputs = processor(dataset['train'][3]["audio"]["array"], sampling_rate=16000, return_tensors="pt")
# getting the predictions
with torch.no_grad():
logits = model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
transcription
Output:
['HOW DO I FURN A JOINA COUT']
The actual text of the audio is ‘how do I start a joint account’ .
Fine Tuning the model
We want to prepare our data to match the expected format for the Wav2Vec2 model using the Dataset map function. For this, we’re creating two columns named ‘input_values,’ where the raw input sound wave array needs to be resampled to 16kHz, and ‘labels,’ which will hold the transcription in the format expected by the tokenizer. To achieve this, we’re passing each piece of data through a processor defined below.
Python
# Preparing a function to process the entire dataset
# We need to crate two variables with name 'input_featrues'
# (input array of sound wave in raw foram) and 'labels'(transcription)
def prepare_dataset(batch):
audio = batch["audio"]
batch["input_values"] = processor(
audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
with processor.as_target_processor():
batch["labels"] = processor(batch["transcription"].upper()).input_ids
return batch
encoded_dataset = dataset.map(prepare_dataset, num_proc=1)
Creating a specialized class for Data
We’re crafting a DataCollator Class specifically designed for fine-tuning Wav2Vec2. Unlike Transformer models, ASR tasks don’t have a built-in data collator. So, we’re tweaking the DataCollatorWithPadding class to create batches of examples that match the elements found in the training or evaluation datasets.
It’s worth highlighting that ‘input_values’ and ‘labels’ need different padding strategies since they can have varying lengths. In ASR tasks with potentially large input sizes, it’s more efficient to dynamically pad training batches. This means each training sample only gets padded to match the length of the longest sample within its batch, rather than padding to the overall longest sample.
So, in essence, for fine-tuning Wav2Vec2, we’re crafting a specialized padding data collator, and we’ll define it below:
Python
@dataclass
class DataCollatorCTCWithPadding:
processor: Wav2Vec2Processor
padding: Union[bool, str] = "longest"
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need
# different padding methods
input_values = [{"input_values": feature["input_values"]} for feature in features]
label_features = [{"input_ids": feature["labels"]} for feature in features]
batch = self.processor.pad(input_values, padding=self.padding, return_tensors="pt")
with self.processor.as_target_processor():
labels_batch = self.processor.pad(label_features, padding=self.padding, return_tensors="pt")
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
batch["labels"] = labels
return batch
data_collator = DataCollatorCTCWithPadding(processor=processor, padding="longest")
Evaluation Metric
For our task, we’ll be using the word error rate metric. To measure this, we need to define a ‘compute_metrics’ function. Each logit vector has a length equal to the configured vocabulary size, which is noted as ‘config.vocab_size.’ Our main focus is on figuring out the model’s prediction, and we do this by calculating the argmax(…) of the logits.
To make sense of the predictions, we convert the encoded labels back into their original string form. This involves a couple of steps. First, we replace instances of -100 with the ‘pad_token_id.’ Then, we decode the IDs while making sure that consecutive tokens are not incorrectly grouped together. This decoding process aligns with the CTC (Connectionist Temporal Classification) style, ensuring accuracy in the representation of the original string.
Python
wer = evaluate.load('wer')
def compute_metrics(pred):
wer = evaluate.load("wer")
pred_logits = pred.predictions
pred_ids = np.argmax(pred_logits, axis=-1)
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
pred_str = processor.batch_decode(pred_ids)
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
wer = wer.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
Model Training
Wav2Vec2 is a sizable model that demands a significant amount of memory, making GPU training a necessity. If your system lacks sufficient memory, there’s a risk of encountering out-of-memory issues. The learning rate has been fine-tuned through heuristic methods to ensure stable fine-tuning. It’s crucial to note that these parameters are highly dependent on the dataset, so experimenting with various values is essential.
To initiate the training process, pass these training arguments, along with the dataset, model, tokenizer, and data collator, to the Trainer. Once set up, call the ‘.train()’ method to kickstart the training.
Python
del model
model = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-base-960h",
ctc_loss_reduction="mean",
pad_token_id=processor.tokenizer.pad_token_id)
model.to(device)
# defining training arguments and trainer
training_args = TrainingArguments(
output_dir="wav2vec2_finetuned",
gradient_checkpointing=True,
per_device_train_batch_size=1,
learning_rate=1e-5,
warmup_steps=2,
max_steps=2000,
fp16=True,
optim='adafactor',
group_by_length=True,
evaluation_strategy="steps",
per_device_eval_batch_size=1,
eval_steps=100,
load_best_model_at_end=True,
metric_for_best_model="wer",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=encoded_dataset["train"],
eval_dataset=encoded_dataset["test"],
tokenizer=processor.feature_extractor,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
trainer.train()
Output:
Step Training Loss Validation Loss Wer
100 No log 1.422148 0.354839
200 No log 1.584326 0.379032
300 No log 1.595137 0.346774
400 No log 1.534755 0.314516
500 1.022900 1.548012 0.322581
600 1.022900 1.525821 0.322581
Getting Prediction from the Fine-tuned model
Python
## getting test data
i2 = processor(dataset['test'][6]["audio"]["array"], sampling_rate=16000, return_tensors="pt")
print(f"The input test audio is: {dataset['test'][6]['transcription']}")
# prediction for test data
with torch.no_grad():
logits = model(**i2.to(device)).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
print(f'The output prediction is : {transcription[0]}')
Output :
The input test audio is: so you spent the money I'd like to see my new account balance
The output prediction is : SO JUS SPEND SOME MONEY I'D LIKE TO SEE MY NEW ACCOUNT BALANCE
The output is better this time.
Conclusion
Self-supervised learning, exemplified by models like Wav2Vec2, offers a robust approach for representation learning in domains with limited labeled data. Fine-tuning on specific tasks further refines the model’s performance, showcasing the adaptability and effectiveness of this training methodology.