Benchmarking Efficient Transformer Variants

Transformer
DNN
DL
ML
author avatar
Shashank Gurnalkar Data Scientist @ Infocusp
28 min read  .  18 Sep 2024

banner image

Transformers have transformed the landscape of natural language processing (NLP) by enabling models to capture complex relationships within text data. Transformers are highly effective for tasks like sentiment classification, neural machine translation, etc. where understanding the meaning of text is crucial. In this blog, we'll explore the implementation of two Transformer variants (Performer and Longformer) along with the base Transformer. These are more memory efficient variations proposed after the Transformer particularly when we have to deal with longer sequences. We will compare their performance on a classification task. Specifically, we'll focus on sentiment classification using only the encoder portion of the Transformer model, since the decoder is not needed for this task.

Transformer: The Foundation

The Transformer, as presented by Vaswani et al. in the seminal "Attention is All You Need" paper, forms the backbone of modern Transformer-based models. It introduced the concept of self-attention, a mechanism that allows each token in a sequence to attend to all other tokens, capturing both local and global dependencies.

In a typical Transformer, the model consists of an encoder-decoder structure. However, for a classification, we usually require just the encoder. For Transformer implementation, we will refer to The Annotated Transformer tutorial which walks through the entire process of implementing the Transformer model from scratch using PyTorch providing extensive annotations and explanations for each part of the model. Let’s briefly look into the components of the Transformer’s encoder and understand how they work together to process input data for classification tasks.

1. Input Embedding

The first step in the Transformer architecture is to convert input tokens (words, subwords, etc.) into dense vectors. This process is called embedding. Each token in the input sequence is mapped to a high-dimensional vector space, where similar words have similar representations. This embedding allows the model to capture semantic information from the input text.

class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

2. Positional Encoding

Unlike recurrent neural networks (RNNs), Transformers do not inherently understand the order of tokens in a sequence. To introduce this sequential information, we add positional encodings to the input embeddings. Positional encodings are vectors that encode the position of each token in the sequence, ensuring the model understands the order and relative positions of words.

class PositionalEncoding(nn.Module):
    "Implement the PE function."

    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        return self.dropout(x)

3. Multi-Head Self-Attention

The core innovation of the Transformer is the self-attention mechanism. In the multi-head self-attention layer, the model calculates attention scores for each token pair in the sequence. This means each token "attends" to every other token, allowing the model to capture dependencies and relationships regardless of their distance in the sequence.

Multi-head attention enhances this process by running multiple attention mechanisms in parallel, each focusing on different parts of the sequence. The outputs of these attention heads are then combined to form a richer representation of the input data.

class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()

        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, mask=None):
        query, key, value = (x, x, x)
        nbatches = query.size(0)

        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = [lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
                             for lin, x in zip(self.linears, (query, key, value))]

        # 2) Apply attention on all the projected vectors in the batch.
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)

        # 3) "Concat" using a view and apply a final linear.
        x = (x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k))

        return self.linears[-1](x)
def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = scores.softmax(dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

4. Feed-Forward Networks (FFN)

After the self-attention layer, the model applies a feed-forward network (FFN) to each position in the sequence independently. This FFN is a simple fully connected neural network that further processes the information from the self-attention layer. It helps the model capture non-linear relationships in the data, refining the representations before they are passed to the next layer.

class PositionwiseFeedForward(nn.Module):
    "Implements FFN equation."

    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(self.w_1(x).relu()))

5. Residual Connections and Layer Normalization

To stabilize training and improve gradient flow, the Transformer uses residual connections and layer normalization. Residual connections allow the model to pass information directly from one layer to the next, bypassing the intermediate transformations. This helps the model learn more efficiently, especially in deep networks. Layer normalization, on the other hand, normalizes the inputs to each layer, reducing internal covariate shift and speeding up convergence.

class LayerNorm(nn.Module):
    "Construct a layernorm module (See citation for details)."

    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    """

    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        "Apply residual connection to any sublayer with the same size."
        return x + self.dropout(sublayer(self.norm(x)))

6. Final Classification Step

For the sentiment classification task, we add a special `[CLS]` token at the beginning of the input sequence. After the input passes through the entire encoder, the output corresponding to this `[CLS]` token is used as the aggregate representation of the input sequence. This output is then passed through a feed-forward network, which classifies the sentiment of the input text.

class ClsDecoder(nn.Module):
    """
    Decoder for classification task.
    """

    def __init__(self, *args):

        super().__init__()
        self._decoder = nn.ModuleList()

        n = len(args)

        for i in range(n-2):
            self._decoder.append(nn.Linear(args[i], args[i+1]))
            self._decoder.append(nn.ReLU())

        self.out_layer = nn.Linear(args[n-2], args[n-1])
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        for layer in self._decoder:
            x = layer(x)
        x = self.out_layer(x)
        return self.sigmoid(x)

Challenges:

Quadratic Complexity: While the Transformer architecture is powerful, it faces challenges with long sequences due to its quadratic complexity in the self-attention mechanism which occurs because of the pairwise interaction of each token with every other tokens in the sequence. For a sequence of length N, the self-attention mechanism needs to compute the dot product for each pair of tokens leading to a N * N = N^2 operations just for the attention scores. Also, the resulting matrix which stores the attention scores has a size of N x N leading to a memory complexity of O(N^2). The Softmax and subsequent operations (like applying the attention scores to the Value vectors) are also dependent on the size of the attention matrix, further contributing to the quadratic nature.

Performer: Enhancing Efficiency with FAVOR+

The Performer, introduced by Choromanski et al., addresses the quadratic complexity by introducing an efficient attention mechanism called FAVOR+ (Fast Attention Via Orthogonal Random features).

FAVOR+:

It approximates the traditional softmax-based self-attention with linear complexity O(N). Instead of directly computing the dot product between all pairs of tokens, FAVOR+ uses kernel-based methods to project the input into a lower-dimensional space where interactions can be computed more efficiently.

Random Feature Maps:

Performer utilizes random feature maps to approximate the softmax function. This technique enables them to compute attention using a linear number of operations relative to the sequence length, significantly reducing the memory footprint and computational cost.

Despite the approximation, Performers maintain competitive accuracy on various tasks compared to traditional transformers, while being much more scalable to longer sequences.

class PerformerAttention(nn.Module):
    def __init__(self, dim, heads, device,  
                 generalized_attention = False,
                 kernel_fn = nn.ReLU(), dropout = 0.1, 
                 qkv_bias = True, attn_out_bias = True, ortho_scaling = 0, 
                 auto_check_redraw = True, feature_redraw_interval = 1000
    ):
        super(PerformerAttention, self).__init__()
        
        assert dim % heads == 0
        dim_heads = dim // heads
        inner_dim = dim_heads * heads        
        nb_features = dim_heads // 2
        self.dim_heads = dim_heads
        self.nb_features = nb_features
        self.ortho_scaling = ortho_scaling
        self.generalized_attention = generalized_attention

        self.create_projection = partial(gaussian_orthogonal_random_matrix, 
                                         nb_rows=self.nb_features,  
                                         nb_columns=self.dim_heads,
                                         scaling=self.ortho_scaling, 
                                         device=device)  
                   
        self.register_buffer('projection_matrix', self.create_projection())
        self.heads = heads
        self.to_q = nn.Linear(dim, inner_dim, bias = qkv_bias)
        self.to_k = nn.Linear(dim, inner_dim, bias = qkv_bias)
        self.to_v = nn.Linear(dim, inner_dim, bias = qkv_bias)
        self.to_out = nn.Linear(inner_dim, dim, bias = attn_out_bias)
        self.dropout = nn.Dropout(dropout)
        
        self.feature_redraw_interval = feature_redraw_interval
        self.current_feature_interval = 0
        self.auto_check_redraw = auto_check_redraw

    def forward(self, x, mask):
                        
        b, n, _ = *x.shape,     
        q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
        
        if mask is not None:
            v.masked_fill_(mask, 0.)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
        
        device = q.device
        create_kernel = partial(exponential_kernel, projection_matrix = self.projection_matrix,device = device)
        q = create_kernel(q, is_query = True)
        k = create_kernel(k, is_query = False)
        
        out = performer_attention(q, k, v)    
        
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)
        
        self.current_feature_interval += 1        
        return self.dropout(out)

For our sentiment classification task, we integrate the Performer by replacing the multi-head self-attention in the Transformer with the Performer (FAVOR+) attention. The rest of the model remains the same, with the `[CLS]` token's output used for classification after passing through the encoder.

Longformer: Adapting Attention for Long Sequences

The Longformer, developed by Beltagy et al., modifies the Transformer architecture to handle long sequences effectively.

Local Attention:

Longformer introduces a sliding window attention mechanism, where each token attends to a fixed number of neighboring tokens within a sliding window, reducing the complexity from O(N^2) to O(N). Additionally, it incorporates dilated (or strided) attention, which allows tokens to attend to other tokens at regular intervals outside their immediate neighborhood, capturing broader context while still being efficient.

Global Attention:

For tasks requiring some tokens to attend globally across the entire sequence (e.g., classification tokens, question tokens in QA tasks), Longformer allows certain tokens to attend to all tokens in the sequence (global attention), while most tokens continue using the local sliding window attention. This hybrid approach balances efficiency with the ability to capture long-range dependencies when needed.

class LongformerSelfAttention(nn.Module):
    def __init__(self, config, layer_id):
        super().__init__()
        self.num_heads = config.num_attention_heads
        self.head_dim = int(config.hidden_size / config.num_attention_heads)
        self.embed_dim = config.hidden_size
        self.query = nn.Linear(config.hidden_size, self.embed_dim)
        self.key = nn.Linear(config.hidden_size, self.embed_dim)
        self.value = nn.Linear(config.hidden_size, self.embed_dim)

        # separate projection layers for tokens with global attention
        self.query_global = nn.Linear(config.hidden_size, self.embed_dim)
        self.key_global = nn.Linear(config.hidden_size, self.embed_dim)
        self.value_global = nn.Linear(config.hidden_size, self.embed_dim)
        self.dropout = config.attention_probs_dropout_prob
        self.layer_id = layer_id
        attention_window = config.attention_window[self.layer_id]
        self.one_sided_attn_window_size=attention_window // 2
        self.config = config 

    def forward(self, hidden_states, attention_mask=None):
        
        (hidden_states, attention_mask, padding_len) =  self._pad_to_window_size(hidden_states,attention_mask, 
                                                                                 2*self.one_sided_attn_window_size, factor=1)

        is_index_masked = attention_mask
        
        # create a global attention mask to have just CLS token in it
        is_index_global_attn = attention_mask.new_empty(attention_mask.size())
        is_index_global_attn[:, 0, :] = True
        is_index_global_attn[:, 1:, :] = False   

        # project hidden states
        query_vectors = self.query(hidden_states)
        key_vectors = self.key(hidden_states)
        value_vectors = self.value(hidden_states)

        batch_size, seq_len, embed_dim = hidden_states.size()
        
        # mask the key_vectors (padding) with 0
        key_vectors = key_vectors.masked_fill(attention_mask, 0)

        query_vectors = query_vectors.view(batch_size, seq_len, self.num_heads, self.head_dim)
        key_vectors = key_vectors.view(batch_size, seq_len, self.num_heads, self.head_dim)    
        
        # calculate partial local attention
        attn_scores = self._sliding_chunks_query_key_matmul(
            query_vectors, key_vectors, self.one_sided_attn_window_size)
   
        # get global indices
        (max_num_global_attn_indices, is_index_global_attn_nonzero, 
         is_local_index_global_attn_nonzero,) = self._get_global_attn_indices(is_index_global_attn)
        
        # (A): calculate (partial) attention where each token (from query) attends global token (from key)
        global_key_attn_scores = self._concat_with_global_key_attn_probs(
                                query_vectors=query_vectors,
                                key_vectors=key_vectors,
                                max_num_global_attn_indices=max_num_global_attn_indices,
                                is_index_global_attn_nonzero=is_index_global_attn_nonzero,
                                is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,)
        
        # concat (A) and partial local attention            
        attn_scores = torch.cat((global_key_attn_scores, attn_scores), dim=-1)
        
        # fill the padding with -inf to get ignored from softmax    
        attn_scores = attn_scores.masked_fill(attention_mask.unsqueeze(dim=-1), -float("inf"))
        
        # mask zero values to -inf in last valid w tokens before padding 
        attn_scores = attn_scores.masked_fill(attn_scores==0,-float("inf"))

        attn_probs = nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32)

        # replace NaN with 0
        attn_probs = attn_probs.masked_fill(attention_mask.unsqueeze(dim=-1), 0)
        attn_probs = attn_probs.type_as(attn_scores)

        # apply dropout
        attn_probs = nn.functional.dropout(attn_probs, p=self.dropout, training=self.training)

        value_vectors = value_vectors.view(batch_size, seq_len, self.num_heads, self.head_dim)
                    
        # (B) calculate complete attention where each token (from (A)) attends global token (from value)
        
        # calculate complete local attention
        # (X): Add ((B) + complete local attention)
        attn_output = self._compute_attn_output_with_global_indices(
                                value_vectors=value_vectors,
                                attn_probs=attn_probs,
                                max_num_global_attn_indices=max_num_global_attn_indices,
                                is_index_global_attn_nonzero=is_index_global_attn_nonzero,
                                is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,)
        
        attn_output = attn_output.reshape(batch_size, seq_len, embed_dim).contiguous()

        global_attn_output, _ =  self._compute_global_attn_output_from_hidden(
                                hidden_states=hidden_states,
                                max_num_global_attn_indices=max_num_global_attn_indices,
                                is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
                                is_index_global_attn_nonzero=is_index_global_attn_nonzero,
                                is_index_masked=is_index_masked,)

        # replace the attention values in (X) at global indices with (Y)
        attn_output[is_index_global_attn_nonzero] = global_attn_output[is_local_index_global_attn_nonzero]        
        outputs = attn_output[:,:attn_output.shape[1]-padding_len,:]
        return outputs

In our implementation, the Longformer replaces the self-attention mechanism in the Transformer with the Longformer attention (sliding window and global attention). As with the Performer, the `[CLS]` token's output from the encoder is passed through a feed-forward network to produce the final classification output.

Dataset Information:

For this sentiment classification task, we used a review dataset consisting of fashion product reviews, which was downloaded from Amazon Reviews 2023. The dataset contains reviews with ratings ranging from 1 to 5, providing a rich source of text data for sentiment analysis.

Polarity Assignment

To prepare the dataset for binary classification, we assigned sentiment polarity based on the ratings:

  • Negative Sentiment (0): Reviews with a rating of 1, 2, or 3 were labeled as 0, indicating negative sentiment.
  • Positive Sentiment (1): Reviews with a rating of 4 or 5 were labeled as 1, indicating positive sentiment.

Sampling and Splitting

Given the large size of the dataset, we randomly sampled 120,000 reviews stratified by polarity to create a balanced and more manageable dataset for training, testing, and validation. The sampled dataset was then split into three subsets:

  • Training Set: 80,000 samples used to train the models.
  • Validation Set: 20,000 samples used for validation.
  • Test Set: 20,000 unseen samples used to evaluate the final performance of the models.

This balanced dataset allowed us to rigorously train and compare the Transformer variants on a real-world sentiment classification task.

Training Information:

The table below provides a summary of the common key training parameters used to train these Transformer variants.

ParameterValue
Number of Encoder Layers6
Number of Attention Heads8
Embedding Dimension128
Dropout Rate0.1
OptimizerAdam
Learning Rate0.0001
Early StoppingTrue
Batch Size16

Comparison

To evaluate the training performance of the three Transformer variants - Transformer, Performer, and Longformer, their validation loss and validation accuracy was compared on the validation dataset (against number of epochs). Here's a pictorial summary of the findings:

Validation Loss:

validation loss

Validation Accuracy:

validation accuracy

Below tables show the model performance (classification report) of each variant on the test dataset.

Transformer:

precisionrecallf1-scoresupport
00.8650.9170.88910019
10.9060.8580.88210009
accuracy0.8850.8850.8850.885
macro-avg0.8860.8850.88520028
weighted-avg0.8860.8850.88520028

Performer:

precisionrecallf1-scoresupport
00.8750.8880.88210019
10.8860.8730.8810009
accuracy0.8810.8810.8810.881
macro-avg0.8810.8810.88120028
weighted-avg0.8810.8810.88120028

Longformer:

precisionrecallf1-scoresupport
00.8790.7750.82410019
10.7980.8930.84310009
accuracy0.8340.8340.8340.834
macro-avg0.8390.8340.83320028
weighted-avg0.8390.8340.83320028

Conclusion

In this blog, we explored the implementation of three Transformer variants for sentiment classification: the base model from the Transformer, the Performer, and the Longformer. Each variant modifies the attention mechanism to address specific challenges, whether it’s improving efficiency or adapting to very long sequences. Despite these differences, the overall architecture remains consistent: the encoder processes the input, and the `[CLS]` token's representation is used to predict sentiment through a feed-forward network.

By understanding and implementing these Transformer variants, we can tailor our models to better handle the unique demands of any classification tasks, leading to more accurate and efficient models.

The complete Code implementation of these variants is available on github.