1use mistralrs_core::*;
2use std::{
3 num::NonZeroUsize,
4 ops::{Deref, DerefMut},
5 path::PathBuf,
6};
7
8use crate::{best_device, Model};
9
10#[derive(Clone)]
11pub struct TextModelBuilder {
13 pub(crate) model_id: String,
15 pub(crate) token_source: TokenSource,
16 pub(crate) hf_revision: Option<String>,
17 pub(crate) write_uqff: Option<PathBuf>,
18 pub(crate) from_uqff: Option<PathBuf>,
19 pub(crate) imatrix: Option<PathBuf>,
20 pub(crate) calibration_file: Option<PathBuf>,
21 pub(crate) chat_template: Option<String>,
22 pub(crate) jinja_explicit: Option<String>,
23 pub(crate) tokenizer_json: Option<String>,
24 pub(crate) device_mapping: Option<DeviceMapSetting>,
25 pub(crate) hf_cache_path: Option<PathBuf>,
26 pub(crate) search_bert_model: Option<BertEmbeddingModel>,
27
28 pub(crate) use_flash_attn: bool,
30 pub(crate) prompt_chunksize: Option<NonZeroUsize>,
31 pub(crate) topology: Option<Topology>,
32 pub(crate) organization: IsqOrganization,
33 pub(crate) loader_type: Option<NormalLoaderType>,
34 pub(crate) dtype: ModelDType,
35 pub(crate) force_cpu: bool,
36 pub(crate) isq: Option<IsqType>,
37 pub(crate) throughput_logging: bool,
38
39 pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
41 pub(crate) max_num_seqs: usize,
42 pub(crate) no_kv_cache: bool,
43 pub(crate) with_logging: bool,
44 pub(crate) prefix_cache_n: Option<usize>,
45}
46
47pub struct PagedAttentionMetaBuilder {
49 block_size: Option<usize>,
50 mem_cpu: usize,
51 mem_gpu: MemoryGpuConfig,
52}
53
54impl Default for PagedAttentionMetaBuilder {
55 fn default() -> Self {
56 Self {
57 block_size: None,
58 mem_cpu: 64,
59 mem_gpu: MemoryGpuConfig::ContextSize(4096),
60 }
61 }
62}
63
64impl PagedAttentionMetaBuilder {
65 pub fn with_block_size(mut self, block_size: usize) -> Self {
66 self.block_size = Some(block_size);
67 self
68 }
69
70 pub fn with_gpu_memory(mut self, mem_gpu: MemoryGpuConfig) -> Self {
71 self.mem_gpu = mem_gpu;
72 self
73 }
74
75 pub fn build(self) -> anyhow::Result<PagedAttentionConfig> {
76 PagedAttentionConfig::new(self.block_size, self.mem_cpu, self.mem_gpu)
77 }
78}
79
80impl TextModelBuilder {
81 pub fn new(model_id: impl ToString) -> Self {
89 Self {
90 model_id: model_id.to_string(),
91 use_flash_attn: cfg!(feature = "flash-attn"),
92 prompt_chunksize: None,
93 topology: None,
94 organization: IsqOrganization::Default,
95 write_uqff: None,
96 from_uqff: None,
97 chat_template: None,
98 tokenizer_json: None,
99 loader_type: None,
100 dtype: ModelDType::Auto,
101 force_cpu: false,
102 token_source: TokenSource::CacheToken,
103 hf_revision: None,
104 isq: None,
105 paged_attn_cfg: None,
106 max_num_seqs: 32,
107 no_kv_cache: false,
108 prefix_cache_n: Some(16),
109 with_logging: false,
110 device_mapping: None,
111 imatrix: None,
112 calibration_file: None,
113 jinja_explicit: None,
114 throughput_logging: false,
115 hf_cache_path: None,
116 search_bert_model: None,
117 }
118 }
119
120 pub fn with_search(mut self, search_bert_model: BertEmbeddingModel) -> Self {
122 self.search_bert_model = Some(search_bert_model);
123 self
124 }
125
126 pub fn with_throughput_logging(mut self) -> Self {
128 self.throughput_logging = true;
129 self
130 }
131
132 pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
134 self.jinja_explicit = Some(jinja_explicit);
135 self
136 }
137
138 pub fn with_prompt_chunksize(mut self, prompt_chunksize: NonZeroUsize) -> Self {
140 self.prompt_chunksize = Some(prompt_chunksize);
141 self
142 }
143
144 pub fn with_topology(mut self, topology: Topology) -> Self {
146 self.topology = Some(topology);
147 self
148 }
149
150 pub fn with_mixture_qexperts_isq(mut self) -> Self {
152 self.organization = IsqOrganization::MoeExpertsOnly;
153 self
154 }
155
156 pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
158 self.chat_template = Some(chat_template.to_string());
159 self
160 }
161
162 pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
164 self.tokenizer_json = Some(tokenizer_json.to_string());
165 self
166 }
167
168 pub fn with_loader_type(mut self, loader_type: NormalLoaderType) -> Self {
171 self.loader_type = Some(loader_type);
172 self
173 }
174
175 pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
177 self.dtype = dtype;
178 self
179 }
180
181 pub fn with_force_cpu(mut self) -> Self {
183 self.force_cpu = true;
184 self
185 }
186
187 pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
189 self.token_source = token_source;
190 self
191 }
192
193 pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
195 self.hf_revision = Some(revision.to_string());
196 self
197 }
198
199 pub fn with_isq(mut self, isq: IsqType) -> Self {
201 self.isq = Some(isq);
202 self
203 }
204
205 pub fn with_imatrix(mut self, path: PathBuf) -> Self {
207 self.imatrix = Some(path);
208 self
209 }
210
211 pub fn with_calibration_file(mut self, path: PathBuf) -> Self {
213 self.calibration_file = Some(path);
214 self
215 }
216
217 pub fn with_paged_attn(
222 mut self,
223 paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
224 ) -> anyhow::Result<Self> {
225 if paged_attn_supported() {
226 self.paged_attn_cfg = Some(paged_attn_cfg()?);
227 } else {
228 self.paged_attn_cfg = None;
229 }
230 Ok(self)
231 }
232
233 pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
235 self.max_num_seqs = max_num_seqs;
236 self
237 }
238
239 pub fn with_no_kv_cache(mut self) -> Self {
241 self.no_kv_cache = true;
242 self
243 }
244
245 pub fn with_prefix_cache_n(mut self, n_seqs: Option<usize>) -> Self {
247 self.prefix_cache_n = n_seqs;
248 self
249 }
250
251 pub fn with_logging(mut self) -> Self {
253 self.with_logging = true;
254 self
255 }
256
257 pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
259 self.device_mapping = Some(device_mapping);
260 self
261 }
262
263 pub fn from_uqff(mut self, path: PathBuf) -> Self {
265 self.from_uqff = Some(path);
266 self
267 }
268
269 pub fn write_uqff(mut self, path: PathBuf) -> Self {
278 self.write_uqff = Some(path);
279 self
280 }
281
282 pub fn from_hf_cache_pathf(mut self, hf_cache_path: PathBuf) -> Self {
284 self.hf_cache_path = Some(hf_cache_path);
285 self
286 }
287
288 pub async fn build(self) -> anyhow::Result<Model> {
289 let config = NormalSpecificConfig {
290 use_flash_attn: self.use_flash_attn,
291 prompt_chunksize: self.prompt_chunksize,
292 topology: self.topology,
293 organization: self.organization,
294 write_uqff: self.write_uqff,
295 from_uqff: self.from_uqff,
296 imatrix: self.imatrix,
297 calibration_file: self.calibration_file,
298 hf_cache_path: self.hf_cache_path,
299 };
300
301 if self.with_logging {
302 initialize_logging();
303 }
304
305 let loader = NormalLoaderBuilder::new(
306 config,
307 self.chat_template,
308 self.tokenizer_json,
309 Some(self.model_id),
310 self.no_kv_cache,
311 self.jinja_explicit,
312 )
313 .build(self.loader_type)?;
314
315 let pipeline = loader.load_model_from_hf(
317 self.hf_revision,
318 self.token_source,
319 &self.dtype,
320 &best_device(self.force_cpu)?,
321 !self.with_logging,
322 self.device_mapping
323 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
324 self.isq,
325 self.paged_attn_cfg,
326 )?;
327
328 let scheduler_method = match self.paged_attn_cfg {
329 Some(_) => {
330 let config = pipeline
331 .lock()
332 .await
333 .get_metadata()
334 .cache_config
335 .as_ref()
336 .unwrap()
337 .clone();
338
339 SchedulerConfig::PagedAttentionMeta {
340 max_num_seqs: self.max_num_seqs,
341 config,
342 }
343 }
344 None => SchedulerConfig::DefaultScheduler {
345 method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
346 },
347 };
348
349 let mut runner = MistralRsBuilder::new(
350 pipeline,
351 scheduler_method,
352 self.throughput_logging,
353 self.search_bert_model,
354 )
355 .with_no_kv_cache(self.no_kv_cache)
356 .with_no_prefix_cache(self.prefix_cache_n.is_none());
357
358 if let Some(n) = self.prefix_cache_n {
359 runner = runner.with_prefix_cache_n(n)
360 }
361
362 Ok(Model::new(runner.build()))
363 }
364}
365
366#[derive(Clone)]
367pub struct UqffTextModelBuilder(TextModelBuilder);
370
371impl UqffTextModelBuilder {
372 pub fn new(model_id: impl ToString, uqff_file: PathBuf) -> Self {
379 let mut inner = TextModelBuilder::new(model_id);
380 inner = inner.from_uqff(uqff_file);
381 Self(inner)
382 }
383
384 pub async fn build(self) -> anyhow::Result<Model> {
385 self.0.build().await
386 }
387
388 pub fn into_inner(self) -> TextModelBuilder {
390 self.0
391 }
392}
393
394impl Deref for UqffTextModelBuilder {
395 type Target = TextModelBuilder;
396
397 fn deref(&self) -> &Self::Target {
398 &self.0
399 }
400}
401
402impl DerefMut for UqffTextModelBuilder {
403 fn deref_mut(&mut self) -> &mut Self::Target {
404 &mut self.0
405 }
406}
407
408impl From<UqffTextModelBuilder> for TextModelBuilder {
409 fn from(value: UqffTextModelBuilder) -> Self {
410 value.0
411 }
412}