1use std::{
2 any::Any,
3 sync::{Arc, Mutex},
4 time::{Duration, Instant},
5};
6
7use anyhow::Result as anyhowResult;
8use candle_core::{Device, IndexOp, Result, Tensor};
9use mistralrs_quant::IsqType;
10use rand_isaac::Isaac64Rng;
11use tokenizers::Tokenizer;
12use tracing::warn;
13
14use crate::{
15 device_map::DeviceMapper,
16 get_mut_arcmutex,
17 kv_cache::NormalCacheManager,
18 pipeline::sampling::{
19 finish_or_add_toks_to_seq, sample_sequence, sample_target_sequence_speculative,
20 },
21 prefix_cacher::PrefixCacheManagerV2,
22 sequence::Sequence,
23 DeviceMapSetting, Loader, ModelKind, PagedAttentionConfig, Pipeline, TokenSource, TryIntoDType,
24};
25
26use crate::kv_cache::CacheManager;
27use crate::utils::progress::ProgressScopeGuard;
28
29use super::{
30 chat_template::ChatTemplate, sampling::SpeculativeSample, AnyMoePipelineMixin,
31 CacheBackendMetadata, CacheInstruction, CacheManagerMixin, EitherCache, ForwardInputsResult,
32 GeneralMetadata, IsqPipelineMixin, MetadataMixin, ModelCategory, ModelPaths,
33 PreProcessingMixin,
34};
35
36pub struct SpeculativeLoader {
38 pub target: Box<dyn Loader>,
39 pub draft: Box<dyn Loader>,
40 pub config: SpeculativeConfig,
41}
42
43impl Loader for SpeculativeLoader {
44 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
45 fn load_model_from_hf(
46 &self,
47 revision: Option<String>,
48 token_source: TokenSource,
49 dtype: &dyn TryIntoDType,
50 device: &Device,
51 silent: bool,
52 mapper: DeviceMapSetting,
53 in_situ_quant: Option<IsqType>,
54 paged_attn_config: Option<PagedAttentionConfig>,
55 ) -> anyhowResult<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
56 let _progress_guard = ProgressScopeGuard::new(silent);
57 let paged_attn_config = if paged_attn_config.is_none() {
58 warn!(
59 "Speculative decoding does not currently support PagedAttention, running without"
60 );
61 None
62 } else {
63 paged_attn_config
64 };
65
66 let target = self.target.load_model_from_hf(
67 revision.clone(),
68 token_source.clone(),
69 dtype,
70 device,
71 silent,
72 mapper.clone(),
73 in_situ_quant,
74 paged_attn_config,
75 )?;
76 let draft = self.draft.load_model_from_hf(
77 revision,
78 token_source,
79 dtype,
80 device,
81 silent,
82 mapper,
83 in_situ_quant,
84 paged_attn_config,
85 )?;
86 Ok(Arc::new(tokio::sync::Mutex::new(SpeculativePipeline::new(
87 target,
88 draft,
89 self.config,
90 )?)))
91 }
92
93 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
94 fn load_model_from_path(
95 &self,
96 paths: &Box<dyn ModelPaths>,
97 dtype: &dyn TryIntoDType,
98 device: &Device,
99 silent: bool,
100 mapper: DeviceMapSetting,
101 in_situ_quant: Option<IsqType>,
102 paged_attn_config: Option<PagedAttentionConfig>,
103 ) -> anyhowResult<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
104 let _progress_guard = ProgressScopeGuard::new(silent);
105 let paged_attn_config = if paged_attn_config.is_none() {
106 warn!(
107 "Speculative decoding does not currently support PagedAttention, running without"
108 );
109 None
110 } else {
111 paged_attn_config
112 };
113
114 let target = self.target.load_model_from_path(
115 paths,
116 dtype,
117 device,
118 silent,
119 mapper.clone(),
120 in_situ_quant,
121 paged_attn_config,
122 )?;
123 let draft = self.draft.load_model_from_path(
124 paths,
125 dtype,
126 device,
127 silent,
128 mapper.clone(),
129 in_situ_quant,
130 paged_attn_config,
131 )?;
132 Ok(Arc::new(tokio::sync::Mutex::new(SpeculativePipeline::new(
133 target,
134 draft,
135 self.config,
136 )?)))
137 }
138 fn get_id(&self) -> String {
139 format!(
140 "Speculative: tgt = `{}`, draft = `{}`, gamma = `{}`",
141 self.target.get_id(),
142 self.draft.get_id(),
143 self.config.gamma,
144 )
145 }
146 fn get_kind(&self) -> ModelKind {
147 ModelKind::Speculative {
148 target: Box::new(self.target.get_kind()),
149 draft: Box::new(self.draft.get_kind()),
150 }
151 }
152}
153
154pub struct SpeculativePipeline {
166 target: Arc<tokio::sync::Mutex<dyn Pipeline>>,
167 draft: Arc<tokio::sync::Mutex<dyn Pipeline>>,
168 gamma: usize,
169 metadata: Arc<GeneralMetadata>,
170 category: ModelCategory,
171}
172
173#[derive(Copy, Clone)]
174pub struct SpeculativeConfig {
176 pub gamma: usize,
178}
179
180impl SpeculativePipeline {
181 pub fn new(
182 target: Arc<tokio::sync::Mutex<dyn Pipeline>>,
183 draft: Arc<tokio::sync::Mutex<dyn Pipeline>>,
184 config: SpeculativeConfig,
185 ) -> Result<Self> {
186 if get_mut_arcmutex!(target)
187 .tokenizer()
188 .as_ref()
189 .ok_or(candle_core::Error::Msg(
190 "`SpeculativePipeline::new` requires the target pipeline to have a token trie"
191 .to_string(),
192 ))?
193 .get_vocab(true)
194 != get_mut_arcmutex!(draft)
195 .tokenizer()
196 .as_ref()
197 .ok_or(candle_core::Error::Msg(
198 "`SpeculativePipeline::new` requires the draft pipeline to have a token trie"
199 .to_string(),
200 ))?
201 .get_vocab(true)
202 {
203 candle_core::bail!("Target and draft models' tokenizer vocab do not match. This is required for speculative decoding.");
204 }
205 if get_mut_arcmutex!(target).category() != get_mut_arcmutex!(draft).category() {
206 candle_core::bail!("Target and draft models' category do not match. This is required for speculative decoding.");
207 }
208 if get_mut_arcmutex!(target)
209 .get_processor()
210 .inputs_processor()
211 .get_type()
212 != get_mut_arcmutex!(draft)
213 .get_processor()
214 .inputs_processor()
215 .get_type()
216 {
217 candle_core::bail!("Target and draft models' input processors do not match. This is required for speculative decoding.");
218 }
219 let metadata = get_mut_arcmutex!(target).get_metadata().clone();
220 let category = get_mut_arcmutex!(target).category();
221 Ok(Self {
223 target,
224 draft,
225 gamma: config.gamma,
226 metadata,
227 category,
228 })
229 }
230}
231
232impl PreProcessingMixin for SpeculativePipeline {
233 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
234 get_mut_arcmutex!(self.target).get_chat_template()
235 }
236 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
237 get_mut_arcmutex!(self.target).get_input_processor_config()
238 }
239}
240
241impl IsqPipelineMixin for SpeculativePipeline {
242 fn re_isq_model(&mut self, dtype: IsqType) -> anyhow::Result<()> {
243 get_mut_arcmutex!(self.target).re_isq_model(dtype)?;
244 get_mut_arcmutex!(self.draft).re_isq_model(dtype)
245 }
246}
247
248impl CacheManagerMixin for SpeculativePipeline {
249 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
250 NormalCacheManager.clone_in_cache(&*get_mut_arcmutex!(self.draft), seqs, true);
251 NormalCacheManager.clone_in_cache(&*get_mut_arcmutex!(self.target), seqs, false);
252 }
253 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
254 NormalCacheManager.clone_out_cache(&*get_mut_arcmutex!(self.draft), seqs, true);
255 NormalCacheManager.clone_out_cache(&*get_mut_arcmutex!(self.target), seqs, false);
256 }
257 fn set_none_cache(
258 &self,
259 seqs: &mut [&mut Sequence],
260 reset_non_granular: bool,
261 modify_draft_cache: bool,
262 load_preallocated_cache: bool,
263 ) {
264 NormalCacheManager.set_none_cache(
265 &*get_mut_arcmutex!(self.draft),
266 seqs,
267 modify_draft_cache,
268 load_preallocated_cache,
269 );
270 NormalCacheManager.set_none_cache(
271 &*get_mut_arcmutex!(self.target),
272 seqs,
273 false,
274 load_preallocated_cache,
275 );
276 if reset_non_granular {
277 self.reset_non_granular_state()
278 }
279 }
280 fn cache(&self) -> &EitherCache {
281 unreachable!()
282 }
283 fn do_preallocated_cache(&self) -> bool {
284 false
286 }
287}
288
289impl MetadataMixin for SpeculativePipeline {
290 fn device(&self) -> Device {
291 get_mut_arcmutex!(self.target).device()
292 }
293 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
294 get_mut_arcmutex!(self.target).tokenizer()
295 }
296 fn name(&self) -> String {
297 format!(
298 "Speculative: tgt = `{}`, draft = `{}`, gamma = `{}`",
299 get_mut_arcmutex!(self.target).name(),
300 get_mut_arcmutex!(self.draft).name(),
301 self.gamma,
302 )
303 }
304 fn reset_non_granular_state(&self) {
305 get_mut_arcmutex!(self.target).reset_non_granular_state();
306 get_mut_arcmutex!(self.draft).reset_non_granular_state();
307 }
308 fn get_metadata(&self) -> Arc<GeneralMetadata> {
309 self.metadata.clone()
310 }
311 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
312 None
313 }
314}
315
316#[async_trait::async_trait]
317impl Pipeline for SpeculativePipeline {
318 fn forward_inputs(
319 &mut self,
320 _inputs: Box<dyn Any>,
321 _return_raw_logits: bool,
322 ) -> Result<ForwardInputsResult> {
323 unreachable!()
324 }
325 async fn sample_causal_gen(
326 &self,
327 _seqs: &mut [&mut Sequence],
328 _logits: Vec<Tensor>,
329 _prefix_cacher: &mut PrefixCacheManagerV2,
330 _disable_eos_stop: bool,
331 _rng: Arc<std::sync::Mutex<Isaac64Rng>>,
332 ) -> Result<()> {
333 unreachable!()
334 }
335 async fn step(
336 &mut self,
337 input_seqs: &mut [&mut Sequence],
338 is_prompt: bool,
339 _return_raw_logits: bool,
340 prefix_cacher: &mut PrefixCacheManagerV2,
341 disable_eos_stop: bool,
342 rng: Arc<Mutex<Isaac64Rng>>,
343 backend_metadata: CacheBackendMetadata,
344 ) -> Result<Duration> {
345 match backend_metadata {
346 CacheBackendMetadata::DefaultInstructions { pre_op, post_op } => {
347 match pre_op {
348 CacheInstruction::In => self.clone_in_cache(input_seqs),
349 CacheInstruction::Nothing => (),
350 CacheInstruction::Reset {
351 reset_non_granular,
352 load_preallocated_cache,
353 } => self.set_none_cache(
354 input_seqs,
355 reset_non_granular,
356 true,
357 load_preallocated_cache,
358 ),
359 _ => unreachable!("Unreachable PRE cache op."),
360 }
361
362 let start = Instant::now();
363 assert_eq!(input_seqs.len(), 1);
364
365 let seq = &mut input_seqs[0];
366
367 let mut draft_samples = Vec::new();
370 for i in 0..self.gamma {
371 let is_xlora = get_mut_arcmutex!(self.draft).get_metadata().is_xlora;
372 let device = get_mut_arcmutex!(self.draft).device();
373 let no_kv_cache = get_mut_arcmutex!(self.draft).get_metadata().no_kv_cache;
374 let inputs = self
375 .get_processor()
376 .inputs_processor()
377 .process_inputs(
378 self.tokenizer(),
379 &mut [seq],
380 is_prompt && i == 0, is_xlora,
382 &device,
383 no_kv_cache,
384 None,
385 false,
386 None,
387 None, get_mut_arcmutex!(self.draft).device_mapper(),
389 )
390 .unwrap()
391 .inputs;
392 let logits = get_mut_arcmutex!(self.draft).forward_inputs(inputs, false)?;
393 #[allow(irrefutable_let_patterns)]
394 let ForwardInputsResult::CausalGeneration { logits } = logits
395 else {
396 candle_core::bail!(
397 "Speculative decoding requires `CausalGeneration` forward results"
398 );
399 };
400
401 let sample = sample_sequence(
402 logits.clone(),
403 seq,
404 seq.return_logprobs(),
405 rng.clone(),
406 false, true,
408 false,
409 )
410 .await?;
411 seq.add_tmp_tok(sample.token);
412 draft_samples.push(SpeculativeSample { sample });
413 }
414 seq.remove_tmp_tok(self.gamma);
415
416 let mut draft_prefill_tokens = if is_prompt {
418 seq.get_toks().to_vec()
419 } else {
420 vec![*seq.get_toks().last().unwrap()]
421 };
422 for (i, sample) in draft_samples.iter().enumerate() {
423 if i == draft_samples.len() - 1 {
424 continue;
425 }
426 draft_prefill_tokens.push(sample.sample.token);
427 }
428 seq.set_prefill_toks(draft_prefill_tokens);
429
430 let initial_cache_len = match get_mut_arcmutex!(self.target).cache() {
433 EitherCache::Full(full) => full.lock()[0]
434 .as_ref()
435 .map(|(k, _)| k.dims()[2])
436 .unwrap_or(0),
437 EitherCache::Normal(normal) => normal.lock().unwrap().0[0].current_seq_len(),
438 EitherCache::Hybrid(_) => {
439 unreachable!("Speculative decoding is not supported with hybrid caches")
440 }
441 };
442
443 let is_xlora = get_mut_arcmutex!(self.target).get_metadata().is_xlora;
445 let device = get_mut_arcmutex!(self.target).device();
446 let no_kv_cache = get_mut_arcmutex!(self.target).get_metadata().no_kv_cache;
447 let inputs = self
448 .get_processor()
449 .inputs_processor()
450 .process_inputs(
451 self.tokenizer(),
452 &mut [seq],
453 true, is_xlora,
455 &device,
456 no_kv_cache,
457 Some((self.gamma, initial_cache_len)), false,
459 None,
460 None, get_mut_arcmutex!(self.target).device_mapper(),
462 )
463 .unwrap()
464 .inputs;
465
466 let logits = get_mut_arcmutex!(self.target).forward_inputs(inputs, false)?;
467 #[allow(irrefutable_let_patterns)]
468 let ForwardInputsResult::CausalGeneration { logits } = logits
469 else {
470 candle_core::bail!(
471 "Speculative decoding requires `CausalGeneration` forward results"
472 );
473 };
474
475 seq.reset_prefill_toks();
477
478 let samples = sample_target_sequence_speculative(
482 logits.clone(),
483 seq,
484 seq.return_logprobs(),
485 rng.clone(),
486 &draft_samples,
487 )
488 .await?;
489
490 let accepted_tokens = samples.into_iter().map(|s| s.sample).collect::<Vec<_>>();
491
492 let n_not_accepted = self.gamma - accepted_tokens.len();
494
495 match get_mut_arcmutex!(self.draft).cache() {
496 EitherCache::Full(full) => {
497 for (k, v) in full.lock().iter_mut().flatten() {
498 *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
499 *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
500 }
501 }
502 EitherCache::Normal(normal) => {
503 for cache in &mut *normal.lock().unwrap().0 {
504 cache
505 .set_len(cache.current_seq_len() - n_not_accepted)
506 .map_err(|_| candle_core::Error::msg("KV cache set_len failed."))?;
507 }
508 }
509 EitherCache::Hybrid(_) => {
510 unreachable!("Speculative decoding is not supported with hybrid caches")
511 }
512 }
513 if get_mut_arcmutex!(self.draft).get_metadata().is_xlora {
514 match get_mut_arcmutex!(self.draft).cache() {
515 EitherCache::Full(full) => {
516 for (k, v) in full.xlora_lock().iter_mut().flatten() {
517 *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
518 *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
519 }
520 }
521 EitherCache::Normal(_) | EitherCache::Hybrid(_) => {
522 unreachable!()
523 }
524 }
525 }
526 match get_mut_arcmutex!(self.target).cache() {
527 EitherCache::Full(full) => {
528 for (k, v) in full.lock().iter_mut().flatten() {
529 *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
530 *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
531 }
532 }
533 EitherCache::Normal(normal) => {
534 for cache in &mut *normal.lock().unwrap().0 {
535 cache
536 .set_len(cache.current_seq_len() - n_not_accepted)
537 .map_err(|_| candle_core::Error::msg("KV cache set_len failed."))?;
538 }
539 }
540 EitherCache::Hybrid(_) => {
541 unreachable!("Speculative decoding is not supported with hybrid caches")
542 }
543 }
544 if get_mut_arcmutex!(self.draft).get_metadata().is_xlora {
545 match get_mut_arcmutex!(self.target).cache() {
546 EitherCache::Full(full) => {
547 for (k, v) in full.xlora_lock().iter_mut().flatten() {
548 *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
549 *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
550 }
551 }
552 EitherCache::Normal(_) | EitherCache::Hybrid(_) => {
553 unreachable!()
554 }
555 }
556 }
557
558 let eos_owned = get_mut_arcmutex!(self.target)
559 .get_metadata()
560 .eos_tok
561 .clone();
562 let eos_tok = if disable_eos_stop {
563 None
564 } else {
565 Some(&eos_owned[..])
566 };
567 for accepted in accepted_tokens {
569 finish_or_add_toks_to_seq(
571 self,
572 prefix_cacher,
573 seq,
574 accepted.clone(),
575 eos_tok,
576 false,
577 )
578 .await?;
579 }
580
581 let end = Instant::now();
596 let exec_duration = end.duration_since(start);
597
598 match post_op {
599 CacheInstruction::Out => {
600 self.clone_out_cache(input_seqs);
601 }
602 CacheInstruction::Nothing => (),
603 CacheInstruction::Reset {
604 reset_non_granular,
605 load_preallocated_cache,
606 } => self.set_none_cache(
607 input_seqs,
608 reset_non_granular,
609 true,
610 load_preallocated_cache,
611 ),
612 _ => unreachable!("Unreachable pre cache op."),
613 }
614
615 Ok(exec_duration)
625 }
626 CacheBackendMetadata::PagedAttention {
627 metadata: _,
628 blocks_to_copy: _,
629 } => unreachable!(),
630 }
631 }
632 fn category(&self) -> ModelCategory {
633 self.category.clone()
634 }
635}
636
637impl AnyMoePipelineMixin for SpeculativePipeline {}