mistralrs/
gguf.rs

1use candle_core::Device;
2use mistralrs_core::*;
3use mistralrs_core::{SearchCallback, Tool, ToolCallback};
4use std::collections::HashMap;
5
6use crate::{best_device, Model};
7use std::sync::Arc;
8
9/// A tool callback with its associated Tool definition.
10#[derive(Clone)]
11pub struct ToolCallbackWithTool {
12    pub callback: Arc<ToolCallback>,
13    pub tool: Tool,
14}
15
16/// Configure a text GGUF model with the various parameters for loading, running, and other inference behaviors.
17pub struct GgufModelBuilder {
18    // Loading model
19    pub(crate) model_id: String,
20    pub(crate) files: Vec<String>,
21    pub(crate) tok_model_id: Option<String>,
22    pub(crate) token_source: TokenSource,
23    pub(crate) hf_revision: Option<String>,
24    pub(crate) chat_template: Option<String>,
25    pub(crate) jinja_explicit: Option<String>,
26    pub(crate) tokenizer_json: Option<String>,
27    pub(crate) device_mapping: Option<DeviceMapSetting>,
28    pub(crate) search_bert_model: Option<BertEmbeddingModel>,
29    pub(crate) search_callback: Option<Arc<SearchCallback>>,
30    pub(crate) tool_callbacks: HashMap<String, Arc<ToolCallback>>,
31    pub(crate) tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
32    pub(crate) device: Option<Device>,
33
34    // Model running
35    pub(crate) force_cpu: bool,
36    pub(crate) topology: Option<Topology>,
37    pub(crate) throughput_logging: bool,
38
39    // Other things
40    pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
41    pub(crate) max_num_seqs: usize,
42    pub(crate) no_kv_cache: bool,
43    pub(crate) with_logging: bool,
44    pub(crate) prefix_cache_n: Option<usize>,
45}
46
47impl GgufModelBuilder {
48    /// A few defaults are applied here:
49    /// - Token source is from the cache (.cache/huggingface/token)
50    /// - Maximum number of sequences running is 32
51    /// - Number of sequences to hold in prefix cache is 16.
52    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
53    /// - By default, web searching compatible with the OpenAI `web_search_options` setting is disabled.
54    pub fn new(model_id: impl ToString, files: Vec<impl ToString>) -> Self {
55        Self {
56            model_id: model_id.to_string(),
57            files: files.into_iter().map(|f| f.to_string()).collect::<Vec<_>>(),
58            chat_template: None,
59            tokenizer_json: None,
60            force_cpu: false,
61            token_source: TokenSource::CacheToken,
62            hf_revision: None,
63            paged_attn_cfg: None,
64            max_num_seqs: 32,
65            no_kv_cache: false,
66            prefix_cache_n: Some(16),
67            with_logging: false,
68            topology: None,
69            tok_model_id: None,
70            device_mapping: None,
71            jinja_explicit: None,
72            throughput_logging: false,
73            search_bert_model: None,
74            search_callback: None,
75            tool_callbacks: HashMap::new(),
76            tool_callbacks_with_tools: HashMap::new(),
77            device: None,
78        }
79    }
80
81    /// Enable searching compatible with the OpenAI `web_search_options` setting. This uses the BERT model specified or the default.
82    pub fn with_search(mut self, search_bert_model: BertEmbeddingModel) -> Self {
83        self.search_bert_model = Some(search_bert_model);
84        self
85    }
86
87    /// Override the search function used when `web_search_options` is enabled.
88    pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
89        self.search_callback = Some(callback);
90        self
91    }
92
93    pub fn with_tool_callback(
94        mut self,
95        name: impl Into<String>,
96        callback: Arc<ToolCallback>,
97    ) -> Self {
98        self.tool_callbacks.insert(name.into(), callback);
99        self
100    }
101
102    /// Register a callback with an associated Tool definition that will be automatically
103    /// added to requests when tool callbacks are active.
104    pub fn with_tool_callback_and_tool(
105        mut self,
106        name: impl Into<String>,
107        callback: Arc<ToolCallback>,
108        tool: Tool,
109    ) -> Self {
110        let name = name.into();
111        self.tool_callbacks_with_tools
112            .insert(name, ToolCallbackWithTool { callback, tool });
113        self
114    }
115
116    /// Enable runner throughput logging.
117    pub fn with_throughput_logging(mut self) -> Self {
118        self.throughput_logging = true;
119        self
120    }
121
122    /// Explicit JINJA chat template file (.jinja) to be used. If specified, this overrides all other chat templates.
123    pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
124        self.jinja_explicit = Some(jinja_explicit);
125        self
126    }
127
128    /// Source the tokenizer and chat template from this model ID (must contain `tokenizer.json` and `tokenizer_config.json`).
129    pub fn with_tok_model_id(mut self, tok_model_id: impl ToString) -> Self {
130        self.tok_model_id = Some(tok_model_id.to_string());
131        self
132    }
133
134    /// Set the model topology for use during loading. If there is an overlap, the topology type is used over the ISQ type.
135    pub fn with_topology(mut self, topology: Topology) -> Self {
136        self.topology = Some(topology);
137        self
138    }
139
140    /// Literal Jinja chat template OR Path (ending in `.json`) to one.
141    pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
142        self.chat_template = Some(chat_template.to_string());
143        self
144    }
145
146    /// Path to a discrete `tokenizer.json` file.
147    pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
148        self.tokenizer_json = Some(tokenizer_json.to_string());
149        self
150    }
151
152    /// Force usage of the CPU device. Do not use PagedAttention with this.
153    pub fn with_force_cpu(mut self) -> Self {
154        self.force_cpu = true;
155        self
156    }
157
158    /// Source of the Hugging Face token.
159    pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
160        self.token_source = token_source;
161        self
162    }
163
164    /// Set the revision to use for a Hugging Face remote model.
165    pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
166        self.hf_revision = Some(revision.to_string());
167        self
168    }
169
170    /// Enable PagedAttention. Configure PagedAttention with a [`PagedAttentionConfig`] object, which
171    /// can be created with sensible values with a [`PagedAttentionMetaBuilder`].
172    ///
173    /// If PagedAttention is not supported (query with [`paged_attn_supported`]), this will do nothing.
174    ///
175    /// [`PagedAttentionMetaBuilder`]: crate::PagedAttentionMetaBuilder
176    pub fn with_paged_attn(
177        mut self,
178        paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
179    ) -> anyhow::Result<Self> {
180        if paged_attn_supported() {
181            self.paged_attn_cfg = Some(paged_attn_cfg()?);
182        } else {
183            self.paged_attn_cfg = None;
184        }
185        Ok(self)
186    }
187
188    /// Set the maximum number of sequences which can be run at once.
189    pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
190        self.max_num_seqs = max_num_seqs;
191        self
192    }
193
194    /// Disable KV cache. Trade performance for memory usage.
195    pub fn with_no_kv_cache(mut self) -> Self {
196        self.no_kv_cache = true;
197        self
198    }
199
200    /// Set the number of sequences to hold in the prefix cache. Set to `None` to disable the prefix cacher.
201    pub fn with_prefix_cache_n(mut self, n_seqs: Option<usize>) -> Self {
202        self.prefix_cache_n = n_seqs;
203        self
204    }
205
206    /// Enable logging.
207    pub fn with_logging(mut self) -> Self {
208        self.with_logging = true;
209        self
210    }
211
212    /// Provide metadata to initialize the device mapper.
213    pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
214        self.device_mapping = Some(device_mapping);
215        self
216    }
217
218    /// Set the main device to load this model onto. Automatic device mapping will be performed starting with this device.
219    pub fn with_device(mut self, device: Device) -> Self {
220        self.device = Some(device);
221        self
222    }
223
224    pub async fn build(self) -> anyhow::Result<Model> {
225        let config = GGUFSpecificConfig {
226            topology: self.topology,
227        };
228
229        if self.with_logging {
230            initialize_logging();
231        }
232
233        let loader = GGUFLoaderBuilder::new(
234            self.chat_template,
235            self.tok_model_id,
236            self.model_id,
237            self.files,
238            config,
239            self.no_kv_cache,
240            self.jinja_explicit,
241        )
242        .build();
243
244        // Load, into a Pipeline
245        let pipeline = loader.load_model_from_hf(
246            self.hf_revision,
247            self.token_source,
248            &ModelDType::Auto,
249            &self.device.unwrap_or(best_device(self.force_cpu).unwrap()),
250            !self.with_logging,
251            self.device_mapping
252                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
253            None,
254            self.paged_attn_cfg,
255        )?;
256
257        let scheduler_method = match self.paged_attn_cfg {
258            Some(_) => {
259                let config = pipeline
260                    .lock()
261                    .await
262                    .get_metadata()
263                    .cache_config
264                    .as_ref()
265                    .unwrap()
266                    .clone();
267
268                SchedulerConfig::PagedAttentionMeta {
269                    max_num_seqs: self.max_num_seqs,
270                    config,
271                }
272            }
273            None => SchedulerConfig::DefaultScheduler {
274                method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
275            },
276        };
277
278        let mut runner = MistralRsBuilder::new(
279            pipeline,
280            scheduler_method,
281            self.throughput_logging,
282            self.search_bert_model,
283        );
284        if let Some(cb) = self.search_callback.clone() {
285            runner = runner.with_search_callback(cb);
286        }
287        for (name, cb) in &self.tool_callbacks {
288            runner = runner.with_tool_callback(name.clone(), cb.clone());
289        }
290        for (name, callback_with_tool) in &self.tool_callbacks_with_tools {
291            runner = runner.with_tool_callback_and_tool(
292                name.clone(),
293                callback_with_tool.callback.clone(),
294                callback_with_tool.tool.clone(),
295            );
296        }
297        runner = runner
298            .with_no_kv_cache(self.no_kv_cache)
299            .with_no_prefix_cache(self.prefix_cache_n.is_none());
300
301        if let Some(n) = self.prefix_cache_n {
302            runner = runner.with_prefix_cache_n(n)
303        }
304
305        Ok(Model::new(runner.build().await))
306    }
307}