mistralrs/
gguf.rs

1use mistralrs_core::*;
2use std::num::NonZeroUsize;
3
4use crate::{best_device, Model};
5
6/// Configure a text GGUF model with the various parameters for loading, running, and other inference behaviors.
7pub struct GgufModelBuilder {
8    // Loading model
9    pub(crate) model_id: String,
10    pub(crate) files: Vec<String>,
11    pub(crate) tok_model_id: Option<String>,
12    pub(crate) token_source: TokenSource,
13    pub(crate) hf_revision: Option<String>,
14    pub(crate) chat_template: Option<String>,
15    pub(crate) jinja_explicit: Option<String>,
16    pub(crate) tokenizer_json: Option<String>,
17    pub(crate) device_mapping: Option<DeviceMapSetting>,
18    pub(crate) search_bert_model: Option<BertEmbeddingModel>,
19
20    // Model running
21    pub(crate) prompt_chunksize: Option<NonZeroUsize>,
22    pub(crate) force_cpu: bool,
23    pub(crate) topology: Option<Topology>,
24    pub(crate) throughput_logging: bool,
25
26    // Other things
27    pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
28    pub(crate) max_num_seqs: usize,
29    pub(crate) no_kv_cache: bool,
30    pub(crate) with_logging: bool,
31    pub(crate) prefix_cache_n: Option<usize>,
32}
33
34impl GgufModelBuilder {
35    /// A few defaults are applied here:
36    /// - Token source is from the cache (.cache/huggingface/token)
37    /// - Maximum number of sequences running is 32
38    /// - Number of sequences to hold in prefix cache is 16.
39    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
40    /// - By default, web searching compatible with the OpenAI `web_search_options` setting is disabled.
41    pub fn new(model_id: impl ToString, files: Vec<impl ToString>) -> Self {
42        Self {
43            model_id: model_id.to_string(),
44            files: files.into_iter().map(|f| f.to_string()).collect::<Vec<_>>(),
45            prompt_chunksize: None,
46            chat_template: None,
47            tokenizer_json: None,
48            force_cpu: false,
49            token_source: TokenSource::CacheToken,
50            hf_revision: None,
51            paged_attn_cfg: None,
52            max_num_seqs: 32,
53            no_kv_cache: false,
54            prefix_cache_n: Some(16),
55            with_logging: false,
56            topology: None,
57            tok_model_id: None,
58            device_mapping: None,
59            jinja_explicit: None,
60            throughput_logging: false,
61            search_bert_model: None,
62        }
63    }
64
65    /// Enable searching compatible with the OpenAI `web_search_options` setting. This uses the BERT model specified or the default.
66    pub fn with_search(mut self, search_bert_model: BertEmbeddingModel) -> Self {
67        self.search_bert_model = Some(search_bert_model);
68        self
69    }
70
71    /// Enable runner throughput logging.
72    pub fn with_throughput_logging(mut self) -> Self {
73        self.throughput_logging = true;
74        self
75    }
76
77    /// Explicit JINJA chat template file (.jinja) to be used. If specified, this overrides all other chat templates.
78    pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
79        self.jinja_explicit = Some(jinja_explicit);
80        self
81    }
82
83    /// Source the tokenizer and chat template from this model ID (must contain `tokenizer.json` and `tokenizer_config.json`).
84    pub fn with_tok_model_id(mut self, tok_model_id: impl ToString) -> Self {
85        self.tok_model_id = Some(tok_model_id.to_string());
86        self
87    }
88
89    /// Set the prompt batchsize to use for inference.
90    pub fn with_prompt_chunksize(mut self, prompt_chunksize: NonZeroUsize) -> Self {
91        self.prompt_chunksize = Some(prompt_chunksize);
92        self
93    }
94
95    /// Set the model topology for use during loading. If there is an overlap, the topology type is used over the ISQ type.
96    pub fn with_topology(mut self, topology: Topology) -> Self {
97        self.topology = Some(topology);
98        self
99    }
100
101    /// Literal Jinja chat template OR Path (ending in `.json`) to one.
102    pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
103        self.chat_template = Some(chat_template.to_string());
104        self
105    }
106
107    /// Path to a discrete `tokenizer.json` file.
108    pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
109        self.tokenizer_json = Some(tokenizer_json.to_string());
110        self
111    }
112
113    /// Force usage of the CPU device. Do not use PagedAttention with this.
114    pub fn with_force_cpu(mut self) -> Self {
115        self.force_cpu = true;
116        self
117    }
118
119    /// Source of the Hugging Face token.
120    pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
121        self.token_source = token_source;
122        self
123    }
124
125    /// Set the revision to use for a Hugging Face remote model.
126    pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
127        self.hf_revision = Some(revision.to_string());
128        self
129    }
130
131    /// Enable PagedAttention. Configure PagedAttention with a [`PagedAttentionConfig`] object, which
132    /// can be created with sensible values with a [`PagedAttentionMetaBuilder`].
133    ///
134    /// If PagedAttention is not supported (query with [`paged_attn_supported`]), this will do nothing.
135    ///
136    /// [`PagedAttentionMetaBuilder`]: crate::PagedAttentionMetaBuilder
137    pub fn with_paged_attn(
138        mut self,
139        paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
140    ) -> anyhow::Result<Self> {
141        if paged_attn_supported() {
142            self.paged_attn_cfg = Some(paged_attn_cfg()?);
143        } else {
144            self.paged_attn_cfg = None;
145        }
146        Ok(self)
147    }
148
149    /// Set the maximum number of sequences which can be run at once.
150    pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
151        self.max_num_seqs = max_num_seqs;
152        self
153    }
154
155    /// Disable KV cache. Trade performance for memory usage.
156    pub fn with_no_kv_cache(mut self) -> Self {
157        self.no_kv_cache = true;
158        self
159    }
160
161    /// Set the number of sequences to hold in the prefix cache. Set to `None` to disable the prefix cacher.
162    pub fn with_prefix_cache_n(mut self, n_seqs: Option<usize>) -> Self {
163        self.prefix_cache_n = n_seqs;
164        self
165    }
166
167    /// Enable logging.
168    pub fn with_logging(mut self) -> Self {
169        self.with_logging = true;
170        self
171    }
172
173    /// Provide metadata to initialize the device mapper.
174    pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
175        self.device_mapping = Some(device_mapping);
176        self
177    }
178
179    pub async fn build(self) -> anyhow::Result<Model> {
180        let config = GGUFSpecificConfig {
181            prompt_chunksize: self.prompt_chunksize,
182            topology: self.topology,
183        };
184
185        if self.with_logging {
186            initialize_logging();
187        }
188
189        let loader = GGUFLoaderBuilder::new(
190            self.chat_template,
191            self.tok_model_id,
192            self.model_id,
193            self.files,
194            config,
195            self.no_kv_cache,
196            self.jinja_explicit,
197        )
198        .build();
199
200        // Load, into a Pipeline
201        let pipeline = loader.load_model_from_hf(
202            self.hf_revision,
203            self.token_source,
204            &ModelDType::Auto,
205            &best_device(self.force_cpu)?,
206            !self.with_logging,
207            self.device_mapping
208                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
209            None,
210            self.paged_attn_cfg,
211        )?;
212
213        let scheduler_method = match self.paged_attn_cfg {
214            Some(_) => {
215                let config = pipeline
216                    .lock()
217                    .await
218                    .get_metadata()
219                    .cache_config
220                    .as_ref()
221                    .unwrap()
222                    .clone();
223
224                SchedulerConfig::PagedAttentionMeta {
225                    max_num_seqs: self.max_num_seqs,
226                    config,
227                }
228            }
229            None => SchedulerConfig::DefaultScheduler {
230                method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
231            },
232        };
233
234        let mut runner = MistralRsBuilder::new(
235            pipeline,
236            scheduler_method,
237            self.throughput_logging,
238            self.search_bert_model,
239        )
240        .with_no_kv_cache(self.no_kv_cache)
241        .with_no_prefix_cache(self.prefix_cache_n.is_none());
242
243        if let Some(n) = self.prefix_cache_n {
244            runner = runner.with_prefix_cache_n(n)
245        }
246
247        Ok(Model::new(runner.build()))
248    }
249}