1use std::{num::NonZeroUsize, sync::Arc};
4
5use anyhow::{Context, Result};
6use candle_core::Device;
7use mistralrs_core::{
8 get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, paged_attn_supported,
9 parse_isq_value, AutoDeviceMapParams, BertEmbeddingModel, DefaultSchedulerMethod,
10 DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, Loader, LoaderBuilder,
11 McpClientConfig, MemoryGpuConfig, MistralRsBuilder, ModelSelected, PagedAttentionConfig,
12 PagedCacheType, SchedulerConfig, SearchCallback, TokenSource,
13};
14use tracing::{info, warn};
15
16use crate::types::{LoadedPipeline, SharedMistralRsState};
17use std::collections::HashMap;
18
19#[derive(Clone, serde::Deserialize)]
21pub struct ModelConfig {
22 pub model_id: String,
24 pub model: ModelSelected,
26 pub chat_template: Option<String>,
28 pub jinja_explicit: Option<String>,
30 pub num_device_layers: Option<Vec<String>>,
32 pub in_situ_quant: Option<String>,
34}
35
36impl ModelConfig {
37 pub fn new(model_id: String, model: ModelSelected) -> Self {
38 Self {
39 model_id,
40 model,
41 chat_template: None,
42 jinja_explicit: None,
43 num_device_layers: None,
44 in_situ_quant: None,
45 }
46 }
47
48 pub fn with_chat_template(mut self, chat_template: String) -> Self {
49 self.chat_template = Some(chat_template);
50 self
51 }
52
53 pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
54 self.jinja_explicit = Some(jinja_explicit);
55 self
56 }
57
58 pub fn with_num_device_layers(mut self, num_device_layers: Vec<String>) -> Self {
59 self.num_device_layers = Some(num_device_layers);
60 self
61 }
62
63 pub fn with_in_situ_quant(mut self, in_situ_quant: String) -> Self {
64 self.in_situ_quant = Some(in_situ_quant);
65 self
66 }
67}
68
69pub mod defaults {
70 use std::sync::Arc;
74
75 use mistralrs_core::PagedCacheType;
76
77 pub const DEVICE: Option<candle_core::Device> = None;
78 pub const SEED: Option<u64> = None;
79 pub const LOG: Option<String> = None;
80 pub const TRUNCATE_SEQUENCE: bool = false;
81 pub const MODEL: Option<mistralrs_core::ModelSelected> = None;
82 pub const MAX_SEQS: usize = 16;
83 pub const NO_KV_CACHE: bool = false;
84 pub const CHAT_TEMPLATE: Option<String> = None;
85 pub const JINJA_EXPLICIT: Option<String> = None;
86 pub const INTERACTIVE_MODE: bool = false;
87 pub const PREFIX_CACHE_N: usize = 16;
88 pub const NUM_DEVICE_LAYERS: Option<Vec<String>> = None;
89 pub const IN_SITU_QUANT: Option<String> = None;
90 pub const PAGED_ATTN_GPU_MEM: Option<usize> = None;
91 pub const PAGED_ATTN_GPU_MEM_USAGE: Option<f32> = None;
92 pub const PAGED_CTXT_LEN: Option<usize> = None;
93 pub const PAGED_ATTN_BLOCK_SIZE: Option<usize> = None;
94 pub const PAGED_ATTN: Option<bool> = None;
95 pub const PAGED_ATTN_CPU: bool = false;
96 pub const PAGED_ATTN_CUDA: bool = true;
97 pub const PAGED_ATTN_METAL: bool = false;
98 pub const PROMPT_CHUNKSIZE: Option<usize> = None;
99 pub const CPU: bool = false;
100 pub const ENABLE_SEARCH: bool = false;
101 pub const SEARCH_BERT_MODEL: Option<String> = None;
102 pub const TOKEN_SOURCE: mistralrs_core::TokenSource = mistralrs_core::TokenSource::CacheToken;
103 pub const SEARCH_CALLBACK: Option<Arc<mistralrs_core::SearchCallback>> = None;
104 pub const PAGED_CACHE_TYPE: PagedCacheType = PagedCacheType::Auto;
105}
106
107pub struct MistralRsForServerBuilder {
143 device: Option<Device>,
145
146 seed: Option<u64>,
148
149 log: Option<String>,
151
152 truncate_sequence: bool,
156
157 model: Option<ModelSelected>,
159
160 models: Vec<ModelConfig>,
162
163 default_model_id: Option<String>,
165
166 max_seqs: usize,
168
169 no_kv_cache: bool,
171
172 chat_template: Option<String>,
175
176 jinja_explicit: Option<String>,
178
179 token_source: TokenSource,
183
184 interactive_mode: bool,
186
187 prefix_cache_n: usize,
189
190 num_device_layers: Option<Vec<String>>,
195
196 in_situ_quant: Option<String>,
198
199 paged_attn_gpu_mem: Option<usize>,
203
204 paged_attn_gpu_mem_usage: Option<f32>,
209
210 paged_ctxt_len: Option<usize>,
215
216 paged_attn_block_size: Option<usize>,
219
220 paged_attn: Option<bool>,
222
223 prompt_chunksize: Option<usize>,
225
226 cpu: bool,
228
229 enable_search: bool,
231
232 search_bert_model: Option<String>,
234
235 search_callback: Option<Arc<SearchCallback>>,
237
238 mcp_client_config: Option<McpClientConfig>,
240
241 paged_cache_type: PagedCacheType,
243}
244
245impl Default for MistralRsForServerBuilder {
246 fn default() -> Self {
248 Self {
249 device: defaults::DEVICE,
250 seed: defaults::SEED,
251 log: defaults::LOG,
252 truncate_sequence: defaults::TRUNCATE_SEQUENCE,
253 model: defaults::MODEL,
254 models: Vec::new(),
255 default_model_id: None,
256 max_seqs: defaults::MAX_SEQS,
257 no_kv_cache: defaults::NO_KV_CACHE,
258 chat_template: defaults::CHAT_TEMPLATE,
259 jinja_explicit: defaults::JINJA_EXPLICIT,
260 token_source: defaults::TOKEN_SOURCE,
261 interactive_mode: defaults::INTERACTIVE_MODE,
262 prefix_cache_n: defaults::PREFIX_CACHE_N,
263 num_device_layers: defaults::NUM_DEVICE_LAYERS,
264 in_situ_quant: defaults::IN_SITU_QUANT,
265 paged_attn_gpu_mem: defaults::PAGED_ATTN_GPU_MEM,
266 paged_attn_gpu_mem_usage: defaults::PAGED_ATTN_GPU_MEM_USAGE,
267 paged_ctxt_len: defaults::PAGED_CTXT_LEN,
268 paged_attn_block_size: defaults::PAGED_ATTN_BLOCK_SIZE,
269 paged_attn: defaults::PAGED_ATTN,
270 prompt_chunksize: defaults::PROMPT_CHUNKSIZE,
271 cpu: defaults::CPU,
272 enable_search: defaults::ENABLE_SEARCH,
273 search_bert_model: defaults::SEARCH_BERT_MODEL,
274 search_callback: defaults::SEARCH_CALLBACK,
275 mcp_client_config: None,
276 paged_cache_type: defaults::PAGED_CACHE_TYPE,
277 }
278 }
279}
280
281impl MistralRsForServerBuilder {
282 pub fn new() -> Self {
294 Default::default()
295 }
296
297 pub fn with_device(mut self, device: Device) -> Self {
299 self.device = Some(device);
300 self
301 }
302
303 pub fn with_seed(mut self, seed: u64) -> Self {
305 self.seed = Some(seed);
306 self
307 }
308
309 pub fn with_seed_optional(mut self, seed: Option<u64>) -> Self {
311 if let Some(seed) = seed {
312 self = self.with_seed(seed);
313 }
314 self
315 }
316
317 pub fn with_log(mut self, log: String) -> Self {
319 self.log = Some(log);
320 self
321 }
322
323 pub fn with_log_optional(mut self, log: Option<String>) -> Self {
325 if let Some(log) = log {
326 self = self.with_log(log);
327 }
328 self
329 }
330
331 pub fn with_truncate_sequence(mut self, truncate_sequence: bool) -> Self {
333 self.truncate_sequence = truncate_sequence;
334 self
335 }
336
337 pub fn with_model(mut self, model: ModelSelected) -> Self {
339 self.model = Some(model);
340 self
341 }
342
343 pub fn with_model_config(mut self, model_config: ModelConfig) -> Self {
345 self.models.push(model_config);
346 self
347 }
348
349 pub fn with_model_configs(mut self, model_configs: Vec<ModelConfig>) -> Self {
351 self.models.extend(model_configs);
352 self
353 }
354
355 pub fn with_default_model_id(mut self, default_model_id: String) -> Self {
357 self.default_model_id = Some(default_model_id);
358 self
359 }
360
361 pub fn add_model_config(mut self, config: ModelConfig) -> Self {
363 self.models.push(config);
364 self
365 }
366
367 pub fn add_model(mut self, model_id: String, model: ModelSelected) -> Self {
369 self.models.push(ModelConfig::new(model_id, model));
370 self
371 }
372
373 pub fn with_max_seqs(mut self, max_seqs: usize) -> Self {
375 self.max_seqs = max_seqs;
376 self
377 }
378
379 pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
381 self.no_kv_cache = no_kv_cache;
382 self
383 }
384
385 pub fn with_chat_template(mut self, chat_template: String) -> Self {
387 self.chat_template = Some(chat_template);
388 self
389 }
390
391 pub fn with_chat_template_optional(mut self, chat_template: Option<String>) -> Self {
393 if let Some(chat_template) = chat_template {
394 self = self.with_chat_template(chat_template);
395 }
396 self
397 }
398
399 pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
401 self.jinja_explicit = Some(jinja_explicit);
402 self
403 }
404
405 pub fn with_jinja_explicit_optional(mut self, jinja_explicit: Option<String>) -> Self {
407 if let Some(jinja_explicit) = jinja_explicit {
408 self = self.with_jinja_explicit(jinja_explicit);
409 }
410 self
411 }
412
413 pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
415 self.token_source = token_source;
416 self
417 }
418
419 pub fn with_interactive_mode(mut self, interactive_mode: bool) -> Self {
421 self.interactive_mode = interactive_mode;
422 self
423 }
424
425 pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
427 self.prefix_cache_n = prefix_cache_n;
428 self
429 }
430
431 pub fn with_num_device_layers(mut self, num_device_layers: Vec<String>) -> Self {
433 self.num_device_layers = Some(num_device_layers);
434 self
435 }
436
437 pub fn with_num_device_layers_optional(
439 mut self,
440 num_device_layers: Option<Vec<String>>,
441 ) -> Self {
442 if let Some(num_device_layers) = num_device_layers {
443 self = self.with_num_device_layers(num_device_layers);
444 }
445 self
446 }
447
448 pub fn with_in_situ_quant(mut self, in_situ_quant: String) -> Self {
450 self.in_situ_quant = Some(in_situ_quant);
451 self
452 }
453
454 pub fn with_in_situ_quant_optional(mut self, in_situ_quant: Option<String>) -> Self {
456 if let Some(in_situ_quant) = in_situ_quant {
457 self = self.with_in_situ_quant(in_situ_quant);
458 }
459 self
460 }
461
462 pub fn set_paged_attn(mut self, paged_attn: Option<bool>) -> Self {
472 self.paged_attn = paged_attn;
473 self
474 }
475
476 pub fn with_paged_attn_gpu_mem(mut self, paged_attn_gpu_mem: usize) -> Self {
478 self.paged_attn_gpu_mem = Some(paged_attn_gpu_mem);
479 self
480 }
481
482 pub fn with_paged_attn_gpu_mem_optional(mut self, paged_attn_gpu_mem: Option<usize>) -> Self {
484 if let Some(paged_attn_gpu_mem) = paged_attn_gpu_mem {
485 self = self.with_paged_attn_gpu_mem(paged_attn_gpu_mem);
486 }
487 self
488 }
489
490 pub fn with_paged_attn_gpu_mem_usage(mut self, paged_attn_gpu_mem_usage: f32) -> Self {
492 self.paged_attn_gpu_mem_usage = Some(paged_attn_gpu_mem_usage);
493 self
494 }
495
496 pub fn with_paged_attn_gpu_mem_usage_optional(
498 mut self,
499 paged_attn_gpu_mem_usage: Option<f32>,
500 ) -> Self {
501 if let Some(paged_attn_gpu_mem_usage) = paged_attn_gpu_mem_usage {
502 self = self.with_paged_attn_gpu_mem_usage(paged_attn_gpu_mem_usage);
503 }
504 self
505 }
506
507 pub fn with_paged_ctxt_len(mut self, paged_ctxt_len: usize) -> Self {
509 self.paged_ctxt_len = Some(paged_ctxt_len);
510 self
511 }
512
513 pub fn with_paged_ctxt_len_optional(mut self, paged_ctxt_len: Option<usize>) -> Self {
515 if let Some(paged_ctxt_len) = paged_ctxt_len {
516 self = self.with_paged_ctxt_len(paged_ctxt_len);
517 }
518 self
519 }
520
521 pub fn with_paged_attn_block_size(mut self, paged_attn_block_size: usize) -> Self {
523 self.paged_attn_block_size = Some(paged_attn_block_size);
524 self
525 }
526
527 pub fn with_paged_attn_cache_type(mut self, cache_type: PagedCacheType) -> Self {
529 self.paged_cache_type = cache_type;
530 self
531 }
532
533 pub fn with_paged_attn_block_size_optional(
535 mut self,
536 paged_attn_block_size: Option<usize>,
537 ) -> Self {
538 if let Some(paged_attn_block_size) = paged_attn_block_size {
539 self = self.with_paged_attn_block_size(paged_attn_block_size);
540 }
541 self
542 }
543
544 pub fn with_prompt_chunksize(mut self, prompt_chunksize: usize) -> Self {
546 self.prompt_chunksize = Some(prompt_chunksize);
547 self
548 }
549
550 pub fn with_prompt_chunksize_optional(mut self, prompt_chunksize: Option<usize>) -> Self {
552 if let Some(prompt_chunksize) = prompt_chunksize {
553 self = self.with_prompt_chunksize(prompt_chunksize);
554 }
555 self
556 }
557
558 pub fn with_cpu(mut self, cpu: bool) -> Self {
560 self.cpu = cpu;
561 self
562 }
563
564 pub fn with_enable_search(mut self, enable_search: bool) -> Self {
566 self.enable_search = enable_search;
567 self
568 }
569
570 pub fn with_search_bert_model(mut self, search_bert_model: String) -> Self {
572 self.search_bert_model = Some(search_bert_model);
573 self
574 }
575
576 pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
578 self.search_callback = Some(callback);
579 self
580 }
581
582 pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
584 self.mcp_client_config = Some(mcp_config);
585 self
586 }
587
588 pub fn with_mcp_config_optional(mut self, mcp_config: Option<McpClientConfig>) -> Self {
590 if let Some(mcp_config) = mcp_config {
591 self = self.with_mcp_config(mcp_config);
592 }
593 self
594 }
595
596 pub async fn build(self) -> Result<SharedMistralRsState> {
611 if !self.models.is_empty() {
613 self.build_multi_model().await
614 } else {
615 self.build_single_model().await
616 }
617 }
618
619 async fn build_single_model(mut self) -> Result<SharedMistralRsState> {
621 let model = self.model.context("Model was None")?;
622
623 let tgt_non_granular_index = get_tgt_non_granular_index(&model);
624 let dtype = get_model_dtype(&model)?;
625 let auto_device_map_params = get_auto_device_map_params(&model)?;
626
627 if tgt_non_granular_index.is_some() {
628 self.max_seqs = 1;
629 }
630
631 let prompt_chunksize = match self.prompt_chunksize {
632 Some(0) => {
633 anyhow::bail!("`prompt_chunksize` must be a strictly positive integer, got 0.",)
634 }
635 Some(x) => Some(NonZeroUsize::new(x).unwrap()),
636 None => None,
637 };
638
639 let max_seq_len = auto_device_map_params.max_seq_len();
640
641 let device = if let Some(device) = self.device {
642 device
643 } else {
644 init_device(self.cpu, self.seed)?
645 };
646
647 let mapper = init_mapper(&self.num_device_layers, &auto_device_map_params);
648 let paged_attn = configure_paged_attn(&device, self.paged_attn);
649
650 let cache_config = init_cache_config(
653 self.paged_attn_block_size,
654 self.paged_attn_gpu_mem,
655 self.paged_attn_gpu_mem_usage,
656 self.paged_ctxt_len,
657 self.paged_cache_type,
658 !paged_attn,
659 max_seq_len,
660 )?;
661
662 let loader: Box<dyn Loader> = LoaderBuilder::new(model)
664 .with_no_kv_cache(self.no_kv_cache)
665 .with_chat_template(self.chat_template)
666 .with_prompt_chunksize(prompt_chunksize)
667 .with_jinja_explicit(self.jinja_explicit)
668 .build()?;
669
670 mistralrs_instance_info(&*loader);
671
672 let isq = self
673 .in_situ_quant
674 .as_ref()
675 .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
676
677 let pipeline: LoadedPipeline = loader.load_model_from_hf(
678 None,
679 self.token_source,
680 &dtype,
681 &device,
682 false,
683 mapper,
684 isq,
685 cache_config,
686 )?;
687 info!("Model loaded.");
688
689 let scheduler_config = init_scheduler_config(&cache_config, &pipeline, self.max_seqs).await;
690
691 let bert_model = get_bert_model(self.enable_search, self.search_bert_model);
692
693 let mut builder = MistralRsBuilder::new(
694 pipeline,
695 scheduler_config,
696 !self.interactive_mode,
697 bert_model,
698 )
699 .with_opt_log(self.log)
700 .with_truncate_sequence(self.truncate_sequence)
701 .with_no_kv_cache(self.no_kv_cache)
702 .with_prefix_cache_n(self.prefix_cache_n);
703
704 if let Some(mcp_config) = self.mcp_client_config {
706 builder = builder.with_mcp_client(mcp_config);
707 }
708
709 let mistralrs = builder.build().await;
710
711 Ok(mistralrs)
712 }
713
714 pub async fn build_multi_model(mut self) -> Result<SharedMistralRsState> {
716 if self.models.is_empty() {
717 anyhow::bail!("No models configured for multi-model mode");
718 }
719
720 let first_model = &self.models[0];
722 let model = first_model.model.clone();
723
724 let tgt_non_granular_index = get_tgt_non_granular_index(&model);
725 let dtype = get_model_dtype(&model)?;
726 let auto_device_map_params = get_auto_device_map_params(&model)?;
727
728 if tgt_non_granular_index.is_some() {
729 self.max_seqs = 1;
730 }
731
732 let prompt_chunksize = match self.prompt_chunksize {
733 Some(0) => {
734 anyhow::bail!("`prompt_chunksize` must be a strictly positive integer, got 0.",)
735 }
736 Some(x) => Some(NonZeroUsize::new(x).unwrap()),
737 None => None,
738 };
739
740 let max_seq_len = auto_device_map_params.max_seq_len();
741
742 let device = if let Some(device) = self.device {
743 device
744 } else {
745 init_device(self.cpu, self.seed)?
746 };
747
748 let loader: Box<dyn Loader> = LoaderBuilder::new(model)
750 .with_no_kv_cache(self.no_kv_cache)
751 .with_chat_template(
752 first_model
753 .chat_template
754 .clone()
755 .or(self.chat_template.clone()),
756 )
757 .with_prompt_chunksize(prompt_chunksize)
758 .with_jinja_explicit(
759 first_model
760 .jinja_explicit
761 .clone()
762 .or(self.jinja_explicit.clone()),
763 )
764 .build()?;
765
766 mistralrs_instance_info(&*loader);
767
768 let mapper = init_mapper(
769 &first_model
770 .num_device_layers
771 .clone()
772 .or(self.num_device_layers.clone()),
773 &auto_device_map_params,
774 );
775 let paged_attn = configure_paged_attn(&device, self.paged_attn);
776
777 let cache_config = init_cache_config(
778 self.paged_attn_block_size,
779 self.paged_attn_gpu_mem,
780 self.paged_attn_gpu_mem_usage,
781 self.paged_ctxt_len,
782 self.paged_cache_type,
783 !paged_attn,
784 max_seq_len,
785 )?;
786
787 let isq = first_model
788 .in_situ_quant
789 .as_ref()
790 .or(self.in_situ_quant.as_ref())
791 .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
792
793 let mut pipeline_names = Vec::new();
794
795 let pipeline: LoadedPipeline = loader.load_model_from_hf(
796 None,
797 self.token_source.clone(),
798 &dtype,
799 &device,
800 false,
801 mapper,
802 isq,
803 cache_config,
804 )?;
805 let first_pipeline_name = pipeline.lock().await.name();
806 info!(
807 "First model loaded: `{first_pipeline_name}` (from config key: {})",
808 first_model.model_id
809 );
810 pipeline_names.push(first_pipeline_name);
811
812 let scheduler_config = init_scheduler_config(&cache_config, &pipeline, self.max_seqs).await;
813 let bert_model = get_bert_model(self.enable_search, self.search_bert_model);
814
815 let mut builder = MistralRsBuilder::new(
817 pipeline,
818 scheduler_config.clone(),
819 !self.interactive_mode,
820 bert_model.clone(),
821 )
822 .with_opt_log(self.log.clone())
823 .with_truncate_sequence(self.truncate_sequence)
824 .with_no_kv_cache(self.no_kv_cache)
825 .with_prefix_cache_n(self.prefix_cache_n);
826
827 if let Some(mcp_config) = self.mcp_client_config.clone() {
829 builder = builder.with_mcp_client(mcp_config);
830 }
831
832 let mistralrs = builder.build().await;
833
834 for model_config in self.models.iter().skip(1) {
836 info!(
837 "Loading additional model from config key: {}",
838 model_config.model_id
839 );
840
841 let model = model_config.model.clone();
842 let dtype = get_model_dtype(&model)?;
843 let auto_device_map_params = get_auto_device_map_params(&model)?;
844
845 let loader: Box<dyn Loader> = LoaderBuilder::new(model)
846 .with_no_kv_cache(self.no_kv_cache)
847 .with_chat_template(
848 model_config
849 .chat_template
850 .clone()
851 .or(self.chat_template.clone()),
852 )
853 .with_prompt_chunksize(prompt_chunksize)
854 .with_jinja_explicit(
855 model_config
856 .jinja_explicit
857 .clone()
858 .or(self.jinja_explicit.clone()),
859 )
860 .build()?;
861
862 let mapper = init_mapper(
863 &model_config
864 .num_device_layers
865 .clone()
866 .or(self.num_device_layers.clone()),
867 &auto_device_map_params,
868 );
869
870 let isq = model_config
871 .in_situ_quant
872 .as_ref()
873 .or(self.in_situ_quant.as_ref())
874 .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
875
876 let pipeline: LoadedPipeline = loader.load_model_from_hf(
877 None,
878 self.token_source.clone(),
879 &dtype,
880 &device,
881 false,
882 mapper,
883 isq,
884 cache_config,
885 )?;
886
887 let pipeline_name = pipeline.lock().await.name();
889
890 if pipeline_names.contains(&pipeline_name) {
892 anyhow::bail!(
893 "Model ID conflict: '{}' is already registered. Models from config keys '{}' and previous models have the same pipeline identifier.",
894 pipeline_name,
895 model_config.model_id
896 );
897 }
898
899 let engine_config = mistralrs_core::EngineConfig {
901 truncate_sequence: self.truncate_sequence,
902 no_kv_cache: self.no_kv_cache,
903 no_prefix_cache: false,
904 prefix_cache_n: self.prefix_cache_n,
905 disable_eos_stop: false,
906 throughput_logging_enabled: !self.interactive_mode,
907 search_embedding_model: bert_model.clone(),
908 search_callback: self.search_callback.clone(),
909 tool_callbacks: HashMap::new(),
910 tool_callbacks_with_tools: HashMap::new(),
911 };
912
913 let mut add_model_config = mistralrs_core::AddModelConfig::new(engine_config);
914 if let Some(mcp_config) = self.mcp_client_config.clone() {
915 add_model_config = add_model_config.with_mcp_config(mcp_config);
916 }
917
918 mistralrs
919 .add_model(
920 pipeline_name.clone(),
921 pipeline,
922 scheduler_config.clone(),
923 add_model_config,
924 )
925 .await
926 .map_err(|e| anyhow::anyhow!("Failed to add model {}: {}", pipeline_name, e))?;
927
928 info!(
929 "Model `{pipeline_name}` registered successfully (from config key: {})",
930 model_config.model_id
931 );
932 pipeline_names.push(pipeline_name);
933 }
934
935 if let Some(ref default_model_id) = self.default_model_id {
937 mistralrs
938 .set_default_model_id(default_model_id)
939 .map_err(|e| anyhow::anyhow!("Failed to set default model: {}", e))?;
940 }
941
942 info!("All models loaded: `{}`", pipeline_names.join("`, `"));
944
945 if let Some(ref default_id) = self.default_model_id {
947 info!("Default model: {}", default_id);
948 } else {
949 info!(
950 "Default model: {} (first model, from config key: {})",
951 pipeline_names[0], self.models[0].model_id
952 );
953 }
954 Ok(mistralrs)
955 }
956}
957
958fn init_device(force_cpu: bool, seed: Option<u64>) -> Result<candle_core::Device> {
961 #[cfg(feature = "metal")]
962 let device = if force_cpu {
963 Device::Cpu
964 } else {
965 Device::new_metal(0)?
966 };
967 #[cfg(not(feature = "metal"))]
968 #[allow(clippy::if_same_then_else)]
969 let device = if force_cpu {
970 Device::Cpu
971 } else if mistralrs_core::distributed::use_nccl() {
972 Device::Cpu
973 } else {
974 Device::cuda_if_available(0)?
975 };
976
977 if let Some(seed) = seed {
978 device.set_seed(seed)?;
979 }
980
981 Ok(device)
982}
983
984fn init_mapper(
986 num_device_layers: &Option<Vec<String>>,
987 auto_device_map_params: &AutoDeviceMapParams,
988) -> DeviceMapSetting {
989 if let Some(device_layers) = num_device_layers {
991 if device_layers.len() == 1 && device_layers[0].parse::<usize>().is_ok() {
992 let layers = device_layers[0].parse::<usize>().unwrap();
993 DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(vec![
994 DeviceLayerMapMetadata { ordinal: 0, layers },
995 ]))
996 } else {
997 let mut mapping = Vec::new();
998 for layer in device_layers {
999 let split = layer.splitn(2, ':').collect::<Vec<_>>();
1000 if split.len() < 2 {
1001 panic!("Expected layer to be of format ORD:NUM, got {layer}");
1002 }
1003 let ord = split[0]
1004 .parse::<usize>()
1005 .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[0]));
1006 let num = split[1]
1007 .parse::<usize>()
1008 .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[1]));
1009 for DeviceLayerMapMetadata { ordinal, layers: _ } in &mapping {
1010 if *ordinal == ord {
1011 panic!("Duplicate ordinal {ord}");
1012 }
1013 }
1014 mapping.push(DeviceLayerMapMetadata {
1015 ordinal: ord,
1016 layers: num,
1017 });
1018 }
1019 DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(mapping))
1020 }
1021 } else {
1022 DeviceMapSetting::Auto(auto_device_map_params.clone())
1023 }
1024}
1025
1026fn mistralrs_instance_info(loader: &dyn Loader) {
1028 info!(
1029 "avx: {}, neon: {}, simd128: {}, f16c: {}",
1030 candle_core::utils::with_avx(),
1031 candle_core::utils::with_neon(),
1032 candle_core::utils::with_simd128(),
1033 candle_core::utils::with_f16c()
1034 );
1035
1036 info!("Sampling method: penalties -> temperature -> topk -> topp -> minp -> multinomial");
1037 info!("Model kind is: {}", loader.get_kind().to_string());
1038}
1039
1040fn configure_paged_attn(device: &Device, paged_attn: Option<bool>) -> bool {
1042 if device.is_cpu() {
1043 if paged_attn == Some(true) {
1044 warn!("Paged attention is not supported on CPU.");
1045 }
1046
1047 defaults::PAGED_ATTN_CPU
1048 } else if device.is_cuda() || mistralrs_core::distributed::use_nccl() {
1049 paged_attn.unwrap_or(defaults::PAGED_ATTN_CUDA)
1050 } else if device.is_metal() {
1051 paged_attn.unwrap_or(defaults::PAGED_ATTN_METAL)
1052 } else {
1053 false
1054 }
1055}
1056
1057fn init_cache_config(
1059 paged_attn_block_size: Option<usize>,
1060 paged_attn_gpu_mem: Option<usize>,
1061 paged_attn_gpu_mem_usage: Option<f32>,
1062 paged_ctxt_len: Option<usize>,
1063 cache_type: PagedCacheType,
1064 no_paged_attn: bool,
1065 max_seq_len: usize,
1066) -> Result<Option<PagedAttentionConfig>> {
1067 match (
1068 paged_attn_block_size,
1069 paged_attn_gpu_mem,
1070 paged_attn_gpu_mem_usage,
1071 paged_ctxt_len,
1072 paged_attn_supported(),
1073 no_paged_attn,
1074 ) {
1075 (block_size, None, None, None, true, false) => Ok(Some(PagedAttentionConfig::new(
1076 block_size,
1077 512,
1078 MemoryGpuConfig::ContextSize(max_seq_len),
1079 cache_type,
1080 )?)),
1081 (block_size, None, None, Some(ctxt), true, false) => Ok(Some(PagedAttentionConfig::new(
1082 block_size,
1083 512,
1084 MemoryGpuConfig::ContextSize(ctxt),
1085 cache_type,
1086 )?)),
1087 (block_size, None, Some(f), None, true, false) => Ok(Some(PagedAttentionConfig::new(
1088 block_size,
1089 512,
1090 MemoryGpuConfig::Utilization(f),
1091 cache_type,
1092 )?)),
1093 (block_size, Some(m), None, None, true, false) => Ok(Some(PagedAttentionConfig::new(
1094 block_size,
1095 512,
1096 MemoryGpuConfig::MbAmount(m),
1097 cache_type,
1098 )?)),
1099 (block_size, Some(_m), Some(f), None, true, false) => {
1100 info!("Both memory size, and usage were specified, defaulting to the usage value.");
1101 Ok(Some(PagedAttentionConfig::new(
1102 block_size,
1103 512,
1104 MemoryGpuConfig::Utilization(f),
1105 cache_type,
1106 )?))
1107 }
1108 (block_size, Some(_m), None, Some(ctxt), true, false) => {
1109 info!("All memory size and ctxt len, defaulting to the context len value.");
1110 Ok(Some(PagedAttentionConfig::new(
1111 block_size,
1112 512,
1113 MemoryGpuConfig::ContextSize(ctxt),
1114 cache_type,
1115 )?))
1116 }
1117 (block_size, None, Some(f), Some(_ctxt), true, false) => {
1118 info!("Both ctxt len and usage were specified, defaulting to the usage value.");
1119 Ok(Some(PagedAttentionConfig::new(
1120 block_size,
1121 512,
1122 MemoryGpuConfig::Utilization(f),
1123 cache_type,
1124 )?))
1125 }
1126 (_, _, _, _, _, _) => Ok(None),
1127 }
1128}
1129
1130async fn init_scheduler_config(
1132 cache_config: &Option<PagedAttentionConfig>,
1133 pipeline: &LoadedPipeline,
1134 args_max_seqs: usize,
1135) -> SchedulerConfig {
1136 if cache_config.is_some() {
1137 if let Some(ref cache_config) = pipeline.lock().await.get_metadata().cache_config {
1139 SchedulerConfig::PagedAttentionMeta {
1140 max_num_seqs: args_max_seqs,
1141 config: cache_config.clone(),
1142 }
1143 } else {
1144 SchedulerConfig::DefaultScheduler {
1145 method: DefaultSchedulerMethod::Fixed(args_max_seqs.try_into().unwrap()),
1146 }
1147 }
1148 } else {
1149 SchedulerConfig::DefaultScheduler {
1150 method: DefaultSchedulerMethod::Fixed(args_max_seqs.try_into().unwrap()),
1151 }
1152 }
1153}
1154
1155pub fn configure_paged_attn_from_flags(
1160 paged_attn: bool,
1161 no_paged_attn: bool,
1162) -> Result<Option<bool>> {
1163 match (paged_attn, no_paged_attn) {
1164 (true, true) => {
1165 anyhow::bail!("Error: `--paged-attn` and `--no-paged-attn` cannot be used together.");
1166 }
1167 (true, false) => Ok(Some(true)),
1168 (false, true) => Ok(Some(false)),
1169 (false, false) => Ok(None),
1170 }
1171}
1172
1173pub fn get_bert_model(
1175 enable_search: bool,
1176 search_bert_model: Option<String>,
1177) -> Option<BertEmbeddingModel> {
1178 if enable_search {
1179 Some(
1180 search_bert_model
1181 .map(BertEmbeddingModel::Custom)
1182 .unwrap_or_default(),
1183 )
1184 } else {
1185 None
1186 }
1187}