License: arXiv.org perpetual non-exclusive license
arXiv:2403.09195v2 [cs.CV] 18 Mar 2024

SAM-Lightening: A Lightweight Segment Anything Model with Dilated Flash Attention to Achieve 30×\times× Acceleration

Abstract

Segment Anything Model (SAM) has garnered significant attention in segmentation tasks due to their zero-shot generalization ability. However, a broader application of SAMs to real-world practice has been restricted by their low inference speed and high computational memory demands, which mainly stem from the attention mechanism. Existing work concentrated on optimizing the encoder, yet has not adequately addressed the inefficiency of the attention mechanism itself, even when distilled to a smaller model, which thus leaves space for further improvement. In response, we introduce SAM-Lightening, a variant of SAM, that features a re-engineered attention mechanism, termed Dilated Flash Attention. It not only facilitates higher parallelism, enhancing processing efficiency but also retains compatibility with the existing FlashAttention. Correspondingly, we propose a progressive distillation to enable an efficient knowledge transfer from the vanilla SAM without costly training from scratch. Experiments on COCO and LVIS reveal that SAM-Lightening significantly outperforms the state-of-the-art methods in both run-time efficiency and segmentation accuracy. Specifically, it can achieve an inference speed of 7 milliseconds (ms) per image, for images of size 1024×\times×1024 pixels, which is 30.1×30.1\times30.1 × faster than the vanilla SAM and 2.1×2.1\times2.1 × than the state-of-the-art. Moreover, it takes only 244 MBtimes244megabyte244\text{\,}\mathrm{MB}start_ARG 244 end_ARG start_ARG times end_ARG start_ARG roman_MB end_ARG memory, which is (3.5%)percent3.5(3.5\%)( 3.5 % ) of the vanilla SAM. The code and weights are available at https://anonymous.4open.science/r/SAM-LIGHTENING-BC25/.

Index Terms—  Segment Anything Model (SAM), Knowledge Distillation, Computational Efficient Attention Mechanisms

1 Introduction

Refer to caption
Fig. 1: The overall framework of SAM-Lightening along with the dynamic layer-wise distillation that can efficiently transfer knowledge from the vanilla SAM without training from scratch.

Image segmentation has been traditionally constrained by the necessity for deep learning models to be specifically trained on datasets designed for particular tasks. This specialization of hand-crafted datasets often limits their generation ability. Addressing this constraint, the Segment Anything Model (SAM) [1] represents a paradigmatic shift with its zero-shot learning abilities that allow itself to segment new and unseen images. However, SAM’s application in varied sectors like augmented reality (AR), image editing, deployment on smartphones and medical imaging [2, 3, 4, 5, 6] is impeded by its computational burden challenge in its image encoder, which comprises a substantial 632 million parameters. This size is roughly 20 times that of conventional segmentation networks like U-Net [7], leading to high computational demands.

In response to this challenge, various efforts have been initiated. For example, FastSAM [8] adopts a strategy of replacing SAM’s transformer encoder with a more streamlined convolutional neural network (CNN), aiming to create a lighter model. However, this often leads to diminished accuracy, especially in complex segmentation tasks. Another notable approach is MobileSAM [9], which employs distillation techniques to transfer knowledge from SAM’s encoder to a more compact ViT-tiny [10] encoder. Similarly, initiatives like EfficientSAM [11] aim to refine the training processes of MobileSAM to improve accuracy. Conversely, SAMFast [12] focuses on speed optimization of the original SAM through techniques such as quantization and pruning, but these modifications have limited impact on performance enhancement.

Our research identifies key limitations in previous works [9, 11, 12] on SAM, primarily in terms of inefficient computation and memory usage in attention mechanisms. To address these issues, we integrate FlashAttention [13] and dilated attention mechanisms into our SAM framework, providing orthogonal improvements over existing methods. These enhancements not only reduce memory consumption but also improve parallel processing, making them complementary to previous advancements. However, directly applying these mechanisms to SAM would necessitate a complete retraining of the model, incurring substantial computational costs. To circumvent this challenge, we proposed a dynamic layer-wise distillation (DLD). DLD implements a progressive distillation scheme for the image encoder by progressively allocating feature weights, effectively facilitating the transfer of knowledge from SAM to our lightweight model. We demonstrate that our model (SAM-Lightening) is not only expressive enough to represent the original SAM but is also computationally efficient, completing inference within 7 mstimes7millisecond7\text{\,}\mathrm{ms}start_ARG 7 end_ARG start_ARG times end_ARG start_ARG roman_ms end_ARG.

In brief, our main contributions are four-fold:

  • We introduce a novel SAM structure, SAM-Lightening, to significantly reduce the computational complexity.

  • We design a novel dilated flash attention mechanism to replace the vanilla self-attention to enhance the efficiency and inference speed of SAM-Lightening.

  • To efficiently transfer the knowledge from vanilla SAM to SAM-Lightening, we propose a dynamic layer-wise distillation without compromising the performance.

  • SAM-Lightening achieves state-of-the-art performance of 7 ms per image, which is 30.1×30.1\times30.1 × faster than vanilla SAM.

2 Related work

Segment Anything Model: SAM comprises three main parts: the image encoder, prompt encoder, and mask decoder. Notably, the image encoder is the most parameter-intensive segment of SAM, accounting for a substantial 98.3% of its processing time [1], which highlights the need for optimization. FastSAM [8] employs a CNN encoder, specifically the YOLOv8-seg [14], to replace the ViT encoder to enhance processing speed. However, it has been observed to compromise segmentation precision, particularly in complex scenarios and in capturing fine edge details. MobileSAM [9] distill the encoder to reduce both the model size and computational requirements. Nevertheless, the imbalance in MobileSAM’s encoder structure and parameter distribution limits its potential for practical deployment and performance optimization. SAMFast [12] represents another optimization strategy, focusing on enhancing the processing speed of SAM using methods like quantization and sparsification. While this scheme does offer some acceleration, its overall impact remains moderate. EfficientSAM [11], on the other hand, improves upon MobileSAM’s training methodology, specifically targeting the accuracy aspect of the MobileSAM approach.

FlashAttention: The FlashAttention mechanism [13] introduces an efficient and accurate approach for computing attention in neural networks. It achieves a significant reduction in high bandwidth memory reads and writes, primarily through strategic tiling and recomputation techniques. Building upon this, FlashAttention-2 [15] further refines the process by enhanced matrix multiplication operations. These improvements have been shown to deliver up to a twofold increase in performance in specific computational settings.

Knowledge Distillation: Knowledge distillation [16] is a technique for transferring knowledge from a complex model to a simpler one. They aim to retain the performance attributes of the larger model while significantly reducing its computational footprint and model size. MobileSAM employs a decoupled knowledge distillation by extracting outputs from the original SAM’s ViT-H image encoder and using them to distill into a pre-trained ViT-tiny encoder directly. This strategy proves particularly beneficial for smaller models that already possess pre-trained parameters.

3 Methods

3.1 Dilated Flash Attention

To address the high computational demands in the image encoder of SAM, we design a novel attention operation with FlashAttention to expedite the inference speed.

Segmentation and Sparsification: To alleviate the computational burden in processing (Q,K,V𝑄𝐾𝑉Q,K,Vitalic_Q , italic_K , italic_V) in attention operation, we divide each input into equal-length parts (w𝑤witalic_w) and then apply sparsification along the sequence dimension within each segment. This sparsification involves selecting rows at fixed intervals (r𝑟ritalic_r), thereby reducing the volume of data the attention mechanism needs to process. As shown in Fig. 1, the sparsification process can be formulated as:

X~i=[Xiw,Xiw+r,Xiw+2r,,X(i+1)w1],subscript~𝑋𝑖subscript𝑋𝑖𝑤subscript𝑋𝑖𝑤𝑟subscript𝑋𝑖𝑤2𝑟subscript𝑋𝑖1𝑤1\widetilde{X}_{i}=[X_{iw},X_{iw+r},X_{iw+2r},\ldots,X_{(i+1)w-1}],over~ start_ARG italic_X end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = [ italic_X start_POSTSUBSCRIPT italic_i italic_w end_POSTSUBSCRIPT , italic_X start_POSTSUBSCRIPT italic_i italic_w + italic_r end_POSTSUBSCRIPT , italic_X start_POSTSUBSCRIPT italic_i italic_w + 2 italic_r end_POSTSUBSCRIPT , … , italic_X start_POSTSUBSCRIPT ( italic_i + 1 ) italic_w - 1 end_POSTSUBSCRIPT ] , (1)

Here, X~isubscript~𝑋𝑖\widetilde{X}_{i}over~ start_ARG italic_X end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT represents the sampled sparse matrix. X𝑋Xitalic_X represents any of the variables Q𝑄Qitalic_Q, K𝐾Kitalic_K, or V𝑉Vitalic_V.

Parallel Processing With FlashAttention: Sparsified segments of each input data are dense matrices that can participate in the attention calculation independently and thus can be processed in parallel. This parallelism is vital for efficiently managing large-scale image datasets, significantly speeding up the processing time and enhancing the efficiency of our model for real-time image segmentation. Incorporating FlashAttention further increases efficiency by parallelizing dense matrix computations in the process.

Output Recomposition: In the proposed Dilated Flash Attention framework, we process sparsified segments in parallel, implementing a softmax function applied to the product of Q~isubscript~𝑄𝑖\widetilde{Q}_{i}over~ start_ARG italic_Q end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and the transpose of K~isubscript~𝐾𝑖\widetilde{K}_{i}over~ start_ARG italic_K end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, subsequently followed by multiplication with V~isubscript~𝑉𝑖\widetilde{V}_{i}over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT as follows:

O~i=𝚜𝚘𝚏𝚝𝚖𝚊𝚡(Q~iK~iT)V~i.subscript~𝑂𝑖𝚜𝚘𝚏𝚝𝚖𝚊𝚡subscript~𝑄𝑖subscriptsuperscript~𝐾𝑇𝑖subscript~𝑉𝑖\widetilde{O}_{i}=\texttt{softmax}(\widetilde{Q}_{i}\cdot\widetilde{K}^{T}_{i}% )\cdot\widetilde{V}_{i}.over~ start_ARG italic_O end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = softmax ( over~ start_ARG italic_Q end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ⋅ over~ start_ARG italic_K end_ARG start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ over~ start_ARG italic_V end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT . (2)

The reassembly of these outputs into the cohesive final output O𝑂Oitalic_O involves a meticulously designed process:

  1. (1)

    Initially, we establish a zero matrix Oinitsubscript𝑂initO_{\text{init}}italic_O start_POSTSUBSCRIPT init end_POSTSUBSCRIPT that mirrors the dimensions of the original input for accumulating the outputs of the individual segments.

  2. (2)

    For each computed segment output O~isubscript~𝑂𝑖\widetilde{O}_{i}over~ start_ARG italic_O end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, a specific offset γisubscript𝛾𝑖\gamma_{i}italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is identified. This offset determines the precise starting position of O~isubscript~𝑂𝑖\widetilde{O}_{i}over~ start_ARG italic_O end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT within the Oinitsubscript𝑂initO_{\text{init}}italic_O start_POSTSUBSCRIPT init end_POSTSUBSCRIPT matrix.

  3. (3)

    Each O~isubscript~𝑂𝑖\widetilde{O}_{i}over~ start_ARG italic_O end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is mapped to Oinitsubscript𝑂initO_{\text{init}}italic_O start_POSTSUBSCRIPT init end_POSTSUBSCRIPT using a map** operation based on its γisubscript𝛾𝑖\gamma_{i}italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT:

    O=i𝙼𝙰𝙿(Oinit,O~i,γi)𝑂subscript𝑖𝙼𝙰𝙿subscript𝑂initsubscript~𝑂𝑖subscript𝛾𝑖O=\sum_{i}\texttt{MAP}(O_{\text{init}},\widetilde{O}_{i},\gamma_{i})italic_O = ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT MAP ( italic_O start_POSTSUBSCRIPT init end_POSTSUBSCRIPT , over~ start_ARG italic_O end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) (3)

The “MAP” operation places each O~isubscript~𝑂𝑖\widetilde{O}_{i}over~ start_ARG italic_O end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT element into Oinitsubscript𝑂initO_{\text{init}}italic_O start_POSTSUBSCRIPT init end_POSTSUBSCRIPT according to the position determined by γisubscript𝛾𝑖\gamma_{i}italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. This guarantees the accurate alignment of each segment’s output within the final output matrix O𝑂Oitalic_O, based on its original input position.

Computation Efficiency With the proposed Dilated Flash Attention mechanism, efficiency is quantitatively enhanced by a factor of Nwr2𝑁𝑤superscript𝑟2\frac{N}{wr^{2}}divide start_ARG italic_N end_ARG start_ARG italic_w italic_r start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG, where N𝑁Nitalic_N represents the total size of the input, w𝑤witalic_w the length of each segment, and r𝑟ritalic_r the interval of sparsification. This mathematical relationship demonstrates that Dilated Flash Attention requires substantially fewer computations for any given input size. Consequently, this boosts the model’s capability in efficiently processing large-scale image segmentation tasks, marking a notable improvement in both performance and practicality.

3.2 Dynamic Layer-Wise Distillation (DLD)

Training the SAM-Lightening from scratch is costly, while layer adaptation is challenging due to the distinctive structures between SAM with ViT-H as the feature encoder and SAM-Lightening. To enable efficient knowledge transfer from vanilla SAM to the proposed framework, we propose a novel Dynamic Layer-Wise Distillation (DLD), which dynamically modifies feature weights to enhance the layer-wise distillation between the models [17].

Dynamic Layer-Wise Weights: When preceding layers are not well-distilled, the performance of subsequent layers can suffer from low-quality features extracted from preceding layers. By assigning greater weight to the loss of these initial layers, dynamic weighting ensures they receive more focus during the training process. This helps in better aligning the student model with the teacher model in the initial stages. Given a deep neural network consisting of L𝐿Litalic_L layers, each layer i𝑖iitalic_i is associated with a temporal weight αi(t)subscript𝛼𝑖𝑡\alpha_{i}(t)italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ). This mechanism adjusts the significance of each layer i𝑖iitalic_i in the neural network across various training stages t𝑡titalic_t. The initial layer retains maximum emphasis (α1(t)=1subscript𝛼1𝑡1\alpha_{1}(t)=1italic_α start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_t ) = 1) and the subsequent layers adhere to a dynamic weighting scheme, which can be mathematically represented by the piece-wise function:

αi(t)={0for t<TitTiΔtfor Tit<Ti+Δt1for tTi+Δtsubscript𝛼𝑖𝑡cases0for 𝑡subscript𝑇𝑖𝑡subscript𝑇𝑖Δ𝑡for subscript𝑇𝑖𝑡subscript𝑇𝑖Δ𝑡1for 𝑡subscript𝑇𝑖Δ𝑡\displaystyle\alpha_{i}(t)=\begin{cases}0&\text{for }t<T_{i}\\ \frac{t-T_{i}}{\Delta t}&\text{for }T_{i}\leq t<T_{i}+\Delta t\\ 1&\text{for }t\geq T_{i}+\Delta t\end{cases}italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) = { start_ROW start_CELL 0 end_CELL start_CELL for italic_t < italic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL divide start_ARG italic_t - italic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG roman_Δ italic_t end_ARG end_CELL start_CELL for italic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≤ italic_t < italic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + roman_Δ italic_t end_CELL end_ROW start_ROW start_CELL 1 end_CELL start_CELL for italic_t ≥ italic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + roman_Δ italic_t end_CELL end_ROW (4)

Where Tisubscript𝑇𝑖T_{i}italic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT denotes the epoch at which the ithsuperscript𝑖𝑡i^{th}italic_i start_POSTSUPERSCRIPT italic_t italic_h end_POSTSUPERSCRIPT layer commences updating its weight, and the previous layer has reached saturation, i.e., Ti=Ti1+Δtsubscript𝑇𝑖subscript𝑇𝑖1Δ𝑡T_{i}=T_{i-1}+\Delta titalic_T start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_T start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT + roman_Δ italic_t. The parameter ΔtΔ𝑡\Delta troman_Δ italic_t captures the number of epochs over which the weight transitions from 0 to 1. For a predefined epoch increment ΔtΔ𝑡\Delta troman_Δ italic_t, each layer sequentially activates its learning potential after the preceding layer reaches its peak weight. This mechanism facilitates a cascading knowledge absorption from the teacher model.

Decoupled Feature Distillation: The distillation process transfers knowledge from SAM’s encoder (the teacher model) to our proposed encoder (the student model), as shown in Fig.1. We have chosen the N𝑁Nitalic_N layers closest to the output for feature distillation. Since these deeper layers are directly related to the model’s outputs, distilling them can more effectively transfer crucial information for prediction results. These layers are designated as “Focus Layers”.

During the initial phase of training, layers closer to the input are given precedence. Here, the intent is to align the SAM-Lightning primary feature representations of the student model, expressed as fSAM-Li(x)superscriptsubscript𝑓SAM-L𝑖𝑥f_{\text{SAM-L}}^{i}(x)italic_f start_POSTSUBSCRIPT SAM-L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ( italic_x ), with those of the teacher model, fSAMi(x)superscriptsubscript𝑓SAM𝑖𝑥f_{\text{SAM}}^{i}(x)italic_f start_POSTSUBSCRIPT SAM end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ( italic_x ), for the i𝑖iitalic_i layers closest to the input. As training advances, the layer-wise weighting dynamically shifts. The loss associated with subsequent layers is incrementally amplified. In the progress, the loss function evolves to assimilate representations from succeeding layers:

LP=iFocusαi(t)j=1NfSAM(i)(xj)fSAM-L(i)(xj)22subscript𝐿Psubscript𝑖Focussubscript𝛼𝑖𝑡superscriptsubscript𝑗1𝑁superscriptsubscriptnormsuperscriptsubscript𝑓SAM𝑖subscript𝑥𝑗superscriptsubscript𝑓SAM-L𝑖subscript𝑥𝑗22\displaystyle L_{\text{P}}=\sum_{i\in\text{Focus}}\alpha_{i}(t)\sum_{j=1}^{N}% \left\|f_{\text{SAM}}^{(i)}(x_{j})-f_{\text{SAM-L}}^{(i)}(x_{j})\right\|_{2}^{2}italic_L start_POSTSUBSCRIPT P end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_i ∈ Focus end_POSTSUBSCRIPT italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∥ italic_f start_POSTSUBSCRIPT SAM end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) - italic_f start_POSTSUBSCRIPT SAM-L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (5)

where L𝐿Litalic_L is the complete count of layers, and the coefficient α(i)𝛼𝑖\alpha(i)italic_α ( italic_i ) is a piece-wise function determined by the training epoch and the layer i𝑖iitalic_i. The integrated distillation loss is formulated as:

Lintegrated=LP+λLoutputsubscript𝐿integratedsubscript𝐿P𝜆subscript𝐿output\displaystyle L_{\text{integrated}}=L_{\text{P}}+\lambda L_{\text{output}}italic_L start_POSTSUBSCRIPT integrated end_POSTSUBSCRIPT = italic_L start_POSTSUBSCRIPT P end_POSTSUBSCRIPT + italic_λ italic_L start_POSTSUBSCRIPT output end_POSTSUBSCRIPT (6)

where LPsubscript𝐿PL_{\text{P}}italic_L start_POSTSUBSCRIPT P end_POSTSUBSCRIPT encapsulates the weighted sum of all selected feature layer losses, Loutputsubscript𝐿outputL_{\text{output}}italic_L start_POSTSUBSCRIPT output end_POSTSUBSCRIPT is the loss for the image encoder output layer, and λ𝜆\lambdaitalic_λ is a scaling factor to balance the significance of the decoder output in the overall distillation process.

Align Decoder: Additionally, the lightweight image encoder obtained through decoupled distillation has alignment issues with the frozen decoder, especially for point-based prompt segmentation tasks. Therefore, we fine-tuned the decoder by sampling point prompts and box prompts on the SA-1B dataset to align with the image encoder. The loss function is defined as follows:

Lfine-tune=20×IOU+Dice+Focal Losssubscript𝐿fine-tune20IOUDiceFocal Loss\displaystyle L_{\text{fine-tune}}=20\times\text{IOU}+\text{Dice}+\text{Focal Loss}italic_L start_POSTSUBSCRIPT fine-tune end_POSTSUBSCRIPT = 20 × IOU + Dice + Focal Loss (7)

Here, IOU represents the Intersection over Union loss, while Dice loss and Focal Loss are used to address class imbalance and challenging segmentation regions, respectively.

4 Experiment

4.1 Experimental Setups

Our model utilizes 1%percent11\%1 % of the SA-1B dataset for distillation and fine-tuning. It features an encoder with an embedding dimension of 384, six attention heads, and a six-layer structure. For the FlashAttention component, we use bfloat16. Both the distillation and fine-tuning processes are conducted for 10 epochs each, with a learning rate of 103superscript10310^{-3}10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT and a batch size of 32. Gradient accumulation is set with a step size of 4. The model is trained on two NVIDIA RTX 4090 GPUs. To enhance training speed, the outputs of SAM’s image encoder are saved [10, 9].

4.2 Results

Run-Time And Memory Efficiency Evaluation: We compare the performance of our proposed SAM-Lightening with vanilla SAM (i.e., SAM-ViT-H) [1], FastSAM [8], MobileSAM [9], EfficientSAM [11], SAMFast [12] in Table 1 and Table 2. Regarding the segmentation performance, the vanilla SAM is considered as the upper bound. Importantly, Table 1 shows that SAM-Lightening outperforms all its counterparts in terms of inference latency and peak memory usage, achieving 30.1×30.1\times30.1 × acceleration, 96.5%percent96.596.5\%96.5 % peak memory reduction when compared to vanilla SAM, and 2.1×2.1\times2.1 × acceleration when compared to the state-of-the-art. The throughput comparison in Table 2 further reinforces SAM-Lightening’s superior performance, which achieves the highest throughput across various batch sizes. Conclusively, this high throughput with its low latency and memory usage, positions SAM-Lightening as a highly efficient model for image segmentation tasks.

Table 1: Performance comparison on Nvidia RTX 4090 GPU, where “Enc.” refers to the Encoder, “Dec.” to the Decoder, “Mem.” to Memory usage, “Tot.” to Total Time, and “SU” denotes the Speed-Up ratio.
Model

Enc. ms

Dec. ms

Tot. ms

S.U. Mem.

SAM-ViT-H

216.1

3.8

219.9

1.0×\times×

5.7GB

SAMFast

23.2

3.8

27.0

8.5×\times×

4.1GB

FastSAM

20.7

3.4

24.1

9.1×\times×

2.6GB

EfficientSAM

22.3

3.8

26.1

8.3×\times×

309MB

MobileSAM

8.1

3.8

11.9

18.5×\times×

309MB

SAM-Lightening 3.5 3.4 6.9 30.1×\times× 224MB
Table 2: Parallel throughput comparison. Inference times are given in milliseconds (ms).
Model Size 1 Size 4 Size 8 Size 16

SAM-ViT-H

219.9

944.9

OOM

OOM

SAMFast

53.6

206.6

438.2

964.2

FastSAM

24.1

80.1

171.5

349.1

EfficientSAM

22.3

79.2

157.7

317.5

MobileSAM

8.1

34.1

72.3

156.8

SAM-Lightening 3.5 13.0 27.2 59.2
Refer to caption
Fig. 2: Representative image segmentation results between SAM-Lightening and the vanilla SAM in prompt mode.

Comparison In Box/Point Prompt Mode: We first evaluated the performance under bounding boxes and point-based prompts. For bounding box prompts, we followed the settings in vanilla SAM by leveraging the ground-truth annotation in the COCO [18] and LVIS [19] to synthesize bounding boxes that define areas of interest in each image. For point prompts, we randomly sampled points within the ground-truth masks from images, challenging all the models to accurately segment the object or region associated with each point. Quantitatively, we used mean Intersection over Union (mIoU) as the metric. As shown in Table 3, both SAMFast and MobileSAM suffer from a performance decline when compared to vanilla SAM, particularly with point prompts. FastSAM, as a CNN-based model, shows an even more pronounced drop, which is especially evident in the handling LVIS dataset that contains a large number of small objects. This observation reflects the limitations of CNN-based encoders in processing more complex segmentation scenarios. In contrast, SAM-Lightening matches the original SAM in terms of segmentation performance to the best context. This holds even in scenarios of point-based prompts, where SAM-Lightening achieves mIoU similar to the vanilla SAM.

Table 3: Segmentation performance comparison in terms of mIOU on COCO and LVIS. The labels “Box”, “1P”, and “3P” correspond to the use of a bounding box, one point, and three points as prompts, respectively.
Model COCO LVIS
Box 1P 3P Box 1P 3P
SAM-ViT-H 80.1 49.2 72.5 83.8 60.6 74.7
SAMFast 77.3 44.7 66.3 80.5 54.9 69.4
FastSAM 65.0 50.9 52.4 61.5 41.5 41.8
EfficientSAM 77.8 43.6 69.7 79.5 53.7 72.9
MobileSAM 77.9 47.9 67.4 78.5 55.4 66.8
SAM-Lightening 78.8 48.4 72.5 81.0 59.9 74.6

Comparison In Anything Mode:

Refer to caption
Fig. 3: Representative samples under anything mode.

While the segment-anything mode is an innovative approach, it is not a commonly used segmentation method and thus does not effectively represent typical segmentation tasks. Therefore, our analysis has primarily focused on visually comparing the segmentation outcomes through point-based and box-based methods, which are more prevalent in practical applications. However, for completeness and to demonstrate the versatility of the models, we have also included the outputs of the segment-anything mode in our comparison.

From the representative samples demonstrated in Fig. 3, both SAM-Lightening and MobileSAM exhibit segmentation results that are nearly indistinguishable from those of the vanilla SAM. This similarity is notable in terms of edge clarity and detail preservation, which are hallmarks of high-quality segmentation. SAM-Lightening demonstrates its robustness and accuracy, aligning closely with the performance of the vanilla SAM.

4.3 Ablation study

Refer to caption
Fig. 4: Impacts of inference time with FlashAttention over input size, where we select two embedding dimensions, namely 768 and 384, for comparison.

It’s noteworthy that many previous works [4, 20, 21] use smaller input sizes for SAM other than 1024. For a fair comparison, we also conducted experiments in these scenarios and found that kee** FlashAttention for input sizes equal to or smaller than 512×512512512512\times 512512 × 512 achieves optimal performance. This indicates that the applicability of FlashAttention depends on the model’s input size and specific hardware configuration. The decision to use FlashAttention should be made based on the specific application context and performance requirements. Although FlashAttention accelerates training in model distillation, its impact on inference performance is determined by various hardware metrics. On our inference platform, especially for the SAM with a 1024 input size, the multi-head attention operator exhibits a more computation-intensive characteristic. As shown in Fig. 4, this results in a slightly lower inference speed with FlashAttention compared to without it. Therefore, we opt to use FlashAttention during the distillation process to optimize performance while removing it during the evaluation phase.

5 Conclusion

We propose SAM-Lightening to address the primary limitations of high computational demand and slow inference speed in vanilla SAM to make it more suitable for deployment on resource-constrained devices. Our approach involves the redesign of the image encoder in SAM, by distilling the self-attention operators into dilated flash attentions with dynamic layer-wise distillation. These optimizations contribute to a notable reduction in computational complexity and memory usage without compromising the segmentation performance. Specifically, SAM-Lightening can complete inference within 7 milliseconds per image, achieving a 30.1×30.1\times30.1 × speed up over SAM-ViT-H. Since SAM-Lightening is complementary to pruning and quantization, one future direction can look into the integration with them.

References

  • [1] Kirillov et al., “Segment anything,” arXiv preprint arXiv:2304.02643, 2023.
  • [2] Archit et al., “Segment anything for microscopy,” Aug 2023.
  • [3] Ma et al., “Segment anything in medical images,” Apr 2023.
  • [4] Cheng et al., “Sam-med2d,” arXiv preprint arXiv:2308.16184, 2023.
  • [5] Yang et al., “Track anything: Segment anything meets videos,” Apr 2023.
  • [6] Shen et al., “Anything-3d: Towards single-view anything reconstruction in the wild,” arXiv preprint arXiv:2304.10261, 2023.
  • [7] Ronneberger et al., U-Net: Convolutional Networks for Biomedical Image Segmentation, p. 234–241, Jan 2015.
  • [8] Zhao et al., “Fast segment anything,” arXiv preprint arXiv:2306.12156, 2023.
  • [9] Zhang et al., “Faster segment anything: Towards lightweight sam for mobile applications,” arXiv preprint arXiv:2306.14289, 2023.
  • [10] Wu et al., “Tinyvit: Fast pretraining distillation for small vision transformers,” Springer, Cham, 2022.
  • [11] Xiong et al., “Efficientsam: Leveraged masked image pretraining for efficient segment anything,” arXiv preprint arXiv:2312.00863, 2023.
  • [12] PyTorch Team, “Accelerating generative ai,” accelerating-generative-ai, 2023.
  • [13] Dao et al., “Flashattention: Fast and memory-efficient exact attention with io-awareness,” Advances in Neural Information Processing Systems, vol. 35, pp. 16344–16359, 2022.
  • [14] Jocher et al., “Ultralytics yolov8,” 2023.
  • [15] Dao, “Flashattention-2: Faster attention with better parallelism and work partitioning,” arXiv preprint arXiv:2307.08691, 2023.
  • [16] Hinton et al., “Distilling the knowledge in a neural network,” arXiv: Machine Learning,arXiv: Machine Learning, Mar 2015.
  • [17] Ji et al., “Show, attend and distill:knowledge distillation via attention-based feature matching,” Proceedings of the AAAI Conference on Artificial Intelligence, p. 7945–7952, Sep 2022.
  • [18] Lin et al., “Microsoft coco: Common objects in context,” COMPUTER VISION - ECCV 2014, PT V, pp. 740–755, 2014.
  • [19] Gupta et al., “Lvis: A dataset for large vocabulary instance segmentation,” in 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), Jun 2019.
  • [20] Chen et al., “MMDetection: Open mmlab detection toolbox and benchmark,” arXiv preprint arXiv:1906.07155, 2019.
  • [21] Wang et al., “Seggpt: Segmenting everything in context,” arXiv preprint arXiv:2304.03284, 2023.