mistralrs/
text_model.rs

1use mistralrs_core::*;
2use std::{
3    num::NonZeroUsize,
4    ops::{Deref, DerefMut},
5    path::PathBuf,
6};
7
8use crate::{best_device, Model};
9
10#[derive(Clone)]
11/// Configure a text model with the various parameters for loading, running, and other inference behaviors.
12pub struct TextModelBuilder {
13    // Loading model
14    pub(crate) model_id: String,
15    pub(crate) token_source: TokenSource,
16    pub(crate) hf_revision: Option<String>,
17    pub(crate) write_uqff: Option<PathBuf>,
18    pub(crate) from_uqff: Option<PathBuf>,
19    pub(crate) imatrix: Option<PathBuf>,
20    pub(crate) calibration_file: Option<PathBuf>,
21    pub(crate) chat_template: Option<String>,
22    pub(crate) jinja_explicit: Option<String>,
23    pub(crate) tokenizer_json: Option<String>,
24    pub(crate) device_mapping: Option<DeviceMapSetting>,
25    pub(crate) hf_cache_path: Option<PathBuf>,
26    pub(crate) search_bert_model: Option<BertEmbeddingModel>,
27
28    // Model running
29    pub(crate) use_flash_attn: bool,
30    pub(crate) prompt_chunksize: Option<NonZeroUsize>,
31    pub(crate) topology: Option<Topology>,
32    pub(crate) organization: IsqOrganization,
33    pub(crate) loader_type: Option<NormalLoaderType>,
34    pub(crate) dtype: ModelDType,
35    pub(crate) force_cpu: bool,
36    pub(crate) isq: Option<IsqType>,
37    pub(crate) throughput_logging: bool,
38
39    // Other things
40    pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
41    pub(crate) max_num_seqs: usize,
42    pub(crate) no_kv_cache: bool,
43    pub(crate) with_logging: bool,
44    pub(crate) prefix_cache_n: Option<usize>,
45}
46
47/// Builder for PagedAttention metadata.
48pub struct PagedAttentionMetaBuilder {
49    block_size: Option<usize>,
50    mem_cpu: usize,
51    mem_gpu: MemoryGpuConfig,
52}
53
54impl Default for PagedAttentionMetaBuilder {
55    fn default() -> Self {
56        Self {
57            block_size: None,
58            mem_cpu: 64,
59            mem_gpu: MemoryGpuConfig::ContextSize(4096),
60        }
61    }
62}
63
64impl PagedAttentionMetaBuilder {
65    pub fn with_block_size(mut self, block_size: usize) -> Self {
66        self.block_size = Some(block_size);
67        self
68    }
69
70    pub fn with_gpu_memory(mut self, mem_gpu: MemoryGpuConfig) -> Self {
71        self.mem_gpu = mem_gpu;
72        self
73    }
74
75    pub fn build(self) -> anyhow::Result<PagedAttentionConfig> {
76        PagedAttentionConfig::new(self.block_size, self.mem_cpu, self.mem_gpu)
77    }
78}
79
80impl TextModelBuilder {
81    /// A few defaults are applied here:
82    /// - MoQE ISQ organization
83    /// - Token source is from the cache (.cache/huggingface/token)
84    /// - Maximum number of sequences running is 32
85    /// - Number of sequences to hold in prefix cache is 16.
86    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
87    /// - By default, web searching compatible with the OpenAI `web_search_options` setting is disabled.
88    pub fn new(model_id: impl ToString) -> Self {
89        Self {
90            model_id: model_id.to_string(),
91            use_flash_attn: cfg!(feature = "flash-attn"),
92            prompt_chunksize: None,
93            topology: None,
94            organization: IsqOrganization::Default,
95            write_uqff: None,
96            from_uqff: None,
97            chat_template: None,
98            tokenizer_json: None,
99            loader_type: None,
100            dtype: ModelDType::Auto,
101            force_cpu: false,
102            token_source: TokenSource::CacheToken,
103            hf_revision: None,
104            isq: None,
105            paged_attn_cfg: None,
106            max_num_seqs: 32,
107            no_kv_cache: false,
108            prefix_cache_n: Some(16),
109            with_logging: false,
110            device_mapping: None,
111            imatrix: None,
112            calibration_file: None,
113            jinja_explicit: None,
114            throughput_logging: false,
115            hf_cache_path: None,
116            search_bert_model: None,
117        }
118    }
119
120    /// Enable searching compatible with the OpenAI `web_search_options` setting. This uses the BERT model specified or the default.
121    pub fn with_search(mut self, search_bert_model: BertEmbeddingModel) -> Self {
122        self.search_bert_model = Some(search_bert_model);
123        self
124    }
125
126    /// Enable runner throughput logging.
127    pub fn with_throughput_logging(mut self) -> Self {
128        self.throughput_logging = true;
129        self
130    }
131
132    /// Explicit JINJA chat template file (.jinja) to be used. If specified, this overrides all other chat templates.
133    pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
134        self.jinja_explicit = Some(jinja_explicit);
135        self
136    }
137
138    /// Set the prompt batchsize to use for inference.
139    pub fn with_prompt_chunksize(mut self, prompt_chunksize: NonZeroUsize) -> Self {
140        self.prompt_chunksize = Some(prompt_chunksize);
141        self
142    }
143
144    /// Set the model topology for use during loading. If there is an overlap, the topology type is used over the ISQ type.
145    pub fn with_topology(mut self, topology: Topology) -> Self {
146        self.topology = Some(topology);
147        self
148    }
149
150    /// Organize ISQ to enable MoQE (Mixture of Quantized Experts, <https://arxiv.org/abs/2310.02410>)
151    pub fn with_mixture_qexperts_isq(mut self) -> Self {
152        self.organization = IsqOrganization::MoeExpertsOnly;
153        self
154    }
155
156    /// Literal Jinja chat template OR Path (ending in `.json`) to one.
157    pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
158        self.chat_template = Some(chat_template.to_string());
159        self
160    }
161
162    /// Path to a discrete `tokenizer.json` file.
163    pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
164        self.tokenizer_json = Some(tokenizer_json.to_string());
165        self
166    }
167
168    /// Manually set the model loader type. Otherwise, it will attempt to automatically
169    /// determine the loader type.
170    pub fn with_loader_type(mut self, loader_type: NormalLoaderType) -> Self {
171        self.loader_type = Some(loader_type);
172        self
173    }
174
175    /// Load the model in a certain dtype.
176    pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
177        self.dtype = dtype;
178        self
179    }
180
181    /// Force usage of the CPU device. Do not use PagedAttention with this.
182    pub fn with_force_cpu(mut self) -> Self {
183        self.force_cpu = true;
184        self
185    }
186
187    /// Source of the Hugging Face token.
188    pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
189        self.token_source = token_source;
190        self
191    }
192
193    /// Set the revision to use for a Hugging Face remote model.
194    pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
195        self.hf_revision = Some(revision.to_string());
196        self
197    }
198
199    /// Use ISQ of a certain type. If there is an overlap, the topology type is used over the ISQ type.
200    pub fn with_isq(mut self, isq: IsqType) -> Self {
201        self.isq = Some(isq);
202        self
203    }
204
205    /// Utilise this imatrix file during ISQ. Incompatible with specifying a calibration file.
206    pub fn with_imatrix(mut self, path: PathBuf) -> Self {
207        self.imatrix = Some(path);
208        self
209    }
210
211    /// Utilise this calibration file to collcet an imatrix. Incompatible with specifying a calibration file.
212    pub fn with_calibration_file(mut self, path: PathBuf) -> Self {
213        self.calibration_file = Some(path);
214        self
215    }
216
217    /// Enable PagedAttention. Configure PagedAttention with a [`PagedAttentionConfig`] object, which
218    /// can be created with sensible values with a [`PagedAttentionMetaBuilder`].
219    ///
220    /// If PagedAttention is not supported (query with [`paged_attn_supported`]), this will do nothing.
221    pub fn with_paged_attn(
222        mut self,
223        paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
224    ) -> anyhow::Result<Self> {
225        if paged_attn_supported() {
226            self.paged_attn_cfg = Some(paged_attn_cfg()?);
227        } else {
228            self.paged_attn_cfg = None;
229        }
230        Ok(self)
231    }
232
233    /// Set the maximum number of sequences which can be run at once.
234    pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
235        self.max_num_seqs = max_num_seqs;
236        self
237    }
238
239    /// Disable KV cache. Trade performance for memory usage.
240    pub fn with_no_kv_cache(mut self) -> Self {
241        self.no_kv_cache = true;
242        self
243    }
244
245    /// Set the number of sequences to hold in the prefix cache. Set to `None` to disable the prefix cacher.
246    pub fn with_prefix_cache_n(mut self, n_seqs: Option<usize>) -> Self {
247        self.prefix_cache_n = n_seqs;
248        self
249    }
250
251    /// Enable logging.
252    pub fn with_logging(mut self) -> Self {
253        self.with_logging = true;
254        self
255    }
256
257    /// Provide metadata to initialize the device mapper.
258    pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
259        self.device_mapping = Some(device_mapping);
260        self
261    }
262
263    /// Path to read a UQFF file from.
264    pub fn from_uqff(mut self, path: PathBuf) -> Self {
265        self.from_uqff = Some(path);
266        self
267    }
268
269    /// Path to write a UQFF file to.
270    ///
271    /// The parent (part of the path excluding the filename) will determine where any other files
272    /// generated are written to. These can be used to load UQFF models standalone, and may include:
273    /// - `residual.safetensors`
274    /// - `tokenizer.json`
275    /// - `config.json`
276    /// - And others
277    pub fn write_uqff(mut self, path: PathBuf) -> Self {
278        self.write_uqff = Some(path);
279        self
280    }
281
282    /// Cache path for Hugging Face models downloaded locally
283    pub fn from_hf_cache_pathf(mut self, hf_cache_path: PathBuf) -> Self {
284        self.hf_cache_path = Some(hf_cache_path);
285        self
286    }
287
288    pub async fn build(self) -> anyhow::Result<Model> {
289        let config = NormalSpecificConfig {
290            use_flash_attn: self.use_flash_attn,
291            prompt_chunksize: self.prompt_chunksize,
292            topology: self.topology,
293            organization: self.organization,
294            write_uqff: self.write_uqff,
295            from_uqff: self.from_uqff,
296            imatrix: self.imatrix,
297            calibration_file: self.calibration_file,
298            hf_cache_path: self.hf_cache_path,
299        };
300
301        if self.with_logging {
302            initialize_logging();
303        }
304
305        let loader = NormalLoaderBuilder::new(
306            config,
307            self.chat_template,
308            self.tokenizer_json,
309            Some(self.model_id),
310            self.no_kv_cache,
311            self.jinja_explicit,
312        )
313        .build(self.loader_type)?;
314
315        // Load, into a Pipeline
316        let pipeline = loader.load_model_from_hf(
317            self.hf_revision,
318            self.token_source,
319            &self.dtype,
320            &best_device(self.force_cpu)?,
321            !self.with_logging,
322            self.device_mapping
323                .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
324            self.isq,
325            self.paged_attn_cfg,
326        )?;
327
328        let scheduler_method = match self.paged_attn_cfg {
329            Some(_) => {
330                let config = pipeline
331                    .lock()
332                    .await
333                    .get_metadata()
334                    .cache_config
335                    .as_ref()
336                    .unwrap()
337                    .clone();
338
339                SchedulerConfig::PagedAttentionMeta {
340                    max_num_seqs: self.max_num_seqs,
341                    config,
342                }
343            }
344            None => SchedulerConfig::DefaultScheduler {
345                method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
346            },
347        };
348
349        let mut runner = MistralRsBuilder::new(
350            pipeline,
351            scheduler_method,
352            self.throughput_logging,
353            self.search_bert_model,
354        )
355        .with_no_kv_cache(self.no_kv_cache)
356        .with_no_prefix_cache(self.prefix_cache_n.is_none());
357
358        if let Some(n) = self.prefix_cache_n {
359            runner = runner.with_prefix_cache_n(n)
360        }
361
362        Ok(Model::new(runner.build()))
363    }
364}
365
366#[derive(Clone)]
367/// Configure a UQFF text model with the various parameters for loading, running, and other inference behaviors.
368/// This wraps and implements `DerefMut` for the TextModelBuilder, so users should take care to not call UQFF-related methods.
369pub struct UqffTextModelBuilder(TextModelBuilder);
370
371impl UqffTextModelBuilder {
372    /// A few defaults are applied here:
373    /// - MoQE ISQ organization
374    /// - Token source is from the cache (.cache/huggingface/token)
375    /// - Maximum number of sequences running is 32
376    /// - Number of sequences to hold in prefix cache is 16.
377    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
378    pub fn new(model_id: impl ToString, uqff_file: PathBuf) -> Self {
379        let mut inner = TextModelBuilder::new(model_id);
380        inner = inner.from_uqff(uqff_file);
381        Self(inner)
382    }
383
384    pub async fn build(self) -> anyhow::Result<Model> {
385        self.0.build().await
386    }
387
388    /// This wraps the VisionModelBuilder, so users should take care to not call UQFF-related methods.
389    pub fn into_inner(self) -> TextModelBuilder {
390        self.0
391    }
392}
393
394impl Deref for UqffTextModelBuilder {
395    type Target = TextModelBuilder;
396
397    fn deref(&self) -> &Self::Target {
398        &self.0
399    }
400}
401
402impl DerefMut for UqffTextModelBuilder {
403    fn deref_mut(&mut self) -> &mut Self::Target {
404        &mut self.0
405    }
406}
407
408impl From<UqffTextModelBuilder> for TextModelBuilder {
409    fn from(value: UqffTextModelBuilder) -> Self {
410        value.0
411    }
412}