1use candle_core::Device;
2use mistralrs_core::*;
3use std::{
4 num::NonZeroUsize,
5 ops::{Deref, DerefMut},
6 path::PathBuf,
7};
8
9use crate::{best_device, Model};
10
11#[derive(Clone)]
12pub struct TextModelBuilder {
14 pub(crate) model_id: String,
16 pub(crate) token_source: TokenSource,
17 pub(crate) hf_revision: Option<String>,
18 pub(crate) write_uqff: Option<PathBuf>,
19 pub(crate) from_uqff: Option<Vec<PathBuf>>,
20 pub(crate) imatrix: Option<PathBuf>,
21 pub(crate) calibration_file: Option<PathBuf>,
22 pub(crate) chat_template: Option<String>,
23 pub(crate) jinja_explicit: Option<String>,
24 pub(crate) tokenizer_json: Option<String>,
25 pub(crate) device_mapping: Option<DeviceMapSetting>,
26 pub(crate) hf_cache_path: Option<PathBuf>,
27 pub(crate) search_bert_model: Option<BertEmbeddingModel>,
28 pub(crate) device: Option<Device>,
29
30 pub(crate) prompt_chunksize: Option<NonZeroUsize>,
32 pub(crate) topology: Option<Topology>,
33 pub(crate) organization: IsqOrganization,
34 pub(crate) loader_type: Option<NormalLoaderType>,
35 pub(crate) dtype: ModelDType,
36 pub(crate) force_cpu: bool,
37 pub(crate) isq: Option<IsqType>,
38 pub(crate) throughput_logging: bool,
39
40 pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
42 pub(crate) max_num_seqs: usize,
43 pub(crate) no_kv_cache: bool,
44 pub(crate) with_logging: bool,
45 pub(crate) prefix_cache_n: Option<usize>,
46}
47
48pub struct PagedAttentionMetaBuilder {
50 block_size: Option<usize>,
51 mem_cpu: usize,
52 mem_gpu: MemoryGpuConfig,
53}
54
55impl Default for PagedAttentionMetaBuilder {
56 fn default() -> Self {
57 Self {
58 block_size: None,
59 mem_cpu: 64,
60 mem_gpu: MemoryGpuConfig::ContextSize(4096),
61 }
62 }
63}
64
65impl PagedAttentionMetaBuilder {
66 pub fn with_block_size(mut self, block_size: usize) -> Self {
67 self.block_size = Some(block_size);
68 self
69 }
70
71 pub fn with_gpu_memory(mut self, mem_gpu: MemoryGpuConfig) -> Self {
72 self.mem_gpu = mem_gpu;
73 self
74 }
75
76 pub fn build(self) -> anyhow::Result<PagedAttentionConfig> {
77 PagedAttentionConfig::new(self.block_size, self.mem_cpu, self.mem_gpu)
78 }
79}
80
81impl TextModelBuilder {
82 pub fn new(model_id: impl ToString) -> Self {
90 Self {
91 model_id: model_id.to_string(),
92 prompt_chunksize: None,
93 topology: None,
94 organization: IsqOrganization::Default,
95 write_uqff: None,
96 from_uqff: None,
97 chat_template: None,
98 tokenizer_json: None,
99 loader_type: None,
100 dtype: ModelDType::Auto,
101 force_cpu: false,
102 token_source: TokenSource::CacheToken,
103 hf_revision: None,
104 isq: None,
105 paged_attn_cfg: None,
106 max_num_seqs: 32,
107 no_kv_cache: false,
108 prefix_cache_n: Some(16),
109 with_logging: false,
110 device_mapping: None,
111 imatrix: None,
112 calibration_file: None,
113 jinja_explicit: None,
114 throughput_logging: false,
115 hf_cache_path: None,
116 search_bert_model: None,
117 device: None,
118 }
119 }
120
121 pub fn with_search(mut self, search_bert_model: BertEmbeddingModel) -> Self {
123 self.search_bert_model = Some(search_bert_model);
124 self
125 }
126
127 pub fn with_throughput_logging(mut self) -> Self {
129 self.throughput_logging = true;
130 self
131 }
132
133 pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
135 self.jinja_explicit = Some(jinja_explicit);
136 self
137 }
138
139 pub fn with_prompt_chunksize(mut self, prompt_chunksize: NonZeroUsize) -> Self {
141 self.prompt_chunksize = Some(prompt_chunksize);
142 self
143 }
144
145 pub fn with_topology(mut self, topology: Topology) -> Self {
147 self.topology = Some(topology);
148 self
149 }
150
151 pub fn with_mixture_qexperts_isq(mut self) -> Self {
153 self.organization = IsqOrganization::MoeExpertsOnly;
154 self
155 }
156
157 pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
159 self.chat_template = Some(chat_template.to_string());
160 self
161 }
162
163 pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
165 self.tokenizer_json = Some(tokenizer_json.to_string());
166 self
167 }
168
169 pub fn with_loader_type(mut self, loader_type: NormalLoaderType) -> Self {
172 self.loader_type = Some(loader_type);
173 self
174 }
175
176 pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
178 self.dtype = dtype;
179 self
180 }
181
182 pub fn with_force_cpu(mut self) -> Self {
184 self.force_cpu = true;
185 self
186 }
187
188 pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
190 self.token_source = token_source;
191 self
192 }
193
194 pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
196 self.hf_revision = Some(revision.to_string());
197 self
198 }
199
200 pub fn with_isq(mut self, isq: IsqType) -> Self {
202 self.isq = Some(isq);
203 self
204 }
205
206 pub fn with_imatrix(mut self, path: PathBuf) -> Self {
208 self.imatrix = Some(path);
209 self
210 }
211
212 pub fn with_calibration_file(mut self, path: PathBuf) -> Self {
214 self.calibration_file = Some(path);
215 self
216 }
217
218 pub fn with_paged_attn(
223 mut self,
224 paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
225 ) -> anyhow::Result<Self> {
226 if paged_attn_supported() {
227 self.paged_attn_cfg = Some(paged_attn_cfg()?);
228 } else {
229 self.paged_attn_cfg = None;
230 }
231 Ok(self)
232 }
233
234 pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
236 self.max_num_seqs = max_num_seqs;
237 self
238 }
239
240 pub fn with_no_kv_cache(mut self) -> Self {
242 self.no_kv_cache = true;
243 self
244 }
245
246 pub fn with_prefix_cache_n(mut self, n_seqs: Option<usize>) -> Self {
248 self.prefix_cache_n = n_seqs;
249 self
250 }
251
252 pub fn with_logging(mut self) -> Self {
254 self.with_logging = true;
255 self
256 }
257
258 pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
260 self.device_mapping = Some(device_mapping);
261 self
262 }
263
264 pub fn from_uqff(mut self, path: Vec<PathBuf>) -> Self {
266 self.from_uqff = Some(path);
267 self
268 }
269
270 pub fn write_uqff(mut self, path: PathBuf) -> Self {
279 self.write_uqff = Some(path);
280 self
281 }
282
283 pub fn from_hf_cache_pathf(mut self, hf_cache_path: PathBuf) -> Self {
285 self.hf_cache_path = Some(hf_cache_path);
286 self
287 }
288
289 pub fn with_device(mut self, device: Device) -> Self {
291 self.device = Some(device);
292 self
293 }
294
295 pub async fn build(self) -> anyhow::Result<Model> {
296 let config = NormalSpecificConfig {
297 prompt_chunksize: self.prompt_chunksize,
298 topology: self.topology,
299 organization: self.organization,
300 write_uqff: self.write_uqff,
301 from_uqff: self.from_uqff,
302 imatrix: self.imatrix,
303 calibration_file: self.calibration_file,
304 hf_cache_path: self.hf_cache_path,
305 };
306
307 if self.with_logging {
308 initialize_logging();
309 }
310
311 let loader = NormalLoaderBuilder::new(
312 config,
313 self.chat_template,
314 self.tokenizer_json,
315 Some(self.model_id),
316 self.no_kv_cache,
317 self.jinja_explicit,
318 )
319 .build(self.loader_type)?;
320
321 let pipeline = loader.load_model_from_hf(
323 self.hf_revision,
324 self.token_source,
325 &self.dtype,
326 &self.device.unwrap_or(best_device(self.force_cpu).unwrap()),
327 !self.with_logging,
328 self.device_mapping
329 .unwrap_or(DeviceMapSetting::Auto(AutoDeviceMapParams::default_text())),
330 self.isq,
331 self.paged_attn_cfg,
332 )?;
333
334 let scheduler_method = match self.paged_attn_cfg {
335 Some(_) => {
336 let config = pipeline
337 .lock()
338 .await
339 .get_metadata()
340 .cache_config
341 .as_ref()
342 .cloned();
343
344 if let Some(config) = config {
345 SchedulerConfig::PagedAttentionMeta {
346 max_num_seqs: self.max_num_seqs,
347 config,
348 }
349 } else {
350 SchedulerConfig::DefaultScheduler {
351 method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
352 }
353 }
354 }
355 None => SchedulerConfig::DefaultScheduler {
356 method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
357 },
358 };
359
360 let mut runner = MistralRsBuilder::new(
361 pipeline,
362 scheduler_method,
363 self.throughput_logging,
364 self.search_bert_model,
365 )
366 .with_no_kv_cache(self.no_kv_cache)
367 .with_no_prefix_cache(self.prefix_cache_n.is_none());
368
369 if let Some(n) = self.prefix_cache_n {
370 runner = runner.with_prefix_cache_n(n)
371 }
372
373 Ok(Model::new(runner.build()))
374 }
375}
376
377#[derive(Clone)]
378pub struct UqffTextModelBuilder(TextModelBuilder);
381
382impl UqffTextModelBuilder {
383 pub fn new(model_id: impl ToString, uqff_file: Vec<PathBuf>) -> Self {
390 let mut inner = TextModelBuilder::new(model_id);
391 inner = inner.from_uqff(uqff_file);
392 Self(inner)
393 }
394
395 pub async fn build(self) -> anyhow::Result<Model> {
396 self.0.build().await
397 }
398
399 pub fn into_inner(self) -> TextModelBuilder {
401 self.0
402 }
403}
404
405impl Deref for UqffTextModelBuilder {
406 type Target = TextModelBuilder;
407
408 fn deref(&self) -> &Self::Target {
409 &self.0
410 }
411}
412
413impl DerefMut for UqffTextModelBuilder {
414 fn deref_mut(&mut self) -> &mut Self::Target {
415 &mut self.0
416 }
417}
418
419impl From<UqffTextModelBuilder> for TextModelBuilder {
420 fn from(value: UqffTextModelBuilder) -> Self {
421 value.0
422 }
423}