What makes transformer faster than LSTM on generation

Eric Lam
4 min readJul 22, 2021

--

This article will explain the mechanism of the Transformer for text generation and the main idea step by step.

Text generation process

Text generation is to input the previous tokens and output next one literately until the end. Suppose the input is:

how the weather is tomorrow?

And the model should reply:

sunny day

It can be represented as:

P(sunny day|how the weather is tomorrow?)

we can further turn it into:

P(sunny|how the weather is tomorrow?)× P(day|how the weather is tomorrow? sunny)

The objective of our models is to maximize the above probability. During training, we already know what should be our output. It’s no need to wait for the past for the next token prediction. In other words, we can parallel calculate the loss, and this training method will be called teacher forcing.

Moreover, a start token <S> and an end token <E> will also be added for the first token prediction and stop predict signal.

Input    how the weather is tomorrow?   <S>    sunny  day
Target sunny day <E>

Transformer parallelism

Although teacher forcing can parallel calculate the loss of the sentence, the RNN architecture such as LSTM requires the output of the previous state to calculate the next state, so parallel training cannot be performed.

Image by dvgodoy / CC BY

One of the ideas of transformer architecture is to solve the problem that RNN cannot be parallel.

Image by dvgodoy / CC BY

Transformer architecture basically is a Fully Connected Neural Network with Weight on edge. The weight calculation will calculate the importance of each element using an attention mechanism.

In detail, Transformer will turn the input text into an embedding vector, and then project the embedding into three vector spaces of the same size, Q, K and V respectively.

The design of QKV is to allow the inputs in the Transformer model to “see” each other and decide which part is important to itself.

The problem is, when the previous token can “see” all of the token in a sentence, it will take no effort to predict the next token, the training will be collapse.

Visualization of QKV process

Let’s visualize the QKV calculation process to better understand the problem.

Input    x x x x x x x x x x x x x x   <S>    sunny  day
Target sunny day <E>

In this example, the input will be <S> sunny day , the Q*K matrix will be:

note that all weights are set to 1 for simplicity

Q*K result represents the importance of each K for Q’s token. Then, we apply the importance result with V metric, to get the weight of the corresponding word after the importance result. The dimension of V can be very large, so we only take one of the dimensions x:

The weight of sunny in dim x depends on the Q*K result, The trouble is that the result of sunny in Q*K will contain day and <E> information when Q*K dot V.

To prevent the former text from seeing the latter text, we will set all the latter text Q*K to zero. In matrix calculation, we call it causal mask.

This is the design of the Transformer decoder. Compared with the encoder, there is an extra causal mask, so that the front characters will not be seen behind, and the training text generation can be parallelized.

The difference between encoder and decoder is whether there is has a causal mask. Because of causal mask, each token of the decoder can only contain the preceding information, the semantic may not be complete as the encoder.

Therefore, in order to take advantage from both encoder and decoder, we can connect the encoder and decoder together. In transformer, it will pass the encoder’s K and V to the decoder.

Image by dvgodoy / CC BY

The idea of decoder here is that for each decoder’s token, we need to find the importance of each token in the encoder(Decoder Q to Encoder K) and then apply a weighted sum to encoder V.

That’s all for transformer model, due to its design can allow both data and model parallel training, the transformer is much more efficient than recurrent neural network such as LSTM.

At the same time, the encoder-decoder architecture is also proposed to balance the effect and efficiency. Main difference with encoder-decoder is Q * K with causal mask or not, and to combine them altogether. It will pass encoder K,V to decoder which allow decoder calculation its importance on input sequence.

--

--

No responses yet