1use std::{
2 any::Any,
3 iter::zip,
4 sync::{Arc, Mutex},
5 time::{Duration, Instant},
6};
7
8use anyhow::Result as anyhowResult;
9use candle_core::{Device, IndexOp, Result, Tensor};
10use mistralrs_quant::IsqType;
11use rand_isaac::Isaac64Rng;
12use tokenizers::Tokenizer;
13use tracing::warn;
14
15use crate::{
16 device_map::DeviceMapper,
17 get_mut_arcmutex,
18 pipeline::sampling::{
19 finish_or_add_toks_to_seq, sample_sequence, sample_target_sequence_speculative,
20 },
21 prefix_cacher::PrefixCacheManagerV2,
22 sequence::{Sequence, SequenceRecognizer},
23 DeviceMapSetting, Loader, ModelKind, PagedAttentionConfig, Pipeline, TokenSource, TryIntoDType,
24};
25
26use super::{
27 cache_manager::NormalCacheManager, chat_template::ChatTemplate, sampling::SpeculativeSample,
28 AnyMoePipelineMixin, CacheBackendMetadata, CacheInstruction, CacheManager, CacheManagerMixin,
29 EitherCache, ForwardInputsResult, GeneralMetadata, IsqPipelineMixin, MetadataMixin,
30 ModelCategory, ModelPaths, PreProcessingMixin,
31};
32
33pub struct SpeculativeLoader {
35 pub target: Box<dyn Loader>,
36 pub draft: Box<dyn Loader>,
37 pub config: SpeculativeConfig,
38}
39
40impl Loader for SpeculativeLoader {
41 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
42 fn load_model_from_hf(
43 &self,
44 revision: Option<String>,
45 token_source: TokenSource,
46 dtype: &dyn TryIntoDType,
47 device: &Device,
48 silent: bool,
49 mapper: DeviceMapSetting,
50 in_situ_quant: Option<IsqType>,
51 paged_attn_config: Option<PagedAttentionConfig>,
52 ) -> anyhowResult<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
53 let paged_attn_config = if paged_attn_config.is_none() {
54 warn!(
55 "Speculative decoding does not currently support PagedAttention, running without"
56 );
57 None
58 } else {
59 paged_attn_config
60 };
61
62 let target = self.target.load_model_from_hf(
63 revision.clone(),
64 token_source.clone(),
65 dtype,
66 device,
67 silent,
68 mapper.clone(),
69 in_situ_quant,
70 paged_attn_config,
71 )?;
72 let draft = self.draft.load_model_from_hf(
73 revision,
74 token_source,
75 dtype,
76 device,
77 silent,
78 mapper,
79 in_situ_quant,
80 paged_attn_config,
81 )?;
82 Ok(Arc::new(tokio::sync::Mutex::new(SpeculativePipeline::new(
83 target,
84 draft,
85 self.config,
86 )?)))
87 }
88
89 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
90 fn load_model_from_path(
91 &self,
92 paths: &Box<dyn ModelPaths>,
93 dtype: &dyn TryIntoDType,
94 device: &Device,
95 silent: bool,
96 mapper: DeviceMapSetting,
97 in_situ_quant: Option<IsqType>,
98 paged_attn_config: Option<PagedAttentionConfig>,
99 ) -> anyhowResult<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
100 let paged_attn_config = if paged_attn_config.is_none() {
101 warn!(
102 "Speculative decoding does not currently support PagedAttention, running without"
103 );
104 None
105 } else {
106 paged_attn_config
107 };
108
109 let target = self.target.load_model_from_path(
110 paths,
111 dtype,
112 device,
113 silent,
114 mapper.clone(),
115 in_situ_quant,
116 paged_attn_config,
117 )?;
118 let draft = self.draft.load_model_from_path(
119 paths,
120 dtype,
121 device,
122 silent,
123 mapper.clone(),
124 in_situ_quant,
125 paged_attn_config,
126 )?;
127 Ok(Arc::new(tokio::sync::Mutex::new(SpeculativePipeline::new(
128 target,
129 draft,
130 self.config,
131 )?)))
132 }
133 fn get_id(&self) -> String {
134 format!(
135 "Speculative: tgt = `{}`, draft = `{}`, gamma = `{}`",
136 self.target.get_id(),
137 self.draft.get_id(),
138 self.config.gamma,
139 )
140 }
141 fn get_kind(&self) -> ModelKind {
142 ModelKind::Speculative {
143 target: Box::new(self.target.get_kind()),
144 draft: Box::new(self.draft.get_kind()),
145 }
146 }
147}
148
149pub struct SpeculativePipeline {
161 target: Arc<tokio::sync::Mutex<dyn Pipeline>>,
162 draft: Arc<tokio::sync::Mutex<dyn Pipeline>>,
163 gamma: usize,
164 metadata: Arc<GeneralMetadata>,
165 category: ModelCategory,
166}
167
168#[derive(Copy, Clone)]
169pub struct SpeculativeConfig {
171 pub gamma: usize,
173}
174
175impl SpeculativePipeline {
176 pub fn new(
177 target: Arc<tokio::sync::Mutex<dyn Pipeline>>,
178 draft: Arc<tokio::sync::Mutex<dyn Pipeline>>,
179 config: SpeculativeConfig,
180 ) -> Result<Self> {
181 if get_mut_arcmutex!(target)
182 .tokenizer()
183 .as_ref()
184 .ok_or(candle_core::Error::Msg(
185 "`SpeculativePipeline::new` requires the target pipeline to have a token trie"
186 .to_string(),
187 ))?
188 .get_vocab(true)
189 != get_mut_arcmutex!(draft)
190 .tokenizer()
191 .as_ref()
192 .ok_or(candle_core::Error::Msg(
193 "`SpeculativePipeline::new` requires the draft pipeline to have a token trie"
194 .to_string(),
195 ))?
196 .get_vocab(true)
197 {
198 candle_core::bail!("Target and draft models' tokenizer vocab do not match. This is required for speculative decoding.");
199 }
200 if get_mut_arcmutex!(target).category() != get_mut_arcmutex!(draft).category() {
201 candle_core::bail!("Target and draft models' category do not match. This is required for speculative decoding.");
202 }
203 if get_mut_arcmutex!(target)
204 .get_processor()
205 .inputs_processor()
206 .get_type()
207 != get_mut_arcmutex!(draft)
208 .get_processor()
209 .inputs_processor()
210 .get_type()
211 {
212 candle_core::bail!("Target and draft models' input processors do not match. This is required for speculative decoding.");
213 }
214 let metadata = get_mut_arcmutex!(target).get_metadata().clone();
215 let category = get_mut_arcmutex!(target).category();
216 Ok(Self {
218 target,
219 draft,
220 gamma: config.gamma,
221 metadata,
222 category,
223 })
224 }
225}
226
227impl PreProcessingMixin for SpeculativePipeline {
228 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
229 get_mut_arcmutex!(self.target).get_chat_template()
230 }
231 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
232 get_mut_arcmutex!(self.target).get_input_processor_config()
233 }
234}
235
236impl IsqPipelineMixin for SpeculativePipeline {
237 fn re_isq_model(&mut self, dtype: IsqType) -> anyhow::Result<()> {
238 get_mut_arcmutex!(self.target).re_isq_model(dtype)?;
239 get_mut_arcmutex!(self.draft).re_isq_model(dtype)
240 }
241}
242
243impl CacheManagerMixin for SpeculativePipeline {
244 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
245 NormalCacheManager.clone_in_cache(&*get_mut_arcmutex!(self.draft), seqs, true);
246 NormalCacheManager.clone_in_cache(&*get_mut_arcmutex!(self.target), seqs, false);
247 }
248 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
249 NormalCacheManager.clone_out_cache(&*get_mut_arcmutex!(self.draft), seqs, true);
250 NormalCacheManager.clone_out_cache(&*get_mut_arcmutex!(self.target), seqs, false);
251 }
252 fn set_none_cache(
253 &self,
254 seqs: &mut [&mut Sequence],
255 reset_non_granular: bool,
256 modify_draft_cache: bool,
257 load_preallocated_cache: bool,
258 ) {
259 NormalCacheManager.set_none_cache(
260 &*get_mut_arcmutex!(self.draft),
261 seqs,
262 modify_draft_cache,
263 load_preallocated_cache,
264 );
265 NormalCacheManager.set_none_cache(
266 &*get_mut_arcmutex!(self.target),
267 seqs,
268 false,
269 load_preallocated_cache,
270 );
271 if reset_non_granular {
272 self.reset_non_granular_state()
273 }
274 }
275 fn cache(&self) -> &EitherCache {
276 unreachable!()
277 }
278 fn do_preallocated_cache(&self) -> bool {
279 false
281 }
282}
283
284impl MetadataMixin for SpeculativePipeline {
285 fn device(&self) -> Device {
286 get_mut_arcmutex!(self.target).device()
287 }
288 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
289 get_mut_arcmutex!(self.target).tokenizer()
290 }
291 fn name(&self) -> String {
292 format!(
293 "Speculative: tgt = `{}`, draft = `{}`, gamma = `{}`",
294 get_mut_arcmutex!(self.target).name(),
295 get_mut_arcmutex!(self.draft).name(),
296 self.gamma,
297 )
298 }
299 fn reset_non_granular_state(&self) {
300 get_mut_arcmutex!(self.target).reset_non_granular_state();
301 get_mut_arcmutex!(self.draft).reset_non_granular_state();
302 }
303 fn get_metadata(&self) -> Arc<GeneralMetadata> {
304 self.metadata.clone()
305 }
306 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
307 None
308 }
309}
310
311#[async_trait::async_trait]
312impl Pipeline for SpeculativePipeline {
313 fn forward_inputs(
314 &mut self,
315 _inputs: Box<dyn Any>,
316 _return_raw_logits: bool,
317 ) -> Result<ForwardInputsResult> {
318 unreachable!()
319 }
320 async fn sample_causal_gen(
321 &self,
322 _seqs: &mut [&mut Sequence],
323 _logits: Vec<Tensor>,
324 _prefix_cacher: &mut PrefixCacheManagerV2,
325 _disable_eos_stop: bool,
326 _rng: Arc<std::sync::Mutex<Isaac64Rng>>,
327 ) -> Result<()> {
328 unreachable!()
329 }
330 async fn step(
331 &mut self,
332 input_seqs: &mut [&mut Sequence],
333 is_prompt: bool,
334 _return_raw_logits: bool,
335 prefix_cacher: &mut PrefixCacheManagerV2,
336 disable_eos_stop: bool,
337 rng: Arc<Mutex<Isaac64Rng>>,
338 backend_metadata: CacheBackendMetadata<'_>,
339 ) -> Result<Duration> {
340 match backend_metadata {
341 CacheBackendMetadata::DefaultInstructions { pre_op, post_op } => {
342 match pre_op {
343 CacheInstruction::In => self.clone_in_cache(input_seqs),
344 CacheInstruction::Nothing => (),
345 CacheInstruction::Reset {
346 reset_non_granular,
347 load_preallocated_cache,
348 } => self.set_none_cache(
349 input_seqs,
350 reset_non_granular,
351 true,
352 load_preallocated_cache,
353 ),
354 _ => unreachable!("Unreachable PRE cache op."),
355 }
356
357 let start = Instant::now();
358 assert_eq!(input_seqs.len(), 1);
359
360 let seq = &mut input_seqs[0];
361
362 let mut draft_samples = Vec::new();
365 for i in 0..self.gamma {
366 let is_xlora = get_mut_arcmutex!(self.draft).get_metadata().is_xlora;
367 let device = get_mut_arcmutex!(self.draft).device();
368 let no_kv_cache = get_mut_arcmutex!(self.draft).get_metadata().no_kv_cache;
369 let inputs = self
370 .get_processor()
371 .inputs_processor()
372 .process_inputs(
373 self.tokenizer(),
374 &mut [seq],
375 is_prompt && i == 0, is_xlora,
377 &device,
378 no_kv_cache,
379 None,
380 false,
381 None,
382 None, None, get_mut_arcmutex!(self.draft).device_mapper(),
385 )
386 .nth(0)
387 .unwrap()
388 .unwrap()
389 .inputs;
390 let logits = get_mut_arcmutex!(self.draft).forward_inputs(inputs, false)?;
391 #[allow(irrefutable_let_patterns)]
392 let ForwardInputsResult::CausalGeneration { logits } = logits
393 else {
394 candle_core::bail!(
395 "Speculative decoding requires `CausalGeneration` forward results"
396 );
397 };
398
399 let sample = sample_sequence(
400 logits.clone(),
401 seq,
402 seq.return_logprobs(),
403 rng.clone(),
404 false, false, true,
407 )
408 .await?;
409 seq.add_tmp_tok(sample.token);
410 draft_samples.push(SpeculativeSample { sample });
411 }
412 seq.remove_tmp_tok(self.gamma);
413
414 let mut draft_prefill_tokens = if is_prompt {
416 seq.get_toks().to_vec()
417 } else {
418 vec![*seq.get_toks().last().unwrap()]
419 };
420 for (i, sample) in draft_samples.iter().enumerate() {
421 if i == draft_samples.len() - 1 {
422 continue;
423 }
424 draft_prefill_tokens.push(sample.sample.token);
425 }
426 seq.set_prefill_toks(draft_prefill_tokens);
427
428 let initial_cache_len = match get_mut_arcmutex!(self.target).cache() {
431 EitherCache::Full(full) => full.lock()[0]
432 .as_ref()
433 .map(|(k, _)| k.dims()[2])
434 .unwrap_or(0),
435 EitherCache::Normal(normal) => normal.lock().unwrap().0[0].current_seq_len(),
436 };
437
438 let is_xlora = get_mut_arcmutex!(self.target).get_metadata().is_xlora;
440 let device = get_mut_arcmutex!(self.target).device();
441 let no_kv_cache = get_mut_arcmutex!(self.target).get_metadata().no_kv_cache;
442 let inputs = self
443 .get_processor()
444 .inputs_processor()
445 .process_inputs(
446 self.tokenizer(),
447 &mut [seq],
448 true, is_xlora,
450 &device,
451 no_kv_cache,
452 Some((self.gamma, initial_cache_len)), false,
454 None,
455 None, None, get_mut_arcmutex!(self.target).device_mapper(),
458 )
459 .nth(0)
460 .unwrap()
461 .unwrap()
462 .inputs;
463
464 let logits = get_mut_arcmutex!(self.target).forward_inputs(inputs, false)?;
465 #[allow(irrefutable_let_patterns)]
466 let ForwardInputsResult::CausalGeneration { logits } = logits
467 else {
468 candle_core::bail!(
469 "Speculative decoding requires `CausalGeneration` forward results"
470 );
471 };
472
473 seq.reset_prefill_toks();
475
476 let samples = sample_target_sequence_speculative(
479 logits.clone(),
480 seq,
481 seq.return_logprobs(),
482 rng.clone(),
483 self.gamma,
484 )
485 .await?;
486
487 let mut accepted_tokens = Vec::new();
488 for (target_sample, draft_sample) in zip(samples, draft_samples) {
489 let tok = target_sample.sample.token;
490 accepted_tokens.push(target_sample.sample);
491 if draft_sample.sample.token != tok {
492 break;
493 }
494 }
495
496 let n_not_accepted = self.gamma - accepted_tokens.len();
498 match get_mut_arcmutex!(self.draft).cache() {
499 EitherCache::Full(full) => {
500 for (k, v) in full.lock().iter_mut().flatten() {
501 *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
502 *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
503 }
504 }
505 EitherCache::Normal(normal) => {
506 for cache in &mut *normal.lock().unwrap().0 {
507 cache
508 .set_len(cache.current_seq_len() - n_not_accepted)
509 .map_err(|_| candle_core::Error::msg("KV cache set_len failed."))?;
510 }
511 }
512 }
513 if get_mut_arcmutex!(self.draft).get_metadata().is_xlora {
514 match get_mut_arcmutex!(self.draft).cache() {
515 EitherCache::Full(full) => {
516 for (k, v) in full.xlora_lock().iter_mut().flatten() {
517 *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
518 *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
519 }
520 }
521 EitherCache::Normal(_) => {
522 unreachable!()
523 }
524 }
525 }
526 match get_mut_arcmutex!(self.target).cache() {
527 EitherCache::Full(full) => {
528 for (k, v) in full.lock().iter_mut().flatten() {
529 *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
530 *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
531 }
532 }
533 EitherCache::Normal(normal) => {
534 for cache in &mut *normal.lock().unwrap().0 {
535 cache
536 .set_len(cache.current_seq_len() - n_not_accepted)
537 .map_err(|_| candle_core::Error::msg("KV cache set_len failed."))?;
538 }
539 }
540 }
541 if get_mut_arcmutex!(self.draft).get_metadata().is_xlora {
542 match get_mut_arcmutex!(self.target).cache() {
543 EitherCache::Full(full) => {
544 for (k, v) in full.xlora_lock().iter_mut().flatten() {
545 *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
546 *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
547 }
548 }
549 EitherCache::Normal(_) => {
550 unreachable!()
551 }
552 }
553 }
554
555 let eos_owned = get_mut_arcmutex!(self.target)
556 .get_metadata()
557 .eos_tok
558 .clone();
559 let eos_tok = if disable_eos_stop {
560 None
561 } else {
562 Some(&eos_owned[..])
563 };
564 for accepted in accepted_tokens {
566 finish_or_add_toks_to_seq(
568 self,
569 prefix_cacher,
570 seq,
571 accepted.clone(),
572 eos_tok,
573 false,
574 )
575 .await?;
576 match seq.recognizer {
577 SequenceRecognizer::Llguidance(ref mut llg) => {
578 llg.commit_token(Some(accepted.token))
579 .map_err(candle_core::Error::msg)?;
580 }
581 SequenceRecognizer::None => {}
582 }
583 }
584
585 let end = Instant::now();
600 let exec_duration = end.duration_since(start);
601
602 match post_op {
603 CacheInstruction::Out => {
604 self.clone_out_cache(input_seqs);
605 }
606 CacheInstruction::Nothing => (),
607 CacheInstruction::Reset {
608 reset_non_granular,
609 load_preallocated_cache,
610 } => self.set_none_cache(
611 input_seqs,
612 reset_non_granular,
613 true,
614 load_preallocated_cache,
615 ),
616 _ => unreachable!("Unreachable pre cache op."),
617 }
618
619 Ok(exec_duration)
629 }
630 CacheBackendMetadata::PagedAttention {
631 metadata: _,
632 blocks_to_copy: _,
633 blocks_to_swap_in: _,
634 blocks_to_swap_out: _,
635 } => unreachable!(),
636 }
637 }
638 fn category(&self) -> ModelCategory {
639 self.category.clone()
640 }
641}
642
643impl AnyMoePipelineMixin for SpeculativePipeline {}