mistralrs/
text_model.rs

1use candle_core::Device;
2use mistralrs_core::*;
3use std::{
4    num::NonZeroUsize,
5    ops::{Deref, DerefMut},
6    path::PathBuf,
7};
8
9use crate::{best_device, Model};
10
11#[derive(Clone)]
12/// Configure a text model with the various parameters for loading, running, and other inference behaviors.
13pub struct TextModelBuilder {
14    // Loading model
15    pub(crate) model_id: String,
16    pub(crate) token_source: TokenSource,
17    pub(crate) hf_revision: Option<String>,
18    pub(crate) write_uqff: Option<PathBuf>,
19    pub(crate) from_uqff: Option<Vec<PathBuf>>,
20    pub(crate) imatrix: Option<PathBuf>,
21    pub(crate) calibration_file: Option<PathBuf>,
22    pub(crate) chat_template: Option<String>,
23    pub(crate) jinja_explicit: Option<String>,
24    pub(crate) tokenizer_json: Option<String>,
25    pub(crate) device_mapping: Option<DeviceMapSetting>,
26    pub(crate) hf_cache_path: Option<PathBuf>,
27    pub(crate) search_bert_model: Option<BertEmbeddingModel>,
28    pub(crate) device: Option<Device>,
29
30    // Model running
31    pub(crate) prompt_chunksize: Option<NonZeroUsize>,
32    pub(crate) topology: Option<Topology>,
33    pub(crate) organization: IsqOrganization,
34    pub(crate) loader_type: Option<NormalLoaderType>,
35    pub(crate) dtype: ModelDType,
36    pub(crate) force_cpu: bool,
37    pub(crate) isq: Option<IsqType>,
38    pub(crate) throughput_logging: bool,
39
40    // Other things
41    pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
42    pub(crate) max_num_seqs: usize,
43    pub(crate) no_kv_cache: bool,
44    pub(crate) with_logging: bool,
45    pub(crate) prefix_cache_n: Option<usize>,
46}
47
48/// Builder for PagedAttention metadata.
49pub struct PagedAttentionMetaBuilder {
50    block_size: Option<usize>,
51    mem_cpu: usize,
52    mem_gpu: MemoryGpuConfig,
53}
54
55impl Default for PagedAttentionMetaBuilder {
56    fn default() -> Self {
57        Self {
58            block_size: None,
59            mem_cpu: 64,
60            mem_gpu: MemoryGpuConfig::ContextSize(4096),
61        }
62    }
63}
64
65impl PagedAttentionMetaBuilder {
66    pub fn with_block_size(mut self, block_size: usize) -> Self {
67        self.block_size = Some(block_size);
68        self
69    }
70
71    pub fn with_gpu_memory(mut self, mem_gpu: MemoryGpuConfig) -> Self {
72        self.mem_gpu = mem_gpu;
73        self
74    }
75
76    pub fn build(self) -> anyhow::Result<PagedAttentionConfig> {
77        PagedAttentionConfig::new(self.block_size, self.mem_cpu, self.mem_gpu)
78    }
79}
80
81impl TextModelBuilder {
82    /// A few defaults are applied here:
83    /// - MoQE ISQ organization
84    /// - Token source is from the cache (.cache/huggingface/token)
85    /// - Maximum number of sequences running is 32
86    /// - Number of sequences to hold in prefix cache is 16.
87    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
88    /// - By default, web searching compatible with the OpenAI `web_search_options` setting is disabled.
89    pub fn new(model_id: impl ToString) -> Self {
90        Self {
91            model_id: model_id.to_string(),
92            prompt_chunksize: None,
93            topology: None,
94            organization: IsqOrganization::Default,
95            write_uqff: None,
96            from_uqff: None,
97            chat_template: None,
98            tokenizer_json: None,
99            loader_type: None,
100            dtype: ModelDType::Auto,
101            force_cpu: false,
102            token_source: TokenSource::CacheToken,
103            hf_revision: None,
104            isq: None,
105            paged_attn_cfg: None,
106            max_num_seqs: 32,
107            no_kv_cache: false,
108            prefix_cache_n: Some(16),
109            with_logging: false,
110            device_mapping: None,
111            imatrix: None,
112            calibration_file: None,
113            jinja_explicit: None,
114            throughput_logging: false,
115            hf_cache_path: None,
116            search_bert_model: None,
117            device: None,
118        }
119    }
120
121    /// Enable searching compatible with the OpenAI `web_search_options` setting. This uses the BERT model specified or the default.
122    pub fn with_search(mut self, search_bert_model: BertEmbeddingModel) -> Self {
123        self.search_bert_model = Some(search_bert_model);
124        self
125    }
126
127    /// Enable runner throughput logging.
128    pub fn with_throughput_logging(mut self) -> Self {
129        self.throughput_logging = true;
130        self
131    }
132
133    /// Explicit JINJA chat template file (.jinja) to be used. If specified, this overrides all other chat templates.
134    pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
135        self.jinja_explicit = Some(jinja_explicit);
136        self
137    }
138
139    /// Set the prompt batchsize to use for inference.
140    pub fn with_prompt_chunksize(mut self, prompt_chunksize: NonZeroUsize) -> Self {
141        self.prompt_chunksize = Some(prompt_chunksize);
142        self
143    }
144
145    /// Set the model topology for use during loading. If there is an overlap, the topology type is used over the ISQ type.
146    pub fn with_topology(mut self, topology: Topology) -> Self {
147        self.topology = Some(topology);
148        self
149    }
150
151    /// Organize ISQ to enable MoQE (Mixture of Quantized Experts, <https://arxiv.org/abs/2310.02410>)
152    pub fn with_mixture_qexperts_isq(mut self) -> Self {
153        self.organization = IsqOrganization::MoeExpertsOnly;
154        self
155    }
156
157    /// Literal Jinja chat template OR Path (ending in `.json`) to one.
158    pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
159        self.chat_template = Some(chat_template.to_string());
160        self
161    }
162
163    /// Path to a discrete `tokenizer.json` file.
164    pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
165        self.tokenizer_json = Some(tokenizer_json.to_string());
166        self
167    }
168
169    /// Manually set the model loader type. Otherwise, it will attempt to automatically
170    /// determine the loader type.
171    pub fn with_loader_type(mut self, loader_type: NormalLoaderType) -> Self {
172        self.loader_type = Some(loader_type);
173        self
174    }
175
176    /// Load the model in a certain dtype.
177    pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
178        self.dtype = dtype;
179        self
180    }
181
182    /// Force usage of the CPU device. Do not use PagedAttention with this.
183    pub fn with_force_cpu(mut self) -> Self {
184        self.force_cpu = true;
185        self
186    }
187
188    /// Source of the Hugging Face token.
189    pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
190        self.token_source = token_source;
191        self
192    }
193
194    /// Set the revision to use for a Hugging Face remote model.
195    pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
196        self.hf_revision = Some(revision.to_string());
197        self
198    }
199
200    /// Use ISQ of a certain type. If there is an overlap, the topology type is used over the ISQ type.
201    pub fn with_isq(mut self, isq: IsqType) -> Self {
202        self.isq = Some(isq);
203        self
204    }
205
206    /// Utilise this imatrix file during ISQ. Incompatible with specifying a calibration file.
207    pub fn with_imatrix(mut self, path: PathBuf) -> Self {
208        self.imatrix = Some(path);
209        self
210    }
211
212    /// Utilise this calibration file to collcet an imatrix. Incompatible with specifying a calibration file.
213    pub fn with_calibration_file(mut self, path: PathBuf) -> Self {
214        self.calibration_file = Some(path);
215        self
216    }
217
218    /// Enable PagedAttention. Configure PagedAttention with a [`PagedAttentionConfig`] object, which
219    /// can be created with sensible values with a [`PagedAttentionMetaBuilder`].
220    ///
221    /// If PagedAttention is not supported (query with [`paged_attn_supported`]), this will do nothing.
222    pub fn with_paged_attn(
223        mut self,
224        paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
225    ) -> anyhow::Result<Self> {
226        if paged_attn_supported() {
227            self.paged_attn_cfg = Some(paged_attn_cfg()?);
228        } else {
229            self.paged_attn_cfg = None;
230        }
231        Ok(self)
232    }
233
234    /// Set the maximum number of sequences which can be run at once.
235    pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
236        self.max_num_seqs = max_num_seqs;
237        self
238    }
239
240    /// Disable KV cache. Trade performance for memory usage.
241    pub fn with_no_kv_cache(mut self) -> Self {
242        self.no_kv_cache = true;
243        self
244    }
245
246    /// Set the number of sequences to hold in the prefix cache. Set to `None` to disable the prefix cacher.
247    pub fn with_prefix_cache_n(mut self, n_seqs: Option<usize>) -> Self {
248        self.prefix_cache_n = n_seqs;
249        self
250    }
251
252    /// Enable logging.
253    pub fn with_logging(mut self) -> Self {
254        self.with_logging = true;
255        self
256    }
257
258    /// Provide metadata to initialize the device mapper.
259    pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
260        self.device_mapping = Some(device_mapping);
261        self
262    }
263
264    /// Path to read a UQFF file from.
265    pub fn from_uqff(mut self, path: Vec<PathBuf>) -> Self {
266        self.from_uqff = Some(path);
267        self
268    }
269
270    /// Path to write a UQFF file to.
271    ///
272    /// The parent (part of the path excluding the filename) will determine where any other files
273    /// generated are written to. These can be used to load UQFF models standalone, and may include:
274    /// - `residual.safetensors`
275    /// - `tokenizer.json`
276    /// - `config.json`
277    /// - And others
278    pub fn write_uqff(mut self, path: PathBuf) -> Self {
279        self.write_uqff = Some(path);
280        self
281    }
282
283    /// Cache path for Hugging Face models downloaded locally
284    pub fn from_hf_cache_pathf(mut self, hf_cache_path: PathBuf) -> Self {
285        self.hf_cache_path = Some(hf_cache_path);
286        self
287    }
288
289    /// Set the main device to load this model onto. Automatic device mapping will be performed starting with this device.
290    pub fn with_device(mut self, device: Device) -> Self {
291        self.device = Some(device);
292        self
293    }
294
295    pub async fn build(self) -> anyhow::Result<Model> {
296        let config = NormalSpecificConfig {
297            prompt_chunksize: self.prompt_chunksize,
298            topology: self.topology,
299            organization: self.organization,
300            write_uqff: self.write_uqff,
301            from_uqff: self.from_uqff,
302            imatrix: self.imatrix,
303            calibration_file: self.calibration_file,
304            hf_cache_path: self.hf_cache_path,
305        };
306
307        if self.with_logging {
308            initialize_logging();
309        }
310
311        let loader = NormalLoaderBuilder::new(
312            config,
313            self.chat_template,
314            self.tokenizer_json,
315            Some(self.model_id),
316            self.no_kv_cache,
317            self.jinja_explicit,
318        )
319        .build(self.loader_type)?;
320
321        // Load, into a Pipeline
322        let pipeline = loader.load_model_from_hf(
323            self.hf_revision,
324            self.token_source,
325            &self.dtype,
326            &self.device.unwrap_or(best_device(self.force_cpu).unwrap()),
327            !self.with_logging,
328            self.device_mapping
329                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
330            self.isq,
331            self.paged_attn_cfg,
332        )?;
333
334        let scheduler_method = match self.paged_attn_cfg {
335            Some(_) => {
336                let config = pipeline
337                    .lock()
338                    .await
339                    .get_metadata()
340                    .cache_config
341                    .as_ref()
342                    .cloned();
343
344                if let Some(config) = config {
345                    SchedulerConfig::PagedAttentionMeta {
346                        max_num_seqs: self.max_num_seqs,
347                        config,
348                    }
349                } else {
350                    SchedulerConfig::DefaultScheduler {
351                        method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
352                    }
353                }
354            }
355            None => SchedulerConfig::DefaultScheduler {
356                method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
357            },
358        };
359
360        let mut runner = MistralRsBuilder::new(
361            pipeline,
362            scheduler_method,
363            self.throughput_logging,
364            self.search_bert_model,
365        )
366        .with_no_kv_cache(self.no_kv_cache)
367        .with_no_prefix_cache(self.prefix_cache_n.is_none());
368
369        if let Some(n) = self.prefix_cache_n {
370            runner = runner.with_prefix_cache_n(n)
371        }
372
373        Ok(Model::new(runner.build()))
374    }
375}
376
377#[derive(Clone)]
378/// Configure a UQFF text model with the various parameters for loading, running, and other inference behaviors.
379/// This wraps and implements `DerefMut` for the TextModelBuilder, so users should take care to not call UQFF-related methods.
380pub struct UqffTextModelBuilder(TextModelBuilder);
381
382impl UqffTextModelBuilder {
383    /// A few defaults are applied here:
384    /// - MoQE ISQ organization
385    /// - Token source is from the cache (.cache/huggingface/token)
386    /// - Maximum number of sequences running is 32
387    /// - Number of sequences to hold in prefix cache is 16.
388    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
389    pub fn new(model_id: impl ToString, uqff_file: Vec<PathBuf>) -> Self {
390        let mut inner = TextModelBuilder::new(model_id);
391        inner = inner.from_uqff(uqff_file);
392        Self(inner)
393    }
394
395    pub async fn build(self) -> anyhow::Result<Model> {
396        self.0.build().await
397    }
398
399    /// This wraps the VisionModelBuilder, so users should take care to not call UQFF-related methods.
400    pub fn into_inner(self) -> TextModelBuilder {
401        self.0
402    }
403}
404
405impl Deref for UqffTextModelBuilder {
406    type Target = TextModelBuilder;
407
408    fn deref(&self) -> &Self::Target {
409        &self.0
410    }
411}
412
413impl DerefMut for UqffTextModelBuilder {
414    fn deref_mut(&mut self) -> &mut Self::Target {
415        &mut self.0
416    }
417}
418
419impl From<UqffTextModelBuilder> for TextModelBuilder {
420    fn from(value: UqffTextModelBuilder) -> Self {
421        value.0
422    }
423}