mistralrs/
speculative.rs

1use std::sync::Arc;
2
3use mistralrs_core::{
4    initialize_logging, AutoDeviceMapParams, DefaultSchedulerMethod, DeviceMapSetting,
5    MistralRsBuilder, NormalLoaderBuilder, NormalSpecificConfig, Pipeline, SchedulerConfig,
6    SpeculativeConfig, SpeculativePipeline,
7};
8use tokio::sync::Mutex;
9
10use crate::{best_device, Model, TextModelBuilder};
11
12pub struct TextSpeculativeBuilder {
13    target: TextModelBuilder,
14    draft: TextModelBuilder,
15    speculative_config: SpeculativeConfig,
16}
17
18impl TextSpeculativeBuilder {
19    /// Create a builder for a speculative decoding pipeline.
20    ///
21    /// - PagedAttention settings are ignored as our impl of speculative decoding does not support this yet.
22    /// - Prefix caching settings are ignored as our impl of speculative decoding does not support this yet.
23    ///
24    /// Otherwise, scheduling parameters such as `max_num_seqs` are sourced from the target model.
25    pub fn new(
26        target: TextModelBuilder,
27        draft: TextModelBuilder,
28        speculative_config: SpeculativeConfig,
29    ) -> anyhow::Result<Self> {
30        if target.no_kv_cache || draft.no_kv_cache {
31            anyhow::bail!("Both target and draft must have KV cache enabled.");
32        }
33
34        Ok(Self {
35            target,
36            draft,
37            speculative_config,
38        })
39    }
40
41    fn build_pipeline(builder: TextModelBuilder) -> anyhow::Result<Arc<Mutex<dyn Pipeline>>> {
42        let config = NormalSpecificConfig {
43            prompt_chunksize: builder.prompt_chunksize,
44            topology: builder.topology,
45            organization: builder.organization,
46            write_uqff: builder.write_uqff,
47            from_uqff: builder.from_uqff,
48            imatrix: builder.imatrix,
49            calibration_file: builder.calibration_file,
50            hf_cache_path: builder.hf_cache_path,
51            matformer_config_path: None,
52            matformer_slice_name: None,
53        };
54
55        if builder.with_logging {
56            initialize_logging();
57        }
58
59        let loader = NormalLoaderBuilder::new(
60            config,
61            builder.chat_template,
62            builder.tokenizer_json,
63            Some(builder.model_id),
64            builder.no_kv_cache,
65            builder.jinja_explicit,
66        )
67        .build(builder.loader_type)?;
68
69        // Load, into a Pipeline
70        let pipeline = loader.load_model_from_hf(
71            builder.hf_revision,
72            builder.token_source,
73            &builder.dtype,
74            &best_device(builder.force_cpu)?,
75            !builder.with_logging,
76            builder
77                .device_mapping
78                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
79            builder.isq,
80            builder.paged_attn_cfg,
81        )?;
82        Ok(pipeline)
83    }
84
85    pub async fn build(self) -> anyhow::Result<Model> {
86        let target = Self::build_pipeline(self.target.clone())?;
87        let draft = Self::build_pipeline(self.draft.clone())?;
88
89        let scheduler_method = SchedulerConfig::DefaultScheduler {
90            method: DefaultSchedulerMethod::Fixed(self.target.max_num_seqs.try_into()?),
91        };
92
93        let pipeline = Arc::new(Mutex::new(SpeculativePipeline::new(
94            target,
95            draft,
96            self.speculative_config,
97        )?));
98
99        let mut runner = MistralRsBuilder::new(
100            pipeline,
101            scheduler_method,
102            self.target.throughput_logging,
103            self.target.search_bert_model,
104        );
105        if let Some(cb) = self.target.search_callback.clone() {
106            runner = runner.with_search_callback(cb);
107        }
108        for (name, cb) in &self.target.tool_callbacks {
109            runner = runner.with_tool_callback(name.clone(), cb.clone());
110        }
111
112        Ok(Model::new(runner.build().await))
113    }
114}