mistralrs/
gguf.rs

1use candle_core::Device;
2use mistralrs_core::*;
3use std::num::NonZeroUsize;
4
5use crate::{best_device, Model};
6
7/// Configure a text GGUF model with the various parameters for loading, running, and other inference behaviors.
8pub struct GgufModelBuilder {
9    // Loading model
10    pub(crate) model_id: String,
11    pub(crate) files: Vec<String>,
12    pub(crate) tok_model_id: Option<String>,
13    pub(crate) token_source: TokenSource,
14    pub(crate) hf_revision: Option<String>,
15    pub(crate) chat_template: Option<String>,
16    pub(crate) jinja_explicit: Option<String>,
17    pub(crate) tokenizer_json: Option<String>,
18    pub(crate) device_mapping: Option<DeviceMapSetting>,
19    pub(crate) search_bert_model: Option<BertEmbeddingModel>,
20    pub(crate) device: Option<Device>,
21
22    // Model running
23    pub(crate) prompt_chunksize: Option<NonZeroUsize>,
24    pub(crate) force_cpu: bool,
25    pub(crate) topology: Option<Topology>,
26    pub(crate) throughput_logging: bool,
27
28    // Other things
29    pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
30    pub(crate) max_num_seqs: usize,
31    pub(crate) no_kv_cache: bool,
32    pub(crate) with_logging: bool,
33    pub(crate) prefix_cache_n: Option<usize>,
34}
35
36impl GgufModelBuilder {
37    /// A few defaults are applied here:
38    /// - Token source is from the cache (.cache/huggingface/token)
39    /// - Maximum number of sequences running is 32
40    /// - Number of sequences to hold in prefix cache is 16.
41    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
42    /// - By default, web searching compatible with the OpenAI `web_search_options` setting is disabled.
43    pub fn new(model_id: impl ToString, files: Vec<impl ToString>) -> Self {
44        Self {
45            model_id: model_id.to_string(),
46            files: files.into_iter().map(|f| f.to_string()).collect::<Vec<_>>(),
47            prompt_chunksize: None,
48            chat_template: None,
49            tokenizer_json: None,
50            force_cpu: false,
51            token_source: TokenSource::CacheToken,
52            hf_revision: None,
53            paged_attn_cfg: None,
54            max_num_seqs: 32,
55            no_kv_cache: false,
56            prefix_cache_n: Some(16),
57            with_logging: false,
58            topology: None,
59            tok_model_id: None,
60            device_mapping: None,
61            jinja_explicit: None,
62            throughput_logging: false,
63            search_bert_model: None,
64            device: None,
65        }
66    }
67
68    /// Enable searching compatible with the OpenAI `web_search_options` setting. This uses the BERT model specified or the default.
69    pub fn with_search(mut self, search_bert_model: BertEmbeddingModel) -> Self {
70        self.search_bert_model = Some(search_bert_model);
71        self
72    }
73
74    /// Enable runner throughput logging.
75    pub fn with_throughput_logging(mut self) -> Self {
76        self.throughput_logging = true;
77        self
78    }
79
80    /// Explicit JINJA chat template file (.jinja) to be used. If specified, this overrides all other chat templates.
81    pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
82        self.jinja_explicit = Some(jinja_explicit);
83        self
84    }
85
86    /// Source the tokenizer and chat template from this model ID (must contain `tokenizer.json` and `tokenizer_config.json`).
87    pub fn with_tok_model_id(mut self, tok_model_id: impl ToString) -> Self {
88        self.tok_model_id = Some(tok_model_id.to_string());
89        self
90    }
91
92    /// Set the prompt batchsize to use for inference.
93    pub fn with_prompt_chunksize(mut self, prompt_chunksize: NonZeroUsize) -> Self {
94        self.prompt_chunksize = Some(prompt_chunksize);
95        self
96    }
97
98    /// Set the model topology for use during loading. If there is an overlap, the topology type is used over the ISQ type.
99    pub fn with_topology(mut self, topology: Topology) -> Self {
100        self.topology = Some(topology);
101        self
102    }
103
104    /// Literal Jinja chat template OR Path (ending in `.json`) to one.
105    pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
106        self.chat_template = Some(chat_template.to_string());
107        self
108    }
109
110    /// Path to a discrete `tokenizer.json` file.
111    pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
112        self.tokenizer_json = Some(tokenizer_json.to_string());
113        self
114    }
115
116    /// Force usage of the CPU device. Do not use PagedAttention with this.
117    pub fn with_force_cpu(mut self) -> Self {
118        self.force_cpu = true;
119        self
120    }
121
122    /// Source of the Hugging Face token.
123    pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
124        self.token_source = token_source;
125        self
126    }
127
128    /// Set the revision to use for a Hugging Face remote model.
129    pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
130        self.hf_revision = Some(revision.to_string());
131        self
132    }
133
134    /// Enable PagedAttention. Configure PagedAttention with a [`PagedAttentionConfig`] object, which
135    /// can be created with sensible values with a [`PagedAttentionMetaBuilder`].
136    ///
137    /// If PagedAttention is not supported (query with [`paged_attn_supported`]), this will do nothing.
138    ///
139    /// [`PagedAttentionMetaBuilder`]: crate::PagedAttentionMetaBuilder
140    pub fn with_paged_attn(
141        mut self,
142        paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
143    ) -> anyhow::Result<Self> {
144        if paged_attn_supported() {
145            self.paged_attn_cfg = Some(paged_attn_cfg()?);
146        } else {
147            self.paged_attn_cfg = None;
148        }
149        Ok(self)
150    }
151
152    /// Set the maximum number of sequences which can be run at once.
153    pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
154        self.max_num_seqs = max_num_seqs;
155        self
156    }
157
158    /// Disable KV cache. Trade performance for memory usage.
159    pub fn with_no_kv_cache(mut self) -> Self {
160        self.no_kv_cache = true;
161        self
162    }
163
164    /// Set the number of sequences to hold in the prefix cache. Set to `None` to disable the prefix cacher.
165    pub fn with_prefix_cache_n(mut self, n_seqs: Option<usize>) -> Self {
166        self.prefix_cache_n = n_seqs;
167        self
168    }
169
170    /// Enable logging.
171    pub fn with_logging(mut self) -> Self {
172        self.with_logging = true;
173        self
174    }
175
176    /// Provide metadata to initialize the device mapper.
177    pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
178        self.device_mapping = Some(device_mapping);
179        self
180    }
181
182    /// Set the main device to load this model onto. Automatic device mapping will be performed starting with this device.
183    pub fn with_device(mut self, device: Device) -> Self {
184        self.device = Some(device);
185        self
186    }
187
188    pub async fn build(self) -> anyhow::Result<Model> {
189        let config = GGUFSpecificConfig {
190            prompt_chunksize: self.prompt_chunksize,
191            topology: self.topology,
192        };
193
194        if self.with_logging {
195            initialize_logging();
196        }
197
198        let loader = GGUFLoaderBuilder::new(
199            self.chat_template,
200            self.tok_model_id,
201            self.model_id,
202            self.files,
203            config,
204            self.no_kv_cache,
205            self.jinja_explicit,
206        )
207        .build();
208
209        // Load, into a Pipeline
210        let pipeline = loader.load_model_from_hf(
211            self.hf_revision,
212            self.token_source,
213            &ModelDType::Auto,
214            &self.device.unwrap_or(best_device(self.force_cpu).unwrap()),
215            !self.with_logging,
216            self.device_mapping
217                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
218            None,
219            self.paged_attn_cfg,
220        )?;
221
222        let scheduler_method = match self.paged_attn_cfg {
223            Some(_) => {
224                let config = pipeline
225                    .lock()
226                    .await
227                    .get_metadata()
228                    .cache_config
229                    .as_ref()
230                    .unwrap()
231                    .clone();
232
233                SchedulerConfig::PagedAttentionMeta {
234                    max_num_seqs: self.max_num_seqs,
235                    config,
236                }
237            }
238            None => SchedulerConfig::DefaultScheduler {
239                method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
240            },
241        };
242
243        let mut runner = MistralRsBuilder::new(
244            pipeline,
245            scheduler_method,
246            self.throughput_logging,
247            self.search_bert_model,
248        )
249        .with_no_kv_cache(self.no_kv_cache)
250        .with_no_prefix_cache(self.prefix_cache_n.is_none());
251
252        if let Some(n) = self.prefix_cache_n {
253            runner = runner.with_prefix_cache_n(n)
254        }
255
256        Ok(Model::new(runner.build()))
257    }
258}