Understanding CTC loss for speech recognition
What exactly is automatic speech recognization(ASR) trying to do? and how will the loss function of ASR model? Here will try to simply explain how CTC loss going to work on ASR.
In transformers==4.2.0
, a new model called Wav2Vec2ForCTC which support speech recognization with a few line:
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from datasets import load_dataset
import soundfile as sfprocessor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")def map_to_array(batch):
speech, _ = sf.read(batch["file"])
batch["speech"] = speech
return batch
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
ds = ds.map(map_to_array)input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
logits = model(input_values).logitspredicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0])
Result:'A MAN SAID TO THE UNIVERSE SIR I EXIST'
How is that work?
What speech recognition do?
The input of speech recognition will be a recording, which will be processed into a Spectrogram (time spectrum).
Spectrogram can present the energy of different frequency bands at every moment.
The model will use Spectrogram to make a prediction every time (100ms for example), and then get the corresponding text based on the predicted result.
Since we only have recording and corresponding text. We don’t have the alignment of each character to its voice.
The calculation of Loss requires input and target to correspond. We cannot train this model in this case,
To solve the alignment, we have Connectionist Temporal Classification Loss(CTC Loss)
CTC Loss
First, we need to find a mechanism to allow long sequences to mapping to a shorter sequence. In our case will be the recording and its text.
- Allow repeated output.
When the model is not sure at which moment it should outputg
, it should allow models to predict the same token multiple times. - Merge output
The next step is to merge these repetitive outputs.
- In order to distinguish between two consecutive tokens and duplicate tokens, a separate token
_
is also introduced.
The gold of training is the make prediction follow by this rule, the idea will be like this:
Use an example to show its usage, input a piece of audio, and predict the word g
. Assuming that the model will decode three states, each state will give us the probability of all tokens, and then we select the result with the highest probability.
After obtaining the text, we will merge the output according to the above text:
There are many combinations that can generate the same result, the training goal is to guide the model to produce one of the results, then we can decode the corresponding text. An intuitive way is to enumerate all the combinations and calculate the loss for each candidate.
Follow the same example, model will output three states, and we want to decode g
. The all possible combination and its loss will be:
Evaluate the result on Pytorch CTC Loss:
import torchinput = torch.log(torch.tensor([
[[ 0.4, 0.6,]],
[[ 0.3, 0.7,]],
[[ 0.2, 0.8,]],
], dtype=torch.float, requires_grad=True))
target = torch.tensor([[1]])
input_lengths = torch.tensor([3])
target_lengths = torch.tensor([1])ctc_loss = torch.nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)
print(loss)
# tensor(0.1839, grad_fn=<MeanBackward0>)
That this the main idea of CTC Loss, but there is an obvious flaw: the number of combinations will increase exponentially as the length of the input increases. The efficiency is too low to train on a large amount of data.
To increase efficiency, CTC Loss also uses Dynamic Programming during loss calculation.
First of all, we list all of the probability output on every time step. We can get this right after model’s softmax. Then, we create another table to simulate the dynamic programming calculation.
At the first time step T1, we give each of the element corresponding probability from left. Note that there will be two _
here, the upper one means it appears before g
, and the other means to appear after g
. We use circle and triangle to distinguish these two_
as shown in the table.
Then, calculate T2 based on the result of T1. According to the previous exhaustive results:
(T2,
_
circle) will only come from (T1,_
circle).(T2,g) may come from (T1,
_
circle) and (T1,g).(T2,_, triangle) will be the result after (T1, g).
Calculation T3 with the above rules:
Sum up all paths in T3 and take log will get overall loss. The calculated result is consistent with the previous result.
In the above process, you may notice that:
- The beginning will be the first two cells in the upper left corner, and the end will be the last two cells in the lower right corner.
- When T turns to T+1, it will only move to the right, and it only one or two step, which means it moves either to the next word or space symbol.
We can also calculate the backward result in the same way. The only difference is that it starts from back to front.
Calculating both forward and backward loss can help us figure out what the loss of each token should be at each T for backpropagation.
The limitation of CTC loss is the input sequence must be longer than the output, and the longer the input sequence, the harder to train.
That’s all for CTC loss! It solves the alignment problem which make loss calculation possible from a long sequence corresponds to the short sequence. The training of speech recognition can benefit from it with a larger amount of data.