1use std::{
2 any::Any,
3 sync::{Arc, Mutex},
4 time::{Duration, Instant},
5};
6
7use anyhow::Result as anyhowResult;
8use candle_core::{Device, IndexOp, Result, Tensor};
9use mistralrs_quant::IsqType;
10use rand_isaac::Isaac64Rng;
11use tokenizers::Tokenizer;
12use tracing::warn;
13
14use crate::{
15 device_map::DeviceMapper,
16 get_mut_arcmutex,
17 pipeline::sampling::{
18 finish_or_add_toks_to_seq, sample_sequence, sample_target_sequence_speculative,
19 },
20 prefix_cacher::PrefixCacheManagerV2,
21 sequence::Sequence,
22 DeviceMapSetting, Loader, ModelKind, PagedAttentionConfig, Pipeline, TokenSource, TryIntoDType,
23};
24
25use super::{
26 cache_manager::NormalCacheManager, chat_template::ChatTemplate, sampling::SpeculativeSample,
27 AnyMoePipelineMixin, CacheBackendMetadata, CacheInstruction, CacheManager, CacheManagerMixin,
28 EitherCache, ForwardInputsResult, GeneralMetadata, IsqPipelineMixin, MetadataMixin,
29 ModelCategory, ModelPaths, PreProcessingMixin,
30};
31
32pub struct SpeculativeLoader {
34 pub target: Box<dyn Loader>,
35 pub draft: Box<dyn Loader>,
36 pub config: SpeculativeConfig,
37}
38
39impl Loader for SpeculativeLoader {
40 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
41 fn load_model_from_hf(
42 &self,
43 revision: Option<String>,
44 token_source: TokenSource,
45 dtype: &dyn TryIntoDType,
46 device: &Device,
47 silent: bool,
48 mapper: DeviceMapSetting,
49 in_situ_quant: Option<IsqType>,
50 paged_attn_config: Option<PagedAttentionConfig>,
51 ) -> anyhowResult<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
52 let paged_attn_config = if paged_attn_config.is_none() {
53 warn!(
54 "Speculative decoding does not currently support PagedAttention, running without"
55 );
56 None
57 } else {
58 paged_attn_config
59 };
60
61 let target = self.target.load_model_from_hf(
62 revision.clone(),
63 token_source.clone(),
64 dtype,
65 device,
66 silent,
67 mapper.clone(),
68 in_situ_quant,
69 paged_attn_config,
70 )?;
71 let draft = self.draft.load_model_from_hf(
72 revision,
73 token_source,
74 dtype,
75 device,
76 silent,
77 mapper,
78 in_situ_quant,
79 paged_attn_config,
80 )?;
81 Ok(Arc::new(tokio::sync::Mutex::new(SpeculativePipeline::new(
82 target,
83 draft,
84 self.config,
85 )?)))
86 }
87
88 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
89 fn load_model_from_path(
90 &self,
91 paths: &Box<dyn ModelPaths>,
92 dtype: &dyn TryIntoDType,
93 device: &Device,
94 silent: bool,
95 mapper: DeviceMapSetting,
96 in_situ_quant: Option<IsqType>,
97 paged_attn_config: Option<PagedAttentionConfig>,
98 ) -> anyhowResult<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
99 let paged_attn_config = if paged_attn_config.is_none() {
100 warn!(
101 "Speculative decoding does not currently support PagedAttention, running without"
102 );
103 None
104 } else {
105 paged_attn_config
106 };
107
108 let target = self.target.load_model_from_path(
109 paths,
110 dtype,
111 device,
112 silent,
113 mapper.clone(),
114 in_situ_quant,
115 paged_attn_config,
116 )?;
117 let draft = self.draft.load_model_from_path(
118 paths,
119 dtype,
120 device,
121 silent,
122 mapper.clone(),
123 in_situ_quant,
124 paged_attn_config,
125 )?;
126 Ok(Arc::new(tokio::sync::Mutex::new(SpeculativePipeline::new(
127 target,
128 draft,
129 self.config,
130 )?)))
131 }
132 fn get_id(&self) -> String {
133 format!(
134 "Speculative: tgt = `{}`, draft = `{}`, gamma = `{}`",
135 self.target.get_id(),
136 self.draft.get_id(),
137 self.config.gamma,
138 )
139 }
140 fn get_kind(&self) -> ModelKind {
141 ModelKind::Speculative {
142 target: Box::new(self.target.get_kind()),
143 draft: Box::new(self.draft.get_kind()),
144 }
145 }
146}
147
148pub struct SpeculativePipeline {
160 target: Arc<tokio::sync::Mutex<dyn Pipeline>>,
161 draft: Arc<tokio::sync::Mutex<dyn Pipeline>>,
162 gamma: usize,
163 metadata: Arc<GeneralMetadata>,
164 category: ModelCategory,
165}
166
167#[derive(Copy, Clone)]
168pub struct SpeculativeConfig {
170 pub gamma: usize,
172}
173
174impl SpeculativePipeline {
175 pub fn new(
176 target: Arc<tokio::sync::Mutex<dyn Pipeline>>,
177 draft: Arc<tokio::sync::Mutex<dyn Pipeline>>,
178 config: SpeculativeConfig,
179 ) -> Result<Self> {
180 if get_mut_arcmutex!(target)
181 .tokenizer()
182 .as_ref()
183 .ok_or(candle_core::Error::Msg(
184 "`SpeculativePipeline::new` requires the target pipeline to have a token trie"
185 .to_string(),
186 ))?
187 .get_vocab(true)
188 != get_mut_arcmutex!(draft)
189 .tokenizer()
190 .as_ref()
191 .ok_or(candle_core::Error::Msg(
192 "`SpeculativePipeline::new` requires the draft pipeline to have a token trie"
193 .to_string(),
194 ))?
195 .get_vocab(true)
196 {
197 candle_core::bail!("Target and draft models' tokenizer vocab do not match. This is required for speculative decoding.");
198 }
199 if get_mut_arcmutex!(target).category() != get_mut_arcmutex!(draft).category() {
200 candle_core::bail!("Target and draft models' category do not match. This is required for speculative decoding.");
201 }
202 if get_mut_arcmutex!(target)
203 .get_processor()
204 .inputs_processor()
205 .get_type()
206 != get_mut_arcmutex!(draft)
207 .get_processor()
208 .inputs_processor()
209 .get_type()
210 {
211 candle_core::bail!("Target and draft models' input processors do not match. This is required for speculative decoding.");
212 }
213 let metadata = get_mut_arcmutex!(target).get_metadata().clone();
214 let category = get_mut_arcmutex!(target).category();
215 Ok(Self {
217 target,
218 draft,
219 gamma: config.gamma,
220 metadata,
221 category,
222 })
223 }
224}
225
226impl PreProcessingMixin for SpeculativePipeline {
227 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
228 get_mut_arcmutex!(self.target).get_chat_template()
229 }
230 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
231 get_mut_arcmutex!(self.target).get_input_processor_config()
232 }
233}
234
235impl IsqPipelineMixin for SpeculativePipeline {
236 fn re_isq_model(&mut self, dtype: IsqType) -> anyhow::Result<()> {
237 get_mut_arcmutex!(self.target).re_isq_model(dtype)?;
238 get_mut_arcmutex!(self.draft).re_isq_model(dtype)
239 }
240}
241
242impl CacheManagerMixin for SpeculativePipeline {
243 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
244 NormalCacheManager.clone_in_cache(&*get_mut_arcmutex!(self.draft), seqs, true);
245 NormalCacheManager.clone_in_cache(&*get_mut_arcmutex!(self.target), seqs, false);
246 }
247 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
248 NormalCacheManager.clone_out_cache(&*get_mut_arcmutex!(self.draft), seqs, true);
249 NormalCacheManager.clone_out_cache(&*get_mut_arcmutex!(self.target), seqs, false);
250 }
251 fn set_none_cache(
252 &self,
253 seqs: &mut [&mut Sequence],
254 reset_non_granular: bool,
255 modify_draft_cache: bool,
256 load_preallocated_cache: bool,
257 ) {
258 NormalCacheManager.set_none_cache(
259 &*get_mut_arcmutex!(self.draft),
260 seqs,
261 modify_draft_cache,
262 load_preallocated_cache,
263 );
264 NormalCacheManager.set_none_cache(
265 &*get_mut_arcmutex!(self.target),
266 seqs,
267 false,
268 load_preallocated_cache,
269 );
270 if reset_non_granular {
271 self.reset_non_granular_state()
272 }
273 }
274 fn cache(&self) -> &EitherCache {
275 unreachable!()
276 }
277 fn do_preallocated_cache(&self) -> bool {
278 false
280 }
281}
282
283impl MetadataMixin for SpeculativePipeline {
284 fn device(&self) -> Device {
285 get_mut_arcmutex!(self.target).device()
286 }
287 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
288 get_mut_arcmutex!(self.target).tokenizer()
289 }
290 fn name(&self) -> String {
291 format!(
292 "Speculative: tgt = `{}`, draft = `{}`, gamma = `{}`",
293 get_mut_arcmutex!(self.target).name(),
294 get_mut_arcmutex!(self.draft).name(),
295 self.gamma,
296 )
297 }
298 fn reset_non_granular_state(&self) {
299 get_mut_arcmutex!(self.target).reset_non_granular_state();
300 get_mut_arcmutex!(self.draft).reset_non_granular_state();
301 }
302 fn get_metadata(&self) -> Arc<GeneralMetadata> {
303 self.metadata.clone()
304 }
305 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
306 None
307 }
308}
309
310#[async_trait::async_trait]
311impl Pipeline for SpeculativePipeline {
312 fn forward_inputs(
313 &mut self,
314 _inputs: Box<dyn Any>,
315 _return_raw_logits: bool,
316 ) -> Result<ForwardInputsResult> {
317 unreachable!()
318 }
319 async fn sample_causal_gen(
320 &self,
321 _seqs: &mut [&mut Sequence],
322 _logits: Vec<Tensor>,
323 _prefix_cacher: &mut PrefixCacheManagerV2,
324 _disable_eos_stop: bool,
325 _rng: Arc<std::sync::Mutex<Isaac64Rng>>,
326 ) -> Result<()> {
327 unreachable!()
328 }
329 async fn step(
330 &mut self,
331 input_seqs: &mut [&mut Sequence],
332 is_prompt: bool,
333 _return_raw_logits: bool,
334 prefix_cacher: &mut PrefixCacheManagerV2,
335 disable_eos_stop: bool,
336 rng: Arc<Mutex<Isaac64Rng>>,
337 backend_metadata: CacheBackendMetadata<'_>,
338 ) -> Result<Duration> {
339 match backend_metadata {
340 CacheBackendMetadata::DefaultInstructions { pre_op, post_op } => {
341 match pre_op {
342 CacheInstruction::In => self.clone_in_cache(input_seqs),
343 CacheInstruction::Nothing => (),
344 CacheInstruction::Reset {
345 reset_non_granular,
346 load_preallocated_cache,
347 } => self.set_none_cache(
348 input_seqs,
349 reset_non_granular,
350 true,
351 load_preallocated_cache,
352 ),
353 _ => unreachable!("Unreachable PRE cache op."),
354 }
355
356 let start = Instant::now();
357 assert_eq!(input_seqs.len(), 1);
358
359 let seq = &mut input_seqs[0];
360
361 let mut draft_samples = Vec::new();
364 for i in 0..self.gamma {
365 let is_xlora = get_mut_arcmutex!(self.draft).get_metadata().is_xlora;
366 let device = get_mut_arcmutex!(self.draft).device();
367 let no_kv_cache = get_mut_arcmutex!(self.draft).get_metadata().no_kv_cache;
368 let inputs = self
369 .get_processor()
370 .inputs_processor()
371 .process_inputs(
372 self.tokenizer(),
373 &mut [seq],
374 is_prompt && i == 0, is_xlora,
376 &device,
377 no_kv_cache,
378 None,
379 false,
380 None,
381 None, None, get_mut_arcmutex!(self.draft).device_mapper(),
384 )
385 .nth(0)
386 .unwrap()
387 .unwrap()
388 .inputs;
389 let logits = get_mut_arcmutex!(self.draft).forward_inputs(inputs, false)?;
390 #[allow(irrefutable_let_patterns)]
391 let ForwardInputsResult::CausalGeneration { logits } = logits
392 else {
393 candle_core::bail!(
394 "Speculative decoding requires `CausalGeneration` forward results"
395 );
396 };
397
398 let sample = sample_sequence(
399 logits.clone(),
400 seq,
401 seq.return_logprobs(),
402 rng.clone(),
403 false, true,
405 )
406 .await?;
407 seq.add_tmp_tok(sample.token);
408 draft_samples.push(SpeculativeSample { sample });
409 }
410 seq.remove_tmp_tok(self.gamma);
411
412 let mut draft_prefill_tokens = if is_prompt {
414 seq.get_toks().to_vec()
415 } else {
416 vec![*seq.get_toks().last().unwrap()]
417 };
418 for (i, sample) in draft_samples.iter().enumerate() {
419 if i == draft_samples.len() - 1 {
420 continue;
421 }
422 draft_prefill_tokens.push(sample.sample.token);
423 }
424 seq.set_prefill_toks(draft_prefill_tokens);
425
426 let initial_cache_len = match get_mut_arcmutex!(self.target).cache() {
429 EitherCache::Full(full) => full.lock()[0]
430 .as_ref()
431 .map(|(k, _)| k.dims()[2])
432 .unwrap_or(0),
433 EitherCache::Normal(normal) => normal.lock().unwrap().0[0].current_seq_len(),
434 };
435
436 let is_xlora = get_mut_arcmutex!(self.target).get_metadata().is_xlora;
438 let device = get_mut_arcmutex!(self.target).device();
439 let no_kv_cache = get_mut_arcmutex!(self.target).get_metadata().no_kv_cache;
440 let inputs = self
441 .get_processor()
442 .inputs_processor()
443 .process_inputs(
444 self.tokenizer(),
445 &mut [seq],
446 true, is_xlora,
448 &device,
449 no_kv_cache,
450 Some((self.gamma, initial_cache_len)), false,
452 None,
453 None, None, get_mut_arcmutex!(self.target).device_mapper(),
456 )
457 .nth(0)
458 .unwrap()
459 .unwrap()
460 .inputs;
461
462 let logits = get_mut_arcmutex!(self.target).forward_inputs(inputs, false)?;
463 #[allow(irrefutable_let_patterns)]
464 let ForwardInputsResult::CausalGeneration { logits } = logits
465 else {
466 candle_core::bail!(
467 "Speculative decoding requires `CausalGeneration` forward results"
468 );
469 };
470
471 seq.reset_prefill_toks();
473
474 let samples = sample_target_sequence_speculative(
478 logits.clone(),
479 seq,
480 seq.return_logprobs(),
481 rng.clone(),
482 &draft_samples,
483 )
484 .await?;
485
486 let accepted_tokens = samples.into_iter().map(|s| s.sample).collect::<Vec<_>>();
487
488 let n_not_accepted = self.gamma - accepted_tokens.len();
490
491 match get_mut_arcmutex!(self.draft).cache() {
492 EitherCache::Full(full) => {
493 for (k, v) in full.lock().iter_mut().flatten() {
494 *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
495 *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
496 }
497 }
498 EitherCache::Normal(normal) => {
499 for cache in &mut *normal.lock().unwrap().0 {
500 cache
501 .set_len(cache.current_seq_len() - n_not_accepted)
502 .map_err(|_| candle_core::Error::msg("KV cache set_len failed."))?;
503 }
504 }
505 }
506 if get_mut_arcmutex!(self.draft).get_metadata().is_xlora {
507 match get_mut_arcmutex!(self.draft).cache() {
508 EitherCache::Full(full) => {
509 for (k, v) in full.xlora_lock().iter_mut().flatten() {
510 *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
511 *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
512 }
513 }
514 EitherCache::Normal(_) => {
515 unreachable!()
516 }
517 }
518 }
519 match get_mut_arcmutex!(self.target).cache() {
520 EitherCache::Full(full) => {
521 for (k, v) in full.lock().iter_mut().flatten() {
522 *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
523 *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
524 }
525 }
526 EitherCache::Normal(normal) => {
527 for cache in &mut *normal.lock().unwrap().0 {
528 cache
529 .set_len(cache.current_seq_len() - n_not_accepted)
530 .map_err(|_| candle_core::Error::msg("KV cache set_len failed."))?;
531 }
532 }
533 }
534 if get_mut_arcmutex!(self.draft).get_metadata().is_xlora {
535 match get_mut_arcmutex!(self.target).cache() {
536 EitherCache::Full(full) => {
537 for (k, v) in full.xlora_lock().iter_mut().flatten() {
538 *k = k.i((.., .., ..k.dims()[2] - n_not_accepted, ..))?;
539 *v = v.i((.., .., ..v.dims()[2] - n_not_accepted, ..))?;
540 }
541 }
542 EitherCache::Normal(_) => {
543 unreachable!()
544 }
545 }
546 }
547
548 let eos_owned = get_mut_arcmutex!(self.target)
549 .get_metadata()
550 .eos_tok
551 .clone();
552 let eos_tok = if disable_eos_stop {
553 None
554 } else {
555 Some(&eos_owned[..])
556 };
557 for accepted in accepted_tokens {
559 finish_or_add_toks_to_seq(
561 self,
562 prefix_cacher,
563 seq,
564 accepted.clone(),
565 eos_tok,
566 false,
567 )
568 .await?;
569 }
570
571 let end = Instant::now();
586 let exec_duration = end.duration_since(start);
587
588 match post_op {
589 CacheInstruction::Out => {
590 self.clone_out_cache(input_seqs);
591 }
592 CacheInstruction::Nothing => (),
593 CacheInstruction::Reset {
594 reset_non_granular,
595 load_preallocated_cache,
596 } => self.set_none_cache(
597 input_seqs,
598 reset_non_granular,
599 true,
600 load_preallocated_cache,
601 ),
602 _ => unreachable!("Unreachable pre cache op."),
603 }
604
605 Ok(exec_duration)
615 }
616 CacheBackendMetadata::PagedAttention {
617 metadata: _,
618 blocks_to_copy: _,
619 blocks_to_swap_in: _,
620 blocks_to_swap_out: _,
621 } => unreachable!(),
622 }
623 }
624 fn category(&self) -> ModelCategory {
625 self.category.clone()
626 }
627}
628
629impl AnyMoePipelineMixin for SpeculativePipeline {}