1use std::{
2 any::Any,
3 sync::{Arc, Mutex},
4 time::{Duration, Instant},
5};
6
7use anyhow::Result as anyhowResult;
8use candle_core::{Device, IndexOp, Result, Tensor};
9use mistralrs_quant::IsqType;
10use rand_isaac::Isaac64Rng;
11use tokenizers::Tokenizer;
12use tracing::warn;
13
14use crate::{
15 device_map::DeviceMapper,
16 get_mut_arcmutex,
17 kv_cache::NormalCacheManager,
18 pipeline::sampling::{
19 finish_or_add_toks_to_seq, sample_sequence, sample_target_sequence_speculative,
20 },
21 prefix_cacher::PrefixCacheManagerV2,
22 sequence::Sequence,
23 DeviceMapSetting, Loader, ModelKind, PagedAttentionConfig, Pipeline, TokenSource, TryIntoDType,
24};
25
26use crate::kv_cache::CacheManager;
27
28use super::{
29 chat_template::ChatTemplate, sampling::SpeculativeSample, AnyMoePipelineMixin,
30 CacheBackendMetadata, CacheInstruction, CacheManagerMixin, EitherCache, ForwardInputsResult,
31 GeneralMetadata, IsqPipelineMixin, MetadataMixin, ModelCategory, ModelPaths,
32 PreProcessingMixin,
33};
34
35pub struct SpeculativeLoader {
37 pub target: Box<dyn Loader>,
38 pub draft: Box<dyn Loader>,
39 pub config: SpeculativeConfig,
40}
41
42impl Loader for SpeculativeLoader {
43 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
44 fn load_model_from_hf(
45 &self,
46 revision: Option<String>,
47 token_source: TokenSource,
48 dtype: &dyn TryIntoDType,
49 device: &Device,
50 silent: bool,
51 mapper: DeviceMapSetting,
52 in_situ_quant: Option<IsqType>,
53 paged_attn_config: Option<PagedAttentionConfig>,
54 ) -> anyhowResult<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
55 let paged_attn_config = if paged_attn_config.is_none() {
56 warn!(
57 "Speculative decoding does not currently support PagedAttention, running without"
58 );
59 None
60 } else {
61 paged_attn_config
62 };
63
64 let target = self.target.load_model_from_hf(
65 revision.clone(),
66 token_source.clone(),
67 dtype,
68 device,
69 silent,
70 mapper.clone(),
71 in_situ_quant,
72 paged_attn_config,
73 )?;
74 let draft = self.draft.load_model_from_hf(
75 revision,
76 token_source,
77 dtype,
78 device,
79 silent,
80 mapper,
81 in_situ_quant,
82 paged_attn_config,
83 )?;
84 Ok(Arc::new(tokio::sync::Mutex::new(SpeculativePipeline::new(
85 target,
86 draft,
87 self.config,
88 )?)))
89 }
90
91 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
92 fn load_model_from_path(
93 &self,
94 paths: &Box<dyn ModelPaths>,
95 dtype: &dyn TryIntoDType,
96 device: &Device,
97 silent: bool,
98 mapper: DeviceMapSetting,
99 in_situ_quant: Option<IsqType>,
100 paged_attn_config: Option<PagedAttentionConfig>,
101 ) -> anyhowResult<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
102 let paged_attn_config = if paged_attn_config.is_none() {
103 warn!(
104 "Speculative decoding does not currently support PagedAttention, running without"
105 );
106 None
107 } else {
108 paged_attn_config
109 };
110
111 let target = self.target.load_model_from_path(
112 paths,
113 dtype,
114 device,
115 silent,
116 mapper.clone(),
117 in_situ_quant,
118 paged_attn_config,
119 )?;
120 let draft = self.draft.load_model_from_path(
121 paths,
122 dtype,
123 device,
124 silent,
125 mapper.clone(),
126 in_situ_quant,
127 paged_attn_config,
128 )?;
129 Ok(Arc::new(tokio::sync::Mutex::new(SpeculativePipeline::new(
130 target,
131 draft,
132 self.config,
133 )?)))
134 }
135 fn get_id(&self) -> String {
136 format!(
137 "Speculative: tgt = `{}`, draft = `{}`, gamma = `{}`",
138 self.target.get_id(),
139 self.draft.get_id(),
140 self.config.gamma,
141 )
142 }
143 fn get_kind(&self) -> ModelKind {
144 ModelKind::Speculative {
145 target: Box::new(self.target.get_kind()),
146 draft: Box::new(self.draft.get_kind()),
147 }
148 }
149}
150
151pub struct SpeculativePipeline {
163 target: Arc<tokio::sync::Mutex<dyn Pipeline>>,
164 draft: Arc<tokio::sync::Mutex<dyn Pipeline>>,
165 gamma: usize,
166 metadata: Arc<GeneralMetadata>,
167 category: ModelCategory,
168}
169
170#[derive(Copy, Clone)]
171pub struct SpeculativeConfig {
173 pub gamma: usize,
175}
176
177impl SpeculativePipeline {
178 pub fn new(
179 target: Arc<tokio::sync::Mutex<dyn Pipeline>>,
180 draft: Arc<tokio::sync::Mutex<dyn Pipeline>>,
181 config: SpeculativeConfig,
182 ) -> Result<Self> {
183 if get_mut_arcmutex!(target)
184 .tokenizer()
185 .as_ref()
186 .ok_or(candle_core::Error::Msg(
187 "`SpeculativePipeline::new` requires the target pipeline to have a token trie"
188 .to_string(),
189 ))?
190 .get_vocab(true)
191 != get_mut_arcmutex!(draft)
192 .tokenizer()
193 .as_ref()
194 .ok_or(candle_core::Error::Msg(
195 "`SpeculativePipeline::new` requires the draft pipeline to have a token trie"
196 .to_string(),
197 ))?
198 .get_vocab(true)
199 {
200 candle_core::bail!("Target and draft models' tokenizer vocab do not match. This is required for speculative decoding.");
201 }
202 if get_mut_arcmutex!(target).category() != get_mut_arcmutex!(draft).category() {
203 candle_core::bail!("Target and draft models' category do not match. This is required for speculative decoding.");
204 }
205 if get_mut_arcmutex!(target)
206 .get_processor()
207 .inputs_processor()
208 .get_type()
209 != get_mut_arcmutex!(draft)
210 .get_processor()
211 .inputs_processor()
212 .get_type()
213 {
214 candle_core::bail!("Target and draft models' input processors do not match. This is required for speculative decoding.");
215 }
216 let metadata = get_mut_arcmutex!(target).get_metadata().clone();
217 let category = get_mut_arcmutex!(target).category();
218 Ok(Self {
220 target,
221 draft,
222 gamma: config.gamma,
223 metadata,
224 category,
225 })
226 }
227}
228
229impl PreProcessingMixin for SpeculativePipeline {
230 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
231 get_mut_arcmutex!(self.target).get_chat_template()
232 }
233 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
234 get_mut_arcmutex!(self.target).get_input_processor_config()
235 }
236}
237
238impl IsqPipelineMixin for SpeculativePipeline {
239 fn re_isq_model(&mut self, dtype: IsqType) -> anyhow::Result<()> {
240 get_mut_arcmutex!(self.target).re_isq_model(dtype)?;
241 get_mut_arcmutex!(self.draft).re_isq_model(dtype)
242 }
243}
244
245impl CacheManagerMixin for SpeculativePipeline {
246 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
247 NormalCacheManager.clone_in_cache(&*get_mut_arcmutex!(self.draft), seqs, true);
248 NormalCacheManager.clone_in_cache(&*get_mut_arcmutex!(self.target), seqs, false);
249 }
250 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
251 NormalCacheManager.clone_out_cache(&*get_mut_arcmutex!(self.draft), seqs, true);
252 NormalCacheManager.clone_out_cache(&*get_mut_arcmutex!(self.target), seqs, false);
253 }
254 fn set_none_cache(
255 &self,
256 seqs: &mut [&mut Sequence],
257 reset_non_granular: bool,
258 modify_draft_cache: bool,
259 load_preallocated_cache: bool,
260 ) {
261 NormalCacheManager.set_none_cache(
262 &*get_mut_arcmutex!(self.draft),
263 seqs,
264 modify_draft_cache,
265 load_preallocated_cache,
266 );
267 NormalCacheManager.set_none_cache(
268 &*get_mut_arcmutex!(self.target),
269 seqs,
270 false,
271 load_preallocated_cache,
272 );
273 if reset_non_granular {
274 self.reset_non_granular_state()
275 }
276 }
277 fn cache(&self) -> &EitherCache {
278 unreachable!()
279 }
280 fn do_preallocated_cache(&self) -> bool {
281 false
283 }
284}
285
286impl MetadataMixin for SpeculativePipeline {
287 fn device(&self) -> Device {
288 get_mut_arcmutex!(self.target).device()
289 }
290 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
291 get_mut_arcmutex!(self.target).tokenizer()
292 }
293 fn name(&self) -> String {
294 format!(
295 "Speculative: tgt = `{}`, draft = `{}`, gamma = `{}`",
296 get_mut_arcmutex!(self.target).name(),
297 get_mut_arcmutex!(self.draft).name(),
298 self.gamma,
299 )
300 }
301 fn reset_non_granular_state(&self) {
302 get_mut_arcmutex!(self.target).reset_non_granular_state();
303 get_mut_arcmutex!(self.draft).reset_non_granular_state();
304 }
305 fn get_metadata(&self) -> Arc<GeneralMetadata> {
306 self.metadata.clone()
307 }
308 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
309 None
310 }
311}
312
313#[async_trait::async_trait]
314impl Pipeline for SpeculativePipeline {
315 fn forward_inputs(
316 &mut self,
317 _inputs: Box<dyn Any>,
318 _return_raw_logits: bool,
319 ) -> Result<ForwardInputsResult> {
320 unreachable!()
321 }
322 async fn sample_causal_gen(
323 &self,
324 _seqs: &mut [&mut Sequence],
325 _logits: Vec<Tensor>,
326 _prefix_cacher: &mut PrefixCacheManagerV2,
327 _disable_eos_stop: bool,
328 _rng: Arc<std::sync::Mutex<Isaac64Rng>>,
329 ) -> Result<()> {
330 unreachable!()
331 }
332 async fn step(
333 &mut self,
334 input_seqs: &mut [&mut Sequence],
335 is_prompt: bool,
336 _return_raw_logits: bool,
337 prefix_cacher: &mut PrefixCacheManagerV2,
338 disable_eos_stop: bool,
339 rng: Arc<Mutex<Isaac64Rng>>,
340 backend_metadata: CacheBackendMetadata,
341 ) -> Result<Duration> {
342 match backend_metadata {
343 CacheBackendMetadata::DefaultInstructions { pre_op, post_op } => {
344 match pre_op {
345 CacheInstruction::In => self.clone_in_cache(input_seqs),
346 CacheInstruction::Nothing => (),
347 CacheInstruction::Reset {
348 reset_non_granular,
349 load_preallocated_cache,
350 } => self.set_none_cache(
351 input_seqs,
352 reset_non_granular,
353 true,
354 load_preallocated_cache,
355 ),
356 _ => unreachable!("Unreachable PRE cache op."),
357 }
358
359 let start = Instant::now();
360 assert_eq!(input_seqs.len(), 1);
361
362 let seq = &mut input_seqs[0];
363
364 let mut draft_samples = Vec::new();
367 for i in 0..self.gamma {
368 let is_xlora = get_mut_arcmutex!(self.draft).get_metadata().is_xlora;
369 let device = get_mut_arcmutex!(self.draft).device();
370 let no_kv_cache = get_mut_arcmutex!(self.draft).get_metadata().no_kv_cache;
371 let inputs = self
372 .get_processor()
373 .inputs_processor()
374 .process_inputs(
375 self.tokenizer(),
376 &mut [seq],
377 is_prompt && i == 0, is_xlora,
379 &device,
380 no_kv_cache,
381 None,
382 false,
383 None,
384 None, None, get_mut_arcmutex!(self.draft).device_mapper(),
387 )
388 .nth(0)
389 .unwrap()
390 .unwrap()
391 .inputs;
392 let logits = get_mut_arcmutex!(self.draft).forward_inputs(inputs, false)?;
393 #[allow(irrefutable_let_patterns)]
394 let ForwardInputsResult::CausalGeneration { logits } = logits
395 else {
396 candle_core::bail!(
397 "Speculative decoding requires `CausalGeneration` forward results"
398 );
399 };
400
401 let sample = sample_sequence(
402 logits.clone(),
403 seq,
404 seq.return_logprobs(),
405 rng.clone(),
406 false, true,
408 false,
409 )
410 .await?;
411 seq.add_tmp_tok(sample.token);
412 draft_samples.push(SpeculativeSample { sample });
413 }
414 seq.remove_tmp_tok(self.gamma);
415
416 let mut draft_prefill_tokens = if is_prompt {
418 seq.get_toks().to_vec()
419 } else {
420 vec![*seq.get_toks().last().unwrap()]
421 };
422 for (i, sample) in draft_samples.iter().enumerate() {
423 if i == draft_samples.len() - 1 {
424 continue;
425 }
426 draft_prefill_tokens.push(sample.sample.token);
427 }
428 seq.set_prefill_toks(draft_prefill_tokens);
429
430 let initial_cache_len = match get_mut_arcmutex!(self.target).cache() {
433 EitherCache::Full(full) => full.lock()[0]
434 .as_ref()
435 .map(|(k, _)| k.dims()[2])
436 .unwrap_or(0),
437 EitherCache::Normal(normal) => normal.lock().unwrap().0[0].current_seq_len(),
438 };
439
440 let is_xlora = get_mut_arcmutex!(self.target).get_metadata().is_xlora;
442 let device = get_mut_arcmutex!(self.target).device();
443 let no_kv_cache = get_mut_arcmutex!(self.target).get_metadata().no_kv_cache;
444 let inputs = self
445 .get_processor()
446 .inputs_processor()
447 .process_inputs(
448 self.tokenizer(),
449 &mut [seq],
450 true, is_xlora,
452 &device,
453 no_kv_cache,
454 Some((self.gamma, initial_cache_len)), false,
456 None,
457 None, None, get_mut_arcmutex!(self.target).device_mapper(),
460 )
461 .nth(0)
462 .unwrap()
463 .unwrap()
464 .inputs;
465
466 let logits = get_mut_arcmutex!(self.target).forward_inputs(inputs, false)?;
467 #[allow(irrefutable_let_patterns)]
468 let ForwardInputsResult::CausalGeneration { logits } = logits
469 else {
470 candle_core::bail!(
471 "Speculative decoding requires `CausalGeneration` forward results"
472 );
473 };
474
475 seq.reset_prefill_toks();
477
478 let samples = sample_target_sequence_speculative(
482 logits.clone(),
483 seq,
484 seq.return_logprobs(),
485 rng.clone(),
486 &draft_samples,
487 )
488 .await?;
489
490 let accepted_tokens = samples.into_iter().map(|s| s.sample).collect::<Vec<_>>();
491
492 let n_not_accepted = self.gamma - accepted_tokens.len();
494
495 match get_mut_arcmutex!(self.draft).cache() {
496 EitherCache::Full(full) => {
497 for (k, v) in full.lock().iter_mut().flatten() {
498 *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
499 *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
500 }
501 }
502 EitherCache::Normal(normal) => {
503 for cache in &mut *normal.lock().unwrap().0 {
504 cache
505 .set_len(cache.current_seq_len() - n_not_accepted)
506 .map_err(|_| candle_core::Error::msg("KV cache set_len failed."))?;
507 }
508 }
509 }
510 if get_mut_arcmutex!(self.draft).get_metadata().is_xlora {
511 match get_mut_arcmutex!(self.draft).cache() {
512 EitherCache::Full(full) => {
513 for (k, v) in full.xlora_lock().iter_mut().flatten() {
514 *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
515 *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
516 }
517 }
518 EitherCache::Normal(_) => {
519 unreachable!()
520 }
521 }
522 }
523 match get_mut_arcmutex!(self.target).cache() {
524 EitherCache::Full(full) => {
525 for (k, v) in full.lock().iter_mut().flatten() {
526 *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
527 *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
528 }
529 }
530 EitherCache::Normal(normal) => {
531 for cache in &mut *normal.lock().unwrap().0 {
532 cache
533 .set_len(cache.current_seq_len() - n_not_accepted)
534 .map_err(|_| candle_core::Error::msg("KV cache set_len failed."))?;
535 }
536 }
537 }
538 if get_mut_arcmutex!(self.draft).get_metadata().is_xlora {
539 match get_mut_arcmutex!(self.target).cache() {
540 EitherCache::Full(full) => {
541 for (k, v) in full.xlora_lock().iter_mut().flatten() {
542 *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
543 *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
544 }
545 }
546 EitherCache::Normal(_) => {
547 unreachable!()
548 }
549 }
550 }
551
552 let eos_owned = get_mut_arcmutex!(self.target)
553 .get_metadata()
554 .eos_tok
555 .clone();
556 let eos_tok = if disable_eos_stop {
557 None
558 } else {
559 Some(&eos_owned[..])
560 };
561 for accepted in accepted_tokens {
563 finish_or_add_toks_to_seq(
565 self,
566 prefix_cacher,
567 seq,
568 accepted.clone(),
569 eos_tok,
570 false,
571 )
572 .await?;
573 }
574
575 let end = Instant::now();
590 let exec_duration = end.duration_since(start);
591
592 match post_op {
593 CacheInstruction::Out => {
594 self.clone_out_cache(input_seqs);
595 }
596 CacheInstruction::Nothing => (),
597 CacheInstruction::Reset {
598 reset_non_granular,
599 load_preallocated_cache,
600 } => self.set_none_cache(
601 input_seqs,
602 reset_non_granular,
603 true,
604 load_preallocated_cache,
605 ),
606 _ => unreachable!("Unreachable pre cache op."),
607 }
608
609 Ok(exec_duration)
619 }
620 CacheBackendMetadata::PagedAttention {
621 metadata: _,
622 blocks_to_copy: _,
623 blocks_to_swap_in: _,
624 blocks_to_swap_out: _,
625 } => unreachable!(),
626 }
627 }
628 fn category(&self) -> ModelCategory {
629 self.category.clone()
630 }
631}
632
633impl AnyMoePipelineMixin for SpeculativePipeline {}