1use std::sync::Arc;
4
5use anyhow::{Context, Result};
6use candle_core::Device;
7use mistralrs_core::{
8 get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, paged_attn_supported,
9 parse_isq_value, AutoDeviceMapParams, DefaultSchedulerMethod, DeviceLayerMapMetadata,
10 DeviceMapMetadata, DeviceMapSetting, Loader, LoaderBuilder, McpClientConfig, MemoryGpuConfig,
11 MistralRsBuilder, ModelSelected, PagedAttentionConfig, PagedCacheType, SchedulerConfig,
12 SearchCallback, SearchEmbeddingModel, TokenSource,
13};
14use tracing::{info, warn};
15
16use crate::types::{LoadedPipeline, SharedMistralRsState};
17use std::collections::HashMap;
18
19#[derive(Clone, serde::Deserialize)]
21pub struct ModelConfig {
22 pub model_id: String,
24 pub model: ModelSelected,
26 pub chat_template: Option<String>,
28 pub jinja_explicit: Option<String>,
30 pub num_device_layers: Option<Vec<String>>,
32 pub in_situ_quant: Option<String>,
34}
35
36impl ModelConfig {
37 pub fn new(model_id: String, model: ModelSelected) -> Self {
38 Self {
39 model_id,
40 model,
41 chat_template: None,
42 jinja_explicit: None,
43 num_device_layers: None,
44 in_situ_quant: None,
45 }
46 }
47
48 pub fn with_chat_template(mut self, chat_template: String) -> Self {
49 self.chat_template = Some(chat_template);
50 self
51 }
52
53 pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
54 self.jinja_explicit = Some(jinja_explicit);
55 self
56 }
57
58 pub fn with_num_device_layers(mut self, num_device_layers: Vec<String>) -> Self {
59 self.num_device_layers = Some(num_device_layers);
60 self
61 }
62
63 pub fn with_in_situ_quant(mut self, in_situ_quant: String) -> Self {
64 self.in_situ_quant = Some(in_situ_quant);
65 self
66 }
67}
68
69pub mod defaults {
70 use super::SearchEmbeddingModel;
71 use std::sync::Arc;
75
76 use mistralrs_core::PagedCacheType;
77
78 pub const DEVICE: Option<candle_core::Device> = None;
79 pub const SEED: Option<u64> = None;
80 pub const LOG: Option<String> = None;
81 pub const MODEL: Option<mistralrs_core::ModelSelected> = None;
82 pub const MAX_SEQS: usize = 16;
83 pub const NO_KV_CACHE: bool = false;
84 pub const CHAT_TEMPLATE: Option<String> = None;
85 pub const JINJA_EXPLICIT: Option<String> = None;
86 pub const INTERACTIVE_MODE: bool = false;
87 pub const PREFIX_CACHE_N: usize = 16;
88 pub const NUM_DEVICE_LAYERS: Option<Vec<String>> = None;
89 pub const IN_SITU_QUANT: Option<String> = None;
90 pub const PAGED_ATTN_GPU_MEM: Option<usize> = None;
91 pub const PAGED_ATTN_GPU_MEM_USAGE: Option<f32> = None;
92 pub const PAGED_CTXT_LEN: Option<usize> = None;
93 pub const PAGED_ATTN_BLOCK_SIZE: Option<usize> = None;
94 pub const PAGED_ATTN: Option<bool> = None;
95 pub const PAGED_ATTN_CPU: bool = false;
96 pub const PAGED_ATTN_CUDA: bool = true;
97 pub const PAGED_ATTN_METAL: bool = false;
98 pub const CPU: bool = false;
99 pub const ENABLE_SEARCH: bool = false;
100 pub const SEARCH_EMBEDDING_MODEL: Option<SearchEmbeddingModel> = None;
101 pub const TOKEN_SOURCE: mistralrs_core::TokenSource = mistralrs_core::TokenSource::CacheToken;
102 pub const SEARCH_CALLBACK: Option<Arc<mistralrs_core::SearchCallback>> = None;
103 pub const PAGED_CACHE_TYPE: PagedCacheType = PagedCacheType::Auto;
104}
105
106pub struct MistralRsForServerBuilder {
140 device: Option<Device>,
142
143 seed: Option<u64>,
145
146 log: Option<String>,
148
149 model: Option<ModelSelected>,
151
152 models: Vec<ModelConfig>,
154
155 default_model_id: Option<String>,
157
158 max_seqs: usize,
160
161 no_kv_cache: bool,
163
164 chat_template: Option<String>,
167
168 jinja_explicit: Option<String>,
170
171 token_source: TokenSource,
175
176 interactive_mode: bool,
178
179 prefix_cache_n: usize,
181
182 num_device_layers: Option<Vec<String>>,
187
188 in_situ_quant: Option<String>,
190
191 paged_attn_gpu_mem: Option<usize>,
195
196 paged_attn_gpu_mem_usage: Option<f32>,
201
202 paged_ctxt_len: Option<usize>,
207
208 paged_attn_block_size: Option<usize>,
211
212 paged_attn: Option<bool>,
214
215 cpu: bool,
217
218 enable_search: bool,
220
221 search_embedding_model: Option<SearchEmbeddingModel>,
223
224 search_callback: Option<Arc<SearchCallback>>,
226
227 mcp_client_config: Option<McpClientConfig>,
229
230 paged_cache_type: PagedCacheType,
232}
233
234impl Default for MistralRsForServerBuilder {
235 fn default() -> Self {
237 Self {
238 device: defaults::DEVICE,
239 seed: defaults::SEED,
240 log: defaults::LOG,
241 model: defaults::MODEL,
242 models: Vec::new(),
243 default_model_id: None,
244 max_seqs: defaults::MAX_SEQS,
245 no_kv_cache: defaults::NO_KV_CACHE,
246 chat_template: defaults::CHAT_TEMPLATE,
247 jinja_explicit: defaults::JINJA_EXPLICIT,
248 token_source: defaults::TOKEN_SOURCE,
249 interactive_mode: defaults::INTERACTIVE_MODE,
250 prefix_cache_n: defaults::PREFIX_CACHE_N,
251 num_device_layers: defaults::NUM_DEVICE_LAYERS,
252 in_situ_quant: defaults::IN_SITU_QUANT,
253 paged_attn_gpu_mem: defaults::PAGED_ATTN_GPU_MEM,
254 paged_attn_gpu_mem_usage: defaults::PAGED_ATTN_GPU_MEM_USAGE,
255 paged_ctxt_len: defaults::PAGED_CTXT_LEN,
256 paged_attn_block_size: defaults::PAGED_ATTN_BLOCK_SIZE,
257 paged_attn: defaults::PAGED_ATTN,
258 cpu: defaults::CPU,
259 enable_search: defaults::ENABLE_SEARCH,
260 search_embedding_model: defaults::SEARCH_EMBEDDING_MODEL,
261 search_callback: defaults::SEARCH_CALLBACK,
262 mcp_client_config: None,
263 paged_cache_type: defaults::PAGED_CACHE_TYPE,
264 }
265 }
266}
267
268impl MistralRsForServerBuilder {
269 pub fn new() -> Self {
281 Default::default()
282 }
283
284 pub fn with_device(mut self, device: Device) -> Self {
286 self.device = Some(device);
287 self
288 }
289
290 pub fn with_seed(mut self, seed: u64) -> Self {
292 self.seed = Some(seed);
293 self
294 }
295
296 pub fn with_seed_optional(mut self, seed: Option<u64>) -> Self {
298 if let Some(seed) = seed {
299 self = self.with_seed(seed);
300 }
301 self
302 }
303
304 pub fn with_log(mut self, log: String) -> Self {
306 self.log = Some(log);
307 self
308 }
309
310 pub fn with_log_optional(mut self, log: Option<String>) -> Self {
312 if let Some(log) = log {
313 self = self.with_log(log);
314 }
315 self
316 }
317
318 pub fn with_model(mut self, model: ModelSelected) -> Self {
320 self.model = Some(model);
321 self
322 }
323
324 pub fn with_model_config(mut self, model_config: ModelConfig) -> Self {
326 self.models.push(model_config);
327 self
328 }
329
330 pub fn with_model_configs(mut self, model_configs: Vec<ModelConfig>) -> Self {
332 self.models.extend(model_configs);
333 self
334 }
335
336 pub fn with_default_model_id(mut self, default_model_id: String) -> Self {
338 self.default_model_id = Some(default_model_id);
339 self
340 }
341
342 pub fn add_model_config(mut self, config: ModelConfig) -> Self {
344 self.models.push(config);
345 self
346 }
347
348 pub fn add_model(mut self, model_id: String, model: ModelSelected) -> Self {
350 self.models.push(ModelConfig::new(model_id, model));
351 self
352 }
353
354 pub fn with_max_seqs(mut self, max_seqs: usize) -> Self {
356 self.max_seqs = max_seqs;
357 self
358 }
359
360 pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
362 self.no_kv_cache = no_kv_cache;
363 self
364 }
365
366 pub fn with_chat_template(mut self, chat_template: String) -> Self {
368 self.chat_template = Some(chat_template);
369 self
370 }
371
372 pub fn with_chat_template_optional(mut self, chat_template: Option<String>) -> Self {
374 if let Some(chat_template) = chat_template {
375 self = self.with_chat_template(chat_template);
376 }
377 self
378 }
379
380 pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
382 self.jinja_explicit = Some(jinja_explicit);
383 self
384 }
385
386 pub fn with_jinja_explicit_optional(mut self, jinja_explicit: Option<String>) -> Self {
388 if let Some(jinja_explicit) = jinja_explicit {
389 self = self.with_jinja_explicit(jinja_explicit);
390 }
391 self
392 }
393
394 pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
396 self.token_source = token_source;
397 self
398 }
399
400 pub fn with_interactive_mode(mut self, interactive_mode: bool) -> Self {
402 self.interactive_mode = interactive_mode;
403 self
404 }
405
406 pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
408 self.prefix_cache_n = prefix_cache_n;
409 self
410 }
411
412 pub fn with_num_device_layers(mut self, num_device_layers: Vec<String>) -> Self {
414 self.num_device_layers = Some(num_device_layers);
415 self
416 }
417
418 pub fn with_num_device_layers_optional(
420 mut self,
421 num_device_layers: Option<Vec<String>>,
422 ) -> Self {
423 if let Some(num_device_layers) = num_device_layers {
424 self = self.with_num_device_layers(num_device_layers);
425 }
426 self
427 }
428
429 pub fn with_in_situ_quant(mut self, in_situ_quant: String) -> Self {
431 self.in_situ_quant = Some(in_situ_quant);
432 self
433 }
434
435 pub fn with_in_situ_quant_optional(mut self, in_situ_quant: Option<String>) -> Self {
437 if let Some(in_situ_quant) = in_situ_quant {
438 self = self.with_in_situ_quant(in_situ_quant);
439 }
440 self
441 }
442
443 pub fn set_paged_attn(mut self, paged_attn: Option<bool>) -> Self {
453 self.paged_attn = paged_attn;
454 self
455 }
456
457 pub fn with_paged_attn_gpu_mem(mut self, paged_attn_gpu_mem: usize) -> Self {
459 self.paged_attn_gpu_mem = Some(paged_attn_gpu_mem);
460 self
461 }
462
463 pub fn with_paged_attn_gpu_mem_optional(mut self, paged_attn_gpu_mem: Option<usize>) -> Self {
465 if let Some(paged_attn_gpu_mem) = paged_attn_gpu_mem {
466 self = self.with_paged_attn_gpu_mem(paged_attn_gpu_mem);
467 }
468 self
469 }
470
471 pub fn with_paged_attn_gpu_mem_usage(mut self, paged_attn_gpu_mem_usage: f32) -> Self {
473 self.paged_attn_gpu_mem_usage = Some(paged_attn_gpu_mem_usage);
474 self
475 }
476
477 pub fn with_paged_attn_gpu_mem_usage_optional(
479 mut self,
480 paged_attn_gpu_mem_usage: Option<f32>,
481 ) -> Self {
482 if let Some(paged_attn_gpu_mem_usage) = paged_attn_gpu_mem_usage {
483 self = self.with_paged_attn_gpu_mem_usage(paged_attn_gpu_mem_usage);
484 }
485 self
486 }
487
488 pub fn with_paged_ctxt_len(mut self, paged_ctxt_len: usize) -> Self {
490 self.paged_ctxt_len = Some(paged_ctxt_len);
491 self
492 }
493
494 pub fn with_paged_ctxt_len_optional(mut self, paged_ctxt_len: Option<usize>) -> Self {
496 if let Some(paged_ctxt_len) = paged_ctxt_len {
497 self = self.with_paged_ctxt_len(paged_ctxt_len);
498 }
499 self
500 }
501
502 pub fn with_paged_attn_block_size(mut self, paged_attn_block_size: usize) -> Self {
504 self.paged_attn_block_size = Some(paged_attn_block_size);
505 self
506 }
507
508 pub fn with_paged_attn_cache_type(mut self, cache_type: PagedCacheType) -> Self {
510 self.paged_cache_type = cache_type;
511 self
512 }
513
514 pub fn with_paged_attn_block_size_optional(
516 mut self,
517 paged_attn_block_size: Option<usize>,
518 ) -> Self {
519 if let Some(paged_attn_block_size) = paged_attn_block_size {
520 self = self.with_paged_attn_block_size(paged_attn_block_size);
521 }
522 self
523 }
524
525 pub fn with_cpu(mut self, cpu: bool) -> Self {
527 self.cpu = cpu;
528 self
529 }
530
531 pub fn with_enable_search(mut self, enable_search: bool) -> Self {
533 self.enable_search = enable_search;
534 self
535 }
536
537 pub fn with_search_embedding_model(
539 mut self,
540 search_embedding_model: SearchEmbeddingModel,
541 ) -> Self {
542 self.search_embedding_model = Some(search_embedding_model);
543 self
544 }
545
546 pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
548 self.search_callback = Some(callback);
549 self
550 }
551
552 pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
554 self.mcp_client_config = Some(mcp_config);
555 self
556 }
557
558 pub fn with_mcp_config_optional(mut self, mcp_config: Option<McpClientConfig>) -> Self {
560 if let Some(mcp_config) = mcp_config {
561 self = self.with_mcp_config(mcp_config);
562 }
563 self
564 }
565
566 pub async fn build(self) -> Result<SharedMistralRsState> {
581 if !self.models.is_empty() {
583 self.build_multi_model().await
584 } else {
585 self.build_single_model().await
586 }
587 }
588
589 async fn build_single_model(mut self) -> Result<SharedMistralRsState> {
591 let model = self.model.context("Model was None")?;
592
593 let tgt_non_granular_index = get_tgt_non_granular_index(&model);
594 let dtype = get_model_dtype(&model)?;
595 let auto_device_map_params = get_auto_device_map_params(&model)?;
596
597 if tgt_non_granular_index.is_some() {
598 self.max_seqs = 1;
599 }
600
601 let max_seq_len = auto_device_map_params.max_seq_len();
602
603 let device = if let Some(device) = self.device {
604 device
605 } else {
606 init_device(self.cpu, self.seed)?
607 };
608
609 let mapper = init_mapper(&self.num_device_layers, &auto_device_map_params);
610 let paged_attn = configure_paged_attn(&device, self.paged_attn);
611
612 let cache_config = init_cache_config(
613 self.paged_attn_block_size,
614 self.paged_attn_gpu_mem,
615 self.paged_attn_gpu_mem_usage,
616 self.paged_ctxt_len,
617 self.paged_cache_type,
618 !paged_attn,
619 max_seq_len,
620 )?;
621
622 let loader: Box<dyn Loader> = LoaderBuilder::new(model)
624 .with_no_kv_cache(self.no_kv_cache)
625 .with_chat_template(self.chat_template)
626 .with_jinja_explicit(self.jinja_explicit)
627 .build()?;
628
629 mistralrs_instance_info(&*loader);
630
631 let isq = self
632 .in_situ_quant
633 .as_ref()
634 .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
635
636 let pipeline: LoadedPipeline = loader.load_model_from_hf(
637 None,
638 self.token_source,
639 &dtype,
640 &device,
641 false,
642 mapper,
643 isq,
644 cache_config,
645 )?;
646 info!("Model loaded.");
647
648 let scheduler_config = init_scheduler_config(&cache_config, &pipeline, self.max_seqs).await;
649
650 let search_embedding_model =
651 get_search_embedding_model(self.enable_search, self.search_embedding_model);
652
653 let mut builder = MistralRsBuilder::new(
654 pipeline,
655 scheduler_config,
656 !self.interactive_mode,
657 search_embedding_model,
658 )
659 .with_opt_log(self.log)
660 .with_no_kv_cache(self.no_kv_cache)
661 .with_prefix_cache_n(self.prefix_cache_n);
662
663 if let Some(mcp_config) = self.mcp_client_config {
665 builder = builder.with_mcp_client(mcp_config);
666 }
667
668 let mistralrs = builder.build().await;
669
670 Ok(mistralrs)
671 }
672
673 pub async fn build_multi_model(mut self) -> Result<SharedMistralRsState> {
675 if self.models.is_empty() {
676 anyhow::bail!("No models configured for multi-model mode");
677 }
678
679 let first_model = &self.models[0];
681 let model = first_model.model.clone();
682
683 let tgt_non_granular_index = get_tgt_non_granular_index(&model);
684 let dtype = get_model_dtype(&model)?;
685 let auto_device_map_params = get_auto_device_map_params(&model)?;
686
687 if tgt_non_granular_index.is_some() {
688 self.max_seqs = 1;
689 }
690
691 let max_seq_len = auto_device_map_params.max_seq_len();
692
693 let device = if let Some(device) = self.device {
694 device
695 } else {
696 init_device(self.cpu, self.seed)?
697 };
698
699 let loader: Box<dyn Loader> = LoaderBuilder::new(model)
701 .with_no_kv_cache(self.no_kv_cache)
702 .with_chat_template(
703 first_model
704 .chat_template
705 .clone()
706 .or(self.chat_template.clone()),
707 )
708 .with_jinja_explicit(
709 first_model
710 .jinja_explicit
711 .clone()
712 .or(self.jinja_explicit.clone()),
713 )
714 .build()?;
715
716 mistralrs_instance_info(&*loader);
717
718 let mapper = init_mapper(
719 &first_model
720 .num_device_layers
721 .clone()
722 .or(self.num_device_layers.clone()),
723 &auto_device_map_params,
724 );
725 let paged_attn = configure_paged_attn(&device, self.paged_attn);
726
727 let cache_config = init_cache_config(
728 self.paged_attn_block_size,
729 self.paged_attn_gpu_mem,
730 self.paged_attn_gpu_mem_usage,
731 self.paged_ctxt_len,
732 self.paged_cache_type,
733 !paged_attn,
734 max_seq_len,
735 )?;
736
737 let isq = first_model
738 .in_situ_quant
739 .as_ref()
740 .or(self.in_situ_quant.as_ref())
741 .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
742
743 let mut pipeline_names = Vec::new();
744
745 let pipeline: LoadedPipeline = loader.load_model_from_hf(
746 None,
747 self.token_source.clone(),
748 &dtype,
749 &device,
750 false,
751 mapper,
752 isq,
753 cache_config,
754 )?;
755 let first_pipeline_name = pipeline.lock().await.name();
756 info!(
757 "First model loaded: `{first_pipeline_name}` (from config key: {})",
758 first_model.model_id
759 );
760 pipeline_names.push(first_pipeline_name);
761
762 let scheduler_config = init_scheduler_config(&cache_config, &pipeline, self.max_seqs).await;
763 let search_embedding_model =
764 get_search_embedding_model(self.enable_search, self.search_embedding_model);
765
766 let mut builder = MistralRsBuilder::new(
768 pipeline,
769 scheduler_config.clone(),
770 !self.interactive_mode,
771 search_embedding_model,
772 )
773 .with_opt_log(self.log.clone())
774 .with_no_kv_cache(self.no_kv_cache)
775 .with_prefix_cache_n(self.prefix_cache_n);
776
777 if let Some(mcp_config) = self.mcp_client_config.clone() {
779 builder = builder.with_mcp_client(mcp_config);
780 }
781
782 let mistralrs = builder.build().await;
783
784 for model_config in self.models.iter().skip(1) {
786 info!(
787 "Loading additional model from config key: {}",
788 model_config.model_id
789 );
790
791 let model = model_config.model.clone();
792 let dtype = get_model_dtype(&model)?;
793 let auto_device_map_params = get_auto_device_map_params(&model)?;
794
795 let loader: Box<dyn Loader> = LoaderBuilder::new(model)
796 .with_no_kv_cache(self.no_kv_cache)
797 .with_chat_template(
798 model_config
799 .chat_template
800 .clone()
801 .or(self.chat_template.clone()),
802 )
803 .with_jinja_explicit(
804 model_config
805 .jinja_explicit
806 .clone()
807 .or(self.jinja_explicit.clone()),
808 )
809 .build()?;
810
811 let mapper = init_mapper(
812 &model_config
813 .num_device_layers
814 .clone()
815 .or(self.num_device_layers.clone()),
816 &auto_device_map_params,
817 );
818
819 let isq = model_config
820 .in_situ_quant
821 .as_ref()
822 .or(self.in_situ_quant.as_ref())
823 .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
824
825 let pipeline: LoadedPipeline = loader.load_model_from_hf(
826 None,
827 self.token_source.clone(),
828 &dtype,
829 &device,
830 false,
831 mapper,
832 isq,
833 cache_config,
834 )?;
835
836 let pipeline_name = pipeline.lock().await.name();
838
839 if pipeline_names.contains(&pipeline_name) {
841 anyhow::bail!(
842 "Model ID conflict: '{}' is already registered. Models from config keys '{}' and previous models have the same pipeline identifier.",
843 pipeline_name,
844 model_config.model_id
845 );
846 }
847
848 let engine_config = mistralrs_core::EngineConfig {
850 no_kv_cache: self.no_kv_cache,
851 no_prefix_cache: false,
852 prefix_cache_n: self.prefix_cache_n,
853 disable_eos_stop: false,
854 throughput_logging_enabled: !self.interactive_mode,
855 search_embedding_model,
856 search_callback: self.search_callback.clone(),
857 tool_callbacks: HashMap::new(),
858 tool_callbacks_with_tools: HashMap::new(),
859 };
860
861 let mut add_model_config = mistralrs_core::AddModelConfig::new(engine_config);
862 if let Some(mcp_config) = self.mcp_client_config.clone() {
863 add_model_config = add_model_config.with_mcp_config(mcp_config);
864 }
865
866 mistralrs
867 .add_model(
868 pipeline_name.clone(),
869 pipeline,
870 scheduler_config.clone(),
871 add_model_config,
872 )
873 .await
874 .map_err(|e| anyhow::anyhow!("Failed to add model {}: {}", pipeline_name, e))?;
875
876 info!(
877 "Model `{pipeline_name}` registered successfully (from config key: {})",
878 model_config.model_id
879 );
880 pipeline_names.push(pipeline_name);
881 }
882
883 if let Some(ref default_model_id) = self.default_model_id {
885 mistralrs
886 .set_default_model_id(default_model_id)
887 .map_err(|e| anyhow::anyhow!("Failed to set default model: {}", e))?;
888 }
889
890 info!("All models loaded: `{}`", pipeline_names.join("`, `"));
892
893 if let Some(ref default_id) = self.default_model_id {
895 info!("Default model: {}", default_id);
896 } else {
897 info!(
898 "Default model: {} (first model, from config key: {})",
899 pipeline_names[0], self.models[0].model_id
900 );
901 }
902 Ok(mistralrs)
903 }
904}
905
906fn init_device(force_cpu: bool, seed: Option<u64>) -> Result<candle_core::Device> {
909 #[cfg(feature = "metal")]
910 let device = if force_cpu {
911 Device::Cpu
912 } else {
913 Device::new_metal(0)?
914 };
915 #[cfg(not(feature = "metal"))]
916 #[allow(clippy::if_same_then_else)]
917 let device = if force_cpu {
918 Device::Cpu
919 } else if mistralrs_core::distributed::use_nccl() {
920 Device::Cpu
921 } else {
922 Device::cuda_if_available(0)?
923 };
924
925 if let Some(seed) = seed {
926 device.set_seed(seed)?;
927 }
928
929 Ok(device)
930}
931
932fn init_mapper(
934 num_device_layers: &Option<Vec<String>>,
935 auto_device_map_params: &AutoDeviceMapParams,
936) -> DeviceMapSetting {
937 if let Some(device_layers) = num_device_layers {
939 if device_layers.len() == 1 && device_layers[0].parse::<usize>().is_ok() {
940 let layers = device_layers[0].parse::<usize>().unwrap();
941 DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(vec![
942 DeviceLayerMapMetadata { ordinal: 0, layers },
943 ]))
944 } else {
945 let mut mapping = Vec::new();
946 for layer in device_layers {
947 let split = layer.splitn(2, ':').collect::<Vec<_>>();
948 if split.len() < 2 {
949 panic!("Expected layer to be of format ORD:NUM, got {layer}");
950 }
951 let ord = split[0]
952 .parse::<usize>()
953 .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[0]));
954 let num = split[1]
955 .parse::<usize>()
956 .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[1]));
957 for DeviceLayerMapMetadata { ordinal, layers: _ } in &mapping {
958 if *ordinal == ord {
959 panic!("Duplicate ordinal {ord}");
960 }
961 }
962 mapping.push(DeviceLayerMapMetadata {
963 ordinal: ord,
964 layers: num,
965 });
966 }
967 DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(mapping))
968 }
969 } else {
970 DeviceMapSetting::Auto(auto_device_map_params.clone())
971 }
972}
973
974fn mistralrs_instance_info(loader: &dyn Loader) {
976 info!(
977 "avx: {}, neon: {}, simd128: {}, f16c: {}",
978 candle_core::utils::with_avx(),
979 candle_core::utils::with_neon(),
980 candle_core::utils::with_simd128(),
981 candle_core::utils::with_f16c()
982 );
983
984 info!("Sampling method: penalties -> temperature -> topk -> topp -> minp -> multinomial");
985 info!("Model kind is: {}", loader.get_kind().to_string());
986}
987
988fn configure_paged_attn(device: &Device, paged_attn: Option<bool>) -> bool {
990 if device.is_cpu() {
991 if paged_attn == Some(true) {
992 warn!("Paged attention is not supported on CPU.");
993 }
994
995 defaults::PAGED_ATTN_CPU
996 } else if device.is_cuda() || mistralrs_core::distributed::use_nccl() {
997 paged_attn.unwrap_or(defaults::PAGED_ATTN_CUDA)
998 } else if device.is_metal() {
999 paged_attn.unwrap_or(defaults::PAGED_ATTN_METAL)
1000 } else {
1001 false
1002 }
1003}
1004
1005fn init_cache_config(
1007 paged_attn_block_size: Option<usize>,
1008 paged_attn_gpu_mem: Option<usize>,
1009 paged_attn_gpu_mem_usage: Option<f32>,
1010 paged_ctxt_len: Option<usize>,
1011 cache_type: PagedCacheType,
1012 no_paged_attn: bool,
1013 max_seq_len: usize,
1014) -> Result<Option<PagedAttentionConfig>> {
1015 match (
1016 paged_attn_block_size,
1017 paged_attn_gpu_mem,
1018 paged_attn_gpu_mem_usage,
1019 paged_ctxt_len,
1020 paged_attn_supported(),
1021 no_paged_attn,
1022 ) {
1023 (block_size, None, None, None, true, false) => Ok(Some(PagedAttentionConfig::new(
1024 block_size,
1025 MemoryGpuConfig::ContextSize(max_seq_len),
1026 cache_type,
1027 )?)),
1028 (block_size, None, None, Some(ctxt), true, false) => Ok(Some(PagedAttentionConfig::new(
1029 block_size,
1030 MemoryGpuConfig::ContextSize(ctxt),
1031 cache_type,
1032 )?)),
1033 (block_size, None, Some(f), None, true, false) => Ok(Some(PagedAttentionConfig::new(
1034 block_size,
1035 MemoryGpuConfig::Utilization(f),
1036 cache_type,
1037 )?)),
1038 (block_size, Some(m), None, None, true, false) => Ok(Some(PagedAttentionConfig::new(
1039 block_size,
1040 MemoryGpuConfig::MbAmount(m),
1041 cache_type,
1042 )?)),
1043 (block_size, Some(_m), Some(f), None, true, false) => {
1044 info!("Both memory size, and usage were specified, defaulting to the usage value.");
1045 Ok(Some(PagedAttentionConfig::new(
1046 block_size,
1047 MemoryGpuConfig::Utilization(f),
1048 cache_type,
1049 )?))
1050 }
1051 (block_size, Some(_m), None, Some(ctxt), true, false) => {
1052 info!("All memory size and ctxt len, defaulting to the context len value.");
1053 Ok(Some(PagedAttentionConfig::new(
1054 block_size,
1055 MemoryGpuConfig::ContextSize(ctxt),
1056 cache_type,
1057 )?))
1058 }
1059 (block_size, None, Some(f), Some(_ctxt), true, false) => {
1060 info!("Both ctxt len and usage were specified, defaulting to the usage value.");
1061 Ok(Some(PagedAttentionConfig::new(
1062 block_size,
1063 MemoryGpuConfig::Utilization(f),
1064 cache_type,
1065 )?))
1066 }
1067 (_, _, _, _, _, _) => Ok(None),
1068 }
1069}
1070
1071async fn init_scheduler_config(
1073 cache_config: &Option<PagedAttentionConfig>,
1074 pipeline: &LoadedPipeline,
1075 args_max_seqs: usize,
1076) -> SchedulerConfig {
1077 if cache_config.is_some() {
1078 if let Some(ref cache_config) = pipeline.lock().await.get_metadata().cache_config {
1080 SchedulerConfig::PagedAttentionMeta {
1081 max_num_seqs: args_max_seqs,
1082 config: cache_config.clone(),
1083 }
1084 } else {
1085 SchedulerConfig::DefaultScheduler {
1086 method: DefaultSchedulerMethod::Fixed(args_max_seqs.try_into().unwrap()),
1087 }
1088 }
1089 } else {
1090 SchedulerConfig::DefaultScheduler {
1091 method: DefaultSchedulerMethod::Fixed(args_max_seqs.try_into().unwrap()),
1092 }
1093 }
1094}
1095
1096pub fn configure_paged_attn_from_flags(
1101 paged_attn: bool,
1102 no_paged_attn: bool,
1103) -> Result<Option<bool>> {
1104 match (paged_attn, no_paged_attn) {
1105 (true, true) => {
1106 anyhow::bail!("Error: `--paged-attn` and `--no-paged-attn` cannot be used together.");
1107 }
1108 (true, false) => Ok(Some(true)),
1109 (false, true) => Ok(Some(false)),
1110 (false, false) => Ok(None),
1111 }
1112}
1113
1114pub fn get_search_embedding_model(
1116 enable_search: bool,
1117 search_embedding_model: Option<SearchEmbeddingModel>,
1118) -> Option<SearchEmbeddingModel> {
1119 if enable_search {
1120 Some(search_embedding_model.unwrap_or_default())
1121 } else {
1122 None
1123 }
1124}