mistralrs_core/pipeline/
speculative.rs

1use std::{
2    any::Any,
3    iter::zip,
4    sync::{Arc, Mutex},
5    time::{Duration, Instant},
6};
7
8use anyhow::Result as anyhowResult;
9use candle_core::{Device, IndexOp, Result, Tensor};
10use mistralrs_quant::IsqType;
11use rand_isaac::Isaac64Rng;
12use tokenizers::Tokenizer;
13use tracing::warn;
14
15use crate::{
16    device_map::DeviceMapper,
17    get_mut_arcmutex,
18    pipeline::sampling::{
19        finish_or_add_toks_to_seq, sample_sequence, sample_target_sequence_speculative,
20    },
21    prefix_cacher::PrefixCacheManagerV2,
22    sequence::{Sequence, SequenceRecognizer},
23    DeviceMapSetting, Loader, ModelKind, PagedAttentionConfig, Pipeline, TokenSource, TryIntoDType,
24};
25
26use super::{
27    cache_manager::NormalCacheManager, chat_template::ChatTemplate, sampling::SpeculativeSample,
28    AnyMoePipelineMixin, CacheBackendMetadata, CacheInstruction, CacheManager, CacheManagerMixin,
29    EitherCache, ForwardInputsResult, GeneralMetadata, IsqPipelineMixin, MetadataMixin,
30    ModelCategory, ModelPaths, PreProcessingMixin,
31};
32
33/// A loader for a speculative pipeline using 2 [`Loader`]s.
34pub struct SpeculativeLoader {
35    pub target: Box<dyn Loader>,
36    pub draft: Box<dyn Loader>,
37    pub config: SpeculativeConfig,
38}
39
40impl Loader for SpeculativeLoader {
41    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
42    fn load_model_from_hf(
43        &self,
44        revision: Option<String>,
45        token_source: TokenSource,
46        dtype: &dyn TryIntoDType,
47        device: &Device,
48        silent: bool,
49        mapper: DeviceMapSetting,
50        in_situ_quant: Option<IsqType>,
51        paged_attn_config: Option<PagedAttentionConfig>,
52    ) -> anyhowResult<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
53        let paged_attn_config = if paged_attn_config.is_none() {
54            warn!(
55                "Speculative decoding does not currently support PagedAttention, running without"
56            );
57            None
58        } else {
59            paged_attn_config
60        };
61
62        let target = self.target.load_model_from_hf(
63            revision.clone(),
64            token_source.clone(),
65            dtype,
66            device,
67            silent,
68            mapper.clone(),
69            in_situ_quant,
70            paged_attn_config,
71        )?;
72        let draft = self.draft.load_model_from_hf(
73            revision,
74            token_source,
75            dtype,
76            device,
77            silent,
78            mapper,
79            in_situ_quant,
80            paged_attn_config,
81        )?;
82        Ok(Arc::new(tokio::sync::Mutex::new(SpeculativePipeline::new(
83            target,
84            draft,
85            self.config,
86        )?)))
87    }
88
89    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
90    fn load_model_from_path(
91        &self,
92        paths: &Box<dyn ModelPaths>,
93        dtype: &dyn TryIntoDType,
94        device: &Device,
95        silent: bool,
96        mapper: DeviceMapSetting,
97        in_situ_quant: Option<IsqType>,
98        paged_attn_config: Option<PagedAttentionConfig>,
99    ) -> anyhowResult<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
100        let paged_attn_config = if paged_attn_config.is_none() {
101            warn!(
102                "Speculative decoding does not currently support PagedAttention, running without"
103            );
104            None
105        } else {
106            paged_attn_config
107        };
108
109        let target = self.target.load_model_from_path(
110            paths,
111            dtype,
112            device,
113            silent,
114            mapper.clone(),
115            in_situ_quant,
116            paged_attn_config,
117        )?;
118        let draft = self.draft.load_model_from_path(
119            paths,
120            dtype,
121            device,
122            silent,
123            mapper.clone(),
124            in_situ_quant,
125            paged_attn_config,
126        )?;
127        Ok(Arc::new(tokio::sync::Mutex::new(SpeculativePipeline::new(
128            target,
129            draft,
130            self.config,
131        )?)))
132    }
133    fn get_id(&self) -> String {
134        format!(
135            "Speculative: tgt = `{}`, draft = `{}`, gamma = `{}`",
136            self.target.get_id(),
137            self.draft.get_id(),
138            self.config.gamma,
139        )
140    }
141    fn get_kind(&self) -> ModelKind {
142        ModelKind::Speculative {
143            target: Box::new(self.target.get_kind()),
144            draft: Box::new(self.draft.get_kind()),
145        }
146    }
147}
148
149/// Speculative decoding pipeline: <https://arxiv.org/pdf/2211.17192>
150///
151/// # Algorithm
152/// Given draft model q and target model p with probability distributions \
153/// q_i(x) and p_i(x) for each token
154///
155/// - Keep the sample for token i if q_i(x) <= p_i(x)
156///     - This means the target model agrees
157/// - Else (q_i(x) > p_i(x)) accept that token with prob p_i(x)/q_i(x)
158///     - If rejected, sample token from from p'_i(x) = norm(max(0, p(x) − q(x))) and do not take any more'
159///
160pub struct SpeculativePipeline {
161    target: Arc<tokio::sync::Mutex<dyn Pipeline>>,
162    draft: Arc<tokio::sync::Mutex<dyn Pipeline>>,
163    gamma: usize,
164    metadata: Arc<GeneralMetadata>,
165    category: ModelCategory,
166}
167
168#[derive(Copy, Clone)]
169/// Metadata for a speculative pipeline
170pub struct SpeculativeConfig {
171    /// γ completions to run of the draft model
172    pub gamma: usize,
173}
174
175impl SpeculativePipeline {
176    pub fn new(
177        target: Arc<tokio::sync::Mutex<dyn Pipeline>>,
178        draft: Arc<tokio::sync::Mutex<dyn Pipeline>>,
179        config: SpeculativeConfig,
180    ) -> Result<Self> {
181        if get_mut_arcmutex!(target)
182            .tokenizer()
183            .as_ref()
184            .ok_or(candle_core::Error::Msg(
185                "`SpeculativePipeline::new` requires the target pipeline to have a token trie"
186                    .to_string(),
187            ))?
188            .get_vocab(true)
189            != get_mut_arcmutex!(draft)
190                .tokenizer()
191                .as_ref()
192                .ok_or(candle_core::Error::Msg(
193                    "`SpeculativePipeline::new` requires the draft pipeline to have a token trie"
194                        .to_string(),
195                ))?
196                .get_vocab(true)
197        {
198            candle_core::bail!("Target and draft models' tokenizer vocab do not match. This is required for speculative decoding.");
199        }
200        if get_mut_arcmutex!(target).category() != get_mut_arcmutex!(draft).category() {
201            candle_core::bail!("Target and draft models' category do not match. This is required for speculative decoding.");
202        }
203        if get_mut_arcmutex!(target)
204            .get_processor()
205            .inputs_processor()
206            .get_type()
207            != get_mut_arcmutex!(draft)
208                .get_processor()
209                .inputs_processor()
210                .get_type()
211        {
212            candle_core::bail!("Target and draft models' input processors do not match. This is required for speculative decoding.");
213        }
214        let metadata = get_mut_arcmutex!(target).get_metadata().clone();
215        let category = get_mut_arcmutex!(target).category();
216        // TODO: some checks or relaxation here?
217        Ok(Self {
218            target,
219            draft,
220            gamma: config.gamma,
221            metadata,
222            category,
223        })
224    }
225}
226
227impl PreProcessingMixin for SpeculativePipeline {
228    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
229        get_mut_arcmutex!(self.target).get_chat_template()
230    }
231    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
232        get_mut_arcmutex!(self.target).get_input_processor_config()
233    }
234}
235
236impl IsqPipelineMixin for SpeculativePipeline {
237    fn re_isq_model(&mut self, dtype: IsqType) -> anyhow::Result<()> {
238        get_mut_arcmutex!(self.target).re_isq_model(dtype)?;
239        get_mut_arcmutex!(self.draft).re_isq_model(dtype)
240    }
241}
242
243impl CacheManagerMixin for SpeculativePipeline {
244    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
245        NormalCacheManager.clone_in_cache(&*get_mut_arcmutex!(self.draft), seqs, true);
246        NormalCacheManager.clone_in_cache(&*get_mut_arcmutex!(self.target), seqs, false);
247    }
248    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
249        NormalCacheManager.clone_out_cache(&*get_mut_arcmutex!(self.draft), seqs, true);
250        NormalCacheManager.clone_out_cache(&*get_mut_arcmutex!(self.target), seqs, false);
251    }
252    fn set_none_cache(
253        &self,
254        seqs: &mut [&mut Sequence],
255        reset_non_granular: bool,
256        modify_draft_cache: bool,
257        load_preallocated_cache: bool,
258    ) {
259        NormalCacheManager.set_none_cache(
260            &*get_mut_arcmutex!(self.draft),
261            seqs,
262            modify_draft_cache,
263            load_preallocated_cache,
264        );
265        NormalCacheManager.set_none_cache(
266            &*get_mut_arcmutex!(self.target),
267            seqs,
268            false,
269            load_preallocated_cache,
270        );
271        if reset_non_granular {
272            self.reset_non_granular_state()
273        }
274    }
275    fn cache(&self) -> &EitherCache {
276        unreachable!()
277    }
278    fn do_preallocated_cache(&self) -> bool {
279        // KV cache size is not the same (necessarily)
280        false
281    }
282}
283
284impl MetadataMixin for SpeculativePipeline {
285    fn device(&self) -> Device {
286        get_mut_arcmutex!(self.target).device()
287    }
288    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
289        get_mut_arcmutex!(self.target).tokenizer()
290    }
291    fn name(&self) -> String {
292        format!(
293            "Speculative: tgt = `{}`, draft = `{}`, gamma = `{}`",
294            get_mut_arcmutex!(self.target).name(),
295            get_mut_arcmutex!(self.draft).name(),
296            self.gamma,
297        )
298    }
299    fn reset_non_granular_state(&self) {
300        get_mut_arcmutex!(self.target).reset_non_granular_state();
301        get_mut_arcmutex!(self.draft).reset_non_granular_state();
302    }
303    fn get_metadata(&self) -> Arc<GeneralMetadata> {
304        self.metadata.clone()
305    }
306    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
307        None
308    }
309}
310
311#[async_trait::async_trait]
312impl Pipeline for SpeculativePipeline {
313    fn forward_inputs(
314        &mut self,
315        _inputs: Box<dyn Any>,
316        _return_raw_logits: bool,
317    ) -> Result<ForwardInputsResult> {
318        unreachable!()
319    }
320    async fn sample_causal_gen(
321        &self,
322        _seqs: &mut [&mut Sequence],
323        _logits: Vec<Tensor>,
324        _prefix_cacher: &mut PrefixCacheManagerV2,
325        _disable_eos_stop: bool,
326        _rng: Arc<std::sync::Mutex<Isaac64Rng>>,
327    ) -> Result<()> {
328        unreachable!()
329    }
330    async fn step(
331        &mut self,
332        input_seqs: &mut [&mut Sequence],
333        is_prompt: bool,
334        _return_raw_logits: bool,
335        prefix_cacher: &mut PrefixCacheManagerV2,
336        disable_eos_stop: bool,
337        rng: Arc<Mutex<Isaac64Rng>>,
338        backend_metadata: CacheBackendMetadata<'_>,
339    ) -> Result<Duration> {
340        match backend_metadata {
341            CacheBackendMetadata::DefaultInstructions { pre_op, post_op } => {
342                match pre_op {
343                    CacheInstruction::In => self.clone_in_cache(input_seqs),
344                    CacheInstruction::Nothing => (),
345                    CacheInstruction::Reset {
346                        reset_non_granular,
347                        load_preallocated_cache,
348                    } => self.set_none_cache(
349                        input_seqs,
350                        reset_non_granular,
351                        true,
352                        load_preallocated_cache,
353                    ),
354                    _ => unreachable!("Unreachable PRE cache op."),
355                }
356
357                let start = Instant::now();
358                assert_eq!(input_seqs.len(), 1);
359
360                let seq = &mut input_seqs[0];
361
362                // ======================= Run draft model gamma times producing tokens ============================
363                // ======================= Sample the `gamma` logits. ============================
364                let mut draft_samples = Vec::new();
365                for i in 0..self.gamma {
366                    let is_xlora = get_mut_arcmutex!(self.draft).get_metadata().is_xlora;
367                    let device = get_mut_arcmutex!(self.draft).device();
368                    let no_kv_cache = get_mut_arcmutex!(self.draft).get_metadata().no_kv_cache;
369                    let inputs = self
370                        .get_processor()
371                        .inputs_processor()
372                        .process_inputs(
373                            self.tokenizer(),
374                            &mut [seq],
375                            is_prompt && i == 0, // Only prompt (no kv cache) if first
376                            is_xlora,
377                            &device,
378                            no_kv_cache,
379                            None,
380                            false,
381                            None,
382                            None, // TODO: get block tables/handle it
383                            None, // TODO: do we support???
384                            get_mut_arcmutex!(self.draft).device_mapper(),
385                        )
386                        .nth(0)
387                        .unwrap()
388                        .unwrap()
389                        .inputs;
390                    let logits = get_mut_arcmutex!(self.draft).forward_inputs(inputs, false)?;
391                    #[allow(irrefutable_let_patterns)]
392                    let ForwardInputsResult::CausalGeneration { logits } = logits
393                    else {
394                        candle_core::bail!(
395                            "Speculative decoding requires `CausalGeneration` forward results"
396                        );
397                    };
398
399                    let sample = sample_sequence(
400                        logits.clone(),
401                        seq,
402                        seq.return_logprobs(),
403                        rng.clone(),
404                        false, // todo tune
405                        false, // do not add to tok trie yet
406                        true,
407                    )
408                    .await?;
409                    seq.add_tmp_tok(sample.token);
410                    draft_samples.push(SpeculativeSample { sample });
411                }
412                seq.remove_tmp_tok(self.gamma);
413
414                // ======================= Add all draft tokens but the last one. Add the last from the seq. ============================
415                let mut draft_prefill_tokens = if is_prompt {
416                    seq.get_toks().to_vec()
417                } else {
418                    vec![*seq.get_toks().last().unwrap()]
419                };
420                for (i, sample) in draft_samples.iter().enumerate() {
421                    if i == draft_samples.len() - 1 {
422                        continue;
423                    }
424                    draft_prefill_tokens.push(sample.sample.token);
425                }
426                seq.set_prefill_toks(draft_prefill_tokens);
427
428                // ======================= Run the model with all draft tokens. ============================
429
430                let initial_cache_len = match get_mut_arcmutex!(self.target).cache() {
431                    EitherCache::Full(full) => full.lock()[0]
432                        .as_ref()
433                        .map(|(k, _)| k.dims()[2])
434                        .unwrap_or(0),
435                    EitherCache::Normal(normal) => normal.lock().unwrap().0[0].current_seq_len(),
436                };
437
438                // ========= Run the model ============
439                let is_xlora = get_mut_arcmutex!(self.target).get_metadata().is_xlora;
440                let device = get_mut_arcmutex!(self.target).device();
441                let no_kv_cache = get_mut_arcmutex!(self.target).get_metadata().no_kv_cache;
442                let inputs = self
443                    .get_processor()
444                    .inputs_processor()
445                    .process_inputs(
446                        self.tokenizer(),
447                        &mut [seq],
448                        true, // use the "prefill" tokens
449                        is_xlora,
450                        &device,
451                        no_kv_cache,
452                        Some((self.gamma, initial_cache_len)), // Get the last gamma, see above
453                        false,
454                        None,
455                        None, // TODO: get block tables/handle it
456                        None, // TODO: do we support???
457                        get_mut_arcmutex!(self.target).device_mapper(),
458                    )
459                    .nth(0)
460                    .unwrap()
461                    .unwrap()
462                    .inputs;
463
464                let logits = get_mut_arcmutex!(self.target).forward_inputs(inputs, false)?;
465                #[allow(irrefutable_let_patterns)]
466                let ForwardInputsResult::CausalGeneration { logits } = logits
467                else {
468                    candle_core::bail!(
469                        "Speculative decoding requires `CausalGeneration` forward results"
470                    );
471                };
472
473                // Reset the prefill tokens
474                seq.reset_prefill_toks();
475
476                // ======================= Rejection sampling. ============================
477                // Map from each target sample to corresponding in draft sample
478                let samples = sample_target_sequence_speculative(
479                    logits.clone(),
480                    seq,
481                    seq.return_logprobs(),
482                    rng.clone(),
483                    self.gamma,
484                )
485                .await?;
486
487                let mut accepted_tokens = Vec::new();
488                for (target_sample, draft_sample) in zip(samples, draft_samples) {
489                    let tok = target_sample.sample.token;
490                    accepted_tokens.push(target_sample.sample);
491                    if draft_sample.sample.token != tok {
492                        break;
493                    }
494                }
495
496                // ======================= Narrow caches to account for rejections ============================
497                let n_not_accepted = self.gamma - accepted_tokens.len();
498                match get_mut_arcmutex!(self.draft).cache() {
499                    EitherCache::Full(full) => {
500                        for (k, v) in full.lock().iter_mut().flatten() {
501                            *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
502                            *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
503                        }
504                    }
505                    EitherCache::Normal(normal) => {
506                        for cache in &mut *normal.lock().unwrap().0 {
507                            cache
508                                .set_len(cache.current_seq_len() - n_not_accepted)
509                                .map_err(|_| candle_core::Error::msg("KV cache set_len failed."))?;
510                        }
511                    }
512                }
513                if get_mut_arcmutex!(self.draft).get_metadata().is_xlora {
514                    match get_mut_arcmutex!(self.draft).cache() {
515                        EitherCache::Full(full) => {
516                            for (k, v) in full.xlora_lock().iter_mut().flatten() {
517                                *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
518                                *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
519                            }
520                        }
521                        EitherCache::Normal(_) => {
522                            unreachable!()
523                        }
524                    }
525                }
526                match get_mut_arcmutex!(self.target).cache() {
527                    EitherCache::Full(full) => {
528                        for (k, v) in full.lock().iter_mut().flatten() {
529                            *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
530                            *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
531                        }
532                    }
533                    EitherCache::Normal(normal) => {
534                        for cache in &mut *normal.lock().unwrap().0 {
535                            cache
536                                .set_len(cache.current_seq_len() - n_not_accepted)
537                                .map_err(|_| candle_core::Error::msg("KV cache set_len failed."))?;
538                        }
539                    }
540                }
541                if get_mut_arcmutex!(self.draft).get_metadata().is_xlora {
542                    match get_mut_arcmutex!(self.target).cache() {
543                        EitherCache::Full(full) => {
544                            for (k, v) in full.xlora_lock().iter_mut().flatten() {
545                                *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
546                                *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
547                            }
548                        }
549                        EitherCache::Normal(_) => {
550                            unreachable!()
551                        }
552                    }
553                }
554
555                let eos_owned = get_mut_arcmutex!(self.target)
556                    .get_metadata()
557                    .eos_tok
558                    .clone();
559                let eos_tok = if disable_eos_stop {
560                    None
561                } else {
562                    Some(&eos_owned[..])
563                };
564                // Add the tokens to the seq and the trie
565                for accepted in accepted_tokens {
566                    // Do not use the prefix cacher
567                    finish_or_add_toks_to_seq(
568                        self,
569                        prefix_cacher,
570                        seq,
571                        accepted.clone(),
572                        eos_tok,
573                        false,
574                    )
575                    .await?;
576                    match seq.recognizer {
577                        SequenceRecognizer::Llguidance(ref mut llg) => {
578                            llg.commit_token(Some(accepted.token))
579                                .map_err(candle_core::Error::msg)?;
580                        }
581                        SequenceRecognizer::None => {}
582                    }
583                }
584
585                // Trick to improve lower bounds. Sample last token in multinomial
586                /*
587                let sample = sample_sequence(
588                    logits.clone(),
589                    seq,
590                    seq.return_logprobs(),
591                    rng.clone(),
592                    false, // todo tune
593                    true, // do not add to tok trie yet
594                    true,
595                )
596                .await?;
597                finish_or_add_toks_to_seq(self, prefix_cacher, seq, sample, eos_tok, false);
598                */
599                let end = Instant::now();
600                let exec_duration = end.duration_since(start);
601
602                match post_op {
603                    CacheInstruction::Out => {
604                        self.clone_out_cache(input_seqs);
605                    }
606                    CacheInstruction::Nothing => (),
607                    CacheInstruction::Reset {
608                        reset_non_granular,
609                        load_preallocated_cache,
610                    } => self.set_none_cache(
611                        input_seqs,
612                        reset_non_granular,
613                        true,
614                        load_preallocated_cache,
615                    ),
616                    _ => unreachable!("Unreachable pre cache op."),
617                }
618
619                // Done! We have:
620                // - Run the draft model gamma times
621                // - Reset draft model cache fully
622                // - Sampled draft model's distributions
623                // - Run target model
624                // - Execute speculative decoding algorithm on the resulting distributions
625                // - Added the accepted tokens to buffer and trie
626                // - Maybe fixed up cache of base model based on accepted tokens.
627
628                Ok(exec_duration)
629            }
630            CacheBackendMetadata::PagedAttention {
631                metadata: _,
632                blocks_to_copy: _,
633                blocks_to_swap_in: _,
634                blocks_to_swap_out: _,
635            } => unreachable!(),
636        }
637    }
638    fn category(&self) -> ModelCategory {
639        self.category.clone()
640    }
641}
642
643impl AnyMoePipelineMixin for SpeculativePipeline {}