MLKD-BERT: Multi-level Knowledge Distillation for Pre-trained Language Models

Ying Zhang1, Ziheng Yang1, Shufan Ji1
1Beihang University
{yingzhang1998, yzh2206140, jishufan}@buaa.edu.cn
Abstract

Knowledge distillation is an effective technique for pre-trained language model compression. Although existing knowledge distillation methods perform well for the most typical model BERT, they could be further improved in two aspects: the relation-level knowledge could be further explored to improve model performance; and the setting of student attention head number could be more flexible to decrease inference time. Therefore, we are motivated to propose a novel knowledge distillation method MLKD-BERT to distill multi-level knowledge in teacher-student framework. Extensive experiments on GLUE benchmark and extractive question answering tasks demonstrate that our method outperforms state-of-the-art knowledge distillation methods on BERT. In addition, MLKD-BERT can flexibly set student attention head number, allowing for substantial inference time decrease with little performance drop.

1 Introduction

In recent years, large-scale pre-trained language models (PLMs) have been widely applied in natural language processing, such as BERT (Devlin et al., 2019), XLNet (Yang et al., 2019), and RoBERTa (Liu et al., 2019). The PLM usually has large number of parameters and long inference time, making it inapplicable to resource-limited devices and real-time scenarios. Therefore, it is crucial to reduce PLM’s storage and computation overhead while retaining its performance. Knowledge distillation (Hinton et al., 2015) is an effective technique for PLM compression. In knowledge distillation, a smaller compact student model is trained, under the guidance of a larger complicated teacher model, to keep similar model performance.

As for the most typical PLM BERT, there exist several knowledge distillation methods for model compression, including DistilBERT (Sanh et al., 2019), BERT-PKD (Sun et al., 2019), TinyBERT (Jiao et al., 2020), BERT-EMD (Li et al., 2020), MINILM (Wang et al., 2020a), and MINILMv2 (Wang et al., 2021). Despite the effectiveness of previous methods, there still exist two problems not addressed well:

  • Existing methods mainly distill feature-level knowledge, but seldom consider relation-level knowledge (relation among tokens and relation among samples). However, the relation-level knowledge may be valuable to improve the performance of student model.

  • Most previous works use self-attention distribution to distill teacher’s self-attention modules, thus the student model is restricted to take the same attention head number as its teacher. Such restriction prevents the reduction of attention head number in student model, resulting in increased inference time.

Therefore, a more flexible knowledge distillation method with improved performance is preferred to transfer knowledge from teacher to student model.

In this paper, we propose a novel multi-level knowledge distillation method MLKD-BERT for BERT compression. MLKD-BERT conducts a two-stage distillation for feature-level as well as relation-level knowledge, with 6 distillation loss functions designed for embedding layer, Transformer layers, and prediction layer. Compared with previous works, we have made two main contributions:

  • In addition to feature-level knowledge, our student model learns valuable relation-level knowledge (relation among tokens and relation among samples) from its teacher model, which further improves the performance.

  • Our student model learns self-attention relation instead of self-attention distribution, making it flexible in attention head number setting. As such, our student model could reduce attention head number to further decrease inference time.

Extensive experiments on GLUE (Wang et al., 2019) benchmark and extractive question answering tasks show that our MLKD-BERT outperforms state-of-the-art BERT distillation methods on various prediction tasks. In addition, MLKD-BERT can set smaller student attention head number, allowing for substantial inference time decrease with little performance drop. Moreover, MLKD-BERT is effective in PLM compression, in that MLKD-BERT keeps competitive performance (99.5%percent99.599.5\%99.5 % on average for GLUE tasks) as its teacher with 50%percent5050\%50 % compression in parameters and inference time.

The rest of the paper is organized as follows. Related works are reviewed in Section 2. We introduce our method MLKD-BERT in Section 3, and conduct extensive experiments in Section 4. Finally, in Section 5, conclusions are drawn.

2 Related Works

2.1 Pre-trained Language Models

Nowadays, large-scale pre-trained language models have significantly improved the performance of many natural language processing tasks. Pre-trained language models are usually trained on large amounts of text data, and then fine-tuned for specific task. Early research efforts mainly focus on word embedding, such as word2vec (Mikolov et al., 2013) and GloVe (Pennington et al., 2014). Subsequently, researchers have shifted to contextual word embedding, including BERT (Devlin et al., 2019), GPT Radford et al. (2018), ENRIE (Zhang et al., 2019), XLNet (Yang et al., 2019) and RoBERTa (Liu et al., 2019). However, those PLMs contain millions of parameters and take long inference time, making them inapplicable to resource-limited devices and real-time scenarios. Fortunately, there exist many compression techniques for PLMs, which reduce model size and accelerate model inference while kee** model performance.

2.2 Knowledge Distillation

Besides quantization (Shen et al., 2020) and network pruning (Wang et al., 2020b), knowledge distillation (Tang et al., 2019) has been proven to be an effective technique for PLM compression. As for the most typical PLM BERT, there exist several knowledge distillation methods for model compression. Distilled BiLSTM (Tang et al., 2019) tries to distill knowledge from BERT into a simple LSTM. DistilBERT (Sanh et al., 2019) uses soft target probabilities and embedding outputs to train student model. BERT-PKD (Sun et al., 2019) learns from multiple intermediate layers of teacher model for incremental knowledge extraction. TinyBERT (Jiao et al., 2020), MobileBERT (Sun et al., 2020), and SID (Aguilar et al., 2020) further improve BERT-PKD by distilling more internal representations, such as embedding layer outputs and self-attention distribution. BERT-EMD (Li et al., 2020) allows each intermediate student layer to learn from any intermediate teacher layer based on Earth Mover’s Distance. MINILM (Wang et al., 2020a) uses self-attention distribution and value relation to conduct deep self-attention distillation. MINILMv2 (Wang et al., 2021) generalizes deep self-attention distillation in MINILM (Wang et al., 2020a), employing self-attention relation.

Although existing knowledge distillation methods perform well in BERT compression, they are limited in two aspects: the relation-level knowledge is not well explored by the student model to enhance performance; and the attention head number of student model is restricted to the same as its teacher, increasing the inference time. Hence, we are motivated to propose a more flexible knowledge distillation method for BERT with improved performance.

3 Method

MLKD-BERT employs a two-stage distillation procedure for downstream task prediction in teacher-student framework (illustrated in Figure 1). Stage 1 distills embedding-layer and Transformer-layers, emphasizing feature representation and transformation, while Stage 2 distills prediction-layer, emphasizing sample prediction. In knowledge distillation, each student layer is mapped to corresponding teacher layer. As the number of Transformer-layers in student model is smaller than that in teacher model, we take uniform map** strategy (Jiao et al., 2020) for Transformer-layer (TL) map**. For example, in Figure 1, when 2 student TLs are mapped to 4 teacher TLs, TL1 and TL2 in student model are mapped to TL2 and TL4 in teacher model, respectively.

Refer to caption
Figure 1: Framework of MLKD-BERT.

In MLKD-BERT, we bring in some new relation-level knowledge to improve layer distillation. At embedding-layer, token similarity relation is employed to enhance feature representation. At Transformer-layers, besides feature-level knowledge, self-attention relation is proposed for flexible student attention head number setting. At prediction-layer, in addition to soft labels, sample similarity relation and sample contrastive relation are introduced to enhance prediction. The distillation procedures will be detailed in the following subsections.

3.1 Embedding-layer Distillation

Since embedding-layer performs feature representation of tokens for each data sample, token similarity relation could be valuable knowledge to enhance embedding-layer distillation. In embedding-layer distillation, token similarity relation is transferred from teacher to student by minimizing the KL-divergence of token embedding similarities between teacher and student, according to the embedding distillation loss function EMBsubscriptEMB\mathcal{L}_{\mbox{{EMB}}}caligraphic_L start_POSTSUBSCRIPT EMB end_POSTSUBSCRIPT defined in Eqn.(1):

EMB=1|x|i=1|x|DKL(𝐑iT||𝐑iS)\mathcal{L}_{\mbox{{EMB}}}=\frac{1}{|x|}\sum_{i=1}^{|x|}D_{\mbox{{KL}}}(% \mathbf{R}_{i}^{T}||\mathbf{R}_{i}^{S})caligraphic_L start_POSTSUBSCRIPT EMB end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG | italic_x | end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT | italic_x | end_POSTSUPERSCRIPT italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( bold_R start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT | | bold_R start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT ) (1)
𝐑T=softmax(𝐄T𝐄T𝖳dhT)superscript𝐑𝑇softmaxsuperscript𝐄𝑇superscript𝐄𝑇𝖳superscriptsubscript𝑑𝑇\mathbf{R}^{T}=\mbox{softmax}(\frac{\mathbf{E}^{T}\mathbf{E}^{T\mathsf{T}}}{% \sqrt{d_{h}^{T}}})bold_R start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT = softmax ( divide start_ARG bold_E start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_E start_POSTSUPERSCRIPT italic_T sansserif_T end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_ARG end_ARG ) (2)
𝐑S=softmax(𝐄S𝐄S𝖳dhS)superscript𝐑𝑆softmaxsuperscript𝐄𝑆superscript𝐄𝑆𝖳superscriptsubscript𝑑𝑆\mathbf{R}^{S}=\mbox{softmax}(\frac{\mathbf{E}^{S}\mathbf{E}^{S\mathsf{T}}}{% \sqrt{d_{h}^{S}}})bold_R start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT = softmax ( divide start_ARG bold_E start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT bold_E start_POSTSUPERSCRIPT italic_S sansserif_T end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT end_ARG end_ARG ) (3)

where |x|𝑥|x|| italic_x | is the length of input sequence; matrices 𝐄T|x|×dhTsuperscript𝐄𝑇superscript𝑥superscriptsubscript𝑑𝑇\mathbf{E}^{T}\in\mathbb{R}^{|x|\times{d_{h}^{T}}}bold_E start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT | italic_x | × italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT and 𝐄S|x|×dhSsuperscript𝐄𝑆superscript𝑥superscriptsubscript𝑑𝑆\mathbf{E}^{S}\in\mathbb{R}^{|x|\times{d_{h}^{S}}}bold_E start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT | italic_x | × italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT are token embeddings of teacher and student; dhTsuperscriptsubscript𝑑𝑇d_{h}^{T}italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT and dhSsuperscriptsubscript𝑑𝑆d_{h}^{S}italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT are hidden dimension of teacher and student; matrices 𝐑T|x|×|x|superscript𝐑𝑇superscript𝑥𝑥\mathbf{R}^{T}\in\mathbb{R}^{|x|\times{|x|}}bold_R start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT | italic_x | × | italic_x | end_POSTSUPERSCRIPT and 𝐑S|x|×|x|superscript𝐑𝑆superscript𝑥𝑥\mathbf{R}^{S}\in\mathbb{R}^{|x|\times{|x|}}bold_R start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT | italic_x | × | italic_x | end_POSTSUPERSCRIPT are token embedding similarity matrices of teacher and student, respectively.

3.2 Transformer-layer Distillation

Standard Transformer-layer contains two main sub-layers: Multi-Head Attention (MHA) and Feed Forward Network (FFN). Transformer-layer distillation (illustrated in Figure 2) is conducted on MHA and FFN to transfer self-attention relation and feature-level knowledge, respectively.

Refer to caption
Figure 2: Transformer-layer Distillation: MHA Distillation and FFN Distillation.

As for MHA sub-layer, the similarities among its output vectors are defined as self-attention relation. At Transformer layer l𝑙litalic_l, let Ahsubscript𝐴A_{h}italic_A start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT represents the number of attention heads, then output 𝐎l,a(a[1,Ah]\mathbf{O}_{l,a}(a\in[1,A_{h}]bold_O start_POSTSUBSCRIPT italic_l , italic_a end_POSTSUBSCRIPT ( italic_a ∈ [ 1 , italic_A start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ]) of the a𝑎aitalic_a-th attention head, is computed via:

𝐎l,a=softmax(𝐐l,a𝐊l,a𝖳dk)𝐕l,asubscript𝐎𝑙𝑎softmaxsubscript𝐐𝑙𝑎superscriptsubscript𝐊𝑙𝑎𝖳subscript𝑑𝑘subscript𝐕𝑙𝑎\mathbf{O}_{l,a}=\mbox{softmax}(\frac{\mathbf{Q}_{l,a}\mathbf{K}_{l,a}^{% \mathsf{T}}}{\sqrt{d_{k}}})\mathbf{V}_{l,a}bold_O start_POSTSUBSCRIPT italic_l , italic_a end_POSTSUBSCRIPT = softmax ( divide start_ARG bold_Q start_POSTSUBSCRIPT italic_l , italic_a end_POSTSUBSCRIPT bold_K start_POSTSUBSCRIPT italic_l , italic_a end_POSTSUBSCRIPT start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_d start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG end_ARG ) bold_V start_POSTSUBSCRIPT italic_l , italic_a end_POSTSUBSCRIPT (4)
𝐐l,a=𝐇l1𝐖l,aQsubscript𝐐𝑙𝑎superscript𝐇𝑙1superscriptsubscript𝐖𝑙𝑎𝑄\mathbf{Q}_{l,a}=\mathbf{H}^{l-1}\mathbf{W}_{l,a}^{Q}bold_Q start_POSTSUBSCRIPT italic_l , italic_a end_POSTSUBSCRIPT = bold_H start_POSTSUPERSCRIPT italic_l - 1 end_POSTSUPERSCRIPT bold_W start_POSTSUBSCRIPT italic_l , italic_a end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_Q end_POSTSUPERSCRIPT (5)
𝐊l,a=𝐇l1𝐖l,aKsubscript𝐊𝑙𝑎superscript𝐇𝑙1superscriptsubscript𝐖𝑙𝑎𝐾\mathbf{K}_{l,a}=\mathbf{H}^{l-1}\mathbf{W}_{l,a}^{K}bold_K start_POSTSUBSCRIPT italic_l , italic_a end_POSTSUBSCRIPT = bold_H start_POSTSUPERSCRIPT italic_l - 1 end_POSTSUPERSCRIPT bold_W start_POSTSUBSCRIPT italic_l , italic_a end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT (6)
𝐕l,a=𝐇l1𝐖l,aVsubscript𝐕𝑙𝑎superscript𝐇𝑙1superscriptsubscript𝐖𝑙𝑎𝑉\mathbf{V}_{l,a}=\mathbf{H}^{l-1}\mathbf{W}_{l,a}^{V}bold_V start_POSTSUBSCRIPT italic_l , italic_a end_POSTSUBSCRIPT = bold_H start_POSTSUPERSCRIPT italic_l - 1 end_POSTSUPERSCRIPT bold_W start_POSTSUBSCRIPT italic_l , italic_a end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_V end_POSTSUPERSCRIPT (7)

where 𝐇l1|x|×dhsuperscript𝐇𝑙1superscript𝑥subscript𝑑\mathbf{H}^{l-1}\in\mathbb{R}^{|x|\times{d_{h}}}bold_H start_POSTSUPERSCRIPT italic_l - 1 end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT | italic_x | × italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is the input vectors of Transformer layer l1𝑙1l-1italic_l - 1, with |x|𝑥|x|| italic_x | representing the length of input sequence and dhsubscript𝑑d_{h}italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT representing the hidden dimension; 𝐐l,asubscript𝐐𝑙𝑎\mathbf{Q}_{l,a}bold_Q start_POSTSUBSCRIPT italic_l , italic_a end_POSTSUBSCRIPT, 𝐊l,asubscript𝐊𝑙𝑎\mathbf{K}_{l,a}bold_K start_POSTSUBSCRIPT italic_l , italic_a end_POSTSUBSCRIPT, 𝐕l,asubscript𝐕𝑙𝑎\mathbf{V}_{l,a}bold_V start_POSTSUBSCRIPT italic_l , italic_a end_POSTSUBSCRIPT are linearly projections of 𝐇l1superscript𝐇𝑙1\mathbf{H}^{l-1}bold_H start_POSTSUPERSCRIPT italic_l - 1 end_POSTSUPERSCRIPT; 𝐖l,aQ,𝐖l,aK,𝐖l,aVdh×dksuperscriptsubscript𝐖𝑙𝑎𝑄superscriptsubscript𝐖𝑙𝑎𝐾superscriptsubscript𝐖𝑙𝑎𝑉superscriptsubscript𝑑subscript𝑑𝑘\mathbf{W}_{l,a}^{Q},\mathbf{W}_{l,a}^{K},\mathbf{W}_{l,a}^{V}\in\mathbb{R}^{d% _{h}\times{d_{k}}}bold_W start_POSTSUBSCRIPT italic_l , italic_a end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_Q end_POSTSUPERSCRIPT , bold_W start_POSTSUBSCRIPT italic_l , italic_a end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT , bold_W start_POSTSUBSCRIPT italic_l , italic_a end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_V end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT are parameter matrices; and dksubscript𝑑𝑘d_{k}italic_d start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT is attention head size.

In MHA distillation, MHA outputs (i.e., Ol,asubscriptO𝑙𝑎\textbf{O}_{l,a}O start_POSTSUBSCRIPT italic_l , italic_a end_POSTSUBSCRIPT defined in Eqn.(4)) are concatenated together and then split into a certain number of vector groups (named MHA-splits), in both teacher and student models. We suggest setting the number of MHA-splits as the number of student attention heads. After that, self-attention relation is transferred from teacher to student by minimizing the KL-divergence of MHA output vector similarities between teacher and student. Given that the n𝑛nitalic_n-th student layer is mapped to the m𝑚mitalic_m-th teacher layer, the loss function for MHA distillation MHAsubscriptMHA\mathcal{L}_{\mbox{{MHA}}}caligraphic_L start_POSTSUBSCRIPT MHA end_POSTSUBSCRIPT is defined in Eqn.(8):

MHA=1As|x|n=1Na=1Asi=1|x|DKL(𝐑m,a,iT||𝐑n,a,iS)\mathcal{L}_{\mbox{{MHA}}}=\frac{1}{A_{s}|x|}\sum_{n=1}^{N}\sum_{a=1}^{A_{s}}% \sum_{i=1}^{|x|}D_{\mbox{{KL}}}(\mathbf{R}_{m,a,i}^{T}||\mathbf{R}_{n,a,i}^{S})caligraphic_L start_POSTSUBSCRIPT MHA end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_A start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT | italic_x | end_ARG ∑ start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_a = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_A start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT | italic_x | end_POSTSUPERSCRIPT italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( bold_R start_POSTSUBSCRIPT italic_m , italic_a , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT | | bold_R start_POSTSUBSCRIPT italic_n , italic_a , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT ) (8)
𝐑m,aT=softmax(𝐎m,aT𝐎m,aT𝖳dsT)superscriptsubscript𝐑𝑚𝑎𝑇softmaxsuperscriptsubscript𝐎𝑚𝑎𝑇superscriptsubscript𝐎𝑚𝑎𝑇𝖳superscriptsubscript𝑑𝑠𝑇\mathbf{R}_{m,a}^{T}=\mbox{softmax}(\frac{\mathbf{O}_{m,a}^{T}\mathbf{O}_{m,a}% ^{T\mathsf{T}}}{\sqrt{d_{s}^{T}}})bold_R start_POSTSUBSCRIPT italic_m , italic_a end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT = softmax ( divide start_ARG bold_O start_POSTSUBSCRIPT italic_m , italic_a end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_O start_POSTSUBSCRIPT italic_m , italic_a end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T sansserif_T end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_d start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_ARG end_ARG ) (9)
𝐑n,aS=softmax(𝐎n,aS𝐎n,aS𝖳dsS)superscriptsubscript𝐑𝑛𝑎𝑆softmaxsuperscriptsubscript𝐎𝑛𝑎𝑆superscriptsubscript𝐎𝑛𝑎𝑆𝖳superscriptsubscript𝑑𝑠𝑆\mathbf{R}_{n,a}^{S}=\mbox{softmax}(\frac{\mathbf{O}_{n,a}^{S}\mathbf{O}_{n,a}% ^{S\mathsf{T}}}{\sqrt{d_{s}^{S}}})bold_R start_POSTSUBSCRIPT italic_n , italic_a end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT = softmax ( divide start_ARG bold_O start_POSTSUBSCRIPT italic_n , italic_a end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT bold_O start_POSTSUBSCRIPT italic_n , italic_a end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S sansserif_T end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_d start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT end_ARG end_ARG ) (10)

where |x|𝑥|x|| italic_x | is the length of input sequence; Assubscript𝐴𝑠A_{s}italic_A start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT is the number of MHA-splits; N𝑁Nitalic_N is the number of student Transformer-layers; matrices 𝐎m,aT|x|×dsTsuperscriptsubscript𝐎𝑚𝑎𝑇superscript𝑥superscriptsubscript𝑑𝑠𝑇\mathbf{O}_{m,a}^{T}\in\mathbb{R}^{|x|\times{d_{s}^{T}}}bold_O start_POSTSUBSCRIPT italic_m , italic_a end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT | italic_x | × italic_d start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT and 𝐎n,aS|x|×dsSsuperscriptsubscript𝐎𝑛𝑎𝑆superscript𝑥superscriptsubscript𝑑𝑠𝑆\mathbf{O}_{n,a}^{S}\in\mathbb{R}^{|x|\times{d_{s}^{S}}}bold_O start_POSTSUBSCRIPT italic_n , italic_a end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT | italic_x | × italic_d start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT are MHA outputs in MHA-split a𝑎aitalic_a at teacher’s Layer m𝑚mitalic_m and student’s Layer n𝑛nitalic_n; dsTsuperscriptsubscript𝑑𝑠𝑇d_{s}^{T}italic_d start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT and dsSsuperscriptsubscript𝑑𝑠𝑆d_{s}^{S}italic_d start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT are split-head size of teacher and student MHA-split; matrices 𝐑m,aT|x|×|x|superscriptsubscript𝐑𝑚𝑎𝑇superscript𝑥𝑥\mathbf{R}_{m,a}^{T}\in\mathbb{R}^{|x|\times{|x|}}bold_R start_POSTSUBSCRIPT italic_m , italic_a end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT | italic_x | × | italic_x | end_POSTSUPERSCRIPT and 𝐑n,aS|x|×|x|superscriptsubscript𝐑𝑛𝑎𝑆superscript𝑥𝑥\mathbf{R}_{n,a}^{S}\in\mathbb{R}^{|x|\times{|x|}}bold_R start_POSTSUBSCRIPT italic_n , italic_a end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT | italic_x | × | italic_x | end_POSTSUPERSCRIPT are MHA output vector similarities in MHA-split a𝑎aitalic_a at teacher’s Layer m𝑚mitalic_m and student’s Layer n𝑛nitalic_n, respectively.

Note that the teacher and student model should have the same number of MHA-splits. In MLKD-BERT, self-attention relation is transferred by MHA-split rather than MHA attention head. As such, student model does not have to set the same number of attention heads as its teacher. In this way, the number of attention heads in the student model can be set smaller, which could decrease inference time.

As for FFN sub-layer, the feature-level knowledge is distilled by minimizing the mean squared error of the output hidden states between teacher and student, according to the FFN distillation loss function FFNsubscriptFFN\mathcal{L}_{\mbox{{FFN}}}caligraphic_L start_POSTSUBSCRIPT FFN end_POSTSUBSCRIPT in Eqn.(11):

FFN=n=1NMSE(𝐇nS𝐖h,𝐇mT)subscriptFFNsuperscriptsubscript𝑛1𝑁MSEsuperscriptsubscript𝐇𝑛𝑆subscript𝐖superscriptsubscript𝐇𝑚𝑇\mathcal{L}_{\mbox{{FFN}}}=\sum_{n=1}^{N}\mbox{MSE}(\mathbf{H}_{n}^{S}\mathbf{% W}_{h},\mathbf{H}_{m}^{T})caligraphic_L start_POSTSUBSCRIPT FFN end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT MSE ( bold_H start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT bold_W start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT , bold_H start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) (11)

Given that the n𝑛nitalic_n-th student layer is mapped to the m𝑚mitalic_m-th teacher layer, matrices 𝐇nS|x|×dhSsuperscriptsubscript𝐇𝑛𝑆superscript𝑥superscriptsubscript𝑑𝑆\mathbf{H}_{n}^{S}\in\mathbb{R}^{|x|\times{d_{h}^{S}}}bold_H start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT | italic_x | × italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT and 𝐇mT|x|×dhTsuperscriptsubscript𝐇𝑚𝑇superscript𝑥superscriptsubscript𝑑𝑇\mathbf{H}_{m}^{T}\in\mathbb{R}^{|x|\times{d_{h}^{T}}}bold_H start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT | italic_x | × italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT are the hidden states of student’s Layer n𝑛nitalic_n and teacher’s Layer m𝑚mitalic_m; matrix 𝐖hdhS×dhTsubscript𝐖superscriptsuperscriptsubscript𝑑𝑆superscriptsubscript𝑑𝑇\mathbf{W}_{h}\in\mathbb{R}^{d_{h}^{S}\times{d_{h}^{T}}}bold_W start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT × italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT is a learnable linear transformation, which transforms the student’s hidden states into the same space as the teacher’s hidden states.

In summary, as the first stage distillation covers both embedding-layer and Transformer-layers, the above distillation loss functions are summed up to the first stage distillation loss function in Eqn.(12):

Stage 1=EMB+MHA+FFNsubscriptStage 1subscriptEMBsubscriptMHAsubscriptFFN\mathcal{L}_{\mbox{{Stage\mbox{\,}1}}}=\mathcal{L}_{\mbox{{EMB}}}+\mathcal{L}_% {\mbox{{MHA}}}+\mathcal{L}_{\mbox{{FFN}}}caligraphic_L start_POSTSUBSCRIPT Stage 1 end_POSTSUBSCRIPT = caligraphic_L start_POSTSUBSCRIPT EMB end_POSTSUBSCRIPT + caligraphic_L start_POSTSUBSCRIPT MHA end_POSTSUBSCRIPT + caligraphic_L start_POSTSUBSCRIPT FFN end_POSTSUBSCRIPT (12)

3.3 Prediction-layer Distillation

As samples with same class label tend to be similar than samples with different class labels, relation among samples would be valuable knowledge for prediction-layer distillation. Here, we will bring in sample similarity relation and sample contrastive relation to enhance prediction-layer distillation.

Sample similarity relation is defined as the sample similarities within a data batch, without considering sample labels. The sample similarity relation is distilled from teacher to student by minimizing the KL-divergence of sample similarities between teacher and student, according to the sample-similarity distillation loss function SSsubscriptSS\mathcal{L}_{\mbox{{SS}}}caligraphic_L start_POSTSUBSCRIPT SS end_POSTSUBSCRIPT defined in Eqn.(13):

SS=1bi=1bDKL(𝐑iT||𝐑iS)\mathcal{L}_{\mbox{{SS}}}=\frac{1}{b}\sum_{i=1}^{b}D_{\mbox{{KL}}}(\mathbf{R}_% {i}^{T}||\mathbf{R}_{i}^{S})caligraphic_L start_POSTSUBSCRIPT SS end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b end_POSTSUPERSCRIPT italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( bold_R start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT | | bold_R start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT ) (13)
𝐑T=softmax(𝐆T𝐆T𝖳dhT)superscript𝐑𝑇softmaxsuperscript𝐆𝑇superscript𝐆𝑇𝖳superscriptsubscript𝑑𝑇\mathbf{R}^{T}=\mbox{softmax}(\frac{\mathbf{G}^{T}\mathbf{G}^{T\mathsf{T}}}{% \sqrt{d_{h}^{T}}})bold_R start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT = softmax ( divide start_ARG bold_G start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_G start_POSTSUPERSCRIPT italic_T sansserif_T end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_ARG end_ARG ) (14)
𝐑S=softmax(𝐆S𝐆S𝖳dhS)superscript𝐑𝑆softmaxsuperscript𝐆𝑆superscript𝐆𝑆𝖳superscriptsubscript𝑑𝑆\mathbf{R}^{S}=\mbox{softmax}(\frac{\mathbf{G}^{S}\mathbf{G}^{S\mathsf{T}}}{% \sqrt{d_{h}^{S}}})bold_R start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT = softmax ( divide start_ARG bold_G start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT bold_G start_POSTSUPERSCRIPT italic_S sansserif_T end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT end_ARG end_ARG ) (15)

where b𝑏bitalic_b is batch size; dhTsuperscriptsubscript𝑑𝑇d_{h}^{T}italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT and dhSsuperscriptsubscript𝑑𝑆d_{h}^{S}italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT are hidden dimension of teacher and student; matrices 𝐆Tb×dhTsuperscript𝐆𝑇superscript𝑏superscriptsubscript𝑑𝑇\mathbf{G}^{T}\in\mathbb{R}^{b\times{d_{h}^{T}}}bold_G start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_b × italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT and 𝐆Sb×dhSsuperscript𝐆𝑆superscript𝑏superscriptsubscript𝑑𝑆\mathbf{G}^{S}\in\mathbb{R}^{b\times{d_{h}^{S}}}bold_G start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_b × italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT are sample representations in a batch, i.e., [CLS] outputs from the last Transformer-layer of teacher and student; 𝐑Tb×bsuperscript𝐑𝑇superscript𝑏𝑏\mathbf{R}^{T}\in\mathbb{R}^{b\times{b}}bold_R start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_b × italic_b end_POSTSUPERSCRIPT and 𝐑Sb×bsuperscript𝐑𝑆superscript𝑏𝑏\mathbf{R}^{S}\in\mathbb{R}^{b\times{b}}bold_R start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_b × italic_b end_POSTSUPERSCRIPT are sample similarity matrices of teacher and student, respectively.

Sample contrastive relation is employed to map samples with same and different class labels (according to ground truth of training samples) into close and distant representation space, respectively. Sample contrastive relation is distilled by minimizing the sample-contrastive distillation loss function SCsubscriptSC\mathcal{L}_{\mbox{{SC}}}caligraphic_L start_POSTSUBSCRIPT SC end_POSTSUBSCRIPT in Eqn.(16) (Khosla et al., 2020):

SC=12bi=12biI1|P(i)|pP(i)InfoNCE(i,p)subscriptSC12𝑏superscriptsubscript𝑖12𝑏subscript𝑖𝐼1𝑃𝑖subscript𝑝𝑃𝑖subscriptInfoNCE𝑖𝑝\mathcal{L}_{\mbox{{SC}}}=\frac{1}{2b}\sum_{i=1}^{2b}\sum_{i\in{I}}\frac{1}{|P% (i)|}\sum_{p\in{P(i)}}\mathcal{L}_{\mbox{{InfoNCE}}}(i,p)caligraphic_L start_POSTSUBSCRIPT SC end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG 2 italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_b end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_i ∈ italic_I end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG | italic_P ( italic_i ) | end_ARG ∑ start_POSTSUBSCRIPT italic_p ∈ italic_P ( italic_i ) end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT InfoNCE end_POSTSUBSCRIPT ( italic_i , italic_p ) (16)
InfoNCE(i,p)=logexp(𝐡i𝐡p)/ρaA(i)exp(𝐡i𝐡a)/ρsubscriptInfoNCE𝑖𝑝logexpsubscript𝐡𝑖subscript𝐡𝑝𝜌subscript𝑎𝐴𝑖expsubscript𝐡𝑖subscript𝐡𝑎𝜌\mathcal{L}_{\mbox{{InfoNCE}}}(i,p)=-\mbox{log}\frac{\mbox{exp}(\mathbf{h}_{i}% \cdot\mathbf{h}_{p})/\rho}{\sum_{a\in{A(i)}}\mbox{exp}(\mathbf{h}_{i}\cdot% \mathbf{h}_{a})/\rho}caligraphic_L start_POSTSUBSCRIPT InfoNCE end_POSTSUBSCRIPT ( italic_i , italic_p ) = - log divide start_ARG exp ( bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ⋅ bold_h start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ) / italic_ρ end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_a ∈ italic_A ( italic_i ) end_POSTSUBSCRIPT exp ( bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ⋅ bold_h start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT ) / italic_ρ end_ARG (17)

where b𝑏bitalic_b is batch size; I{1,,2b}𝐼12𝑏I\equiv{\{1,...,2b\}}italic_I ≡ { 1 , … , 2 italic_b }, A(i)I\{i}𝐴𝑖\𝐼𝑖A(i)\equiv{I\backslash\{i\}}italic_A ( italic_i ) ≡ italic_I \ { italic_i }, P(i){p|pA(i),yp=yi}𝑃𝑖conditional-set𝑝formulae-sequence𝑝𝐴𝑖subscript𝑦𝑝subscript𝑦𝑖P(i)\equiv{\{p|p\in{A(i)},y_{p}=y_{i}\}}italic_P ( italic_i ) ≡ { italic_p | italic_p ∈ italic_A ( italic_i ) , italic_y start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT = italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT }; yisubscript𝑦𝑖y_{i}italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is class label of i𝑖iitalic_i-th sample; ρ𝜌\rhoitalic_ρ is scalar temperature parameter; hisubscripth𝑖\textbf{h}_{i}h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is the i𝑖iitalic_i-th row of 𝐇2b×dhT𝐇superscript2𝑏superscriptsubscript𝑑𝑇\mathbf{H}\in\mathbb{R}^{2b\times{d_{h}^{T}}}bold_H ∈ blackboard_R start_POSTSUPERSCRIPT 2 italic_b × italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT; 𝐇=Concat(𝐆S𝐖g,𝐆T)=[h1;;h2b],i.e.,𝐆S𝐖g=[𝐡1;;𝐡b],𝐆T=[𝐡b+1;;𝐡2b]\mathbf{H}=\mbox{Concat}(\mathbf{G}^{S}\mathbf{W}_{g},\mathbf{G}^{T})=[\textbf% {h}_{1};...;\textbf{h}_{2b}],i.e.,\mathbf{G}^{S}\mathbf{W}_{g}=[\mathbf{h}_{1}% ;...;\mathbf{h}_{b}],\mathbf{G}^{T}=[\mathbf{h}_{b+1};...;\mathbf{h}_{2b}]bold_H = Concat ( bold_G start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT bold_W start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT , bold_G start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) = [ h start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ; … ; h start_POSTSUBSCRIPT 2 italic_b end_POSTSUBSCRIPT ] , italic_i . italic_e . , bold_G start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT bold_W start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT = [ bold_h start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ; … ; bold_h start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT ] , bold_G start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT = [ bold_h start_POSTSUBSCRIPT italic_b + 1 end_POSTSUBSCRIPT ; … ; bold_h start_POSTSUBSCRIPT 2 italic_b end_POSTSUBSCRIPT ]; 𝐖gdhS×dhTsubscript𝐖𝑔superscriptsuperscriptsubscript𝑑𝑆superscriptsubscript𝑑𝑇\mathbf{W}_{g}\in\mathbb{R}^{d_{h}^{S}\times{d_{h}^{T}}}bold_W start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT × italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT is linear transformation matrix; 𝐆Tsuperscript𝐆𝑇\mathbf{G}^{T}bold_G start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT and 𝐆Ssuperscript𝐆𝑆\mathbf{G}^{S}bold_G start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT are defined in Eqn.(14) and Eqn.(15).

Similar to previous distillation works, we also adopt soft label distillation by minimizing the soft-label distillation loss function KDsubscriptKD\mathcal{L}_{\mbox{{KD}}}caligraphic_L start_POSTSUBSCRIPT KD end_POSTSUBSCRIPT in Eqn.(18):

KD=DKL(softmax(𝐳T/τ)||softmax(𝐳S/τ))\mathcal{L}_{\mbox{{KD}}}=D_{\mbox{{KL}}}(\mbox{softmax}(\mathbf{z}^{T}/\tau)|% |\mbox{softmax}(\mathbf{z}^{S}/\tau))caligraphic_L start_POSTSUBSCRIPT KD end_POSTSUBSCRIPT = italic_D start_POSTSUBSCRIPT KL end_POSTSUBSCRIPT ( softmax ( bold_z start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT / italic_τ ) | | softmax ( bold_z start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT / italic_τ ) ) (18)

where τ𝜏\tauitalic_τ is scalar temperature parameter; 𝐳Tsuperscript𝐳𝑇\mathbf{z}^{T}bold_z start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT and 𝐳Ssuperscript𝐳𝑆\mathbf{z}^{S}bold_z start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT are logits predicted by teacher and student, respectively.

Model #Params Speedup MNLI-m/-mm QQP QNLI SST-2 CoLA STS-B MRPC RTE Avg
(393k) (364k) (105k) (67k) (8.5k) (5.7k) (3.7k) (2.5k)
BERT-base 109M 1.0×1.0\times1.0 × 84.2/83.6 71.6 90.8 94.3 52.6 83.9 87.3 67.3 79.5
BERTTINYsubscriptBERTTINY\mbox{BERT}_{\mbox{{TINY}}}BERT start_POSTSUBSCRIPT TINY end_POSTSUBSCRIPT 14.5M 9.4×9.4\times9.4 × 75.4/74.9 66.5 84.8 87.6 19.5 77.1 83.2 62.6 70.2
BERT4-PKDsubscriptBERT4-PKD\mbox{BERT}_{4}\mbox{-PKD}BERT start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT -PKD 52.2M 3.0×3.0\times3.0 × 79.9/79.3 70.2 85.1 89.4 24.8 79.8 82.6 62.3 72.6
DistilBERT4subscriptDistilBERT4\mbox{DistilBERT}_{4}DistilBERT start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT 52.2M 3.0×3.0\times3.0 × 78.9/78.0 68.5 85.2 91.4 32.8 76.1 82.4 54.1 71.9
BERT-EMD4subscriptBERT-EMD4\mbox{BERT-EMD}_{4}BERT-EMD start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT 14.5M 9.4×9.4\times9.4 × 82.1/80.6 69.3 87.2 91.0 25.6 82.3 87.6 66.2 74.7
TinyBERT4subscriptTinyBERT4\mbox{TinyBERT}_{4}TinyBERT start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT 14.5M 9.4×9.4\times9.4 × 81.4/80.4 69.9 85.9 91.9 35.2 81.5 85.4 62.1 74.8
MLKD-BERT𝟒subscriptMLKD-BERT4\bm{\mbox{MLKD-BERT}_{4}}MLKD-BERT start_POSTSUBSCRIPT bold_4 end_POSTSUBSCRIPT 14.5M 9.4×9.4\times9.4 × 82.0/80.7 70.6 87.5 91.9 35.5 81.9 86.3 63.5 75.6
BERT6-PKDsubscriptBERT6-PKD\mbox{BERT}_{6}\mbox{-PKD}BERT start_POSTSUBSCRIPT 6 end_POSTSUBSCRIPT -PKD 67.0M 2.0×2.0\times2.0 × 81.5/81.0 70.7 89.0 92.0 43.5 81.6 85.0 65.5 76.6
DistilBERT6subscriptDistilBERT6\mbox{DistilBERT}_{6}DistilBERT start_POSTSUBSCRIPT 6 end_POSTSUBSCRIPT 67.0M 2.0×2.0\times2.0 × 82.6/81.3 70.1 88.9 92.5 49.0 81.3 86.9 58.4 76.8
TinyBERT6subscriptTinyBERT6\mbox{TinyBERT}_{6}TinyBERT start_POSTSUBSCRIPT 6 end_POSTSUBSCRIPT 67.0M 2.0×2.0\times2.0 × 83.9/83.4 72.0 89.9 93.7 46.7 83.3 85.7 66.6 78.4
MINILMv2 67.0M 2.0×2.0\times2.0 × 83.8/83.3 70.9 90.2 92.9 46.6 84.3 89.1 69.2 78.9
MLKD-BERT𝟔subscriptMLKD-BERT6\bm{\mbox{MLKD-BERT}_{6}}MLKD-BERT start_POSTSUBSCRIPT bold_6 end_POSTSUBSCRIPT 67.0M 2.0×2.0\times2.0 × 84.4/83.5 72.2 90.8 93.3 48.0 84.3 87.3 67.8 79.1
Table 1: Comparative Studies on GLUE Benchmark. The number under each task represents the number of its training samples. Avg represents the average score over all tasks. The subscript within each model name represents the number of Transformer layers. The best result on each task is in-bold.
Model SQuAD 1.1 SQuAD 2.0 Avg
BERT-base 88.5 77.0 82.8
BERT4-PKDsubscriptBERT4-PKD\mbox{BERT}_{4}\mbox{-PKD}BERT start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT -PKD 79.5 64.6 72.1
DistilBERT4subscriptDistilBERT4\mbox{DistilBERT}_{4}DistilBERT start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT 81.2 64.1 72.7
TinyBERT4subscriptTinyBERT4\mbox{TinyBERT}_{4}TinyBERT start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT 81.0 68.2 74.6
MLKD-BERT𝟒subscriptMLKD-BERT4\bm{\mbox{MLKD-BERT}_{4}}MLKD-BERT start_POSTSUBSCRIPT bold_4 end_POSTSUBSCRIPT 82.0 68.9 75.5
BERT6-PKDsubscriptBERT6-PKD\mbox{BERT}_{6}\mbox{-PKD}BERT start_POSTSUBSCRIPT 6 end_POSTSUBSCRIPT -PKD 85.3 69.8 77.6
DistilBERT6subscriptDistilBERT6\mbox{DistilBERT}_{6}DistilBERT start_POSTSUBSCRIPT 6 end_POSTSUBSCRIPT 86.2 69.5 77.9
TinyBERT6subscriptTinyBERT6\mbox{TinyBERT}_{6}TinyBERT start_POSTSUBSCRIPT 6 end_POSTSUBSCRIPT 88.0 76.1 82.1
MINILMv2 - 76.3 -
MLKD-BERT𝟔subscriptMLKD-BERT6\bm{\mbox{MLKD-BERT}_{6}}MLKD-BERT start_POSTSUBSCRIPT bold_6 end_POSTSUBSCRIPT 88.3 76.5 82.4
Table 2: Comparative Studies on SQuAD 1.1 and SQuAD 2.0.
MLKD-BERT𝟒subscriptMLKD-BERT4\bm{\mbox{MLKD-BERT}_{4}}MLKD-BERT start_POSTSUBSCRIPT bold_4 end_POSTSUBSCRIPT 𝑩sizesubscript𝑩size\bm{B_{\mbox{{{size}}}}}bold_italic_B start_POSTSUBSCRIPT size end_POSTSUBSCRIPT MNLI-m MNLI-mm
𝑨𝒉𝑺=𝟏𝟐superscriptsubscript𝑨𝒉𝑺12\bm{A_{h}^{S}=12}bold_italic_A start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT bold_italic_S end_POSTSUPERSCRIPT bold_= bold_12 𝑨𝒉𝑺=𝟔superscriptsubscript𝑨𝒉𝑺6\bm{A_{h}^{S}=6}bold_italic_A start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT bold_italic_S end_POSTSUPERSCRIPT bold_= bold_6 𝑨𝒉𝑺=𝟑superscriptsubscript𝑨𝒉𝑺3\bm{A_{h}^{S}=3}bold_italic_A start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT bold_italic_S end_POSTSUPERSCRIPT bold_= bold_3 𝑨𝒉𝑺=𝟏𝟐superscriptsubscript𝑨𝒉𝑺12\bm{A_{h}^{S}=12}bold_italic_A start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT bold_italic_S end_POSTSUPERSCRIPT bold_= bold_12 𝑨𝒉𝑺=𝟔superscriptsubscript𝑨𝒉𝑺6\bm{A_{h}^{S}=6}bold_italic_A start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT bold_italic_S end_POSTSUPERSCRIPT bold_= bold_6 𝑨𝒉𝑺=𝟑superscriptsubscript𝑨𝒉𝑺3\bm{A_{h}^{S}=3}bold_italic_A start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT bold_italic_S end_POSTSUPERSCRIPT bold_= bold_3
1 4.23 ±plus-or-minus\pm± 0.19 4.12 ±plus-or-minus\pm± 0.24 4.07 ±plus-or-minus\pm± 0.25 4.24 ±plus-or-minus\pm± 0.28 4.14 ±plus-or-minus\pm± 0.40 4.03 ±plus-or-minus\pm± 0.27
(-2.60%) (-3.78%) (-2.36%) (-4.95%)
16 5.46 ±plus-or-minus\pm± 0.19 5.09 ±plus-or-minus\pm± 0.22 4.90 ±plus-or-minus\pm± 0.32 5.47 ±plus-or-minus\pm± 0.17 5.10 ±plus-or-minus\pm± 0.29 4.84 ±plus-or-minus\pm± 0.20
Inference (-6.78%) (-10.26%) (-6.76%) (-11.52%)
time(ms) 32 10.34 ±plus-or-minus\pm± 0.18 9.46 ±plus-or-minus\pm± 0.17 9.00 ±plus-or-minus\pm± 0.30 10.35 ±plus-or-minus\pm± 0.20 9.46 ±plus-or-minus\pm± 0.13 8.99 ±plus-or-minus\pm± 0.24
(-8.51%) (-12.96%) (-8.60%) (-13.14%)
64 19.28 ±plus-or-minus\pm± 0.20 17.52 ±plus-or-minus\pm± 0.22 16.56 ±plus-or-minus\pm± 0.18 19.28 ±plus-or-minus\pm± 0.21 17.49 ±plus-or-minus\pm± 0.17 16.59 ±plus-or-minus\pm± 0.17
(-9.13%) (-14.11%) (-9.28%) (-13.95%)
Performance 82.0 81.4 80.6 80.7 80.6 79.5
(-0.73%) (-1.70%) (-0.12%) (-1.49%)
Table 3: Effect of Attention Head Number AhSsuperscriptsubscript𝐴𝑆A_{h}^{S}italic_A start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT on Model Performance and Inference Time.

In summary, the sample-similarity, sample-contrastive, and soft-label distillation loss functions are summed up to the second stage distillation loss function in Eqn.(19):

Stage 2=SS+SC+KDsubscriptStage 2subscriptSSsubscriptSCsubscriptKD\mathcal{L}_{\mbox{{Stage\mbox{\,}2}}}=\mathcal{L}_{\mbox{{SS}}}+\mathcal{L}_{% \mbox{{SC}}}+\mathcal{L}_{\mbox{{KD}}}caligraphic_L start_POSTSUBSCRIPT Stage 2 end_POSTSUBSCRIPT = caligraphic_L start_POSTSUBSCRIPT SS end_POSTSUBSCRIPT + caligraphic_L start_POSTSUBSCRIPT SC end_POSTSUBSCRIPT + caligraphic_L start_POSTSUBSCRIPT KD end_POSTSUBSCRIPT (19)

4 Experiments

4.1 Experimental Datasets

Our experiments are conducted on General Language Understanding Evaluation (GLUE) (Wang et al., 2019) benchmark and extractive question answering tasks. The former is sentence-level task while the latter is token-level task. GLUE includes 8 tasks: 1) Corpus of Linguistic Acceptability (CoLA) (Warstadt et al., 2019); 2) Stanford Sentiment Treebank (SST-2) (Socher et al., 2013); 3) Microsoft Research Paraphrase Corpus (MRPC) (Dolan and Brockett, 2005); 4) Semantic Textual Similarity Benchmark (STS-B) (Cer et al., 2017); 5) Quora Question Pairs (QQP) (Chen et al., 2018); 6) Question Natural Language Inference (QNLI) (Rajpurkar et al., 2016); 7) Recognizing Textual Entailment (RTE) (Bentivogli et al., 2009); 8) Multi-Genre Natural Language Inference (MNLI) (Williams et al., 2018), which is further divided into in-domain (MNLI-m) and cross-domain (MNLI-mm) tasks. The extractive question answering tasks include SQuAD 1.1 (Rajpurkar et al., 2016) and SQuAD 2.0 (Rajpurkar et al., 2018). The GLUE tasks are evaluated on GLUE test sets and the extractive question answering tasks are evaluated on dev sets.

4.2 Experimental Setup

We take a pre-trained language model BERT-base (Devlin et al., 2019) with 109M parameters as the teacher model (number of layers M=12𝑀12M=12italic_M = 12, hidden size dhT=768superscriptsubscript𝑑𝑇768d_{h}^{T}=768italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT = 768, intermediate size dhT=3072superscriptsubscript𝑑superscript𝑇3072d_{h}^{T^{\prime}}=3072italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT = 3072, and attention head number AhT=12superscriptsubscript𝐴𝑇12A_{h}^{T}=12italic_A start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT = 12), which is fine-tuned for each specific task. Two student models are instantiated for comparative studies: 1) MLKD-BERT4subscriptMLKD-BERT4\mbox{MLKD-BERT}_{4}MLKD-BERT start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT (N=4,dhS=312,dhS=1200,AhS=12formulae-sequence𝑁4formulae-sequencesuperscriptsubscript𝑑𝑆312formulae-sequencesuperscriptsubscript𝑑superscript𝑆1200superscriptsubscript𝐴𝑆12N=4,d_{h}^{S}=312,d_{h}^{S^{\prime}}=1200,A_{h}^{S}=12italic_N = 4 , italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT = 312 , italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT = 1200 , italic_A start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT = 12) with 14.5M parameters; and 2) MLKD-BERT6subscriptMLKD-BERT6\mbox{MLKD-BERT}_{6}MLKD-BERT start_POSTSUBSCRIPT 6 end_POSTSUBSCRIPT (N=6,dhS=768,dhS=3072,AhS=12formulae-sequence𝑁6formulae-sequencesuperscriptsubscript𝑑𝑆768formulae-sequencesuperscriptsubscript𝑑superscript𝑆3072superscriptsubscript𝐴𝑆12N=6,d_{h}^{S}=768,d_{h}^{S^{\prime}}=3072,A_{h}^{S}=12italic_N = 6 , italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT = 768 , italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT = 3072 , italic_A start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT = 12) with 67.0M parameters. The student models are initialized with the general distillation model delivered by TinyBERT111https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/TinyBERT. The detailed hyper-parameters are presented in Appendix A.

4.3 Comparative Studies

Comparative studies are conducted to evaluate performance of MLKD-BERT against state-of-the-art BERT distillation baselines, including BERT-PKD (Sun et al., 2019), DistilBERT (Sanh et al., 2019), BERT-EMD (Li et al., 2020), TinyBERT (Jiao et al., 2020), and MINILMv2 (Wang et al., 2021). In addition, to measure the performance improvement delivered by distillation, MLKD-BERT is compared with a same structured pre-trained BERT model BERTTINYsubscriptBERTTINY\mbox{BERT}_{\mbox{{TINY}}}BERT start_POSTSUBSCRIPT TINY end_POSTSUBSCRIPT (Turc et al., 2019), which is fine-tuned for each specific task without knowledge distillation. Note that all experiments are done without data augmentation.

The comparative results evaluated on the test sets of GLUE official benchmark 222https://gluebenchmark.com are presented in Table 1. Here, different evaluation indices are adopted with regarding to the tasks: F1 metric for MRPC and QQP, Spearman correlation for STS-B, Matthew’s correlation for CoLA, and Accuracy for the other tasks.

Experimental results in Table 1 show that, among 4-layer models, our MLKD-BERT4subscriptMLKD-BERT4\mbox{MLKD-BERT}_{4}MLKD-BERT start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT ranks first on average performance and 5 specific tasks, while ranks second on the rest tasks. Among 6-layer models, our MLKD-BERT6subscriptMLKD-BERT6\mbox{MLKD-BERT}_{6}MLKD-BERT start_POSTSUBSCRIPT 6 end_POSTSUBSCRIPT delivers similar results. Moreover, MLKD-BERT4subscriptMLKD-BERT4\mbox{MLKD-BERT}_{4}MLKD-BERT start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT performs better than BERT4-PKDsubscriptBERT4-PKD\mbox{BERT}_{4}\mbox{-PKD}BERT start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT -PKD and DistilBERT4subscriptDistilBERT4\mbox{DistilBERT}_{4}DistilBERT start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT, with only 30%percent3030\%30 % parameters and inference time. Therefore, our MLKD-BERT has an average improved performance, compared to state-of-the-art knowledge distillation methods on BERT.

To evaluate the effectiveness of MLKD-BERT distillation, MLKD-BERT is compared with BERTTINYsubscriptBERTTINY\mbox{BERT}_{\mbox{{TINY}}}BERT start_POSTSUBSCRIPT TINY end_POSTSUBSCRIPT and its teacher BERT-base. Results show that our MLKD-BERT4subscriptMLKD-BERT4\mbox{MLKD-BERT}_{4}MLKD-BERT start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT is consistently better than BERTTINYsubscriptBERTTINY\mbox{BERT}_{\mbox{{TINY}}}BERT start_POSTSUBSCRIPT TINY end_POSTSUBSCRIPT on all tasks with 5.35.35.35.3 improvement on average. Comparing with teacher model BERT-base, MLKD-BERT4subscriptMLKD-BERT4\mbox{MLKD-BERT}_{4}MLKD-BERT start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT is 7.5x smaller and 9.4x faster, while kee** average 95.1%percent95.195.1\%95.1 % performance of its teacher; MLKD-BERT6subscriptMLKD-BERT6\mbox{MLKD-BERT}_{6}MLKD-BERT start_POSTSUBSCRIPT 6 end_POSTSUBSCRIPT keeps average 99.5%percent99.599.5\%99.5 % performance of its teacher, with only 50%percent5050\%50 % parameters and inference time. As such, the distillation strategies designed for MLKD-BERT are effective to enhance model performance.

The comparative results evaluated on the dev sets of SQuAD 1.1 and SQuAD 2.0 are presented in Table 2, with F1 metric for evaluation. As those two tasks are token-level tasks, we remove SSsubscriptSS\mathcal{L}_{\mbox{{SS}}}caligraphic_L start_POSTSUBSCRIPT SS end_POSTSUBSCRIPT and SCsubscriptSC\mathcal{L}_{\mbox{{SC}}}caligraphic_L start_POSTSUBSCRIPT SC end_POSTSUBSCRIPT in Stage 2. The results in Table 2 show that our MLKD-BERT4subscriptMLKD-BERT4\mbox{MLKD-BERT}_{4}MLKD-BERT start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT and MLKD-BERT6subscriptMLKD-BERT6\mbox{MLKD-BERT}_{6}MLKD-BERT start_POSTSUBSCRIPT 6 end_POSTSUBSCRIPT outperform all other methods on SQuAD 1.1 and SQuAD 2.0, which further demonstrates the effectiveness of our method.

4.4 Inference Time vs. Performance

As MLKD-BERT can flexibly set attention head number for student model, this group of experiments are to evaluate its effect on model performance and inference time. 4-layer student model MLKD-BERT4subscriptMLKD-BERT4\mbox{MLKD-BERT}_{4}MLKD-BERT start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT instantiated with varied numbers of attention heads are experimented on GLUE tasks. MLKD-BERT4subscriptMLKD-BERT4\mbox{MLKD-BERT}_{4}MLKD-BERT start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT with 12 attention heads is taken as baseline as it has the same number of attention heads as its teacher. The inference time is measured by the time (±plus-or-minus\pm± standard deviation) required for each batch data on a single GeForce RTX 2080Ti GPU. The batch size Bsizesubscript𝐵sizeB_{\mbox{{size}}}italic_B start_POSTSUBSCRIPT size end_POSTSUBSCRIPT is varied to provide more insights.

As shown in Table 3, with decrease of attention head number, the inference time decreases very fast with relatively little performance drop. And such effect is emphasized with the increase of batch size. As for batch size of 64 on MNLI-m task, when attention head number drops from 12 to 3, the inference time of MLKD-BERT4subscriptMLKD-BERT4\mbox{MLKD-BERT}_{4}MLKD-BERT start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT decreases 14.11% while kee** over 98% prediction performance. More experimental results presented in Appendix B, deliver similar effect. Therefore, the flexible setting of student attention head number would allow substantial inference time decrease at little expense of performance drop.

Model 𝑨𝒔subscript𝑨𝒔\bm{A_{s}}bold_italic_A start_POSTSUBSCRIPT bold_italic_s end_POSTSUBSCRIPT MNLI-m/-mm SST-2 QQP Avg
3 80.6/79.5 91.9 70.0 80.5
𝑨𝒉𝑺=𝟑superscriptsubscript𝑨𝒉𝑺3\bm{A_{h}^{S}=3}bold_italic_A start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT bold_italic_S end_POSTSUPERSCRIPT bold_= bold_3 6 80.6/79.0 90.4 69.8 79.9
12 80.4/79.3 91.0 69.9 80.1
6 81.4/80.6 92.0 70.4 81.1
𝑨𝒉𝑺=𝟔superscriptsubscript𝑨𝒉𝑺6\bm{A_{h}^{S}=6}bold_italic_A start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT bold_italic_S end_POSTSUPERSCRIPT bold_= bold_6 8 80.7/80.0 91.3 70.4 80.6
12 81.0/80.2 91.9 70.4 80.9
Table 4: Effect of MHA-split Number on Model Performance.
Model Method MNLI-m/-mm SST-2 QQP Avg
Concat-split 80.6/79.5 91.9 70.0 80.5
𝑨𝒉𝑺=𝟑superscriptsubscript𝑨𝒉𝑺3\bm{A_{h}^{S}=3}bold_italic_A start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT bold_italic_S end_POSTSUPERSCRIPT bold_= bold_3 Average 80.3/79.2 90.8 70.0 80.0
Random 80.4/79.2 90.8 69.9 80.0
Concat-split 81.4/80.6 92.0 70.4 81.1
𝑨𝒉𝑺=𝟔superscriptsubscript𝑨𝒉𝑺6\bm{A_{h}^{S}=6}bold_italic_A start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT bold_italic_S end_POSTSUPERSCRIPT bold_= bold_6 Average 81.2/80.4 91.9 70.5 81.0
Random 81.4/80.0 92.0 70.3 80.9
Table 5: Concat-split vs. Average and Random map**

4.5 MHA-split Studies

First we study the effect of MHA-split number Assubscript𝐴𝑠A_{s}italic_A start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT on model performance in MHA distillation. We instantiate MLKD-BERT4subscriptMLKD-BERT4\mbox{MLKD-BERT}_{4}MLKD-BERT start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT with 3 attention heads (AhS=3superscriptsubscript𝐴𝑆3A_{h}^{S}=3italic_A start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT = 3) and 6 attention heads (AhS=6superscriptsubscript𝐴𝑆6A_{h}^{S}=6italic_A start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT = 6), respectively. For the student model with AhS=3superscriptsubscript𝐴𝑆3A_{h}^{S}=3italic_A start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT = 3, the number of MHA-splits varies with 3, 6 and 12; and for the model with AhS=6superscriptsubscript𝐴𝑆6A_{h}^{S}=6italic_A start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT = 6, the number of MHA-splits varies with 6, 8 and 12. The training is conducted under the supervision of teacher model BERT-base, which has 12 attention heads.

Table 4 shows that, the student model performs best, when the MHA-split number Assubscript𝐴𝑠A_{s}italic_A start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT equals to the number of student attention heads AhSsuperscriptsubscript𝐴𝑆A_{h}^{S}italic_A start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT. This might because it could keep the integrity of student model’s subspaces. That explains why we suggest setting the number of MHA-splits as the number of student attention heads.

As the number of attention heads in student model could be smaller than that in teacher model, the teacher attention heads are divided into several MHA-splits by the number of student attention heads, so that each student attention head is mapped to several teacher attention heads in a MHA-split. Our method (named Concat-split) concatenates all teacher attention heads in a MHA-split together. We could have two other methods to map student attention head to its teacher in a MHA-split: 1) Average map** averages the teacher attention heads in a MHA-split; 2) Random map** randomly selects one teacher attention head in a MHA-split.

The second group of experiments are conducted to compare the performance of Concat-split against Average and Random map**. MLKD-BERT4subscriptMLKD-BERT4\mbox{MLKD-BERT}_{4}MLKD-BERT start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT with 3 attention heads (AhS=3superscriptsubscript𝐴𝑆3A_{h}^{S}=3italic_A start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT = 3) and 6 attention heads (AhS=6superscriptsubscript𝐴𝑆6A_{h}^{S}=6italic_A start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT = 6) are instantiated, respectively. As shown in Table 5, our Concat-split outperforms the two other methods on almost all tasks. We think it is because our method has less information loss.

Model MNLI-m/-mm CoLA MRPC Avg
MLKD-BERT4subscriptMLKD-BERT4\mbox{MLKD-BERT}_{4}MLKD-BERT start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT 82.0/80.7 35.5 86.3 71.2
FFNsubscriptFFN\mathcal{L}_{\mbox{{FFN}}}caligraphic_L start_POSTSUBSCRIPT FFN end_POSTSUBSCRIPT for both FFN 81.7/80.8 25.7 85.7 68.5
and MHA sub-layers
MHAsubscriptMHA\mathcal{L}_{\mbox{{MHA}}}caligraphic_L start_POSTSUBSCRIPT MHA end_POSTSUBSCRIPT for both FFN 80.6/79.7 26.6 85.6 68.1
and MHA sub-layers
Table 6: FFN vs. MHA Distillation.

4.6 FFN vs. MHA Distillation

In Transformer-layer distillation, as FFN distillation and MHA distillation can be applied on both FFN sub-layer and MHA sub-layer, we are to study whether those two distillations could replace each other. As shown in Table 6, neither MHA distillation MHAsubscriptMHA\mathcal{L}_{\mbox{{MHA}}}caligraphic_L start_POSTSUBSCRIPT MHA end_POSTSUBSCRIPT nor FFN distillation FFNsubscriptFFN\mathcal{L}_{\mbox{{FFN}}}caligraphic_L start_POSTSUBSCRIPT FFN end_POSTSUBSCRIPT could perform well on both FFN and MHA sub-layers. Instead, MHAsubscriptMHA\mathcal{L}_{\mbox{{MHA}}}caligraphic_L start_POSTSUBSCRIPT MHA end_POSTSUBSCRIPT on MHA sub-layer jointing FFNsubscriptFFN\mathcal{L}_{\mbox{{FFN}}}caligraphic_L start_POSTSUBSCRIPT FFN end_POSTSUBSCRIPT on FFN sub-layer delivers best performance, which is conducted by our method MLKD-BERT4subscriptMLKD-BERT4\mbox{MLKD-BERT}_{4}MLKD-BERT start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT.

Since the outputs of FFN sub-layer contain feature representations of tokens while the outputs of MHA sub-layer encode the relations of different tokens, feature-level distillation (FFN distillation) fits better for FFN sub-layer, and relation-level distillation (MHA distillation) performs better for MHA. That might explain why our method performs best.

4.7 Ablation Studies

The effect of different distillation loss functions (including EMBsubscriptEMB\mathcal{L}_{\mbox{{EMB}}}caligraphic_L start_POSTSUBSCRIPT EMB end_POSTSUBSCRIPT, MHAsubscriptMHA\mathcal{L}_{\mbox{{MHA}}}caligraphic_L start_POSTSUBSCRIPT MHA end_POSTSUBSCRIPT, FFNsubscriptFFN\mathcal{L}_{\mbox{{FFN}}}caligraphic_L start_POSTSUBSCRIPT FFN end_POSTSUBSCRIPT, SSsubscriptSS\mathcal{L}_{\mbox{{SS}}}caligraphic_L start_POSTSUBSCRIPT SS end_POSTSUBSCRIPT, and SCsubscriptSC\mathcal{L}_{\mbox{{SC}}}caligraphic_L start_POSTSUBSCRIPT SC end_POSTSUBSCRIPT) on model performance are evaluated by ablation studies on MNLI-m/-mm, CoLA and MRPC tasks. MLKD-BERT4subscriptMLKD-BERT4\mbox{MLKD-BERT}_{4}MLKD-BERT start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT kee** all distillation loss functions is taken as the baseline.

As shown in Table 7, greater drop in performance indicates more importance of corresponding distillation loss function. As such, according to average performance, the distillation loss functions FFNsubscriptFFN\mathcal{L}_{\mbox{{FFN}}}caligraphic_L start_POSTSUBSCRIPT FFN end_POSTSUBSCRIPT, MHAsubscriptMHA\mathcal{L}_{\mbox{{MHA}}}caligraphic_L start_POSTSUBSCRIPT MHA end_POSTSUBSCRIPT, SCsubscriptSC\mathcal{L}_{\mbox{{SC}}}caligraphic_L start_POSTSUBSCRIPT SC end_POSTSUBSCRIPT, SSsubscriptSS\mathcal{L}_{\mbox{{SS}}}caligraphic_L start_POSTSUBSCRIPT SS end_POSTSUBSCRIPT, and EMBsubscriptEMB\mathcal{L}_{\mbox{{EMB}}}caligraphic_L start_POSTSUBSCRIPT EMB end_POSTSUBSCRIPT are listed in importance descending order. As the removal of any distillation loss function leads to average performance drop, we may conclude that the distillation loss functions proposed by MLKD-BERT are all effective to enhance performance. Therefore, the relation-level knowledge (EMBsubscriptEMB\mathcal{L}_{\mbox{{EMB}}}caligraphic_L start_POSTSUBSCRIPT EMB end_POSTSUBSCRIPT, MHAsubscriptMHA\mathcal{L}_{\mbox{{MHA}}}caligraphic_L start_POSTSUBSCRIPT MHA end_POSTSUBSCRIPT, SSsubscriptSS\mathcal{L}_{\mbox{{SS}}}caligraphic_L start_POSTSUBSCRIPT SS end_POSTSUBSCRIPT, SCsubscriptSC\mathcal{L}_{\mbox{{SC}}}caligraphic_L start_POSTSUBSCRIPT SC end_POSTSUBSCRIPT) can be complementary to feature-level knowledge (FFNsubscriptFFN\mathcal{L}_{\mbox{{FFN}}}caligraphic_L start_POSTSUBSCRIPT FFN end_POSTSUBSCRIPT) for performance enhancement. That could explain why our MLKD-BERT outperforms state-of-the-art knowledge distillation methods on BERT.

Model MNLI-m/-mm CoLA MRPC Avg
MLKD-BERT4subscriptMLKD-BERT4\mbox{MLKD-BERT}_{4}MLKD-BERT start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT 82.0/80.7 35.5 86.3 71.2
w/o EMBsubscriptEMB\mathcal{L}_{\mbox{{EMB}}}caligraphic_L start_POSTSUBSCRIPT EMB end_POSTSUBSCRIPT 81.9/80.8 33.3 86.9 70.7
w/o MHAsubscriptMHA\mathcal{L}_{\mbox{{MHA}}}caligraphic_L start_POSTSUBSCRIPT MHA end_POSTSUBSCRIPT 81.8/80.3 31.3 86.1 69.9
w/o FFNsubscriptFFN\mathcal{L}_{\mbox{{FFN}}}caligraphic_L start_POSTSUBSCRIPT FFN end_POSTSUBSCRIPT 81.7/80.1 32.7 84.9 69.9
w/o SSsubscriptSS\mathcal{L}_{\mbox{{SS}}}caligraphic_L start_POSTSUBSCRIPT SS end_POSTSUBSCRIPT 81.8/80.4 33.7 85.2 70.3
w/o SCsubscriptSC\mathcal{L}_{\mbox{{SC}}}caligraphic_L start_POSTSUBSCRIPT SC end_POSTSUBSCRIPT 81.6/80.4 32.2 86.2 70.1
Table 7: Effect of Distillation Loss Functions on Performance. The worst performance on each task is in-bold.
Model MNLI-m/-mm CoLA MRPC Avg
MLKD-BERT4subscriptMLKD-BERT4\mbox{MLKD-BERT}_{4}MLKD-BERT start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT 82.0/80.7 35.5 86.3 71.2
(two-stage)
MLKD-BERT4subscriptMLKD-BERT4\mbox{MLKD-BERT}_{4}MLKD-BERT start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT 80.7/80.2 28.5 86.2 68.9
(one-stage)
Table 8: One-stage vs. Two-stage Distillation.

4.8 One-stage vs. Two-stage Distillation

Here we are to study whether the distillation procedure of MLKD-BERT should be partitioned into two stages. As we have got 6 distillation loss functions for MLKD-BERT, the one-stage procedure is designed to minimize the sum of the 6 distillation loss functions. Experimental results with MLKD-BERT4subscriptMLKD-BERT4\mbox{MLKD-BERT}_{4}MLKD-BERT start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT on MNLI-m/-mm, CoLA and MRPC tasks are presented in Table 8. We find that MLKD-BERT4subscriptMLKD-BERT4\mbox{MLKD-BERT}_{4}MLKD-BERT start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT with two-stage distillation outperforms one-stage distillation on all tasks. That might because our procedure partition could emphasize different distillation objectives for Stage 1 and Stage 2. Stage 1 emphasizes distilling feature representation and transformation, while Stage 2 emphasizes distilling sample prediction.

5 Conclusion

In this paper, we propose a novel two-stage distillation method MLKD-BERT to distill multi-level knowledge in teacher-student framework. MLKD-BERT enhances existing knowledge distillation methods on BERT in two ways: bringing in valuable relation-level knowledge, and making flexible setting of student attention head number. Experimental results show that MLKD-BERT outperforms state-of-the-art BERT distillation methods on GLUE benchmark and extractive question answering tasks. We believe that, the easy adaption of our method would be helpful to other Transformer-based PLM compression in teacher-student framework.

Limitations

Our MLKD-BERT has two limitations: 1) The two-stage distillation costs relatively more training time than one-stage methods; 2) MLKD-BERT is limited to handle natural language understanding tasks.

References

  • Aguilar et al. (2020) Gustavo Aguilar, Yuan Ling, Yu Zhang, Benjamin Yao, Xing Fan, and Chenlei Guo. 2020. Knowledge distillation from internal representations. In The Thirty-Fourth AAAI Conference on Artificial Intelligence, AAAI 2020, The Thirty-Second Innovative Applications of Artificial Intelligence Conference, IAAI 2020, The Tenth AAAI Symposium on Educational Advances in Artificial Intelligence, EAAI 2020, New York, NY, USA, February 7-12, 2020, pages 7350–7357. AAAI Press.
  • Bentivogli et al. (2009) Luisa Bentivogli, Bernardo Magnini, Ido Dagan, Hoa Trang Dang, and Danilo Giampiccolo. 2009. The fifth PASCAL recognizing textual entailment challenge. In Proceedings of the Second Text Analysis Conference, TAC 2009, Gaithersburg, Maryland, USA, November 16-17, 2009. NIST.
  • Cer et al. (2017) Daniel M. Cer, Mona T. Diab, Eneko Agirre, Iñigo Lopez-Gazpio, and Lucia Specia. 2017. Semeval-2017 task 1: Semantic textual similarity - multilingual and cross-lingual focused evaluation. CoRR, abs/1708.00055.
  • Chen et al. (2018) Zihan Chen, Hongbo Zhang, Xiaoji Zhang, and Leqi Zhao. 2018. Quora question pairs. URL https://www. kaggle. com/c/quora-question-pairs.
  • Devlin et al. (2019) Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. 2019. BERT: pre-training of deep bidirectional transformers for language understanding. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, NAACL-HLT 2019, Minneapolis, MN, USA, June 2-7, 2019, Volume 1 (Long and Short Papers), pages 4171–4186. Association for Computational Linguistics.
  • Dolan and Brockett (2005) William B. Dolan and Chris Brockett. 2005. Automatically constructing a corpus of sentential paraphrases. In Proceedings of the Third International Workshop on Paraphrasing, IWP@IJCNLP 2005, Jeju Island, Korea, October 2005, 2005. Asian Federation of Natural Language Processing.
  • Hinton et al. (2015) Geoffrey E. Hinton, Oriol Vinyals, and Jeffrey Dean. 2015. Distilling the knowledge in a neural network. CoRR, abs/1503.02531.
  • Jiao et al. (2020) Xiaoqi Jiao, Yichun Yin, Lifeng Shang, Xin Jiang, Xiao Chen, Linlin Li, Fang Wang, and Qun Liu. 2020. Tinybert: Distilling BERT for natural language understanding. In Findings of the Association for Computational Linguistics: EMNLP 2020, Online Event, 16-20 November 2020, volume EMNLP 2020 of Findings of ACL, pages 4163–4174. Association for Computational Linguistics.
  • Khosla et al. (2020) Prannay Khosla, Piotr Teterwak, Chen Wang, Aaron Sarna, Yonglong Tian, Phillip Isola, Aaron Maschinot, Ce Liu, and Dilip Krishnan. 2020. Supervised contrastive learning. In Advances in Neural Information Processing Systems 33: Annual Conference on Neural Information Processing Systems 2020, NeurIPS 2020, December 6-12, 2020, virtual.
  • Li et al. (2020) Jianquan Li, Xiaokang Liu, Honghong Zhao, Ruifeng Xu, Min Yang, and Yaohong **. 2020. BERT-EMD: many-to-many layer map** for BERT compression with earth mover’s distance. In Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing, EMNLP 2020, Online, November 16-20, 2020, pages 3009–3018. Association for Computational Linguistics.
  • Liu et al. (2019) Yinhan Liu, Myle Ott, Naman Goyal, **gfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, and Veselin Stoyanov. 2019. Roberta: A robustly optimized BERT pretraining approach. CoRR, abs/1907.11692.
  • Mikolov et al. (2013) Tomás Mikolov, Ilya Sutskever, Kai Chen, Gregory S. Corrado, and Jeffrey Dean. 2013. Distributed representations of words and phrases and their compositionality. In Advances in Neural Information Processing Systems 26: 27th Annual Conference on Neural Information Processing Systems 2013. Proceedings of a meeting held December 5-8, 2013, Lake Tahoe, Nevada, United States, pages 3111–3119.
  • Pennington et al. (2014) Jeffrey Pennington, Richard Socher, and Christopher D. Manning. 2014. Glove: Global vectors for word representation. In Proceedings of the 2014 Conference on Empirical Methods in Natural Language Processing, EMNLP 2014, October 25-29, 2014, Doha, Qatar, A meeting of SIGDAT, a Special Interest Group of the ACL, pages 1532–1543. ACL.
  • Radford et al. (2018) Alec Radford, Karthik Narasimhan, Tim Salimans, Ilya Sutskever, et al. 2018. Improving language understanding by generative pre-training.
  • Rajpurkar et al. (2018) Pranav Rajpurkar, Robin Jia, and Percy Liang. 2018. Know what you don’t know: Unanswerable questions for squad. In Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics, ACL 2018, Melbourne, Australia, July 15-20, 2018, Volume 2: Short Papers, pages 784–789. Association for Computational Linguistics.
  • Rajpurkar et al. (2016) Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, and Percy Liang. 2016. Squad: 100, 000+ questions for machine comprehension of text. In Proceedings of the 2016 Conference on Empirical Methods in Natural Language Processing, EMNLP 2016, Austin, Texas, USA, November 1-4, 2016, pages 2383–2392. The Association for Computational Linguistics.
  • Sanh et al. (2019) Victor Sanh, Lysandre Debut, Julien Chaumond, and Thomas Wolf. 2019. Distilbert, a distilled version of BERT: smaller, faster, cheaper and lighter. CoRR, abs/1910.01108.
  • Shen et al. (2020) Sheng Shen, Zhen Dong, Jiayu Ye, Linjian Ma, Zhewei Yao, Amir Gholami, Michael W. Mahoney, and Kurt Keutzer. 2020. Q-BERT: hessian based ultra low precision quantization of BERT. In The Thirty-Fourth AAAI Conference on Artificial Intelligence, AAAI 2020, The Thirty-Second Innovative Applications of Artificial Intelligence Conference, IAAI 2020, The Tenth AAAI Symposium on Educational Advances in Artificial Intelligence, EAAI 2020, New York, NY, USA, February 7-12, 2020, pages 8815–8821. AAAI Press.
  • Socher et al. (2013) Richard Socher, Alex Perelygin, Jean Wu, Jason Chuang, Christopher D. Manning, Andrew Y. Ng, and Christopher Potts. 2013. Recursive deep models for semantic compositionality over a sentiment treebank. In Proceedings of the 2013 Conference on Empirical Methods in Natural Language Processing, EMNLP 2013, 18-21 October 2013, Grand Hyatt Seattle, Seattle, Washington, USA, A meeting of SIGDAT, a Special Interest Group of the ACL, pages 1631–1642. ACL.
  • Sun et al. (2019) Siqi Sun, Yu Cheng, Zhe Gan, and **g**g Liu. 2019. Patient knowledge distillation for BERT model compression. In Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing, EMNLP-IJCNLP 2019, Hong Kong, China, November 3-7, 2019, pages 4322–4331. Association for Computational Linguistics.
  • Sun et al. (2020) Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny Zhou. 2020. Mobilebert: a compact task-agnostic BERT for resource-limited devices. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, ACL 2020, Online, July 5-10, 2020, pages 2158–2170. Association for Computational Linguistics.
  • Tang et al. (2019) Raphael Tang, Yao Lu, Linqing Liu, Lili Mou, Olga Vechtomova, and Jimmy Lin. 2019. Distilling task-specific knowledge from BERT into simple neural networks. CoRR, abs/1903.12136.
  • Turc et al. (2019) Iulia Turc, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. 2019. Well-read students learn better: The impact of student initialization on knowledge distillation. CoRR, abs/1908.08962.
  • Wang et al. (2019) Alex Wang, Amanpreet Singh, Julian Michael, Felix Hill, Omer Levy, and Samuel R. Bowman. 2019. GLUE: A multi-task benchmark and analysis platform for natural language understanding. In 7th International Conference on Learning Representations, ICLR 2019, New Orleans, LA, USA, May 6-9, 2019. OpenReview.net.
  • Wang et al. (2021) Wenhui Wang, Hangbo Bao, Shaohan Huang, Li Dong, and Furu Wei. 2021. Minilmv2: Multi-head self-attention relation distillation for compressing pretrained transformers. In Findings of the Association for Computational Linguistics: ACL/IJCNLP 2021, Online Event, August 1-6, 2021, volume ACL/IJCNLP 2021 of Findings of ACL, pages 2140–2151. Association for Computational Linguistics.
  • Wang et al. (2020a) Wenhui Wang, Furu Wei, Li Dong, Hangbo Bao, Nan Yang, and Ming Zhou. 2020a. Minilm: Deep self-attention distillation for task-agnostic compression of pre-trained transformers. In Advances in Neural Information Processing Systems 33: Annual Conference on Neural Information Processing Systems 2020, NeurIPS 2020, December 6-12, 2020, virtual.
  • Wang et al. (2020b) Ziheng Wang, Jeremy Wohlwend, and Tao Lei. 2020b. Structured pruning of large language models. In Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing, EMNLP 2020, Online, November 16-20, 2020, pages 6151–6162. Association for Computational Linguistics.
  • Warstadt et al. (2019) Alex Warstadt, Amanpreet Singh, and Samuel R. Bowman. 2019. Neural network acceptability judgments. Trans. Assoc. Comput. Linguistics, 7:625–641.
  • Williams et al. (2018) Adina Williams, Nikita Nangia, and Samuel R. Bowman. 2018. A broad-coverage challenge corpus for sentence understanding through inference. In Proceedings of the 2018 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, NAACL-HLT 2018, New Orleans, Louisiana, USA, June 1-6, 2018, Volume 1 (Long Papers), pages 1112–1122. Association for Computational Linguistics.
  • Yang et al. (2019) Zhilin Yang, Zihang Dai, Yiming Yang, Jaime G. Carbonell, Ruslan Salakhutdinov, and Quoc V. Le. 2019. Xlnet: Generalized autoregressive pretraining for language understanding. In Advances in Neural Information Processing Systems 32: Annual Conference on Neural Information Processing Systems 2019, NeurIPS 2019, December 8-14, 2019, Vancouver, BC, Canada, pages 5754–5764.
  • Zhang et al. (2019) Zhengyan Zhang, Xu Han, Zhiyuan Liu, Xin Jiang, Maosong Sun, and Qun Liu. 2019. ERNIE: enhanced language representation with informative entities. In Proceedings of the 57th Conference of the Association for Computational Linguistics, ACL 2019, Florence, Italy, July 28- August 2, 2019, Volume 1: Long Papers, pages 1441–1451. Association for Computational Linguistics.

Appendix A Hyper-parameters for Two-stage Distillation

Task Stage 1 Stage 2
Epochs Batch size Max seq length Learning rate Epochs Batch size Max seq length Learning rate 𝝆𝝆\bm{\rho}bold_italic_ρ 𝝉𝝉\bm{\tau}bold_italic_τ
CoLA 50 32 64 1e-5 30 32 64 1e-5 0.07 1.0
MNLI 6 32 128 3e-5 6 32 128 3e-5 0.07 1.0
MRPC 20 32 128 2e-5 15 32 128 2e-5 0.07 1.0
SST-2 15 32 64 2e-5 10 32 64 2e-5 0.07 1.0
STS-B 20 32 128 3e-5 15 32 128 3e-5 0.07 1.0
QQP 6 32 128 2e-5 6 32 128 2e-5 0.07 1.0
QNLI 10 32 128 2e-5 10 32 128 2e-5 0.07 1.0
RTE 20 32 128 2e-5 15 32 128 2e-5 0.07 1.0
SQuAD 1.1 4 16 384 3e-5 3 16 384 3e-5 0.07 1.0
SQuAD 2.0 4 16 384 3e-5 3 16 384 3e-5 0.07 1.0
Table 9: Hyper-parameters for Two-stage Distillation.
MLKD-BERT𝟒subscriptMLKD-BERT4\bm{\mbox{MLKD-BERT}_{4}}MLKD-BERT start_POSTSUBSCRIPT bold_4 end_POSTSUBSCRIPT 𝑩sizesubscript𝑩size\bm{B_{\mbox{{{size}}}}}bold_italic_B start_POSTSUBSCRIPT size end_POSTSUBSCRIPT QQP SST-2
𝑨𝒉𝑺=𝟏𝟐superscriptsubscript𝑨𝒉𝑺12\bm{A_{h}^{S}=12}bold_italic_A start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT bold_italic_S end_POSTSUPERSCRIPT bold_= bold_12 𝑨𝒉𝑺=𝟔superscriptsubscript𝑨𝒉𝑺6\bm{A_{h}^{S}=6}bold_italic_A start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT bold_italic_S end_POSTSUPERSCRIPT bold_= bold_6 𝑨𝒉𝑺=𝟑superscriptsubscript𝑨𝒉𝑺3\bm{A_{h}^{S}=3}bold_italic_A start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT bold_italic_S end_POSTSUPERSCRIPT bold_= bold_3 𝑨𝒉𝑺=𝟏𝟐superscriptsubscript𝑨𝒉𝑺12\bm{A_{h}^{S}=12}bold_italic_A start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT bold_italic_S end_POSTSUPERSCRIPT bold_= bold_12 𝑨𝒉𝑺=𝟔superscriptsubscript𝑨𝒉𝑺6\bm{A_{h}^{S}=6}bold_italic_A start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT bold_italic_S end_POSTSUPERSCRIPT bold_= bold_6 𝑨𝒉𝑺=𝟑superscriptsubscript𝑨𝒉𝑺3\bm{A_{h}^{S}=3}bold_italic_A start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT bold_italic_S end_POSTSUPERSCRIPT bold_= bold_3
1 4.27 ±plus-or-minus\pm± 0.33 4.16 ±plus-or-minus\pm± 0.19 4.09 ±plus-or-minus\pm± 0.20 4.24 ±plus-or-minus\pm± 0.22 4.12 ±plus-or-minus\pm± 0.22 4.00 ±plus-or-minus\pm± 0.19
(-2.58%) (-4.22%) (-2.83%) (-5.66%)
16 5.46 ±plus-or-minus\pm± 0.17 5.06 ±plus-or-minus\pm± 0.21 4.93 ±plus-or-minus\pm± 0.21 5.49 ±plus-or-minus\pm± 0.20 5.07 ±plus-or-minus\pm± 0.24 4.91 ±plus-or-minus\pm± 0.30
Inference (-7.33%) (-9.71%) (-7.65%) (-10.56%)
time(ms) 32 10.35 ±plus-or-minus\pm± 0.17 9.48 ±plus-or-minus\pm± 0.26 8.99 ±plus-or-minus\pm± 0.22 10.39 ±plus-or-minus\pm± 0.22 9.49 ±plus-or-minus\pm± 0.18 9.01 ±plus-or-minus\pm± 0.24
(-8.41%) (-13.14%) (-8.66%) (-13.28%)
64 19.28 ±plus-or-minus\pm± 0.21 17.49 ±plus-or-minus\pm± 0.18 16.60 ±plus-or-minus\pm± 0.22 19.27 ±plus-or-minus\pm± 0.20 17.49 ±plus-or-minus\pm± 0.21 16.59 ±plus-or-minus\pm± 0.23
(-9.28%) (-13.90%) (-9.24%) (-13.91%)
Performance 70.6 70.4 70.0 91.9 92.0 91.9
(-0.28%) (-0.85%) (+0.11%) (-0.00%)
Table 10: Effect of Attention Head Number AhSsuperscriptsubscript𝐴𝑆A_{h}^{S}italic_A start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT on Model Performance and Inference Time on more GLUE tasks.

The detailed hyper-parameters for two-stage distillation of MLKD-BERT on GLUE tasks and extractive question answering tasks are presented in Table 9.

Appendix B Effect of Attention Head Number 𝑨𝒉𝑺superscriptsubscript𝑨𝒉𝑺\bm{A_{h}^{S}}bold_italic_A start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT bold_italic_S end_POSTSUPERSCRIPT on Model Performance and Inference Time on more GLUE tasks

The effect of attention head number AhSsuperscriptsubscript𝐴𝑆A_{h}^{S}italic_A start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT on model performance and inference time on more GLUE tasks are summarized in Table 10. We can observe the similar trends as the results on MNLI task in Table 3.

Appendix C GLUE Dataset

In this section, we provide a brief description of the tasks in GLUE benchmark (Wang et al., 2019).
MNLI. Multi-Genre Natural Language Inference is a large-scale, crowd-sourced entailment classification task (Williams et al., 2018). Given a pair of premise,hypothesis𝑝𝑟𝑒𝑚𝑖𝑠𝑒𝑦𝑝𝑜𝑡𝑒𝑠𝑖𝑠\langle{premise,hypothesis}\rangle⟨ italic_p italic_r italic_e italic_m italic_i italic_s italic_e , italic_h italic_y italic_p italic_o italic_t italic_h italic_e italic_s italic_i italic_s ⟩, the goal is to predict whether the hypothesis𝑦𝑝𝑜𝑡𝑒𝑠𝑖𝑠hypothesisitalic_h italic_y italic_p italic_o italic_t italic_h italic_e italic_s italic_i italic_s is an entailment, contradiction, or neutral with respect to the premise𝑝𝑟𝑒𝑚𝑖𝑠𝑒premiseitalic_p italic_r italic_e italic_m italic_i italic_s italic_e.
QQP. Quora Question Pairs is a collection of question pairs from the website Quora. The task is to determine whether two questions are semantically equivalent (Chen et al., 2018).
QNLI. Question Natural Language Inference is a version of the Stanford Question Answering Dataset which has been converted to a binary sentence pair classification task by (Wang et al., 2019). Given a pair of question,context𝑞𝑢𝑒𝑠𝑡𝑖𝑜𝑛𝑐𝑜𝑛𝑡𝑒𝑥𝑡\langle{question,context}\rangle⟨ italic_q italic_u italic_e italic_s italic_t italic_i italic_o italic_n , italic_c italic_o italic_n italic_t italic_e italic_x italic_t ⟩, The task is to determine whether the context𝑐𝑜𝑛𝑡𝑒𝑥𝑡contextitalic_c italic_o italic_n italic_t italic_e italic_x italic_t contains the answer𝑎𝑛𝑠𝑤𝑒𝑟answeritalic_a italic_n italic_s italic_w italic_e italic_r to the question.
SST-2. The Stanford Sentiment Treebank is a binary single-sentence classification task, where the goal is to predict the sentiment of movie reviews (Socher et al., 2013).
CoLA. The Corpus of Linguistic Acceptability is a task to predict whether an English sentence is a grammatically correct one (Warstadt et al., 2019).
STS-B. The Semantic Textual Similarity Benchmark is a collection of sentence pairs drawn from news headlines and many other domains (Cer et al., 2017). The task aims to evaluate how similar two pieces of texts are by a score from 1 to 5.
MRPC. Microsoft Research Paraphrase Corpus is a paraphrase identification dataset where systems aim to identify if two sentences are paraphrases of each other (Dolan and Brockett, 2005).
RTE. Recognizing Textual Entailment is a binary entailment task with a small training dataset (Bentivogli et al., 2009).

Appendix D SQuAD

Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable.
SQuAD 1.1 Rajpurkar et al. (2016) contains 100,000+ question-answer pairs on 500+ articles.
SQuAD 2.0 (Rajpurkar et al., 2018) combines the 100,000 questions in SQuAD 1.1 with over 50,000 unanswerable questions written adversarially by crowdworkers to look similar to answerable ones.