mistralrs_core/pipeline/
speculative.rs

1use std::{
2    any::Any,
3    sync::{Arc, Mutex},
4    time::{Duration, Instant},
5};
6
7use anyhow::Result as anyhowResult;
8use candle_core::{Device, IndexOp, Result, Tensor};
9use mistralrs_quant::IsqType;
10use rand_isaac::Isaac64Rng;
11use tokenizers::Tokenizer;
12use tracing::warn;
13
14use crate::{
15    device_map::DeviceMapper,
16    get_mut_arcmutex,
17    kv_cache::NormalCacheManager,
18    pipeline::sampling::{
19        finish_or_add_toks_to_seq, sample_sequence, sample_target_sequence_speculative,
20    },
21    prefix_cacher::PrefixCacheManagerV2,
22    sequence::Sequence,
23    DeviceMapSetting, Loader, ModelKind, PagedAttentionConfig, Pipeline, TokenSource, TryIntoDType,
24};
25
26use crate::kv_cache::CacheManager;
27use crate::utils::progress::ProgressScopeGuard;
28
29use super::{
30    chat_template::ChatTemplate, sampling::SpeculativeSample, AnyMoePipelineMixin,
31    CacheBackendMetadata, CacheInstruction, CacheManagerMixin, EitherCache, ForwardInputsResult,
32    GeneralMetadata, IsqPipelineMixin, MetadataMixin, ModelCategory, ModelPaths,
33    PreProcessingMixin,
34};
35
36/// A loader for a speculative pipeline using 2 [`Loader`]s.
37pub struct SpeculativeLoader {
38    pub target: Box<dyn Loader>,
39    pub draft: Box<dyn Loader>,
40    pub config: SpeculativeConfig,
41}
42
43impl Loader for SpeculativeLoader {
44    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
45    fn load_model_from_hf(
46        &self,
47        revision: Option<String>,
48        token_source: TokenSource,
49        dtype: &dyn TryIntoDType,
50        device: &Device,
51        silent: bool,
52        mapper: DeviceMapSetting,
53        in_situ_quant: Option<IsqType>,
54        paged_attn_config: Option<PagedAttentionConfig>,
55    ) -> anyhowResult<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
56        let _progress_guard = ProgressScopeGuard::new(silent);
57        let paged_attn_config = if paged_attn_config.is_none() {
58            warn!(
59                "Speculative decoding does not currently support PagedAttention, running without"
60            );
61            None
62        } else {
63            paged_attn_config
64        };
65
66        let target = self.target.load_model_from_hf(
67            revision.clone(),
68            token_source.clone(),
69            dtype,
70            device,
71            silent,
72            mapper.clone(),
73            in_situ_quant,
74            paged_attn_config,
75        )?;
76        let draft = self.draft.load_model_from_hf(
77            revision,
78            token_source,
79            dtype,
80            device,
81            silent,
82            mapper,
83            in_situ_quant,
84            paged_attn_config,
85        )?;
86        Ok(Arc::new(tokio::sync::Mutex::new(SpeculativePipeline::new(
87            target,
88            draft,
89            self.config,
90        )?)))
91    }
92
93    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
94    fn load_model_from_path(
95        &self,
96        paths: &Box<dyn ModelPaths>,
97        dtype: &dyn TryIntoDType,
98        device: &Device,
99        silent: bool,
100        mapper: DeviceMapSetting,
101        in_situ_quant: Option<IsqType>,
102        paged_attn_config: Option<PagedAttentionConfig>,
103    ) -> anyhowResult<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
104        let _progress_guard = ProgressScopeGuard::new(silent);
105        let paged_attn_config = if paged_attn_config.is_none() {
106            warn!(
107                "Speculative decoding does not currently support PagedAttention, running without"
108            );
109            None
110        } else {
111            paged_attn_config
112        };
113
114        let target = self.target.load_model_from_path(
115            paths,
116            dtype,
117            device,
118            silent,
119            mapper.clone(),
120            in_situ_quant,
121            paged_attn_config,
122        )?;
123        let draft = self.draft.load_model_from_path(
124            paths,
125            dtype,
126            device,
127            silent,
128            mapper.clone(),
129            in_situ_quant,
130            paged_attn_config,
131        )?;
132        Ok(Arc::new(tokio::sync::Mutex::new(SpeculativePipeline::new(
133            target,
134            draft,
135            self.config,
136        )?)))
137    }
138    fn get_id(&self) -> String {
139        format!(
140            "Speculative: tgt = `{}`, draft = `{}`, gamma = `{}`",
141            self.target.get_id(),
142            self.draft.get_id(),
143            self.config.gamma,
144        )
145    }
146    fn get_kind(&self) -> ModelKind {
147        ModelKind::Speculative {
148            target: Box::new(self.target.get_kind()),
149            draft: Box::new(self.draft.get_kind()),
150        }
151    }
152}
153
154/// Speculative decoding pipeline: <https://arxiv.org/pdf/2211.17192>
155///
156/// # Algorithm
157/// Given draft model q and target model p with probability distributions \
158/// q_i(x) and p_i(x) for each token
159///
160/// - Keep the sample for token i if q_i(x) <= p_i(x)
161///     - This means the target model agrees
162/// - Else (q_i(x) > p_i(x)) accept that token with prob p_i(x)/q_i(x)
163///     - If rejected, sample token from from p'_i(x) = norm(max(0, p(x) − q(x))) and do not take any more'
164///
165pub struct SpeculativePipeline {
166    target: Arc<tokio::sync::Mutex<dyn Pipeline>>,
167    draft: Arc<tokio::sync::Mutex<dyn Pipeline>>,
168    gamma: usize,
169    metadata: Arc<GeneralMetadata>,
170    category: ModelCategory,
171}
172
173#[derive(Copy, Clone)]
174/// Metadata for a speculative pipeline
175pub struct SpeculativeConfig {
176    /// γ completions to run of the draft model
177    pub gamma: usize,
178}
179
180impl SpeculativePipeline {
181    pub fn new(
182        target: Arc<tokio::sync::Mutex<dyn Pipeline>>,
183        draft: Arc<tokio::sync::Mutex<dyn Pipeline>>,
184        config: SpeculativeConfig,
185    ) -> Result<Self> {
186        if get_mut_arcmutex!(target)
187            .tokenizer()
188            .as_ref()
189            .ok_or(candle_core::Error::Msg(
190                "`SpeculativePipeline::new` requires the target pipeline to have a token trie"
191                    .to_string(),
192            ))?
193            .get_vocab(true)
194            != get_mut_arcmutex!(draft)
195                .tokenizer()
196                .as_ref()
197                .ok_or(candle_core::Error::Msg(
198                    "`SpeculativePipeline::new` requires the draft pipeline to have a token trie"
199                        .to_string(),
200                ))?
201                .get_vocab(true)
202        {
203            candle_core::bail!("Target and draft models' tokenizer vocab do not match. This is required for speculative decoding.");
204        }
205        if get_mut_arcmutex!(target).category() != get_mut_arcmutex!(draft).category() {
206            candle_core::bail!("Target and draft models' category do not match. This is required for speculative decoding.");
207        }
208        if get_mut_arcmutex!(target)
209            .get_processor()
210            .inputs_processor()
211            .get_type()
212            != get_mut_arcmutex!(draft)
213                .get_processor()
214                .inputs_processor()
215                .get_type()
216        {
217            candle_core::bail!("Target and draft models' input processors do not match. This is required for speculative decoding.");
218        }
219        let metadata = get_mut_arcmutex!(target).get_metadata().clone();
220        let category = get_mut_arcmutex!(target).category();
221        // TODO: some checks or relaxation here?
222        Ok(Self {
223            target,
224            draft,
225            gamma: config.gamma,
226            metadata,
227            category,
228        })
229    }
230}
231
232impl PreProcessingMixin for SpeculativePipeline {
233    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
234        get_mut_arcmutex!(self.target).get_chat_template()
235    }
236    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
237        get_mut_arcmutex!(self.target).get_input_processor_config()
238    }
239}
240
241impl IsqPipelineMixin for SpeculativePipeline {
242    fn re_isq_model(&mut self, dtype: IsqType) -> anyhow::Result<()> {
243        get_mut_arcmutex!(self.target).re_isq_model(dtype)?;
244        get_mut_arcmutex!(self.draft).re_isq_model(dtype)
245    }
246}
247
248impl CacheManagerMixin for SpeculativePipeline {
249    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
250        NormalCacheManager.clone_in_cache(&*get_mut_arcmutex!(self.draft), seqs, true);
251        NormalCacheManager.clone_in_cache(&*get_mut_arcmutex!(self.target), seqs, false);
252    }
253    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
254        NormalCacheManager.clone_out_cache(&*get_mut_arcmutex!(self.draft), seqs, true);
255        NormalCacheManager.clone_out_cache(&*get_mut_arcmutex!(self.target), seqs, false);
256    }
257    fn set_none_cache(
258        &self,
259        seqs: &mut [&mut Sequence],
260        reset_non_granular: bool,
261        modify_draft_cache: bool,
262        load_preallocated_cache: bool,
263    ) {
264        NormalCacheManager.set_none_cache(
265            &*get_mut_arcmutex!(self.draft),
266            seqs,
267            modify_draft_cache,
268            load_preallocated_cache,
269        );
270        NormalCacheManager.set_none_cache(
271            &*get_mut_arcmutex!(self.target),
272            seqs,
273            false,
274            load_preallocated_cache,
275        );
276        if reset_non_granular {
277            self.reset_non_granular_state()
278        }
279    }
280    fn cache(&self) -> &EitherCache {
281        unreachable!()
282    }
283    fn do_preallocated_cache(&self) -> bool {
284        // KV cache size is not the same (necessarily)
285        false
286    }
287}
288
289impl MetadataMixin for SpeculativePipeline {
290    fn device(&self) -> Device {
291        get_mut_arcmutex!(self.target).device()
292    }
293    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
294        get_mut_arcmutex!(self.target).tokenizer()
295    }
296    fn name(&self) -> String {
297        format!(
298            "Speculative: tgt = `{}`, draft = `{}`, gamma = `{}`",
299            get_mut_arcmutex!(self.target).name(),
300            get_mut_arcmutex!(self.draft).name(),
301            self.gamma,
302        )
303    }
304    fn reset_non_granular_state(&self) {
305        get_mut_arcmutex!(self.target).reset_non_granular_state();
306        get_mut_arcmutex!(self.draft).reset_non_granular_state();
307    }
308    fn get_metadata(&self) -> Arc<GeneralMetadata> {
309        self.metadata.clone()
310    }
311    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
312        None
313    }
314}
315
316#[async_trait::async_trait]
317impl Pipeline for SpeculativePipeline {
318    fn forward_inputs(
319        &mut self,
320        _inputs: Box<dyn Any>,
321        _return_raw_logits: bool,
322    ) -> Result<ForwardInputsResult> {
323        unreachable!()
324    }
325    async fn sample_causal_gen(
326        &self,
327        _seqs: &mut [&mut Sequence],
328        _logits: Vec<Tensor>,
329        _prefix_cacher: &mut PrefixCacheManagerV2,
330        _disable_eos_stop: bool,
331        _rng: Arc<std::sync::Mutex<Isaac64Rng>>,
332    ) -> Result<()> {
333        unreachable!()
334    }
335    async fn step(
336        &mut self,
337        input_seqs: &mut [&mut Sequence],
338        is_prompt: bool,
339        _return_raw_logits: bool,
340        prefix_cacher: &mut PrefixCacheManagerV2,
341        disable_eos_stop: bool,
342        rng: Arc<Mutex<Isaac64Rng>>,
343        backend_metadata: CacheBackendMetadata,
344    ) -> Result<Duration> {
345        match backend_metadata {
346            CacheBackendMetadata::DefaultInstructions { pre_op, post_op } => {
347                match pre_op {
348                    CacheInstruction::In => self.clone_in_cache(input_seqs),
349                    CacheInstruction::Nothing => (),
350                    CacheInstruction::Reset {
351                        reset_non_granular,
352                        load_preallocated_cache,
353                    } => self.set_none_cache(
354                        input_seqs,
355                        reset_non_granular,
356                        true,
357                        load_preallocated_cache,
358                    ),
359                    _ => unreachable!("Unreachable PRE cache op."),
360                }
361
362                let start = Instant::now();
363                assert_eq!(input_seqs.len(), 1);
364
365                let seq = &mut input_seqs[0];
366
367                // ======================= Run draft model gamma times producing tokens ============================
368                // ======================= Sample the `gamma` logits. ============================
369                let mut draft_samples = Vec::new();
370                for i in 0..self.gamma {
371                    let is_xlora = get_mut_arcmutex!(self.draft).get_metadata().is_xlora;
372                    let device = get_mut_arcmutex!(self.draft).device();
373                    let no_kv_cache = get_mut_arcmutex!(self.draft).get_metadata().no_kv_cache;
374                    let inputs = self
375                        .get_processor()
376                        .inputs_processor()
377                        .process_inputs(
378                            self.tokenizer(),
379                            &mut [seq],
380                            is_prompt && i == 0, // Only prompt (no kv cache) if first
381                            is_xlora,
382                            &device,
383                            no_kv_cache,
384                            None,
385                            false,
386                            None,
387                            None, // TODO: get block tables/handle it
388                            get_mut_arcmutex!(self.draft).device_mapper(),
389                        )
390                        .unwrap()
391                        .inputs;
392                    let logits = get_mut_arcmutex!(self.draft).forward_inputs(inputs, false)?;
393                    #[allow(irrefutable_let_patterns)]
394                    let ForwardInputsResult::CausalGeneration { logits } = logits
395                    else {
396                        candle_core::bail!(
397                            "Speculative decoding requires `CausalGeneration` forward results"
398                        );
399                    };
400
401                    let sample = sample_sequence(
402                        logits.clone(),
403                        seq,
404                        seq.return_logprobs(),
405                        rng.clone(),
406                        false, // todo tune
407                        true,
408                        false,
409                    )
410                    .await?;
411                    seq.add_tmp_tok(sample.token);
412                    draft_samples.push(SpeculativeSample { sample });
413                }
414                seq.remove_tmp_tok(self.gamma);
415
416                // ======================= Add all draft tokens but the last one. Add the last from the seq. ============================
417                let mut draft_prefill_tokens = if is_prompt {
418                    seq.get_toks().to_vec()
419                } else {
420                    vec![*seq.get_toks().last().unwrap()]
421                };
422                for (i, sample) in draft_samples.iter().enumerate() {
423                    if i == draft_samples.len() - 1 {
424                        continue;
425                    }
426                    draft_prefill_tokens.push(sample.sample.token);
427                }
428                seq.set_prefill_toks(draft_prefill_tokens);
429
430                // ======================= Run the model with all draft tokens. ============================
431
432                let initial_cache_len = match get_mut_arcmutex!(self.target).cache() {
433                    EitherCache::Full(full) => full.lock()[0]
434                        .as_ref()
435                        .map(|(k, _)| k.dims()[2])
436                        .unwrap_or(0),
437                    EitherCache::Normal(normal) => normal.lock().unwrap().0[0].current_seq_len(),
438                    EitherCache::Hybrid(_) => {
439                        unreachable!("Speculative decoding is not supported with hybrid caches")
440                    }
441                };
442
443                // ========= Run the model ============
444                let is_xlora = get_mut_arcmutex!(self.target).get_metadata().is_xlora;
445                let device = get_mut_arcmutex!(self.target).device();
446                let no_kv_cache = get_mut_arcmutex!(self.target).get_metadata().no_kv_cache;
447                let inputs = self
448                    .get_processor()
449                    .inputs_processor()
450                    .process_inputs(
451                        self.tokenizer(),
452                        &mut [seq],
453                        true, // use the "prefill" tokens
454                        is_xlora,
455                        &device,
456                        no_kv_cache,
457                        Some((self.gamma, initial_cache_len)), // Get the last gamma, see above
458                        false,
459                        None,
460                        None, // TODO: get block tables/handle it
461                        get_mut_arcmutex!(self.target).device_mapper(),
462                    )
463                    .unwrap()
464                    .inputs;
465
466                let logits = get_mut_arcmutex!(self.target).forward_inputs(inputs, false)?;
467                #[allow(irrefutable_let_patterns)]
468                let ForwardInputsResult::CausalGeneration { logits } = logits
469                else {
470                    candle_core::bail!(
471                        "Speculative decoding requires `CausalGeneration` forward results"
472                    );
473                };
474
475                // Reset the prefill tokens
476                seq.reset_prefill_toks();
477
478                // ======================= Rejection sampling. ============================
479                // Map from each target sample to corresponding in draft sample
480                // this will first rollback LLG state if any, and then advance for the accepted tokens only
481                let samples = sample_target_sequence_speculative(
482                    logits.clone(),
483                    seq,
484                    seq.return_logprobs(),
485                    rng.clone(),
486                    &draft_samples,
487                )
488                .await?;
489
490                let accepted_tokens = samples.into_iter().map(|s| s.sample).collect::<Vec<_>>();
491
492                // ======================= Narrow caches to account for rejections ============================
493                let n_not_accepted = self.gamma - accepted_tokens.len();
494
495                match get_mut_arcmutex!(self.draft).cache() {
496                    EitherCache::Full(full) => {
497                        for (k, v) in full.lock().iter_mut().flatten() {
498                            *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
499                            *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
500                        }
501                    }
502                    EitherCache::Normal(normal) => {
503                        for cache in &mut *normal.lock().unwrap().0 {
504                            cache
505                                .set_len(cache.current_seq_len() - n_not_accepted)
506                                .map_err(|_| candle_core::Error::msg("KV cache set_len failed."))?;
507                        }
508                    }
509                    EitherCache::Hybrid(_) => {
510                        unreachable!("Speculative decoding is not supported with hybrid caches")
511                    }
512                }
513                if get_mut_arcmutex!(self.draft).get_metadata().is_xlora {
514                    match get_mut_arcmutex!(self.draft).cache() {
515                        EitherCache::Full(full) => {
516                            for (k, v) in full.xlora_lock().iter_mut().flatten() {
517                                *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
518                                *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
519                            }
520                        }
521                        EitherCache::Normal(_) | EitherCache::Hybrid(_) => {
522                            unreachable!()
523                        }
524                    }
525                }
526                match get_mut_arcmutex!(self.target).cache() {
527                    EitherCache::Full(full) => {
528                        for (k, v) in full.lock().iter_mut().flatten() {
529                            *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
530                            *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
531                        }
532                    }
533                    EitherCache::Normal(normal) => {
534                        for cache in &mut *normal.lock().unwrap().0 {
535                            cache
536                                .set_len(cache.current_seq_len() - n_not_accepted)
537                                .map_err(|_| candle_core::Error::msg("KV cache set_len failed."))?;
538                        }
539                    }
540                    EitherCache::Hybrid(_) => {
541                        unreachable!("Speculative decoding is not supported with hybrid caches")
542                    }
543                }
544                if get_mut_arcmutex!(self.draft).get_metadata().is_xlora {
545                    match get_mut_arcmutex!(self.target).cache() {
546                        EitherCache::Full(full) => {
547                            for (k, v) in full.xlora_lock().iter_mut().flatten() {
548                                *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
549                                *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
550                            }
551                        }
552                        EitherCache::Normal(_) | EitherCache::Hybrid(_) => {
553                            unreachable!()
554                        }
555                    }
556                }
557
558                let eos_owned = get_mut_arcmutex!(self.target)
559                    .get_metadata()
560                    .eos_tok
561                    .clone();
562                let eos_tok = if disable_eos_stop {
563                    None
564                } else {
565                    Some(&eos_owned[..])
566                };
567                // Add the tokens to the seq and the trie
568                for accepted in accepted_tokens {
569                    // Do not use the prefix cacher
570                    finish_or_add_toks_to_seq(
571                        self,
572                        prefix_cacher,
573                        seq,
574                        accepted.clone(),
575                        eos_tok,
576                        false,
577                    )
578                    .await?;
579                }
580
581                // Trick to improve lower bounds. Sample last token in multinomial
582                /*
583                let sample = sample_sequence(
584                    logits.clone(),
585                    seq,
586                    seq.return_logprobs(),
587                    rng.clone(),
588                    false, // todo tune
589                    true, // do not add to tok trie yet
590                    true,
591                )
592                .await?;
593                finish_or_add_toks_to_seq(self, prefix_cacher, seq, sample, eos_tok, false);
594                */
595                let end = Instant::now();
596                let exec_duration = end.duration_since(start);
597
598                match post_op {
599                    CacheInstruction::Out => {
600                        self.clone_out_cache(input_seqs);
601                    }
602                    CacheInstruction::Nothing => (),
603                    CacheInstruction::Reset {
604                        reset_non_granular,
605                        load_preallocated_cache,
606                    } => self.set_none_cache(
607                        input_seqs,
608                        reset_non_granular,
609                        true,
610                        load_preallocated_cache,
611                    ),
612                    _ => unreachable!("Unreachable pre cache op."),
613                }
614
615                // Done! We have:
616                // - Run the draft model gamma times
617                // - Reset draft model cache fully
618                // - Sampled draft model's distributions
619                // - Run target model
620                // - Execute speculative decoding algorithm on the resulting distributions
621                // - Added the accepted tokens to buffer and trie
622                // - Maybe fixed up cache of base model based on accepted tokens.
623
624                Ok(exec_duration)
625            }
626            CacheBackendMetadata::PagedAttention {
627                metadata: _,
628                blocks_to_copy: _,
629            } => unreachable!(),
630        }
631    }
632    fn category(&self) -> ModelCategory {
633        self.category.clone()
634    }
635}
636
637impl AnyMoePipelineMixin for SpeculativePipeline {}