1use std::{num::NonZeroUsize, sync::Arc};
4
5use anyhow::{Context, Result};
6use candle_core::Device;
7use mistralrs_core::{
8 get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, paged_attn_supported,
9 parse_isq_value, AutoDeviceMapParams, BertEmbeddingModel, DefaultSchedulerMethod,
10 DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, Loader, LoaderBuilder,
11 MemoryGpuConfig, MistralRsBuilder, ModelSelected, PagedAttentionConfig, SchedulerConfig,
12 SearchCallback, TokenSource,
13};
14use tracing::info;
15
16use crate::types::{LoadedPipeline, SharedMistralRsState};
17
18pub mod defaults {
19 use std::sync::Arc;
23
24 pub const DEVICE: Option<candle_core::Device> = None;
25 pub const SEED: Option<u64> = None;
26 pub const LOG: Option<String> = None;
27 pub const TRUNCATE_SEQUENCE: bool = false;
28 pub const MODEL: Option<mistralrs_core::ModelSelected> = None;
29 pub const MAX_SEQS: usize = 16;
30 pub const NO_KV_CACHE: bool = false;
31 pub const CHAT_TEMPLATE: Option<String> = None;
32 pub const JINJA_EXPLICIT: Option<String> = None;
33 pub const INTERACTIVE_MODE: bool = false;
34 pub const PREFIX_CACHE_N: usize = 16;
35 pub const NUM_DEVICE_LAYERS: Option<Vec<String>> = None;
36 pub const IN_SITU_QUANT: Option<String> = None;
37 pub const PAGED_ATTN_GPU_MEM: Option<usize> = None;
38 pub const PAGED_ATTN_GPU_MEM_USAGE: Option<f32> = None;
39 pub const PAGED_CTXT_LEN: Option<usize> = None;
40 pub const PAGED_ATTN_BLOCK_SIZE: Option<usize> = None;
41 pub const NO_PAGED_ATTN: bool = false;
42 pub const PAGED_ATTN: bool = false;
43 pub const PROMPT_CHUNKSIZE: Option<usize> = None;
44 pub const CPU: bool = false;
45 pub const ENABLE_SEARCH: bool = false;
46 pub const SEARCH_BERT_MODEL: Option<String> = None;
47 pub const TOKEN_SOURCE: mistralrs_core::TokenSource = mistralrs_core::TokenSource::CacheToken;
48 pub const SEARCH_CALLBACK: Option<Arc<mistralrs_core::SearchCallback>> = None;
49}
50
51pub struct MistralRsForServerBuilder {
88 device: Option<Device>,
90
91 seed: Option<u64>,
93
94 log: Option<String>,
96
97 truncate_sequence: bool,
101
102 model: Option<ModelSelected>,
104
105 max_seqs: usize,
107
108 no_kv_cache: bool,
110
111 chat_template: Option<String>,
114
115 jinja_explicit: Option<String>,
117
118 token_source: TokenSource,
122
123 interactive_mode: bool,
125
126 prefix_cache_n: usize,
128
129 num_device_layers: Option<Vec<String>>,
134
135 in_situ_quant: Option<String>,
137
138 paged_attn_gpu_mem: Option<usize>,
142
143 paged_attn_gpu_mem_usage: Option<f32>,
148
149 paged_ctxt_len: Option<usize>,
154
155 paged_attn_block_size: Option<usize>,
158
159 no_paged_attn: bool,
161
162 paged_attn: bool,
164
165 prompt_chunksize: Option<usize>,
167
168 cpu: bool,
170
171 enable_search: bool,
173
174 search_bert_model: Option<String>,
176
177 search_callback: Option<Arc<SearchCallback>>,
179}
180
181impl Default for MistralRsForServerBuilder {
182 fn default() -> Self {
184 Self {
185 device: defaults::DEVICE,
186 seed: defaults::SEED,
187 log: defaults::LOG,
188 truncate_sequence: defaults::TRUNCATE_SEQUENCE,
189 model: defaults::MODEL,
190 max_seqs: defaults::MAX_SEQS,
191 no_kv_cache: defaults::NO_KV_CACHE,
192 chat_template: defaults::CHAT_TEMPLATE,
193 jinja_explicit: defaults::JINJA_EXPLICIT,
194 token_source: defaults::TOKEN_SOURCE,
195 interactive_mode: defaults::INTERACTIVE_MODE,
196 prefix_cache_n: defaults::PREFIX_CACHE_N,
197 num_device_layers: defaults::NUM_DEVICE_LAYERS,
198 in_situ_quant: defaults::IN_SITU_QUANT,
199 paged_attn_gpu_mem: defaults::PAGED_ATTN_GPU_MEM,
200 paged_attn_gpu_mem_usage: defaults::PAGED_ATTN_GPU_MEM_USAGE,
201 paged_ctxt_len: defaults::PAGED_CTXT_LEN,
202 paged_attn_block_size: defaults::PAGED_ATTN_BLOCK_SIZE,
203 no_paged_attn: defaults::NO_PAGED_ATTN,
204 paged_attn: defaults::PAGED_ATTN,
205 prompt_chunksize: defaults::PROMPT_CHUNKSIZE,
206 cpu: defaults::CPU,
207 enable_search: defaults::ENABLE_SEARCH,
208 search_bert_model: defaults::SEARCH_BERT_MODEL,
209 search_callback: defaults::SEARCH_CALLBACK,
210 }
211 }
212}
213
214impl MistralRsForServerBuilder {
215 pub fn new() -> Self {
227 Default::default()
228 }
229
230 pub fn with_device(mut self, device: Device) -> Self {
232 self.device = Some(device);
233 self
234 }
235
236 pub fn with_seed(mut self, seed: u64) -> Self {
238 self.seed = Some(seed);
239 self
240 }
241
242 pub fn with_seed_optional(mut self, seed: Option<u64>) -> Self {
244 if let Some(seed) = seed {
245 self = self.with_seed(seed);
246 }
247 self
248 }
249
250 pub fn with_log(mut self, log: String) -> Self {
252 self.log = Some(log);
253 self
254 }
255
256 pub fn with_log_optional(mut self, log: Option<String>) -> Self {
258 if let Some(log) = log {
259 self = self.with_log(log);
260 }
261 self
262 }
263
264 pub fn with_truncate_sequence(mut self, truncate_sequence: bool) -> Self {
266 self.truncate_sequence = truncate_sequence;
267 self
268 }
269
270 pub fn with_model(mut self, model: ModelSelected) -> Self {
272 self.model = Some(model);
273 self
274 }
275
276 pub fn with_max_seqs(mut self, max_seqs: usize) -> Self {
278 self.max_seqs = max_seqs;
279 self
280 }
281
282 pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
284 self.no_kv_cache = no_kv_cache;
285 self
286 }
287
288 pub fn with_chat_template(mut self, chat_template: String) -> Self {
290 self.chat_template = Some(chat_template);
291 self
292 }
293
294 pub fn with_chat_template_optional(mut self, chat_template: Option<String>) -> Self {
296 if let Some(chat_template) = chat_template {
297 self = self.with_chat_template(chat_template);
298 }
299 self
300 }
301
302 pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
304 self.jinja_explicit = Some(jinja_explicit);
305 self
306 }
307
308 pub fn with_jinja_explicit_optional(mut self, jinja_explicit: Option<String>) -> Self {
310 if let Some(jinja_explicit) = jinja_explicit {
311 self = self.with_jinja_explicit(jinja_explicit);
312 }
313 self
314 }
315
316 pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
318 self.token_source = token_source;
319 self
320 }
321
322 pub fn with_interactive_mode(mut self, interactive_mode: bool) -> Self {
324 self.interactive_mode = interactive_mode;
325 self
326 }
327
328 pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
330 self.prefix_cache_n = prefix_cache_n;
331 self
332 }
333
334 pub fn with_num_device_layers(mut self, num_device_layers: Vec<String>) -> Self {
336 self.num_device_layers = Some(num_device_layers);
337 self
338 }
339
340 pub fn with_num_device_layers_optional(
342 mut self,
343 num_device_layers: Option<Vec<String>>,
344 ) -> Self {
345 if let Some(num_device_layers) = num_device_layers {
346 self = self.with_num_device_layers(num_device_layers);
347 }
348 self
349 }
350
351 pub fn with_in_situ_quant(mut self, in_situ_quant: String) -> Self {
353 self.in_situ_quant = Some(in_situ_quant);
354 self
355 }
356
357 pub fn with_in_situ_quant_optional(mut self, in_situ_quant: Option<String>) -> Self {
359 if let Some(in_situ_quant) = in_situ_quant {
360 self = self.with_in_situ_quant(in_situ_quant);
361 }
362 self
363 }
364
365 pub fn with_paged_attn_gpu_mem(mut self, paged_attn_gpu_mem: usize) -> Self {
367 self.paged_attn_gpu_mem = Some(paged_attn_gpu_mem);
368 self
369 }
370
371 pub fn with_paged_attn_gpu_mem_optional(mut self, paged_attn_gpu_mem: Option<usize>) -> Self {
373 if let Some(paged_attn_gpu_mem) = paged_attn_gpu_mem {
374 self = self.with_paged_attn_gpu_mem(paged_attn_gpu_mem);
375 }
376 self
377 }
378
379 pub fn with_paged_attn_gpu_mem_usage(mut self, paged_attn_gpu_mem_usage: f32) -> Self {
381 self.paged_attn_gpu_mem_usage = Some(paged_attn_gpu_mem_usage);
382 self
383 }
384
385 pub fn with_paged_attn_gpu_mem_usage_optional(
387 mut self,
388 paged_attn_gpu_mem_usage: Option<f32>,
389 ) -> Self {
390 if let Some(paged_attn_gpu_mem_usage) = paged_attn_gpu_mem_usage {
391 self = self.with_paged_attn_gpu_mem_usage(paged_attn_gpu_mem_usage);
392 }
393 self
394 }
395
396 pub fn with_paged_ctxt_len(mut self, paged_ctxt_len: usize) -> Self {
398 self.paged_ctxt_len = Some(paged_ctxt_len);
399 self
400 }
401
402 pub fn with_paged_ctxt_len_optional(mut self, paged_ctxt_len: Option<usize>) -> Self {
404 if let Some(paged_ctxt_len) = paged_ctxt_len {
405 self = self.with_paged_ctxt_len(paged_ctxt_len);
406 }
407 self
408 }
409
410 pub fn with_paged_attn_block_size(mut self, paged_attn_block_size: usize) -> Self {
412 self.paged_attn_block_size = Some(paged_attn_block_size);
413 self
414 }
415
416 pub fn with_paged_attn_block_size_optional(
418 mut self,
419 paged_attn_block_size: Option<usize>,
420 ) -> Self {
421 if let Some(paged_attn_block_size) = paged_attn_block_size {
422 self = self.with_paged_attn_block_size(paged_attn_block_size);
423 }
424 self
425 }
426
427 pub fn with_no_paged_attn(mut self, no_paged_attn: bool) -> Self {
429 self.no_paged_attn = no_paged_attn;
430 self
431 }
432
433 pub fn with_paged_attn(mut self, paged_attn: bool) -> Self {
435 self.paged_attn = paged_attn;
436 self
437 }
438
439 pub fn with_prompt_chunksize(mut self, prompt_chunksize: usize) -> Self {
441 self.prompt_chunksize = Some(prompt_chunksize);
442 self
443 }
444
445 pub fn with_prompt_chunksize_optional(mut self, prompt_chunksize: Option<usize>) -> Self {
447 if let Some(prompt_chunksize) = prompt_chunksize {
448 self = self.with_prompt_chunksize(prompt_chunksize);
449 }
450 self
451 }
452
453 pub fn with_cpu(mut self, cpu: bool) -> Self {
455 self.cpu = cpu;
456 self
457 }
458
459 pub fn with_enable_search(mut self, enable_search: bool) -> Self {
461 self.enable_search = enable_search;
462 self
463 }
464
465 pub fn with_search_bert_model(mut self, search_bert_model: String) -> Self {
467 self.search_bert_model = Some(search_bert_model);
468 self
469 }
470
471 pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
473 self.search_callback = Some(callback);
474 self
475 }
476
477 pub async fn build(mut self) -> Result<SharedMistralRsState> {
492 if self.cpu {
494 self.no_paged_attn = true;
495 }
496
497 let model = self.model.context("Model was None")?;
498
499 let tgt_non_granular_index = get_tgt_non_granular_index(&model);
500 let dtype = get_model_dtype(&model)?;
501 let auto_device_map_params = get_auto_device_map_params(&model)?;
502
503 if tgt_non_granular_index.is_some() {
504 self.max_seqs = 1;
505 }
506
507 let prompt_chunksize = match self.prompt_chunksize {
508 Some(0) => {
509 anyhow::bail!("`prompt_chunksize` must be a strictly positive integer, got 0.",)
510 }
511 Some(x) => Some(NonZeroUsize::new(x).unwrap()),
512 None => None,
513 };
514
515 let max_seq_len = auto_device_map_params.max_seq_len();
516
517 let device = if let Some(device) = self.device {
518 device
519 } else {
520 init_device(self.cpu, self.seed)?
521 };
522
523 let mapper = init_mapper(&self.num_device_layers, &auto_device_map_params);
524 let no_paged_attn = configure_no_paged_attn(&device, self.no_paged_attn, self.paged_attn);
525
526 let cache_config = init_cache_config(
529 self.paged_attn_block_size,
530 self.paged_attn_gpu_mem,
531 self.paged_attn_gpu_mem_usage,
532 self.paged_ctxt_len,
533 no_paged_attn,
534 max_seq_len,
535 )?;
536
537 let loader: Box<dyn Loader> = LoaderBuilder::new(model)
539 .with_no_kv_cache(self.no_kv_cache)
540 .with_chat_template(self.chat_template)
541 .with_prompt_chunksize(prompt_chunksize)
542 .with_jinja_explicit(self.jinja_explicit)
543 .build()?;
544
545 mistralrs_instance_info(&*loader);
546
547 let isq = self
548 .in_situ_quant
549 .as_ref()
550 .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
551
552 let pipeline: LoadedPipeline = loader.load_model_from_hf(
553 None,
554 self.token_source,
555 &dtype,
556 &device,
557 false,
558 mapper,
559 isq,
560 cache_config,
561 )?;
562 info!("Model loaded.");
563
564 let scheduler_config = init_scheduler_config(&cache_config, &pipeline, self.max_seqs).await;
565
566 let bert_model = get_bert_model(self.enable_search, self.search_bert_model);
567
568 let mistralrs = MistralRsBuilder::new(
569 pipeline,
570 scheduler_config,
571 !self.interactive_mode,
572 bert_model,
573 )
574 .with_opt_log(self.log)
575 .with_truncate_sequence(self.truncate_sequence)
576 .with_no_kv_cache(self.no_kv_cache)
577 .with_prefix_cache_n(self.prefix_cache_n)
578 .build();
579
580 Ok(mistralrs)
581 }
582}
583
584fn init_device(force_cpu: bool, seed: Option<u64>) -> Result<candle_core::Device> {
587 #[cfg(feature = "metal")]
588 let device = if force_cpu {
589 Device::Cpu
590 } else {
591 Device::new_metal(0)?
592 };
593 #[cfg(not(feature = "metal"))]
594 #[allow(clippy::if_same_then_else)]
595 let device = if force_cpu {
596 Device::Cpu
597 } else if mistralrs_core::distributed::use_nccl() {
598 Device::Cpu
599 } else {
600 Device::cuda_if_available(0)?
601 };
602
603 if let Some(seed) = seed {
604 device.set_seed(seed)?;
605 }
606
607 Ok(device)
608}
609
610fn init_mapper(
612 num_device_layers: &Option<Vec<String>>,
613 auto_device_map_params: &AutoDeviceMapParams,
614) -> DeviceMapSetting {
615 if let Some(device_layers) = num_device_layers {
617 if device_layers.len() == 1 && device_layers[0].parse::<usize>().is_ok() {
618 let layers = device_layers[0].parse::<usize>().unwrap();
619 DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(vec![
620 DeviceLayerMapMetadata { ordinal: 0, layers },
621 ]))
622 } else {
623 let mut mapping = Vec::new();
624 for layer in device_layers {
625 let split = layer.splitn(2, ':').collect::<Vec<_>>();
626 if split.len() < 2 {
627 panic!("Expected layer to be of format ORD:NUM, got {layer}");
628 }
629 let ord = split[0]
630 .parse::<usize>()
631 .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[0]));
632 let num = split[1]
633 .parse::<usize>()
634 .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[1]));
635 for DeviceLayerMapMetadata { ordinal, layers: _ } in &mapping {
636 if *ordinal == ord {
637 panic!("Duplicate ordinal {ord}");
638 }
639 }
640 mapping.push(DeviceLayerMapMetadata {
641 ordinal: ord,
642 layers: num,
643 });
644 }
645 DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(mapping))
646 }
647 } else {
648 DeviceMapSetting::Auto(auto_device_map_params.clone())
649 }
650}
651
652fn mistralrs_instance_info(loader: &dyn Loader) {
654 info!(
655 "avx: {}, neon: {}, simd128: {}, f16c: {}",
656 candle_core::utils::with_avx(),
657 candle_core::utils::with_neon(),
658 candle_core::utils::with_simd128(),
659 candle_core::utils::with_f16c()
660 );
661
662 info!("Sampling method: penalties -> temperature -> topk -> topp -> minp -> multinomial");
663 info!("Model kind is: {}", loader.get_kind().to_string());
664}
665
666fn configure_no_paged_attn(device: &Device, no_paged_attn: bool, paged_attn: bool) -> bool {
668 if device.is_cuda() || mistralrs_core::distributed::use_nccl() {
669 no_paged_attn
670 } else if device.is_metal() {
671 !paged_attn
672 } else {
673 true
674 }
675}
676
677fn init_cache_config(
679 paged_attn_block_size: Option<usize>,
680 paged_attn_gpu_mem: Option<usize>,
681 paged_attn_gpu_mem_usage: Option<f32>,
682 paged_ctxt_len: Option<usize>,
683 no_paged_attn: bool,
684 max_seq_len: usize,
685) -> Result<Option<PagedAttentionConfig>> {
686 match (
687 paged_attn_block_size,
688 paged_attn_gpu_mem,
689 paged_attn_gpu_mem_usage,
690 paged_ctxt_len,
691 paged_attn_supported(),
692 no_paged_attn,
693 ) {
694 (block_size, None, None, None, true, false) => Ok(Some(PagedAttentionConfig::new(
695 block_size,
696 512,
697 MemoryGpuConfig::ContextSize(max_seq_len),
698 )?)),
699 (block_size, None, None, Some(ctxt), true, false) => Ok(Some(PagedAttentionConfig::new(
700 block_size,
701 512,
702 MemoryGpuConfig::ContextSize(ctxt),
703 )?)),
704 (block_size, None, Some(f), None, true, false) => Ok(Some(PagedAttentionConfig::new(
705 block_size,
706 512,
707 MemoryGpuConfig::Utilization(f),
708 )?)),
709 (block_size, Some(m), None, None, true, false) => Ok(Some(PagedAttentionConfig::new(
710 block_size,
711 512,
712 MemoryGpuConfig::MbAmount(m),
713 )?)),
714 (block_size, Some(_m), Some(f), None, true, false) => {
715 info!("Both memory size, and usage were specified, defaulting to the usage value.");
716 Ok(Some(PagedAttentionConfig::new(
717 block_size,
718 512,
719 MemoryGpuConfig::Utilization(f),
720 )?))
721 }
722 (block_size, Some(_m), None, Some(ctxt), true, false) => {
723 info!("All memory size and ctxt len, defaulting to the context len value.");
724 Ok(Some(PagedAttentionConfig::new(
725 block_size,
726 512,
727 MemoryGpuConfig::ContextSize(ctxt),
728 )?))
729 }
730 (block_size, None, Some(f), Some(_ctxt), true, false) => {
731 info!("Both ctxt len and usage were specified, defaulting to the usage value.");
732 Ok(Some(PagedAttentionConfig::new(
733 block_size,
734 512,
735 MemoryGpuConfig::Utilization(f),
736 )?))
737 }
738 (_, _, _, _, _, _) => Ok(None),
739 }
740}
741
742async fn init_scheduler_config(
744 cache_config: &Option<PagedAttentionConfig>,
745 pipeline: &LoadedPipeline,
746 args_max_seqs: usize,
747) -> SchedulerConfig {
748 if cache_config.is_some() {
749 if let Some(ref cache_config) = pipeline.lock().await.get_metadata().cache_config {
751 SchedulerConfig::PagedAttentionMeta {
752 max_num_seqs: args_max_seqs,
753 config: cache_config.clone(),
754 }
755 } else {
756 SchedulerConfig::DefaultScheduler {
757 method: DefaultSchedulerMethod::Fixed(args_max_seqs.try_into().unwrap()),
758 }
759 }
760 } else {
761 SchedulerConfig::DefaultScheduler {
762 method: DefaultSchedulerMethod::Fixed(args_max_seqs.try_into().unwrap()),
763 }
764 }
765}
766
767pub fn get_bert_model(
769 enable_search: bool,
770 search_bert_model: Option<String>,
771) -> Option<BertEmbeddingModel> {
772 if enable_search {
773 Some(
774 search_bert_model
775 .map(BertEmbeddingModel::Custom)
776 .unwrap_or_default(),
777 )
778 } else {
779 None
780 }
781}