Is Flash Attention Stable?
Abstract
Training large-scale machine learning models poses distinct system challenges, given both the size and complexity of today’s workloads. Recently, many organizations training state-of-the-art Generative AI models have reported cases of instability during training, often taking the form of loss spikes. Numeric deviation has emerged as a potential cause of this training instability, although quantifying this is especially challenging given the costly nature of training runs. In this work, we develop a principled approach to understanding the effects of numeric deviation, and construct proxies to put observations into context when downstream effects are difficult to quantify. As a case study, we apply this framework to analyze the widely-adopted Flash Attention optimization. We find that Flash Attention sees roughly an order of magnitude more numeric deviation as compared to Baseline Attention at BF16 when measured during an isolated forward pass. We then use a data-driven analysis based on the Wasserstein Distance to provide upper bounds on how this numeric deviation impacts model weights during training, finding that the numerical deviation present in Flash Attention is 2-5 times less significant than low-precision training.
Index Terms:
Generative AI, Numeric Deviation, Training Instability, Attention, TransformersI Introduction
As machine learning trends to larger and more complex models, the model training process is becoming ever more compute- and resource-intensive. The advent of Generative AI has further pushed the boundaries in model development, with Large Language Models (LLMs) often training for months at a time, across hundreds or thousands of GPUs. Take, for example, LLaMA2’s 70-B parameter model, which required 1,720,320 GPU hours to train [10]. With such long training jobs, training instability has become increasingly problematic. As reported in works such as Google’s PaLM model, training instability often manifests itself in the form of loss spikes occurring up to 20 times throughout training [1]. These loss spikes are costly, as they often cause interrupts in the training process, requiring training to stop and restart.
While previous work has examined possible mitigation responses to increase training stability from an algorithm perspective [6, 5, 9, 3], the fundamental cause of this instability is still unclear. The stochastic nature of model training, combined with the fact that many such effects only manifest at large scale, present distinct challenges to fully understanding the nature of these instabilities. Furthermore, repeatedly training such large models to isolate contributing factors is often unfeasible, given the cost and high demand for data center compute.
One under-explored potential cause of training instability is numeric deviation. Numeric deviation between an optimization and its corresponding baseline can lead to the gradual accumulation of errors, which over the course of training have the potential to culminate in loss spikes that require a resetting of the model state [1]. This is challenging to quantify, as training’s stochastic nature suggests some level of numeric deviation might be acceptable, yet determining the threshold for when training becomes unstable proves difficult.
In this work, we develop a principled quantitative approach to understanding numeric deviation in training optimizations. Our approach consists of two phases, including (i) develo** a microbenchmark to perturb numeric precision in the given optimization, and (ii) evaluating how numeric deviation translates to changes in model weights through a data-driven analysis based on Wasserstein distance. This ultimately allows us to provide an upper bound on the amount of numeric deviation for a given optimization, and helps to contextualize the improvement within known techniques. We aim to use this principled analysis to evaluate different state-of-the-art optimization techniques and identify whether they are likely to introduce unintended instabilities when used to train large models.
As a case study, we analyze the state-of-the-art optimization technique Flash Attention [2], and quantify the potential numeric deviation introduced. Flash Attention is a widely-adopted technique used to speed up the attention mechanism, often considered a system bottleneck in transformer models [11]. However, while offering increased speedup and reduced memory accesses, Flash Attention depends on algorithm optimizations that have the potential to contribute to increased numeric deviation. Specifically, we hypothesize that the addition of rescaling factors could introduce unintentional approximation that leads to numeric tradeoff, which could later impact training stability. We analyze Flash Attention in the context of multi-modal Text-to-Image workloads in order to determine the potential significance of numeric deviation between Flash Attention and its baseline.
Ultimately, we introduce a framework to quantify the numeric deviation of training optimizations and their downstream impacts. Our key contributions are as follows:
-
•
We design a microbenchmark to isolate the impact of numerical precision on numeric deviation. Our microbenchmark serves as a technique to measure and quantify numeric deviation resulting from traditionally black-box optimizations such as Flash Attention. By perturbing aspects not typically available through the provided kernel, we initially find Flash Attention sees roughly an order of magnitude more numeric deviation as compared to Baseline Attention at low numerical precision (BF16).
-
•
We perform a data-driven analysis based on the Wasserstein Distance metric to contextualize this observed numeric deviation and form an upper bound for the impact on downstream model properties. In our case study, we are able to bound the impact of this observed numerical deviation, and find that Flash Attention introduces roughly less model weight deviation as compared to low-precision training.
Our investigations underscore the importance of develo** a principled approach to not only quantify, but contextualize, the impact of training optimizations on numeric deviation. By constructing proxies to put this numeric deviation in context, we aim to reason about the likelihood of downstream model effects (i.e., training instability) that are traditionally difficult to measure.
II Background
As a case study for this work, we analyze the state-of-the-art optimization Flash Attention and its potential numeric deviation from Baseline Attention. The Attention operation has been the focus of myraid optimizations as of late. Attention is currently the main system-performance bottleneck of the Transformer architecture, which is a widely adopted technique in modern machine learning algorithms known for its effectiveness in modeling sequence-to-sequence tasks [11].
II-A Attention as a System-Performance Bottleneck
The Attention operation essentially derives a weighted sum that represents how much emphasis a model should place on all previous words when generating a new token [11]. The key computation in Attention involves matrix-multiplications and softmax. Three representations (Query, Key, Value) are first derived from the input vector, each having dimension , where is the sequence length and is the model dimension. The dot product of the Query and Key values then form a matrix, which scales quadratically in size with sequence length. This so-called similarity matrix subsequently undergoes a softmax operation, before being multiplied with the Value matrix to complete the computation. The full operation is described below:
Given this quadratic scaling with sequence length, myraid techniques have been proposed to accelerate the Attention mechanism, including system-aware methods such as Flash Attention [2].
II-B Understanding Flash Attention
Flash Attention is a recently proposed technique that is designed to accelerate the Attention bottleneck characteristic of Transformers [2]. As an IO aware technique, it aims to minimize the memory overhead of the large similarity matrix typically used in the attention mechanism. It essentially uses the traditional techniques of tiling and recomputation, along with an online softmax trick, in order to only calculate the matrix one tile at a time [2]. As shown in Figure 1, the introduction of tiling eliminates the need for the large similarity matrix to be materialized. The block/tile size is defined by and , where M is the size of the SRAM and d is the model dimension. This ensures the block fits inside SRAM, and thus eliminates the need for the large matrix to be loaded/stored from HBM.
However, since the online softmax technique requires global information about the matrix, Flash Attention must incorporate re-scaling factors in order to allow for consistent calculations since global information is no longer available when computing on a single block. While introducing relatively minimal overhead, these additional re-scaling factors do introduce extra computation that is calculated per tile.
Flash Attention comes with timing performance speedup as well as more efficient resource utilization; we find that it yields a 14% speedup of forward + backward pass for our example text-to-image model. However, it has also been hypothesized that the additional computation introduced by rescaling factors of Flash Attention could introduce numeric deviation when used in text-to-image model training, and we aim to investigate this below.
![Refer to caption](extracted/2405.02803v1/figures/flashblock.png)
III Experimental Methodology
![Refer to caption](extracted/2405.02803v1/figures/Figure1-2.png)
We first develop a microbenchmark to isolate and study the numeric deviation caused by Flash Attention. A summary of our microbenchmark design can be found in Figure 2. As shown, we numerically re-implement Flash Attention in order to analyze different numerical precisions and apply potential optimizations at each step of the algorithm, which is not easily done with the original CUDA implementation. This is necessary, as the Flash Attention kernel currently only supports FP16 and BF16 number formats. The Flash Attention kernel is also a wrapped API call of CUDA code, making it challenging to perturb the algorithm to examine the impact on numeric deviation. In contrast, our microbenchmark design allows for varying precision inputs and modifications inside the algorithm. We validate our microbenchmark against the original Flash Attention kernel.
We further devise a technique to compare the output of the Attention matrix at each step during model execution. We modify the model code to compute both Baseline Attention and Flash Attention each time Attention is called, which allows for identical input matrices and an exact output matrix comparison. To contextualize this, we additionally quantify the difference in model weights throughout training via identical and independent training runs, using the and metrics, which are further detailed in Section V.
For training experiments we use a type of Generative AI workload that converts text inputs to images (i.e., Text-to-Image model). We retrain the model using the Shutterstock dataset and run our experiments across a cluster of NVIDIA 80GB A100 GPUs.
IV Quantifying Numeric Deviation Through Microbenchmark
We first analyze the impact of Flash Attention during the forward pass. We utilize our microbenchmark to examine how different numerical precisions impact the output matrix of the Attention calculation, given the same randomly initialized query, key, value vectors.
IV-A Sweep Numerical Precision
We find that numerical precision has an impact on the output of Flash Attention, causing it to deviate from the output of Baseline Attention. We quantify this through a measure of the maximum difference between attention output matrices taken elementwise, which serves as an upper bound on the possible deviation. As Figure 3 shows, when using different number formats varying from BF16 to FP64, the numeric deviation between Flash Attention and Baseline Attention decreases with increasing number of mantissa bits. This suggests the numeric difference is a result of approximation inherent with fewer mantissa bits.
We then subsequently compare this to the behavior of Baseline Attention. For a common standard of comparison, we set a ”golden value” of Attention to be Baseline Attention at FP64. We then compare the maximum difference of the Attention output at different number formats to this golden value, as shown in Figure 4. Note we plot the maximum difference between Flash Attention outputs and this golden value (blue bars), while comparing Baseline Attention outputs to this golden value for comparison as well (red bars). We find that Flash Attention sees roughly 10 more numeric deviation as compared to Baseline Attention at BF16. A detailed discussion on whether this level of deviation is significant can be found in Section V.
![Refer to caption](extracted/2405.02803v1/figures/number_formats.jpg)
![Refer to caption](extracted/2405.02803v1/figures/number_formats_alt.jpg)
To analyze this observed numeric deviation further, we sweep the sequence length of our matrices while kee** the tile size and SRAM size the same. As shown in Figure 5, as sequence length increases, the numerical deviation between Flash Attention and Baseline Attention increases, when measured by both (a) the maximum difference upper bound, and (b) the mean and standard deviation of that difference. Since a larger sequence length implies a larger intermediate matrix that must be tiled while the tile size stays the same, more rescaling is needed. This presents more opportunities for precision errors to accumulate, and thus more deviation.
IV-B Sweep Algorithm Changes
![Refer to caption](extracted/2405.02803v1/figures/sequence_length_1.jpg)
![Refer to caption](extracted/2405.02803v1/figures/sequence_length_2.jpg)
![Refer to caption](extracted/2405.02803v1/figures/algorithm_changes_cropped_1.jpg)
![Refer to caption](extracted/2405.02803v1/figures/algorithm_changes_cropped_2.jpg)
![Refer to caption](extracted/2405.02803v1/figures/algorithm_changes_cropped_3.jpg)
We further leverage the microbenchmark design to experiment with different optimizations so we can better understand the effects of this numeric deviation. Figure 6 shows several algorithm changes and their corresponding impact on observed numeric deviation. For each experiment, we sweep block area (defined according to Bc and Br dimensions introduced in Section II), and plot the corresponding maximum difference between Attention matrix outputs. We highlight how this changes at the four precisions analyzed in Section IV-A. Figure 6a shows how swap** the order of the block dimensions leads to larger numerical difference between Flash Attention and Baseline Attention.
Notably, we additionally find that larger block/tile sizes lead to smaller numeric deviation (Figure 6c). This is because less re-scaling calculations are needed with larger tile sizes, since there are fewer tiles needed to cover the original matrix. Note that other perturbations such as constraining tile sizes to be square does not have an impact on numeric deviation, since this does not drastically change the number of re-scaling calculations needed to be performed (Figure 6b).
V Contextualizing Numeric Deviation Via Weight Differences
While Flash Attention may cause numeric deviation of the Attention output during a forward pass, our ultimate goal is to determine whether this has any impact during model training, in order to investigate whether it contributes to training instability. We thus aim to quantify whether Flash Attention changes the model during training — i.e., if the observed Attention output difference from Section IV is reflected in model weights that are updated during training.
We utilize two metrics to measure the model weight difference between a model trained with Baseline Attention as compared to Flash Attention. We first calculate , by finding the absolute value of the difference between weight matrices and taking the maximum to give an upper bound on the deviation, as shown below:
While provides an upper bound on numeric deviation, it fails to take into account the distribution of each matrix. We therefore additionally quantify weight differences through Wasserstein Distance, which is a commonly used metric to measure the similarity between tensors [7, 4]. While slightly more computationally complex, the Wasserstein metric incorporates information about the shape of the tensor distributions in order measure similarity. The formulation for Wasserstein Distance is outlined below:
Note that lower values indicate higher similarity between matrices.
Using these two metrics, we subsequently quantify how model weights change throughout the course of training when Flash Attention is implemented as compared to Baseline Attention. As shown in Figure 7, the incorporation of Flash Attention does in fact change the model weights throughout training, as measured by both Wasserstein Distance and Max Difference, and as training continues, this difference only increases. This suggests that models trained with Flash Attention converge to a different model than an identical one trained with Baseline Attention.
![Refer to caption](extracted/2405.02803v1/figures/weights_cropped_2.png)
However, training is a stochastic process and some model architecture changes can yield comparable results in terms of downstream effects and accuracy. Thus, even if the model weights differ between a model trained with Flash Attention and Baseline Attention, is this significant? Training models to completion and evaluating accuracy is a costly and resource-intensive task, especially for large models where training takes on order of months. We therefore formulate a proxy to understand (a) how significant these weight changes are? and (b) can we put this into context with standard weight changes in other widely-adopted training optimizations?
To achieve this, we devise a series of experiments to compare how weight differences change over the course of training under different scenarios. In addition to comparing training runs with Flash and Baseline Attention, we quantify the weight difference of identical training runs where weights are initialized to different random values at the beginning of training. This provides a bound, since random weight initialization is a commonly-used technique [8], and typically produces equivalent results. Furthermore, we also measure the change in model weights of models trained with different precisions. Numeric precision (i.e., FP16 vs FP32) has potential to cause changes downstream, and this serves as an upper bound for determining the significance of Flash Attention weights.
Figure 8 shows the relative weight differences over the course of training as measured using Wasserstein Distance. We find that the rate of change of weight deviation for a model using Flash Attention is comparable or less than the deviation from a different model initialization (note the slope of red and blue curves). Furthermore, we see that the rate of change of weights when using FP16 vs FP32 is higher and more variable than the rate of change for different model initializations. These results provide a proxy to suggest that although numeric deviation occurs with Flash Attention, it is bounded by random model initialization and low-precision training, and introduces roughly less model weight deviation as compared to low-precision training.
![Refer to caption](extracted/2405.02803v1/figures/relative_weights.jpg)
VI Discussion and Future Work
Our work takes the first step towards addressing the question, Is Flash Attention Stable, yet there is still more work that needs to be done before drawing a conclusive answer. Since training instability is challenging and costly to isolate, we explore numeric deviation, which is hypothesized to be a potential cause. Through our principled numeric deviation analysis, we make progress towards this goal by develo** a framework to quantify numeric deviation and develop proxies to bound the impact of this deviation in terms of model weights. However, ultimately linking this numeric deviation back to training instability requires further investigation.
Our numeric quantification methodology opens up a broader set of inquiries, including understanding how various other optimizations impact numeric deviation. Although this analysis is focused on Flash Attention, future work should aim to widen the scope, and investigate additional training optimizations and their corresponding numeric deviation from the appropriate baseline. For example, investigating the numeric deviation caused by the Winograd Algorithm compared to traditional convolution, or various other Attention optimizations, kernel fusion techniques, etc.
More broadly, this work raises a larger set of research questions that deserve attention, specifically regarding training instability. It would be interesting to investigate, for example, the exact point in training where loss spikes occur, and then relate this to the measured weight deviation at this same point. In general, further develo** proxies to understand training instability is critical, given the costly nature of these experiments. The following aspects would be interesting to explore as well:
Training Instability and Reliability. Training instability leads to interruptions in training, often in the form of loss spikes. However, this is not the only cause of interruptions. Hardware failures, for example, also contribute to the starting/stop** of model training across datacenters. Future work should investigate the relationship between this hardware reliability, checkpointing, and instability.
Training Instability and System Overhead. We are additionally interested in quantifying the system overhead that comes with queuing/requeuing training jobs after a loss spike has occurred. The additional overheads that come with these instabilities only magnify when training at-scale, and thus are important to quantify.
Training Instability and Sustainability. Frequent interruptions to the training procedure cause major power swings across the datacenter as well. For example, interruptions can cause significant surges in power consumption when training resumes. This has sustainability implications for designing low-carbon, power-efficient infrastructure for datacenters.
VII Conclusion
In this work, we develop a principled approach to understanding the effects of numeric deviation, and develop proxies to put observations into context when downstream effects are otherwise challenging to measure. We examine Flash Attention as a case study, and make progress towards quantifying numeric deviation. We hope that sharing our methodology will allow others to investigate future research questions in a similar manner, and encourage others to investigate this challenging problem of training instability.
References
- [1] A. Chowdhery, S. Narang, J. Devlin, M. Bosma, G. Mishra, A. Roberts, P. Barham, H. W. Chung, C. Sutton, S. Gehrmann, P. Schuh, K. Shi, S. Tsvyashchenko, J. Maynez, A. Rao, P. Barnes, Y. Tay, N. Shazeer, V. Prabhakaran, E. Reif, N. Du, B. Hutchinson, R. Pope, J. Bradbury, J. Austin, M. Isard, G. Gur-Ari, P. Yin, T. Duke, A. Levskaya, S. Ghemawat, S. Dev, H. Michalewski, X. Garcia, V. Misra, K. Robinson, L. Fedus, D. Zhou, D. Ippolito, D. Luan, H. Lim, B. Zoph, A. Spiridonov, R. Sepassi, D. Dohan, S. Agrawal, M. Omernick, A. M. Dai, T. S. Pillai, M. Pellat, A. Lewkowycz, E. Moreira, R. Child, O. Polozov, K. Lee, Z. Zhou, X. Wang, B. Saeta, M. Diaz, O. Firat, M. Catasta, J. Wei, K. Meier-Hellstern, D. Eck, J. Dean, S. Petrov, and N. Fiedel, “Palm: Scaling language modeling with pathways,” 2022.
- [2] T. Dao, D. Y. Fu, S. Ermon, A. Rudra, and C. Ré, “Flashattention: Fast and memory-efficient exact attention with io-awareness,” 2022.
- [3] J. Gilmer, B. Ghorbani, A. Garg, S. Kudugunta, B. Neyshabur, D. Cardoze, G. Dahl, Z. Nado, and O. Firat, “A loss curvature perspective on training instability in deep learning,” 2021.
- [4] S. B. Harma, A. Chakraborty, B. Falsafi, M. Jaggi, and Y. Oh, “Accuracy boosters: Epoch-driven mixed-mantissa block floating-point for dnn training,” 2023.
- [5] L. Liu, X. Liu, J. Gao, W. Chen, and J. Han, “Understanding the difficulty of training transformers,” 2023.
- [6] I. Molybog, P. Albert, M. Chen, Z. DeVito, D. Esiobu, N. Goyal, P. S. Koura, S. Narang, A. Poulton, R. Silva, B. Tang, D. Liskovich, P. Xu, Y. Zhang, M. Kambadur, S. Roller, and S. Zhang, “A theory on adam instability in large-scale machine learning,” 2023.
- [7] V. M. Panaretos and Y. Zemel, “Statistical aspects of wasserstein distances,” Annual Review of Statistics and Its Application, vol. 6, no. 1, p. 405–431, Mar. 2019. [Online]. Available: http://dx.doi.org/10.1146/annurev-statistics-030718-104938
- [8] L. Scabini, B. D. Baets, and O. M. Bruno, “Improving deep neural network random initialization through neuronal rewiring,” 2022.
- [9] Y. Sun, D. LAO, G. Sundaramoorthi, and A. Yezzi, “Surprising instabilities in training deep networks and a theoretical analysis,” in Advances in Neural Information Processing Systems, S. Koyejo, S. Mohamed, A. Agarwal, D. Belgrave, K. Cho, and A. Oh, Eds., vol. 35. Curran Associates, Inc., 2022, pp. 19 567–19 578.
- [10] H. Touvron, L. Martin, K. Stone, P. Albert, A. Almahairi, Y. Babaei, N. Bashlykov, S. Batra, P. Bhargava, S. Bhosale, D. Bikel, L. Blecher, C. C. Ferrer, M. Chen, G. Cucurull, D. Esiobu, J. Fernandes, J. Fu, W. Fu, B. Fuller, C. Gao, V. Goswami, N. Goyal, A. Hartshorn, S. Hosseini, R. Hou, H. Inan, M. Kardas, V. Kerkez, M. Khabsa, I. Kloumann, A. Korenev, P. S. Koura, M.-A. Lachaux, T. Lavril, J. Lee, D. Liskovich, Y. Lu, Y. Mao, X. Martinet, T. Mihaylov, P. Mishra, I. Molybog, Y. Nie, A. Poulton, J. Reizenstein, R. Rungta, K. Saladi, A. Schelten, R. Silva, E. M. Smith, R. Subramanian, X. E. Tan, B. Tang, R. Taylor, A. Williams, J. X. Kuan, P. Xu, Z. Yan, I. Zarov, Y. Zhang, A. Fan, M. Kambadur, S. Narang, A. Rodriguez, R. Stojnic, S. Edunov, and T. Scialom, “Llama 2: Open foundation and fine-tuned chat models,” 2023.
- [11] A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, L. Kaiser, and I. Polosukhin, “Attention is all you need,” 2023.