Distributed Speculative Inference
of Large Language Models

Nadav Timor Correspondence:[email protected] Jonathan Mamou Daniel Korat Moshe Berchansky Oren Pereg Moshe Wasserblat Tomer Galanti Michal Gordon David Harel
Abstract

Accelerating the inference of large language models (LLMs) is an important challenge in artificial intelligence. This paper introduces distributed speculative inference (DSI), a novel distributed inference algorithm that is provably faster than speculative inference (SI) (Leviathan et al., 2023; Chen et al., 2023; Miao et al., 2023) and traditional autoregressive inference (non-SI). Like other SI algorithms, DSI works on frozen LLMs, requiring no training or architectural modifications, and it preserves the target distribution.

Prior studies on SI have demonstrated empirical speedups (compared to non-SI) but require a fast and accurate drafter LLM. In practice, off-the-shelf LLMs often do not have matching drafters that are sufficiently fast and accurate. We show a gap: SI gets slower than non-SI when using slower or less accurate drafters. We close this gap by proving that DSI is faster than both SI and non-SI—given any drafters. By orchestrating multiple instances of the target and drafters, DSI is not only faster than SI but also supports LLMs that cannot be accelerated with SI.

Our simulations show speedups of off-the-shelf LLMs in realistic settings: DSI is 1.29-1.92x faster than SI. 111Code will be available upon publication.

1 Introduction

Generative LLMs, such as GPT-4 (OpenAI, 2023), have demonstrated unprecedented results in various applications (Andreas, 2022; Li et al., 2023; Bubeck et al., 2023; Wei et al., 2022). Despite their potential, the inference latency of LLMs presents a significant challenge and a bottleneck for adoption in real-time applications. For example, in algorithmic trading, the model needs to make rapid predictions to execute trades in milliseconds, and in autonomous driving, the model must act quickly to ensure the vehicle’s reliability. This challenge is compounded by existing inference algorithms that do not fully utilize the computational resources that modern hardware offers.

Given their usefulness, speeding up the inference of LLMs is an important area of research. Existing efforts to reduce the inference latency can be classified into two main categories: algorithmic innovations and system optimizations. Algorithmic innovations include compressing LLMs through pruning (e.g., (Frantar and Alistarh, 2023; Sun et al., 2024a; Ma et al., 2023)), knowledge distillation (e.g., Hinton et al. (2015); Gu et al. (2024)), quantization (e.g., (Hubara et al., 2018; Frantar et al., 2022; Lin et al., 2024)), and low-rank factorization (e.g., (Hsu et al., 2022; Xu et al., 2023)). On the system side, enhancements such as kernel optimizations (Dao et al., 2022), tensor parallelism (Shoeybi et al., 2019), and low-bit quantization (Yao et al., 2024; Dettmers et al., 2022) are utilized to increase computation speed and reduce operational overhead, directly lowering latency.

Despite reducing the inference time, these methods have a significant drawback: they typically degrade the output quality. Consequently, other approaches acknowledge that some inputs require a very large model, while others can be effectively approximated by more efficient models. The goal of these adaptive methods (e.g., (Elbayad et al., 2020; Bapna et al., 2020; Han et al., 2022; Schuster et al., 2021)) is to channel fewer computational resources for easier inference steps. While many of these solutions are useful in practice, they often require modifications to the model architecture, changes to the training procedure and re-training of the models, without guaranteeing identical outputs.

A recent line of work (Stern et al., 2018) for accelerating the inference of LLMs is based on speculative inference. The idea is to use speculative execution (Burton, 1985; Hennessy and Patterson, 2012) to predict possible continuations of the input prompt using faster drafter LLMs that approximate the target LLM, then verify the correctness of the predicted continuations simultaneously by utilizing the concurrency of CUDA-based processors (i.e., batching). They provided empirical evidence that their proposed draft-then-verify approach speeds up the inference. Since the introduction of speculative inference (Stern et al., 2018), various papers Leviathan et al. (2023); Chen et al. (2023) have improved this method by introducing novel lossless methods to verify the correctness of token sequences that were generated by the drafter LLMs. Empirically, these approaches lead to speedups in decoding LLMs in practical use cases, such as 2-3x speedups in decoding LLMs of 11B and 70B parameters in some settings. Following this line of work, Miao et al. (2023) extended the verification algorithm of Leviathan et al. (2023); Chen et al. (2023) and showed that their method increases the probability of accepting draft tokens, and proved its losslessness. Following the success of this approach, research in this area has expanded in various directions (Mamou et al., 2024; Li et al., 2024; Cai et al., 2024; Sun et al., 2024b; Zhou et al., 2024; Liu et al., 2023; Joao Gante, 2023).

While traditional methods for SI show how to accelerate the inference time of LMs, they do not take advantage of the possibility of having multiple processing units (e.g., GPUs). In addition, empirical evidence indicates that acceleration happens only when the drafter is very accurate and is significantly faster than the target model. Two key questions then are: (i) can we reduce the inference time of LLMs by taking advantage of multiple processors simultaneously? (ii) can we accelerate the inference time using drafters that are not necessarily very fast or accurate?

Contributions.  In this paper we make the following contributions: 1. We design the first distributed algorithm (across multiple GPUs) for speculative inference of large language models. This algorithm is provably faster than both non-SI and SI methods. 2. We empirically validate, across a wide range of experiments, that our method can speed up the inference time compared to SI, even when fixing the number of processors. 3. We demonstrate that SI requires a drafter model that is both faster and more accurate than the target model. Conversely, our method accelerates inference time even with drafter models that are slower and less accurate (>10%absentpercent10>10\%> 10 % latency compared to the target model).

2 Preliminaries

We begin by describing autoregressive language models, next-token prediction, speculative inference and how to measure latency.

Autoregressive language models (LMs) are deterministic, real-valued multivariate functions. An input to an LM is a sequence of vectors. We call these vectors tokens, and the sequence a prompt. Tokens have a pre-defined dimension, denoted by nvocabsubscript𝑛vocabn_{\textnormal{vocab}}italic_n start_POSTSUBSCRIPT vocab end_POSTSUBSCRIPT. LMs output a vector of real numbers of dimension nvocabsubscript𝑛vocabn_{\textnormal{vocab}}italic_n start_POSTSUBSCRIPT vocab end_POSTSUBSCRIPT, also known as the logits. Since prompts to the same LM may vary in length, we simplify the notation of the forward pass as follows: f:×nvocabnvocab:𝑓superscriptabsentsubscript𝑛vocabsuperscriptsubscript𝑛vocabf:\mathbb{R}^{*\times n_{\textnormal{vocab}}}\to\mathbb{R}^{n_{\textnormal{% vocab}}}italic_f : blackboard_R start_POSTSUPERSCRIPT ∗ × italic_n start_POSTSUBSCRIPT vocab end_POSTSUBSCRIPT end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT vocab end_POSTSUBSCRIPT end_POSTSUPERSCRIPT.

Self-Attention LMs are LMs with pre-defined context length nctxsubscript𝑛ctxn_{\textnormal{ctx}}italic_n start_POSTSUBSCRIPT ctx end_POSTSUBSCRIPT (Vaswani et al., 2017). Hence, we represent the forward pass of such LMs in the following manner: f:nctx×nvocabnvocab:𝑓superscriptsubscript𝑛ctxsubscript𝑛vocabsuperscriptsubscript𝑛vocabf:\mathbb{R}^{n_{\textnormal{ctx}}\times n_{\textnormal{vocab}}}\to\mathbb{R}^% {n_{\textnormal{vocab}}}italic_f : blackboard_R start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT ctx end_POSTSUBSCRIPT × italic_n start_POSTSUBSCRIPT vocab end_POSTSUBSCRIPT end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT vocab end_POSTSUBSCRIPT end_POSTSUPERSCRIPT. For example, GPT-2 and GPT-3 are Transformers with nvocab=50257subscript𝑛vocab50257n_{\textnormal{vocab}}=50257italic_n start_POSTSUBSCRIPT vocab end_POSTSUBSCRIPT = 50257, and context lengths nctx=1024subscript𝑛ctx1024n_{\textnormal{ctx}}=1024italic_n start_POSTSUBSCRIPT ctx end_POSTSUBSCRIPT = 1024 and nctx=2048subscript𝑛ctx2048n_{\textnormal{ctx}}=2048italic_n start_POSTSUBSCRIPT ctx end_POSTSUBSCRIPT = 2048, respectively (Radford et al., 2019; Brown et al., 2020). In this paper, all LMs are Self-Attention ones with pre-set (frozen) parameters.

We extend the prompt notation such that prompts can have length lnctx𝑙subscript𝑛ctxl\leq n_{\textnormal{ctx}}italic_l ≤ italic_n start_POSTSUBSCRIPT ctx end_POSTSUBSCRIPT. Self-Attention LMs handle prompts of length l<nctx𝑙subscript𝑛ctxl<n_{\textnormal{ctx}}italic_l < italic_n start_POSTSUBSCRIPT ctx end_POSTSUBSCRIPT by starting the input sequence with a prefix of nctxlsubscript𝑛ctx𝑙n_{\textnormal{ctx}}-litalic_n start_POSTSUBSCRIPT ctx end_POSTSUBSCRIPT - italic_l tokens, followed by the l𝑙litalic_l given tokens. LMs ignore the prefix, either by zeroing (masking) the Attention parts corresponding to the prefix or by left-padding with dedicated tokens. In this paper, prompts of length l<nctx𝑙subscript𝑛ctxl<n_{\textnormal{ctx}}italic_l < italic_n start_POSTSUBSCRIPT ctx end_POSTSUBSCRIPT are the non-masked, non-padded suffix of the input sequence of length nctxsubscript𝑛ctxn_{\textnormal{ctx}}italic_n start_POSTSUBSCRIPT ctx end_POSTSUBSCRIPT.

Generating the next token is the primary application of autoregressive LMs. This process consists of two steps: computing the forward pass of the LM and then selecting the next token based on the output. The selection can be deterministic or non-deterministic.

Non-deterministic selection procedures apply the softmax function after the forward pass of LMs and sample from the resulting probability vector:

softmax:nctx×nvocab[0,1]nvocab such that the entries sum to 1.:softmaxsuperscriptsubscript𝑛ctxsubscript𝑛vocabsuperscript01subscript𝑛vocab such that the entries sum to 1\text{softmax}:\mathbb{R}^{n_{\textnormal{ctx}}\times n_{\textnormal{vocab}}}% \to[0,1]^{n_{\textnormal{vocab}}}\text{ such that the entries sum to 1}.softmax : blackboard_R start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT ctx end_POSTSUBSCRIPT × italic_n start_POSTSUBSCRIPT vocab end_POSTSUBSCRIPT end_POSTSUPERSCRIPT → [ 0 , 1 ] start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT vocab end_POSTSUBSCRIPT end_POSTSUPERSCRIPT such that the entries sum to 1 . (1)

For convenience, we denote the output probability vector by f(xi)𝑓subscript𝑥absent𝑖f\left({x_{\leq i}}\right)italic_f ( italic_x start_POSTSUBSCRIPT ≤ italic_i end_POSTSUBSCRIPT ):

xi+1f(xi):=softmax(f(xi)):=softmax(f(x0x1xi)),similar-tosubscript𝑥𝑖1𝑓subscript𝑥absent𝑖assignsoftmax𝑓subscript𝑥absent𝑖assignsoftmax𝑓direct-sumsubscript𝑥absent0subscript𝑥1subscript𝑥𝑖\displaystyle x_{i+1}\sim f\left({x_{\leq i}}\right):=\text{softmax}\left({f% \left({x_{\leq i}}\right)}\right):=\text{softmax}\left({f\left({x_{\leq 0}% \oplus x_{1}\oplus\dots\oplus x_{i}}\right)}\right),italic_x start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ∼ italic_f ( italic_x start_POSTSUBSCRIPT ≤ italic_i end_POSTSUBSCRIPT ) := softmax ( italic_f ( italic_x start_POSTSUBSCRIPT ≤ italic_i end_POSTSUBSCRIPT ) ) := softmax ( italic_f ( italic_x start_POSTSUBSCRIPT ≤ 0 end_POSTSUBSCRIPT ⊕ italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⊕ ⋯ ⊕ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) , (2)

where ab=(a,b)direct-sum𝑎𝑏𝑎𝑏a\oplus b=(a,b)italic_a ⊕ italic_b = ( italic_a , italic_b ) is the concatenation of the vectors a𝑎aitalic_a and b𝑏bitalic_b and xi:=x0x1xiassignsubscript𝑥absent𝑖direct-sumsubscript𝑥absent0subscript𝑥1subscript𝑥𝑖x_{\leq i}:=x_{\leq 0}\oplus x_{1}\oplus\dots\oplus x_{i}italic_x start_POSTSUBSCRIPT ≤ italic_i end_POSTSUBSCRIPT := italic_x start_POSTSUBSCRIPT ≤ 0 end_POSTSUBSCRIPT ⊕ italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⊕ ⋯ ⊕ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT.

For deterministic selection procedures, composing monotonic functions, such as softmax, is usually unnecessary. For example, the most likely next token is the argmax\arg\maxroman_arg roman_max of both the logits and the output of the softmax. Still, for convenience, we assume that LMs always output probability vectors. The sampling process in (2) is either deterministic (i.e., xi+1subscript𝑥𝑖1x_{i+1}italic_x start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT is the token with maximal probability) or random (achieved by randomly selecting xi+1subscript𝑥𝑖1x_{i+1}italic_x start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT from the distribution softmax(f(xi))softmax𝑓subscript𝑥absent𝑖\text{softmax}\left({f\left({x_{\leq i}}\right)}\right)softmax ( italic_f ( italic_x start_POSTSUBSCRIPT ≤ italic_i end_POSTSUBSCRIPT ) )).

Speculative Inference (SI) is an approach for accelerating the inference of a target LM (e.g., a member of the GPT series) fmsubscript𝑓𝑚f_{m}italic_f start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT. Such methods use faster LMs f1,,fm1subscript𝑓1subscript𝑓𝑚1f_{1},\dots,f_{m-1}italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_f start_POSTSUBSCRIPT italic_m - 1 end_POSTSUBSCRIPT that approximate the target model fmsubscript𝑓𝑚f_{m}italic_f start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT in order to reduce the total inference time. For example, Leviathan et al. (2023) reduces the amount of time to infer a target model f2subscript𝑓2f_{2}italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT on a given prompt x0subscript𝑥absent0x_{\leq 0}italic_x start_POSTSUBSCRIPT ≤ 0 end_POSTSUBSCRIPT by using batching. SI algorithms start by drafting k𝑘kitalic_k tokens xi:=f1(xi1):=f1(x0x1xi1)assignsubscriptsuperscript𝑥𝑖subscript𝑓1subscriptsuperscript𝑥absent𝑖1assignsubscript𝑓1direct-sumsubscript𝑥absent0subscriptsuperscript𝑥1subscriptsuperscript𝑥𝑖1x^{\prime}_{i}:=f_{1}(x^{\prime}_{\leq i-1}):=f_{1}(x_{\leq 0}\oplus x^{\prime% }_{1}\oplus\dots\oplus x^{\prime}_{i-1})italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT := italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ≤ italic_i - 1 end_POSTSUBSCRIPT ) := italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT ≤ 0 end_POSTSUBSCRIPT ⊕ italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⊕ ⋯ ⊕ italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT ) (i[k]𝑖delimited-[]𝑘i\in[k]italic_i ∈ [ italic_k ]) using a faster drafter model f1subscript𝑓1f_{1}italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT before seeding the prompts {xi}i=0ksubscriptsuperscriptsubscriptsuperscript𝑥absent𝑖𝑘𝑖0\{x^{\prime}_{\leq i}\}^{k}_{i=0}{ italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ≤ italic_i end_POSTSUBSCRIPT } start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT altogether as one input batch to the target model f2subscript𝑓2f_{2}italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. The idea is to take advantage of the fact that modern GPUs can process the batch {xi}i=0ksubscriptsuperscriptsubscriptsuperscript𝑥absent𝑖𝑘𝑖0\{x^{\prime}_{\leq i}\}^{k}_{i=0}{ italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ≤ italic_i end_POSTSUBSCRIPT } start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT faster than feeding the k𝑘kitalic_k individual sequences independently.

Straightforward algorithms of speculative inference are typically lossless in expectation, i.e., they generate tokens from the same distributions as the target would generate without speculation. Naive algorithms of speculation guarantee returning the same tokens as the target (Joao Gante, 2023; Spector and Re, 2023). More sophisticated algorithms of speculation might generate different tokens, but their generated tokens follow the distribution of the target (Leviathan et al., 2023; Chen et al., 2023).

To implement distributed algorithms for speculative inference, we use multiple processors, which are hardware components capable of executing threads. Processors can compute forward passes and sample tokens from the output probability vectors and we assume that threads can run in parallel. When using DSI we will run sequences of drafter models fj1,fj2,,fjksubscript𝑓subscript𝑗1subscript𝑓subscript𝑗2subscript𝑓subscript𝑗𝑘f_{j_{1}},f_{j_{2}},\dots,f_{j_{k}}italic_f start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , italic_f start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , … , italic_f start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT, where the first model takes x0subscript𝑥absent0x_{\leq 0}italic_x start_POSTSUBSCRIPT ≤ 0 end_POSTSUBSCRIPT and returns some token x1j1subscriptsuperscript𝑥subscript𝑗11x^{j_{1}}_{1}italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, the second takes x0x1j1direct-sumsubscript𝑥absent0subscriptsuperscript𝑥subscript𝑗11x_{\leq 0}\oplus x^{j_{1}}_{1}italic_x start_POSTSUBSCRIPT ≤ 0 end_POSTSUBSCRIPT ⊕ italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT as a prompt and returns x2j1,j2subscriptsuperscript𝑥subscript𝑗1subscript𝑗22x^{j_{1},j_{2}}_{2}italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, and so on. As such, in order to denote that a given thread is computing the output of fjksubscript𝑓subscript𝑗𝑘f_{j_{k}}italic_f start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT on a sequence xk1j1,,jk1:=x0x1j1xk1j1,,jk1assignsubscriptsuperscript𝑥subscript𝑗1subscript𝑗𝑘1absent𝑘1direct-sumsubscript𝑥absent0subscriptsuperscript𝑥subscript𝑗11subscriptsuperscript𝑥subscript𝑗1subscript𝑗𝑘1𝑘1x^{j_{1},\dots,j_{k-1}}_{\leq k-1}:=x_{\leq 0}\oplus x^{j_{1}}_{1}\oplus\dots% \oplus x^{j_{1},\dots,j_{k-1}}_{k-1}italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_j start_POSTSUBSCRIPT italic_k - 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ≤ italic_k - 1 end_POSTSUBSCRIPT := italic_x start_POSTSUBSCRIPT ≤ 0 end_POSTSUBSCRIPT ⊕ italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⊕ ⋯ ⊕ italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_j start_POSTSUBSCRIPT italic_k - 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k - 1 end_POSTSUBSCRIPT, we denote CJsubscript𝐶𝐽C_{J}italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT, where J=(j1,,jk)𝐽subscript𝑗1subscript𝑗𝑘J=(j_{1},\dots,j_{k})italic_J = ( italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_j start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ). When a thread CJsubscript𝐶𝐽C_{J}italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT computes an LM, we denote the output probability vector by CJ[prob]subscript𝐶𝐽[prob]C_{J}\textnormal{[prob]}italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT [prob]. If CJsubscript𝐶𝐽C_{J}italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT samples a new token from CJ[prob]subscript𝐶𝐽[prob]C_{J}\textnormal{[prob]}italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT [prob], we denote this token by CJ[new]subscript𝐶𝐽[new]C_{J}\textnormal{[new]}italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT [new]. For example, thread CJsubscript𝐶𝐽C_{J}italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT implementing (2) above will have

CJ[prompt]:=xi,CJ[prob]:=f(CJ[prompt]) and CJ[new]CJ[prob].formulae-sequenceassignsubscript𝐶𝐽[prompt]subscript𝑥absent𝑖assignsubscript𝐶𝐽[prob]𝑓subscript𝐶𝐽[prompt] and subscript𝐶𝐽[new]similar-tosubscript𝐶𝐽[prob]C_{J}\textnormal{[prompt]}:=x_{\leq i},~{}C_{J}\textnormal{[prob]}:=f\left({C_% {J}\textnormal{[prompt]}}\right)\textnormal{ and }C_{J}\textnormal{[new]}\sim C% _{J}\textnormal{[prob]}.italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT [prompt] := italic_x start_POSTSUBSCRIPT ≤ italic_i end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT [prob] := italic_f ( italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT [prompt] ) and italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT [new] ∼ italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT [prob] .

Once a thread CJsubscript𝐶𝐽C_{J}italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT finishes sampling a new token, the thread outputs the concatenation of CJ[prompt]subscript𝐶𝐽[prompt]C_{J}\textnormal{[prompt]}italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT [prompt] and CJ[new]subscript𝐶𝐽[new]C_{J}\textnormal{[new]}italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT [new]. Following the example in (2), we have

CJ[return]:=CJ[prompt](CJ[new]):=(x0,x1,,xi+1).assignsubscript𝐶𝐽[return]direct-sumsubscript𝐶𝐽[prompt]subscript𝐶𝐽[new]assignsubscript𝑥absent0subscript𝑥1subscript𝑥𝑖1\displaystyle C_{J}\textnormal{[return]}:=C_{J}\textnormal{[prompt]}\oplus% \left(C_{J}\textnormal{[new]}\right):=(x_{\leq 0},x_{1},\ldots,x_{i+1}).italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT [return] := italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT [prompt] ⊕ ( italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT [new] ) := ( italic_x start_POSTSUBSCRIPT ≤ 0 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_x start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) .

A new thread that was initiated by CJsubscript𝐶𝐽C_{J}italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT is denoted by CJ(j)subscript𝐶direct-sum𝐽𝑗C_{J\oplus(j)}italic_C start_POSTSUBSCRIPT italic_J ⊕ ( italic_j ) end_POSTSUBSCRIPT, where J(j)direct-sum𝐽𝑗J\oplus(j)italic_J ⊕ ( italic_j ) is the concatenation of J𝐽Jitalic_J and (j)𝑗(j)( italic_j ). The set of all the threads that originate from CJsubscript𝐶𝐽C_{J}italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT is {CJJ:J is a nonempty tuple}conditional-setsubscript𝐶direct-sum𝐽superscript𝐽superscript𝐽 is a nonempty tuple\{C_{J\oplus J^{\prime}}:J^{\prime}\text{ is a nonempty tuple}\}{ italic_C start_POSTSUBSCRIPT italic_J ⊕ italic_J start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT : italic_J start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT is a nonempty tuple }. We assume that terminating a concurrent thread terminates all the threads that originate from it.

Time in this paper is the wall time. We measure the time that passes from the initiation of a task until its termination. A task is a nonempty set of threads, denoted by {CJ:J𝔍}conditional-setsubscript𝐶𝐽𝐽𝔍\left\{{C_{J}:J\in\mathfrak{J}}\right\}{ italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT : italic_J ∈ fraktur_J }. Its time is

Twall[{CJ}J𝔍]:=maxJ𝔍(Timepoint CJ finishes)minJ𝔍(Timepoint CJ starts).assignsubscript𝑇walldelimited-[]subscriptsubscript𝐶𝐽𝐽𝔍subscript𝐽𝔍(Timepoint CJ finishes)subscript𝐽𝔍(Timepoint CJ starts)T_{\text{wall}}\left[\left\{{C_{J}}\right\}_{J\in\mathfrak{J}}\right]:=\max_{J% \in\mathfrak{J}}\text{(Timepoint $C_{J}$ finishes)}-\min_{J\in\mathfrak{J}}% \text{(Timepoint $C_{J}$ starts)}.italic_T start_POSTSUBSCRIPT wall end_POSTSUBSCRIPT [ { italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_J ∈ fraktur_J end_POSTSUBSCRIPT ] := roman_max start_POSTSUBSCRIPT italic_J ∈ fraktur_J end_POSTSUBSCRIPT (Timepoint italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT finishes) - roman_min start_POSTSUBSCRIPT italic_J ∈ fraktur_J end_POSTSUBSCRIPT (Timepoint italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT starts) .

When a task consists of a single thread, we omit the curly brackets, namely,

Twall[CJ]:=Twall[{CJ}] where |{CJ}|=1.assignsubscript𝑇walldelimited-[]subscript𝐶𝐽subscript𝑇walldelimited-[]subscript𝐶𝐽 where subscript𝐶𝐽1T_{\text{wall}}\left[C_{J}\right]:=T_{\text{wall}}\left[\left\{{C_{J}}\right\}% \right]\text{ where }|\left\{{C_{J}}\right\}|=1.italic_T start_POSTSUBSCRIPT wall end_POSTSUBSCRIPT [ italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT ] := italic_T start_POSTSUBSCRIPT wall end_POSTSUBSCRIPT [ { italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT } ] where | { italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT } | = 1 .

Note that two threads, denoted by CJsubscript𝐶𝐽C_{J}italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT and CJsubscript𝐶superscript𝐽C_{J^{\prime}}italic_C start_POSTSUBSCRIPT italic_J start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT, may run concurrently and overlap in time. Hence, it is possible that max{Twall[CJ],Twall[CJ]}Twall[{CJ,CJ}]<Twall[CJ]+Twall[CJ]subscript𝑇walldelimited-[]subscript𝐶𝐽subscript𝑇walldelimited-[]subscript𝐶superscript𝐽subscript𝑇walldelimited-[]subscript𝐶𝐽subscript𝐶superscript𝐽subscript𝑇walldelimited-[]subscript𝐶𝐽subscript𝑇walldelimited-[]subscript𝐶superscript𝐽\max\left\{{T_{\text{wall}}\left[C_{J}\right],T_{\text{wall}}\left[C_{J^{% \prime}}\right]}\right\}\leq T_{\text{wall}}\left[\left\{{C_{J},C_{J^{\prime}}% }\right\}\right]<T_{\text{wall}}\left[C_{J}\right]+T_{\text{wall}}\left[C_{J^{% \prime}}\right]roman_max { italic_T start_POSTSUBSCRIPT wall end_POSTSUBSCRIPT [ italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT ] , italic_T start_POSTSUBSCRIPT wall end_POSTSUBSCRIPT [ italic_C start_POSTSUBSCRIPT italic_J start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ] } ≤ italic_T start_POSTSUBSCRIPT wall end_POSTSUBSCRIPT [ { italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_J start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT } ] < italic_T start_POSTSUBSCRIPT wall end_POSTSUBSCRIPT [ italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT ] + italic_T start_POSTSUBSCRIPT wall end_POSTSUBSCRIPT [ italic_C start_POSTSUBSCRIPT italic_J start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ]. However, if CJsubscript𝐶𝐽C_{J}italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT and CJsubscript𝐶superscript𝐽C_{J^{\prime}}italic_C start_POSTSUBSCRIPT italic_J start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT do not overlap in time, then Twall[{CJ,CJ}]Twall[CJ]+Twall[CJ]subscript𝑇walldelimited-[]subscript𝐶𝐽subscript𝐶superscript𝐽subscript𝑇walldelimited-[]subscript𝐶𝐽subscript𝑇walldelimited-[]subscript𝐶superscript𝐽T_{\text{wall}}\left[\left\{{C_{J},C_{J^{\prime}}}\right\}\right]\geq T_{\text% {wall}}\left[C_{J}\right]+T_{\text{wall}}\left[C_{J^{\prime}}\right]italic_T start_POSTSUBSCRIPT wall end_POSTSUBSCRIPT [ { italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_J start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT } ] ≥ italic_T start_POSTSUBSCRIPT wall end_POSTSUBSCRIPT [ italic_C start_POSTSUBSCRIPT italic_J end_POSTSUBSCRIPT ] + italic_T start_POSTSUBSCRIPT wall end_POSTSUBSCRIPT [ italic_C start_POSTSUBSCRIPT italic_J start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ].

3 Distributed Speculative Inference

While previous methods for SI (Leviathan et al., 2023; Chen et al., 2023; Miao et al., 2023) are useful for speeding up the inference, they overlook the idea of utilizing multiple processing units to compute LM outputs in parallel. In this section, we outline a theoretically sound approach to infer LMs using a sufficiently large number of processors. The naive version of our method operates under the assumption that we have access to a sufficient amount of processors so that threads never have to wait. Later, we discuss how our method can be implemented in practice with a fixed number of processors.

Algorithm 1 Distributed Speculative Inference (DSI) of N𝑁Nitalic_N tokens
0:  A prompt x0subscript𝑥absent0x_{\leq 0}italic_x start_POSTSUBSCRIPT ≤ 0 end_POSTSUBSCRIPT, and m𝑚mitalic_m autoregressive models, f1,f2,,fmsubscript𝑓1subscript𝑓2subscript𝑓𝑚f_{1},f_{2},\dots,f_{m}italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_f start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT.
1:  v=1𝑣1v=1italic_v = 1.
2:  initiate m𝑚mitalic_m threads C(1),,C(m)subscript𝐶1subscript𝐶𝑚C_{(1)},\dots,C_{(m)}italic_C start_POSTSUBSCRIPT ( 1 ) end_POSTSUBSCRIPT , … , italic_C start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT such that C(j1)subscript𝐶subscript𝑗1C_{(j_{1})}italic_C start_POSTSUBSCRIPT ( italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT generates the token xj1fj1(x0)similar-tosuperscript𝑥subscript𝑗1subscript𝑓subscript𝑗1subscript𝑥absent0x^{j_{1}}\sim f_{j_{1}}(x_{\leq 0})italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ∼ italic_f start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT ≤ 0 end_POSTSUBSCRIPT ) for all j1[m]subscript𝑗1delimited-[]𝑚j_{1}\in[m]italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∈ [ italic_m ] concurrently.
3:  label thread C(m)subscript𝐶𝑚C_{(m)}italic_C start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT as the current verifier.
4:  ONCE any thread CJ(j)subscript𝐶direct-sum𝐽𝑗C_{J\oplus(j)}italic_C start_POSTSUBSCRIPT italic_J ⊕ ( italic_j ) end_POSTSUBSCRIPT finishes to generate a token, namely, sampled CJ(j)[new]fj(CJ(j)[prompt])similar-tosubscript𝐶direct-sum𝐽𝑗[new]subscript𝑓𝑗subscript𝐶direct-sum𝐽𝑗[prompt]C_{J\oplus(j)}\textnormal{[new]}\sim f_{j}\left(C_{J\oplus(j)}\textnormal{[% prompt]}\right)italic_C start_POSTSUBSCRIPT italic_J ⊕ ( italic_j ) end_POSTSUBSCRIPT [new] ∼ italic_f start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_C start_POSTSUBSCRIPT italic_J ⊕ ( italic_j ) end_POSTSUBSCRIPT [prompt] ):
5:  if |J|+1<N𝐽1𝑁|J|+1<N| italic_J | + 1 < italic_N then
6:     initiate m𝑚mitalic_m threads, CJ(j,1),CJ(j,2),,CJ(j,m)subscript𝐶direct-sum𝐽𝑗1subscript𝐶direct-sum𝐽𝑗2subscript𝐶direct-sum𝐽𝑗𝑚C_{J\oplus(j,1)},C_{J\oplus(j,2)},\dots,C_{J\oplus(j,m)}italic_C start_POSTSUBSCRIPT italic_J ⊕ ( italic_j , 1 ) end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_J ⊕ ( italic_j , 2 ) end_POSTSUBSCRIPT , … , italic_C start_POSTSUBSCRIPT italic_J ⊕ ( italic_j , italic_m ) end_POSTSUBSCRIPT, to generate a token concurrently and respectively from f1,f2,,fmsubscript𝑓1subscript𝑓2subscript𝑓𝑚f_{1},f_{2},\dots,f_{m}italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_f start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT.
7:     if CJ(j)subscript𝐶direct-sum𝐽𝑗C_{J\oplus(j)}italic_C start_POSTSUBSCRIPT italic_J ⊕ ( italic_j ) end_POSTSUBSCRIPT is the current verifier thread then
8:        terminate all threads CJ(j)subscript𝐶direct-sum𝐽superscript𝑗C_{J\oplus(j^{\prime})}italic_C start_POSTSUBSCRIPT italic_J ⊕ ( italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT (and their descendant threads) that sampled a different token than CJ(j)subscript𝐶direct-sum𝐽𝑗C_{J\oplus(j)}italic_C start_POSTSUBSCRIPT italic_J ⊕ ( italic_j ) end_POSTSUBSCRIPT.
9:        let j=argminj[m]{jCJ(j)[new]=CJ(j)[new]}superscript𝑗subscriptsuperscript𝑗delimited-[]𝑚conditionalsuperscript𝑗subscript𝐶direct-sum𝐽superscript𝑗[new]subscript𝐶direct-sum𝐽𝑗[new]j^{*}=\arg\min\limits_{j^{\prime}\in[m]}\{j^{\prime}\mid C_{J\oplus(j^{\prime}% )}\textnormal{[new]}=C_{J\oplus(j)}\textnormal{[new]}\}italic_j start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = roman_arg roman_min start_POSTSUBSCRIPT italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ [ italic_m ] end_POSTSUBSCRIPT { italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∣ italic_C start_POSTSUBSCRIPT italic_J ⊕ ( italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT [new] = italic_C start_POSTSUBSCRIPT italic_J ⊕ ( italic_j ) end_POSTSUBSCRIPT [new] }.
10:        terminate all threads CJ(j)subscript𝐶direct-sum𝐽superscript𝑗C_{J\oplus(j^{\prime})}italic_C start_POSTSUBSCRIPT italic_J ⊕ ( italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT (and their descendant threads), where j>jsuperscript𝑗superscript𝑗j^{\prime}>j^{*}italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT > italic_j start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT.
11:        label CJ(j,m)subscript𝐶direct-sum𝐽superscript𝑗𝑚C_{J\oplus(j^{*},m)}italic_C start_POSTSUBSCRIPT italic_J ⊕ ( italic_j start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_m ) end_POSTSUBSCRIPT as the current verifier.
12:        update v=v+1𝑣𝑣1v=v+1italic_v = italic_v + 1.
13:        if CJ(j,m)subscript𝐶direct-sum𝐽superscript𝑗𝑚C_{J\oplus(j^{*},m)}italic_C start_POSTSUBSCRIPT italic_J ⊕ ( italic_j start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_m ) end_POSTSUBSCRIPT has already finished then
14:           go back to step 7 with J=J(j,m)𝐽direct-sum𝐽superscript𝑗𝑚J=J\oplus(j^{*},m)italic_J = italic_J ⊕ ( italic_j start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_m ).
15:        end if
16:     end if
17:  else if the last entry of J(j)direct-sum𝐽𝑗J\oplus(j)italic_J ⊕ ( italic_j ) equals m𝑚mitalic_m (i.e., j=m𝑗𝑚j=mitalic_j = italic_mthen
18:     return  CJ(j)[return]subscript𝐶direct-sum𝐽𝑗[return]C_{J\oplus(j)}\textnormal{[return]}italic_C start_POSTSUBSCRIPT italic_J ⊕ ( italic_j ) end_POSTSUBSCRIPT [return].
19:  end if
20:  end ONCE

3.1 Method Overview

Consider the task of computing N𝑁Nitalic_N output tokens autoregressively from a target model fmsubscript𝑓𝑚f_{m}italic_f start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT given a prompt x0subscript𝑥absent0x_{\leq 0}italic_x start_POSTSUBSCRIPT ≤ 0 end_POSTSUBSCRIPT. We have a set of faster drafter models, f1,,fm1subscript𝑓1subscript𝑓𝑚1f_{1},\dots,f_{m-1}italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_f start_POSTSUBSCRIPT italic_m - 1 end_POSTSUBSCRIPT, that are all faster than fmsubscript𝑓𝑚f_{m}italic_f start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT (as defined in Assumption 2). Our goal is to compute xi=fm(xi1)subscript𝑥𝑖subscript𝑓𝑚subscript𝑥absent𝑖1x_{i}=f_{m}(x_{\leq i-1})italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_f start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT ≤ italic_i - 1 end_POSTSUBSCRIPT ) for all i[N]𝑖delimited-[]𝑁i\in[N]italic_i ∈ [ italic_N ]. To achieve this, we initiate m𝑚mitalic_m threads, C(1),,C(m)subscript𝐶1subscript𝐶𝑚C_{(1)},\dots,C_{(m)}italic_C start_POSTSUBSCRIPT ( 1 ) end_POSTSUBSCRIPT , … , italic_C start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT (line 2). Each thread, denoted as (j1)subscript𝑗1(j_{1})( italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ), is responsible for computing x1j1=fj1(x0)subscriptsuperscript𝑥subscript𝑗11subscript𝑓subscript𝑗1subscript𝑥absent0x^{j_{1}}_{1}=f_{j_{1}}(x_{\leq 0})italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = italic_f start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT ≤ 0 end_POSTSUBSCRIPT ). Once a thread, C(j1)subscript𝐶subscript𝑗1C_{(j_{1})}italic_C start_POSTSUBSCRIPT ( italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT, finishes computation, we instantiate m𝑚mitalic_m new threads, C(j1,j2)subscript𝐶subscript𝑗1subscript𝑗2C_{(j_{1},j_{2})}italic_C start_POSTSUBSCRIPT ( italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT, to calculate x2j1,j2=fj2(x0x1j1)subscriptsuperscript𝑥subscript𝑗1subscript𝑗22subscript𝑓subscript𝑗2direct-sumsubscript𝑥absent0subscriptsuperscript𝑥subscript𝑗11x^{j_{1},j_{2}}_{2}=f_{j_{2}}(x_{\leq 0}\oplus x^{j_{1}}_{1})italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = italic_f start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT ≤ 0 end_POSTSUBSCRIPT ⊕ italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) for all j2[m]subscript𝑗2delimited-[]𝑚j_{2}\in[m]italic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ [ italic_m ]. In general, once we compute xr1j1,,jr1subscriptsuperscript𝑥subscript𝑗1subscript𝑗𝑟1𝑟1x^{j_{1},\dots,j_{r-1}}_{r-1}italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_j start_POSTSUBSCRIPT italic_r - 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_r - 1 end_POSTSUBSCRIPT, we initiate m𝑚mitalic_m new threads, C(j1,,jr1,1),,C(j1,,jr1,m)subscript𝐶subscript𝑗1subscript𝑗𝑟11subscript𝐶subscript𝑗1subscript𝑗𝑟1𝑚C_{(j_{1},\dots,j_{r-1},1)},\dots,C_{(j_{1},\dots,j_{r-1},m)}italic_C start_POSTSUBSCRIPT ( italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_j start_POSTSUBSCRIPT italic_r - 1 end_POSTSUBSCRIPT , 1 ) end_POSTSUBSCRIPT , … , italic_C start_POSTSUBSCRIPT ( italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_j start_POSTSUBSCRIPT italic_r - 1 end_POSTSUBSCRIPT , italic_m ) end_POSTSUBSCRIPT, to compute xrj1,,jr=fjr(x0x1j1xr1j1,,jr1)subscriptsuperscript𝑥subscript𝑗1subscript𝑗𝑟𝑟subscript𝑓subscript𝑗𝑟direct-sumsubscript𝑥absent0subscriptsuperscript𝑥subscript𝑗11subscriptsuperscript𝑥subscript𝑗1subscript𝑗𝑟1𝑟1x^{j_{1},\dots,j_{r}}_{r}=f_{j_{r}}(x_{\leq 0}\oplus x^{j_{1}}_{1}\oplus\dots% \oplus x^{j_{1},\dots,j_{r-1}}_{r-1})italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_j start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT = italic_f start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT ≤ 0 end_POSTSUBSCRIPT ⊕ italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⊕ ⋯ ⊕ italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_j start_POSTSUBSCRIPT italic_r - 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_r - 1 end_POSTSUBSCRIPT ) for all jr[m]subscript𝑗𝑟delimited-[]𝑚j_{r}\in[m]italic_j start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ∈ [ italic_m ]. This is captured in lines 4 and 6.

Once C(m)subscript𝐶𝑚C_{(m)}italic_C start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT completes its computation and provides the correct value of the first output token x1m=x1subscriptsuperscript𝑥𝑚1subscript𝑥1x^{m}_{1}=x_{1}italic_x start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, we can verify which other threads, C(j1)subscript𝐶subscript𝑗1C_{(j_{1})}italic_C start_POSTSUBSCRIPT ( italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT, have accurately computed x1subscript𝑥1x_{1}italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT. Any thread C(j1)subscript𝐶subscript𝑗1C_{(j_{1})}italic_C start_POSTSUBSCRIPT ( italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT where x1j1x1subscriptsuperscript𝑥subscript𝑗11subscript𝑥1x^{j_{1}}_{1}\neq x_{1}italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ≠ italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT is immediately terminated along with its descendant processes. For each j1[m]subscript𝑗1delimited-[]𝑚j_{1}\in[m]italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∈ [ italic_m ] that correctly computed x1j1=x1subscriptsuperscript𝑥subscript𝑗11subscript𝑥1x^{j_{1}}_{1}=x_{1}italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, we continue with computing xj1,j2=fj2(x0x1j1)superscript𝑥subscript𝑗1subscript𝑗2subscript𝑓subscript𝑗2direct-sumsubscript𝑥absent0subscriptsuperscript𝑥subscript𝑗11x^{j_{1},j_{2}}=f_{j_{2}}(x_{\leq 0}\oplus x^{j_{1}}_{1})italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT = italic_f start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT ≤ 0 end_POSTSUBSCRIPT ⊕ italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) for all j2[m]subscript𝑗2delimited-[]𝑚j_{2}\in[m]italic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ [ italic_m ]. However, since all threads are computing the same set of tokens, we terminate all but the one corresponding to the smallest value of j1subscript𝑗1j_{1}italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT that satisfies x1j1=x1subscriptsuperscript𝑥subscript𝑗11subscript𝑥1x^{j_{1}}_{1}=x_{1}italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT. In essence, C(m)subscript𝐶𝑚C_{(m)}italic_C start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT serves as a verifier, identifying drafters that miscalculated the initial part of the autoregressive computation. Once we retain one valid j1subscript𝑗1j_{1}italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, we relabel C(j1,m)subscript𝐶subscript𝑗1𝑚C_{(j_{1},m)}italic_C start_POSTSUBSCRIPT ( italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_m ) end_POSTSUBSCRIPT as the new verifier thread. We know that since C(j1)subscript𝐶subscript𝑗1C_{(j_{1})}italic_C start_POSTSUBSCRIPT ( italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT returned the correct token x1j1=x1subscriptsuperscript𝑥subscript𝑗11subscript𝑥1x^{j_{1}}_{1}=x_{1}italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and x2=fm(x1)subscript𝑥2subscript𝑓𝑚subscript𝑥absent1x_{2}=f_{m}(x_{\leq 1})italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = italic_f start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT ≤ 1 end_POSTSUBSCRIPT ), the output of C(j1,m)subscript𝐶subscript𝑗1𝑚C_{{(j_{1},m)}}italic_C start_POSTSUBSCRIPT ( italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_m ) end_POSTSUBSCRIPT must be correct. When that thread finishes, among the remaining threads, C(j1,j2)subscript𝐶subscript𝑗1subscript𝑗2C_{(j_{1},j_{2})}italic_C start_POSTSUBSCRIPT ( italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT, we terminate those that miscalculated x2=x2j1,msubscript𝑥2subscriptsuperscript𝑥subscript𝑗1𝑚2x_{2}=x^{j_{1},m}_{2}italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT and keep only the one with x2j1,j2=x2j1,m=x2subscriptsuperscript𝑥subscript𝑗1subscript𝑗22subscriptsuperscript𝑥subscript𝑗1𝑚2subscript𝑥2x^{j_{1},j_{2}}_{2}=x^{j_{1},m}_{2}=x_{2}italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, whose index j2subscript𝑗2j_{2}italic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT is minimal. We continue this process until the output xj1,,jN1,msuperscript𝑥subscript𝑗1subscript𝑗𝑁1𝑚x^{j_{1},\dots,j_{N-1},m}italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_j start_POSTSUBSCRIPT italic_N - 1 end_POSTSUBSCRIPT , italic_m end_POSTSUPERSCRIPT is obtained from the last verifier thread C(j1,,jN1,m)subscript𝐶subscript𝑗1subscript𝑗𝑁1𝑚C_{(j_{1},\dots,j_{N-1},m)}italic_C start_POSTSUBSCRIPT ( italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_j start_POSTSUBSCRIPT italic_N - 1 end_POSTSUBSCRIPT , italic_m ) end_POSTSUBSCRIPT. The process of relabeling verifier threads and terminating irrelevant threads is outlined in lines 810, and 11. Line 13 considers the case where the newly labeled thread may have already finished. If so, in line 14, we return to line 7 with the new verifier thread.

Preemption in DSI.  A crucial enhancement in DSI is the introduction of preemption. Algorithm 1 invokes a new process for each token. With preemption, DSI waits and processes a batch of tokens after a lookahead number of tokens have been drafted. This change reduces the number of invocations required, allowing the use of a fixed number of processing units. The lookahead parameter allows tuning DSI to use an arbitrary maximal number of available processing units. The maximum number of required processing units is within the range [2,max{2,target_latencylookaheaddrafter_latency}]22target_latencylookaheaddrafter_latency\left[{2,\max\left\{{2,\left\lceil\frac{\texttt{target\_latency}}{\texttt{% lookahead}\cdot\texttt{drafter\_latency}}\right\rceil}\right\}}\right][ 2 , roman_max { 2 , ⌈ divide start_ARG target_latency end_ARG start_ARG lookahead ⋅ drafter_latency end_ARG ⌉ } ]. For example, for a drafter of 5%percent55\%5 % latency and lookahead=5lookahead5\texttt{lookahead}=5lookahead = 5, having 4444 processing units is enough. Preempting threads that are not actively labeled as a verifier maintains the integrity of the speculative inference algorithm, ensuring that the theoretical foundations hold while optimizing performance.

Rejection sampling algorithm.  As can be seen in lines 8 and 10, Algorithm 1 rejects/terminates any thread (and its descendants) that returns a token that is not exactly the same as the token returned by the current verifier. However, this criterion is fairly strict and leads to many rejections in practical settings. Even if the drafter is another instance of the target, they may disagree due to the randomness of the sampling. In order to increase the number of acceptances while maintaining the distribution of the outputs of the target model, Leviathan et al. (2023); Miao et al. (2023) suggested different relaxed methods for rejecting draft outputs. In order to incorporate these rejection sampling methods, we can replace lines 8-9 with an application of their rejection sampling procedures.

3.2 Analysis

As a next step, we would like to prove that DSI (Algorithm 1) always returns the correct sequence of tokens x1,,xNsubscript𝑥1subscript𝑥𝑁x_{1},\dots,x_{N}italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT (it is lossless) and that it runs at least as fast as non-SI and SI. Before we state our main theoretical results, we state several assumptions that will be used in the analysis. The proofs are provided in Appendix A.

Assumption 1.

We assume the existence of a (universal) constant c>0𝑐0c>0italic_c > 0 such that, for any input prompt x0subscript𝑥absent0x_{\leq 0}italic_x start_POSTSUBSCRIPT ≤ 0 end_POSTSUBSCRIPT and model index j[m]𝑗delimited-[]𝑚j\in[m]italic_j ∈ [ italic_m ], we have:

Twall[computing fj(x0)](0,c) and Twall[sampling xfj(x0)]=0.subscript𝑇walldelimited-[]computing subscript𝑓𝑗subscript𝑥absent00𝑐 and subscript𝑇walldelimited-[]similar-tosampling 𝑥subscript𝑓𝑗subscript𝑥absent00\displaystyle T_{\text{wall}}\left[\text{computing }f_{j}\left(x_{\leq 0}% \right)\right]\in(0,c)~{}~{}\text{ and }~{}~{}T_{\text{wall}}\left[\text{% sampling }x\sim f_{j}\left(x_{\leq 0}\right)\right]=0.italic_T start_POSTSUBSCRIPT wall end_POSTSUBSCRIPT [ computing italic_f start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT ≤ 0 end_POSTSUBSCRIPT ) ] ∈ ( 0 , italic_c ) and italic_T start_POSTSUBSCRIPT wall end_POSTSUBSCRIPT [ sampling italic_x ∼ italic_f start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT ≤ 0 end_POSTSUBSCRIPT ) ] = 0 .
Assumption 2.

We assume that for all j[m1]𝑗delimited-[]𝑚1j\in[m-1]italic_j ∈ [ italic_m - 1 ], fjsubscript𝑓𝑗f_{j}italic_f start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT is faster than fmsubscript𝑓𝑚f_{m}italic_f start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT (denoted fjfmprecedes-or-equalssubscript𝑓𝑗subscript𝑓𝑚f_{j}\preceq f_{m}italic_f start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ⪯ italic_f start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT) in the following sense maxx0Twall[computing f1(x0)]minx0Twall[computing f2(x0)]subscriptsubscript𝑥absent0subscript𝑇walldelimited-[]computing subscript𝑓1subscript𝑥absent0subscriptsubscript𝑥absent0subscript𝑇walldelimited-[]computing subscript𝑓2subscript𝑥absent0\max_{x_{\leq 0}}T_{\text{wall}}\left[\text{computing }f_{1}\left(x_{\leq 0}% \right)\right]\leq\min_{x_{\leq 0}}T_{\text{wall}}\left[\text{computing }f_{2}% \left(x_{\leq 0}\right)\right]roman_max start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT ≤ 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT wall end_POSTSUBSCRIPT [ computing italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT ≤ 0 end_POSTSUBSCRIPT ) ] ≤ roman_min start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT ≤ 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT wall end_POSTSUBSCRIPT [ computing italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT ≤ 0 end_POSTSUBSCRIPT ) ].

Assumption 3.

We assume that Twall[{C(j1,,ji)}i=1k]=i=1kTwall[C(j1,,ji)]subscript𝑇walldelimited-[]subscriptsuperscriptsubscript𝐶subscript𝑗1subscript𝑗𝑖𝑘𝑖1subscriptsuperscript𝑘𝑖1subscript𝑇walldelimited-[]subscript𝐶subscript𝑗1subscript𝑗𝑖T_{\text{wall}}\left[\{C_{(j_{1},\dots,j_{i})}\}^{k}_{i=1}\right]=\sum^{k}_{i=% 1}T_{\text{wall}}\left[C_{(j_{1},\dots,j_{i})}\right]italic_T start_POSTSUBSCRIPT wall end_POSTSUBSCRIPT [ { italic_C start_POSTSUBSCRIPT ( italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_j start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT } start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT ] = ∑ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT wall end_POSTSUBSCRIPT [ italic_C start_POSTSUBSCRIPT ( italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_j start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT ].

The first assumption asserts that computing the output of any model takes a non-zero, bounded amount of time, and sampling a token from the output probabilities takes a negligible amount of time. The second assumption asserts that each drafter model runs faster than the target model, for any given input prompt. The third assumption asserts that computing xkj1,,jksubscriptsuperscript𝑥subscript𝑗1subscript𝑗𝑘𝑘x^{j_{1},\dots,j_{k}}_{k}italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_j start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT takes the time of first computing x1j1subscriptsuperscript𝑥subscript𝑗11x^{j_{1}}_{1}italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, then x2j1,j2subscriptsuperscript𝑥subscript𝑗1subscript𝑗22x^{j_{1},j_{2}}_{2}italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_j start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, and so forth, up to xkj1,,jksubscriptsuperscript𝑥subscript𝑗1subscript𝑗𝑘𝑘x^{j_{1},\dots,j_{k}}_{k}italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_j start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, with no delays.

The following theorem suggests that our method returns tokens from the same distributions as those the target would generate without speculation, and is at least as fast as iteratively applying the target model itself.

Theorem 1.

Under Assumptions 12 and 3, Algorithm 1 returns the same output and runs at least as fast as running the target model itself without speculative inference.

Theorem 2.

Under Assumptions 12 and 3, Algorithm 1 runs at least as fast as SI in expectation.

The advantage of Algorithm 1 lies in its concurrency. The following example shows how DSI can accelerate the inference of a given target model using a drafter model that is faster than the target model and returns the correct output with high probability.

Proposition 1.

Suppose we have a drafter model f1subscript𝑓1f_{1}italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, a target model f2subscript𝑓2f_{2}italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT and a prompt x0subscript𝑥absent0x_{\leq 0}italic_x start_POSTSUBSCRIPT ≤ 0 end_POSTSUBSCRIPT. Assume that f1subscript𝑓1f_{1}italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT requires t1subscript𝑡1t_{1}italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT time units to compute each of its outputs, and f2subscript𝑓2f_{2}italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT requires t2subscript𝑡2t_{2}italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT time units, where t2>t1subscript𝑡2subscript𝑡1t_{2}>t_{1}italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT > italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT. Assume that given the prompt xi=x0x1xisubscript𝑥absent𝑖direct-sumsubscript𝑥absent0subscript𝑥1subscript𝑥𝑖x_{\leq i}=x_{\leq 0}\oplus x_{1}\oplus\dots\oplus x_{i}italic_x start_POSTSUBSCRIPT ≤ italic_i end_POSTSUBSCRIPT = italic_x start_POSTSUBSCRIPT ≤ 0 end_POSTSUBSCRIPT ⊕ italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⊕ ⋯ ⊕ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, the probability that f1subscript𝑓1f_{1}italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT returns the (correct) token xi+1subscript𝑥𝑖1x_{i+1}italic_x start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT is p𝑝pitalic_p. Then, the expected time it takes Algorithm 1 to calculate the correct output is at most t1p(N1)+t2((1p)(N1)+1)subscript𝑡1𝑝𝑁1subscript𝑡21𝑝𝑁11t_{1}p(N-1)+t_{2}((1-p)(N-1)+1)italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_p ( italic_N - 1 ) + italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( ( 1 - italic_p ) ( italic_N - 1 ) + 1 ) time units, compared to the t2Nsubscript𝑡2𝑁t_{2}Nitalic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_N time units required if we were to compute f2subscript𝑓2f_{2}italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT without speculative inference.

Proposition 2.

Suppose we have a drafter model f1subscript𝑓1f_{1}italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, a target model f2subscript𝑓2f_{2}italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, and a prompt x0subscript𝑥absent0x_{\leq 0}italic_x start_POSTSUBSCRIPT ≤ 0 end_POSTSUBSCRIPT. Assume that f1subscript𝑓1f_{1}italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT requires t1subscript𝑡1t_{1}italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT time units to compute each of its outputs, and f2subscript𝑓2f_{2}italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT requires t2subscript𝑡2t_{2}italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT time units, where t1<t2subscript𝑡1subscript𝑡2t_{1}<t_{2}italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT < italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. Assume that given the prompt xi=x0x1xisubscript𝑥absent𝑖direct-sumsubscript𝑥absent0subscript𝑥1subscript𝑥𝑖x_{\leq i}=x_{\leq 0}\oplus x_{1}\oplus\dots\oplus x_{i}italic_x start_POSTSUBSCRIPT ≤ italic_i end_POSTSUBSCRIPT = italic_x start_POSTSUBSCRIPT ≤ 0 end_POSTSUBSCRIPT ⊕ italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⊕ ⋯ ⊕ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, the probability that f1subscript𝑓1f_{1}italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT returns the (correct) token xi+1subscript𝑥𝑖1x_{i+1}italic_x start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT is p𝑝pitalic_p. Then, Algorithm 1 is expected to complete before SI (Leviathan et al., 2023; Chen et al., 2023), where both algorithms use the same lookahead .

4 Experiments and Results

Target Drafter Dataset Target Latency (ms) Drafter Latency (ms) Drafter Latency (%) Acceptance Rate (%) Speedup DSI vs. SI
Vicuna-13B Vicuna-68M CNN-DM 37.7 2.5 6.5 63 1.47x
Vicuna-13B Vicuna-68M Alpaca 33.3 2.5 7.4 58 1.41x
Vicuna-7B Vicuna-68M CNN-DM 29.4 2.5 8.4 67 1.29x
Vicuna-7B Vicuna-68M Alpaca 26.0 2.5 9.5 59 1.70x
Starcoder-15B Starcoder-168M HumanEval 20.6 6.8 32.3 93 1.92x
Starcoder-15B Starcoder-168M MBPP 21.0 6.8 32.9 90 1.66x
Phi3-14B Phi3-4B HumanEval 52.1 34.0 65.3 95 1.41x
Phi3-14B Phi3-4B MBPP 52.2 34.3 65.8 94 1.37x
Phi3-14B Phi3-4B CNN-DM 52.4 34.6 66.0 93 1.39x
Phi3-14B Phi3-4B Alpaca 49.6 33.4 67.4 87 1.60x
Table 1: DSI Speedups over SI for various target/drafter pairs. We observe that DSI outperforms the SI implementation consistently across all models and tasks. These results are based on simulations with thread pools.

We conducted two sets of experiments to validate our theoretical results—that DSI outperforms SI (Leviathan et al., 2023) and non-SI (Theorems 1 and 2)—in practical settings.

Experiments with LLMs.  The first experiment (see Table 1) studies the latency values and acceptance rates of pairs of off-the-shelf target and drafter LLMs on various well-established datasets. All the LLMs were downloaded from the Hugging Face Hub and used as-is. We evaluate our method on four datasets including three tasks: text summarization using CNN Daily Mail (Hermann et al., 2015); instruction-following using Alpaca (Taori et al., 2023a); and code generation using MBPP (Austin et al., 2021) and HumanEval (Chen et al., 2021). For a complete description of the models, datasets and examples of relevant prompts, please refer to Appendix C and Appendix B.

For each combination of dataset d𝑑ditalic_d and corresponding target/drafter model f𝑓fitalic_f, we estimate the average latency of f𝑓fitalic_f in the following manner. First, we select 50 prompts from d𝑑ditalic_d uniformly at random, and for each prompt, generate 20 tokens using f𝑓fitalic_f, measuring the latency for each token in milliseconds. Following prior work, we distinguish between Time to First Token (TTFT) generation and Time Per Output Token (TPOT) generation (of all subsequent 19 tokens). Since TTFT is usually significantly longer than TPOT (which dominates the overall sequence generation time), all latency figures in Table 1 refer to TPOT, for brevity. Finally, we calculate the average TTFTs and TPOTs over all prompts per model/dataset pair, to estimate the expected latency of a single forward pass. Thus, the TPOT latency of the target LLM and the drafter LLM are shown in “Target Latency (ms)” and “Drafter Latency (ms)”, respectively. We also report the ratio between the target and drafter latencies and present it in percentages (“Drafter Latency (%)”).

In order to estimate the alignment level between each target and drafter pairs, we use the “Acceptance Rate” (AR). To calculate the AR, we generate 256 tokens using the drafter given the same prompt used by the target. For each prompt, we consider the lengths of the longest sequences of exact token matches between the target and the drafter. Below is a simplified example where tokens are counted as English words. If the target generates “We can only see a short distance ahead, but we can see plenty there that needs to be done. […]” and the drafter generates “We can only see a short distance ahead, we done. […]”, then the longest sequence of exact matches is 8 tokens long. The expected number of accepted drafts is n¯:=1NiNniassign¯𝑛1𝑁superscriptsubscript𝑖𝑁subscript𝑛𝑖\bar{n}:=\frac{1}{N}\sum_{i}^{N}n_{i}over¯ start_ARG italic_n end_ARG := divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT where nisubscript𝑛𝑖n_{i}italic_n start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is the number of accepted draft tokens for the i𝑖iitalic_ith prompt. The AR is then calculated as acceptance_rate:=111+n¯assignacceptance_rate111¯𝑛\texttt{acceptance\_rate}:=1-\frac{1}{1+\bar{n}}acceptance_rate := 1 - divide start_ARG 1 end_ARG start_ARG 1 + over¯ start_ARG italic_n end_ARG end_ARG. We estimate latency and AR on a single A100 80GB GPU.

This experiment suggests that off-the-shelf LLM “families” such as StarCoder (Li et al., 2023) or Vicuna (Zheng et al., 2024) can form good pairs of target and drafter. Such families consists of LLM versions of different sizes that were trained in similar ways and on similar datasets. We notice that even relatively small drafters demonstrate good alignment with larger LLMs of the same family. For example, Starcoder-168M (drafter) and Starcoder-15B (target) yield an AR of 93%.

Experiments with thread pools.  To measure the speedup of DSI relatively to SI (Speedup DSI vs. SI), we perform a simulation of generating 50 tokens using each target-drafter pairs on each dataset, using the latency and acceptance rate values computed above. The simulation of each combination above is run on multiple lookahead token values (namely 1, 5, and 10). For the DSI run, we perform a grid search over the lookahead tokens, but also different parallel target counts (namely 1 and 7). The DSI simulation involves opening a separate thread in parallel for each target, using python’s thread-pool implementation. Overall, DSI outperforms SI consistently across all models and tasks.

Refer to caption Refer to caption Refer to caption Refer to caption
(a) non-SI/SI (b) SI/DSI (c) non-SI/DSI (d) min(SI,non-SI)DSISInon-SIDSI\frac{\min(\text{SI},\text{non-SI})}{\text{DSI}}divide start_ARG roman_min ( SI , non-SI ) end_ARG start_ARG DSI end_ARG
Figure 1: Each heatmap (labeled “X/Y”) plots the ratio between the run time of algorithm X and the run time of algorithm Y. See Appendix D for the detailed results. (a): SI is slower than non-speculative inference (non-SI) when the drafter is either slow or inaccurate enough (pink marks slowdowns). DSI is never slower than either SI or non-SI. (b, c, d): DSI is always faster than speculative inference (SI) and non-speculative inference (non-SI) algorithms for various drafters. (d): DSI is up to 1.6x faster than the baseline algorithm, where the baseline is the faster between SI and non-SI. These results are based on simulations without thread pools.

Simulated SI without thread pools.  Figure 1 presents the results of an experiment that simulates SI across a wide range of configurations. The aim of this experiment is to estimate pairwise speedups: DSI compared to SI, DSI compared to non-SI, and SI compared to non-SI. Since SI is slower than non-SI in some configurations, we have included an additional comparison that shows DSI speedups relative to the faster of the two algorithms—SI or non-SI—for any given configuration. This experiment helps identify configurations where DSI achieves the highest speedup. It demonstrates that, unlike SI, our method introduces no slowdown compared to non-SI and consistently accelerates inference, provided there are enough processing units. Furthermore, we demonstrate that our method remains useful in practical settings with a relatively small number of processing units.

In the experiment, the latency of DSI is computed using the formula from Lemma 1. This means that the lookahead parameter is set to one and that we need a maximum of ceil(target_latencylookaheaddrafter_latency)=ceil(target_latencydrafter_latency)ceiltarget_latencylookaheaddrafter_latencyceiltarget_latencydrafter_latency\texttt{ceil}(\frac{\texttt{target\_latency}}{\texttt{lookahead}\cdot\texttt{% drafter\_latency}})=\texttt{ceil}(\frac{\texttt{target\_latency}}{\texttt{% drafter\_latency}})ceil ( divide start_ARG target_latency end_ARG start_ARG lookahead ⋅ drafter_latency end_ARG ) = ceil ( divide start_ARG target_latency end_ARG start_ARG drafter_latency end_ARG ) target servers (i.e., processors that are capable of computing the target). Having two target servers, for example, means that we may start computing the target on an input batch, even if the other instance of the target is not yet available. In practice, if we can compute the target LLM on a single GPU, then two target servers are two GPUs. If serving the target LLM takes two GPUs, then two target servers means we need four GPUs. For every configuration, we consider the latency of SI with lookahead {1,2,,200}absent12200\in\left\{{1,2,\ldots,200}\right\}∈ { 1 , 2 , … , 200 } that minimizes its latency for that configuration. The configurations considered are the cartesian product of drafter latency (0.01,0.02,,10.010.0210.01,0.02,\ldots,10.01 , 0.02 , … , 1), acceptance rate (0,0.01,0.02,,100.010.0210,0.01,0.02,\ldots,10 , 0.01 , 0.02 , … , 1), and lookahead (1,2,,200122001,2,\ldots,2001 , 2 , … , 200). For each combination of (drafter latency, acceptance rate, and lookahead ), we run 5 repeats of SI and average the results to estimate the expected latency of SI (the implementation of SI is described in Appendix D). Then, for each combination of (drafter latency, acceptance rate), we consider the minimal latency (letting SI select its optimal lookahead ). Since the lookahead is a tunable parameter, our experiment assumes that it will be optimized by the user so that SI is optimized. It is known (and trivial) that SI is highly sensitive to the choice of the lookahead . To calculate the speedup of algorithm A over algorithm B per (drafter_latency, acceptance_rate), we divide the latency of B by the latency of A. The speedups are not smooth for drafter latencies <20%absentpercent20<20\%< 20 % due to the discretization of the lookahead parameter. If we fix the lookahead and decrease the drafter latency to 0, the number of servers required by DSI grows to infinity. However, it is possible to tune the lookahead hyperparameter to an arbitrary number (as we did). For example, for lookahead =5lookahead 5\text{{lookahead} }=5typewriter_lookahead = 5, the speedups are smooth for both algorithms (Figure 6).

As shown in Figure 1(a), to achieve a speedup with SI compared to non-SI, the acceptance rate of the drafter must at least match the latency of the drafter model (the region above the pink region in the figure). This means that the SI algorithm cannot speed up the inference if the acceptance rate of the drafter is not sufficiently high for a given latency. Conversely, in Figure 1(b), we observe that DSI consistently speeds up the inference time, regardless of the latency and acceptance rates of the drafter. This provides our method with much greater flexibility and robustness. In Figure 1(c), we observe that DSI is faster than non-SI for all the configurations for which non-SI is faster than SI. Finally, to obtain a comprehensive view of the inference speedup achieved by DSI, in Figure 1(d), we compare the performance of DSI with the minimal runtime of both SI and non-SI.

5 Discussion

In this work we studied how to reduce the run time of speculative inference algorithms by taking advantage of multiple processing units (e.g., GPUs). We have shown that in contrast to their empirical success, traditional SI algorithms can end up slowing the inference of LMs in various practical settings. For instance, when the drafters are insufficiently accurate or the drafter is too slow. We showed that by taking advantage of multiple GPUs, we can design a speculatively inference algorithm that provably reduces the inference time of both non-SI and SI algorithms. Our simulations validate our theory, indicating speedups for all possible configurations. For each configuration, the comparison is between DSI and the faster alternative algorithm (SI or non-SI). In essence this work paves the way to additional SI algorithms that can orchestrate multiple processing units at the same time.

Limitations.  We introduce DSI and show that it is faster than SI and non-SI for all possible configurations by theoretical analysis and experiments. Our first experiment measures the time that it takes to compute LLMs. Then, the second experiment simulates DSI and SI. In the simulations, we replace the calls to LLMs with a wait. These wait times in the simulations are the expected wait times that we estimated in the first experiment. However, the simulations ignore latencies that exist in practice, such as the communication between processors (CPU and GPUs). Hence, the key limitation is that the algorithm is not yet implemented and tested over a physical computing node. Another limitation of DSI is the maximal number of servers that DSI requires. For example, if the target LLM fit a single GPU and the drafter latency is 14.29%percent14.2914.29\%14.29 %, then DSI orchestrates a total of eight GPUs (seven instances of the target and one for the drafter). For faster drafters, DSI requires additional target servers. The exact number of servers is discussed in section 4. In our simulations with off-the-shelf LLMs (Table 1), we only consider configurations that do not require more than seven target servers.

Broader Impacts.  DSI introduces a new tradeoff: reducing the inference latency by utilizing more computing resources. Hence, adopting DSI as an alternative to SI or non-SI algorithms will increase the overall computing resources consumption of LLM-based applications.

Acknowledgements

We thank Intel Labs for funding this research.

References

  • Andreas [2022] Jacob Andreas. Language models as agent models. arXiv preprint arXiv:2212.01681, 2022.
  • Austin et al. [2021] Jacob Austin, Augustus Odena, Maxwell Nye, Maarten Bosma, Henryk Michalewski, David Dohan, Ellen Jiang, Carrie Cai, Michael Terry, Quoc Le, et al. Program synthesis with large language models. arXiv preprint arXiv:2108.07732, 2021.
  • Bapna et al. [2020] Ankur Bapna, Naveen Arivazhagan, and Orhan Firat. Controlling computation versus quality for neural sequence models, 2020.
  • Ben Allal et al. [2022] Loubna Ben Allal, Niklas Muennighoff, Logesh Kumar Umapathi, Ben Lipkin, and Leandro von Werra. A framework for the evaluation of code generation models. https://github.com/bigcode-project/bigcode-evaluation-harness, 2022.
  • Brown et al. [2020] Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, et al. Language models are few-shot learners. Advances in neural information processing systems, 33:1877–1901, 2020.
  • Bubeck et al. [2023] Sébastien Bubeck, Varun Chandrasekaran, Ronen Eldan, Johannes Gehrke, Eric Horvitz, Ece Kamar, Peter Lee, Yin Tat Lee, Yuanzhi Li, Scott Lundberg, et al. Sparks of artificial general intelligence: Early experiments with gpt-4. arXiv preprint arXiv:2303.12712, 2023.
  • Burton [1985] F. Warren Burton. Speculative computation, parallelism, and functional programming. IEEE Transactions on Computers, C-34(12):1190–1193, 1985. doi: 10.1109/TC.1985.6312218.
  • Cai et al. [2024] Tianle Cai, Yuhong Li, Zhengyang Geng, Hongwu Peng, Jason D Lee, Deming Chen, and Tri Dao. Medusa: Simple llm inference acceleration framework with multiple decoding heads. arXiv preprint arXiv:2401.10774, 2024.
  • Chen et al. [2023] Charlie Chen, Sebastian Borgeaud, Geoffrey Irving, Jean-Baptiste Lespiau, Laurent Sifre, and John Jumper. Accelerating large language model decoding with speculative sampling. arXiv preprint arXiv:2302.01318, 2023.
  • Chen et al. [2021] Mark Chen, Jerry Tworek, Heewoo Jun, Qiming Yuan, Henrique Ponde de Oliveira Pinto, Jared Kaplan, Harri Edwards, Yuri Burda, Nicholas Joseph, Greg Brockman, et al. Evaluating large language models trained on code. arXiv preprint arXiv:2107.03374, 2021.
  • Dao et al. [2022] Tri Dao, Dan Fu, Stefano Ermon, Atri Rudra, and Christopher Ré. Flashattention: Fast and memory-efficient exact attention with io-awareness. Advances in Neural Information Processing Systems, 35:16344–16359, 2022.
  • Dettmers et al. [2022] Tim Dettmers, Mike Lewis, Younes Belkada, and Luke Zettlemoyer. Gpt3. int8 (): 8-bit matrix multiplication for transformers at scale. Advances in Neural Information Processing Systems, 35:30318–30332, 2022.
  • Elbayad et al. [2020] Maha Elbayad, Jiatao Gu, Edouard Grave, and Michael Auli. Depth-adaptive transformer. In International Conference on Learning Representations, 2020. URL https://openreview.net/forum?id=SJg7KhVKPH.
  • Frantar and Alistarh [2023] Elias Frantar and Dan Alistarh. Sparsegpt: Massive language models can be accurately pruned in one-shot. In International Conference on Machine Learning, pages 10323–10337. PMLR, 2023.
  • Frantar et al. [2022] Elias Frantar, Saleh Ashkboos, Torsten Hoefler, and Dan Alistarh. GPTQ: Accurate post-training quantization for generative pre-trained transformers. arXiv preprint arXiv:2210.17323, 2022.
  • Fried et al. [2023] Daniel Fried, Armen Aghajanyan, Jessy Lin, Sida Wang, Eric Wallace, Freda Shi, Ruiqi Zhong, Wen tau Yih, Luke Zettlemoyer, and Mike Lewis. Incoder: A generative model for code infilling and synthesis. In Proc. of ICLR, 2023. URL https://arxiv.longhoe.net/abs/2204.05999.
  • Gu et al. [2024] Yuxian Gu, Li Dong, Furu Wei, and Minlie Huang. MiniLLM: Knowledge distillation of large language models. In The Twelfth International Conference on Learning Representations, 2024. URL https://openreview.net/forum?id=5h0qf7IBZZ.
  • Han et al. [2022] Y. Han, G. Huang, S. Song, L. Yang, H. Wang, and Y. Wang. Dynamic neural networks: A survey. IEEE Transactions on Pattern Analysis; Machine Intelligence, 44(11):7436–7456, nov 2022. ISSN 1939-3539. doi: 10.1109/TPAMI.2021.3117837.
  • Hennessy and Patterson [2012] John L. Hennessy and David A. Patterson. Computer Architecture: A Quantitative Approach. Morgan Kaufmann, Amsterdam, 5 edition, 2012. ISBN 978-0-12-383872-8.
  • Hermann et al. [2015] Karl Moritz Hermann, Tomas Kocisky, Edward Grefenstette, Lasse Espeholt, Will Kay, Mustafa Suleyman, and Phil Blunsom. Teaching machines to read and comprehend. Advances in neural information processing systems, 28, 2015.
  • Hinton et al. [2015] Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. Distilling the knowledge in a neural network, 2015.
  • Hsu et al. [2022] Yen-Chang Hsu, Ting Hua, Sungen Chang, Qian Lou, Yilin Shen, and Hongxia **. Language model compression with weighted low-rank factorization. In International Conference on Learning Representations, 2022. URL https://openreview.net/forum?id=uPv9Y3gmAI5.
  • Hubara et al. [2018] Itay Hubara, Matthieu Courbariaux, Daniel Soudry, Ran El-Yaniv, and Yoshua Bengio. Quantized neural networks: Training neural networks with low precision weights and activations. Journal of Machine Learning Research, 18(187):1–30, 2018.
  • Joao Gante [2023] Joao Gante. Assisted generation: a new direction toward low-latency text generation, 2023. URL https://huggingface.co/blog/assisted-generation.
  • Leviathan et al. [2023] Yaniv Leviathan, Matan Kalman, and Yossi Matias. Fast inference from transformers via speculative decoding. In International Conference on Machine Learning, pages 19274–19286. PMLR, 2023.
  • Li et al. [2023] Raymond Li, Loubna Ben Allal, Yangtian Zi, Niklas Muennighoff, Denis Kocetkov, Chenghao Mou, Marc Marone, Christopher Akiki, Jia Li, Jenny Chim, et al. Starcoder: may the source be with you! arXiv preprint arXiv:2305.06161, 2023.
  • Li et al. [2024] Yuhui Li, Fangyun Wei, Chao Zhang, and Hongyang Zhang. Eagle: Speculative sampling requires rethinking feature uncertainty. arXiv preprint arXiv:2401.15077, 2024.
  • Lin et al. [2024] Ji Lin, Jiaming Tang, Haotian Tang, Shang Yang, Wei-Ming Chen, Wei-Chen Wang, Guangxuan Xiao, Xingyu Dang, Chuang Gan, and Song Han. Awq: Activation-aware weight quantization for llm compression and acceleration. In MLSys, 2024.
  • Liu et al. [2023] Xiaoxuan Liu, Lanxiang Hu, Peter Bailis, Ion Stoica, Zhijie Deng, Alvin Cheung, and Hao Zhang. Online speculative decoding. arXiv preprint arXiv:2310.07177, 2023.
  • Ma et al. [2023] Xinyin Ma, Gongfan Fang, and Xinchao Wang. LLM-pruner: On the structural pruning of large language models. In Thirty-seventh Conference on Neural Information Processing Systems, 2023. URL https://openreview.net/forum?id=J8Ajf9WfXP.
  • Mamou et al. [2024] Jonathan Mamou, Oren Pereg, Daniel Korat, Moshe Berchansky, Nadav Timor, Moshe Wasserblat, and Roy Schwartz. Accelerating speculative decoding using dynamic speculation length, 2024.
  • Miao et al. [2023] Xupeng Miao, Gabriele Oliaro, Zhihao Zhang, Xinhao Cheng, Zeyu Wang, Rae Ying Yee Wong, Zhuoming Chen, Daiyaan Arfeen, Reyna Abhyankar, and Zhihao Jia. Specinfer: Accelerating generative llm serving with speculative inference and token tree verification. arXiv preprint arXiv:2305.09781v2, 2023.
  • OpenAI [2023] R OpenAI. Gpt-4 technical report. arXiv, pages 2303–08774, 2023.
  • Radford et al. [2019] Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei, Ilya Sutskever, et al. Language models are unsupervised multitask learners. OpenAI blog, 1(8):9, 2019.
  • Schuster et al. [2021] Tal Schuster, Adam Fisch, Tommi Jaakkola, and Regina Barzilay. Consistent accelerated inference via confident adaptive transformers. In Marie-Francine Moens, Xuan**g Huang, Lucia Specia, and Scott Wen-tau Yih, editors, Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing, pages 4962–4979, Online and Punta Cana, Dominican Republic, November 2021. Association for Computational Linguistics. doi: 10.18653/v1/2021.emnlp-main.406. URL https://aclanthology.org/2021.emnlp-main.406.
  • Shoeybi et al. [2019] Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper, and Bryan Catanzaro. Megatron-lm: Training multi-billion parameter language models using model parallelism. arXiv preprint arXiv:1909.08053, 2019.
  • Spector and Re [2023] Benjamin Spector and Chris Re. Accelerating llm inference with staged speculative decoding. arXiv preprint arXiv:2308.04623, 2023.
  • Stern et al. [2018] Mitchell Stern, Noam Shazeer, and Jakob Uszkoreit. Blockwise parallel decoding for deep autoregressive models. Advances in Neural Information Processing Systems, 31, 2018.
  • Sun et al. [2024a] Mingjie Sun, Zhuang Liu, Anna Bair, and J Zico Kolter. A simple and effective pruning approach for large language models. In The Twelfth International Conference on Learning Representations, 2024a. URL https://openreview.net/forum?id=PxoFut3dWW.
  • Sun et al. [2024b] Ziteng Sun, Ananda Theertha Suresh, Jae Hun Ro, Ahmad Beirami, Himanshu Jain, and Felix Yu. Spectr: Fast speculative decoding via optimal transport. Advances in Neural Information Processing Systems, 36, 2024b.
  • Taori et al. [2023a] Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li, Carlos Guestrin, Percy Liang, and Tatsunori B. Hashimoto. Stanford alpaca: An instruction-following llama model. https://github.com/tatsu-lab/stanford_alpaca, 2023a.
  • Taori et al. [2023b] Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li, Carlos Guestrin, Percy Liang, and Tatsunori B Hashimoto. Alpaca: A strong, replicable instruction-following model. Stanford Center for Research on Foundation Models, 2023b. https://crfm.stanford.edu/2023/03/13/alpaca.html.
  • Vaswani et al. [2017] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. Advances in neural information processing systems, 30, 2017.
  • Wei et al. [2022] Jason Wei, Xuezhi Wang, Dale Schuurmans, Maarten Bosma, Fei Xia, Ed Chi, Quoc V Le, Denny Zhou, et al. Chain-of-thought prompting elicits reasoning in large language models. Advances in neural information processing systems, 35:24824–24837, 2022.
  • Xu et al. [2023] Mingxue Xu, Yao Lei Xu, and Danilo P. Mandic. Tensorgpt: Efficient compression of the embedding layer in llms based on the tensor-train decomposition, 2023.
  • Yao et al. [2024] Zhewei Yao, Xiaoxia Wu, Cheng Li, Stephen Youn, and Yuxiong He. Exploring post-training quantization in llms from comprehensive study to low rank compensation. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 38, pages 19377–19385, 2024.
  • Zheng et al. [2024] Lianmin Zheng, Wei-Lin Chiang, Ying Sheng, Siyuan Zhuang, Zhanghao Wu, Yonghao Zhuang, Zi Lin, Zhuohan Li, Dacheng Li, Eric Xing, et al. Judging llm-as-a-judge with mt-bench and chatbot arena. Advances in Neural Information Processing Systems, 36, 2024.
  • Zhou et al. [2024] Yongchao Zhou, Kaifeng Lyu, Ankit Singh Rawat, Aditya Krishna Menon, Afshin Rostamizadeh, Sanjiv Kumar, Jean-François Kagy, and Rishabh Agarwal. Distillspec: Improving speculative decoding via knowledge distillation. In The Twelfth International Conference on Learning Representations, 2024. URL https://openreview.net/forum?id=rsY6J3ZaTF.

Appendix A Proofs

See 1

Proof.

We begin by demonstrating the losslessness of the algorithm. We would like to prove that when v=k𝑣𝑘v=kitalic_v = italic_k, there is a thread CJksubscript𝐶subscript𝐽𝑘C_{J_{k}}italic_C start_POSTSUBSCRIPT italic_J start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT, that is the only thread that is labeled as a verifier, and it correctly computes the next token and that Jk=J(m)subscript𝐽𝑘direct-sumsuperscript𝐽𝑚J_{k}=J^{\prime}\oplus(m)italic_J start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = italic_J start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ⊕ ( italic_m ) for some sequence J=(j1,,jk1)superscript𝐽subscript𝑗1subscript𝑗𝑘1J^{\prime}=(j_{1},\dots,j_{k-1})italic_J start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = ( italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_j start_POSTSUBSCRIPT italic_k - 1 end_POSTSUBSCRIPT ) of length k1𝑘1k-1italic_k - 1, where xij1,,ji=xisubscriptsuperscript𝑥subscript𝑗1subscript𝑗𝑖𝑖subscript𝑥𝑖x^{j_{1},\dots,j_{i}}_{i}=x_{i}italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_j start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT for all i[k1]𝑖delimited-[]𝑘1i\in[k-1]italic_i ∈ [ italic_k - 1 ]. We will prove this by induction on the value of v𝑣vitalic_v. In addition, we note that if this pattern is appreciated by the algorithm, then it is clearly a lossless algorithm.

Base case (v=1𝑣1v=1italic_v = 1): Initially, when v=1𝑣1v=1italic_v = 1, there is only one verifier, C(m)subscript𝐶𝑚C_{(m)}italic_C start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT, which runs the target model fmsubscript𝑓𝑚f_{m}italic_f start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT. Thus, when it finishes, it will return the correct token, x1subscript𝑥1x_{1}italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT. Since the verifier is relabeled only when the value of v𝑣vitalic_v changes (see lines 11-12), as long as v=1𝑣1v=1italic_v = 1, the only thread labeled as a verifier is C(m)subscript𝐶𝑚C_{(m)}italic_C start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT.

Induction hypothesis: Assume that as long as v=k𝑣𝑘v=kitalic_v = italic_k, there is only one thread CJksubscript𝐶subscript𝐽𝑘C_{J_{k}}italic_C start_POSTSUBSCRIPT italic_J start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT labeled as a verifier, which returns the correct token xksubscript𝑥𝑘x_{k}italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, and that Jk=J(m)subscript𝐽𝑘direct-sumsuperscript𝐽𝑚J_{k}=J^{\prime}\oplus(m)italic_J start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = italic_J start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ⊕ ( italic_m ) for some J=(j1,,jk1)superscript𝐽subscript𝑗1subscript𝑗𝑘1J^{\prime}=(j_{1},\dots,j_{k-1})italic_J start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = ( italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_j start_POSTSUBSCRIPT italic_k - 1 end_POSTSUBSCRIPT ) of length k1𝑘1k-1italic_k - 1, where xij1,,ji=xisubscriptsuperscript𝑥subscript𝑗1subscript𝑗𝑖𝑖subscript𝑥𝑖x^{j_{1},\dots,j_{i}}_{i}=x_{i}italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_j start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT for all i[k1]𝑖delimited-[]𝑘1i\in[k-1]italic_i ∈ [ italic_k - 1 ].

Induction step: When v𝑣vitalic_v is updated from k𝑘kitalic_k to k+1𝑘1k+1italic_k + 1, this change only occurs when the condition in line 7 is met. This condition indicates that the single verifier thread CJksubscript𝐶subscript𝐽𝑘C_{J_{k}}italic_C start_POSTSUBSCRIPT italic_J start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT, which is of length |Jk|=ksubscript𝐽𝑘𝑘|J_{k}|=k| italic_J start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT | = italic_k, has finished computing its output token. By the induction hypothesis, this thread returns xksubscript𝑥𝑘x_{k}italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT as its output. Since fmsubscript𝑓𝑚f_{m}italic_f start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT is slower than all drafter models f1,,fm1subscript𝑓1subscript𝑓𝑚1f_{1},\dots,f_{m-1}italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_f start_POSTSUBSCRIPT italic_m - 1 end_POSTSUBSCRIPT, all threads CJ(i)subscript𝐶direct-sumsuperscript𝐽𝑖C_{J^{\prime}\oplus(i)}italic_C start_POSTSUBSCRIPT italic_J start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ⊕ ( italic_i ) end_POSTSUBSCRIPT have already finished computing their outputs. Thus, when executing lines 8, 10, and 11, the only threads that remain active are the descendants of CJ(j)subscript𝐶direct-sumsuperscript𝐽superscript𝑗C_{J^{\prime}\oplus(j^{*})}italic_C start_POSTSUBSCRIPT italic_J start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ⊕ ( italic_j start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT, and the only thread serving as a verifier is CJ(j,m)subscript𝐶direct-sumsuperscript𝐽superscript𝑗𝑚C_{J^{\prime}\oplus(j^{*},m)}italic_C start_POSTSUBSCRIPT italic_J start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ⊕ ( italic_j start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_m ) end_POSTSUBSCRIPT. Since xij1,,ji=xisubscriptsuperscript𝑥subscript𝑗1subscript𝑗𝑖𝑖subscript𝑥𝑖x^{j_{1},\dots,j_{i}}_{i}=x_{i}italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_j start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT for all ik1𝑖𝑘1i\leq k-1italic_i ≤ italic_k - 1 and xkj1,,jk1,j=xksubscriptsuperscript𝑥subscript𝑗1subscript𝑗𝑘1superscript𝑗𝑘subscript𝑥𝑘x^{j_{1},\dots,j_{k-1},j^{*}}_{k}=x_{k}italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_j start_POSTSUBSCRIPT italic_k - 1 end_POSTSUBSCRIPT , italic_j start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, then CJ(j,m)subscript𝐶direct-sumsuperscript𝐽superscript𝑗𝑚C_{J^{\prime}\oplus(j^{*},m)}italic_C start_POSTSUBSCRIPT italic_J start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ⊕ ( italic_j start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_m ) end_POSTSUBSCRIPT simply computes the output of the target model fmsubscript𝑓𝑚f_{m}italic_f start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT on the correct sequence x0x1xkdirect-sumsubscript𝑥absent0subscript𝑥1subscript𝑥𝑘x_{\leq 0}\oplus x_{1}\oplus\dots\oplus x_{k}italic_x start_POSTSUBSCRIPT ≤ 0 end_POSTSUBSCRIPT ⊕ italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⊕ ⋯ ⊕ italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT. Hence, it correctly returns the (k+1)𝑘1(k+1)( italic_k + 1 )th token xk+1subscript𝑥𝑘1x_{k+1}italic_x start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT, as desired.

Time: We notice that the algorithm terminates once it has computed the output of CJNsubscript𝐶subscript𝐽𝑁C_{J_{N}}italic_C start_POSTSUBSCRIPT italic_J start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT end_POSTSUBSCRIPT. By Assumption 3, we have Twall[Algorithm1]=i=1NTwall[computing fji(xi)]subscript𝑇walldelimited-[]Algorithm1subscriptsuperscript𝑁𝑖1subscript𝑇walldelimited-[]computing subscript𝑓subscript𝑗𝑖subscript𝑥absent𝑖T_{\text{wall}}\left[\text{Algorithm}~{}\ref{alg:concurrent_informal}\right]=% \sum^{N}_{i=1}T_{\text{wall}}\left[\text{computing }f_{j_{i}}\left(x_{\leq i}% \right)\right]italic_T start_POSTSUBSCRIPT wall end_POSTSUBSCRIPT [ Algorithm ] = ∑ start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT wall end_POSTSUBSCRIPT [ computing italic_f start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT ≤ italic_i end_POSTSUBSCRIPT ) ] and by Assumption 2, we have Twall[computing fji(xi)]Twall[computing fm(xi)]subscript𝑇walldelimited-[]computing subscript𝑓subscript𝑗𝑖subscript𝑥absent𝑖subscript𝑇walldelimited-[]computing subscript𝑓𝑚subscript𝑥absent𝑖T_{\text{wall}}\left[\text{computing }f_{j_{i}}\left(x_{\leq i}\right)\right]% \leq T_{\text{wall}}\left[\text{computing }f_{m}\left(x_{\leq i}\right)\right]italic_T start_POSTSUBSCRIPT wall end_POSTSUBSCRIPT [ computing italic_f start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT ≤ italic_i end_POSTSUBSCRIPT ) ] ≤ italic_T start_POSTSUBSCRIPT wall end_POSTSUBSCRIPT [ computing italic_f start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT ≤ italic_i end_POSTSUBSCRIPT ) ]. Together we obtain Twall[Algorithm1]i=1NTwall[computing fm(xi)]subscript𝑇walldelimited-[]Algorithm1subscriptsuperscript𝑁𝑖1subscript𝑇walldelimited-[]computing subscript𝑓𝑚subscript𝑥absent𝑖T_{\text{wall}}\left[\text{Algorithm}~{}\ref{alg:concurrent_informal}\right]% \leq\sum^{N}_{i=1}T_{\text{wall}}\left[\text{computing }f_{m}\left(x_{\leq i}% \right)\right]italic_T start_POSTSUBSCRIPT wall end_POSTSUBSCRIPT [ Algorithm ] ≤ ∑ start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT wall end_POSTSUBSCRIPT [ computing italic_f start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT ≤ italic_i end_POSTSUBSCRIPT ) ] which is the amount of time that it takes to compute the output tokens without speculative inference. ∎

See 1

Proof.

To understand how it works, let j1{1,2}subscript𝑗112j_{1}\in\{1,2\}italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∈ { 1 , 2 } be the smallest index such that x1j1=x1subscriptsuperscript𝑥subscript𝑗11subscript𝑥1x^{j_{1}}_{1}=x_{1}italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and for all i[N1]𝑖delimited-[]𝑁1i\in[N-1]italic_i ∈ [ italic_N - 1 ], we recursively define ji{1,2}subscript𝑗𝑖12j_{i}\in\{1,2\}italic_j start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ { 1 , 2 } to be the smallest index such that xij1,,ji=xisubscriptsuperscript𝑥subscript𝑗1subscript𝑗𝑖𝑖subscript𝑥𝑖x^{j_{1},\dots,j_{i}}_{i}=x_{i}italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_j start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. We also fix jN=2subscript𝑗𝑁2j_{N}=2italic_j start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT = 2. In addition, let i0=0subscript𝑖00i_{0}=0italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 0 and irsubscript𝑖𝑟i_{r}italic_i start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT be the r𝑟ritalic_rth index in [N]delimited-[]𝑁[N][ italic_N ] such that jir=2subscript𝑗subscript𝑖𝑟2j_{i_{r}}=2italic_j start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT = 2. We notice that it takes t1(i11)+t2subscript𝑡1subscript𝑖11subscript𝑡2t_{1}(i_{1}-1)+t_{2}italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_i start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - 1 ) + italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT time units to compute the value of xi1j1,,ji1subscriptsuperscript𝑥subscript𝑗1subscript𝑗subscript𝑖1subscript𝑖1x^{j_{1},\ldots,j_{i_{1}}}_{i_{1}}italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_j start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT. This is because we first compute x11subscriptsuperscript𝑥11x^{1}_{1}italic_x start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, then x11,1subscriptsuperscript𝑥111x^{1,1}_{1}italic_x start_POSTSUPERSCRIPT 1 , 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, continuing up to xi111,,1subscriptsuperscript𝑥11subscript𝑖11x^{1,\ldots,1}_{i_{1}-1}italic_x start_POSTSUPERSCRIPT 1 , … , 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - 1 end_POSTSUBSCRIPT, and finally xi11,,1,2subscriptsuperscript𝑥112subscript𝑖1x^{1,\ldots,1,2}_{i_{1}}italic_x start_POSTSUPERSCRIPT 1 , … , 1 , 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT. Each of the first (i11)subscript𝑖11(i_{1}-1)( italic_i start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - 1 ) tokens takes t1subscript𝑡1t_{1}italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT time units, while the final token takes t2subscript𝑡2t_{2}italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT time units. After t1(i11)+t2subscript𝑡1subscript𝑖11subscript𝑡2t_{1}(i_{1}-1)+t_{2}italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_i start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - 1 ) + italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT time units, we will have computed x12subscriptsuperscript𝑥21x^{2}_{1}italic_x start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, x21,2subscriptsuperscript𝑥122x^{1,2}_{2}italic_x start_POSTSUPERSCRIPT 1 , 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, x31,1,2subscriptsuperscript𝑥1123x^{1,1,2}_{3}italic_x start_POSTSUPERSCRIPT 1 , 1 , 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT, and so on, up to xi11,,1,2subscriptsuperscript𝑥112subscript𝑖1x^{1,\ldots,1,2}_{i_{1}}italic_x start_POSTSUPERSCRIPT 1 , … , 1 , 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT. Since f1subscript𝑓1f_{1}italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT consistently generates accurate tokens up to index i11subscript𝑖11i_{1}-1italic_i start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - 1, once we observe that x12subscriptsuperscript𝑥21x^{2}_{1}italic_x start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT matches x11subscriptsuperscript𝑥11x^{1}_{1}italic_x start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, we know that x21,2=x2subscriptsuperscript𝑥122subscript𝑥2x^{1,2}_{2}=x_{2}italic_x start_POSTSUPERSCRIPT 1 , 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT and can then verify that x21,1=x2subscriptsuperscript𝑥112subscript𝑥2x^{1,1}_{2}=x_{2}italic_x start_POSTSUPERSCRIPT 1 , 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT is also correct. Once we verify that x21,1=x2subscriptsuperscript𝑥112subscript𝑥2x^{1,1}_{2}=x_{2}italic_x start_POSTSUPERSCRIPT 1 , 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, we can verify x21,1,2subscriptsuperscript𝑥1122x^{1,1,2}_{2}italic_x start_POSTSUPERSCRIPT 1 , 1 , 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT and continue this pattern to verify x21,1,1subscriptsuperscript𝑥1112x^{1,1,1}_{2}italic_x start_POSTSUPERSCRIPT 1 , 1 , 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, and so forth. We note that calculating all of these tokens up to the calculation of xi11,,1,2subscriptsuperscript𝑥112subscript𝑖1x^{1,\dots,1,2}_{i_{1}}italic_x start_POSTSUPERSCRIPT 1 , … , 1 , 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT take at most t1(i11)+t2subscript𝑡1subscript𝑖11subscript𝑡2t_{1}(i_{1}-1)+t_{2}italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_i start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - 1 ) + italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT time units. Thus, we can verify that xi11,,1,2=xi1subscriptsuperscript𝑥112subscript𝑖1subscript𝑥subscript𝑖1x^{1,\ldots,1,2}_{i_{1}}=x_{i_{1}}italic_x start_POSTSUPERSCRIPT 1 , … , 1 , 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT = italic_x start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT with at most t1(i11)+t2subscript𝑡1subscript𝑖11subscript𝑡2t_{1}(i_{1}-1)+t_{2}italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_i start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - 1 ) + italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT time units. By the same argument as above, it takes r(t1((irir1)1)+t2)subscript𝑟subscript𝑡1subscript𝑖𝑟subscript𝑖𝑟11subscript𝑡2\sum_{r}(t_{1}((i_{r}-i_{r-1})-1)+t_{2})∑ start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( ( italic_i start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT - italic_i start_POSTSUBSCRIPT italic_r - 1 end_POSTSUBSCRIPT ) - 1 ) + italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) time units to compute the value of xNj1,,jNsubscriptsuperscript𝑥subscript𝑗1subscript𝑗𝑁𝑁x^{j_{1},\ldots,j_{N}}_{N}italic_x start_POSTSUPERSCRIPT italic_j start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_j start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT (and to verify its correctness). We notice that Q=r(irir11)𝑄subscript𝑟subscript𝑖𝑟subscript𝑖𝑟11Q=\sum_{r}(i_{r}-i_{r-1}-1)italic_Q = ∑ start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ( italic_i start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT - italic_i start_POSTSUBSCRIPT italic_r - 1 end_POSTSUBSCRIPT - 1 ) is the number of indices i[N1]𝑖delimited-[]𝑁1i\in[N-1]italic_i ∈ [ italic_N - 1 ] such that ji=1subscript𝑗𝑖1j_{i}=1italic_j start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1. Since 𝔼[Q]=p(N1)𝔼delimited-[]𝑄𝑝𝑁1\mathbb{E}[Q]=p(N-1)blackboard_E [ italic_Q ] = italic_p ( italic_N - 1 ), we have 𝔼[r(t1((irir1)1)+t2)]=t1p(N1)+t2((1p)(N1)+1)𝔼delimited-[]subscript𝑟subscript𝑡1subscript𝑖𝑟subscript𝑖𝑟11subscript𝑡2subscript𝑡1𝑝𝑁1subscript𝑡21𝑝𝑁11\mathbb{E}\left[\sum_{r}(t_{1}((i_{r}-i_{r-1})-1)+t_{2})\right]=t_{1}p(N-1)+t_% {2}((1-p)(N-1)+1)blackboard_E [ ∑ start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( ( italic_i start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT - italic_i start_POSTSUBSCRIPT italic_r - 1 end_POSTSUBSCRIPT ) - 1 ) + italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ] = italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_p ( italic_N - 1 ) + italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( ( 1 - italic_p ) ( italic_N - 1 ) + 1 ). ∎

See 2

Proof.

Suppose we have a drafter model f1subscript𝑓1f_{1}italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, a target model f2subscript𝑓2f_{2}italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, and a prompt x0subscript𝑥absent0x_{\leq 0}italic_x start_POSTSUBSCRIPT ≤ 0 end_POSTSUBSCRIPT. Assume that f1subscript𝑓1f_{1}italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT requires t1subscript𝑡1t_{1}italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT time units to compute each of its outputs, and f2subscript𝑓2f_{2}italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT requires t2subscript𝑡2t_{2}italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT time units, where t1<t2subscript𝑡1subscript𝑡2t_{1}<t_{2}italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT < italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. Assume that given the prompt xi=x0x1xisubscript𝑥absent𝑖direct-sumsubscript𝑥absent0subscript𝑥1subscript𝑥𝑖x_{\leq i}=x_{\leq 0}\oplus x_{1}\oplus\dots\oplus x_{i}italic_x start_POSTSUBSCRIPT ≤ italic_i end_POSTSUBSCRIPT = italic_x start_POSTSUBSCRIPT ≤ 0 end_POSTSUBSCRIPT ⊕ italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⊕ ⋯ ⊕ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, the probability that f1subscript𝑓1f_{1}italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT returns the (correct) token xi+1subscript𝑥𝑖1x_{i+1}italic_x start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT is p𝑝pitalic_p. Consider generating N>k+1𝑁𝑘1N>k+1italic_N > italic_k + 1 tokens from f2subscript𝑓2f_{2}italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT using the SI (or DSI) algorithm with lookahead=klookahead𝑘\texttt{lookahead}=klookahead = italic_k. At time =0absent0=0= 0, SI starts generating draft tokens, by the definition of SI. At time =kt1absent𝑘subscript𝑡1=kt_{1}= italic_k italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, SI completes generating the first k𝑘kitalic_k draft tokens x11,x21,1,,xk1,,1superscriptsubscript𝑥11superscriptsubscript𝑥211superscriptsubscript𝑥𝑘11x_{1}^{1},x_{2}^{1,1},\ldots,x_{k}^{1,\ldots,1}italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 , 1 end_POSTSUPERSCRIPT , … , italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 , … , 1 end_POSTSUPERSCRIPT. At time =kt1+t2absent𝑘subscript𝑡1subscript𝑡2=kt_{1}+t_{2}= italic_k italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, SI completes verifying the first k𝑘kitalic_k tokens x11,x21,1,,xk1,,1superscriptsubscript𝑥11superscriptsubscript𝑥211superscriptsubscript𝑥𝑘11x_{1}^{1},x_{2}^{1,1},\ldots,x_{k}^{1,\ldots,1}italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 , 1 end_POSTSUPERSCRIPT , … , italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 , … , 1 end_POSTSUPERSCRIPT. Let A1,A2,,Ak+1subscript𝐴1subscript𝐴2subscript𝐴𝑘1A_{1},A_{2},\ldots,A_{k+1}italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_A start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT be indicator variables sampled as follows. Ai=1subscript𝐴𝑖1A_{i}=1italic_A start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1 with probability p𝑝pitalic_p and Ai=0subscript𝐴𝑖0A_{i}=0italic_A start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 0 otherwise, for all i[k+1]𝑖delimited-[]𝑘1i\in\left[{k+1}\right]italic_i ∈ [ italic_k + 1 ]. Let n:=min{i|Ai=0}1assign𝑛conditional𝑖subscript𝐴𝑖01n:=\min\{i|A_{i}=0\}-1italic_n := roman_min { italic_i | italic_A start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 0 } - 1. Note that n𝑛nitalic_n is distributed as the number of accepted drafts among the first k𝑘kitalic_k drafts of SI (or DSI). SI completes generating the first n+1𝑛1n+1italic_n + 1 tokens at time =kt1+t2absent𝑘subscript𝑡1subscript𝑡2=kt_{1}+t_{2}= italic_k italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT for any n{0,1,,k}𝑛01𝑘n\in\left\{{0,1,\ldots,k}\right\}italic_n ∈ { 0 , 1 , … , italic_k }, by the definition of SI. The first iteration of SI cannot output tokens at positions >k+1absent𝑘1>k+1> italic_k + 1, by the definition of SI. The earliest time at which SI can complete generating xk+2subscript𝑥𝑘2x_{k+2}italic_x start_POSTSUBSCRIPT italic_k + 2 end_POSTSUBSCRIPT is by the end of its second iteration. Hence, SI completes generating xk+2subscript𝑥𝑘2x_{k+2}italic_x start_POSTSUBSCRIPT italic_k + 2 end_POSTSUBSCRIPT at time 2(kt1+t2)absent2𝑘subscript𝑡1subscript𝑡2\geq 2(kt_{1}+t_{2})≥ 2 ( italic_k italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ). Consider DSI with the same f1subscript𝑓1f_{1}italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, f2subscript𝑓2f_{2}italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, and lookahead over at least t2kt1subscript𝑡2𝑘subscript𝑡1\left\lceil\frac{t_{2}}{kt_{1}}\right\rceil⌈ divide start_ARG italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG start_ARG italic_k italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG ⌉ servers. We show that DSI completes generating xk+2subscript𝑥𝑘2x_{k+2}italic_x start_POSTSUBSCRIPT italic_k + 2 end_POSTSUBSCRIPT at time 2(kt1+t2)absent2𝑘subscript𝑡1subscript𝑡2\leq 2(kt_{1}+t_{2})≤ 2 ( italic_k italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ), and, in expectation, at time <2(kt1+t2)absent2𝑘subscript𝑡1subscript𝑡2<2(kt_{1}+t_{2})< 2 ( italic_k italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ). By the definition of DSI, DSI never preempt the current verifier. At time =kt1absent𝑘subscript𝑡1=kt_{1}= italic_k italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, DSI invokes concurrently (i) the verifying of the batch containing the first k𝑘kitalic_k tokens x11,x21,1,,xk1,,1superscriptsubscript𝑥11superscriptsubscript𝑥211superscriptsubscript𝑥𝑘11x_{1}^{1},x_{2}^{1,1},\ldots,x_{k}^{1,\ldots,1}italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 , 1 end_POSTSUPERSCRIPT , … , italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 , … , 1 end_POSTSUPERSCRIPT that are not yet verified, and (ii) the drafting of xk+11,,1xk+21,,1,,x2k1,,1superscriptsubscript𝑥𝑘111superscriptsubscript𝑥𝑘211superscriptsubscript𝑥2𝑘11x_{k+1}^{1,\ldots,1}x_{k+2}^{1,\ldots,1},\ldots,x_{2k}^{1,\ldots,1}italic_x start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 , … , 1 end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_k + 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 , … , 1 end_POSTSUPERSCRIPT , … , italic_x start_POSTSUBSCRIPT 2 italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 , … , 1 end_POSTSUPERSCRIPT, by the definition of DSI. If n=k𝑛𝑘n=kitalic_n = italic_k, then both SI and DSI complete generating the (k+1)𝑘1(k+1)( italic_k + 1 )th first token xk+1subscript𝑥𝑘1x_{k+1}italic_x start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT at time =kt1+t2absent𝑘subscript𝑡1subscript𝑡2=kt_{1}+t_{2}= italic_k italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. At that time, DSI either invokes a new current verifier thread or labels an existing thread as the current verifier (depending on t2t1subscript𝑡2subscript𝑡1\frac{t_{2}}{t_{1}}divide start_ARG italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG start_ARG italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG and the lookahead ). Hence, DSI completes generating xk+2subscript𝑥𝑘2x_{k+2}italic_x start_POSTSUBSCRIPT italic_k + 2 end_POSTSUBSCRIPT at time kt1+2t2absent𝑘subscript𝑡12subscript𝑡2\leq kt_{1}+2t_{2}≤ italic_k italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + 2 italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, exactly when the current verifier thread completes its verification. DSI is faster than SI for all t1,t2,ksubscript𝑡1subscript𝑡2𝑘t_{1},t_{2},kitalic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_k since kt1+2t2<2(kt1+t2)𝑘subscript𝑡12subscript𝑡22𝑘subscript𝑡1subscript𝑡2kt_{1}+2t_{2}<2(kt_{1}+t_{2})italic_k italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + 2 italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT < 2 ( italic_k italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ). Otherwise, both algorithms accept the first n+1𝑛1n+1italic_n + 1 tokens at time kt1+t2absent𝑘subscript𝑡1subscript𝑡2\leq kt_{1}+t_{2}≤ italic_k italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. At that time, the proof repeats for N(n+1)𝑁𝑛1N-(n+1)italic_N - ( italic_n + 1 ). ∎

See 2

Proof.

Consider SI and DSI with the same lookahead . We use a coupling argument to align the two algorithms such that the i𝑖iitalic_ith iteration of SI is aligned with corresponding generations of DSI, as follows. Let A1,A2,,Ak+1subscript𝐴1subscript𝐴2subscript𝐴𝑘1A_{1},A_{2},\ldots,A_{k+1}italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_A start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT be indicator variables defined in the following manner. We define Ai=1subscript𝐴𝑖1A_{i}=1italic_A start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1 with probability p𝑝pitalic_p and Ai=0subscript𝐴𝑖0A_{i}=0italic_A start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 0 otherwise, for all i[k+1]𝑖delimited-[]𝑘1i\in\left[{k+1}\right]italic_i ∈ [ italic_k + 1 ]. Let n:=min{iAi=0}1assign𝑛conditional𝑖subscript𝐴𝑖01n:=\min\{i\mid A_{i}=0\}-1italic_n := roman_min { italic_i ∣ italic_A start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 0 } - 1. Note that n𝑛nitalic_n is distributed as the number of accepted drafts in any batch of both SI and DSI. Both algorithms generate n+1𝑛1n+1italic_n + 1 new tokens at every iteration. Here, we assume that SI and DSI generate n+1𝑛1n+1italic_n + 1 tokens at their i𝑖iitalic_ith iteration, and may only differ in the time that it takes. Denote the number of iterations of the algorithms by r𝑟ritalic_r. Since all r𝑟ritalic_r iteration of SI takes constant time, we have that SI completes at time =r(kt1+t2)absent𝑟𝑘subscript𝑡1subscript𝑡2=r(kt_{1}+t_{2})= italic_r ( italic_k italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ). For every iteration i[r]𝑖delimited-[]𝑟i\in\left[{r}\right]italic_i ∈ [ italic_r ] we sample nisubscript𝑛𝑖n_{i}italic_n start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT from new indicators A1,A2,,Ak+1subscript𝐴1subscript𝐴2subscript𝐴𝑘1A_{1},A_{2},\ldots,A_{k+1}italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_A start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT as defined above. At time =0absent0=0= 0, DSI invokes a new current verifier that completes generating x1,x2,x3,,xn1+1subscript𝑥1subscript𝑥2subscript𝑥3subscript𝑥subscript𝑛11x_{1},x_{2},x_{3},\ldots,x_{n_{1}+1}italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT , … , italic_x start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + 1 end_POSTSUBSCRIPT at times t2,2t2,3t2,,(n1+1)t2subscript𝑡22subscript𝑡23subscript𝑡2subscript𝑛11subscript𝑡2t_{2},2t_{2},3t_{2},\ldots,(n_{1}+1)t_{2}italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , 2 italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , 3 italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , ( italic_n start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + 1 ) italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT respectively, unless a faster thread is labeled as the current verifier in line 11. Similarly, at the beginning of the i𝑖iitalic_ith iteration, DSI invokes a new current verifier that completes generating the next ni+1subscript𝑛𝑖1n_{i}+1italic_n start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + 1 tokens within (ni+1)t2subscript𝑛𝑖1subscript𝑡2(n_{i}+1)t_{2}( italic_n start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + 1 ) italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT time units. Hence, DSI completes generating the ni+1subscript𝑛𝑖1n_{i}+1italic_n start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + 1 new tokens of the i𝑖iitalic_ith iteration before SI does if nit2<kt1subscript𝑛𝑖subscript𝑡2𝑘subscript𝑡1n_{i}t_{2}<kt_{1}italic_n start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT < italic_k italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, and at the same time of SI otherwise. Assume that ni=ksubscript𝑛𝑖𝑘n_{i}=kitalic_n start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_k for all i[r]𝑖delimited-[]𝑟i\in\left[{r}\right]italic_i ∈ [ italic_r ], namely, that all of the drafts are accepted. DSI completes at time =rkt1+t2absent𝑟𝑘subscript𝑡1subscript𝑡2=rkt_{1}+t_{2}= italic_r italic_k italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. It means that DSI completes (r1)t2𝑟1subscript𝑡2(r-1)t_{2}( italic_r - 1 ) italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT time units before SI does. Let Bisubscript𝐵𝑖B_{i}italic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT be an indicator such that Bi=1subscript𝐵𝑖1B_{i}=1italic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1 if ni=ksubscript𝑛𝑖𝑘n_{i}=kitalic_n start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_k and Bi=0subscript𝐵𝑖0B_{i}=0italic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 0 otherwise, for all i[r]𝑖delimited-[]𝑟i\in\left[{r}\right]italic_i ∈ [ italic_r ]. Note that Bi=1subscript𝐵𝑖1B_{i}=1italic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1 with probability pksuperscript𝑝𝑘p^{k}italic_p start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT. Define Y1,Y2,,Yssubscript𝑌1subscript𝑌2subscript𝑌𝑠Y_{1},Y_{2},\ldots,Y_{s}italic_Y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_Y start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_Y start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT as follows. Y1subscript𝑌1Y_{1}italic_Y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT denotes the length of the first consecutive successes: Y1=y1y1subscript𝑌1superscriptsubscript𝑦1subscript𝑦1Y_{1}=y_{1}^{\prime}-y_{1}italic_Y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT - italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT where y1=mini[r]{iBi=1}subscript𝑦1subscript𝑖delimited-[]𝑟conditional𝑖subscript𝐵𝑖1y_{1}=\min\limits_{i\in\left[{r}\right]}\left\{{i\mid B_{i}=1}\right\}italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = roman_min start_POSTSUBSCRIPT italic_i ∈ [ italic_r ] end_POSTSUBSCRIPT { italic_i ∣ italic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1 }. If y1subscript𝑦1y_{1}\neq\emptysetitalic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ≠ ∅, then we define y1:=min{r+1,mini[r+1]{iBi1 and Bi=0}}assignsuperscriptsubscript𝑦1𝑟1subscript𝑖delimited-[]𝑟1conditional𝑖subscript𝐵𝑖1 and subscript𝐵𝑖0y_{1}^{\prime}:=\min\left\{{r+1,\min\limits_{i\in\left[{r+1}\right]}\left\{{i% \mid B_{i-1}\text{ and }B_{i}=0}\right\}}\right\}italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT := roman_min { italic_r + 1 , roman_min start_POSTSUBSCRIPT italic_i ∈ [ italic_r + 1 ] end_POSTSUBSCRIPT { italic_i ∣ italic_B start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT and italic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 0 } }. For example, if r=14𝑟14r=14italic_r = 14 and B1=0,B2=0,B3=1,B4=1,B5=1,B6=0,B7=0,B8=0,B9=1,B10=0,B11=1,B12=1,B13=0,B14=0formulae-sequencesubscript𝐵10formulae-sequencesubscript𝐵20formulae-sequencesubscript𝐵31formulae-sequencesubscript𝐵41formulae-sequencesubscript𝐵51formulae-sequencesubscript𝐵60formulae-sequencesubscript𝐵70formulae-sequencesubscript𝐵80formulae-sequencesubscript𝐵91formulae-sequencesubscript𝐵100formulae-sequencesubscript𝐵111formulae-sequencesubscript𝐵121formulae-sequencesubscript𝐵130subscript𝐵140B_{1}=0,B_{2}=0,B_{3}=1,B_{4}=1,B_{5}=1,B_{6}=0,B_{7}=0,B_{8}=0,B_{9}=1,B_{10}% =0,B_{11}=1,B_{12}=1,B_{13}=0,B_{14}=0italic_B start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = 0 , italic_B start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = 0 , italic_B start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT = 1 , italic_B start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT = 1 , italic_B start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPT = 1 , italic_B start_POSTSUBSCRIPT 6 end_POSTSUBSCRIPT = 0 , italic_B start_POSTSUBSCRIPT 7 end_POSTSUBSCRIPT = 0 , italic_B start_POSTSUBSCRIPT 8 end_POSTSUBSCRIPT = 0 , italic_B start_POSTSUBSCRIPT 9 end_POSTSUBSCRIPT = 1 , italic_B start_POSTSUBSCRIPT 10 end_POSTSUBSCRIPT = 0 , italic_B start_POSTSUBSCRIPT 11 end_POSTSUBSCRIPT = 1 , italic_B start_POSTSUBSCRIPT 12 end_POSTSUBSCRIPT = 1 , italic_B start_POSTSUBSCRIPT 13 end_POSTSUBSCRIPT = 0 , italic_B start_POSTSUBSCRIPT 14 end_POSTSUBSCRIPT = 0, then Y1=3,Y2=1,Y3=2formulae-sequencesubscript𝑌13formulae-sequencesubscript𝑌21subscript𝑌32Y_{1}=3,Y_{2}=1,Y_{3}=2italic_Y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = 3 , italic_Y start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = 1 , italic_Y start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT = 2. DSI completes 6t26subscript𝑡26t_{2}6 italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT times units before SI does—or earlier. ∎

Appendix B Datasets and Prompts Details

We use standard datasets from Hugging Face and standard prompts from the state-of-the-art.

B.1 MBPP

MBPP dataset consists of crowd-sourced Python programming problems and is distributed under the cc-by-4.0 License.

Concerning the prompt, we followed [Ben Allal et al., 2022, Fried et al., 2023] and included the description of the programming task and a single test to verify solution, in order to help the model catch the signature of the function (see Figure 2).

"""{text}
{test_list[0]}
"""
Figure 2: MBPP Prompt

B.2 HumanEval

HumanEval dataset includes programming problems and is distributed under the MIT License.

Prompt contains only prompt field from the dataset.

B.3 CNN-DM

CNN-DM contains news articles and is distributed under the Apache License 2.0.

We included the article field in the prompt as in Figure 3.

"""Summarize:
{article}
Summary:
"""
Figure 3: CNN-DM Prompt

B.4 Alpaca

Alpaca dataset contains instructions and demonstrations. It is distributed under the cc-by-nc-4.0 License.

We follow Taori et al. [2023b] to define the prompts. For samples with a non-empty input field, we use the prompt as in Figure 4 while for samples with empty input field, we use the prompt as in Figure 5.

 """Below is an instruction that describes a
 task, paired with an input that provides
 further context. Write a response that
 appropriately completes the request.

### Instruction:
{instruction}

### Input:
{input}

### Response:
"""

Figure 4: Alpaca prompt for samples with a non-empty input field.
"""Below is an instruction that describes a
task. Write a response that appropriately
completes the request.

### Instruction:
{instruction}

### Response:
"""
Figure 5: Alpaca prompt for samples with empty input field.

Appendix C Models

For all models, we retrieve model weights from Hugging Face. For clarity and reproducibility, we provide the URLs for each model used:

Appendix D Experiments Results

D.1 SI Implementation

def si(target_latency: float, drafter_latency: float, lookahead: int, N: int) -> float:
total_cost: float = 0
total_toks: int = 0
while total_toks < N:
num_accepted: int = get_num_accepted()
total_toks += num_accepted + 1
total_cost += lookahead * drafter_latency + target_latency
return total_cost

D.2 Speedups for lookahead = 5

Refer to caption Refer to caption Refer to caption
(a) SI/non-SI (b) SI/DSI (c) non-SI/DSI
Figure 6: Each heatmap (labeled “X/Y”) plots the ratio between the run time of algorithm X and the run time of algorithm Y. SI is run with lookahead =5lookahead 5\text{{lookahead} }=5typewriter_lookahead = 5. (a): SI is slower than non-speculative inference (non-SI) when the drafter is either slow or inaccurate enough (pink marks slowdowns). DSI is never slower than either SI or non-SI. (b, c): DSI is always faster than speculative inference (SI) and non-speculative inference (non-SI) algorithms for various drafters.