1use std::sync::Arc;
2
3use mistralrs_core::{
4 initialize_logging, AutoDeviceMapParams, DefaultSchedulerMethod, DeviceMapSetting,
5 MistralRsBuilder, NormalLoaderBuilder, NormalSpecificConfig, Pipeline, SchedulerConfig,
6 SpeculativeConfig, SpeculativePipeline,
7};
8use tokio::sync::Mutex;
9
10use crate::{best_device, Model, TextModelBuilder};
11
12pub struct TextSpeculativeBuilder {
13 target: TextModelBuilder,
14 draft: TextModelBuilder,
15 speculative_config: SpeculativeConfig,
16}
17
18impl TextSpeculativeBuilder {
19 pub fn new(
26 target: TextModelBuilder,
27 draft: TextModelBuilder,
28 speculative_config: SpeculativeConfig,
29 ) -> anyhow::Result<Self> {
30 if target.no_kv_cache || draft.no_kv_cache {
31 anyhow::bail!("Both target and draft must have KV cache enabled.");
32 }
33
34 Ok(Self {
35 target,
36 draft,
37 speculative_config,
38 })
39 }
40
41 fn build_pipeline(builder: TextModelBuilder) -> anyhow::Result<Arc<Mutex<dyn Pipeline>>> {
42 let config = NormalSpecificConfig {
43 prompt_chunksize: builder.prompt_chunksize,
44 topology: builder.topology,
45 organization: builder.organization,
46 write_uqff: builder.write_uqff,
47 from_uqff: builder.from_uqff,
48 imatrix: builder.imatrix,
49 calibration_file: builder.calibration_file,
50 hf_cache_path: builder.hf_cache_path,
51 matformer_config_path: None,
52 matformer_slice_name: None,
53 };
54
55 if builder.with_logging {
56 initialize_logging();
57 }
58
59 let loader = NormalLoaderBuilder::new(
60 config,
61 builder.chat_template,
62 builder.tokenizer_json,
63 Some(builder.model_id),
64 builder.no_kv_cache,
65 builder.jinja_explicit,
66 )
67 .build(builder.loader_type)?;
68
69 let pipeline = loader.load_model_from_hf(
71 builder.hf_revision,
72 builder.token_source,
73 &builder.dtype,
74 &best_device(builder.force_cpu)?,
75 !builder.with_logging,
76 builder
77 .device_mapping
78 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
79 builder.isq,
80 builder.paged_attn_cfg,
81 )?;
82 Ok(pipeline)
83 }
84
85 pub async fn build(self) -> anyhow::Result<Model> {
86 let target = Self::build_pipeline(self.target.clone())?;
87 let draft = Self::build_pipeline(self.draft.clone())?;
88
89 let scheduler_method = SchedulerConfig::DefaultScheduler {
90 method: DefaultSchedulerMethod::Fixed(self.target.max_num_seqs.try_into()?),
91 };
92
93 let pipeline = Arc::new(Mutex::new(SpeculativePipeline::new(
94 target,
95 draft,
96 self.speculative_config,
97 )?));
98
99 let mut runner = MistralRsBuilder::new(
100 pipeline,
101 scheduler_method,
102 self.target.throughput_logging,
103 self.target.search_bert_model,
104 );
105 if let Some(cb) = self.target.search_callback.clone() {
106 runner = runner.with_search_callback(cb);
107 }
108 for (name, cb) in &self.target.tool_callbacks {
109 runner = runner.with_tool_callback(name.clone(), cb.clone());
110 }
111
112 Ok(Model::new(runner.build().await))
113 }
114}