mistralrs_server_core/
mistralrs_for_server_builder.rs

1//! ## mistral.rs instance for server builder.
2
3use std::sync::Arc;
4
5use anyhow::{Context, Result};
6use candle_core::Device;
7use mistralrs_core::{
8    get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, paged_attn_supported,
9    parse_isq_value, AutoDeviceMapParams, DefaultSchedulerMethod, DeviceLayerMapMetadata,
10    DeviceMapMetadata, DeviceMapSetting, Loader, LoaderBuilder, McpClientConfig, MemoryGpuConfig,
11    MistralRsBuilder, ModelSelected, PagedAttentionConfig, PagedCacheType, SchedulerConfig,
12    SearchCallback, SearchEmbeddingModel, TokenSource,
13};
14use tracing::{info, warn};
15
16use crate::types::{LoadedPipeline, SharedMistralRsState};
17use std::collections::HashMap;
18
19/// Configuration for a single model in a multi-model setup
20#[derive(Clone, serde::Deserialize)]
21pub struct ModelConfig {
22    /// Model identifier (used in API requests)
23    pub model_id: String,
24    /// Model selector
25    pub model: ModelSelected,
26    /// Model-specific chat template
27    pub chat_template: Option<String>,
28    /// Model-specific JINJA template
29    pub jinja_explicit: Option<String>,
30    /// Model-specific device layers
31    pub num_device_layers: Option<Vec<String>>,
32    /// Model-specific in-situ quantization
33    pub in_situ_quant: Option<String>,
34}
35
36impl ModelConfig {
37    pub fn new(model_id: String, model: ModelSelected) -> Self {
38        Self {
39            model_id,
40            model,
41            chat_template: None,
42            jinja_explicit: None,
43            num_device_layers: None,
44            in_situ_quant: None,
45        }
46    }
47
48    pub fn with_chat_template(mut self, chat_template: String) -> Self {
49        self.chat_template = Some(chat_template);
50        self
51    }
52
53    pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
54        self.jinja_explicit = Some(jinja_explicit);
55        self
56    }
57
58    pub fn with_num_device_layers(mut self, num_device_layers: Vec<String>) -> Self {
59        self.num_device_layers = Some(num_device_layers);
60        self
61    }
62
63    pub fn with_in_situ_quant(mut self, in_situ_quant: String) -> Self {
64        self.in_situ_quant = Some(in_situ_quant);
65        self
66    }
67}
68
69pub mod defaults {
70    use super::SearchEmbeddingModel;
71    // Provides the default values used for the mistral.rs instance for server.
72    // These defaults can be used for CLI argument fallbacks, config loading, or general initialization.
73
74    use std::sync::Arc;
75
76    use mistralrs_core::PagedCacheType;
77
78    pub const DEVICE: Option<candle_core::Device> = None;
79    pub const SEED: Option<u64> = None;
80    pub const LOG: Option<String> = None;
81    pub const MODEL: Option<mistralrs_core::ModelSelected> = None;
82    pub const MAX_SEQS: usize = 16;
83    pub const NO_KV_CACHE: bool = false;
84    pub const CHAT_TEMPLATE: Option<String> = None;
85    pub const JINJA_EXPLICIT: Option<String> = None;
86    pub const INTERACTIVE_MODE: bool = false;
87    pub const PREFIX_CACHE_N: usize = 16;
88    pub const NUM_DEVICE_LAYERS: Option<Vec<String>> = None;
89    pub const IN_SITU_QUANT: Option<String> = None;
90    pub const PAGED_ATTN_GPU_MEM: Option<usize> = None;
91    pub const PAGED_ATTN_GPU_MEM_USAGE: Option<f32> = None;
92    pub const PAGED_CTXT_LEN: Option<usize> = None;
93    pub const PAGED_ATTN_BLOCK_SIZE: Option<usize> = None;
94    pub const PAGED_ATTN: Option<bool> = None;
95    pub const PAGED_ATTN_CPU: bool = false;
96    pub const PAGED_ATTN_CUDA: bool = true;
97    pub const PAGED_ATTN_METAL: bool = false;
98    pub const CPU: bool = false;
99    pub const ENABLE_SEARCH: bool = false;
100    pub const SEARCH_EMBEDDING_MODEL: Option<SearchEmbeddingModel> = None;
101    pub const TOKEN_SOURCE: mistralrs_core::TokenSource = mistralrs_core::TokenSource::CacheToken;
102    pub const SEARCH_CALLBACK: Option<Arc<mistralrs_core::SearchCallback>> = None;
103    pub const PAGED_CACHE_TYPE: PagedCacheType = PagedCacheType::Auto;
104}
105
106/// A builder for creating a mistral.rs instance with configured options for the mistral.rs server.
107///
108/// ### Examples
109///
110/// Basic usage:
111/// ```ignore
112/// use mistralrs_server_core::mistralrs_for_server_builder::MistralRsForServerBuilder;
113///
114/// let args = Args::parse();
115///
116/// let mistralrs = MistralRsForServerBuilder::new()
117///        .with_model(args.model)
118///        .with_max_seqs(args.max_seqs)
119///        .with_no_kv_cache(args.no_kv_cache)
120///        .with_token_source(args.token_source)
121///        .with_interactive_mode(args.interactive_mode)
122///        .with_prefix_cache_n(args.prefix_cache_n)
123///        .with_paged_attn(args.paged_attn)
124///        .with_cpu(args.cpu)
125///        .with_enable_search(args.enable_search)
126///        .with_seed_optional(args.seed)
127///        .with_log_optional(args.log)
128///        .with_chat_template_optional(args.chat_template)
129///        .with_jinja_explicit_optional(args.jinja_explicit)
130///        .with_num_device_layers_optional(args.num_device_layers)
131///        .with_in_situ_quant_optional(args.in_situ_quant)
132///        .with_paged_attn_gpu_mem_optional(args.paged_attn_gpu_mem)
133///        .with_paged_attn_gpu_mem_usage_optional(args.paged_attn_gpu_mem_usage)
134///        .with_paged_ctxt_len_optional(args.paged_ctxt_len)
135///        .with_paged_attn_block_size_optional(args.paged_attn_block_size)
136///        .build()
137///        .await?;
138/// ```
139pub struct MistralRsForServerBuilder {
140    /// The Candle device to use for model execution (CPU, CUDA, Metal, etc.).
141    device: Option<Device>,
142
143    /// Integer seed to ensure reproducible random number generation.
144    seed: Option<u64>,
145
146    /// Log all responses and requests to this file
147    log: Option<String>,
148
149    /// Model selector (for single-model mode, deprecated in favor of models)
150    model: Option<ModelSelected>,
151
152    /// Multiple model configurations (for multi-model mode)
153    models: Vec<ModelConfig>,
154
155    /// Default model ID to use when none is specified in requests
156    default_model_id: Option<String>,
157
158    /// Maximum running sequences at any time. If the `tgt_non_granular_index` flag is set for X-LoRA models, this will be set to 1.
159    max_seqs: usize,
160
161    /// Use no KV cache.
162    no_kv_cache: bool,
163
164    /// Chat template file with a JINJA file with `messages`, `add_generation_prompt`, `bos_token`, `eos_token`, and `unk_token` as inputs.
165    /// Used if the automatic deserialization fails. If this ends with `.json` (ie., it is a file) then that template is loaded.
166    chat_template: Option<String>,
167
168    /// Explicit JINJA chat template file (.jinja) to be used. If specified, this overrides all other chat templates.
169    jinja_explicit: Option<String>,
170
171    /// Source of the token for authentication.
172    /// Can be in the formats: `literal:<value>`, `env:<value>`, `path:<value>`, `cache` to use a cached token, or `none` to use no token.
173    /// Defaults to `cache`.
174    token_source: TokenSource,
175
176    /// Enter interactive mode instead of serving a chat server.
177    interactive_mode: bool,
178
179    /// Number of prefix caches to hold on the device. Other caches are evicted to the CPU based on a LRU strategy.
180    prefix_cache_n: usize,
181
182    /// NOTE: This can be omitted to use automatic device mapping!
183    /// Number of device layers to load and run on GPU(s). All others will be on the CPU.
184    /// If one GPU is used, then this value should be an integer. Otherwise, it follows the following pattern:
185    /// ORD:NUM;... Where ORD is a unique device ordinal and NUM is the number of layers for that device.
186    num_device_layers: Option<Vec<String>>,
187
188    /// In-situ quantization to apply.
189    in_situ_quant: Option<String>,
190
191    /// GPU memory to allocate for KV cache with PagedAttention in MBs.
192    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
193    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
194    paged_attn_gpu_mem: Option<usize>,
195
196    /// Percentage of GPU memory to utilize after allocation of KV cache with PagedAttention, from 0 to 1.
197    /// If this is not set and the device is CUDA, it will default to `0.9`.
198    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
199    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
200    paged_attn_gpu_mem_usage: Option<f32>,
201
202    /// Total context length to allocate the KV cache for (total number of tokens which the KV cache can hold).
203    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
204    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
205    /// This is the default setting, and it defaults to the `max-seq-len` specified in after the model type.
206    paged_ctxt_len: Option<usize>,
207
208    /// Block size (number of tokens per block) for PagedAttention. If this is not set and the device is CUDA, it will default to 32.
209    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
210    paged_attn_block_size: Option<usize>,
211
212    /// Enables or disables PagedAttention. By default, PagedAttention will be enabled for CUDA and disabled for Metal (and is not supported for CPU). Use this to override the default behavior.
213    paged_attn: Option<bool>,
214
215    /// Use CPU only
216    cpu: bool,
217
218    /// Enable searching compatible with the OpenAI `web_search_options` setting. This loads the selected search embedding reranker (EmbeddingGemma by default).
219    enable_search: bool,
220
221    /// Specify which built-in search embedding model to load.
222    search_embedding_model: Option<SearchEmbeddingModel>,
223
224    /// Optional override search callback
225    search_callback: Option<Arc<SearchCallback>>,
226
227    /// Optional MCP client configuration
228    mcp_client_config: Option<McpClientConfig>,
229
230    /// PagedAttention KV cache type
231    paged_cache_type: PagedCacheType,
232}
233
234impl Default for MistralRsForServerBuilder {
235    /// Creates a new builder with default configuration.
236    fn default() -> Self {
237        Self {
238            device: defaults::DEVICE,
239            seed: defaults::SEED,
240            log: defaults::LOG,
241            model: defaults::MODEL,
242            models: Vec::new(),
243            default_model_id: None,
244            max_seqs: defaults::MAX_SEQS,
245            no_kv_cache: defaults::NO_KV_CACHE,
246            chat_template: defaults::CHAT_TEMPLATE,
247            jinja_explicit: defaults::JINJA_EXPLICIT,
248            token_source: defaults::TOKEN_SOURCE,
249            interactive_mode: defaults::INTERACTIVE_MODE,
250            prefix_cache_n: defaults::PREFIX_CACHE_N,
251            num_device_layers: defaults::NUM_DEVICE_LAYERS,
252            in_situ_quant: defaults::IN_SITU_QUANT,
253            paged_attn_gpu_mem: defaults::PAGED_ATTN_GPU_MEM,
254            paged_attn_gpu_mem_usage: defaults::PAGED_ATTN_GPU_MEM_USAGE,
255            paged_ctxt_len: defaults::PAGED_CTXT_LEN,
256            paged_attn_block_size: defaults::PAGED_ATTN_BLOCK_SIZE,
257            paged_attn: defaults::PAGED_ATTN,
258            cpu: defaults::CPU,
259            enable_search: defaults::ENABLE_SEARCH,
260            search_embedding_model: defaults::SEARCH_EMBEDDING_MODEL,
261            search_callback: defaults::SEARCH_CALLBACK,
262            mcp_client_config: None,
263            paged_cache_type: defaults::PAGED_CACHE_TYPE,
264        }
265    }
266}
267
268impl MistralRsForServerBuilder {
269    /// Creates a new `MistralRsForServerBuilder` with default settings.
270    ///
271    /// This is equivalent to calling `Default::default()`.
272    ///
273    /// ### Examples
274    ///
275    /// ```ignore
276    /// use mistralrs_server_core::mistralrs_for_server_builder::MistralRsForServerBuilder;
277    ///
278    /// let builder = mistralrs_server_core::mistralrs_for_server_builder::MistralRsForServerBuilder::new();
279    /// ```
280    pub fn new() -> Self {
281        Default::default()
282    }
283
284    /// Sets the Candle device to use for model execution.
285    pub fn with_device(mut self, device: Device) -> Self {
286        self.device = Some(device);
287        self
288    }
289
290    /// Sets the random seed for deterministic model behavior.
291    pub fn with_seed(mut self, seed: u64) -> Self {
292        self.seed = Some(seed);
293        self
294    }
295
296    /// Sets the random seed if provided.
297    pub fn with_seed_optional(mut self, seed: Option<u64>) -> Self {
298        if let Some(seed) = seed {
299            self = self.with_seed(seed);
300        }
301        self
302    }
303
304    /// Sets the logging configuration.
305    pub fn with_log(mut self, log: String) -> Self {
306        self.log = Some(log);
307        self
308    }
309
310    /// Sets the logging configuration if provided.
311    pub fn with_log_optional(mut self, log: Option<String>) -> Self {
312        if let Some(log) = log {
313            self = self.with_log(log);
314        }
315        self
316    }
317
318    /// Sets the model to be used.
319    pub fn with_model(mut self, model: ModelSelected) -> Self {
320        self.model = Some(model);
321        self
322    }
323
324    /// Add a model to the multi-model configuration.
325    pub fn with_model_config(mut self, model_config: ModelConfig) -> Self {
326        self.models.push(model_config);
327        self
328    }
329
330    /// Add multiple models to the multi-model configuration.
331    pub fn with_model_configs(mut self, model_configs: Vec<ModelConfig>) -> Self {
332        self.models.extend(model_configs);
333        self
334    }
335
336    /// Set the default model ID to use when none is specified in requests.
337    pub fn with_default_model_id(mut self, default_model_id: String) -> Self {
338        self.default_model_id = Some(default_model_id);
339        self
340    }
341
342    /// Add a model configuration.
343    pub fn add_model_config(mut self, config: ModelConfig) -> Self {
344        self.models.push(config);
345        self
346    }
347
348    /// Add a model with just an ID and ModelSelected (convenience method).
349    pub fn add_model(mut self, model_id: String, model: ModelSelected) -> Self {
350        self.models.push(ModelConfig::new(model_id, model));
351        self
352    }
353
354    /// Sets the maximum number of concurrent sequences.
355    pub fn with_max_seqs(mut self, max_seqs: usize) -> Self {
356        self.max_seqs = max_seqs;
357        self
358    }
359
360    /// Sets whether to disable the key-value cache.
361    pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
362        self.no_kv_cache = no_kv_cache;
363        self
364    }
365
366    /// Sets the chat template configuration.
367    pub fn with_chat_template(mut self, chat_template: String) -> Self {
368        self.chat_template = Some(chat_template);
369        self
370    }
371
372    /// Sets the chat template configuration if provided.
373    pub fn with_chat_template_optional(mut self, chat_template: Option<String>) -> Self {
374        if let Some(chat_template) = chat_template {
375            self = self.with_chat_template(chat_template);
376        }
377        self
378    }
379
380    /// Sets an explicit JINJA chat template file.
381    pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
382        self.jinja_explicit = Some(jinja_explicit);
383        self
384    }
385
386    /// Sets an explicit JINJA chat template file if provided.
387    pub fn with_jinja_explicit_optional(mut self, jinja_explicit: Option<String>) -> Self {
388        if let Some(jinja_explicit) = jinja_explicit {
389            self = self.with_jinja_explicit(jinja_explicit);
390        }
391        self
392    }
393
394    /// Sets the token source for authentication.
395    pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
396        self.token_source = token_source;
397        self
398    }
399
400    /// Sets whether to run in interactive mode.
401    pub fn with_interactive_mode(mut self, interactive_mode: bool) -> Self {
402        self.interactive_mode = interactive_mode;
403        self
404    }
405
406    /// Sets the number of prefix caches to hold on the device.
407    pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
408        self.prefix_cache_n = prefix_cache_n;
409        self
410    }
411
412    /// Sets the device layer mapping
413    pub fn with_num_device_layers(mut self, num_device_layers: Vec<String>) -> Self {
414        self.num_device_layers = Some(num_device_layers);
415        self
416    }
417
418    /// Sets the device layer mapping if provided.
419    pub fn with_num_device_layers_optional(
420        mut self,
421        num_device_layers: Option<Vec<String>>,
422    ) -> Self {
423        if let Some(num_device_layers) = num_device_layers {
424            self = self.with_num_device_layers(num_device_layers);
425        }
426        self
427    }
428
429    /// Sets the in-situ quantization method.
430    pub fn with_in_situ_quant(mut self, in_situ_quant: String) -> Self {
431        self.in_situ_quant = Some(in_situ_quant);
432        self
433    }
434
435    /// Sets the in-situ quantization method if provided.
436    pub fn with_in_situ_quant_optional(mut self, in_situ_quant: Option<String>) -> Self {
437        if let Some(in_situ_quant) = in_situ_quant {
438            self = self.with_in_situ_quant(in_situ_quant);
439        }
440        self
441    }
442
443    /// Sets PagedAttention.
444    ///
445    /// Unlike other `with_PROP` or `with_PROP_optional` methods, this method
446    /// sets the value to whatever `Option<bool>` is passed in as `None`, `Some(true)`
447    /// and `Some(false)` have different implications.
448    ///
449    /// `None`: default behavior for target device (e.g. enable for CUDA, disable for Metal)
450    /// `Some(true)`: enable (if supported by target device)
451    /// `Some(false)`: disable
452    pub fn set_paged_attn(mut self, paged_attn: Option<bool>) -> Self {
453        self.paged_attn = paged_attn;
454        self
455    }
456
457    /// Sets the GPU memory allocation for PagedAttention KV cache.
458    pub fn with_paged_attn_gpu_mem(mut self, paged_attn_gpu_mem: usize) -> Self {
459        self.paged_attn_gpu_mem = Some(paged_attn_gpu_mem);
460        self
461    }
462
463    /// Sets the GPU memory allocation for PagedAttention KV cache if provided.
464    pub fn with_paged_attn_gpu_mem_optional(mut self, paged_attn_gpu_mem: Option<usize>) -> Self {
465        if let Some(paged_attn_gpu_mem) = paged_attn_gpu_mem {
466            self = self.with_paged_attn_gpu_mem(paged_attn_gpu_mem);
467        }
468        self
469    }
470
471    /// Sets the percentage of GPU memory to utilize for PagedAttention.
472    pub fn with_paged_attn_gpu_mem_usage(mut self, paged_attn_gpu_mem_usage: f32) -> Self {
473        self.paged_attn_gpu_mem_usage = Some(paged_attn_gpu_mem_usage);
474        self
475    }
476
477    /// Sets the percentage of GPU memory to utilize for PagedAttention if provided.
478    pub fn with_paged_attn_gpu_mem_usage_optional(
479        mut self,
480        paged_attn_gpu_mem_usage: Option<f32>,
481    ) -> Self {
482        if let Some(paged_attn_gpu_mem_usage) = paged_attn_gpu_mem_usage {
483            self = self.with_paged_attn_gpu_mem_usage(paged_attn_gpu_mem_usage);
484        }
485        self
486    }
487
488    /// Sets the total context length for KV cache allocation.
489    pub fn with_paged_ctxt_len(mut self, paged_ctxt_len: usize) -> Self {
490        self.paged_ctxt_len = Some(paged_ctxt_len);
491        self
492    }
493
494    /// Sets the total context length for KV cache allocation if provided.
495    pub fn with_paged_ctxt_len_optional(mut self, paged_ctxt_len: Option<usize>) -> Self {
496        if let Some(paged_ctxt_len) = paged_ctxt_len {
497            self = self.with_paged_ctxt_len(paged_ctxt_len);
498        }
499        self
500    }
501
502    /// Sets the block size for PagedAttention.
503    pub fn with_paged_attn_block_size(mut self, paged_attn_block_size: usize) -> Self {
504        self.paged_attn_block_size = Some(paged_attn_block_size);
505        self
506    }
507
508    /// Sets the block size for PagedAttention.
509    pub fn with_paged_attn_cache_type(mut self, cache_type: PagedCacheType) -> Self {
510        self.paged_cache_type = cache_type;
511        self
512    }
513
514    /// Sets the block size for PagedAttention if provided.
515    pub fn with_paged_attn_block_size_optional(
516        mut self,
517        paged_attn_block_size: Option<usize>,
518    ) -> Self {
519        if let Some(paged_attn_block_size) = paged_attn_block_size {
520            self = self.with_paged_attn_block_size(paged_attn_block_size);
521        }
522        self
523    }
524
525    /// Sets whether to force CPU-only execution.
526    pub fn with_cpu(mut self, cpu: bool) -> Self {
527        self.cpu = cpu;
528        self
529    }
530
531    /// Sets whether to enable web search functionality.
532    pub fn with_enable_search(mut self, enable_search: bool) -> Self {
533        self.enable_search = enable_search;
534        self
535    }
536
537    /// Sets the embedding model used for web search assistance.
538    pub fn with_search_embedding_model(
539        mut self,
540        search_embedding_model: SearchEmbeddingModel,
541    ) -> Self {
542        self.search_embedding_model = Some(search_embedding_model);
543        self
544    }
545
546    /// Override the search function used when `web_search_options` is enabled.
547    pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
548        self.search_callback = Some(callback);
549        self
550    }
551
552    /// Sets the MCP client configuration.
553    pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
554        self.mcp_client_config = Some(mcp_config);
555        self
556    }
557
558    /// Sets the MCP client configuration if provided.
559    pub fn with_mcp_config_optional(mut self, mcp_config: Option<McpClientConfig>) -> Self {
560        if let Some(mcp_config) = mcp_config {
561            self = self.with_mcp_config(mcp_config);
562        }
563        self
564    }
565
566    /// Builds the configured mistral.rs instance.
567    ///
568    /// ### Examples
569    ///
570    /// ```ignore
571    /// use mistralrs_server_core::mistralrs_for_server_builder::MistralRsForServerBuilder;
572    ///
573    /// let shared_mistralrs = MistralRsForServerBuilder::new()
574    ///     .with_model(model)
575    ///     .with_in_situ_quant("8".to_string())
576    ///     .set_paged_attn(Some(true))
577    ///     .build()
578    ///     .await?;
579    /// ```
580    pub async fn build(self) -> Result<SharedMistralRsState> {
581        // Determine if we're in single-model or multi-model mode
582        if !self.models.is_empty() {
583            self.build_multi_model().await
584        } else {
585            self.build_single_model().await
586        }
587    }
588
589    /// Build a single-model instance (legacy mode)
590    async fn build_single_model(mut self) -> Result<SharedMistralRsState> {
591        let model = self.model.context("Model was None")?;
592
593        let tgt_non_granular_index = get_tgt_non_granular_index(&model);
594        let dtype = get_model_dtype(&model)?;
595        let auto_device_map_params = get_auto_device_map_params(&model)?;
596
597        if tgt_non_granular_index.is_some() {
598            self.max_seqs = 1;
599        }
600
601        let max_seq_len = auto_device_map_params.max_seq_len();
602
603        let device = if let Some(device) = self.device {
604            device
605        } else {
606            init_device(self.cpu, self.seed)?
607        };
608
609        let mapper = init_mapper(&self.num_device_layers, &auto_device_map_params);
610        let paged_attn = configure_paged_attn(&device, self.paged_attn);
611
612        let cache_config = init_cache_config(
613            self.paged_attn_block_size,
614            self.paged_attn_gpu_mem,
615            self.paged_attn_gpu_mem_usage,
616            self.paged_ctxt_len,
617            self.paged_cache_type,
618            !paged_attn,
619            max_seq_len,
620        )?;
621
622        // Configure this last to prevent arg moves
623        let loader: Box<dyn Loader> = LoaderBuilder::new(model)
624            .with_no_kv_cache(self.no_kv_cache)
625            .with_chat_template(self.chat_template)
626            .with_jinja_explicit(self.jinja_explicit)
627            .build()?;
628
629        mistralrs_instance_info(&*loader);
630
631        let isq = self
632            .in_situ_quant
633            .as_ref()
634            .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
635
636        let pipeline: LoadedPipeline = loader.load_model_from_hf(
637            None,
638            self.token_source,
639            &dtype,
640            &device,
641            false,
642            mapper,
643            isq,
644            cache_config,
645        )?;
646        info!("Model loaded.");
647
648        let scheduler_config = init_scheduler_config(&cache_config, &pipeline, self.max_seqs).await;
649
650        let search_embedding_model =
651            get_search_embedding_model(self.enable_search, self.search_embedding_model);
652
653        let mut builder = MistralRsBuilder::new(
654            pipeline,
655            scheduler_config,
656            !self.interactive_mode,
657            search_embedding_model,
658        )
659        .with_opt_log(self.log)
660        .with_no_kv_cache(self.no_kv_cache)
661        .with_prefix_cache_n(self.prefix_cache_n);
662
663        // Add MCP client configuration if provided
664        if let Some(mcp_config) = self.mcp_client_config {
665            builder = builder.with_mcp_client(mcp_config);
666        }
667
668        let mistralrs = builder.build().await;
669
670        Ok(mistralrs)
671    }
672
673    /// Build a multi-model instance
674    pub async fn build_multi_model(mut self) -> Result<SharedMistralRsState> {
675        if self.models.is_empty() {
676            anyhow::bail!("No models configured for multi-model mode");
677        }
678
679        // Use the first model as the base configuration
680        let first_model = &self.models[0];
681        let model = first_model.model.clone();
682
683        let tgt_non_granular_index = get_tgt_non_granular_index(&model);
684        let dtype = get_model_dtype(&model)?;
685        let auto_device_map_params = get_auto_device_map_params(&model)?;
686
687        if tgt_non_granular_index.is_some() {
688            self.max_seqs = 1;
689        }
690
691        let max_seq_len = auto_device_map_params.max_seq_len();
692
693        let device = if let Some(device) = self.device {
694            device
695        } else {
696            init_device(self.cpu, self.seed)?
697        };
698
699        // Create the first model's pipeline
700        let loader: Box<dyn Loader> = LoaderBuilder::new(model)
701            .with_no_kv_cache(self.no_kv_cache)
702            .with_chat_template(
703                first_model
704                    .chat_template
705                    .clone()
706                    .or(self.chat_template.clone()),
707            )
708            .with_jinja_explicit(
709                first_model
710                    .jinja_explicit
711                    .clone()
712                    .or(self.jinja_explicit.clone()),
713            )
714            .build()?;
715
716        mistralrs_instance_info(&*loader);
717
718        let mapper = init_mapper(
719            &first_model
720                .num_device_layers
721                .clone()
722                .or(self.num_device_layers.clone()),
723            &auto_device_map_params,
724        );
725        let paged_attn = configure_paged_attn(&device, self.paged_attn);
726
727        let cache_config = init_cache_config(
728            self.paged_attn_block_size,
729            self.paged_attn_gpu_mem,
730            self.paged_attn_gpu_mem_usage,
731            self.paged_ctxt_len,
732            self.paged_cache_type,
733            !paged_attn,
734            max_seq_len,
735        )?;
736
737        let isq = first_model
738            .in_situ_quant
739            .as_ref()
740            .or(self.in_situ_quant.as_ref())
741            .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
742
743        let mut pipeline_names = Vec::new();
744
745        let pipeline: LoadedPipeline = loader.load_model_from_hf(
746            None,
747            self.token_source.clone(),
748            &dtype,
749            &device,
750            false,
751            mapper,
752            isq,
753            cache_config,
754        )?;
755        let first_pipeline_name = pipeline.lock().await.name();
756        info!(
757            "First model loaded: `{first_pipeline_name}` (from config key: {})",
758            first_model.model_id
759        );
760        pipeline_names.push(first_pipeline_name);
761
762        let scheduler_config = init_scheduler_config(&cache_config, &pipeline, self.max_seqs).await;
763        let search_embedding_model =
764            get_search_embedding_model(self.enable_search, self.search_embedding_model);
765
766        // Create the first MistralRs instance with the first model
767        let mut builder = MistralRsBuilder::new(
768            pipeline,
769            scheduler_config.clone(),
770            !self.interactive_mode,
771            search_embedding_model,
772        )
773        .with_opt_log(self.log.clone())
774        .with_no_kv_cache(self.no_kv_cache)
775        .with_prefix_cache_n(self.prefix_cache_n);
776
777        // Add MCP client configuration if provided
778        if let Some(mcp_config) = self.mcp_client_config.clone() {
779            builder = builder.with_mcp_client(mcp_config);
780        }
781
782        let mistralrs = builder.build().await;
783
784        // Load additional models
785        for model_config in self.models.iter().skip(1) {
786            info!(
787                "Loading additional model from config key: {}",
788                model_config.model_id
789            );
790
791            let model = model_config.model.clone();
792            let dtype = get_model_dtype(&model)?;
793            let auto_device_map_params = get_auto_device_map_params(&model)?;
794
795            let loader: Box<dyn Loader> = LoaderBuilder::new(model)
796                .with_no_kv_cache(self.no_kv_cache)
797                .with_chat_template(
798                    model_config
799                        .chat_template
800                        .clone()
801                        .or(self.chat_template.clone()),
802                )
803                .with_jinja_explicit(
804                    model_config
805                        .jinja_explicit
806                        .clone()
807                        .or(self.jinja_explicit.clone()),
808                )
809                .build()?;
810
811            let mapper = init_mapper(
812                &model_config
813                    .num_device_layers
814                    .clone()
815                    .or(self.num_device_layers.clone()),
816                &auto_device_map_params,
817            );
818
819            let isq = model_config
820                .in_situ_quant
821                .as_ref()
822                .or(self.in_situ_quant.as_ref())
823                .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
824
825            let pipeline: LoadedPipeline = loader.load_model_from_hf(
826                None,
827                self.token_source.clone(),
828                &dtype,
829                &device,
830                false,
831                mapper,
832                isq,
833                cache_config,
834            )?;
835
836            // Use the pipeline's name() as the model ID
837            let pipeline_name = pipeline.lock().await.name();
838
839            // Check for model ID conflicts
840            if pipeline_names.contains(&pipeline_name) {
841                anyhow::bail!(
842                    "Model ID conflict: '{}' is already registered. Models from config keys '{}' and previous models have the same pipeline identifier.",
843                    pipeline_name,
844                    model_config.model_id
845                );
846            }
847
848            // Add the model to the MistralRs instance
849            let engine_config = mistralrs_core::EngineConfig {
850                no_kv_cache: self.no_kv_cache,
851                no_prefix_cache: false,
852                prefix_cache_n: self.prefix_cache_n,
853                disable_eos_stop: false,
854                throughput_logging_enabled: !self.interactive_mode,
855                search_embedding_model,
856                search_callback: self.search_callback.clone(),
857                tool_callbacks: HashMap::new(),
858                tool_callbacks_with_tools: HashMap::new(),
859            };
860
861            let mut add_model_config = mistralrs_core::AddModelConfig::new(engine_config);
862            if let Some(mcp_config) = self.mcp_client_config.clone() {
863                add_model_config = add_model_config.with_mcp_config(mcp_config);
864            }
865
866            mistralrs
867                .add_model(
868                    pipeline_name.clone(),
869                    pipeline,
870                    scheduler_config.clone(),
871                    add_model_config,
872                )
873                .await
874                .map_err(|e| anyhow::anyhow!("Failed to add model {}: {}", pipeline_name, e))?;
875
876            info!(
877                "Model `{pipeline_name}` registered successfully (from config key: {})",
878                model_config.model_id
879            );
880            pipeline_names.push(pipeline_name);
881        }
882
883        // Set the default model if specified
884        if let Some(ref default_model_id) = self.default_model_id {
885            mistralrs
886                .set_default_model_id(default_model_id)
887                .map_err(|e| anyhow::anyhow!("Failed to set default model: {}", e))?;
888        }
889
890        // Log all models loaded
891        info!("All models loaded: `{}`", pipeline_names.join("`, `"));
892
893        // Log default model
894        if let Some(ref default_id) = self.default_model_id {
895            info!("Default model: {}", default_id);
896        } else {
897            info!(
898                "Default model: {} (first model, from config key: {})",
899                pipeline_names[0], self.models[0].model_id
900            );
901        }
902        Ok(mistralrs)
903    }
904}
905
906// TODO: replace with best device?
907/// Initializes the device to be used for computation, optionally forcing CPU usage and setting a seed.
908fn init_device(force_cpu: bool, seed: Option<u64>) -> Result<candle_core::Device> {
909    #[cfg(feature = "metal")]
910    let device = if force_cpu {
911        Device::Cpu
912    } else {
913        Device::new_metal(0)?
914    };
915    #[cfg(not(feature = "metal"))]
916    #[allow(clippy::if_same_then_else)]
917    let device = if force_cpu {
918        Device::Cpu
919    } else if mistralrs_core::distributed::use_nccl() {
920        Device::Cpu
921    } else {
922        Device::cuda_if_available(0)?
923    };
924
925    if let Some(seed) = seed {
926        device.set_seed(seed)?;
927    }
928
929    Ok(device)
930}
931
932/// Initializes the device mapping configuration for distributing model layers.
933fn init_mapper(
934    num_device_layers: &Option<Vec<String>>,
935    auto_device_map_params: &AutoDeviceMapParams,
936) -> DeviceMapSetting {
937    // Parse device mapper
938    if let Some(device_layers) = num_device_layers {
939        if device_layers.len() == 1 && device_layers[0].parse::<usize>().is_ok() {
940            let layers = device_layers[0].parse::<usize>().unwrap();
941            DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(vec![
942                DeviceLayerMapMetadata { ordinal: 0, layers },
943            ]))
944        } else {
945            let mut mapping = Vec::new();
946            for layer in device_layers {
947                let split = layer.splitn(2, ':').collect::<Vec<_>>();
948                if split.len() < 2 {
949                    panic!("Expected layer to be of format ORD:NUM, got {layer}");
950                }
951                let ord = split[0]
952                    .parse::<usize>()
953                    .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[0]));
954                let num = split[1]
955                    .parse::<usize>()
956                    .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[1]));
957                for DeviceLayerMapMetadata { ordinal, layers: _ } in &mapping {
958                    if *ordinal == ord {
959                        panic!("Duplicate ordinal {ord}");
960                    }
961                }
962                mapping.push(DeviceLayerMapMetadata {
963                    ordinal: ord,
964                    layers: num,
965                });
966            }
967            DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(mapping))
968        }
969    } else {
970        DeviceMapSetting::Auto(auto_device_map_params.clone())
971    }
972}
973
974/// Logs hardware feature information and the model's sampling strategy and kind.
975fn mistralrs_instance_info(loader: &dyn Loader) {
976    info!(
977        "avx: {}, neon: {}, simd128: {}, f16c: {}",
978        candle_core::utils::with_avx(),
979        candle_core::utils::with_neon(),
980        candle_core::utils::with_simd128(),
981        candle_core::utils::with_f16c()
982    );
983
984    info!("Sampling method: penalties -> temperature -> topk -> topp -> minp -> multinomial");
985    info!("Model kind is: {}", loader.get_kind().to_string());
986}
987
988/// Determines whether paged attention should be enabled based on device type and preferences.
989fn configure_paged_attn(device: &Device, paged_attn: Option<bool>) -> bool {
990    if device.is_cpu() {
991        if paged_attn == Some(true) {
992            warn!("Paged attention is not supported on CPU.");
993        }
994
995        defaults::PAGED_ATTN_CPU
996    } else if device.is_cuda() || mistralrs_core::distributed::use_nccl() {
997        paged_attn.unwrap_or(defaults::PAGED_ATTN_CUDA)
998    } else if device.is_metal() {
999        paged_attn.unwrap_or(defaults::PAGED_ATTN_METAL)
1000    } else {
1001        false
1002    }
1003}
1004
1005/// Initializes the cache configuration for paged attention based on provided parameters.
1006fn init_cache_config(
1007    paged_attn_block_size: Option<usize>,
1008    paged_attn_gpu_mem: Option<usize>,
1009    paged_attn_gpu_mem_usage: Option<f32>,
1010    paged_ctxt_len: Option<usize>,
1011    cache_type: PagedCacheType,
1012    no_paged_attn: bool,
1013    max_seq_len: usize,
1014) -> Result<Option<PagedAttentionConfig>> {
1015    match (
1016        paged_attn_block_size,
1017        paged_attn_gpu_mem,
1018        paged_attn_gpu_mem_usage,
1019        paged_ctxt_len,
1020        paged_attn_supported(),
1021        no_paged_attn,
1022    ) {
1023        (block_size, None, None, None, true, false) => Ok(Some(PagedAttentionConfig::new(
1024            block_size,
1025            MemoryGpuConfig::ContextSize(max_seq_len),
1026            cache_type,
1027        )?)),
1028        (block_size, None, None, Some(ctxt), true, false) => Ok(Some(PagedAttentionConfig::new(
1029            block_size,
1030            MemoryGpuConfig::ContextSize(ctxt),
1031            cache_type,
1032        )?)),
1033        (block_size, None, Some(f), None, true, false) => Ok(Some(PagedAttentionConfig::new(
1034            block_size,
1035            MemoryGpuConfig::Utilization(f),
1036            cache_type,
1037        )?)),
1038        (block_size, Some(m), None, None, true, false) => Ok(Some(PagedAttentionConfig::new(
1039            block_size,
1040            MemoryGpuConfig::MbAmount(m),
1041            cache_type,
1042        )?)),
1043        (block_size, Some(_m), Some(f), None, true, false) => {
1044            info!("Both memory size, and usage were specified, defaulting to the usage value.");
1045            Ok(Some(PagedAttentionConfig::new(
1046                block_size,
1047                MemoryGpuConfig::Utilization(f),
1048                cache_type,
1049            )?))
1050        }
1051        (block_size, Some(_m), None, Some(ctxt), true, false) => {
1052            info!("All memory size and ctxt len, defaulting to the context len value.");
1053            Ok(Some(PagedAttentionConfig::new(
1054                block_size,
1055                MemoryGpuConfig::ContextSize(ctxt),
1056                cache_type,
1057            )?))
1058        }
1059        (block_size, None, Some(f), Some(_ctxt), true, false) => {
1060            info!("Both ctxt len and usage were specified, defaulting to the usage value.");
1061            Ok(Some(PagedAttentionConfig::new(
1062                block_size,
1063                MemoryGpuConfig::Utilization(f),
1064                cache_type,
1065            )?))
1066        }
1067        (_, _, _, _, _, _) => Ok(None),
1068    }
1069}
1070
1071/// Initializes the scheduler configuration based on cache settings and pipeline metadata.
1072async fn init_scheduler_config(
1073    cache_config: &Option<PagedAttentionConfig>,
1074    pipeline: &LoadedPipeline,
1075    args_max_seqs: usize,
1076) -> SchedulerConfig {
1077    if cache_config.is_some() {
1078        // Handle case where we may have device mapping
1079        if let Some(ref cache_config) = pipeline.lock().await.get_metadata().cache_config {
1080            SchedulerConfig::PagedAttentionMeta {
1081                max_num_seqs: args_max_seqs,
1082                config: cache_config.clone(),
1083            }
1084        } else {
1085            SchedulerConfig::DefaultScheduler {
1086                method: DefaultSchedulerMethod::Fixed(args_max_seqs.try_into().unwrap()),
1087            }
1088        }
1089    } else {
1090        SchedulerConfig::DefaultScheduler {
1091            method: DefaultSchedulerMethod::Fixed(args_max_seqs.try_into().unwrap()),
1092        }
1093    }
1094}
1095
1096/// Configures PagedAttention based on two flags.
1097///
1098/// This function resolves the tri-state PagedAttention configuration from
1099/// the mutually exclusive `paged_attn` and `no_paged_attn` flags.
1100pub fn configure_paged_attn_from_flags(
1101    paged_attn: bool,
1102    no_paged_attn: bool,
1103) -> Result<Option<bool>> {
1104    match (paged_attn, no_paged_attn) {
1105        (true, true) => {
1106            anyhow::bail!("Error: `--paged-attn` and `--no-paged-attn` cannot be used together.");
1107        }
1108        (true, false) => Ok(Some(true)),
1109        (false, true) => Ok(Some(false)),
1110        (false, false) => Ok(None),
1111    }
1112}
1113
1114/// Creates a search embedding model configuration for agentic search reranking.
1115pub fn get_search_embedding_model(
1116    enable_search: bool,
1117    search_embedding_model: Option<SearchEmbeddingModel>,
1118) -> Option<SearchEmbeddingModel> {
1119    if enable_search {
1120        Some(search_embedding_model.unwrap_or_default())
1121    } else {
1122        None
1123    }
1124}