mistralrs_core/pipeline/
speculative.rs

1use std::{
2    any::Any,
3    sync::{Arc, Mutex},
4    time::{Duration, Instant},
5};
6
7use anyhow::Result as anyhowResult;
8use candle_core::{Device, IndexOp, Result, Tensor};
9use mistralrs_quant::IsqType;
10use rand_isaac::Isaac64Rng;
11use tokenizers::Tokenizer;
12use tracing::warn;
13
14use crate::{
15    device_map::DeviceMapper,
16    get_mut_arcmutex,
17    pipeline::sampling::{
18        finish_or_add_toks_to_seq, sample_sequence, sample_target_sequence_speculative,
19    },
20    prefix_cacher::PrefixCacheManagerV2,
21    sequence::Sequence,
22    DeviceMapSetting, Loader, ModelKind, PagedAttentionConfig, Pipeline, TokenSource, TryIntoDType,
23};
24
25use super::{
26    cache_manager::NormalCacheManager, chat_template::ChatTemplate, sampling::SpeculativeSample,
27    AnyMoePipelineMixin, CacheBackendMetadata, CacheInstruction, CacheManager, CacheManagerMixin,
28    EitherCache, ForwardInputsResult, GeneralMetadata, IsqPipelineMixin, MetadataMixin,
29    ModelCategory, ModelPaths, PreProcessingMixin,
30};
31
32/// A loader for a speculative pipeline using 2 [`Loader`]s.
33pub struct SpeculativeLoader {
34    pub target: Box<dyn Loader>,
35    pub draft: Box<dyn Loader>,
36    pub config: SpeculativeConfig,
37}
38
39impl Loader for SpeculativeLoader {
40    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
41    fn load_model_from_hf(
42        &self,
43        revision: Option<String>,
44        token_source: TokenSource,
45        dtype: &dyn TryIntoDType,
46        device: &Device,
47        silent: bool,
48        mapper: DeviceMapSetting,
49        in_situ_quant: Option<IsqType>,
50        paged_attn_config: Option<PagedAttentionConfig>,
51    ) -> anyhowResult<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
52        let paged_attn_config = if paged_attn_config.is_none() {
53            warn!(
54                "Speculative decoding does not currently support PagedAttention, running without"
55            );
56            None
57        } else {
58            paged_attn_config
59        };
60
61        let target = self.target.load_model_from_hf(
62            revision.clone(),
63            token_source.clone(),
64            dtype,
65            device,
66            silent,
67            mapper.clone(),
68            in_situ_quant,
69            paged_attn_config,
70        )?;
71        let draft = self.draft.load_model_from_hf(
72            revision,
73            token_source,
74            dtype,
75            device,
76            silent,
77            mapper,
78            in_situ_quant,
79            paged_attn_config,
80        )?;
81        Ok(Arc::new(tokio::sync::Mutex::new(SpeculativePipeline::new(
82            target,
83            draft,
84            self.config,
85        )?)))
86    }
87
88    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
89    fn load_model_from_path(
90        &self,
91        paths: &Box<dyn ModelPaths>,
92        dtype: &dyn TryIntoDType,
93        device: &Device,
94        silent: bool,
95        mapper: DeviceMapSetting,
96        in_situ_quant: Option<IsqType>,
97        paged_attn_config: Option<PagedAttentionConfig>,
98    ) -> anyhowResult<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
99        let paged_attn_config = if paged_attn_config.is_none() {
100            warn!(
101                "Speculative decoding does not currently support PagedAttention, running without"
102            );
103            None
104        } else {
105            paged_attn_config
106        };
107
108        let target = self.target.load_model_from_path(
109            paths,
110            dtype,
111            device,
112            silent,
113            mapper.clone(),
114            in_situ_quant,
115            paged_attn_config,
116        )?;
117        let draft = self.draft.load_model_from_path(
118            paths,
119            dtype,
120            device,
121            silent,
122            mapper.clone(),
123            in_situ_quant,
124            paged_attn_config,
125        )?;
126        Ok(Arc::new(tokio::sync::Mutex::new(SpeculativePipeline::new(
127            target,
128            draft,
129            self.config,
130        )?)))
131    }
132    fn get_id(&self) -> String {
133        format!(
134            "Speculative: tgt = `{}`, draft = `{}`, gamma = `{}`",
135            self.target.get_id(),
136            self.draft.get_id(),
137            self.config.gamma,
138        )
139    }
140    fn get_kind(&self) -> ModelKind {
141        ModelKind::Speculative {
142            target: Box::new(self.target.get_kind()),
143            draft: Box::new(self.draft.get_kind()),
144        }
145    }
146}
147
148/// Speculative decoding pipeline: <https://arxiv.org/pdf/2211.17192>
149///
150/// # Algorithm
151/// Given draft model q and target model p with probability distributions \
152/// q_i(x) and p_i(x) for each token
153///
154/// - Keep the sample for token i if q_i(x) <= p_i(x)
155///     - This means the target model agrees
156/// - Else (q_i(x) > p_i(x)) accept that token with prob p_i(x)/q_i(x)
157///     - If rejected, sample token from from p'_i(x) = norm(max(0, p(x) − q(x))) and do not take any more'
158///
159pub struct SpeculativePipeline {
160    target: Arc<tokio::sync::Mutex<dyn Pipeline>>,
161    draft: Arc<tokio::sync::Mutex<dyn Pipeline>>,
162    gamma: usize,
163    metadata: Arc<GeneralMetadata>,
164    category: ModelCategory,
165}
166
167#[derive(Copy, Clone)]
168/// Metadata for a speculative pipeline
169pub struct SpeculativeConfig {
170    /// γ completions to run of the draft model
171    pub gamma: usize,
172}
173
174impl SpeculativePipeline {
175    pub fn new(
176        target: Arc<tokio::sync::Mutex<dyn Pipeline>>,
177        draft: Arc<tokio::sync::Mutex<dyn Pipeline>>,
178        config: SpeculativeConfig,
179    ) -> Result<Self> {
180        if get_mut_arcmutex!(target)
181            .tokenizer()
182            .as_ref()
183            .ok_or(candle_core::Error::Msg(
184                "`SpeculativePipeline::new` requires the target pipeline to have a token trie"
185                    .to_string(),
186            ))?
187            .get_vocab(true)
188            != get_mut_arcmutex!(draft)
189                .tokenizer()
190                .as_ref()
191                .ok_or(candle_core::Error::Msg(
192                    "`SpeculativePipeline::new` requires the draft pipeline to have a token trie"
193                        .to_string(),
194                ))?
195                .get_vocab(true)
196        {
197            candle_core::bail!("Target and draft models' tokenizer vocab do not match. This is required for speculative decoding.");
198        }
199        if get_mut_arcmutex!(target).category() != get_mut_arcmutex!(draft).category() {
200            candle_core::bail!("Target and draft models' category do not match. This is required for speculative decoding.");
201        }
202        if get_mut_arcmutex!(target)
203            .get_processor()
204            .inputs_processor()
205            .get_type()
206            != get_mut_arcmutex!(draft)
207                .get_processor()
208                .inputs_processor()
209                .get_type()
210        {
211            candle_core::bail!("Target and draft models' input processors do not match. This is required for speculative decoding.");
212        }
213        let metadata = get_mut_arcmutex!(target).get_metadata().clone();
214        let category = get_mut_arcmutex!(target).category();
215        // TODO: some checks or relaxation here?
216        Ok(Self {
217            target,
218            draft,
219            gamma: config.gamma,
220            metadata,
221            category,
222        })
223    }
224}
225
226impl PreProcessingMixin for SpeculativePipeline {
227    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
228        get_mut_arcmutex!(self.target).get_chat_template()
229    }
230    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
231        get_mut_arcmutex!(self.target).get_input_processor_config()
232    }
233}
234
235impl IsqPipelineMixin for SpeculativePipeline {
236    fn re_isq_model(&mut self, dtype: IsqType) -> anyhow::Result<()> {
237        get_mut_arcmutex!(self.target).re_isq_model(dtype)?;
238        get_mut_arcmutex!(self.draft).re_isq_model(dtype)
239    }
240}
241
242impl CacheManagerMixin for SpeculativePipeline {
243    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
244        NormalCacheManager.clone_in_cache(&*get_mut_arcmutex!(self.draft), seqs, true);
245        NormalCacheManager.clone_in_cache(&*get_mut_arcmutex!(self.target), seqs, false);
246    }
247    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
248        NormalCacheManager.clone_out_cache(&*get_mut_arcmutex!(self.draft), seqs, true);
249        NormalCacheManager.clone_out_cache(&*get_mut_arcmutex!(self.target), seqs, false);
250    }
251    fn set_none_cache(
252        &self,
253        seqs: &mut [&mut Sequence],
254        reset_non_granular: bool,
255        modify_draft_cache: bool,
256        load_preallocated_cache: bool,
257    ) {
258        NormalCacheManager.set_none_cache(
259            &*get_mut_arcmutex!(self.draft),
260            seqs,
261            modify_draft_cache,
262            load_preallocated_cache,
263        );
264        NormalCacheManager.set_none_cache(
265            &*get_mut_arcmutex!(self.target),
266            seqs,
267            false,
268            load_preallocated_cache,
269        );
270        if reset_non_granular {
271            self.reset_non_granular_state()
272        }
273    }
274    fn cache(&self) -> &EitherCache {
275        unreachable!()
276    }
277    fn do_preallocated_cache(&self) -> bool {
278        // KV cache size is not the same (necessarily)
279        false
280    }
281}
282
283impl MetadataMixin for SpeculativePipeline {
284    fn device(&self) -> Device {
285        get_mut_arcmutex!(self.target).device()
286    }
287    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
288        get_mut_arcmutex!(self.target).tokenizer()
289    }
290    fn name(&self) -> String {
291        format!(
292            "Speculative: tgt = `{}`, draft = `{}`, gamma = `{}`",
293            get_mut_arcmutex!(self.target).name(),
294            get_mut_arcmutex!(self.draft).name(),
295            self.gamma,
296        )
297    }
298    fn reset_non_granular_state(&self) {
299        get_mut_arcmutex!(self.target).reset_non_granular_state();
300        get_mut_arcmutex!(self.draft).reset_non_granular_state();
301    }
302    fn get_metadata(&self) -> Arc<GeneralMetadata> {
303        self.metadata.clone()
304    }
305    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
306        None
307    }
308}
309
310#[async_trait::async_trait]
311impl Pipeline for SpeculativePipeline {
312    fn forward_inputs(
313        &mut self,
314        _inputs: Box<dyn Any>,
315        _return_raw_logits: bool,
316    ) -> Result<ForwardInputsResult> {
317        unreachable!()
318    }
319    async fn sample_causal_gen(
320        &self,
321        _seqs: &mut [&mut Sequence],
322        _logits: Vec<Tensor>,
323        _prefix_cacher: &mut PrefixCacheManagerV2,
324        _disable_eos_stop: bool,
325        _rng: Arc<std::sync::Mutex<Isaac64Rng>>,
326    ) -> Result<()> {
327        unreachable!()
328    }
329    async fn step(
330        &mut self,
331        input_seqs: &mut [&mut Sequence],
332        is_prompt: bool,
333        _return_raw_logits: bool,
334        prefix_cacher: &mut PrefixCacheManagerV2,
335        disable_eos_stop: bool,
336        rng: Arc<Mutex<Isaac64Rng>>,
337        backend_metadata: CacheBackendMetadata<'_>,
338    ) -> Result<Duration> {
339        match backend_metadata {
340            CacheBackendMetadata::DefaultInstructions { pre_op, post_op } => {
341                match pre_op {
342                    CacheInstruction::In => self.clone_in_cache(input_seqs),
343                    CacheInstruction::Nothing => (),
344                    CacheInstruction::Reset {
345                        reset_non_granular,
346                        load_preallocated_cache,
347                    } => self.set_none_cache(
348                        input_seqs,
349                        reset_non_granular,
350                        true,
351                        load_preallocated_cache,
352                    ),
353                    _ => unreachable!("Unreachable PRE cache op."),
354                }
355
356                let start = Instant::now();
357                assert_eq!(input_seqs.len(), 1);
358
359                let seq = &mut input_seqs[0];
360
361                // ======================= Run draft model gamma times producing tokens ============================
362                // ======================= Sample the `gamma` logits. ============================
363                let mut draft_samples = Vec::new();
364                for i in 0..self.gamma {
365                    let is_xlora = get_mut_arcmutex!(self.draft).get_metadata().is_xlora;
366                    let device = get_mut_arcmutex!(self.draft).device();
367                    let no_kv_cache = get_mut_arcmutex!(self.draft).get_metadata().no_kv_cache;
368                    let inputs = self
369                        .get_processor()
370                        .inputs_processor()
371                        .process_inputs(
372                            self.tokenizer(),
373                            &mut [seq],
374                            is_prompt && i == 0, // Only prompt (no kv cache) if first
375                            is_xlora,
376                            &device,
377                            no_kv_cache,
378                            None,
379                            false,
380                            None,
381                            None, // TODO: get block tables/handle it
382                            None, // TODO: do we support???
383                            get_mut_arcmutex!(self.draft).device_mapper(),
384                        )
385                        .nth(0)
386                        .unwrap()
387                        .unwrap()
388                        .inputs;
389                    let logits = get_mut_arcmutex!(self.draft).forward_inputs(inputs, false)?;
390                    #[allow(irrefutable_let_patterns)]
391                    let ForwardInputsResult::CausalGeneration { logits } = logits
392                    else {
393                        candle_core::bail!(
394                            "Speculative decoding requires `CausalGeneration` forward results"
395                        );
396                    };
397
398                    let sample = sample_sequence(
399                        logits.clone(),
400                        seq,
401                        seq.return_logprobs(),
402                        rng.clone(),
403                        false, // todo tune
404                        true,
405                    )
406                    .await?;
407                    seq.add_tmp_tok(sample.token);
408                    draft_samples.push(SpeculativeSample { sample });
409                }
410                seq.remove_tmp_tok(self.gamma);
411
412                // ======================= Add all draft tokens but the last one. Add the last from the seq. ============================
413                let mut draft_prefill_tokens = if is_prompt {
414                    seq.get_toks().to_vec()
415                } else {
416                    vec![*seq.get_toks().last().unwrap()]
417                };
418                for (i, sample) in draft_samples.iter().enumerate() {
419                    if i == draft_samples.len() - 1 {
420                        continue;
421                    }
422                    draft_prefill_tokens.push(sample.sample.token);
423                }
424                seq.set_prefill_toks(draft_prefill_tokens);
425
426                // ======================= Run the model with all draft tokens. ============================
427
428                let initial_cache_len = match get_mut_arcmutex!(self.target).cache() {
429                    EitherCache::Full(full) => full.lock()[0]
430                        .as_ref()
431                        .map(|(k, _)| k.dims()[2])
432                        .unwrap_or(0),
433                    EitherCache::Normal(normal) => normal.lock().unwrap().0[0].current_seq_len(),
434                };
435
436                // ========= Run the model ============
437                let is_xlora = get_mut_arcmutex!(self.target).get_metadata().is_xlora;
438                let device = get_mut_arcmutex!(self.target).device();
439                let no_kv_cache = get_mut_arcmutex!(self.target).get_metadata().no_kv_cache;
440                let inputs = self
441                    .get_processor()
442                    .inputs_processor()
443                    .process_inputs(
444                        self.tokenizer(),
445                        &mut [seq],
446                        true, // use the "prefill" tokens
447                        is_xlora,
448                        &device,
449                        no_kv_cache,
450                        Some((self.gamma, initial_cache_len)), // Get the last gamma, see above
451                        false,
452                        None,
453                        None, // TODO: get block tables/handle it
454                        None, // TODO: do we support???
455                        get_mut_arcmutex!(self.target).device_mapper(),
456                    )
457                    .nth(0)
458                    .unwrap()
459                    .unwrap()
460                    .inputs;
461
462                let logits = get_mut_arcmutex!(self.target).forward_inputs(inputs, false)?;
463                #[allow(irrefutable_let_patterns)]
464                let ForwardInputsResult::CausalGeneration { logits } = logits
465                else {
466                    candle_core::bail!(
467                        "Speculative decoding requires `CausalGeneration` forward results"
468                    );
469                };
470
471                // Reset the prefill tokens
472                seq.reset_prefill_toks();
473
474                // ======================= Rejection sampling. ============================
475                // Map from each target sample to corresponding in draft sample
476                // this will first rollback LLG state if any, and then advance for the accepted tokens only
477                let samples = sample_target_sequence_speculative(
478                    logits.clone(),
479                    seq,
480                    seq.return_logprobs(),
481                    rng.clone(),
482                    &draft_samples,
483                )
484                .await?;
485
486                let accepted_tokens = samples.into_iter().map(|s| s.sample).collect::<Vec<_>>();
487
488                // ======================= Narrow caches to account for rejections ============================
489                let n_not_accepted = self.gamma - accepted_tokens.len();
490
491                match get_mut_arcmutex!(self.draft).cache() {
492                    EitherCache::Full(full) => {
493                        for (k, v) in full.lock().iter_mut().flatten() {
494                            *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
495                            *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
496                        }
497                    }
498                    EitherCache::Normal(normal) => {
499                        for cache in &mut *normal.lock().unwrap().0 {
500                            cache
501                                .set_len(cache.current_seq_len() - n_not_accepted)
502                                .map_err(|_| candle_core::Error::msg("KV cache set_len failed."))?;
503                        }
504                    }
505                }
506                if get_mut_arcmutex!(self.draft).get_metadata().is_xlora {
507                    match get_mut_arcmutex!(self.draft).cache() {
508                        EitherCache::Full(full) => {
509                            for (k, v) in full.xlora_lock().iter_mut().flatten() {
510                                *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
511                                *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
512                            }
513                        }
514                        EitherCache::Normal(_) => {
515                            unreachable!()
516                        }
517                    }
518                }
519                match get_mut_arcmutex!(self.target).cache() {
520                    EitherCache::Full(full) => {
521                        for (k, v) in full.lock().iter_mut().flatten() {
522                            *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
523                            *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
524                        }
525                    }
526                    EitherCache::Normal(normal) => {
527                        for cache in &mut *normal.lock().unwrap().0 {
528                            cache
529                                .set_len(cache.current_seq_len() - n_not_accepted)
530                                .map_err(|_| candle_core::Error::msg("KV cache set_len failed."))?;
531                        }
532                    }
533                }
534                if get_mut_arcmutex!(self.draft).get_metadata().is_xlora {
535                    match get_mut_arcmutex!(self.target).cache() {
536                        EitherCache::Full(full) => {
537                            for (k, v) in full.xlora_lock().iter_mut().flatten() {
538                                *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
539                                *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
540                            }
541                        }
542                        EitherCache::Normal(_) => {
543                            unreachable!()
544                        }
545                    }
546                }
547
548                let eos_owned = get_mut_arcmutex!(self.target)
549                    .get_metadata()
550                    .eos_tok
551                    .clone();
552                let eos_tok = if disable_eos_stop {
553                    None
554                } else {
555                    Some(&eos_owned[..])
556                };
557                // Add the tokens to the seq and the trie
558                for accepted in accepted_tokens {
559                    // Do not use the prefix cacher
560                    finish_or_add_toks_to_seq(
561                        self,
562                        prefix_cacher,
563                        seq,
564                        accepted.clone(),
565                        eos_tok,
566                        false,
567                    )
568                    .await?;
569                }
570
571                // Trick to improve lower bounds. Sample last token in multinomial
572                /*
573                let sample = sample_sequence(
574                    logits.clone(),
575                    seq,
576                    seq.return_logprobs(),
577                    rng.clone(),
578                    false, // todo tune
579                    true, // do not add to tok trie yet
580                    true,
581                )
582                .await?;
583                finish_or_add_toks_to_seq(self, prefix_cacher, seq, sample, eos_tok, false);
584                */
585                let end = Instant::now();
586                let exec_duration = end.duration_since(start);
587
588                match post_op {
589                    CacheInstruction::Out => {
590                        self.clone_out_cache(input_seqs);
591                    }
592                    CacheInstruction::Nothing => (),
593                    CacheInstruction::Reset {
594                        reset_non_granular,
595                        load_preallocated_cache,
596                    } => self.set_none_cache(
597                        input_seqs,
598                        reset_non_granular,
599                        true,
600                        load_preallocated_cache,
601                    ),
602                    _ => unreachable!("Unreachable pre cache op."),
603                }
604
605                // Done! We have:
606                // - Run the draft model gamma times
607                // - Reset draft model cache fully
608                // - Sampled draft model's distributions
609                // - Run target model
610                // - Execute speculative decoding algorithm on the resulting distributions
611                // - Added the accepted tokens to buffer and trie
612                // - Maybe fixed up cache of base model based on accepted tokens.
613
614                Ok(exec_duration)
615            }
616            CacheBackendMetadata::PagedAttention {
617                metadata: _,
618                blocks_to_copy: _,
619                blocks_to_swap_in: _,
620                blocks_to_swap_out: _,
621            } => unreachable!(),
622        }
623    }
624    fn category(&self) -> ModelCategory {
625        self.category.clone()
626    }
627}
628
629impl AnyMoePipelineMixin for SpeculativePipeline {}