MLKD-BERT: Multi-level Knowledge Distillation for Pre-trained Language Models
Abstract
Knowledge distillation is an effective technique for pre-trained language model compression. Although existing knowledge distillation methods perform well for the most typical model BERT, they could be further improved in two aspects: the relation-level knowledge could be further explored to improve model performance; and the setting of student attention head number could be more flexible to decrease inference time. Therefore, we are motivated to propose a novel knowledge distillation method MLKD-BERT to distill multi-level knowledge in teacher-student framework. Extensive experiments on GLUE benchmark and extractive question answering tasks demonstrate that our method outperforms state-of-the-art knowledge distillation methods on BERT. In addition, MLKD-BERT can flexibly set student attention head number, allowing for substantial inference time decrease with little performance drop.
1 Introduction
In recent years, large-scale pre-trained language models (PLMs) have been widely applied in natural language processing, such as BERT (Devlin et al., 2019), XLNet (Yang et al., 2019), and RoBERTa (Liu et al., 2019). The PLM usually has large number of parameters and long inference time, making it inapplicable to resource-limited devices and real-time scenarios. Therefore, it is crucial to reduce PLM’s storage and computation overhead while retaining its performance. Knowledge distillation (Hinton et al., 2015) is an effective technique for PLM compression. In knowledge distillation, a smaller compact student model is trained, under the guidance of a larger complicated teacher model, to keep similar model performance.
As for the most typical PLM BERT, there exist several knowledge distillation methods for model compression, including DistilBERT (Sanh et al., 2019), BERT-PKD (Sun et al., 2019), TinyBERT (Jiao et al., 2020), BERT-EMD (Li et al., 2020), MINILM (Wang et al., 2020a), and MINILMv2 (Wang et al., 2021). Despite the effectiveness of previous methods, there still exist two problems not addressed well:
-
•
Existing methods mainly distill feature-level knowledge, but seldom consider relation-level knowledge (relation among tokens and relation among samples). However, the relation-level knowledge may be valuable to improve the performance of student model.
-
•
Most previous works use self-attention distribution to distill teacher’s self-attention modules, thus the student model is restricted to take the same attention head number as its teacher. Such restriction prevents the reduction of attention head number in student model, resulting in increased inference time.
Therefore, a more flexible knowledge distillation method with improved performance is preferred to transfer knowledge from teacher to student model.
In this paper, we propose a novel multi-level knowledge distillation method MLKD-BERT for BERT compression. MLKD-BERT conducts a two-stage distillation for feature-level as well as relation-level knowledge, with 6 distillation loss functions designed for embedding layer, Transformer layers, and prediction layer. Compared with previous works, we have made two main contributions:
-
•
In addition to feature-level knowledge, our student model learns valuable relation-level knowledge (relation among tokens and relation among samples) from its teacher model, which further improves the performance.
-
•
Our student model learns self-attention relation instead of self-attention distribution, making it flexible in attention head number setting. As such, our student model could reduce attention head number to further decrease inference time.
Extensive experiments on GLUE (Wang et al., 2019) benchmark and extractive question answering tasks show that our MLKD-BERT outperforms state-of-the-art BERT distillation methods on various prediction tasks. In addition, MLKD-BERT can set smaller student attention head number, allowing for substantial inference time decrease with little performance drop. Moreover, MLKD-BERT is effective in PLM compression, in that MLKD-BERT keeps competitive performance ( on average for GLUE tasks) as its teacher with compression in parameters and inference time.
2 Related Works
2.1 Pre-trained Language Models
Nowadays, large-scale pre-trained language models have significantly improved the performance of many natural language processing tasks. Pre-trained language models are usually trained on large amounts of text data, and then fine-tuned for specific task. Early research efforts mainly focus on word embedding, such as word2vec (Mikolov et al., 2013) and GloVe (Pennington et al., 2014). Subsequently, researchers have shifted to contextual word embedding, including BERT (Devlin et al., 2019), GPT Radford et al. (2018), ENRIE (Zhang et al., 2019), XLNet (Yang et al., 2019) and RoBERTa (Liu et al., 2019). However, those PLMs contain millions of parameters and take long inference time, making them inapplicable to resource-limited devices and real-time scenarios. Fortunately, there exist many compression techniques for PLMs, which reduce model size and accelerate model inference while kee** model performance.
2.2 Knowledge Distillation
Besides quantization (Shen et al., 2020) and network pruning (Wang et al., 2020b), knowledge distillation (Tang et al., 2019) has been proven to be an effective technique for PLM compression. As for the most typical PLM BERT, there exist several knowledge distillation methods for model compression. Distilled BiLSTM (Tang et al., 2019) tries to distill knowledge from BERT into a simple LSTM. DistilBERT (Sanh et al., 2019) uses soft target probabilities and embedding outputs to train student model. BERT-PKD (Sun et al., 2019) learns from multiple intermediate layers of teacher model for incremental knowledge extraction. TinyBERT (Jiao et al., 2020), MobileBERT (Sun et al., 2020), and SID (Aguilar et al., 2020) further improve BERT-PKD by distilling more internal representations, such as embedding layer outputs and self-attention distribution. BERT-EMD (Li et al., 2020) allows each intermediate student layer to learn from any intermediate teacher layer based on Earth Mover’s Distance. MINILM (Wang et al., 2020a) uses self-attention distribution and value relation to conduct deep self-attention distillation. MINILMv2 (Wang et al., 2021) generalizes deep self-attention distillation in MINILM (Wang et al., 2020a), employing self-attention relation.
Although existing knowledge distillation methods perform well in BERT compression, they are limited in two aspects: the relation-level knowledge is not well explored by the student model to enhance performance; and the attention head number of student model is restricted to the same as its teacher, increasing the inference time. Hence, we are motivated to propose a more flexible knowledge distillation method for BERT with improved performance.
3 Method
MLKD-BERT employs a two-stage distillation procedure for downstream task prediction in teacher-student framework (illustrated in Figure 1). Stage 1 distills embedding-layer and Transformer-layers, emphasizing feature representation and transformation, while Stage 2 distills prediction-layer, emphasizing sample prediction. In knowledge distillation, each student layer is mapped to corresponding teacher layer. As the number of Transformer-layers in student model is smaller than that in teacher model, we take uniform map** strategy (Jiao et al., 2020) for Transformer-layer (TL) map**. For example, in Figure 1, when 2 student TLs are mapped to 4 teacher TLs, TL1 and TL2 in student model are mapped to TL2 and TL4 in teacher model, respectively.
In MLKD-BERT, we bring in some new relation-level knowledge to improve layer distillation. At embedding-layer, token similarity relation is employed to enhance feature representation. At Transformer-layers, besides feature-level knowledge, self-attention relation is proposed for flexible student attention head number setting. At prediction-layer, in addition to soft labels, sample similarity relation and sample contrastive relation are introduced to enhance prediction. The distillation procedures will be detailed in the following subsections.
3.1 Embedding-layer Distillation
Since embedding-layer performs feature representation of tokens for each data sample, token similarity relation could be valuable knowledge to enhance embedding-layer distillation. In embedding-layer distillation, token similarity relation is transferred from teacher to student by minimizing the KL-divergence of token embedding similarities between teacher and student, according to the embedding distillation loss function defined in Eqn.(1):
(1) |
(2) |
(3) |
where is the length of input sequence; matrices and are token embeddings of teacher and student; and are hidden dimension of teacher and student; matrices and are token embedding similarity matrices of teacher and student, respectively.
3.2 Transformer-layer Distillation
Standard Transformer-layer contains two main sub-layers: Multi-Head Attention (MHA) and Feed Forward Network (FFN). Transformer-layer distillation (illustrated in Figure 2) is conducted on MHA and FFN to transfer self-attention relation and feature-level knowledge, respectively.
As for MHA sub-layer, the similarities among its output vectors are defined as self-attention relation. At Transformer layer , let represents the number of attention heads, then output ) of the -th attention head, is computed via:
(4) |
(5) |
(6) |
(7) |
where is the input vectors of Transformer layer , with representing the length of input sequence and representing the hidden dimension; , , are linearly projections of ; are parameter matrices; and is attention head size.
In MHA distillation, MHA outputs (i.e., defined in Eqn.(4)) are concatenated together and then split into a certain number of vector groups (named MHA-splits), in both teacher and student models. We suggest setting the number of MHA-splits as the number of student attention heads. After that, self-attention relation is transferred from teacher to student by minimizing the KL-divergence of MHA output vector similarities between teacher and student. Given that the -th student layer is mapped to the -th teacher layer, the loss function for MHA distillation is defined in Eqn.(8):
(8) |
(9) |
(10) |
where is the length of input sequence; is the number of MHA-splits; is the number of student Transformer-layers; matrices and are MHA outputs in MHA-split at teacher’s Layer and student’s Layer ; and are split-head size of teacher and student MHA-split; matrices and are MHA output vector similarities in MHA-split at teacher’s Layer and student’s Layer , respectively.
Note that the teacher and student model should have the same number of MHA-splits. In MLKD-BERT, self-attention relation is transferred by MHA-split rather than MHA attention head. As such, student model does not have to set the same number of attention heads as its teacher. In this way, the number of attention heads in the student model can be set smaller, which could decrease inference time.
As for FFN sub-layer, the feature-level knowledge is distilled by minimizing the mean squared error of the output hidden states between teacher and student, according to the FFN distillation loss function in Eqn.(11):
(11) |
Given that the -th student layer is mapped to the -th teacher layer, matrices and are the hidden states of student’s Layer and teacher’s Layer ; matrix is a learnable linear transformation, which transforms the student’s hidden states into the same space as the teacher’s hidden states.
In summary, as the first stage distillation covers both embedding-layer and Transformer-layers, the above distillation loss functions are summed up to the first stage distillation loss function in Eqn.(12):
(12) |
3.3 Prediction-layer Distillation
As samples with same class label tend to be similar than samples with different class labels, relation among samples would be valuable knowledge for prediction-layer distillation. Here, we will bring in sample similarity relation and sample contrastive relation to enhance prediction-layer distillation.
Sample similarity relation is defined as the sample similarities within a data batch, without considering sample labels. The sample similarity relation is distilled from teacher to student by minimizing the KL-divergence of sample similarities between teacher and student, according to the sample-similarity distillation loss function defined in Eqn.(13):
(13) |
(14) |
(15) |
where is batch size; and are hidden dimension of teacher and student; matrices and are sample representations in a batch, i.e., [CLS] outputs from the last Transformer-layer of teacher and student; and are sample similarity matrices of teacher and student, respectively.
Sample contrastive relation is employed to map samples with same and different class labels (according to ground truth of training samples) into close and distant representation space, respectively. Sample contrastive relation is distilled by minimizing the sample-contrastive distillation loss function in Eqn.(16) (Khosla et al., 2020):
(16) |
(17) |
where is batch size; , , ; is class label of -th sample; is scalar temperature parameter; is the -th row of ; ; is linear transformation matrix; and are defined in Eqn.(14) and Eqn.(15).
Similar to previous distillation works, we also adopt soft label distillation by minimizing the soft-label distillation loss function in Eqn.(18):
(18) |
where is scalar temperature parameter; and are logits predicted by teacher and student, respectively.
Model | #Params | Speedup | MNLI-m/-mm | QQP | QNLI | SST-2 | CoLA | STS-B | MRPC | RTE | Avg |
(393k) | (364k) | (105k) | (67k) | (8.5k) | (5.7k) | (3.7k) | (2.5k) | ||||
BERT-base | 109M | 84.2/83.6 | 71.6 | 90.8 | 94.3 | 52.6 | 83.9 | 87.3 | 67.3 | 79.5 | |
14.5M | 75.4/74.9 | 66.5 | 84.8 | 87.6 | 19.5 | 77.1 | 83.2 | 62.6 | 70.2 | ||
52.2M | 79.9/79.3 | 70.2 | 85.1 | 89.4 | 24.8 | 79.8 | 82.6 | 62.3 | 72.6 | ||
52.2M | 78.9/78.0 | 68.5 | 85.2 | 91.4 | 32.8 | 76.1 | 82.4 | 54.1 | 71.9 | ||
14.5M | 82.1/80.6 | 69.3 | 87.2 | 91.0 | 25.6 | 82.3 | 87.6 | 66.2 | 74.7 | ||
14.5M | 81.4/80.4 | 69.9 | 85.9 | 91.9 | 35.2 | 81.5 | 85.4 | 62.1 | 74.8 | ||
14.5M | 82.0/80.7 | 70.6 | 87.5 | 91.9 | 35.5 | 81.9 | 86.3 | 63.5 | 75.6 | ||
67.0M | 81.5/81.0 | 70.7 | 89.0 | 92.0 | 43.5 | 81.6 | 85.0 | 65.5 | 76.6 | ||
67.0M | 82.6/81.3 | 70.1 | 88.9 | 92.5 | 49.0 | 81.3 | 86.9 | 58.4 | 76.8 | ||
67.0M | 83.9/83.4 | 72.0 | 89.9 | 93.7 | 46.7 | 83.3 | 85.7 | 66.6 | 78.4 | ||
MINILMv2 | 67.0M | 83.8/83.3 | 70.9 | 90.2 | 92.9 | 46.6 | 84.3 | 89.1 | 69.2 | 78.9 | |
67.0M | 84.4/83.5 | 72.2 | 90.8 | 93.3 | 48.0 | 84.3 | 87.3 | 67.8 | 79.1 |
Model | SQuAD 1.1 | SQuAD 2.0 | Avg |
---|---|---|---|
BERT-base | 88.5 | 77.0 | 82.8 |
79.5 | 64.6 | 72.1 | |
81.2 | 64.1 | 72.7 | |
81.0 | 68.2 | 74.6 | |
82.0 | 68.9 | 75.5 | |
85.3 | 69.8 | 77.6 | |
86.2 | 69.5 | 77.9 | |
88.0 | 76.1 | 82.1 | |
MINILMv2 | - | 76.3 | - |
88.3 | 76.5 | 82.4 |
MNLI-m | MNLI-mm | ||||||
1 | 4.23 0.19 | 4.12 0.24 | 4.07 0.25 | 4.24 0.28 | 4.14 0.40 | 4.03 0.27 | |
(-2.60%) | (-3.78%) | (-2.36%) | (-4.95%) | ||||
16 | 5.46 0.19 | 5.09 0.22 | 4.90 0.32 | 5.47 0.17 | 5.10 0.29 | 4.84 0.20 | |
Inference | (-6.78%) | (-10.26%) | (-6.76%) | (-11.52%) | |||
time(ms) | 32 | 10.34 0.18 | 9.46 0.17 | 9.00 0.30 | 10.35 0.20 | 9.46 0.13 | 8.99 0.24 |
(-8.51%) | (-12.96%) | (-8.60%) | (-13.14%) | ||||
64 | 19.28 0.20 | 17.52 0.22 | 16.56 0.18 | 19.28 0.21 | 17.49 0.17 | 16.59 0.17 | |
(-9.13%) | (-14.11%) | (-9.28%) | (-13.95%) | ||||
Performance | 82.0 | 81.4 | 80.6 | 80.7 | 80.6 | 79.5 | |
(-0.73%) | (-1.70%) | (-0.12%) | (-1.49%) |
In summary, the sample-similarity, sample-contrastive, and soft-label distillation loss functions are summed up to the second stage distillation loss function in Eqn.(19):
(19) |
4 Experiments
4.1 Experimental Datasets
Our experiments are conducted on General Language Understanding Evaluation (GLUE) (Wang et al., 2019) benchmark and extractive question answering tasks. The former is sentence-level task while the latter is token-level task. GLUE includes 8 tasks: 1) Corpus of Linguistic Acceptability (CoLA) (Warstadt et al., 2019); 2) Stanford Sentiment Treebank (SST-2) (Socher et al., 2013); 3) Microsoft Research Paraphrase Corpus (MRPC) (Dolan and Brockett, 2005); 4) Semantic Textual Similarity Benchmark (STS-B) (Cer et al., 2017); 5) Quora Question Pairs (QQP) (Chen et al., 2018); 6) Question Natural Language Inference (QNLI) (Rajpurkar et al., 2016); 7) Recognizing Textual Entailment (RTE) (Bentivogli et al., 2009); 8) Multi-Genre Natural Language Inference (MNLI) (Williams et al., 2018), which is further divided into in-domain (MNLI-m) and cross-domain (MNLI-mm) tasks. The extractive question answering tasks include SQuAD 1.1 (Rajpurkar et al., 2016) and SQuAD 2.0 (Rajpurkar et al., 2018). The GLUE tasks are evaluated on GLUE test sets and the extractive question answering tasks are evaluated on dev sets.
4.2 Experimental Setup
We take a pre-trained language model BERT-base (Devlin et al., 2019) with 109M parameters as the teacher model (number of layers , hidden size , intermediate size , and attention head number ), which is fine-tuned for each specific task. Two student models are instantiated for comparative studies: 1) () with 14.5M parameters; and 2) () with 67.0M parameters. The student models are initialized with the general distillation model delivered by TinyBERT111https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/TinyBERT. The detailed hyper-parameters are presented in Appendix A.
4.3 Comparative Studies
Comparative studies are conducted to evaluate performance of MLKD-BERT against state-of-the-art BERT distillation baselines, including BERT-PKD (Sun et al., 2019), DistilBERT (Sanh et al., 2019), BERT-EMD (Li et al., 2020), TinyBERT (Jiao et al., 2020), and MINILMv2 (Wang et al., 2021). In addition, to measure the performance improvement delivered by distillation, MLKD-BERT is compared with a same structured pre-trained BERT model (Turc et al., 2019), which is fine-tuned for each specific task without knowledge distillation. Note that all experiments are done without data augmentation.
The comparative results evaluated on the test sets of GLUE official benchmark 222https://gluebenchmark.com are presented in Table 1. Here, different evaluation indices are adopted with regarding to the tasks: F1 metric for MRPC and QQP, Spearman correlation for STS-B, Matthew’s correlation for CoLA, and Accuracy for the other tasks.
Experimental results in Table 1 show that, among 4-layer models, our ranks first on average performance and 5 specific tasks, while ranks second on the rest tasks. Among 6-layer models, our delivers similar results. Moreover, performs better than and , with only parameters and inference time. Therefore, our MLKD-BERT has an average improved performance, compared to state-of-the-art knowledge distillation methods on BERT.
To evaluate the effectiveness of MLKD-BERT distillation, MLKD-BERT is compared with and its teacher BERT-base. Results show that our is consistently better than on all tasks with improvement on average. Comparing with teacher model BERT-base, is 7.5x smaller and 9.4x faster, while kee** average performance of its teacher; keeps average performance of its teacher, with only parameters and inference time. As such, the distillation strategies designed for MLKD-BERT are effective to enhance model performance.
The comparative results evaluated on the dev sets of SQuAD 1.1 and SQuAD 2.0 are presented in Table 2, with F1 metric for evaluation. As those two tasks are token-level tasks, we remove and in Stage 2. The results in Table 2 show that our and outperform all other methods on SQuAD 1.1 and SQuAD 2.0, which further demonstrates the effectiveness of our method.
4.4 Inference Time vs. Performance
As MLKD-BERT can flexibly set attention head number for student model, this group of experiments are to evaluate its effect on model performance and inference time. 4-layer student model instantiated with varied numbers of attention heads are experimented on GLUE tasks. with 12 attention heads is taken as baseline as it has the same number of attention heads as its teacher. The inference time is measured by the time ( standard deviation) required for each batch data on a single GeForce RTX 2080Ti GPU. The batch size is varied to provide more insights.
As shown in Table 3, with decrease of attention head number, the inference time decreases very fast with relatively little performance drop. And such effect is emphasized with the increase of batch size. As for batch size of 64 on MNLI-m task, when attention head number drops from 12 to 3, the inference time of decreases 14.11% while kee** over 98% prediction performance. More experimental results presented in Appendix B, deliver similar effect. Therefore, the flexible setting of student attention head number would allow substantial inference time decrease at little expense of performance drop.
Model | MNLI-m/-mm | SST-2 | QQP | Avg | |
---|---|---|---|---|---|
3 | 80.6/79.5 | 91.9 | 70.0 | 80.5 | |
6 | 80.6/79.0 | 90.4 | 69.8 | 79.9 | |
12 | 80.4/79.3 | 91.0 | 69.9 | 80.1 | |
6 | 81.4/80.6 | 92.0 | 70.4 | 81.1 | |
8 | 80.7/80.0 | 91.3 | 70.4 | 80.6 | |
12 | 81.0/80.2 | 91.9 | 70.4 | 80.9 |
Model | Method | MNLI-m/-mm | SST-2 | QQP | Avg |
---|---|---|---|---|---|
Concat-split | 80.6/79.5 | 91.9 | 70.0 | 80.5 | |
Average | 80.3/79.2 | 90.8 | 70.0 | 80.0 | |
Random | 80.4/79.2 | 90.8 | 69.9 | 80.0 | |
Concat-split | 81.4/80.6 | 92.0 | 70.4 | 81.1 | |
Average | 81.2/80.4 | 91.9 | 70.5 | 81.0 | |
Random | 81.4/80.0 | 92.0 | 70.3 | 80.9 |
4.5 MHA-split Studies
First we study the effect of MHA-split number on model performance in MHA distillation. We instantiate with 3 attention heads () and 6 attention heads (), respectively. For the student model with , the number of MHA-splits varies with 3, 6 and 12; and for the model with , the number of MHA-splits varies with 6, 8 and 12. The training is conducted under the supervision of teacher model BERT-base, which has 12 attention heads.
Table 4 shows that, the student model performs best, when the MHA-split number equals to the number of student attention heads . This might because it could keep the integrity of student model’s subspaces. That explains why we suggest setting the number of MHA-splits as the number of student attention heads.
As the number of attention heads in student model could be smaller than that in teacher model, the teacher attention heads are divided into several MHA-splits by the number of student attention heads, so that each student attention head is mapped to several teacher attention heads in a MHA-split. Our method (named Concat-split) concatenates all teacher attention heads in a MHA-split together. We could have two other methods to map student attention head to its teacher in a MHA-split: 1) Average map** averages the teacher attention heads in a MHA-split; 2) Random map** randomly selects one teacher attention head in a MHA-split.
The second group of experiments are conducted to compare the performance of Concat-split against Average and Random map**. with 3 attention heads () and 6 attention heads () are instantiated, respectively. As shown in Table 5, our Concat-split outperforms the two other methods on almost all tasks. We think it is because our method has less information loss.
Model | MNLI-m/-mm | CoLA | MRPC | Avg |
---|---|---|---|---|
82.0/80.7 | 35.5 | 86.3 | 71.2 | |
for both FFN | 81.7/80.8 | 25.7 | 85.7 | 68.5 |
and MHA sub-layers | ||||
for both FFN | 80.6/79.7 | 26.6 | 85.6 | 68.1 |
and MHA sub-layers |
4.6 FFN vs. MHA Distillation
In Transformer-layer distillation, as FFN distillation and MHA distillation can be applied on both FFN sub-layer and MHA sub-layer, we are to study whether those two distillations could replace each other. As shown in Table 6, neither MHA distillation nor FFN distillation could perform well on both FFN and MHA sub-layers. Instead, on MHA sub-layer jointing on FFN sub-layer delivers best performance, which is conducted by our method .
Since the outputs of FFN sub-layer contain feature representations of tokens while the outputs of MHA sub-layer encode the relations of different tokens, feature-level distillation (FFN distillation) fits better for FFN sub-layer, and relation-level distillation (MHA distillation) performs better for MHA. That might explain why our method performs best.
4.7 Ablation Studies
The effect of different distillation loss functions (including , , , , and ) on model performance are evaluated by ablation studies on MNLI-m/-mm, CoLA and MRPC tasks. kee** all distillation loss functions is taken as the baseline.
As shown in Table 7, greater drop in performance indicates more importance of corresponding distillation loss function. As such, according to average performance, the distillation loss functions , , , , and are listed in importance descending order. As the removal of any distillation loss function leads to average performance drop, we may conclude that the distillation loss functions proposed by MLKD-BERT are all effective to enhance performance. Therefore, the relation-level knowledge (, , , ) can be complementary to feature-level knowledge () for performance enhancement. That could explain why our MLKD-BERT outperforms state-of-the-art knowledge distillation methods on BERT.
Model | MNLI-m/-mm | CoLA | MRPC | Avg |
---|---|---|---|---|
82.0/80.7 | 35.5 | 86.3 | 71.2 | |
w/o | 81.9/80.8 | 33.3 | 86.9 | 70.7 |
w/o | 81.8/80.3 | 31.3 | 86.1 | 69.9 |
w/o | 81.7/80.1 | 32.7 | 84.9 | 69.9 |
w/o | 81.8/80.4 | 33.7 | 85.2 | 70.3 |
w/o | 81.6/80.4 | 32.2 | 86.2 | 70.1 |
Model | MNLI-m/-mm | CoLA | MRPC | Avg |
---|---|---|---|---|
82.0/80.7 | 35.5 | 86.3 | 71.2 | |
(two-stage) | ||||
80.7/80.2 | 28.5 | 86.2 | 68.9 | |
(one-stage) |
4.8 One-stage vs. Two-stage Distillation
Here we are to study whether the distillation procedure of MLKD-BERT should be partitioned into two stages. As we have got 6 distillation loss functions for MLKD-BERT, the one-stage procedure is designed to minimize the sum of the 6 distillation loss functions. Experimental results with on MNLI-m/-mm, CoLA and MRPC tasks are presented in Table 8. We find that with two-stage distillation outperforms one-stage distillation on all tasks. That might because our procedure partition could emphasize different distillation objectives for Stage 1 and Stage 2. Stage 1 emphasizes distilling feature representation and transformation, while Stage 2 emphasizes distilling sample prediction.
5 Conclusion
In this paper, we propose a novel two-stage distillation method MLKD-BERT to distill multi-level knowledge in teacher-student framework. MLKD-BERT enhances existing knowledge distillation methods on BERT in two ways: bringing in valuable relation-level knowledge, and making flexible setting of student attention head number. Experimental results show that MLKD-BERT outperforms state-of-the-art BERT distillation methods on GLUE benchmark and extractive question answering tasks. We believe that, the easy adaption of our method would be helpful to other Transformer-based PLM compression in teacher-student framework.
Limitations
Our MLKD-BERT has two limitations: 1) The two-stage distillation costs relatively more training time than one-stage methods; 2) MLKD-BERT is limited to handle natural language understanding tasks.
References
- Aguilar et al. (2020) Gustavo Aguilar, Yuan Ling, Yu Zhang, Benjamin Yao, Xing Fan, and Chenlei Guo. 2020. Knowledge distillation from internal representations. In The Thirty-Fourth AAAI Conference on Artificial Intelligence, AAAI 2020, The Thirty-Second Innovative Applications of Artificial Intelligence Conference, IAAI 2020, The Tenth AAAI Symposium on Educational Advances in Artificial Intelligence, EAAI 2020, New York, NY, USA, February 7-12, 2020, pages 7350–7357. AAAI Press.
- Bentivogli et al. (2009) Luisa Bentivogli, Bernardo Magnini, Ido Dagan, Hoa Trang Dang, and Danilo Giampiccolo. 2009. The fifth PASCAL recognizing textual entailment challenge. In Proceedings of the Second Text Analysis Conference, TAC 2009, Gaithersburg, Maryland, USA, November 16-17, 2009. NIST.
- Cer et al. (2017) Daniel M. Cer, Mona T. Diab, Eneko Agirre, Iñigo Lopez-Gazpio, and Lucia Specia. 2017. Semeval-2017 task 1: Semantic textual similarity - multilingual and cross-lingual focused evaluation. CoRR, abs/1708.00055.
- Chen et al. (2018) Zihan Chen, Hongbo Zhang, Xiaoji Zhang, and Leqi Zhao. 2018. Quora question pairs. URL https://www. kaggle. com/c/quora-question-pairs.
- Devlin et al. (2019) Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. 2019. BERT: pre-training of deep bidirectional transformers for language understanding. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, NAACL-HLT 2019, Minneapolis, MN, USA, June 2-7, 2019, Volume 1 (Long and Short Papers), pages 4171–4186. Association for Computational Linguistics.
- Dolan and Brockett (2005) William B. Dolan and Chris Brockett. 2005. Automatically constructing a corpus of sentential paraphrases. In Proceedings of the Third International Workshop on Paraphrasing, IWP@IJCNLP 2005, Jeju Island, Korea, October 2005, 2005. Asian Federation of Natural Language Processing.
- Hinton et al. (2015) Geoffrey E. Hinton, Oriol Vinyals, and Jeffrey Dean. 2015. Distilling the knowledge in a neural network. CoRR, abs/1503.02531.
- Jiao et al. (2020) Xiaoqi Jiao, Yichun Yin, Lifeng Shang, Xin Jiang, Xiao Chen, Linlin Li, Fang Wang, and Qun Liu. 2020. Tinybert: Distilling BERT for natural language understanding. In Findings of the Association for Computational Linguistics: EMNLP 2020, Online Event, 16-20 November 2020, volume EMNLP 2020 of Findings of ACL, pages 4163–4174. Association for Computational Linguistics.
- Khosla et al. (2020) Prannay Khosla, Piotr Teterwak, Chen Wang, Aaron Sarna, Yonglong Tian, Phillip Isola, Aaron Maschinot, Ce Liu, and Dilip Krishnan. 2020. Supervised contrastive learning. In Advances in Neural Information Processing Systems 33: Annual Conference on Neural Information Processing Systems 2020, NeurIPS 2020, December 6-12, 2020, virtual.
- Li et al. (2020) Jianquan Li, Xiaokang Liu, Honghong Zhao, Ruifeng Xu, Min Yang, and Yaohong **. 2020. BERT-EMD: many-to-many layer map** for BERT compression with earth mover’s distance. In Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing, EMNLP 2020, Online, November 16-20, 2020, pages 3009–3018. Association for Computational Linguistics.
- Liu et al. (2019) Yinhan Liu, Myle Ott, Naman Goyal, **gfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, and Veselin Stoyanov. 2019. Roberta: A robustly optimized BERT pretraining approach. CoRR, abs/1907.11692.
- Mikolov et al. (2013) Tomás Mikolov, Ilya Sutskever, Kai Chen, Gregory S. Corrado, and Jeffrey Dean. 2013. Distributed representations of words and phrases and their compositionality. In Advances in Neural Information Processing Systems 26: 27th Annual Conference on Neural Information Processing Systems 2013. Proceedings of a meeting held December 5-8, 2013, Lake Tahoe, Nevada, United States, pages 3111–3119.
- Pennington et al. (2014) Jeffrey Pennington, Richard Socher, and Christopher D. Manning. 2014. Glove: Global vectors for word representation. In Proceedings of the 2014 Conference on Empirical Methods in Natural Language Processing, EMNLP 2014, October 25-29, 2014, Doha, Qatar, A meeting of SIGDAT, a Special Interest Group of the ACL, pages 1532–1543. ACL.
- Radford et al. (2018) Alec Radford, Karthik Narasimhan, Tim Salimans, Ilya Sutskever, et al. 2018. Improving language understanding by generative pre-training.
- Rajpurkar et al. (2018) Pranav Rajpurkar, Robin Jia, and Percy Liang. 2018. Know what you don’t know: Unanswerable questions for squad. In Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics, ACL 2018, Melbourne, Australia, July 15-20, 2018, Volume 2: Short Papers, pages 784–789. Association for Computational Linguistics.
- Rajpurkar et al. (2016) Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, and Percy Liang. 2016. Squad: 100, 000+ questions for machine comprehension of text. In Proceedings of the 2016 Conference on Empirical Methods in Natural Language Processing, EMNLP 2016, Austin, Texas, USA, November 1-4, 2016, pages 2383–2392. The Association for Computational Linguistics.
- Sanh et al. (2019) Victor Sanh, Lysandre Debut, Julien Chaumond, and Thomas Wolf. 2019. Distilbert, a distilled version of BERT: smaller, faster, cheaper and lighter. CoRR, abs/1910.01108.
- Shen et al. (2020) Sheng Shen, Zhen Dong, Jiayu Ye, Linjian Ma, Zhewei Yao, Amir Gholami, Michael W. Mahoney, and Kurt Keutzer. 2020. Q-BERT: hessian based ultra low precision quantization of BERT. In The Thirty-Fourth AAAI Conference on Artificial Intelligence, AAAI 2020, The Thirty-Second Innovative Applications of Artificial Intelligence Conference, IAAI 2020, The Tenth AAAI Symposium on Educational Advances in Artificial Intelligence, EAAI 2020, New York, NY, USA, February 7-12, 2020, pages 8815–8821. AAAI Press.
- Socher et al. (2013) Richard Socher, Alex Perelygin, Jean Wu, Jason Chuang, Christopher D. Manning, Andrew Y. Ng, and Christopher Potts. 2013. Recursive deep models for semantic compositionality over a sentiment treebank. In Proceedings of the 2013 Conference on Empirical Methods in Natural Language Processing, EMNLP 2013, 18-21 October 2013, Grand Hyatt Seattle, Seattle, Washington, USA, A meeting of SIGDAT, a Special Interest Group of the ACL, pages 1631–1642. ACL.
- Sun et al. (2019) Siqi Sun, Yu Cheng, Zhe Gan, and **g**g Liu. 2019. Patient knowledge distillation for BERT model compression. In Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing, EMNLP-IJCNLP 2019, Hong Kong, China, November 3-7, 2019, pages 4322–4331. Association for Computational Linguistics.
- Sun et al. (2020) Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny Zhou. 2020. Mobilebert: a compact task-agnostic BERT for resource-limited devices. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, ACL 2020, Online, July 5-10, 2020, pages 2158–2170. Association for Computational Linguistics.
- Tang et al. (2019) Raphael Tang, Yao Lu, Linqing Liu, Lili Mou, Olga Vechtomova, and Jimmy Lin. 2019. Distilling task-specific knowledge from BERT into simple neural networks. CoRR, abs/1903.12136.
- Turc et al. (2019) Iulia Turc, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. 2019. Well-read students learn better: The impact of student initialization on knowledge distillation. CoRR, abs/1908.08962.
- Wang et al. (2019) Alex Wang, Amanpreet Singh, Julian Michael, Felix Hill, Omer Levy, and Samuel R. Bowman. 2019. GLUE: A multi-task benchmark and analysis platform for natural language understanding. In 7th International Conference on Learning Representations, ICLR 2019, New Orleans, LA, USA, May 6-9, 2019. OpenReview.net.
- Wang et al. (2021) Wenhui Wang, Hangbo Bao, Shaohan Huang, Li Dong, and Furu Wei. 2021. Minilmv2: Multi-head self-attention relation distillation for compressing pretrained transformers. In Findings of the Association for Computational Linguistics: ACL/IJCNLP 2021, Online Event, August 1-6, 2021, volume ACL/IJCNLP 2021 of Findings of ACL, pages 2140–2151. Association for Computational Linguistics.
- Wang et al. (2020a) Wenhui Wang, Furu Wei, Li Dong, Hangbo Bao, Nan Yang, and Ming Zhou. 2020a. Minilm: Deep self-attention distillation for task-agnostic compression of pre-trained transformers. In Advances in Neural Information Processing Systems 33: Annual Conference on Neural Information Processing Systems 2020, NeurIPS 2020, December 6-12, 2020, virtual.
- Wang et al. (2020b) Ziheng Wang, Jeremy Wohlwend, and Tao Lei. 2020b. Structured pruning of large language models. In Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing, EMNLP 2020, Online, November 16-20, 2020, pages 6151–6162. Association for Computational Linguistics.
- Warstadt et al. (2019) Alex Warstadt, Amanpreet Singh, and Samuel R. Bowman. 2019. Neural network acceptability judgments. Trans. Assoc. Comput. Linguistics, 7:625–641.
- Williams et al. (2018) Adina Williams, Nikita Nangia, and Samuel R. Bowman. 2018. A broad-coverage challenge corpus for sentence understanding through inference. In Proceedings of the 2018 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, NAACL-HLT 2018, New Orleans, Louisiana, USA, June 1-6, 2018, Volume 1 (Long Papers), pages 1112–1122. Association for Computational Linguistics.
- Yang et al. (2019) Zhilin Yang, Zihang Dai, Yiming Yang, Jaime G. Carbonell, Ruslan Salakhutdinov, and Quoc V. Le. 2019. Xlnet: Generalized autoregressive pretraining for language understanding. In Advances in Neural Information Processing Systems 32: Annual Conference on Neural Information Processing Systems 2019, NeurIPS 2019, December 8-14, 2019, Vancouver, BC, Canada, pages 5754–5764.
- Zhang et al. (2019) Zhengyan Zhang, Xu Han, Zhiyuan Liu, Xin Jiang, Maosong Sun, and Qun Liu. 2019. ERNIE: enhanced language representation with informative entities. In Proceedings of the 57th Conference of the Association for Computational Linguistics, ACL 2019, Florence, Italy, July 28- August 2, 2019, Volume 1: Long Papers, pages 1441–1451. Association for Computational Linguistics.
Appendix A Hyper-parameters for Two-stage Distillation
Task | Stage 1 | Stage 2 | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
Epochs | Batch size | Max seq length | Learning rate | Epochs | Batch size | Max seq length | Learning rate | |||
CoLA | 50 | 32 | 64 | 1e-5 | 30 | 32 | 64 | 1e-5 | 0.07 | 1.0 |
MNLI | 6 | 32 | 128 | 3e-5 | 6 | 32 | 128 | 3e-5 | 0.07 | 1.0 |
MRPC | 20 | 32 | 128 | 2e-5 | 15 | 32 | 128 | 2e-5 | 0.07 | 1.0 |
SST-2 | 15 | 32 | 64 | 2e-5 | 10 | 32 | 64 | 2e-5 | 0.07 | 1.0 |
STS-B | 20 | 32 | 128 | 3e-5 | 15 | 32 | 128 | 3e-5 | 0.07 | 1.0 |
QQP | 6 | 32 | 128 | 2e-5 | 6 | 32 | 128 | 2e-5 | 0.07 | 1.0 |
QNLI | 10 | 32 | 128 | 2e-5 | 10 | 32 | 128 | 2e-5 | 0.07 | 1.0 |
RTE | 20 | 32 | 128 | 2e-5 | 15 | 32 | 128 | 2e-5 | 0.07 | 1.0 |
SQuAD 1.1 | 4 | 16 | 384 | 3e-5 | 3 | 16 | 384 | 3e-5 | 0.07 | 1.0 |
SQuAD 2.0 | 4 | 16 | 384 | 3e-5 | 3 | 16 | 384 | 3e-5 | 0.07 | 1.0 |
QQP | SST-2 | ||||||
1 | 4.27 0.33 | 4.16 0.19 | 4.09 0.20 | 4.24 0.22 | 4.12 0.22 | 4.00 0.19 | |
(-2.58%) | (-4.22%) | (-2.83%) | (-5.66%) | ||||
16 | 5.46 0.17 | 5.06 0.21 | 4.93 0.21 | 5.49 0.20 | 5.07 0.24 | 4.91 0.30 | |
Inference | (-7.33%) | (-9.71%) | (-7.65%) | (-10.56%) | |||
time(ms) | 32 | 10.35 0.17 | 9.48 0.26 | 8.99 0.22 | 10.39 0.22 | 9.49 0.18 | 9.01 0.24 |
(-8.41%) | (-13.14%) | (-8.66%) | (-13.28%) | ||||
64 | 19.28 0.21 | 17.49 0.18 | 16.60 0.22 | 19.27 0.20 | 17.49 0.21 | 16.59 0.23 | |
(-9.28%) | (-13.90%) | (-9.24%) | (-13.91%) | ||||
Performance | 70.6 | 70.4 | 70.0 | 91.9 | 92.0 | 91.9 | |
(-0.28%) | (-0.85%) | (+0.11%) | (-0.00%) |
The detailed hyper-parameters for two-stage distillation of MLKD-BERT on GLUE tasks and extractive question answering tasks are presented in Table 9.
Appendix B Effect of Attention Head Number on Model Performance and Inference Time on more GLUE tasks
Appendix C GLUE Dataset
In this section, we provide a brief description of the tasks in GLUE benchmark (Wang et al., 2019).
MNLI. Multi-Genre Natural Language Inference is a large-scale, crowd-sourced entailment classification task (Williams et al., 2018). Given a pair of , the goal is to predict whether the
is an entailment, contradiction, or neutral with respect to the .
QQP. Quora Question Pairs is a collection of question pairs from the website Quora. The task is to determine whether two questions are semantically equivalent (Chen et al., 2018).
QNLI. Question Natural Language Inference is a version of the Stanford Question Answering Dataset which has been converted to a binary sentence pair classification task by (Wang et al., 2019). Given a pair of , The task is to determine whether the contains the to the question.
SST-2. The Stanford Sentiment Treebank is a binary single-sentence classification task, where the goal is to predict the sentiment of movie reviews (Socher et al., 2013).
CoLA. The Corpus of Linguistic Acceptability is a task to predict whether an English sentence is a grammatically correct one (Warstadt et al., 2019).
STS-B. The Semantic Textual Similarity Benchmark is a collection of sentence pairs drawn from news headlines and many other domains (Cer et al., 2017). The task aims to evaluate how similar two pieces of texts are by a score from 1 to 5.
MRPC. Microsoft Research Paraphrase Corpus is a paraphrase identification dataset where systems aim to identify if two sentences are paraphrases of each other (Dolan and Brockett, 2005).
RTE. Recognizing Textual Entailment is a binary entailment task with a small training dataset (Bentivogli et al., 2009).
Appendix D SQuAD
Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable.
SQuAD 1.1 Rajpurkar et al. (2016) contains 100,000+ question-answer pairs on 500+ articles.
SQuAD 2.0 (Rajpurkar et al., 2018) combines the 100,000 questions in SQuAD 1.1 with over 50,000 unanswerable questions written adversarially by crowdworkers to look similar to answerable ones.