mistralrs/
gguf.rs

1use candle_core::Device;
2use mistralrs_core::*;
3use mistralrs_core::{SearchCallback, Tool, ToolCallback};
4use std::collections::HashMap;
5use std::num::NonZeroUsize;
6
7use crate::{best_device, Model};
8use std::sync::Arc;
9
10/// A tool callback with its associated Tool definition.
11#[derive(Clone)]
12pub struct ToolCallbackWithTool {
13    pub callback: Arc<ToolCallback>,
14    pub tool: Tool,
15}
16
17/// Configure a text GGUF model with the various parameters for loading, running, and other inference behaviors.
18pub struct GgufModelBuilder {
19    // Loading model
20    pub(crate) model_id: String,
21    pub(crate) files: Vec<String>,
22    pub(crate) tok_model_id: Option<String>,
23    pub(crate) token_source: TokenSource,
24    pub(crate) hf_revision: Option<String>,
25    pub(crate) chat_template: Option<String>,
26    pub(crate) jinja_explicit: Option<String>,
27    pub(crate) tokenizer_json: Option<String>,
28    pub(crate) device_mapping: Option<DeviceMapSetting>,
29    pub(crate) search_bert_model: Option<BertEmbeddingModel>,
30    pub(crate) search_callback: Option<Arc<SearchCallback>>,
31    pub(crate) tool_callbacks: HashMap<String, Arc<ToolCallback>>,
32    pub(crate) tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
33    pub(crate) device: Option<Device>,
34
35    // Model running
36    pub(crate) prompt_chunksize: Option<NonZeroUsize>,
37    pub(crate) force_cpu: bool,
38    pub(crate) topology: Option<Topology>,
39    pub(crate) throughput_logging: bool,
40
41    // Other things
42    pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
43    pub(crate) max_num_seqs: usize,
44    pub(crate) no_kv_cache: bool,
45    pub(crate) with_logging: bool,
46    pub(crate) prefix_cache_n: Option<usize>,
47}
48
49impl GgufModelBuilder {
50    /// A few defaults are applied here:
51    /// - Token source is from the cache (.cache/huggingface/token)
52    /// - Maximum number of sequences running is 32
53    /// - Number of sequences to hold in prefix cache is 16.
54    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
55    /// - By default, web searching compatible with the OpenAI `web_search_options` setting is disabled.
56    pub fn new(model_id: impl ToString, files: Vec<impl ToString>) -> Self {
57        Self {
58            model_id: model_id.to_string(),
59            files: files.into_iter().map(|f| f.to_string()).collect::<Vec<_>>(),
60            prompt_chunksize: None,
61            chat_template: None,
62            tokenizer_json: None,
63            force_cpu: false,
64            token_source: TokenSource::CacheToken,
65            hf_revision: None,
66            paged_attn_cfg: None,
67            max_num_seqs: 32,
68            no_kv_cache: false,
69            prefix_cache_n: Some(16),
70            with_logging: false,
71            topology: None,
72            tok_model_id: None,
73            device_mapping: None,
74            jinja_explicit: None,
75            throughput_logging: false,
76            search_bert_model: None,
77            search_callback: None,
78            tool_callbacks: HashMap::new(),
79            tool_callbacks_with_tools: HashMap::new(),
80            device: None,
81        }
82    }
83
84    /// Enable searching compatible with the OpenAI `web_search_options` setting. This uses the BERT model specified or the default.
85    pub fn with_search(mut self, search_bert_model: BertEmbeddingModel) -> Self {
86        self.search_bert_model = Some(search_bert_model);
87        self
88    }
89
90    /// Override the search function used when `web_search_options` is enabled.
91    pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
92        self.search_callback = Some(callback);
93        self
94    }
95
96    pub fn with_tool_callback(
97        mut self,
98        name: impl Into<String>,
99        callback: Arc<ToolCallback>,
100    ) -> Self {
101        self.tool_callbacks.insert(name.into(), callback);
102        self
103    }
104
105    /// Register a callback with an associated Tool definition that will be automatically
106    /// added to requests when tool callbacks are active.
107    pub fn with_tool_callback_and_tool(
108        mut self,
109        name: impl Into<String>,
110        callback: Arc<ToolCallback>,
111        tool: Tool,
112    ) -> Self {
113        let name = name.into();
114        self.tool_callbacks_with_tools
115            .insert(name, ToolCallbackWithTool { callback, tool });
116        self
117    }
118
119    /// Enable runner throughput logging.
120    pub fn with_throughput_logging(mut self) -> Self {
121        self.throughput_logging = true;
122        self
123    }
124
125    /// Explicit JINJA chat template file (.jinja) to be used. If specified, this overrides all other chat templates.
126    pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
127        self.jinja_explicit = Some(jinja_explicit);
128        self
129    }
130
131    /// Source the tokenizer and chat template from this model ID (must contain `tokenizer.json` and `tokenizer_config.json`).
132    pub fn with_tok_model_id(mut self, tok_model_id: impl ToString) -> Self {
133        self.tok_model_id = Some(tok_model_id.to_string());
134        self
135    }
136
137    /// Set the prompt batchsize to use for inference.
138    pub fn with_prompt_chunksize(mut self, prompt_chunksize: NonZeroUsize) -> Self {
139        self.prompt_chunksize = Some(prompt_chunksize);
140        self
141    }
142
143    /// Set the model topology for use during loading. If there is an overlap, the topology type is used over the ISQ type.
144    pub fn with_topology(mut self, topology: Topology) -> Self {
145        self.topology = Some(topology);
146        self
147    }
148
149    /// Literal Jinja chat template OR Path (ending in `.json`) to one.
150    pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
151        self.chat_template = Some(chat_template.to_string());
152        self
153    }
154
155    /// Path to a discrete `tokenizer.json` file.
156    pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
157        self.tokenizer_json = Some(tokenizer_json.to_string());
158        self
159    }
160
161    /// Force usage of the CPU device. Do not use PagedAttention with this.
162    pub fn with_force_cpu(mut self) -> Self {
163        self.force_cpu = true;
164        self
165    }
166
167    /// Source of the Hugging Face token.
168    pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
169        self.token_source = token_source;
170        self
171    }
172
173    /// Set the revision to use for a Hugging Face remote model.
174    pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
175        self.hf_revision = Some(revision.to_string());
176        self
177    }
178
179    /// Enable PagedAttention. Configure PagedAttention with a [`PagedAttentionConfig`] object, which
180    /// can be created with sensible values with a [`PagedAttentionMetaBuilder`].
181    ///
182    /// If PagedAttention is not supported (query with [`paged_attn_supported`]), this will do nothing.
183    ///
184    /// [`PagedAttentionMetaBuilder`]: crate::PagedAttentionMetaBuilder
185    pub fn with_paged_attn(
186        mut self,
187        paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
188    ) -> anyhow::Result<Self> {
189        if paged_attn_supported() {
190            self.paged_attn_cfg = Some(paged_attn_cfg()?);
191        } else {
192            self.paged_attn_cfg = None;
193        }
194        Ok(self)
195    }
196
197    /// Set the maximum number of sequences which can be run at once.
198    pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
199        self.max_num_seqs = max_num_seqs;
200        self
201    }
202
203    /// Disable KV cache. Trade performance for memory usage.
204    pub fn with_no_kv_cache(mut self) -> Self {
205        self.no_kv_cache = true;
206        self
207    }
208
209    /// Set the number of sequences to hold in the prefix cache. Set to `None` to disable the prefix cacher.
210    pub fn with_prefix_cache_n(mut self, n_seqs: Option<usize>) -> Self {
211        self.prefix_cache_n = n_seqs;
212        self
213    }
214
215    /// Enable logging.
216    pub fn with_logging(mut self) -> Self {
217        self.with_logging = true;
218        self
219    }
220
221    /// Provide metadata to initialize the device mapper.
222    pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
223        self.device_mapping = Some(device_mapping);
224        self
225    }
226
227    /// Set the main device to load this model onto. Automatic device mapping will be performed starting with this device.
228    pub fn with_device(mut self, device: Device) -> Self {
229        self.device = Some(device);
230        self
231    }
232
233    pub async fn build(self) -> anyhow::Result<Model> {
234        let config = GGUFSpecificConfig {
235            prompt_chunksize: self.prompt_chunksize,
236            topology: self.topology,
237        };
238
239        if self.with_logging {
240            initialize_logging();
241        }
242
243        let loader = GGUFLoaderBuilder::new(
244            self.chat_template,
245            self.tok_model_id,
246            self.model_id,
247            self.files,
248            config,
249            self.no_kv_cache,
250            self.jinja_explicit,
251        )
252        .build();
253
254        // Load, into a Pipeline
255        let pipeline = loader.load_model_from_hf(
256            self.hf_revision,
257            self.token_source,
258            &ModelDType::Auto,
259            &self.device.unwrap_or(best_device(self.force_cpu).unwrap()),
260            !self.with_logging,
261            self.device_mapping
262                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
263            None,
264            self.paged_attn_cfg,
265        )?;
266
267        let scheduler_method = match self.paged_attn_cfg {
268            Some(_) => {
269                let config = pipeline
270                    .lock()
271                    .await
272                    .get_metadata()
273                    .cache_config
274                    .as_ref()
275                    .unwrap()
276                    .clone();
277
278                SchedulerConfig::PagedAttentionMeta {
279                    max_num_seqs: self.max_num_seqs,
280                    config,
281                }
282            }
283            None => SchedulerConfig::DefaultScheduler {
284                method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
285            },
286        };
287
288        let mut runner = MistralRsBuilder::new(
289            pipeline,
290            scheduler_method,
291            self.throughput_logging,
292            self.search_bert_model,
293        );
294        if let Some(cb) = self.search_callback.clone() {
295            runner = runner.with_search_callback(cb);
296        }
297        for (name, cb) in &self.tool_callbacks {
298            runner = runner.with_tool_callback(name.clone(), cb.clone());
299        }
300        for (name, callback_with_tool) in &self.tool_callbacks_with_tools {
301            runner = runner.with_tool_callback_and_tool(
302                name.clone(),
303                callback_with_tool.callback.clone(),
304                callback_with_tool.tool.clone(),
305            );
306        }
307        runner = runner
308            .with_no_kv_cache(self.no_kv_cache)
309            .with_no_prefix_cache(self.prefix_cache_n.is_none());
310
311        if let Some(n) = self.prefix_cache_n {
312            runner = runner.with_prefix_cache_n(n)
313        }
314
315        Ok(Model::new(runner.build().await))
316    }
317}