1use std::sync::Arc;
4
5use anyhow::{Context, Result};
6use candle_core::Device;
7use mistralrs_core::{
8 get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, paged_attn_supported,
9 parse_isq_value, AutoDeviceMapParams, BertEmbeddingModel, DefaultSchedulerMethod,
10 DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, Loader, LoaderBuilder,
11 McpClientConfig, MemoryGpuConfig, MistralRsBuilder, ModelSelected, PagedAttentionConfig,
12 PagedCacheType, SchedulerConfig, SearchCallback, TokenSource,
13};
14use tracing::{info, warn};
15
16use crate::types::{LoadedPipeline, SharedMistralRsState};
17use std::collections::HashMap;
18
19#[derive(Clone, serde::Deserialize)]
21pub struct ModelConfig {
22 pub model_id: String,
24 pub model: ModelSelected,
26 pub chat_template: Option<String>,
28 pub jinja_explicit: Option<String>,
30 pub num_device_layers: Option<Vec<String>>,
32 pub in_situ_quant: Option<String>,
34}
35
36impl ModelConfig {
37 pub fn new(model_id: String, model: ModelSelected) -> Self {
38 Self {
39 model_id,
40 model,
41 chat_template: None,
42 jinja_explicit: None,
43 num_device_layers: None,
44 in_situ_quant: None,
45 }
46 }
47
48 pub fn with_chat_template(mut self, chat_template: String) -> Self {
49 self.chat_template = Some(chat_template);
50 self
51 }
52
53 pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
54 self.jinja_explicit = Some(jinja_explicit);
55 self
56 }
57
58 pub fn with_num_device_layers(mut self, num_device_layers: Vec<String>) -> Self {
59 self.num_device_layers = Some(num_device_layers);
60 self
61 }
62
63 pub fn with_in_situ_quant(mut self, in_situ_quant: String) -> Self {
64 self.in_situ_quant = Some(in_situ_quant);
65 self
66 }
67}
68
69pub mod defaults {
70 use std::sync::Arc;
74
75 use mistralrs_core::PagedCacheType;
76
77 pub const DEVICE: Option<candle_core::Device> = None;
78 pub const SEED: Option<u64> = None;
79 pub const LOG: Option<String> = None;
80 pub const TRUNCATE_SEQUENCE: bool = false;
81 pub const MODEL: Option<mistralrs_core::ModelSelected> = None;
82 pub const MAX_SEQS: usize = 16;
83 pub const NO_KV_CACHE: bool = false;
84 pub const CHAT_TEMPLATE: Option<String> = None;
85 pub const JINJA_EXPLICIT: Option<String> = None;
86 pub const INTERACTIVE_MODE: bool = false;
87 pub const PREFIX_CACHE_N: usize = 16;
88 pub const NUM_DEVICE_LAYERS: Option<Vec<String>> = None;
89 pub const IN_SITU_QUANT: Option<String> = None;
90 pub const PAGED_ATTN_GPU_MEM: Option<usize> = None;
91 pub const PAGED_ATTN_GPU_MEM_USAGE: Option<f32> = None;
92 pub const PAGED_CTXT_LEN: Option<usize> = None;
93 pub const PAGED_ATTN_BLOCK_SIZE: Option<usize> = None;
94 pub const PAGED_ATTN: Option<bool> = None;
95 pub const PAGED_ATTN_CPU: bool = false;
96 pub const PAGED_ATTN_CUDA: bool = true;
97 pub const PAGED_ATTN_METAL: bool = false;
98 pub const CPU: bool = false;
99 pub const ENABLE_SEARCH: bool = false;
100 pub const SEARCH_BERT_MODEL: Option<String> = None;
101 pub const TOKEN_SOURCE: mistralrs_core::TokenSource = mistralrs_core::TokenSource::CacheToken;
102 pub const SEARCH_CALLBACK: Option<Arc<mistralrs_core::SearchCallback>> = None;
103 pub const PAGED_CACHE_TYPE: PagedCacheType = PagedCacheType::Auto;
104}
105
106pub struct MistralRsForServerBuilder {
141 device: Option<Device>,
143
144 seed: Option<u64>,
146
147 log: Option<String>,
149
150 truncate_sequence: bool,
154
155 model: Option<ModelSelected>,
157
158 models: Vec<ModelConfig>,
160
161 default_model_id: Option<String>,
163
164 max_seqs: usize,
166
167 no_kv_cache: bool,
169
170 chat_template: Option<String>,
173
174 jinja_explicit: Option<String>,
176
177 token_source: TokenSource,
181
182 interactive_mode: bool,
184
185 prefix_cache_n: usize,
187
188 num_device_layers: Option<Vec<String>>,
193
194 in_situ_quant: Option<String>,
196
197 paged_attn_gpu_mem: Option<usize>,
201
202 paged_attn_gpu_mem_usage: Option<f32>,
207
208 paged_ctxt_len: Option<usize>,
213
214 paged_attn_block_size: Option<usize>,
217
218 paged_attn: Option<bool>,
220
221 cpu: bool,
223
224 enable_search: bool,
226
227 search_bert_model: Option<String>,
229
230 search_callback: Option<Arc<SearchCallback>>,
232
233 mcp_client_config: Option<McpClientConfig>,
235
236 paged_cache_type: PagedCacheType,
238}
239
240impl Default for MistralRsForServerBuilder {
241 fn default() -> Self {
243 Self {
244 device: defaults::DEVICE,
245 seed: defaults::SEED,
246 log: defaults::LOG,
247 truncate_sequence: defaults::TRUNCATE_SEQUENCE,
248 model: defaults::MODEL,
249 models: Vec::new(),
250 default_model_id: None,
251 max_seqs: defaults::MAX_SEQS,
252 no_kv_cache: defaults::NO_KV_CACHE,
253 chat_template: defaults::CHAT_TEMPLATE,
254 jinja_explicit: defaults::JINJA_EXPLICIT,
255 token_source: defaults::TOKEN_SOURCE,
256 interactive_mode: defaults::INTERACTIVE_MODE,
257 prefix_cache_n: defaults::PREFIX_CACHE_N,
258 num_device_layers: defaults::NUM_DEVICE_LAYERS,
259 in_situ_quant: defaults::IN_SITU_QUANT,
260 paged_attn_gpu_mem: defaults::PAGED_ATTN_GPU_MEM,
261 paged_attn_gpu_mem_usage: defaults::PAGED_ATTN_GPU_MEM_USAGE,
262 paged_ctxt_len: defaults::PAGED_CTXT_LEN,
263 paged_attn_block_size: defaults::PAGED_ATTN_BLOCK_SIZE,
264 paged_attn: defaults::PAGED_ATTN,
265 cpu: defaults::CPU,
266 enable_search: defaults::ENABLE_SEARCH,
267 search_bert_model: defaults::SEARCH_BERT_MODEL,
268 search_callback: defaults::SEARCH_CALLBACK,
269 mcp_client_config: None,
270 paged_cache_type: defaults::PAGED_CACHE_TYPE,
271 }
272 }
273}
274
275impl MistralRsForServerBuilder {
276 pub fn new() -> Self {
288 Default::default()
289 }
290
291 pub fn with_device(mut self, device: Device) -> Self {
293 self.device = Some(device);
294 self
295 }
296
297 pub fn with_seed(mut self, seed: u64) -> Self {
299 self.seed = Some(seed);
300 self
301 }
302
303 pub fn with_seed_optional(mut self, seed: Option<u64>) -> Self {
305 if let Some(seed) = seed {
306 self = self.with_seed(seed);
307 }
308 self
309 }
310
311 pub fn with_log(mut self, log: String) -> Self {
313 self.log = Some(log);
314 self
315 }
316
317 pub fn with_log_optional(mut self, log: Option<String>) -> Self {
319 if let Some(log) = log {
320 self = self.with_log(log);
321 }
322 self
323 }
324
325 pub fn with_truncate_sequence(mut self, truncate_sequence: bool) -> Self {
327 self.truncate_sequence = truncate_sequence;
328 self
329 }
330
331 pub fn with_model(mut self, model: ModelSelected) -> Self {
333 self.model = Some(model);
334 self
335 }
336
337 pub fn with_model_config(mut self, model_config: ModelConfig) -> Self {
339 self.models.push(model_config);
340 self
341 }
342
343 pub fn with_model_configs(mut self, model_configs: Vec<ModelConfig>) -> Self {
345 self.models.extend(model_configs);
346 self
347 }
348
349 pub fn with_default_model_id(mut self, default_model_id: String) -> Self {
351 self.default_model_id = Some(default_model_id);
352 self
353 }
354
355 pub fn add_model_config(mut self, config: ModelConfig) -> Self {
357 self.models.push(config);
358 self
359 }
360
361 pub fn add_model(mut self, model_id: String, model: ModelSelected) -> Self {
363 self.models.push(ModelConfig::new(model_id, model));
364 self
365 }
366
367 pub fn with_max_seqs(mut self, max_seqs: usize) -> Self {
369 self.max_seqs = max_seqs;
370 self
371 }
372
373 pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
375 self.no_kv_cache = no_kv_cache;
376 self
377 }
378
379 pub fn with_chat_template(mut self, chat_template: String) -> Self {
381 self.chat_template = Some(chat_template);
382 self
383 }
384
385 pub fn with_chat_template_optional(mut self, chat_template: Option<String>) -> Self {
387 if let Some(chat_template) = chat_template {
388 self = self.with_chat_template(chat_template);
389 }
390 self
391 }
392
393 pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
395 self.jinja_explicit = Some(jinja_explicit);
396 self
397 }
398
399 pub fn with_jinja_explicit_optional(mut self, jinja_explicit: Option<String>) -> Self {
401 if let Some(jinja_explicit) = jinja_explicit {
402 self = self.with_jinja_explicit(jinja_explicit);
403 }
404 self
405 }
406
407 pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
409 self.token_source = token_source;
410 self
411 }
412
413 pub fn with_interactive_mode(mut self, interactive_mode: bool) -> Self {
415 self.interactive_mode = interactive_mode;
416 self
417 }
418
419 pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
421 self.prefix_cache_n = prefix_cache_n;
422 self
423 }
424
425 pub fn with_num_device_layers(mut self, num_device_layers: Vec<String>) -> Self {
427 self.num_device_layers = Some(num_device_layers);
428 self
429 }
430
431 pub fn with_num_device_layers_optional(
433 mut self,
434 num_device_layers: Option<Vec<String>>,
435 ) -> Self {
436 if let Some(num_device_layers) = num_device_layers {
437 self = self.with_num_device_layers(num_device_layers);
438 }
439 self
440 }
441
442 pub fn with_in_situ_quant(mut self, in_situ_quant: String) -> Self {
444 self.in_situ_quant = Some(in_situ_quant);
445 self
446 }
447
448 pub fn with_in_situ_quant_optional(mut self, in_situ_quant: Option<String>) -> Self {
450 if let Some(in_situ_quant) = in_situ_quant {
451 self = self.with_in_situ_quant(in_situ_quant);
452 }
453 self
454 }
455
456 pub fn set_paged_attn(mut self, paged_attn: Option<bool>) -> Self {
466 self.paged_attn = paged_attn;
467 self
468 }
469
470 pub fn with_paged_attn_gpu_mem(mut self, paged_attn_gpu_mem: usize) -> Self {
472 self.paged_attn_gpu_mem = Some(paged_attn_gpu_mem);
473 self
474 }
475
476 pub fn with_paged_attn_gpu_mem_optional(mut self, paged_attn_gpu_mem: Option<usize>) -> Self {
478 if let Some(paged_attn_gpu_mem) = paged_attn_gpu_mem {
479 self = self.with_paged_attn_gpu_mem(paged_attn_gpu_mem);
480 }
481 self
482 }
483
484 pub fn with_paged_attn_gpu_mem_usage(mut self, paged_attn_gpu_mem_usage: f32) -> Self {
486 self.paged_attn_gpu_mem_usage = Some(paged_attn_gpu_mem_usage);
487 self
488 }
489
490 pub fn with_paged_attn_gpu_mem_usage_optional(
492 mut self,
493 paged_attn_gpu_mem_usage: Option<f32>,
494 ) -> Self {
495 if let Some(paged_attn_gpu_mem_usage) = paged_attn_gpu_mem_usage {
496 self = self.with_paged_attn_gpu_mem_usage(paged_attn_gpu_mem_usage);
497 }
498 self
499 }
500
501 pub fn with_paged_ctxt_len(mut self, paged_ctxt_len: usize) -> Self {
503 self.paged_ctxt_len = Some(paged_ctxt_len);
504 self
505 }
506
507 pub fn with_paged_ctxt_len_optional(mut self, paged_ctxt_len: Option<usize>) -> Self {
509 if let Some(paged_ctxt_len) = paged_ctxt_len {
510 self = self.with_paged_ctxt_len(paged_ctxt_len);
511 }
512 self
513 }
514
515 pub fn with_paged_attn_block_size(mut self, paged_attn_block_size: usize) -> Self {
517 self.paged_attn_block_size = Some(paged_attn_block_size);
518 self
519 }
520
521 pub fn with_paged_attn_cache_type(mut self, cache_type: PagedCacheType) -> Self {
523 self.paged_cache_type = cache_type;
524 self
525 }
526
527 pub fn with_paged_attn_block_size_optional(
529 mut self,
530 paged_attn_block_size: Option<usize>,
531 ) -> Self {
532 if let Some(paged_attn_block_size) = paged_attn_block_size {
533 self = self.with_paged_attn_block_size(paged_attn_block_size);
534 }
535 self
536 }
537
538 pub fn with_cpu(mut self, cpu: bool) -> Self {
540 self.cpu = cpu;
541 self
542 }
543
544 pub fn with_enable_search(mut self, enable_search: bool) -> Self {
546 self.enable_search = enable_search;
547 self
548 }
549
550 pub fn with_search_bert_model(mut self, search_bert_model: String) -> Self {
552 self.search_bert_model = Some(search_bert_model);
553 self
554 }
555
556 pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
558 self.search_callback = Some(callback);
559 self
560 }
561
562 pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
564 self.mcp_client_config = Some(mcp_config);
565 self
566 }
567
568 pub fn with_mcp_config_optional(mut self, mcp_config: Option<McpClientConfig>) -> Self {
570 if let Some(mcp_config) = mcp_config {
571 self = self.with_mcp_config(mcp_config);
572 }
573 self
574 }
575
576 pub async fn build(self) -> Result<SharedMistralRsState> {
591 if !self.models.is_empty() {
593 self.build_multi_model().await
594 } else {
595 self.build_single_model().await
596 }
597 }
598
599 async fn build_single_model(mut self) -> Result<SharedMistralRsState> {
601 let model = self.model.context("Model was None")?;
602
603 let tgt_non_granular_index = get_tgt_non_granular_index(&model);
604 let dtype = get_model_dtype(&model)?;
605 let auto_device_map_params = get_auto_device_map_params(&model)?;
606
607 if tgt_non_granular_index.is_some() {
608 self.max_seqs = 1;
609 }
610
611 let max_seq_len = auto_device_map_params.max_seq_len();
612
613 let device = if let Some(device) = self.device {
614 device
615 } else {
616 init_device(self.cpu, self.seed)?
617 };
618
619 let mapper = init_mapper(&self.num_device_layers, &auto_device_map_params);
620 let paged_attn = configure_paged_attn(&device, self.paged_attn);
621
622 let cache_config = init_cache_config(
623 self.paged_attn_block_size,
624 self.paged_attn_gpu_mem,
625 self.paged_attn_gpu_mem_usage,
626 self.paged_ctxt_len,
627 self.paged_cache_type,
628 !paged_attn,
629 max_seq_len,
630 )?;
631
632 let loader: Box<dyn Loader> = LoaderBuilder::new(model)
634 .with_no_kv_cache(self.no_kv_cache)
635 .with_chat_template(self.chat_template)
636 .with_jinja_explicit(self.jinja_explicit)
637 .build()?;
638
639 mistralrs_instance_info(&*loader);
640
641 let isq = self
642 .in_situ_quant
643 .as_ref()
644 .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
645
646 let pipeline: LoadedPipeline = loader.load_model_from_hf(
647 None,
648 self.token_source,
649 &dtype,
650 &device,
651 false,
652 mapper,
653 isq,
654 cache_config,
655 )?;
656 info!("Model loaded.");
657
658 let scheduler_config = init_scheduler_config(&cache_config, &pipeline, self.max_seqs).await;
659
660 let bert_model = get_bert_model(self.enable_search, self.search_bert_model);
661
662 let mut builder = MistralRsBuilder::new(
663 pipeline,
664 scheduler_config,
665 !self.interactive_mode,
666 bert_model,
667 )
668 .with_opt_log(self.log)
669 .with_truncate_sequence(self.truncate_sequence)
670 .with_no_kv_cache(self.no_kv_cache)
671 .with_prefix_cache_n(self.prefix_cache_n);
672
673 if let Some(mcp_config) = self.mcp_client_config {
675 builder = builder.with_mcp_client(mcp_config);
676 }
677
678 let mistralrs = builder.build().await;
679
680 Ok(mistralrs)
681 }
682
683 pub async fn build_multi_model(mut self) -> Result<SharedMistralRsState> {
685 if self.models.is_empty() {
686 anyhow::bail!("No models configured for multi-model mode");
687 }
688
689 let first_model = &self.models[0];
691 let model = first_model.model.clone();
692
693 let tgt_non_granular_index = get_tgt_non_granular_index(&model);
694 let dtype = get_model_dtype(&model)?;
695 let auto_device_map_params = get_auto_device_map_params(&model)?;
696
697 if tgt_non_granular_index.is_some() {
698 self.max_seqs = 1;
699 }
700
701 let max_seq_len = auto_device_map_params.max_seq_len();
702
703 let device = if let Some(device) = self.device {
704 device
705 } else {
706 init_device(self.cpu, self.seed)?
707 };
708
709 let loader: Box<dyn Loader> = LoaderBuilder::new(model)
711 .with_no_kv_cache(self.no_kv_cache)
712 .with_chat_template(
713 first_model
714 .chat_template
715 .clone()
716 .or(self.chat_template.clone()),
717 )
718 .with_jinja_explicit(
719 first_model
720 .jinja_explicit
721 .clone()
722 .or(self.jinja_explicit.clone()),
723 )
724 .build()?;
725
726 mistralrs_instance_info(&*loader);
727
728 let mapper = init_mapper(
729 &first_model
730 .num_device_layers
731 .clone()
732 .or(self.num_device_layers.clone()),
733 &auto_device_map_params,
734 );
735 let paged_attn = configure_paged_attn(&device, self.paged_attn);
736
737 let cache_config = init_cache_config(
738 self.paged_attn_block_size,
739 self.paged_attn_gpu_mem,
740 self.paged_attn_gpu_mem_usage,
741 self.paged_ctxt_len,
742 self.paged_cache_type,
743 !paged_attn,
744 max_seq_len,
745 )?;
746
747 let isq = first_model
748 .in_situ_quant
749 .as_ref()
750 .or(self.in_situ_quant.as_ref())
751 .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
752
753 let mut pipeline_names = Vec::new();
754
755 let pipeline: LoadedPipeline = loader.load_model_from_hf(
756 None,
757 self.token_source.clone(),
758 &dtype,
759 &device,
760 false,
761 mapper,
762 isq,
763 cache_config,
764 )?;
765 let first_pipeline_name = pipeline.lock().await.name();
766 info!(
767 "First model loaded: `{first_pipeline_name}` (from config key: {})",
768 first_model.model_id
769 );
770 pipeline_names.push(first_pipeline_name);
771
772 let scheduler_config = init_scheduler_config(&cache_config, &pipeline, self.max_seqs).await;
773 let bert_model = get_bert_model(self.enable_search, self.search_bert_model);
774
775 let mut builder = MistralRsBuilder::new(
777 pipeline,
778 scheduler_config.clone(),
779 !self.interactive_mode,
780 bert_model.clone(),
781 )
782 .with_opt_log(self.log.clone())
783 .with_truncate_sequence(self.truncate_sequence)
784 .with_no_kv_cache(self.no_kv_cache)
785 .with_prefix_cache_n(self.prefix_cache_n);
786
787 if let Some(mcp_config) = self.mcp_client_config.clone() {
789 builder = builder.with_mcp_client(mcp_config);
790 }
791
792 let mistralrs = builder.build().await;
793
794 for model_config in self.models.iter().skip(1) {
796 info!(
797 "Loading additional model from config key: {}",
798 model_config.model_id
799 );
800
801 let model = model_config.model.clone();
802 let dtype = get_model_dtype(&model)?;
803 let auto_device_map_params = get_auto_device_map_params(&model)?;
804
805 let loader: Box<dyn Loader> = LoaderBuilder::new(model)
806 .with_no_kv_cache(self.no_kv_cache)
807 .with_chat_template(
808 model_config
809 .chat_template
810 .clone()
811 .or(self.chat_template.clone()),
812 )
813 .with_jinja_explicit(
814 model_config
815 .jinja_explicit
816 .clone()
817 .or(self.jinja_explicit.clone()),
818 )
819 .build()?;
820
821 let mapper = init_mapper(
822 &model_config
823 .num_device_layers
824 .clone()
825 .or(self.num_device_layers.clone()),
826 &auto_device_map_params,
827 );
828
829 let isq = model_config
830 .in_situ_quant
831 .as_ref()
832 .or(self.in_situ_quant.as_ref())
833 .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
834
835 let pipeline: LoadedPipeline = loader.load_model_from_hf(
836 None,
837 self.token_source.clone(),
838 &dtype,
839 &device,
840 false,
841 mapper,
842 isq,
843 cache_config,
844 )?;
845
846 let pipeline_name = pipeline.lock().await.name();
848
849 if pipeline_names.contains(&pipeline_name) {
851 anyhow::bail!(
852 "Model ID conflict: '{}' is already registered. Models from config keys '{}' and previous models have the same pipeline identifier.",
853 pipeline_name,
854 model_config.model_id
855 );
856 }
857
858 let engine_config = mistralrs_core::EngineConfig {
860 truncate_sequence: self.truncate_sequence,
861 no_kv_cache: self.no_kv_cache,
862 no_prefix_cache: false,
863 prefix_cache_n: self.prefix_cache_n,
864 disable_eos_stop: false,
865 throughput_logging_enabled: !self.interactive_mode,
866 search_embedding_model: bert_model.clone(),
867 search_callback: self.search_callback.clone(),
868 tool_callbacks: HashMap::new(),
869 tool_callbacks_with_tools: HashMap::new(),
870 };
871
872 let mut add_model_config = mistralrs_core::AddModelConfig::new(engine_config);
873 if let Some(mcp_config) = self.mcp_client_config.clone() {
874 add_model_config = add_model_config.with_mcp_config(mcp_config);
875 }
876
877 mistralrs
878 .add_model(
879 pipeline_name.clone(),
880 pipeline,
881 scheduler_config.clone(),
882 add_model_config,
883 )
884 .await
885 .map_err(|e| anyhow::anyhow!("Failed to add model {}: {}", pipeline_name, e))?;
886
887 info!(
888 "Model `{pipeline_name}` registered successfully (from config key: {})",
889 model_config.model_id
890 );
891 pipeline_names.push(pipeline_name);
892 }
893
894 if let Some(ref default_model_id) = self.default_model_id {
896 mistralrs
897 .set_default_model_id(default_model_id)
898 .map_err(|e| anyhow::anyhow!("Failed to set default model: {}", e))?;
899 }
900
901 info!("All models loaded: `{}`", pipeline_names.join("`, `"));
903
904 if let Some(ref default_id) = self.default_model_id {
906 info!("Default model: {}", default_id);
907 } else {
908 info!(
909 "Default model: {} (first model, from config key: {})",
910 pipeline_names[0], self.models[0].model_id
911 );
912 }
913 Ok(mistralrs)
914 }
915}
916
917fn init_device(force_cpu: bool, seed: Option<u64>) -> Result<candle_core::Device> {
920 #[cfg(feature = "metal")]
921 let device = if force_cpu {
922 Device::Cpu
923 } else {
924 Device::new_metal(0)?
925 };
926 #[cfg(not(feature = "metal"))]
927 #[allow(clippy::if_same_then_else)]
928 let device = if force_cpu {
929 Device::Cpu
930 } else if mistralrs_core::distributed::use_nccl() {
931 Device::Cpu
932 } else {
933 Device::cuda_if_available(0)?
934 };
935
936 if let Some(seed) = seed {
937 device.set_seed(seed)?;
938 }
939
940 Ok(device)
941}
942
943fn init_mapper(
945 num_device_layers: &Option<Vec<String>>,
946 auto_device_map_params: &AutoDeviceMapParams,
947) -> DeviceMapSetting {
948 if let Some(device_layers) = num_device_layers {
950 if device_layers.len() == 1 && device_layers[0].parse::<usize>().is_ok() {
951 let layers = device_layers[0].parse::<usize>().unwrap();
952 DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(vec![
953 DeviceLayerMapMetadata { ordinal: 0, layers },
954 ]))
955 } else {
956 let mut mapping = Vec::new();
957 for layer in device_layers {
958 let split = layer.splitn(2, ':').collect::<Vec<_>>();
959 if split.len() < 2 {
960 panic!("Expected layer to be of format ORD:NUM, got {layer}");
961 }
962 let ord = split[0]
963 .parse::<usize>()
964 .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[0]));
965 let num = split[1]
966 .parse::<usize>()
967 .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[1]));
968 for DeviceLayerMapMetadata { ordinal, layers: _ } in &mapping {
969 if *ordinal == ord {
970 panic!("Duplicate ordinal {ord}");
971 }
972 }
973 mapping.push(DeviceLayerMapMetadata {
974 ordinal: ord,
975 layers: num,
976 });
977 }
978 DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(mapping))
979 }
980 } else {
981 DeviceMapSetting::Auto(auto_device_map_params.clone())
982 }
983}
984
985fn mistralrs_instance_info(loader: &dyn Loader) {
987 info!(
988 "avx: {}, neon: {}, simd128: {}, f16c: {}",
989 candle_core::utils::with_avx(),
990 candle_core::utils::with_neon(),
991 candle_core::utils::with_simd128(),
992 candle_core::utils::with_f16c()
993 );
994
995 info!("Sampling method: penalties -> temperature -> topk -> topp -> minp -> multinomial");
996 info!("Model kind is: {}", loader.get_kind().to_string());
997}
998
999fn configure_paged_attn(device: &Device, paged_attn: Option<bool>) -> bool {
1001 if device.is_cpu() {
1002 if paged_attn == Some(true) {
1003 warn!("Paged attention is not supported on CPU.");
1004 }
1005
1006 defaults::PAGED_ATTN_CPU
1007 } else if device.is_cuda() || mistralrs_core::distributed::use_nccl() {
1008 paged_attn.unwrap_or(defaults::PAGED_ATTN_CUDA)
1009 } else if device.is_metal() {
1010 paged_attn.unwrap_or(defaults::PAGED_ATTN_METAL)
1011 } else {
1012 false
1013 }
1014}
1015
1016fn init_cache_config(
1018 paged_attn_block_size: Option<usize>,
1019 paged_attn_gpu_mem: Option<usize>,
1020 paged_attn_gpu_mem_usage: Option<f32>,
1021 paged_ctxt_len: Option<usize>,
1022 cache_type: PagedCacheType,
1023 no_paged_attn: bool,
1024 max_seq_len: usize,
1025) -> Result<Option<PagedAttentionConfig>> {
1026 match (
1027 paged_attn_block_size,
1028 paged_attn_gpu_mem,
1029 paged_attn_gpu_mem_usage,
1030 paged_ctxt_len,
1031 paged_attn_supported(),
1032 no_paged_attn,
1033 ) {
1034 (block_size, None, None, None, true, false) => Ok(Some(PagedAttentionConfig::new(
1035 block_size,
1036 MemoryGpuConfig::ContextSize(max_seq_len),
1037 cache_type,
1038 )?)),
1039 (block_size, None, None, Some(ctxt), true, false) => Ok(Some(PagedAttentionConfig::new(
1040 block_size,
1041 MemoryGpuConfig::ContextSize(ctxt),
1042 cache_type,
1043 )?)),
1044 (block_size, None, Some(f), None, true, false) => Ok(Some(PagedAttentionConfig::new(
1045 block_size,
1046 MemoryGpuConfig::Utilization(f),
1047 cache_type,
1048 )?)),
1049 (block_size, Some(m), None, None, true, false) => Ok(Some(PagedAttentionConfig::new(
1050 block_size,
1051 MemoryGpuConfig::MbAmount(m),
1052 cache_type,
1053 )?)),
1054 (block_size, Some(_m), Some(f), None, true, false) => {
1055 info!("Both memory size, and usage were specified, defaulting to the usage value.");
1056 Ok(Some(PagedAttentionConfig::new(
1057 block_size,
1058 MemoryGpuConfig::Utilization(f),
1059 cache_type,
1060 )?))
1061 }
1062 (block_size, Some(_m), None, Some(ctxt), true, false) => {
1063 info!("All memory size and ctxt len, defaulting to the context len value.");
1064 Ok(Some(PagedAttentionConfig::new(
1065 block_size,
1066 MemoryGpuConfig::ContextSize(ctxt),
1067 cache_type,
1068 )?))
1069 }
1070 (block_size, None, Some(f), Some(_ctxt), true, false) => {
1071 info!("Both ctxt len and usage were specified, defaulting to the usage value.");
1072 Ok(Some(PagedAttentionConfig::new(
1073 block_size,
1074 MemoryGpuConfig::Utilization(f),
1075 cache_type,
1076 )?))
1077 }
1078 (_, _, _, _, _, _) => Ok(None),
1079 }
1080}
1081
1082async fn init_scheduler_config(
1084 cache_config: &Option<PagedAttentionConfig>,
1085 pipeline: &LoadedPipeline,
1086 args_max_seqs: usize,
1087) -> SchedulerConfig {
1088 if cache_config.is_some() {
1089 if let Some(ref cache_config) = pipeline.lock().await.get_metadata().cache_config {
1091 SchedulerConfig::PagedAttentionMeta {
1092 max_num_seqs: args_max_seqs,
1093 config: cache_config.clone(),
1094 }
1095 } else {
1096 SchedulerConfig::DefaultScheduler {
1097 method: DefaultSchedulerMethod::Fixed(args_max_seqs.try_into().unwrap()),
1098 }
1099 }
1100 } else {
1101 SchedulerConfig::DefaultScheduler {
1102 method: DefaultSchedulerMethod::Fixed(args_max_seqs.try_into().unwrap()),
1103 }
1104 }
1105}
1106
1107pub fn configure_paged_attn_from_flags(
1112 paged_attn: bool,
1113 no_paged_attn: bool,
1114) -> Result<Option<bool>> {
1115 match (paged_attn, no_paged_attn) {
1116 (true, true) => {
1117 anyhow::bail!("Error: `--paged-attn` and `--no-paged-attn` cannot be used together.");
1118 }
1119 (true, false) => Ok(Some(true)),
1120 (false, true) => Ok(Some(false)),
1121 (false, false) => Ok(None),
1122 }
1123}
1124
1125pub fn get_bert_model(
1127 enable_search: bool,
1128 search_bert_model: Option<String>,
1129) -> Option<BertEmbeddingModel> {
1130 if enable_search {
1131 Some(
1132 search_bert_model
1133 .map(BertEmbeddingModel::Custom)
1134 .unwrap_or_default(),
1135 )
1136 } else {
1137 None
1138 }
1139}