mistralrs_core/pipeline/
speculative.rs

1use std::{
2    any::Any,
3    sync::{Arc, Mutex},
4    time::{Duration, Instant},
5};
6
7use anyhow::Result as anyhowResult;
8use candle_core::{Device, IndexOp, Result, Tensor};
9use mistralrs_quant::IsqType;
10use rand_isaac::Isaac64Rng;
11use tokenizers::Tokenizer;
12use tracing::warn;
13
14use crate::{
15    device_map::DeviceMapper,
16    get_mut_arcmutex,
17    kv_cache::NormalCacheManager,
18    pipeline::sampling::{
19        finish_or_add_toks_to_seq, sample_sequence, sample_target_sequence_speculative,
20    },
21    prefix_cacher::PrefixCacheManagerV2,
22    sequence::Sequence,
23    DeviceMapSetting, Loader, ModelKind, PagedAttentionConfig, Pipeline, TokenSource, TryIntoDType,
24};
25
26use crate::kv_cache::CacheManager;
27
28use super::{
29    chat_template::ChatTemplate, sampling::SpeculativeSample, AnyMoePipelineMixin,
30    CacheBackendMetadata, CacheInstruction, CacheManagerMixin, EitherCache, ForwardInputsResult,
31    GeneralMetadata, IsqPipelineMixin, MetadataMixin, ModelCategory, ModelPaths,
32    PreProcessingMixin,
33};
34
35/// A loader for a speculative pipeline using 2 [`Loader`]s.
36pub struct SpeculativeLoader {
37    pub target: Box<dyn Loader>,
38    pub draft: Box<dyn Loader>,
39    pub config: SpeculativeConfig,
40}
41
42impl Loader for SpeculativeLoader {
43    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
44    fn load_model_from_hf(
45        &self,
46        revision: Option<String>,
47        token_source: TokenSource,
48        dtype: &dyn TryIntoDType,
49        device: &Device,
50        silent: bool,
51        mapper: DeviceMapSetting,
52        in_situ_quant: Option<IsqType>,
53        paged_attn_config: Option<PagedAttentionConfig>,
54    ) -> anyhowResult<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
55        let paged_attn_config = if paged_attn_config.is_none() {
56            warn!(
57                "Speculative decoding does not currently support PagedAttention, running without"
58            );
59            None
60        } else {
61            paged_attn_config
62        };
63
64        let target = self.target.load_model_from_hf(
65            revision.clone(),
66            token_source.clone(),
67            dtype,
68            device,
69            silent,
70            mapper.clone(),
71            in_situ_quant,
72            paged_attn_config,
73        )?;
74        let draft = self.draft.load_model_from_hf(
75            revision,
76            token_source,
77            dtype,
78            device,
79            silent,
80            mapper,
81            in_situ_quant,
82            paged_attn_config,
83        )?;
84        Ok(Arc::new(tokio::sync::Mutex::new(SpeculativePipeline::new(
85            target,
86            draft,
87            self.config,
88        )?)))
89    }
90
91    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
92    fn load_model_from_path(
93        &self,
94        paths: &Box<dyn ModelPaths>,
95        dtype: &dyn TryIntoDType,
96        device: &Device,
97        silent: bool,
98        mapper: DeviceMapSetting,
99        in_situ_quant: Option<IsqType>,
100        paged_attn_config: Option<PagedAttentionConfig>,
101    ) -> anyhowResult<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
102        let paged_attn_config = if paged_attn_config.is_none() {
103            warn!(
104                "Speculative decoding does not currently support PagedAttention, running without"
105            );
106            None
107        } else {
108            paged_attn_config
109        };
110
111        let target = self.target.load_model_from_path(
112            paths,
113            dtype,
114            device,
115            silent,
116            mapper.clone(),
117            in_situ_quant,
118            paged_attn_config,
119        )?;
120        let draft = self.draft.load_model_from_path(
121            paths,
122            dtype,
123            device,
124            silent,
125            mapper.clone(),
126            in_situ_quant,
127            paged_attn_config,
128        )?;
129        Ok(Arc::new(tokio::sync::Mutex::new(SpeculativePipeline::new(
130            target,
131            draft,
132            self.config,
133        )?)))
134    }
135    fn get_id(&self) -> String {
136        format!(
137            "Speculative: tgt = `{}`, draft = `{}`, gamma = `{}`",
138            self.target.get_id(),
139            self.draft.get_id(),
140            self.config.gamma,
141        )
142    }
143    fn get_kind(&self) -> ModelKind {
144        ModelKind::Speculative {
145            target: Box::new(self.target.get_kind()),
146            draft: Box::new(self.draft.get_kind()),
147        }
148    }
149}
150
151/// Speculative decoding pipeline: <https://arxiv.org/pdf/2211.17192>
152///
153/// # Algorithm
154/// Given draft model q and target model p with probability distributions \
155/// q_i(x) and p_i(x) for each token
156///
157/// - Keep the sample for token i if q_i(x) <= p_i(x)
158///     - This means the target model agrees
159/// - Else (q_i(x) > p_i(x)) accept that token with prob p_i(x)/q_i(x)
160///     - If rejected, sample token from from p'_i(x) = norm(max(0, p(x) − q(x))) and do not take any more'
161///
162pub struct SpeculativePipeline {
163    target: Arc<tokio::sync::Mutex<dyn Pipeline>>,
164    draft: Arc<tokio::sync::Mutex<dyn Pipeline>>,
165    gamma: usize,
166    metadata: Arc<GeneralMetadata>,
167    category: ModelCategory,
168}
169
170#[derive(Copy, Clone)]
171/// Metadata for a speculative pipeline
172pub struct SpeculativeConfig {
173    /// γ completions to run of the draft model
174    pub gamma: usize,
175}
176
177impl SpeculativePipeline {
178    pub fn new(
179        target: Arc<tokio::sync::Mutex<dyn Pipeline>>,
180        draft: Arc<tokio::sync::Mutex<dyn Pipeline>>,
181        config: SpeculativeConfig,
182    ) -> Result<Self> {
183        if get_mut_arcmutex!(target)
184            .tokenizer()
185            .as_ref()
186            .ok_or(candle_core::Error::Msg(
187                "`SpeculativePipeline::new` requires the target pipeline to have a token trie"
188                    .to_string(),
189            ))?
190            .get_vocab(true)
191            != get_mut_arcmutex!(draft)
192                .tokenizer()
193                .as_ref()
194                .ok_or(candle_core::Error::Msg(
195                    "`SpeculativePipeline::new` requires the draft pipeline to have a token trie"
196                        .to_string(),
197                ))?
198                .get_vocab(true)
199        {
200            candle_core::bail!("Target and draft models' tokenizer vocab do not match. This is required for speculative decoding.");
201        }
202        if get_mut_arcmutex!(target).category() != get_mut_arcmutex!(draft).category() {
203            candle_core::bail!("Target and draft models' category do not match. This is required for speculative decoding.");
204        }
205        if get_mut_arcmutex!(target)
206            .get_processor()
207            .inputs_processor()
208            .get_type()
209            != get_mut_arcmutex!(draft)
210                .get_processor()
211                .inputs_processor()
212                .get_type()
213        {
214            candle_core::bail!("Target and draft models' input processors do not match. This is required for speculative decoding.");
215        }
216        let metadata = get_mut_arcmutex!(target).get_metadata().clone();
217        let category = get_mut_arcmutex!(target).category();
218        // TODO: some checks or relaxation here?
219        Ok(Self {
220            target,
221            draft,
222            gamma: config.gamma,
223            metadata,
224            category,
225        })
226    }
227}
228
229impl PreProcessingMixin for SpeculativePipeline {
230    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
231        get_mut_arcmutex!(self.target).get_chat_template()
232    }
233    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
234        get_mut_arcmutex!(self.target).get_input_processor_config()
235    }
236}
237
238impl IsqPipelineMixin for SpeculativePipeline {
239    fn re_isq_model(&mut self, dtype: IsqType) -> anyhow::Result<()> {
240        get_mut_arcmutex!(self.target).re_isq_model(dtype)?;
241        get_mut_arcmutex!(self.draft).re_isq_model(dtype)
242    }
243}
244
245impl CacheManagerMixin for SpeculativePipeline {
246    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
247        NormalCacheManager.clone_in_cache(&*get_mut_arcmutex!(self.draft), seqs, true);
248        NormalCacheManager.clone_in_cache(&*get_mut_arcmutex!(self.target), seqs, false);
249    }
250    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
251        NormalCacheManager.clone_out_cache(&*get_mut_arcmutex!(self.draft), seqs, true);
252        NormalCacheManager.clone_out_cache(&*get_mut_arcmutex!(self.target), seqs, false);
253    }
254    fn set_none_cache(
255        &self,
256        seqs: &mut [&mut Sequence],
257        reset_non_granular: bool,
258        modify_draft_cache: bool,
259        load_preallocated_cache: bool,
260    ) {
261        NormalCacheManager.set_none_cache(
262            &*get_mut_arcmutex!(self.draft),
263            seqs,
264            modify_draft_cache,
265            load_preallocated_cache,
266        );
267        NormalCacheManager.set_none_cache(
268            &*get_mut_arcmutex!(self.target),
269            seqs,
270            false,
271            load_preallocated_cache,
272        );
273        if reset_non_granular {
274            self.reset_non_granular_state()
275        }
276    }
277    fn cache(&self) -> &EitherCache {
278        unreachable!()
279    }
280    fn do_preallocated_cache(&self) -> bool {
281        // KV cache size is not the same (necessarily)
282        false
283    }
284}
285
286impl MetadataMixin for SpeculativePipeline {
287    fn device(&self) -> Device {
288        get_mut_arcmutex!(self.target).device()
289    }
290    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
291        get_mut_arcmutex!(self.target).tokenizer()
292    }
293    fn name(&self) -> String {
294        format!(
295            "Speculative: tgt = `{}`, draft = `{}`, gamma = `{}`",
296            get_mut_arcmutex!(self.target).name(),
297            get_mut_arcmutex!(self.draft).name(),
298            self.gamma,
299        )
300    }
301    fn reset_non_granular_state(&self) {
302        get_mut_arcmutex!(self.target).reset_non_granular_state();
303        get_mut_arcmutex!(self.draft).reset_non_granular_state();
304    }
305    fn get_metadata(&self) -> Arc<GeneralMetadata> {
306        self.metadata.clone()
307    }
308    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
309        None
310    }
311}
312
313#[async_trait::async_trait]
314impl Pipeline for SpeculativePipeline {
315    fn forward_inputs(
316        &mut self,
317        _inputs: Box<dyn Any>,
318        _return_raw_logits: bool,
319    ) -> Result<ForwardInputsResult> {
320        unreachable!()
321    }
322    async fn sample_causal_gen(
323        &self,
324        _seqs: &mut [&mut Sequence],
325        _logits: Vec<Tensor>,
326        _prefix_cacher: &mut PrefixCacheManagerV2,
327        _disable_eos_stop: bool,
328        _rng: Arc<std::sync::Mutex<Isaac64Rng>>,
329    ) -> Result<()> {
330        unreachable!()
331    }
332    async fn step(
333        &mut self,
334        input_seqs: &mut [&mut Sequence],
335        is_prompt: bool,
336        _return_raw_logits: bool,
337        prefix_cacher: &mut PrefixCacheManagerV2,
338        disable_eos_stop: bool,
339        rng: Arc<Mutex<Isaac64Rng>>,
340        backend_metadata: CacheBackendMetadata,
341    ) -> Result<Duration> {
342        match backend_metadata {
343            CacheBackendMetadata::DefaultInstructions { pre_op, post_op } => {
344                match pre_op {
345                    CacheInstruction::In => self.clone_in_cache(input_seqs),
346                    CacheInstruction::Nothing => (),
347                    CacheInstruction::Reset {
348                        reset_non_granular,
349                        load_preallocated_cache,
350                    } => self.set_none_cache(
351                        input_seqs,
352                        reset_non_granular,
353                        true,
354                        load_preallocated_cache,
355                    ),
356                    _ => unreachable!("Unreachable PRE cache op."),
357                }
358
359                let start = Instant::now();
360                assert_eq!(input_seqs.len(), 1);
361
362                let seq = &mut input_seqs[0];
363
364                // ======================= Run draft model gamma times producing tokens ============================
365                // ======================= Sample the `gamma` logits. ============================
366                let mut draft_samples = Vec::new();
367                for i in 0..self.gamma {
368                    let is_xlora = get_mut_arcmutex!(self.draft).get_metadata().is_xlora;
369                    let device = get_mut_arcmutex!(self.draft).device();
370                    let no_kv_cache = get_mut_arcmutex!(self.draft).get_metadata().no_kv_cache;
371                    let inputs = self
372                        .get_processor()
373                        .inputs_processor()
374                        .process_inputs(
375                            self.tokenizer(),
376                            &mut [seq],
377                            is_prompt && i == 0, // Only prompt (no kv cache) if first
378                            is_xlora,
379                            &device,
380                            no_kv_cache,
381                            None,
382                            false,
383                            None,
384                            None, // TODO: get block tables/handle it
385                            None, // TODO: do we support???
386                            get_mut_arcmutex!(self.draft).device_mapper(),
387                        )
388                        .nth(0)
389                        .unwrap()
390                        .unwrap()
391                        .inputs;
392                    let logits = get_mut_arcmutex!(self.draft).forward_inputs(inputs, false)?;
393                    #[allow(irrefutable_let_patterns)]
394                    let ForwardInputsResult::CausalGeneration { logits } = logits
395                    else {
396                        candle_core::bail!(
397                            "Speculative decoding requires `CausalGeneration` forward results"
398                        );
399                    };
400
401                    let sample = sample_sequence(
402                        logits.clone(),
403                        seq,
404                        seq.return_logprobs(),
405                        rng.clone(),
406                        false, // todo tune
407                        true,
408                        false,
409                    )
410                    .await?;
411                    seq.add_tmp_tok(sample.token);
412                    draft_samples.push(SpeculativeSample { sample });
413                }
414                seq.remove_tmp_tok(self.gamma);
415
416                // ======================= Add all draft tokens but the last one. Add the last from the seq. ============================
417                let mut draft_prefill_tokens = if is_prompt {
418                    seq.get_toks().to_vec()
419                } else {
420                    vec![*seq.get_toks().last().unwrap()]
421                };
422                for (i, sample) in draft_samples.iter().enumerate() {
423                    if i == draft_samples.len() - 1 {
424                        continue;
425                    }
426                    draft_prefill_tokens.push(sample.sample.token);
427                }
428                seq.set_prefill_toks(draft_prefill_tokens);
429
430                // ======================= Run the model with all draft tokens. ============================
431
432                let initial_cache_len = match get_mut_arcmutex!(self.target).cache() {
433                    EitherCache::Full(full) => full.lock()[0]
434                        .as_ref()
435                        .map(|(k, _)| k.dims()[2])
436                        .unwrap_or(0),
437                    EitherCache::Normal(normal) => normal.lock().unwrap().0[0].current_seq_len(),
438                };
439
440                // ========= Run the model ============
441                let is_xlora = get_mut_arcmutex!(self.target).get_metadata().is_xlora;
442                let device = get_mut_arcmutex!(self.target).device();
443                let no_kv_cache = get_mut_arcmutex!(self.target).get_metadata().no_kv_cache;
444                let inputs = self
445                    .get_processor()
446                    .inputs_processor()
447                    .process_inputs(
448                        self.tokenizer(),
449                        &mut [seq],
450                        true, // use the "prefill" tokens
451                        is_xlora,
452                        &device,
453                        no_kv_cache,
454                        Some((self.gamma, initial_cache_len)), // Get the last gamma, see above
455                        false,
456                        None,
457                        None, // TODO: get block tables/handle it
458                        None, // TODO: do we support???
459                        get_mut_arcmutex!(self.target).device_mapper(),
460                    )
461                    .nth(0)
462                    .unwrap()
463                    .unwrap()
464                    .inputs;
465
466                let logits = get_mut_arcmutex!(self.target).forward_inputs(inputs, false)?;
467                #[allow(irrefutable_let_patterns)]
468                let ForwardInputsResult::CausalGeneration { logits } = logits
469                else {
470                    candle_core::bail!(
471                        "Speculative decoding requires `CausalGeneration` forward results"
472                    );
473                };
474
475                // Reset the prefill tokens
476                seq.reset_prefill_toks();
477
478                // ======================= Rejection sampling. ============================
479                // Map from each target sample to corresponding in draft sample
480                // this will first rollback LLG state if any, and then advance for the accepted tokens only
481                let samples = sample_target_sequence_speculative(
482                    logits.clone(),
483                    seq,
484                    seq.return_logprobs(),
485                    rng.clone(),
486                    &draft_samples,
487                )
488                .await?;
489
490                let accepted_tokens = samples.into_iter().map(|s| s.sample).collect::<Vec<_>>();
491
492                // ======================= Narrow caches to account for rejections ============================
493                let n_not_accepted = self.gamma - accepted_tokens.len();
494
495                match get_mut_arcmutex!(self.draft).cache() {
496                    EitherCache::Full(full) => {
497                        for (k, v) in full.lock().iter_mut().flatten() {
498                            *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
499                            *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
500                        }
501                    }
502                    EitherCache::Normal(normal) => {
503                        for cache in &mut *normal.lock().unwrap().0 {
504                            cache
505                                .set_len(cache.current_seq_len() - n_not_accepted)
506                                .map_err(|_| candle_core::Error::msg("KV cache set_len failed."))?;
507                        }
508                    }
509                }
510                if get_mut_arcmutex!(self.draft).get_metadata().is_xlora {
511                    match get_mut_arcmutex!(self.draft).cache() {
512                        EitherCache::Full(full) => {
513                            for (k, v) in full.xlora_lock().iter_mut().flatten() {
514                                *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
515                                *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
516                            }
517                        }
518                        EitherCache::Normal(_) => {
519                            unreachable!()
520                        }
521                    }
522                }
523                match get_mut_arcmutex!(self.target).cache() {
524                    EitherCache::Full(full) => {
525                        for (k, v) in full.lock().iter_mut().flatten() {
526                            *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
527                            *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
528                        }
529                    }
530                    EitherCache::Normal(normal) => {
531                        for cache in &mut *normal.lock().unwrap().0 {
532                            cache
533                                .set_len(cache.current_seq_len() - n_not_accepted)
534                                .map_err(|_| candle_core::Error::msg("KV cache set_len failed."))?;
535                        }
536                    }
537                }
538                if get_mut_arcmutex!(self.draft).get_metadata().is_xlora {
539                    match get_mut_arcmutex!(self.target).cache() {
540                        EitherCache::Full(full) => {
541                            for (k, v) in full.xlora_lock().iter_mut().flatten() {
542                                *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
543                                *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
544                            }
545                        }
546                        EitherCache::Normal(_) => {
547                            unreachable!()
548                        }
549                    }
550                }
551
552                let eos_owned = get_mut_arcmutex!(self.target)
553                    .get_metadata()
554                    .eos_tok
555                    .clone();
556                let eos_tok = if disable_eos_stop {
557                    None
558                } else {
559                    Some(&eos_owned[..])
560                };
561                // Add the tokens to the seq and the trie
562                for accepted in accepted_tokens {
563                    // Do not use the prefix cacher
564                    finish_or_add_toks_to_seq(
565                        self,
566                        prefix_cacher,
567                        seq,
568                        accepted.clone(),
569                        eos_tok,
570                        false,
571                    )
572                    .await?;
573                }
574
575                // Trick to improve lower bounds. Sample last token in multinomial
576                /*
577                let sample = sample_sequence(
578                    logits.clone(),
579                    seq,
580                    seq.return_logprobs(),
581                    rng.clone(),
582                    false, // todo tune
583                    true, // do not add to tok trie yet
584                    true,
585                )
586                .await?;
587                finish_or_add_toks_to_seq(self, prefix_cacher, seq, sample, eos_tok, false);
588                */
589                let end = Instant::now();
590                let exec_duration = end.duration_since(start);
591
592                match post_op {
593                    CacheInstruction::Out => {
594                        self.clone_out_cache(input_seqs);
595                    }
596                    CacheInstruction::Nothing => (),
597                    CacheInstruction::Reset {
598                        reset_non_granular,
599                        load_preallocated_cache,
600                    } => self.set_none_cache(
601                        input_seqs,
602                        reset_non_granular,
603                        true,
604                        load_preallocated_cache,
605                    ),
606                    _ => unreachable!("Unreachable pre cache op."),
607                }
608
609                // Done! We have:
610                // - Run the draft model gamma times
611                // - Reset draft model cache fully
612                // - Sampled draft model's distributions
613                // - Run target model
614                // - Execute speculative decoding algorithm on the resulting distributions
615                // - Added the accepted tokens to buffer and trie
616                // - Maybe fixed up cache of base model based on accepted tokens.
617
618                Ok(exec_duration)
619            }
620            CacheBackendMetadata::PagedAttention {
621                metadata: _,
622                blocks_to_copy: _,
623                blocks_to_swap_in: _,
624                blocks_to_swap_out: _,
625            } => unreachable!(),
626        }
627    }
628    fn category(&self) -> ModelCategory {
629        self.category.clone()
630    }
631}
632
633impl AnyMoePipelineMixin for SpeculativePipeline {}