1use std::sync::Arc;
4
5use anyhow::{Context, Result};
6use candle_core::Device;
7use mistralrs_core::{
8 get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, paged_attn_supported,
9 parse_isq_value, AutoDeviceMapParams, BertEmbeddingModel, DefaultSchedulerMethod,
10 DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, Loader, LoaderBuilder,
11 McpClientConfig, MemoryGpuConfig, MistralRsBuilder, ModelSelected, PagedAttentionConfig,
12 PagedCacheType, SchedulerConfig, SearchCallback, TokenSource,
13};
14use tracing::{info, warn};
15
16use crate::types::{LoadedPipeline, SharedMistralRsState};
17use std::collections::HashMap;
18
19#[derive(Clone, serde::Deserialize)]
21pub struct ModelConfig {
22 pub model_id: String,
24 pub model: ModelSelected,
26 pub chat_template: Option<String>,
28 pub jinja_explicit: Option<String>,
30 pub num_device_layers: Option<Vec<String>>,
32 pub in_situ_quant: Option<String>,
34}
35
36impl ModelConfig {
37 pub fn new(model_id: String, model: ModelSelected) -> Self {
38 Self {
39 model_id,
40 model,
41 chat_template: None,
42 jinja_explicit: None,
43 num_device_layers: None,
44 in_situ_quant: None,
45 }
46 }
47
48 pub fn with_chat_template(mut self, chat_template: String) -> Self {
49 self.chat_template = Some(chat_template);
50 self
51 }
52
53 pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
54 self.jinja_explicit = Some(jinja_explicit);
55 self
56 }
57
58 pub fn with_num_device_layers(mut self, num_device_layers: Vec<String>) -> Self {
59 self.num_device_layers = Some(num_device_layers);
60 self
61 }
62
63 pub fn with_in_situ_quant(mut self, in_situ_quant: String) -> Self {
64 self.in_situ_quant = Some(in_situ_quant);
65 self
66 }
67}
68
69pub mod defaults {
70 use std::sync::Arc;
74
75 use mistralrs_core::PagedCacheType;
76
77 pub const DEVICE: Option<candle_core::Device> = None;
78 pub const SEED: Option<u64> = None;
79 pub const LOG: Option<String> = None;
80 pub const TRUNCATE_SEQUENCE: bool = false;
81 pub const MODEL: Option<mistralrs_core::ModelSelected> = None;
82 pub const MAX_SEQS: usize = 16;
83 pub const NO_KV_CACHE: bool = false;
84 pub const CHAT_TEMPLATE: Option<String> = None;
85 pub const JINJA_EXPLICIT: Option<String> = None;
86 pub const INTERACTIVE_MODE: bool = false;
87 pub const PREFIX_CACHE_N: usize = 16;
88 pub const NUM_DEVICE_LAYERS: Option<Vec<String>> = None;
89 pub const IN_SITU_QUANT: Option<String> = None;
90 pub const PAGED_ATTN_GPU_MEM: Option<usize> = None;
91 pub const PAGED_ATTN_GPU_MEM_USAGE: Option<f32> = None;
92 pub const PAGED_CTXT_LEN: Option<usize> = None;
93 pub const PAGED_ATTN_BLOCK_SIZE: Option<usize> = None;
94 pub const PAGED_ATTN: Option<bool> = None;
95 pub const PAGED_ATTN_CPU: bool = false;
96 pub const PAGED_ATTN_CUDA: bool = true;
97 pub const PAGED_ATTN_METAL: bool = false;
98 pub const CPU: bool = false;
99 pub const ENABLE_SEARCH: bool = false;
100 pub const SEARCH_BERT_MODEL: Option<String> = None;
101 pub const TOKEN_SOURCE: mistralrs_core::TokenSource = mistralrs_core::TokenSource::CacheToken;
102 pub const SEARCH_CALLBACK: Option<Arc<mistralrs_core::SearchCallback>> = None;
103 pub const PAGED_CACHE_TYPE: PagedCacheType = PagedCacheType::Auto;
104}
105
106pub struct MistralRsForServerBuilder {
141 device: Option<Device>,
143
144 seed: Option<u64>,
146
147 log: Option<String>,
149
150 truncate_sequence: bool,
154
155 model: Option<ModelSelected>,
157
158 models: Vec<ModelConfig>,
160
161 default_model_id: Option<String>,
163
164 max_seqs: usize,
166
167 no_kv_cache: bool,
169
170 chat_template: Option<String>,
173
174 jinja_explicit: Option<String>,
176
177 token_source: TokenSource,
181
182 interactive_mode: bool,
184
185 prefix_cache_n: usize,
187
188 num_device_layers: Option<Vec<String>>,
193
194 in_situ_quant: Option<String>,
196
197 paged_attn_gpu_mem: Option<usize>,
201
202 paged_attn_gpu_mem_usage: Option<f32>,
207
208 paged_ctxt_len: Option<usize>,
213
214 paged_attn_block_size: Option<usize>,
217
218 paged_attn: Option<bool>,
220
221 cpu: bool,
223
224 enable_search: bool,
226
227 search_bert_model: Option<String>,
229
230 search_callback: Option<Arc<SearchCallback>>,
232
233 mcp_client_config: Option<McpClientConfig>,
235
236 paged_cache_type: PagedCacheType,
238}
239
240impl Default for MistralRsForServerBuilder {
241 fn default() -> Self {
243 Self {
244 device: defaults::DEVICE,
245 seed: defaults::SEED,
246 log: defaults::LOG,
247 truncate_sequence: defaults::TRUNCATE_SEQUENCE,
248 model: defaults::MODEL,
249 models: Vec::new(),
250 default_model_id: None,
251 max_seqs: defaults::MAX_SEQS,
252 no_kv_cache: defaults::NO_KV_CACHE,
253 chat_template: defaults::CHAT_TEMPLATE,
254 jinja_explicit: defaults::JINJA_EXPLICIT,
255 token_source: defaults::TOKEN_SOURCE,
256 interactive_mode: defaults::INTERACTIVE_MODE,
257 prefix_cache_n: defaults::PREFIX_CACHE_N,
258 num_device_layers: defaults::NUM_DEVICE_LAYERS,
259 in_situ_quant: defaults::IN_SITU_QUANT,
260 paged_attn_gpu_mem: defaults::PAGED_ATTN_GPU_MEM,
261 paged_attn_gpu_mem_usage: defaults::PAGED_ATTN_GPU_MEM_USAGE,
262 paged_ctxt_len: defaults::PAGED_CTXT_LEN,
263 paged_attn_block_size: defaults::PAGED_ATTN_BLOCK_SIZE,
264 paged_attn: defaults::PAGED_ATTN,
265 cpu: defaults::CPU,
266 enable_search: defaults::ENABLE_SEARCH,
267 search_bert_model: defaults::SEARCH_BERT_MODEL,
268 search_callback: defaults::SEARCH_CALLBACK,
269 mcp_client_config: None,
270 paged_cache_type: defaults::PAGED_CACHE_TYPE,
271 }
272 }
273}
274
275impl MistralRsForServerBuilder {
276 pub fn new() -> Self {
288 Default::default()
289 }
290
291 pub fn with_device(mut self, device: Device) -> Self {
293 self.device = Some(device);
294 self
295 }
296
297 pub fn with_seed(mut self, seed: u64) -> Self {
299 self.seed = Some(seed);
300 self
301 }
302
303 pub fn with_seed_optional(mut self, seed: Option<u64>) -> Self {
305 if let Some(seed) = seed {
306 self = self.with_seed(seed);
307 }
308 self
309 }
310
311 pub fn with_log(mut self, log: String) -> Self {
313 self.log = Some(log);
314 self
315 }
316
317 pub fn with_log_optional(mut self, log: Option<String>) -> Self {
319 if let Some(log) = log {
320 self = self.with_log(log);
321 }
322 self
323 }
324
325 pub fn with_truncate_sequence(mut self, truncate_sequence: bool) -> Self {
327 self.truncate_sequence = truncate_sequence;
328 self
329 }
330
331 pub fn with_model(mut self, model: ModelSelected) -> Self {
333 self.model = Some(model);
334 self
335 }
336
337 pub fn with_model_config(mut self, model_config: ModelConfig) -> Self {
339 self.models.push(model_config);
340 self
341 }
342
343 pub fn with_model_configs(mut self, model_configs: Vec<ModelConfig>) -> Self {
345 self.models.extend(model_configs);
346 self
347 }
348
349 pub fn with_default_model_id(mut self, default_model_id: String) -> Self {
351 self.default_model_id = Some(default_model_id);
352 self
353 }
354
355 pub fn add_model_config(mut self, config: ModelConfig) -> Self {
357 self.models.push(config);
358 self
359 }
360
361 pub fn add_model(mut self, model_id: String, model: ModelSelected) -> Self {
363 self.models.push(ModelConfig::new(model_id, model));
364 self
365 }
366
367 pub fn with_max_seqs(mut self, max_seqs: usize) -> Self {
369 self.max_seqs = max_seqs;
370 self
371 }
372
373 pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
375 self.no_kv_cache = no_kv_cache;
376 self
377 }
378
379 pub fn with_chat_template(mut self, chat_template: String) -> Self {
381 self.chat_template = Some(chat_template);
382 self
383 }
384
385 pub fn with_chat_template_optional(mut self, chat_template: Option<String>) -> Self {
387 if let Some(chat_template) = chat_template {
388 self = self.with_chat_template(chat_template);
389 }
390 self
391 }
392
393 pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
395 self.jinja_explicit = Some(jinja_explicit);
396 self
397 }
398
399 pub fn with_jinja_explicit_optional(mut self, jinja_explicit: Option<String>) -> Self {
401 if let Some(jinja_explicit) = jinja_explicit {
402 self = self.with_jinja_explicit(jinja_explicit);
403 }
404 self
405 }
406
407 pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
409 self.token_source = token_source;
410 self
411 }
412
413 pub fn with_interactive_mode(mut self, interactive_mode: bool) -> Self {
415 self.interactive_mode = interactive_mode;
416 self
417 }
418
419 pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
421 self.prefix_cache_n = prefix_cache_n;
422 self
423 }
424
425 pub fn with_num_device_layers(mut self, num_device_layers: Vec<String>) -> Self {
427 self.num_device_layers = Some(num_device_layers);
428 self
429 }
430
431 pub fn with_num_device_layers_optional(
433 mut self,
434 num_device_layers: Option<Vec<String>>,
435 ) -> Self {
436 if let Some(num_device_layers) = num_device_layers {
437 self = self.with_num_device_layers(num_device_layers);
438 }
439 self
440 }
441
442 pub fn with_in_situ_quant(mut self, in_situ_quant: String) -> Self {
444 self.in_situ_quant = Some(in_situ_quant);
445 self
446 }
447
448 pub fn with_in_situ_quant_optional(mut self, in_situ_quant: Option<String>) -> Self {
450 if let Some(in_situ_quant) = in_situ_quant {
451 self = self.with_in_situ_quant(in_situ_quant);
452 }
453 self
454 }
455
456 pub fn set_paged_attn(mut self, paged_attn: Option<bool>) -> Self {
466 self.paged_attn = paged_attn;
467 self
468 }
469
470 pub fn with_paged_attn_gpu_mem(mut self, paged_attn_gpu_mem: usize) -> Self {
472 self.paged_attn_gpu_mem = Some(paged_attn_gpu_mem);
473 self
474 }
475
476 pub fn with_paged_attn_gpu_mem_optional(mut self, paged_attn_gpu_mem: Option<usize>) -> Self {
478 if let Some(paged_attn_gpu_mem) = paged_attn_gpu_mem {
479 self = self.with_paged_attn_gpu_mem(paged_attn_gpu_mem);
480 }
481 self
482 }
483
484 pub fn with_paged_attn_gpu_mem_usage(mut self, paged_attn_gpu_mem_usage: f32) -> Self {
486 self.paged_attn_gpu_mem_usage = Some(paged_attn_gpu_mem_usage);
487 self
488 }
489
490 pub fn with_paged_attn_gpu_mem_usage_optional(
492 mut self,
493 paged_attn_gpu_mem_usage: Option<f32>,
494 ) -> Self {
495 if let Some(paged_attn_gpu_mem_usage) = paged_attn_gpu_mem_usage {
496 self = self.with_paged_attn_gpu_mem_usage(paged_attn_gpu_mem_usage);
497 }
498 self
499 }
500
501 pub fn with_paged_ctxt_len(mut self, paged_ctxt_len: usize) -> Self {
503 self.paged_ctxt_len = Some(paged_ctxt_len);
504 self
505 }
506
507 pub fn with_paged_ctxt_len_optional(mut self, paged_ctxt_len: Option<usize>) -> Self {
509 if let Some(paged_ctxt_len) = paged_ctxt_len {
510 self = self.with_paged_ctxt_len(paged_ctxt_len);
511 }
512 self
513 }
514
515 pub fn with_paged_attn_block_size(mut self, paged_attn_block_size: usize) -> Self {
517 self.paged_attn_block_size = Some(paged_attn_block_size);
518 self
519 }
520
521 pub fn with_paged_attn_cache_type(mut self, cache_type: PagedCacheType) -> Self {
523 self.paged_cache_type = cache_type;
524 self
525 }
526
527 pub fn with_paged_attn_block_size_optional(
529 mut self,
530 paged_attn_block_size: Option<usize>,
531 ) -> Self {
532 if let Some(paged_attn_block_size) = paged_attn_block_size {
533 self = self.with_paged_attn_block_size(paged_attn_block_size);
534 }
535 self
536 }
537
538 pub fn with_cpu(mut self, cpu: bool) -> Self {
540 self.cpu = cpu;
541 self
542 }
543
544 pub fn with_enable_search(mut self, enable_search: bool) -> Self {
546 self.enable_search = enable_search;
547 self
548 }
549
550 pub fn with_search_bert_model(mut self, search_bert_model: String) -> Self {
552 self.search_bert_model = Some(search_bert_model);
553 self
554 }
555
556 pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
558 self.search_callback = Some(callback);
559 self
560 }
561
562 pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
564 self.mcp_client_config = Some(mcp_config);
565 self
566 }
567
568 pub fn with_mcp_config_optional(mut self, mcp_config: Option<McpClientConfig>) -> Self {
570 if let Some(mcp_config) = mcp_config {
571 self = self.with_mcp_config(mcp_config);
572 }
573 self
574 }
575
576 pub async fn build(self) -> Result<SharedMistralRsState> {
591 if !self.models.is_empty() {
593 self.build_multi_model().await
594 } else {
595 self.build_single_model().await
596 }
597 }
598
599 async fn build_single_model(mut self) -> Result<SharedMistralRsState> {
601 let model = self.model.context("Model was None")?;
602
603 let tgt_non_granular_index = get_tgt_non_granular_index(&model);
604 let dtype = get_model_dtype(&model)?;
605 let auto_device_map_params = get_auto_device_map_params(&model)?;
606
607 if tgt_non_granular_index.is_some() {
608 self.max_seqs = 1;
609 }
610
611 let max_seq_len = auto_device_map_params.max_seq_len();
612
613 let device = if let Some(device) = self.device {
614 device
615 } else {
616 init_device(self.cpu, self.seed)?
617 };
618
619 let mapper = init_mapper(&self.num_device_layers, &auto_device_map_params);
620 let paged_attn = configure_paged_attn(&device, self.paged_attn);
621
622 let cache_config = init_cache_config(
625 self.paged_attn_block_size,
626 self.paged_attn_gpu_mem,
627 self.paged_attn_gpu_mem_usage,
628 self.paged_ctxt_len,
629 self.paged_cache_type,
630 !paged_attn,
631 max_seq_len,
632 )?;
633
634 let loader: Box<dyn Loader> = LoaderBuilder::new(model)
636 .with_no_kv_cache(self.no_kv_cache)
637 .with_chat_template(self.chat_template)
638 .with_jinja_explicit(self.jinja_explicit)
639 .build()?;
640
641 mistralrs_instance_info(&*loader);
642
643 let isq = self
644 .in_situ_quant
645 .as_ref()
646 .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
647
648 let pipeline: LoadedPipeline = loader.load_model_from_hf(
649 None,
650 self.token_source,
651 &dtype,
652 &device,
653 false,
654 mapper,
655 isq,
656 cache_config,
657 )?;
658 info!("Model loaded.");
659
660 let scheduler_config = init_scheduler_config(&cache_config, &pipeline, self.max_seqs).await;
661
662 let bert_model = get_bert_model(self.enable_search, self.search_bert_model);
663
664 let mut builder = MistralRsBuilder::new(
665 pipeline,
666 scheduler_config,
667 !self.interactive_mode,
668 bert_model,
669 )
670 .with_opt_log(self.log)
671 .with_truncate_sequence(self.truncate_sequence)
672 .with_no_kv_cache(self.no_kv_cache)
673 .with_prefix_cache_n(self.prefix_cache_n);
674
675 if let Some(mcp_config) = self.mcp_client_config {
677 builder = builder.with_mcp_client(mcp_config);
678 }
679
680 let mistralrs = builder.build().await;
681
682 Ok(mistralrs)
683 }
684
685 pub async fn build_multi_model(mut self) -> Result<SharedMistralRsState> {
687 if self.models.is_empty() {
688 anyhow::bail!("No models configured for multi-model mode");
689 }
690
691 let first_model = &self.models[0];
693 let model = first_model.model.clone();
694
695 let tgt_non_granular_index = get_tgt_non_granular_index(&model);
696 let dtype = get_model_dtype(&model)?;
697 let auto_device_map_params = get_auto_device_map_params(&model)?;
698
699 if tgt_non_granular_index.is_some() {
700 self.max_seqs = 1;
701 }
702
703 let max_seq_len = auto_device_map_params.max_seq_len();
704
705 let device = if let Some(device) = self.device {
706 device
707 } else {
708 init_device(self.cpu, self.seed)?
709 };
710
711 let loader: Box<dyn Loader> = LoaderBuilder::new(model)
713 .with_no_kv_cache(self.no_kv_cache)
714 .with_chat_template(
715 first_model
716 .chat_template
717 .clone()
718 .or(self.chat_template.clone()),
719 )
720 .with_jinja_explicit(
721 first_model
722 .jinja_explicit
723 .clone()
724 .or(self.jinja_explicit.clone()),
725 )
726 .build()?;
727
728 mistralrs_instance_info(&*loader);
729
730 let mapper = init_mapper(
731 &first_model
732 .num_device_layers
733 .clone()
734 .or(self.num_device_layers.clone()),
735 &auto_device_map_params,
736 );
737 let paged_attn = configure_paged_attn(&device, self.paged_attn);
738
739 let cache_config = init_cache_config(
740 self.paged_attn_block_size,
741 self.paged_attn_gpu_mem,
742 self.paged_attn_gpu_mem_usage,
743 self.paged_ctxt_len,
744 self.paged_cache_type,
745 !paged_attn,
746 max_seq_len,
747 )?;
748
749 let isq = first_model
750 .in_situ_quant
751 .as_ref()
752 .or(self.in_situ_quant.as_ref())
753 .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
754
755 let mut pipeline_names = Vec::new();
756
757 let pipeline: LoadedPipeline = loader.load_model_from_hf(
758 None,
759 self.token_source.clone(),
760 &dtype,
761 &device,
762 false,
763 mapper,
764 isq,
765 cache_config,
766 )?;
767 let first_pipeline_name = pipeline.lock().await.name();
768 info!(
769 "First model loaded: `{first_pipeline_name}` (from config key: {})",
770 first_model.model_id
771 );
772 pipeline_names.push(first_pipeline_name);
773
774 let scheduler_config = init_scheduler_config(&cache_config, &pipeline, self.max_seqs).await;
775 let bert_model = get_bert_model(self.enable_search, self.search_bert_model);
776
777 let mut builder = MistralRsBuilder::new(
779 pipeline,
780 scheduler_config.clone(),
781 !self.interactive_mode,
782 bert_model.clone(),
783 )
784 .with_opt_log(self.log.clone())
785 .with_truncate_sequence(self.truncate_sequence)
786 .with_no_kv_cache(self.no_kv_cache)
787 .with_prefix_cache_n(self.prefix_cache_n);
788
789 if let Some(mcp_config) = self.mcp_client_config.clone() {
791 builder = builder.with_mcp_client(mcp_config);
792 }
793
794 let mistralrs = builder.build().await;
795
796 for model_config in self.models.iter().skip(1) {
798 info!(
799 "Loading additional model from config key: {}",
800 model_config.model_id
801 );
802
803 let model = model_config.model.clone();
804 let dtype = get_model_dtype(&model)?;
805 let auto_device_map_params = get_auto_device_map_params(&model)?;
806
807 let loader: Box<dyn Loader> = LoaderBuilder::new(model)
808 .with_no_kv_cache(self.no_kv_cache)
809 .with_chat_template(
810 model_config
811 .chat_template
812 .clone()
813 .or(self.chat_template.clone()),
814 )
815 .with_jinja_explicit(
816 model_config
817 .jinja_explicit
818 .clone()
819 .or(self.jinja_explicit.clone()),
820 )
821 .build()?;
822
823 let mapper = init_mapper(
824 &model_config
825 .num_device_layers
826 .clone()
827 .or(self.num_device_layers.clone()),
828 &auto_device_map_params,
829 );
830
831 let isq = model_config
832 .in_situ_quant
833 .as_ref()
834 .or(self.in_situ_quant.as_ref())
835 .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
836
837 let pipeline: LoadedPipeline = loader.load_model_from_hf(
838 None,
839 self.token_source.clone(),
840 &dtype,
841 &device,
842 false,
843 mapper,
844 isq,
845 cache_config,
846 )?;
847
848 let pipeline_name = pipeline.lock().await.name();
850
851 if pipeline_names.contains(&pipeline_name) {
853 anyhow::bail!(
854 "Model ID conflict: '{}' is already registered. Models from config keys '{}' and previous models have the same pipeline identifier.",
855 pipeline_name,
856 model_config.model_id
857 );
858 }
859
860 let engine_config = mistralrs_core::EngineConfig {
862 truncate_sequence: self.truncate_sequence,
863 no_kv_cache: self.no_kv_cache,
864 no_prefix_cache: false,
865 prefix_cache_n: self.prefix_cache_n,
866 disable_eos_stop: false,
867 throughput_logging_enabled: !self.interactive_mode,
868 search_embedding_model: bert_model.clone(),
869 search_callback: self.search_callback.clone(),
870 tool_callbacks: HashMap::new(),
871 tool_callbacks_with_tools: HashMap::new(),
872 };
873
874 let mut add_model_config = mistralrs_core::AddModelConfig::new(engine_config);
875 if let Some(mcp_config) = self.mcp_client_config.clone() {
876 add_model_config = add_model_config.with_mcp_config(mcp_config);
877 }
878
879 mistralrs
880 .add_model(
881 pipeline_name.clone(),
882 pipeline,
883 scheduler_config.clone(),
884 add_model_config,
885 )
886 .await
887 .map_err(|e| anyhow::anyhow!("Failed to add model {}: {}", pipeline_name, e))?;
888
889 info!(
890 "Model `{pipeline_name}` registered successfully (from config key: {})",
891 model_config.model_id
892 );
893 pipeline_names.push(pipeline_name);
894 }
895
896 if let Some(ref default_model_id) = self.default_model_id {
898 mistralrs
899 .set_default_model_id(default_model_id)
900 .map_err(|e| anyhow::anyhow!("Failed to set default model: {}", e))?;
901 }
902
903 info!("All models loaded: `{}`", pipeline_names.join("`, `"));
905
906 if let Some(ref default_id) = self.default_model_id {
908 info!("Default model: {}", default_id);
909 } else {
910 info!(
911 "Default model: {} (first model, from config key: {})",
912 pipeline_names[0], self.models[0].model_id
913 );
914 }
915 Ok(mistralrs)
916 }
917}
918
919fn init_device(force_cpu: bool, seed: Option<u64>) -> Result<candle_core::Device> {
922 #[cfg(feature = "metal")]
923 let device = if force_cpu {
924 Device::Cpu
925 } else {
926 Device::new_metal(0)?
927 };
928 #[cfg(not(feature = "metal"))]
929 #[allow(clippy::if_same_then_else)]
930 let device = if force_cpu {
931 Device::Cpu
932 } else if mistralrs_core::distributed::use_nccl() {
933 Device::Cpu
934 } else {
935 Device::cuda_if_available(0)?
936 };
937
938 if let Some(seed) = seed {
939 device.set_seed(seed)?;
940 }
941
942 Ok(device)
943}
944
945fn init_mapper(
947 num_device_layers: &Option<Vec<String>>,
948 auto_device_map_params: &AutoDeviceMapParams,
949) -> DeviceMapSetting {
950 if let Some(device_layers) = num_device_layers {
952 if device_layers.len() == 1 && device_layers[0].parse::<usize>().is_ok() {
953 let layers = device_layers[0].parse::<usize>().unwrap();
954 DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(vec![
955 DeviceLayerMapMetadata { ordinal: 0, layers },
956 ]))
957 } else {
958 let mut mapping = Vec::new();
959 for layer in device_layers {
960 let split = layer.splitn(2, ':').collect::<Vec<_>>();
961 if split.len() < 2 {
962 panic!("Expected layer to be of format ORD:NUM, got {layer}");
963 }
964 let ord = split[0]
965 .parse::<usize>()
966 .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[0]));
967 let num = split[1]
968 .parse::<usize>()
969 .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[1]));
970 for DeviceLayerMapMetadata { ordinal, layers: _ } in &mapping {
971 if *ordinal == ord {
972 panic!("Duplicate ordinal {ord}");
973 }
974 }
975 mapping.push(DeviceLayerMapMetadata {
976 ordinal: ord,
977 layers: num,
978 });
979 }
980 DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(mapping))
981 }
982 } else {
983 DeviceMapSetting::Auto(auto_device_map_params.clone())
984 }
985}
986
987fn mistralrs_instance_info(loader: &dyn Loader) {
989 info!(
990 "avx: {}, neon: {}, simd128: {}, f16c: {}",
991 candle_core::utils::with_avx(),
992 candle_core::utils::with_neon(),
993 candle_core::utils::with_simd128(),
994 candle_core::utils::with_f16c()
995 );
996
997 info!("Sampling method: penalties -> temperature -> topk -> topp -> minp -> multinomial");
998 info!("Model kind is: {}", loader.get_kind().to_string());
999}
1000
1001fn configure_paged_attn(device: &Device, paged_attn: Option<bool>) -> bool {
1003 if device.is_cpu() {
1004 if paged_attn == Some(true) {
1005 warn!("Paged attention is not supported on CPU.");
1006 }
1007
1008 defaults::PAGED_ATTN_CPU
1009 } else if device.is_cuda() || mistralrs_core::distributed::use_nccl() {
1010 paged_attn.unwrap_or(defaults::PAGED_ATTN_CUDA)
1011 } else if device.is_metal() {
1012 paged_attn.unwrap_or(defaults::PAGED_ATTN_METAL)
1013 } else {
1014 false
1015 }
1016}
1017
1018fn init_cache_config(
1020 paged_attn_block_size: Option<usize>,
1021 paged_attn_gpu_mem: Option<usize>,
1022 paged_attn_gpu_mem_usage: Option<f32>,
1023 paged_ctxt_len: Option<usize>,
1024 cache_type: PagedCacheType,
1025 no_paged_attn: bool,
1026 max_seq_len: usize,
1027) -> Result<Option<PagedAttentionConfig>> {
1028 match (
1029 paged_attn_block_size,
1030 paged_attn_gpu_mem,
1031 paged_attn_gpu_mem_usage,
1032 paged_ctxt_len,
1033 paged_attn_supported(),
1034 no_paged_attn,
1035 ) {
1036 (block_size, None, None, None, true, false) => Ok(Some(PagedAttentionConfig::new(
1037 block_size,
1038 512,
1039 MemoryGpuConfig::ContextSize(max_seq_len),
1040 cache_type,
1041 )?)),
1042 (block_size, None, None, Some(ctxt), true, false) => Ok(Some(PagedAttentionConfig::new(
1043 block_size,
1044 512,
1045 MemoryGpuConfig::ContextSize(ctxt),
1046 cache_type,
1047 )?)),
1048 (block_size, None, Some(f), None, true, false) => Ok(Some(PagedAttentionConfig::new(
1049 block_size,
1050 512,
1051 MemoryGpuConfig::Utilization(f),
1052 cache_type,
1053 )?)),
1054 (block_size, Some(m), None, None, true, false) => Ok(Some(PagedAttentionConfig::new(
1055 block_size,
1056 512,
1057 MemoryGpuConfig::MbAmount(m),
1058 cache_type,
1059 )?)),
1060 (block_size, Some(_m), Some(f), None, true, false) => {
1061 info!("Both memory size, and usage were specified, defaulting to the usage value.");
1062 Ok(Some(PagedAttentionConfig::new(
1063 block_size,
1064 512,
1065 MemoryGpuConfig::Utilization(f),
1066 cache_type,
1067 )?))
1068 }
1069 (block_size, Some(_m), None, Some(ctxt), true, false) => {
1070 info!("All memory size and ctxt len, defaulting to the context len value.");
1071 Ok(Some(PagedAttentionConfig::new(
1072 block_size,
1073 512,
1074 MemoryGpuConfig::ContextSize(ctxt),
1075 cache_type,
1076 )?))
1077 }
1078 (block_size, None, Some(f), Some(_ctxt), true, false) => {
1079 info!("Both ctxt len and usage were specified, defaulting to the usage value.");
1080 Ok(Some(PagedAttentionConfig::new(
1081 block_size,
1082 512,
1083 MemoryGpuConfig::Utilization(f),
1084 cache_type,
1085 )?))
1086 }
1087 (_, _, _, _, _, _) => Ok(None),
1088 }
1089}
1090
1091async fn init_scheduler_config(
1093 cache_config: &Option<PagedAttentionConfig>,
1094 pipeline: &LoadedPipeline,
1095 args_max_seqs: usize,
1096) -> SchedulerConfig {
1097 if cache_config.is_some() {
1098 if let Some(ref cache_config) = pipeline.lock().await.get_metadata().cache_config {
1100 SchedulerConfig::PagedAttentionMeta {
1101 max_num_seqs: args_max_seqs,
1102 config: cache_config.clone(),
1103 }
1104 } else {
1105 SchedulerConfig::DefaultScheduler {
1106 method: DefaultSchedulerMethod::Fixed(args_max_seqs.try_into().unwrap()),
1107 }
1108 }
1109 } else {
1110 SchedulerConfig::DefaultScheduler {
1111 method: DefaultSchedulerMethod::Fixed(args_max_seqs.try_into().unwrap()),
1112 }
1113 }
1114}
1115
1116pub fn configure_paged_attn_from_flags(
1121 paged_attn: bool,
1122 no_paged_attn: bool,
1123) -> Result<Option<bool>> {
1124 match (paged_attn, no_paged_attn) {
1125 (true, true) => {
1126 anyhow::bail!("Error: `--paged-attn` and `--no-paged-attn` cannot be used together.");
1127 }
1128 (true, false) => Ok(Some(true)),
1129 (false, true) => Ok(Some(false)),
1130 (false, false) => Ok(None),
1131 }
1132}
1133
1134pub fn get_bert_model(
1136 enable_search: bool,
1137 search_bert_model: Option<String>,
1138) -> Option<BertEmbeddingModel> {
1139 if enable_search {
1140 Some(
1141 search_bert_model
1142 .map(BertEmbeddingModel::Custom)
1143 .unwrap_or_default(),
1144 )
1145 } else {
1146 None
1147 }
1148}