SAM-Lightening: A Lightweight Segment Anything Model with Dilated Flash Attention to Achieve 30 Acceleration
Abstract
Segment Anything Model (SAM) has garnered significant attention in segmentation tasks due to their zero-shot generalization ability. However, a broader application of SAMs to real-world practice has been restricted by their low inference speed and high computational memory demands, which mainly stem from the attention mechanism. Existing work concentrated on optimizing the encoder, yet has not adequately addressed the inefficiency of the attention mechanism itself, even when distilled to a smaller model, which thus leaves space for further improvement. In response, we introduce SAM-Lightening, a variant of SAM, that features a re-engineered attention mechanism, termed Dilated Flash Attention. It not only facilitates higher parallelism, enhancing processing efficiency but also retains compatibility with the existing FlashAttention. Correspondingly, we propose a progressive distillation to enable an efficient knowledge transfer from the vanilla SAM without costly training from scratch. Experiments on COCO and LVIS reveal that SAM-Lightening significantly outperforms the state-of-the-art methods in both run-time efficiency and segmentation accuracy. Specifically, it can achieve an inference speed of 7 milliseconds (ms) per image, for images of size 10241024 pixels, which is faster than the vanilla SAM and than the state-of-the-art. Moreover, it takes only memory, which is of the vanilla SAM. The code and weights are available at https://anonymous.4open.science/r/SAM-LIGHTENING-BC25/.
Index Terms— Segment Anything Model (SAM), Knowledge Distillation, Computational Efficient Attention Mechanisms
1 Introduction
Image segmentation has been traditionally constrained by the necessity for deep learning models to be specifically trained on datasets designed for particular tasks. This specialization of hand-crafted datasets often limits their generation ability. Addressing this constraint, the Segment Anything Model (SAM) [1] represents a paradigmatic shift with its zero-shot learning abilities that allow itself to segment new and unseen images. However, SAM’s application in varied sectors like augmented reality (AR), image editing, deployment on smartphones and medical imaging [2, 3, 4, 5, 6] is impeded by its computational burden challenge in its image encoder, which comprises a substantial 632 million parameters. This size is roughly 20 times that of conventional segmentation networks like U-Net [7], leading to high computational demands.
In response to this challenge, various efforts have been initiated. For example, FastSAM [8] adopts a strategy of replacing SAM’s transformer encoder with a more streamlined convolutional neural network (CNN), aiming to create a lighter model. However, this often leads to diminished accuracy, especially in complex segmentation tasks. Another notable approach is MobileSAM [9], which employs distillation techniques to transfer knowledge from SAM’s encoder to a more compact ViT-tiny [10] encoder. Similarly, initiatives like EfficientSAM [11] aim to refine the training processes of MobileSAM to improve accuracy. Conversely, SAMFast [12] focuses on speed optimization of the original SAM through techniques such as quantization and pruning, but these modifications have limited impact on performance enhancement.
Our research identifies key limitations in previous works [9, 11, 12] on SAM, primarily in terms of inefficient computation and memory usage in attention mechanisms. To address these issues, we integrate FlashAttention [13] and dilated attention mechanisms into our SAM framework, providing orthogonal improvements over existing methods. These enhancements not only reduce memory consumption but also improve parallel processing, making them complementary to previous advancements. However, directly applying these mechanisms to SAM would necessitate a complete retraining of the model, incurring substantial computational costs. To circumvent this challenge, we proposed a dynamic layer-wise distillation (DLD). DLD implements a progressive distillation scheme for the image encoder by progressively allocating feature weights, effectively facilitating the transfer of knowledge from SAM to our lightweight model. We demonstrate that our model (SAM-Lightening) is not only expressive enough to represent the original SAM but is also computationally efficient, completing inference within .
In brief, our main contributions are four-fold:
-
•
We introduce a novel SAM structure, SAM-Lightening, to significantly reduce the computational complexity.
-
•
We design a novel dilated flash attention mechanism to replace the vanilla self-attention to enhance the efficiency and inference speed of SAM-Lightening.
-
•
To efficiently transfer the knowledge from vanilla SAM to SAM-Lightening, we propose a dynamic layer-wise distillation without compromising the performance.
-
•
SAM-Lightening achieves state-of-the-art performance of 7 ms per image, which is faster than vanilla SAM.
2 Related work
Segment Anything Model: SAM comprises three main parts: the image encoder, prompt encoder, and mask decoder. Notably, the image encoder is the most parameter-intensive segment of SAM, accounting for a substantial 98.3% of its processing time [1], which highlights the need for optimization. FastSAM [8] employs a CNN encoder, specifically the YOLOv8-seg [14], to replace the ViT encoder to enhance processing speed. However, it has been observed to compromise segmentation precision, particularly in complex scenarios and in capturing fine edge details. MobileSAM [9] distill the encoder to reduce both the model size and computational requirements. Nevertheless, the imbalance in MobileSAM’s encoder structure and parameter distribution limits its potential for practical deployment and performance optimization. SAMFast [12] represents another optimization strategy, focusing on enhancing the processing speed of SAM using methods like quantization and sparsification. While this scheme does offer some acceleration, its overall impact remains moderate. EfficientSAM [11], on the other hand, improves upon MobileSAM’s training methodology, specifically targeting the accuracy aspect of the MobileSAM approach.
FlashAttention: The FlashAttention mechanism [13] introduces an efficient and accurate approach for computing attention in neural networks. It achieves a significant reduction in high bandwidth memory reads and writes, primarily through strategic tiling and recomputation techniques. Building upon this, FlashAttention-2 [15] further refines the process by enhanced matrix multiplication operations. These improvements have been shown to deliver up to a twofold increase in performance in specific computational settings.
Knowledge Distillation: Knowledge distillation [16] is a technique for transferring knowledge from a complex model to a simpler one. They aim to retain the performance attributes of the larger model while significantly reducing its computational footprint and model size. MobileSAM employs a decoupled knowledge distillation by extracting outputs from the original SAM’s ViT-H image encoder and using them to distill into a pre-trained ViT-tiny encoder directly. This strategy proves particularly beneficial for smaller models that already possess pre-trained parameters.
3 Methods
3.1 Dilated Flash Attention
To address the high computational demands in the image encoder of SAM, we design a novel attention operation with FlashAttention to expedite the inference speed.
Segmentation and Sparsification: To alleviate the computational burden in processing () in attention operation, we divide each input into equal-length parts () and then apply sparsification along the sequence dimension within each segment. This sparsification involves selecting rows at fixed intervals (), thereby reducing the volume of data the attention mechanism needs to process. As shown in Fig. 1, the sparsification process can be formulated as:
(1) |
Here, represents the sampled sparse matrix. represents any of the variables , , or .
Parallel Processing With FlashAttention: Sparsified segments of each input data are dense matrices that can participate in the attention calculation independently and thus can be processed in parallel. This parallelism is vital for efficiently managing large-scale image datasets, significantly speeding up the processing time and enhancing the efficiency of our model for real-time image segmentation. Incorporating FlashAttention further increases efficiency by parallelizing dense matrix computations in the process.
Output Recomposition: In the proposed Dilated Flash Attention framework, we process sparsified segments in parallel, implementing a softmax function applied to the product of and the transpose of , subsequently followed by multiplication with as follows:
(2) |
The reassembly of these outputs into the cohesive final output involves a meticulously designed process:
-
(1)
Initially, we establish a zero matrix that mirrors the dimensions of the original input for accumulating the outputs of the individual segments.
-
(2)
For each computed segment output , a specific offset is identified. This offset determines the precise starting position of within the matrix.
-
(3)
Each is mapped to using a map** operation based on its :
(3)
The “MAP” operation places each element into according to the position determined by . This guarantees the accurate alignment of each segment’s output within the final output matrix , based on its original input position.
Computation Efficiency With the proposed Dilated Flash Attention mechanism, efficiency is quantitatively enhanced by a factor of , where represents the total size of the input, the length of each segment, and the interval of sparsification. This mathematical relationship demonstrates that Dilated Flash Attention requires substantially fewer computations for any given input size. Consequently, this boosts the model’s capability in efficiently processing large-scale image segmentation tasks, marking a notable improvement in both performance and practicality.
3.2 Dynamic Layer-Wise Distillation (DLD)
Training the SAM-Lightening from scratch is costly, while layer adaptation is challenging due to the distinctive structures between SAM with ViT-H as the feature encoder and SAM-Lightening. To enable efficient knowledge transfer from vanilla SAM to the proposed framework, we propose a novel Dynamic Layer-Wise Distillation (DLD), which dynamically modifies feature weights to enhance the layer-wise distillation between the models [17].
Dynamic Layer-Wise Weights: When preceding layers are not well-distilled, the performance of subsequent layers can suffer from low-quality features extracted from preceding layers. By assigning greater weight to the loss of these initial layers, dynamic weighting ensures they receive more focus during the training process. This helps in better aligning the student model with the teacher model in the initial stages. Given a deep neural network consisting of layers, each layer is associated with a temporal weight . This mechanism adjusts the significance of each layer in the neural network across various training stages . The initial layer retains maximum emphasis () and the subsequent layers adhere to a dynamic weighting scheme, which can be mathematically represented by the piece-wise function:
(4) |
Where denotes the epoch at which the layer commences updating its weight, and the previous layer has reached saturation, i.e., . The parameter captures the number of epochs over which the weight transitions from 0 to 1. For a predefined epoch increment , each layer sequentially activates its learning potential after the preceding layer reaches its peak weight. This mechanism facilitates a cascading knowledge absorption from the teacher model.
Decoupled Feature Distillation: The distillation process transfers knowledge from SAM’s encoder (the teacher model) to our proposed encoder (the student model), as shown in Fig.1. We have chosen the layers closest to the output for feature distillation. Since these deeper layers are directly related to the model’s outputs, distilling them can more effectively transfer crucial information for prediction results. These layers are designated as “Focus Layers”.
During the initial phase of training, layers closer to the input are given precedence. Here, the intent is to align the SAM-Lightning primary feature representations of the student model, expressed as , with those of the teacher model, , for the layers closest to the input. As training advances, the layer-wise weighting dynamically shifts. The loss associated with subsequent layers is incrementally amplified. In the progress, the loss function evolves to assimilate representations from succeeding layers:
(5) |
where is the complete count of layers, and the coefficient is a piece-wise function determined by the training epoch and the layer . The integrated distillation loss is formulated as:
(6) |
where encapsulates the weighted sum of all selected feature layer losses, is the loss for the image encoder output layer, and is a scaling factor to balance the significance of the decoder output in the overall distillation process.
Align Decoder: Additionally, the lightweight image encoder obtained through decoupled distillation has alignment issues with the frozen decoder, especially for point-based prompt segmentation tasks. Therefore, we fine-tuned the decoder by sampling point prompts and box prompts on the SA-1B dataset to align with the image encoder. The loss function is defined as follows:
(7) |
Here, IOU represents the Intersection over Union loss, while Dice loss and Focal Loss are used to address class imbalance and challenging segmentation regions, respectively.
4 Experiment
4.1 Experimental Setups
Our model utilizes of the SA-1B dataset for distillation and fine-tuning. It features an encoder with an embedding dimension of 384, six attention heads, and a six-layer structure. For the FlashAttention component, we use bfloat16. Both the distillation and fine-tuning processes are conducted for 10 epochs each, with a learning rate of and a batch size of 32. Gradient accumulation is set with a step size of 4. The model is trained on two NVIDIA RTX 4090 GPUs. To enhance training speed, the outputs of SAM’s image encoder are saved [10, 9].
4.2 Results
Run-Time And Memory Efficiency Evaluation: We compare the performance of our proposed SAM-Lightening with vanilla SAM (i.e., SAM-ViT-H) [1], FastSAM [8], MobileSAM [9], EfficientSAM [11], SAMFast [12] in Table 1 and Table 2. Regarding the segmentation performance, the vanilla SAM is considered as the upper bound. Importantly, Table 1 shows that SAM-Lightening outperforms all its counterparts in terms of inference latency and peak memory usage, achieving acceleration, peak memory reduction when compared to vanilla SAM, and acceleration when compared to the state-of-the-art. The throughput comparison in Table 2 further reinforces SAM-Lightening’s superior performance, which achieves the highest throughput across various batch sizes. Conclusively, this high throughput with its low latency and memory usage, positions SAM-Lightening as a highly efficient model for image segmentation tasks.
Model |
Enc. ms |
Dec. ms |
Tot. ms |
S.U. | Mem. |
---|---|---|---|---|---|
SAM-ViT-H |
216.1 |
3.8 |
219.9 |
1.0 |
5.7GB |
SAMFast |
23.2 |
3.8 |
27.0 |
8.5 |
4.1GB |
FastSAM |
20.7 |
3.4 |
24.1 |
9.1 |
2.6GB |
EfficientSAM |
22.3 |
3.8 |
26.1 |
8.3 |
309MB |
MobileSAM |
8.1 |
3.8 |
11.9 |
18.5 |
309MB |
SAM-Lightening | 3.5 | 3.4 | 6.9 | 30.1 | 224MB |
Model | Size 1 | Size 4 | Size 8 | Size 16 |
---|---|---|---|---|
SAM-ViT-H |
219.9 |
944.9 |
OOM |
OOM |
SAMFast |
53.6 |
206.6 |
438.2 |
964.2 |
FastSAM |
24.1 |
80.1 |
171.5 |
349.1 |
EfficientSAM |
22.3 |
79.2 |
157.7 |
317.5 |
MobileSAM |
8.1 |
34.1 |
72.3 |
156.8 |
SAM-Lightening | 3.5 | 13.0 | 27.2 | 59.2 |
Comparison In Box/Point Prompt Mode: We first evaluated the performance under bounding boxes and point-based prompts. For bounding box prompts, we followed the settings in vanilla SAM by leveraging the ground-truth annotation in the COCO [18] and LVIS [19] to synthesize bounding boxes that define areas of interest in each image. For point prompts, we randomly sampled points within the ground-truth masks from images, challenging all the models to accurately segment the object or region associated with each point. Quantitatively, we used mean Intersection over Union (mIoU) as the metric. As shown in Table 3, both SAMFast and MobileSAM suffer from a performance decline when compared to vanilla SAM, particularly with point prompts. FastSAM, as a CNN-based model, shows an even more pronounced drop, which is especially evident in the handling LVIS dataset that contains a large number of small objects. This observation reflects the limitations of CNN-based encoders in processing more complex segmentation scenarios. In contrast, SAM-Lightening matches the original SAM in terms of segmentation performance to the best context. This holds even in scenarios of point-based prompts, where SAM-Lightening achieves mIoU similar to the vanilla SAM.
Model | COCO | LVIS | ||||
---|---|---|---|---|---|---|
Box | 1P | 3P | Box | 1P | 3P | |
SAM-ViT-H | 80.1 | 49.2 | 72.5 | 83.8 | 60.6 | 74.7 |
SAMFast | 77.3 | 44.7 | 66.3 | 80.5 | 54.9 | 69.4 |
FastSAM | 65.0 | 50.9 | 52.4 | 61.5 | 41.5 | 41.8 |
EfficientSAM | 77.8 | 43.6 | 69.7 | 79.5 | 53.7 | 72.9 |
MobileSAM | 77.9 | 47.9 | 67.4 | 78.5 | 55.4 | 66.8 |
SAM-Lightening | 78.8 | 48.4 | 72.5 | 81.0 | 59.9 | 74.6 |
Comparison In Anything Mode:
While the segment-anything mode is an innovative approach, it is not a commonly used segmentation method and thus does not effectively represent typical segmentation tasks. Therefore, our analysis has primarily focused on visually comparing the segmentation outcomes through point-based and box-based methods, which are more prevalent in practical applications. However, for completeness and to demonstrate the versatility of the models, we have also included the outputs of the segment-anything mode in our comparison.
From the representative samples demonstrated in Fig. 3, both SAM-Lightening and MobileSAM exhibit segmentation results that are nearly indistinguishable from those of the vanilla SAM. This similarity is notable in terms of edge clarity and detail preservation, which are hallmarks of high-quality segmentation. SAM-Lightening demonstrates its robustness and accuracy, aligning closely with the performance of the vanilla SAM.
4.3 Ablation study
It’s noteworthy that many previous works [4, 20, 21] use smaller input sizes for SAM other than 1024. For a fair comparison, we also conducted experiments in these scenarios and found that kee** FlashAttention for input sizes equal to or smaller than achieves optimal performance. This indicates that the applicability of FlashAttention depends on the model’s input size and specific hardware configuration. The decision to use FlashAttention should be made based on the specific application context and performance requirements. Although FlashAttention accelerates training in model distillation, its impact on inference performance is determined by various hardware metrics. On our inference platform, especially for the SAM with a 1024 input size, the multi-head attention operator exhibits a more computation-intensive characteristic. As shown in Fig. 4, this results in a slightly lower inference speed with FlashAttention compared to without it. Therefore, we opt to use FlashAttention during the distillation process to optimize performance while removing it during the evaluation phase.
5 Conclusion
We propose SAM-Lightening to address the primary limitations of high computational demand and slow inference speed in vanilla SAM to make it more suitable for deployment on resource-constrained devices. Our approach involves the redesign of the image encoder in SAM, by distilling the self-attention operators into dilated flash attentions with dynamic layer-wise distillation. These optimizations contribute to a notable reduction in computational complexity and memory usage without compromising the segmentation performance. Specifically, SAM-Lightening can complete inference within 7 milliseconds per image, achieving a speed up over SAM-ViT-H. Since SAM-Lightening is complementary to pruning and quantization, one future direction can look into the integration with them.
References
- [1] Kirillov et al., “Segment anything,” arXiv preprint arXiv:2304.02643, 2023.
- [2] Archit et al., “Segment anything for microscopy,” Aug 2023.
- [3] Ma et al., “Segment anything in medical images,” Apr 2023.
- [4] Cheng et al., “Sam-med2d,” arXiv preprint arXiv:2308.16184, 2023.
- [5] Yang et al., “Track anything: Segment anything meets videos,” Apr 2023.
- [6] Shen et al., “Anything-3d: Towards single-view anything reconstruction in the wild,” arXiv preprint arXiv:2304.10261, 2023.
- [7] Ronneberger et al., U-Net: Convolutional Networks for Biomedical Image Segmentation, p. 234–241, Jan 2015.
- [8] Zhao et al., “Fast segment anything,” arXiv preprint arXiv:2306.12156, 2023.
- [9] Zhang et al., “Faster segment anything: Towards lightweight sam for mobile applications,” arXiv preprint arXiv:2306.14289, 2023.
- [10] Wu et al., “Tinyvit: Fast pretraining distillation for small vision transformers,” Springer, Cham, 2022.
- [11] Xiong et al., “Efficientsam: Leveraged masked image pretraining for efficient segment anything,” arXiv preprint arXiv:2312.00863, 2023.
- [12] PyTorch Team, “Accelerating generative ai,” accelerating-generative-ai, 2023.
- [13] Dao et al., “Flashattention: Fast and memory-efficient exact attention with io-awareness,” Advances in Neural Information Processing Systems, vol. 35, pp. 16344–16359, 2022.
- [14] Jocher et al., “Ultralytics yolov8,” 2023.
- [15] Dao, “Flashattention-2: Faster attention with better parallelism and work partitioning,” arXiv preprint arXiv:2307.08691, 2023.
- [16] Hinton et al., “Distilling the knowledge in a neural network,” arXiv: Machine Learning,arXiv: Machine Learning, Mar 2015.
- [17] Ji et al., “Show, attend and distill:knowledge distillation via attention-based feature matching,” Proceedings of the AAAI Conference on Artificial Intelligence, p. 7945–7952, Sep 2022.
- [18] Lin et al., “Microsoft coco: Common objects in context,” COMPUTER VISION - ECCV 2014, PT V, pp. 740–755, 2014.
- [19] Gupta et al., “Lvis: A dataset for large vocabulary instance segmentation,” in 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), Jun 2019.
- [20] Chen et al., “MMDetection: Open mmlab detection toolbox and benchmark,” arXiv preprint arXiv:1906.07155, 2019.
- [21] Wang et al., “Seggpt: Segmenting everything in context,” arXiv preprint arXiv:2304.03284, 2023.