mistralrs_server_core/
mistralrs_for_server_builder.rs

1//! ## mistral.rs instance for server builder.
2
3use std::sync::Arc;
4
5use anyhow::{Context, Result};
6use candle_core::Device;
7use mistralrs_core::{
8    get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, paged_attn_supported,
9    parse_isq_value, AutoDeviceMapParams, BertEmbeddingModel, DefaultSchedulerMethod,
10    DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, Loader, LoaderBuilder,
11    McpClientConfig, MemoryGpuConfig, MistralRsBuilder, ModelSelected, PagedAttentionConfig,
12    PagedCacheType, SchedulerConfig, SearchCallback, TokenSource,
13};
14use tracing::{info, warn};
15
16use crate::types::{LoadedPipeline, SharedMistralRsState};
17use std::collections::HashMap;
18
19/// Configuration for a single model in a multi-model setup
20#[derive(Clone, serde::Deserialize)]
21pub struct ModelConfig {
22    /// Model identifier (used in API requests)
23    pub model_id: String,
24    /// Model selector
25    pub model: ModelSelected,
26    /// Model-specific chat template
27    pub chat_template: Option<String>,
28    /// Model-specific JINJA template
29    pub jinja_explicit: Option<String>,
30    /// Model-specific device layers
31    pub num_device_layers: Option<Vec<String>>,
32    /// Model-specific in-situ quantization
33    pub in_situ_quant: Option<String>,
34}
35
36impl ModelConfig {
37    pub fn new(model_id: String, model: ModelSelected) -> Self {
38        Self {
39            model_id,
40            model,
41            chat_template: None,
42            jinja_explicit: None,
43            num_device_layers: None,
44            in_situ_quant: None,
45        }
46    }
47
48    pub fn with_chat_template(mut self, chat_template: String) -> Self {
49        self.chat_template = Some(chat_template);
50        self
51    }
52
53    pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
54        self.jinja_explicit = Some(jinja_explicit);
55        self
56    }
57
58    pub fn with_num_device_layers(mut self, num_device_layers: Vec<String>) -> Self {
59        self.num_device_layers = Some(num_device_layers);
60        self
61    }
62
63    pub fn with_in_situ_quant(mut self, in_situ_quant: String) -> Self {
64        self.in_situ_quant = Some(in_situ_quant);
65        self
66    }
67}
68
69pub mod defaults {
70    //! Provides the default values used for the mistral.rs instance for server.
71    //! These defaults can be used for CLI argument fallbacks, config loading, or general initialization.
72
73    use std::sync::Arc;
74
75    use mistralrs_core::PagedCacheType;
76
77    pub const DEVICE: Option<candle_core::Device> = None;
78    pub const SEED: Option<u64> = None;
79    pub const LOG: Option<String> = None;
80    pub const TRUNCATE_SEQUENCE: bool = false;
81    pub const MODEL: Option<mistralrs_core::ModelSelected> = None;
82    pub const MAX_SEQS: usize = 16;
83    pub const NO_KV_CACHE: bool = false;
84    pub const CHAT_TEMPLATE: Option<String> = None;
85    pub const JINJA_EXPLICIT: Option<String> = None;
86    pub const INTERACTIVE_MODE: bool = false;
87    pub const PREFIX_CACHE_N: usize = 16;
88    pub const NUM_DEVICE_LAYERS: Option<Vec<String>> = None;
89    pub const IN_SITU_QUANT: Option<String> = None;
90    pub const PAGED_ATTN_GPU_MEM: Option<usize> = None;
91    pub const PAGED_ATTN_GPU_MEM_USAGE: Option<f32> = None;
92    pub const PAGED_CTXT_LEN: Option<usize> = None;
93    pub const PAGED_ATTN_BLOCK_SIZE: Option<usize> = None;
94    pub const PAGED_ATTN: Option<bool> = None;
95    pub const PAGED_ATTN_CPU: bool = false;
96    pub const PAGED_ATTN_CUDA: bool = true;
97    pub const PAGED_ATTN_METAL: bool = false;
98    pub const CPU: bool = false;
99    pub const ENABLE_SEARCH: bool = false;
100    pub const SEARCH_BERT_MODEL: Option<String> = None;
101    pub const TOKEN_SOURCE: mistralrs_core::TokenSource = mistralrs_core::TokenSource::CacheToken;
102    pub const SEARCH_CALLBACK: Option<Arc<mistralrs_core::SearchCallback>> = None;
103    pub const PAGED_CACHE_TYPE: PagedCacheType = PagedCacheType::Auto;
104}
105
106/// A builder for creating a mistral.rs instance with configured options for the mistral.rs server.
107///
108/// ### Examples
109///
110/// Basic usage:
111/// ```ignore
112/// use mistralrs_server_core::mistralrs_for_server_builder::MistralRsForServerBuilder;
113///
114/// let args = Args::parse();
115///
116/// let mistralrs = MistralRsForServerBuilder::new()
117///        .with_truncate_sequence(args.truncate_sequence)
118///        .with_model(args.model)
119///        .with_max_seqs(args.max_seqs)
120///        .with_no_kv_cache(args.no_kv_cache)
121///        .with_token_source(args.token_source)
122///        .with_interactive_mode(args.interactive_mode)
123///        .with_prefix_cache_n(args.prefix_cache_n)
124///        .with_paged_attn(args.paged_attn)
125///        .with_cpu(args.cpu)
126///        .with_enable_search(args.enable_search)
127///        .with_seed_optional(args.seed)
128///        .with_log_optional(args.log)
129///        .with_chat_template_optional(args.chat_template)
130///        .with_jinja_explicit_optional(args.jinja_explicit)
131///        .with_num_device_layers_optional(args.num_device_layers)
132///        .with_in_situ_quant_optional(args.in_situ_quant)
133///        .with_paged_attn_gpu_mem_optional(args.paged_attn_gpu_mem)
134///        .with_paged_attn_gpu_mem_usage_optional(args.paged_attn_gpu_mem_usage)
135///        .with_paged_ctxt_len_optional(args.paged_ctxt_len)
136///        .with_paged_attn_block_size_optional(args.paged_attn_block_size)
137///        .build()
138///        .await?;
139/// ```
140pub struct MistralRsForServerBuilder {
141    /// The Candle device to use for model execution (CPU, CUDA, Metal, etc.).
142    device: Option<Device>,
143
144    /// Integer seed to ensure reproducible random number generation.
145    seed: Option<u64>,
146
147    /// Log all responses and requests to this file
148    log: Option<String>,
149
150    /// If a sequence is larger than the maximum model length, truncate the number
151    /// of tokens such that the sequence will fit at most the maximum length.
152    /// If `max_tokens` is not specified in the request, space for 10 tokens will be reserved instead.
153    truncate_sequence: bool,
154
155    /// Model selector (for single-model mode, deprecated in favor of models)
156    model: Option<ModelSelected>,
157
158    /// Multiple model configurations (for multi-model mode)
159    models: Vec<ModelConfig>,
160
161    /// Default model ID to use when none is specified in requests
162    default_model_id: Option<String>,
163
164    /// Maximum running sequences at any time. If the `tgt_non_granular_index` flag is set for X-LoRA models, this will be set to 1.
165    max_seqs: usize,
166
167    /// Use no KV cache.
168    no_kv_cache: bool,
169
170    /// Chat template file with a JINJA file with `messages`, `add_generation_prompt`, `bos_token`, `eos_token`, and `unk_token` as inputs.
171    /// Used if the automatic deserialization fails. If this ends with `.json` (ie., it is a file) then that template is loaded.
172    chat_template: Option<String>,
173
174    /// Explicit JINJA chat template file (.jinja) to be used. If specified, this overrides all other chat templates.
175    jinja_explicit: Option<String>,
176
177    /// Source of the token for authentication.
178    /// Can be in the formats: `literal:<value>`, `env:<value>`, `path:<value>`, `cache` to use a cached token, or `none` to use no token.
179    /// Defaults to `cache`.
180    token_source: TokenSource,
181
182    /// Enter interactive mode instead of serving a chat server.
183    interactive_mode: bool,
184
185    /// Number of prefix caches to hold on the device. Other caches are evicted to the CPU based on a LRU strategy.
186    prefix_cache_n: usize,
187
188    /// NOTE: This can be omitted to use automatic device mapping!
189    /// Number of device layers to load and run on GPU(s). All others will be on the CPU.
190    /// If one GPU is used, then this value should be an integer. Otherwise, it follows the following pattern:
191    /// ORD:NUM;... Where ORD is a unique device ordinal and NUM is the number of layers for that device.
192    num_device_layers: Option<Vec<String>>,
193
194    /// In-situ quantization to apply.
195    in_situ_quant: Option<String>,
196
197    /// GPU memory to allocate for KV cache with PagedAttention in MBs.
198    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
199    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
200    paged_attn_gpu_mem: Option<usize>,
201
202    /// Percentage of GPU memory to utilize after allocation of KV cache with PagedAttention, from 0 to 1.
203    /// If this is not set and the device is CUDA, it will default to `0.9`.
204    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
205    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
206    paged_attn_gpu_mem_usage: Option<f32>,
207
208    /// Total context length to allocate the KV cache for (total number of tokens which the KV cache can hold).
209    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
210    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
211    /// This is the default setting, and it defaults to the `max-seq-len` specified in after the model type.
212    paged_ctxt_len: Option<usize>,
213
214    /// Block size (number of tokens per block) for PagedAttention. If this is not set and the device is CUDA, it will default to 32.
215    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
216    paged_attn_block_size: Option<usize>,
217
218    /// Enables or disables PagedAttention. By default, PagedAttention will be enabled for CUDA and disabled for Metal (and is not supported for CPU). Use this to override the default behavior.
219    paged_attn: Option<bool>,
220
221    /// Use CPU only
222    cpu: bool,
223
224    /// Enable searching compatible with the OpenAI `web_search_options` setting. This uses the BERT model specified below or the default.
225    enable_search: bool,
226
227    /// Specify a Hugging Face model ID for a BERT model to assist web searching. Defaults to Snowflake Arctic Embed L.
228    search_bert_model: Option<String>,
229
230    /// Optional override search callback
231    search_callback: Option<Arc<SearchCallback>>,
232
233    /// Optional MCP client configuration
234    mcp_client_config: Option<McpClientConfig>,
235
236    /// PagedAttention KV cache type
237    paged_cache_type: PagedCacheType,
238}
239
240impl Default for MistralRsForServerBuilder {
241    /// Creates a new builder with default configuration.
242    fn default() -> Self {
243        Self {
244            device: defaults::DEVICE,
245            seed: defaults::SEED,
246            log: defaults::LOG,
247            truncate_sequence: defaults::TRUNCATE_SEQUENCE,
248            model: defaults::MODEL,
249            models: Vec::new(),
250            default_model_id: None,
251            max_seqs: defaults::MAX_SEQS,
252            no_kv_cache: defaults::NO_KV_CACHE,
253            chat_template: defaults::CHAT_TEMPLATE,
254            jinja_explicit: defaults::JINJA_EXPLICIT,
255            token_source: defaults::TOKEN_SOURCE,
256            interactive_mode: defaults::INTERACTIVE_MODE,
257            prefix_cache_n: defaults::PREFIX_CACHE_N,
258            num_device_layers: defaults::NUM_DEVICE_LAYERS,
259            in_situ_quant: defaults::IN_SITU_QUANT,
260            paged_attn_gpu_mem: defaults::PAGED_ATTN_GPU_MEM,
261            paged_attn_gpu_mem_usage: defaults::PAGED_ATTN_GPU_MEM_USAGE,
262            paged_ctxt_len: defaults::PAGED_CTXT_LEN,
263            paged_attn_block_size: defaults::PAGED_ATTN_BLOCK_SIZE,
264            paged_attn: defaults::PAGED_ATTN,
265            cpu: defaults::CPU,
266            enable_search: defaults::ENABLE_SEARCH,
267            search_bert_model: defaults::SEARCH_BERT_MODEL,
268            search_callback: defaults::SEARCH_CALLBACK,
269            mcp_client_config: None,
270            paged_cache_type: defaults::PAGED_CACHE_TYPE,
271        }
272    }
273}
274
275impl MistralRsForServerBuilder {
276    /// Creates a new `MistralRsForServerBuilder` with default settings.
277    ///
278    /// This is equivalent to calling `Default::default()`.
279    ///
280    /// ### Examples
281    ///
282    /// ```ignore
283    /// use mistralrs_server_core::mistralrs_for_server_builder::MistralRsForServerBuilder;
284    ///
285    /// let builder = mistralrs_server_core::mistralrs_for_server_builder::MistralRsForServerBuilder::new();
286    /// ```
287    pub fn new() -> Self {
288        Default::default()
289    }
290
291    /// Sets the Candle device to use for model execution.
292    pub fn with_device(mut self, device: Device) -> Self {
293        self.device = Some(device);
294        self
295    }
296
297    /// Sets the random seed for deterministic model behavior.
298    pub fn with_seed(mut self, seed: u64) -> Self {
299        self.seed = Some(seed);
300        self
301    }
302
303    /// Sets the random seed if provided.
304    pub fn with_seed_optional(mut self, seed: Option<u64>) -> Self {
305        if let Some(seed) = seed {
306            self = self.with_seed(seed);
307        }
308        self
309    }
310
311    /// Sets the logging configuration.
312    pub fn with_log(mut self, log: String) -> Self {
313        self.log = Some(log);
314        self
315    }
316
317    /// Sets the logging configuration if provided.
318    pub fn with_log_optional(mut self, log: Option<String>) -> Self {
319        if let Some(log) = log {
320            self = self.with_log(log);
321        }
322        self
323    }
324
325    /// Sets whether to truncate sequences that exceed the maximum model length.
326    pub fn with_truncate_sequence(mut self, truncate_sequence: bool) -> Self {
327        self.truncate_sequence = truncate_sequence;
328        self
329    }
330
331    /// Sets the model to be used.
332    pub fn with_model(mut self, model: ModelSelected) -> Self {
333        self.model = Some(model);
334        self
335    }
336
337    /// Add a model to the multi-model configuration.
338    pub fn with_model_config(mut self, model_config: ModelConfig) -> Self {
339        self.models.push(model_config);
340        self
341    }
342
343    /// Add multiple models to the multi-model configuration.
344    pub fn with_model_configs(mut self, model_configs: Vec<ModelConfig>) -> Self {
345        self.models.extend(model_configs);
346        self
347    }
348
349    /// Set the default model ID to use when none is specified in requests.
350    pub fn with_default_model_id(mut self, default_model_id: String) -> Self {
351        self.default_model_id = Some(default_model_id);
352        self
353    }
354
355    /// Add a model configuration.
356    pub fn add_model_config(mut self, config: ModelConfig) -> Self {
357        self.models.push(config);
358        self
359    }
360
361    /// Add a model with just an ID and ModelSelected (convenience method).
362    pub fn add_model(mut self, model_id: String, model: ModelSelected) -> Self {
363        self.models.push(ModelConfig::new(model_id, model));
364        self
365    }
366
367    /// Sets the maximum number of concurrent sequences.
368    pub fn with_max_seqs(mut self, max_seqs: usize) -> Self {
369        self.max_seqs = max_seqs;
370        self
371    }
372
373    /// Sets whether to disable the key-value cache.
374    pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
375        self.no_kv_cache = no_kv_cache;
376        self
377    }
378
379    /// Sets the chat template configuration.
380    pub fn with_chat_template(mut self, chat_template: String) -> Self {
381        self.chat_template = Some(chat_template);
382        self
383    }
384
385    /// Sets the chat template configuration if provided.
386    pub fn with_chat_template_optional(mut self, chat_template: Option<String>) -> Self {
387        if let Some(chat_template) = chat_template {
388            self = self.with_chat_template(chat_template);
389        }
390        self
391    }
392
393    /// Sets an explicit JINJA chat template file.
394    pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
395        self.jinja_explicit = Some(jinja_explicit);
396        self
397    }
398
399    /// Sets an explicit JINJA chat template file if provided.
400    pub fn with_jinja_explicit_optional(mut self, jinja_explicit: Option<String>) -> Self {
401        if let Some(jinja_explicit) = jinja_explicit {
402            self = self.with_jinja_explicit(jinja_explicit);
403        }
404        self
405    }
406
407    /// Sets the token source for authentication.
408    pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
409        self.token_source = token_source;
410        self
411    }
412
413    /// Sets whether to run in interactive mode.
414    pub fn with_interactive_mode(mut self, interactive_mode: bool) -> Self {
415        self.interactive_mode = interactive_mode;
416        self
417    }
418
419    /// Sets the number of prefix caches to hold on the device.
420    pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
421        self.prefix_cache_n = prefix_cache_n;
422        self
423    }
424
425    /// Sets the device layer mapping
426    pub fn with_num_device_layers(mut self, num_device_layers: Vec<String>) -> Self {
427        self.num_device_layers = Some(num_device_layers);
428        self
429    }
430
431    /// Sets the device layer mapping if provided.
432    pub fn with_num_device_layers_optional(
433        mut self,
434        num_device_layers: Option<Vec<String>>,
435    ) -> Self {
436        if let Some(num_device_layers) = num_device_layers {
437            self = self.with_num_device_layers(num_device_layers);
438        }
439        self
440    }
441
442    /// Sets the in-situ quantization method.
443    pub fn with_in_situ_quant(mut self, in_situ_quant: String) -> Self {
444        self.in_situ_quant = Some(in_situ_quant);
445        self
446    }
447
448    /// Sets the in-situ quantization method if provided.
449    pub fn with_in_situ_quant_optional(mut self, in_situ_quant: Option<String>) -> Self {
450        if let Some(in_situ_quant) = in_situ_quant {
451            self = self.with_in_situ_quant(in_situ_quant);
452        }
453        self
454    }
455
456    /// Sets PagedAttention.
457    ///
458    /// Unlike other `with_PROP` or `with_PROP_optional` methods, this method
459    /// sets the value to whatever `Option<bool>` is passed in as `None`, `Some(true)`
460    /// and `Some(false)` have different implications.
461    ///
462    /// `None`: default behavior for target device (e.g. enable for CUDA, disable for Metal)
463    /// `Some(true)`: enable (if supported by target device)
464    /// `Some(false)`: disable
465    pub fn set_paged_attn(mut self, paged_attn: Option<bool>) -> Self {
466        self.paged_attn = paged_attn;
467        self
468    }
469
470    /// Sets the GPU memory allocation for PagedAttention KV cache.
471    pub fn with_paged_attn_gpu_mem(mut self, paged_attn_gpu_mem: usize) -> Self {
472        self.paged_attn_gpu_mem = Some(paged_attn_gpu_mem);
473        self
474    }
475
476    /// Sets the GPU memory allocation for PagedAttention KV cache if provided.
477    pub fn with_paged_attn_gpu_mem_optional(mut self, paged_attn_gpu_mem: Option<usize>) -> Self {
478        if let Some(paged_attn_gpu_mem) = paged_attn_gpu_mem {
479            self = self.with_paged_attn_gpu_mem(paged_attn_gpu_mem);
480        }
481        self
482    }
483
484    /// Sets the percentage of GPU memory to utilize for PagedAttention.
485    pub fn with_paged_attn_gpu_mem_usage(mut self, paged_attn_gpu_mem_usage: f32) -> Self {
486        self.paged_attn_gpu_mem_usage = Some(paged_attn_gpu_mem_usage);
487        self
488    }
489
490    /// Sets the percentage of GPU memory to utilize for PagedAttention if provided.
491    pub fn with_paged_attn_gpu_mem_usage_optional(
492        mut self,
493        paged_attn_gpu_mem_usage: Option<f32>,
494    ) -> Self {
495        if let Some(paged_attn_gpu_mem_usage) = paged_attn_gpu_mem_usage {
496            self = self.with_paged_attn_gpu_mem_usage(paged_attn_gpu_mem_usage);
497        }
498        self
499    }
500
501    /// Sets the total context length for KV cache allocation.
502    pub fn with_paged_ctxt_len(mut self, paged_ctxt_len: usize) -> Self {
503        self.paged_ctxt_len = Some(paged_ctxt_len);
504        self
505    }
506
507    /// Sets the total context length for KV cache allocation if provided.
508    pub fn with_paged_ctxt_len_optional(mut self, paged_ctxt_len: Option<usize>) -> Self {
509        if let Some(paged_ctxt_len) = paged_ctxt_len {
510            self = self.with_paged_ctxt_len(paged_ctxt_len);
511        }
512        self
513    }
514
515    /// Sets the block size for PagedAttention.
516    pub fn with_paged_attn_block_size(mut self, paged_attn_block_size: usize) -> Self {
517        self.paged_attn_block_size = Some(paged_attn_block_size);
518        self
519    }
520
521    /// Sets the block size for PagedAttention.
522    pub fn with_paged_attn_cache_type(mut self, cache_type: PagedCacheType) -> Self {
523        self.paged_cache_type = cache_type;
524        self
525    }
526
527    /// Sets the block size for PagedAttention if provided.
528    pub fn with_paged_attn_block_size_optional(
529        mut self,
530        paged_attn_block_size: Option<usize>,
531    ) -> Self {
532        if let Some(paged_attn_block_size) = paged_attn_block_size {
533            self = self.with_paged_attn_block_size(paged_attn_block_size);
534        }
535        self
536    }
537
538    /// Sets whether to force CPU-only execution.
539    pub fn with_cpu(mut self, cpu: bool) -> Self {
540        self.cpu = cpu;
541        self
542    }
543
544    /// Sets whether to enable web search functionality.
545    pub fn with_enable_search(mut self, enable_search: bool) -> Self {
546        self.enable_search = enable_search;
547        self
548    }
549
550    /// Sets the BERT model for web search assistance.
551    pub fn with_search_bert_model(mut self, search_bert_model: String) -> Self {
552        self.search_bert_model = Some(search_bert_model);
553        self
554    }
555
556    /// Override the search function used when `web_search_options` is enabled.
557    pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
558        self.search_callback = Some(callback);
559        self
560    }
561
562    /// Sets the MCP client configuration.
563    pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
564        self.mcp_client_config = Some(mcp_config);
565        self
566    }
567
568    /// Sets the MCP client configuration if provided.
569    pub fn with_mcp_config_optional(mut self, mcp_config: Option<McpClientConfig>) -> Self {
570        if let Some(mcp_config) = mcp_config {
571            self = self.with_mcp_config(mcp_config);
572        }
573        self
574    }
575
576    /// Builds the configured mistral.rs instance.
577    ///
578    /// ### Examples
579    ///
580    /// ```ignore
581    /// use mistralrs_server_core::mistralrs_for_server_builder::MistralRsForServerBuilder;
582    ///
583    /// let shared_mistralrs = MistralRsForServerBuilder::new()
584    ///     .with_model(model)
585    ///     .with_in_situ_quant("8".to_string())
586    ///     .set_paged_attn(Some(true))
587    ///     .build()
588    ///     .await?;
589    /// ```
590    pub async fn build(self) -> Result<SharedMistralRsState> {
591        // Determine if we're in single-model or multi-model mode
592        if !self.models.is_empty() {
593            self.build_multi_model().await
594        } else {
595            self.build_single_model().await
596        }
597    }
598
599    /// Build a single-model instance (legacy mode)
600    async fn build_single_model(mut self) -> Result<SharedMistralRsState> {
601        let model = self.model.context("Model was None")?;
602
603        let tgt_non_granular_index = get_tgt_non_granular_index(&model);
604        let dtype = get_model_dtype(&model)?;
605        let auto_device_map_params = get_auto_device_map_params(&model)?;
606
607        if tgt_non_granular_index.is_some() {
608            self.max_seqs = 1;
609        }
610
611        let max_seq_len = auto_device_map_params.max_seq_len();
612
613        let device = if let Some(device) = self.device {
614            device
615        } else {
616            init_device(self.cpu, self.seed)?
617        };
618
619        let mapper = init_mapper(&self.num_device_layers, &auto_device_map_params);
620        let paged_attn = configure_paged_attn(&device, self.paged_attn);
621
622        // Allocate 0.5 GB of CPU memory just as a placeholder.
623        // Nothing happens here as we have no `swap_out`, see `_preempt_by_swap`.
624        let cache_config = init_cache_config(
625            self.paged_attn_block_size,
626            self.paged_attn_gpu_mem,
627            self.paged_attn_gpu_mem_usage,
628            self.paged_ctxt_len,
629            self.paged_cache_type,
630            !paged_attn,
631            max_seq_len,
632        )?;
633
634        // Configure this last to prevent arg moves
635        let loader: Box<dyn Loader> = LoaderBuilder::new(model)
636            .with_no_kv_cache(self.no_kv_cache)
637            .with_chat_template(self.chat_template)
638            .with_jinja_explicit(self.jinja_explicit)
639            .build()?;
640
641        mistralrs_instance_info(&*loader);
642
643        let isq = self
644            .in_situ_quant
645            .as_ref()
646            .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
647
648        let pipeline: LoadedPipeline = loader.load_model_from_hf(
649            None,
650            self.token_source,
651            &dtype,
652            &device,
653            false,
654            mapper,
655            isq,
656            cache_config,
657        )?;
658        info!("Model loaded.");
659
660        let scheduler_config = init_scheduler_config(&cache_config, &pipeline, self.max_seqs).await;
661
662        let bert_model = get_bert_model(self.enable_search, self.search_bert_model);
663
664        let mut builder = MistralRsBuilder::new(
665            pipeline,
666            scheduler_config,
667            !self.interactive_mode,
668            bert_model,
669        )
670        .with_opt_log(self.log)
671        .with_truncate_sequence(self.truncate_sequence)
672        .with_no_kv_cache(self.no_kv_cache)
673        .with_prefix_cache_n(self.prefix_cache_n);
674
675        // Add MCP client configuration if provided
676        if let Some(mcp_config) = self.mcp_client_config {
677            builder = builder.with_mcp_client(mcp_config);
678        }
679
680        let mistralrs = builder.build().await;
681
682        Ok(mistralrs)
683    }
684
685    /// Build a multi-model instance
686    pub async fn build_multi_model(mut self) -> Result<SharedMistralRsState> {
687        if self.models.is_empty() {
688            anyhow::bail!("No models configured for multi-model mode");
689        }
690
691        // Use the first model as the base configuration
692        let first_model = &self.models[0];
693        let model = first_model.model.clone();
694
695        let tgt_non_granular_index = get_tgt_non_granular_index(&model);
696        let dtype = get_model_dtype(&model)?;
697        let auto_device_map_params = get_auto_device_map_params(&model)?;
698
699        if tgt_non_granular_index.is_some() {
700            self.max_seqs = 1;
701        }
702
703        let max_seq_len = auto_device_map_params.max_seq_len();
704
705        let device = if let Some(device) = self.device {
706            device
707        } else {
708            init_device(self.cpu, self.seed)?
709        };
710
711        // Create the first model's pipeline
712        let loader: Box<dyn Loader> = LoaderBuilder::new(model)
713            .with_no_kv_cache(self.no_kv_cache)
714            .with_chat_template(
715                first_model
716                    .chat_template
717                    .clone()
718                    .or(self.chat_template.clone()),
719            )
720            .with_jinja_explicit(
721                first_model
722                    .jinja_explicit
723                    .clone()
724                    .or(self.jinja_explicit.clone()),
725            )
726            .build()?;
727
728        mistralrs_instance_info(&*loader);
729
730        let mapper = init_mapper(
731            &first_model
732                .num_device_layers
733                .clone()
734                .or(self.num_device_layers.clone()),
735            &auto_device_map_params,
736        );
737        let paged_attn = configure_paged_attn(&device, self.paged_attn);
738
739        let cache_config = init_cache_config(
740            self.paged_attn_block_size,
741            self.paged_attn_gpu_mem,
742            self.paged_attn_gpu_mem_usage,
743            self.paged_ctxt_len,
744            self.paged_cache_type,
745            !paged_attn,
746            max_seq_len,
747        )?;
748
749        let isq = first_model
750            .in_situ_quant
751            .as_ref()
752            .or(self.in_situ_quant.as_ref())
753            .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
754
755        let mut pipeline_names = Vec::new();
756
757        let pipeline: LoadedPipeline = loader.load_model_from_hf(
758            None,
759            self.token_source.clone(),
760            &dtype,
761            &device,
762            false,
763            mapper,
764            isq,
765            cache_config,
766        )?;
767        let first_pipeline_name = pipeline.lock().await.name();
768        info!(
769            "First model loaded: `{first_pipeline_name}` (from config key: {})",
770            first_model.model_id
771        );
772        pipeline_names.push(first_pipeline_name);
773
774        let scheduler_config = init_scheduler_config(&cache_config, &pipeline, self.max_seqs).await;
775        let bert_model = get_bert_model(self.enable_search, self.search_bert_model);
776
777        // Create the first MistralRs instance with the first model
778        let mut builder = MistralRsBuilder::new(
779            pipeline,
780            scheduler_config.clone(),
781            !self.interactive_mode,
782            bert_model.clone(),
783        )
784        .with_opt_log(self.log.clone())
785        .with_truncate_sequence(self.truncate_sequence)
786        .with_no_kv_cache(self.no_kv_cache)
787        .with_prefix_cache_n(self.prefix_cache_n);
788
789        // Add MCP client configuration if provided
790        if let Some(mcp_config) = self.mcp_client_config.clone() {
791            builder = builder.with_mcp_client(mcp_config);
792        }
793
794        let mistralrs = builder.build().await;
795
796        // Load additional models
797        for model_config in self.models.iter().skip(1) {
798            info!(
799                "Loading additional model from config key: {}",
800                model_config.model_id
801            );
802
803            let model = model_config.model.clone();
804            let dtype = get_model_dtype(&model)?;
805            let auto_device_map_params = get_auto_device_map_params(&model)?;
806
807            let loader: Box<dyn Loader> = LoaderBuilder::new(model)
808                .with_no_kv_cache(self.no_kv_cache)
809                .with_chat_template(
810                    model_config
811                        .chat_template
812                        .clone()
813                        .or(self.chat_template.clone()),
814                )
815                .with_jinja_explicit(
816                    model_config
817                        .jinja_explicit
818                        .clone()
819                        .or(self.jinja_explicit.clone()),
820                )
821                .build()?;
822
823            let mapper = init_mapper(
824                &model_config
825                    .num_device_layers
826                    .clone()
827                    .or(self.num_device_layers.clone()),
828                &auto_device_map_params,
829            );
830
831            let isq = model_config
832                .in_situ_quant
833                .as_ref()
834                .or(self.in_situ_quant.as_ref())
835                .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
836
837            let pipeline: LoadedPipeline = loader.load_model_from_hf(
838                None,
839                self.token_source.clone(),
840                &dtype,
841                &device,
842                false,
843                mapper,
844                isq,
845                cache_config,
846            )?;
847
848            // Use the pipeline's name() as the model ID
849            let pipeline_name = pipeline.lock().await.name();
850
851            // Check for model ID conflicts
852            if pipeline_names.contains(&pipeline_name) {
853                anyhow::bail!(
854                    "Model ID conflict: '{}' is already registered. Models from config keys '{}' and previous models have the same pipeline identifier.",
855                    pipeline_name,
856                    model_config.model_id
857                );
858            }
859
860            // Add the model to the MistralRs instance
861            let engine_config = mistralrs_core::EngineConfig {
862                truncate_sequence: self.truncate_sequence,
863                no_kv_cache: self.no_kv_cache,
864                no_prefix_cache: false,
865                prefix_cache_n: self.prefix_cache_n,
866                disable_eos_stop: false,
867                throughput_logging_enabled: !self.interactive_mode,
868                search_embedding_model: bert_model.clone(),
869                search_callback: self.search_callback.clone(),
870                tool_callbacks: HashMap::new(),
871                tool_callbacks_with_tools: HashMap::new(),
872            };
873
874            let mut add_model_config = mistralrs_core::AddModelConfig::new(engine_config);
875            if let Some(mcp_config) = self.mcp_client_config.clone() {
876                add_model_config = add_model_config.with_mcp_config(mcp_config);
877            }
878
879            mistralrs
880                .add_model(
881                    pipeline_name.clone(),
882                    pipeline,
883                    scheduler_config.clone(),
884                    add_model_config,
885                )
886                .await
887                .map_err(|e| anyhow::anyhow!("Failed to add model {}: {}", pipeline_name, e))?;
888
889            info!(
890                "Model `{pipeline_name}` registered successfully (from config key: {})",
891                model_config.model_id
892            );
893            pipeline_names.push(pipeline_name);
894        }
895
896        // Set the default model if specified
897        if let Some(ref default_model_id) = self.default_model_id {
898            mistralrs
899                .set_default_model_id(default_model_id)
900                .map_err(|e| anyhow::anyhow!("Failed to set default model: {}", e))?;
901        }
902
903        // Log all models loaded
904        info!("All models loaded: `{}`", pipeline_names.join("`, `"));
905
906        // Log default model
907        if let Some(ref default_id) = self.default_model_id {
908            info!("Default model: {}", default_id);
909        } else {
910            info!(
911                "Default model: {} (first model, from config key: {})",
912                pipeline_names[0], self.models[0].model_id
913            );
914        }
915        Ok(mistralrs)
916    }
917}
918
919// TODO: replace with best device?
920/// Initializes the device to be used for computation, optionally forcing CPU usage and setting a seed.
921fn init_device(force_cpu: bool, seed: Option<u64>) -> Result<candle_core::Device> {
922    #[cfg(feature = "metal")]
923    let device = if force_cpu {
924        Device::Cpu
925    } else {
926        Device::new_metal(0)?
927    };
928    #[cfg(not(feature = "metal"))]
929    #[allow(clippy::if_same_then_else)]
930    let device = if force_cpu {
931        Device::Cpu
932    } else if mistralrs_core::distributed::use_nccl() {
933        Device::Cpu
934    } else {
935        Device::cuda_if_available(0)?
936    };
937
938    if let Some(seed) = seed {
939        device.set_seed(seed)?;
940    }
941
942    Ok(device)
943}
944
945/// Initializes the device mapping configuration for distributing model layers.
946fn init_mapper(
947    num_device_layers: &Option<Vec<String>>,
948    auto_device_map_params: &AutoDeviceMapParams,
949) -> DeviceMapSetting {
950    // Parse device mapper
951    if let Some(device_layers) = num_device_layers {
952        if device_layers.len() == 1 && device_layers[0].parse::<usize>().is_ok() {
953            let layers = device_layers[0].parse::<usize>().unwrap();
954            DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(vec![
955                DeviceLayerMapMetadata { ordinal: 0, layers },
956            ]))
957        } else {
958            let mut mapping = Vec::new();
959            for layer in device_layers {
960                let split = layer.splitn(2, ':').collect::<Vec<_>>();
961                if split.len() < 2 {
962                    panic!("Expected layer to be of format ORD:NUM, got {layer}");
963                }
964                let ord = split[0]
965                    .parse::<usize>()
966                    .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[0]));
967                let num = split[1]
968                    .parse::<usize>()
969                    .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[1]));
970                for DeviceLayerMapMetadata { ordinal, layers: _ } in &mapping {
971                    if *ordinal == ord {
972                        panic!("Duplicate ordinal {ord}");
973                    }
974                }
975                mapping.push(DeviceLayerMapMetadata {
976                    ordinal: ord,
977                    layers: num,
978                });
979            }
980            DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(mapping))
981        }
982    } else {
983        DeviceMapSetting::Auto(auto_device_map_params.clone())
984    }
985}
986
987/// Logs hardware feature information and the model's sampling strategy and kind.
988fn mistralrs_instance_info(loader: &dyn Loader) {
989    info!(
990        "avx: {}, neon: {}, simd128: {}, f16c: {}",
991        candle_core::utils::with_avx(),
992        candle_core::utils::with_neon(),
993        candle_core::utils::with_simd128(),
994        candle_core::utils::with_f16c()
995    );
996
997    info!("Sampling method: penalties -> temperature -> topk -> topp -> minp -> multinomial");
998    info!("Model kind is: {}", loader.get_kind().to_string());
999}
1000
1001/// Determines whether paged attention should be enabled based on device type and preferences.
1002fn configure_paged_attn(device: &Device, paged_attn: Option<bool>) -> bool {
1003    if device.is_cpu() {
1004        if paged_attn == Some(true) {
1005            warn!("Paged attention is not supported on CPU.");
1006        }
1007
1008        defaults::PAGED_ATTN_CPU
1009    } else if device.is_cuda() || mistralrs_core::distributed::use_nccl() {
1010        paged_attn.unwrap_or(defaults::PAGED_ATTN_CUDA)
1011    } else if device.is_metal() {
1012        paged_attn.unwrap_or(defaults::PAGED_ATTN_METAL)
1013    } else {
1014        false
1015    }
1016}
1017
1018/// Initializes the cache configuration for paged attention based on provided parameters.
1019fn init_cache_config(
1020    paged_attn_block_size: Option<usize>,
1021    paged_attn_gpu_mem: Option<usize>,
1022    paged_attn_gpu_mem_usage: Option<f32>,
1023    paged_ctxt_len: Option<usize>,
1024    cache_type: PagedCacheType,
1025    no_paged_attn: bool,
1026    max_seq_len: usize,
1027) -> Result<Option<PagedAttentionConfig>> {
1028    match (
1029        paged_attn_block_size,
1030        paged_attn_gpu_mem,
1031        paged_attn_gpu_mem_usage,
1032        paged_ctxt_len,
1033        paged_attn_supported(),
1034        no_paged_attn,
1035    ) {
1036        (block_size, None, None, None, true, false) => Ok(Some(PagedAttentionConfig::new(
1037            block_size,
1038            512,
1039            MemoryGpuConfig::ContextSize(max_seq_len),
1040            cache_type,
1041        )?)),
1042        (block_size, None, None, Some(ctxt), true, false) => Ok(Some(PagedAttentionConfig::new(
1043            block_size,
1044            512,
1045            MemoryGpuConfig::ContextSize(ctxt),
1046            cache_type,
1047        )?)),
1048        (block_size, None, Some(f), None, true, false) => Ok(Some(PagedAttentionConfig::new(
1049            block_size,
1050            512,
1051            MemoryGpuConfig::Utilization(f),
1052            cache_type,
1053        )?)),
1054        (block_size, Some(m), None, None, true, false) => Ok(Some(PagedAttentionConfig::new(
1055            block_size,
1056            512,
1057            MemoryGpuConfig::MbAmount(m),
1058            cache_type,
1059        )?)),
1060        (block_size, Some(_m), Some(f), None, true, false) => {
1061            info!("Both memory size, and usage were specified, defaulting to the usage value.");
1062            Ok(Some(PagedAttentionConfig::new(
1063                block_size,
1064                512,
1065                MemoryGpuConfig::Utilization(f),
1066                cache_type,
1067            )?))
1068        }
1069        (block_size, Some(_m), None, Some(ctxt), true, false) => {
1070            info!("All memory size and ctxt len, defaulting to the context len value.");
1071            Ok(Some(PagedAttentionConfig::new(
1072                block_size,
1073                512,
1074                MemoryGpuConfig::ContextSize(ctxt),
1075                cache_type,
1076            )?))
1077        }
1078        (block_size, None, Some(f), Some(_ctxt), true, false) => {
1079            info!("Both ctxt len and usage were specified, defaulting to the usage value.");
1080            Ok(Some(PagedAttentionConfig::new(
1081                block_size,
1082                512,
1083                MemoryGpuConfig::Utilization(f),
1084                cache_type,
1085            )?))
1086        }
1087        (_, _, _, _, _, _) => Ok(None),
1088    }
1089}
1090
1091/// Initializes the scheduler configuration based on cache settings and pipeline metadata.
1092async fn init_scheduler_config(
1093    cache_config: &Option<PagedAttentionConfig>,
1094    pipeline: &LoadedPipeline,
1095    args_max_seqs: usize,
1096) -> SchedulerConfig {
1097    if cache_config.is_some() {
1098        // Handle case where we may have device mapping
1099        if let Some(ref cache_config) = pipeline.lock().await.get_metadata().cache_config {
1100            SchedulerConfig::PagedAttentionMeta {
1101                max_num_seqs: args_max_seqs,
1102                config: cache_config.clone(),
1103            }
1104        } else {
1105            SchedulerConfig::DefaultScheduler {
1106                method: DefaultSchedulerMethod::Fixed(args_max_seqs.try_into().unwrap()),
1107            }
1108        }
1109    } else {
1110        SchedulerConfig::DefaultScheduler {
1111            method: DefaultSchedulerMethod::Fixed(args_max_seqs.try_into().unwrap()),
1112        }
1113    }
1114}
1115
1116/// Configures PagedAttention based on two flags.
1117///
1118/// This function resolves the tri-state PagedAttention configuration from
1119/// the mutually exclusive `paged_attn` and `no_paged_attn` flags.
1120pub fn configure_paged_attn_from_flags(
1121    paged_attn: bool,
1122    no_paged_attn: bool,
1123) -> Result<Option<bool>> {
1124    match (paged_attn, no_paged_attn) {
1125        (true, true) => {
1126            anyhow::bail!("Error: `--paged-attn` and `--no-paged-attn` cannot be used together.");
1127        }
1128        (true, false) => Ok(Some(true)),
1129        (false, true) => Ok(Some(false)),
1130        (false, false) => Ok(None),
1131    }
1132}
1133
1134/// Creates a BERT embedding model configuration for search functionality.
1135pub fn get_bert_model(
1136    enable_search: bool,
1137    search_bert_model: Option<String>,
1138) -> Option<BertEmbeddingModel> {
1139    if enable_search {
1140        Some(
1141            search_bert_model
1142                .map(BertEmbeddingModel::Custom)
1143                .unwrap_or_default(),
1144        )
1145    } else {
1146        None
1147    }
1148}